docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.dataset.maze_dataset

MazeDatasetConfig is where you decide what your dataset should look like, then pass it to MazeDataset.from_config to generate or load the dataset.

see demo_dataset notebook


  1"""`MazeDatasetConfig` is where you decide what your dataset should look like, then pass it to `MazeDataset.from_config` to generate or load the dataset.
  2
  3see [demo_dataset notebook](../../notebooks/demo_dataset)
  4
  5"""
  6
  7import copy
  8import json
  9import multiprocessing
 10import typing
 11import warnings
 12from pathlib import Path
 13from typing import Literal, Optional, cast, overload
 14
 15import numpy as np
 16import tqdm
 17from jaxtyping import Int
 18from muutils.json_serialize import (
 19	json_serialize,
 20)
 21from muutils.json_serialize.util import (
 22	_FORMAT_KEY,
 23	JSONdict,
 24)
 25from muutils.misc import stable_hash
 26from zanj import ZANJ
 27from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler
 28
 29from maze_dataset.constants import CoordArray
 30from maze_dataset.dataset.dataset import (
 31	GPTDataset,
 32)
 33from maze_dataset.dataset.maze_dataset_config import (
 34	SERIALIZE_MINIMAL_THRESHOLD,
 35	EndpointKwargsType,
 36	MazeDatasetConfig,
 37)
 38from maze_dataset.generation.seed import GLOBAL_SEED
 39from maze_dataset.maze import LatticeMaze, SolvedMaze
 40
 41_GLOBAL_WORKER_CONFIG: MazeDatasetConfig
 42
 43
 44def _generate_maze_helper(index: int) -> Optional[SolvedMaze]:  # noqa: ARG001
 45	"""Helper function for generating mazes in parallel.
 46
 47	> [!CAUTION]
 48	> don't use this unless generating in parallel!
 49	"""
 50	global _GLOBAL_WORKER_CONFIG  # noqa: PLW0602
 51	# TODO: don't use this unless generating in parallel!
 52	maze: LatticeMaze = _GLOBAL_WORKER_CONFIG.maze_ctor(
 53		grid_shape=_GLOBAL_WORKER_CONFIG.grid_shape_np,
 54		**_GLOBAL_WORKER_CONFIG.maze_ctor_kwargs,
 55	)
 56
 57	endpoint_kwargs: EndpointKwargsType = _GLOBAL_WORKER_CONFIG.endpoint_kwargs.copy()
 58
 59	# Generate the solution
 60	# mypy doesnt realize EndpointKwargsType has only string keys: `Keywords must be strings  [misc]`
 61	# TYPING: error: No overload variant of "generate_random_path" of "LatticeMaze" matches argument type "dict[Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | list[tuple[int, int]] | None]"  [call-overload]
 62	solution: Optional[CoordArray] = maze.generate_random_path(**endpoint_kwargs)  # type: ignore[misc, call-overload]
 63
 64	# Validate the solution
 65	if (
 66		solution is None
 67		or len(solution) == 0
 68		or not isinstance(solution, np.ndarray)
 69		# magic value is fine here
 70		or len(solution.shape) != 2  # noqa: PLR2004
 71	):
 72		return None  # Return None if the solution is invalid
 73
 74	return SolvedMaze.from_lattice_maze(
 75		lattice_maze=maze,
 76		solution=solution,
 77	)
 78
 79
 80def _maze_gen_init_worker(config: MazeDatasetConfig) -> None:
 81	"""special worker helper
 82
 83	> [!CAUTION]
 84	> this makes the generation depend both on whether parallelism is used, and on the number of processes. this is bad!
 85
 86	"""
 87	# TODO: dont use globals here!
 88	global _GLOBAL_WORKER_CONFIG  # noqa: PLW0603
 89	_GLOBAL_WORKER_CONFIG = config
 90
 91	process_id: tuple[int, ...] = multiprocessing.current_process()._identity
 92	if len(process_id) == 0:
 93		# no multiprocessing, seed was already set
 94		pass
 95	elif len(process_id) == 1:
 96		# multiprocessing, adjust seed based on process id
 97		# only set numpy seed, since we do not use other random gens
 98		np.random.seed(
 99			_GLOBAL_WORKER_CONFIG.seed
100			or GLOBAL_SEED  # if the seed is None, use the global seed
101			+ process_id[0]
102		)
103	else:
104		err_msg = (
105			f"unexpected process id: {process_id = }\n{multiprocessing.Process() = }"
106		)
107		raise ValueError(
108			err_msg,
109		)
110
111
112class MazeDataset(GPTDataset[MazeDatasetConfig]):
113	"""a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`"""
114
115	def __init__(
116		self,
117		cfg: MazeDatasetConfig,
118		mazes: typing.Sequence[SolvedMaze],
119		generation_metadata_collected: dict | None = None,
120	) -> None:
121		"""initialize a maze dataset from a config and a list of solved mazes"""
122		super().__init__()
123		self.cfg: MazeDatasetConfig = cfg
124		self.mazes: list[SolvedMaze] = list(mazes)
125		self.generation_metadata_collected: dict | None = generation_metadata_collected
126
127	# TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset"  [override]
128	@classmethod
129	def from_config(  # type: ignore[override]
130		cls,
131		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
132		cfg: MazeDatasetConfig,  # type: ignore[override]
133		do_generate: bool = True,
134		load_local: bool = True,
135		save_local: bool = True,
136		zanj: ZANJ | None = None,
137		do_download: bool = True,
138		local_base_path: Path = Path("data/maze_dataset"),
139		except_on_config_mismatch: bool = True,
140		allow_generation_metadata_filter_mismatch: bool = True,
141		verbose: bool = False,
142		**kwargs,
143	) -> "MazeDataset":
144		"""create a maze dataset from a config
145
146		priority of loading:
147		1. load from local
148		2. download
149		3. generate
150
151		"""
152		return cast(
153			MazeDataset,
154			super().from_config(
155				cfg=cfg,
156				do_generate=do_generate,
157				load_local=load_local,
158				save_local=save_local,
159				zanj=zanj,
160				do_download=do_download,
161				local_base_path=local_base_path,
162				except_on_config_mismatch=except_on_config_mismatch,
163				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
164				verbose=verbose,
165				**kwargs,
166			),
167		)
168
169	def data_hash(self) -> int:
170		"""return a hash of the data"""
171		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
172
173	def __getitem__(self, i: int) -> SolvedMaze:
174		"""get a maze by index"""
175		return self.mazes[i]
176
177	def __iter__(self) -> typing.Iterator[SolvedMaze]:
178		"""iterate over the mazes"""
179		return iter(self.mazes)
180
181	def __deepcopy__(self, memo) -> "MazeDataset":  # noqa: ANN001
182		"""deepcopy the dataset
183
184		FIX: this isnt actually a deepcopy I think?
185		"""
186		return MazeDataset.load(self._serialize_full())
187
188	# TYPING: get type hints on the tokenizer here
189	@overload
190	def as_tokens(
191		self,
192		maze_tokenizer,  # noqa: ANN001
193		limit: int | None = None,
194		join_tokens_individual_maze: Literal[False] = False,
195	) -> list[list[str]]: ...
196	@overload
197	def as_tokens(
198		self,
199		maze_tokenizer,  # noqa: ANN001
200		limit: int | None = None,
201		join_tokens_individual_maze: Literal[True] = True,
202	) -> list[str]: ...
203	def as_tokens(
204		self,
205		maze_tokenizer,  # TODO: MazeTokenizer
206		limit: int | None = None,
207		join_tokens_individual_maze: bool = False,
208	) -> list[list[str]] | list[str]:
209		"""return the dataset as tokens according to the passed `maze_tokenizer`
210
211		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
212
213		if `join_tokens_individual_maze` is True, then the tokens of each maze are
214		joined with a space, and the result is a list of strings.
215		i.e.:
216
217			>>> dataset.as_tokens(join_tokens_individual_maze=False)
218			[["a", "b", "c"], ["d", "e", "f"]]
219			>>> dataset.as_tokens(join_tokens_individual_maze=True)
220			["a b c", "d e f"]
221		"""
222		output: list[list[str]] = [
223			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
224		]
225		if join_tokens_individual_maze:
226			return [" ".join(tokens) for tokens in output]
227		else:
228			return output
229
230	def __len__(self) -> int:
231		"""return the number of mazes in the dataset"""
232		return len(self.mazes)
233
234	def __eq__(self, other: object) -> bool:
235		"""compare two datasets"""
236		if not isinstance(other, MazeDataset):
237			raise NotImplementedError(
238				"can only compare with other MazeDataset objects",
239			)
240		# TODO: compare hashes of data instead of the data itself?
241		return self.cfg == other.cfg and self.mazes == other.mazes
242
243	def assert_equal(self, other: "MazeDataset") -> None:
244		"""assert that two datasets are equal"""
245		assert isinstance(other, MazeDataset)
246		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
247		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
248
249	@classmethod
250	def generate(
251		cls,
252		cfg: MazeDatasetConfig,
253		gen_parallel: bool = False,
254		pool_kwargs: dict | None = None,
255		verbose: bool = False,
256		# TODO: what to do when unexpected kwargs are passed?
257		**kwargs,  # noqa: ARG003
258	) -> "MazeDataset":
259		"""Generate a maze dataset given a config and some generation parameters"""
260		# Copy the config to avoid modifying the original
261		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
262			json.loads(json.dumps(cfg.serialize())),
263		)
264
265		if pool_kwargs is None:
266			pool_kwargs = dict()
267		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
268
269		solved_mazes: list[SolvedMaze | None]
270		# Configure tqdm for progress bar
271		tqdm_kwargs: dict = dict(
272			total=cfg_cpy.n_mazes,
273			unit="maze",
274			desc="generating & solving mazes",
275			disable=not verbose,
276		)
277		# TODO: don't use the global unless generating in parallel!
278		if gen_parallel:
279			with multiprocessing.Pool(
280				**pool_kwargs,
281				initializer=_maze_gen_init_worker,
282				initargs=(cfg_cpy,),
283			) as pool:
284				solved_mazes = list(
285					tqdm.tqdm(
286						pool.imap(_generate_maze_helper, maze_indexes),
287						**tqdm_kwargs,
288					),
289				)
290
291		else:
292			_maze_gen_init_worker(cfg_cpy)
293			solved_mazes = list(
294				tqdm.tqdm(
295					map(
296						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
297						# why does it think tolist() returns a string?
298						_generate_maze_helper,  # type: ignore[arg-type]
299						maze_indexes.tolist(),
300					),
301					**tqdm_kwargs,
302				),
303			)
304
305		# Filter out None values explicitly after ensuring all results are collected
306		solved_mazes_: list[SolvedMaze] = [
307			maze for maze in solved_mazes if maze is not None
308		]
309		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
310
311		# Update the config with the actual number of mazes
312		cfg_cpy.n_mazes = len(solved_mazes_)
313
314		dataset: MazeDataset = cls(
315			cfg=cfg_cpy,
316			mazes=solved_mazes_,
317		)
318
319		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
320
321		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
322
323		return dataset
324
325	@classmethod
326	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
327		"(not implemented yet!) download a maze dataset from the internet"
328		raise NotImplementedError("not implemented yet")
329
330	@classmethod
331	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
332		"""load from zanj/json"""
333		if data[_FORMAT_KEY] == "MazeDataset:minimal":
334			return cls._load_minimal(data)
335		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
336			return cls._load_minimal_soln_cat(data)
337		elif data[_FORMAT_KEY] == "MazeDataset":
338			if (
339				SERIALIZE_MINIMAL_THRESHOLD == -1
340			):  # Allow access to `_load_legacy` for profiling
341				return cls._load_legacy(data)
342			return cls._load_full(data)
343		else:
344			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
345			raise KeyError(
346				err_msg,
347			)
348
349	@classmethod
350	def _load_full(cls, data: JSONdict) -> "MazeDataset":
351		assert data[_FORMAT_KEY] == "MazeDataset"
352		return cls(
353			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
354			mazes=load_item_recursive(data["mazes"], tuple()),
355			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
356		)
357
358	@classmethod
359	def _load_minimal(cls, data: JSONdict) -> "MazeDataset":
360		assert data[_FORMAT_KEY] == "MazeDataset:minimal"
361		return cls(
362			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
363			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
364			mazes=[
365				SolvedMaze(
366					clist,
367					soln[:slen, ...],
368				)
369				for clist, slen, soln in zip(
370					load_item_recursive(data["maze_connection_lists"], tuple()),
371					load_item_recursive(data["maze_solution_lengths"], tuple()),
372					load_item_recursive(data["maze_solutions"], tuple()),
373					strict=False,
374					# load_item_recursive(data["maze_endpoints"], tuple()),
375				)
376			],
377		)
378
379	@classmethod
380	def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset":
381		assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat"
382
383		maze_solution_lengths = load_item_recursive(
384			data["maze_solution_lengths"],
385			tuple(),
386		)
387		maze_solutions_concat = load_item_recursive(
388			data["maze_solutions_concat"],
389			tuple(),
390		)
391		maze_solutions = np.split(
392			maze_solutions_concat,
393			np.cumsum(maze_solution_lengths)[:-1],
394			axis=0,
395		)
396
397		return cls(
398			cfg=load_item_recursive(data["cfg"], tuple()),
399			generation_metadata_collected=load_item_recursive(
400				data["generation_metadata_collected"],
401				tuple(),
402			),
403			mazes=[
404				SolvedMaze(
405					connection_list=clist,
406					solution=soln,
407				)
408				for clist, soln in zip(
409					load_item_recursive(data["maze_connection_lists"], tuple()),
410					# load_item_recursive(data["maze_endpoints"], tuple()),
411					maze_solutions,
412					strict=False,
413				)
414			],
415		)
416
417	@classmethod
418	def _load_legacy(cls, data: JSONdict) -> "MazeDataset":
419		"""Legacy `load` method from <0.5.2. Used exclusively for profiling comparison."""
420		assert data[_FORMAT_KEY] == "MazeDataset"
421		return cls(
422			**{
423				key: load_item_recursive(data[key], tuple())
424				for key in ["cfg", "mazes", "generation_metadata_collected"]
425			},
426		)
427
428	def serialize(self) -> JSONdict:
429		"""serialize to zanj/json"""
430		if (
431			SERIALIZE_MINIMAL_THRESHOLD is not None
432			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
433		):
434			return self._serialize_minimal()
435		return self._serialize_full()
436
437	def _serialize_full(self) -> JSONdict:
438		return {
439			_FORMAT_KEY: "MazeDataset",
440			"cfg": json_serialize(self.cfg),
441			"fname": self.cfg.to_fname(),
442			"mazes": json_serialize(self.mazes),
443			"generation_metadata_collected": json_serialize(
444				self.generation_metadata_collected,
445			),
446		}
447
448	def _serialize_minimal(self) -> JSONdict:
449		"alternate serialization where metadata is collected and mazes are stored in concatenated form"
450		filtered_meta: MazeDataset
451		if self.generation_metadata_collected is None:
452			filtered_meta = self.filter_by.collect_generation_meta()
453		else:
454			filtered_meta = self
455
456		max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes)
457		n_mazes: int = len(filtered_meta.mazes)
458		grid_n: int = filtered_meta.cfg.grid_n
459
460		maze_connection_lists: np.ndarray = np.empty(
461			(n_mazes, 2, grid_n, grid_n),
462			dtype=np.bool_,
463		)
464		# maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
465		maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32)
466		maze_solutions: np.ndarray = np.empty(
467			(n_mazes, max_solution_len, 2),
468			dtype=np.int8,
469		)
470
471		for idx, maze in enumerate(filtered_meta.mazes):
472			maze_connection_lists[idx] = maze.connection_list
473			# maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
474			maze_solution_lengths[idx] = maze.solution.shape[0]
475			maze_solutions[idx, : maze.solution.shape[0]] = maze.solution
476
477		return {
478			_FORMAT_KEY: "MazeDataset:minimal",
479			"cfg": json_serialize(filtered_meta.cfg),
480			"fname": filtered_meta.cfg.to_fname(),
481			"generation_metadata_collected": json_serialize(
482				filtered_meta.generation_metadata_collected,
483			),
484			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
485			# "maze_endpoints": maze_endpoints,
486			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
487			"maze_solutions": maze_solutions,  # type: ignore[dict-item]
488		}
489
490	def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict:
491		"alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form"
492		filtered_meta: MazeDataset
493		if self.generation_metadata_collected is None:
494			filtered_meta = self.filter_by.collect_generation_meta()
495		else:
496			filtered_meta = self
497
498		maze_solution_lengths: np.ndarray = np.array(
499			[m.solution.shape[0] for m in filtered_meta.mazes],
500			dtype=np.int32,
501		)
502		n_mazes: int = len(filtered_meta.mazes)
503		grid_n: int = filtered_meta.cfg.grid_n
504		total_solution_len: int = np.sum(maze_solution_lengths)
505
506		maze_connection_lists: np.ndarray = np.empty(
507			(n_mazes, 2, grid_n, grid_n),
508			dtype=np.bool_,
509		)
510		maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
511		maze_solutions_concat: np.ndarray = np.empty(
512			(total_solution_len, 2),
513			dtype=np.int8,
514		)
515
516		solutions_running_idx: int = 0
517		for idx, maze in enumerate(filtered_meta.mazes):
518			maze_connection_lists[idx] = maze.connection_list
519			maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
520			soln_len: int = maze.solution.shape[0]
521			maze_solution_lengths[idx] = soln_len
522			maze_solutions_concat[
523				solutions_running_idx : solutions_running_idx + soln_len
524			] = maze.solution
525			solutions_running_idx += soln_len
526
527		return {
528			_FORMAT_KEY: "MazeDataset:minimal_soln_cat",
529			"cfg": json_serialize(filtered_meta.cfg),
530			"fname": filtered_meta.cfg.to_fname(),
531			"generation_metadata_collected": json_serialize(
532				filtered_meta.generation_metadata_collected,
533			),
534			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
535			"maze_endpoints": maze_endpoints,  # type: ignore[dict-item]
536			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
537			"maze_solutions_concat": maze_solutions_concat,  # type: ignore[dict-item]
538		}
539
540	def update_self_config(self) -> None:
541		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
542		if self.cfg.n_mazes != len(self.mazes):
543			warnings.warn(
544				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
545			)
546			self.cfg.n_mazes = len(self.mazes)
547
548	def custom_maze_filter(
549		self,
550		method: typing.Callable[[SolvedMaze], bool],
551		**kwargs,
552	) -> "MazeDataset":
553		"""filter the dataset using a custom method"""
554		output: MazeDataset = MazeDataset(
555			cfg=copy.deepcopy(self.cfg),
556			mazes=[m for m in self.mazes if method(m, **kwargs)],
557		)
558		output.cfg.applied_filters.append(
559			{
560				"name": f"__custom__:{method.__name__}",
561				"kwargs": kwargs,
562			},
563		)
564		output.update_self_config()
565		return output
566
567
568MazeDatasetConfig._dataset_class = property(  # type: ignore[method-assign, assignment]
569	lambda self: MazeDataset,  # noqa: ARG005
570)
571
572# register things with zanj
573register_loader_handler(
574	LoaderHandler(
575		check=lambda json_item, path=None, z=None: (  # type: ignore[misc] # noqa: ARG005
576			isinstance(json_item, typing.Mapping)
577			and _FORMAT_KEY in json_item
578			and json_item[_FORMAT_KEY].startswith("MazeDataset")
579		),
580		load=lambda json_item, path=None, z=None: MazeDataset.load(json_item),  # type: ignore[misc] # noqa: ARG005
581		uid="MazeDataset",
582		source_pckg="maze_dataset.generation.maze_dataset",
583		desc="MazeDataset",
584	),
585)
586
587
588# TODO: the code below is for doing some smarter collecting and type checking. Probably will delete.
589"""
590collect either the type at the field, or the shape of the field if it is an array
591metadata_types: dict[str, set[type, tuple]] = dict()
592for maze in new_dataset:
593	for key, value in maze.generation_meta.items():
594		if key not in metadata_types:
595			metadata_types[key] = set()
596
597		if isinstance(value, np.ndarray):
598			metadata_types[key].add(value.shape)
599		else:
600			metadata_types[key].add(type(value))
601
602# figure out what to do for each field
603metadata_actions: dict[str, typing.Callable] = dict()
604for key, key_type in metadata_types.items():
605	if all(isinstance(kt, tuple) for kt in key_type):
606		if all(kt == (2,) for kt in key_type):
607			# its all coords, do a statcounter on those coords
608			metadata_actions[key] = lambda vals: Counter(tuple(x) for x in vals)
609		elif all(
610			(len(kt) == 2) and (kt[1] == 2)
611			for kt in key_type
612		):
613			# its a list of coords, do a statcounter on those coords
614			metadata_actions[key] = lambda vals: Counter(
615				tuple(x) for x in np.concatenate(vals)
616			)
617		else:
618			# its a list of something else, do a counter on those
619			# TODO: throw except here?
620			metadata_actions[key] = Counter
621
622	elif all(kt in (bool, int, float) for kt in key_type):
623		# statcounter for numeric types
624		metadata_actions[key] = StatCounter
625	elif all(kt == str for kt in key_type):
626		# counter for string types
627		metadata_actions[key] = Counter
628	else:
629		# counter for everything else
630		# TODO: throw except here?
631		metadata_actions[key] = Counter
632"""

113class MazeDataset(GPTDataset[MazeDatasetConfig]):
114	"""a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`"""
115
116	def __init__(
117		self,
118		cfg: MazeDatasetConfig,
119		mazes: typing.Sequence[SolvedMaze],
120		generation_metadata_collected: dict | None = None,
121	) -> None:
122		"""initialize a maze dataset from a config and a list of solved mazes"""
123		super().__init__()
124		self.cfg: MazeDatasetConfig = cfg
125		self.mazes: list[SolvedMaze] = list(mazes)
126		self.generation_metadata_collected: dict | None = generation_metadata_collected
127
128	# TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset"  [override]
129	@classmethod
130	def from_config(  # type: ignore[override]
131		cls,
132		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
133		cfg: MazeDatasetConfig,  # type: ignore[override]
134		do_generate: bool = True,
135		load_local: bool = True,
136		save_local: bool = True,
137		zanj: ZANJ | None = None,
138		do_download: bool = True,
139		local_base_path: Path = Path("data/maze_dataset"),
140		except_on_config_mismatch: bool = True,
141		allow_generation_metadata_filter_mismatch: bool = True,
142		verbose: bool = False,
143		**kwargs,
144	) -> "MazeDataset":
145		"""create a maze dataset from a config
146
147		priority of loading:
148		1. load from local
149		2. download
150		3. generate
151
152		"""
153		return cast(
154			MazeDataset,
155			super().from_config(
156				cfg=cfg,
157				do_generate=do_generate,
158				load_local=load_local,
159				save_local=save_local,
160				zanj=zanj,
161				do_download=do_download,
162				local_base_path=local_base_path,
163				except_on_config_mismatch=except_on_config_mismatch,
164				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
165				verbose=verbose,
166				**kwargs,
167			),
168		)
169
170	def data_hash(self) -> int:
171		"""return a hash of the data"""
172		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
173
174	def __getitem__(self, i: int) -> SolvedMaze:
175		"""get a maze by index"""
176		return self.mazes[i]
177
178	def __iter__(self) -> typing.Iterator[SolvedMaze]:
179		"""iterate over the mazes"""
180		return iter(self.mazes)
181
182	def __deepcopy__(self, memo) -> "MazeDataset":  # noqa: ANN001
183		"""deepcopy the dataset
184
185		FIX: this isnt actually a deepcopy I think?
186		"""
187		return MazeDataset.load(self._serialize_full())
188
189	# TYPING: get type hints on the tokenizer here
190	@overload
191	def as_tokens(
192		self,
193		maze_tokenizer,  # noqa: ANN001
194		limit: int | None = None,
195		join_tokens_individual_maze: Literal[False] = False,
196	) -> list[list[str]]: ...
197	@overload
198	def as_tokens(
199		self,
200		maze_tokenizer,  # noqa: ANN001
201		limit: int | None = None,
202		join_tokens_individual_maze: Literal[True] = True,
203	) -> list[str]: ...
204	def as_tokens(
205		self,
206		maze_tokenizer,  # TODO: MazeTokenizer
207		limit: int | None = None,
208		join_tokens_individual_maze: bool = False,
209	) -> list[list[str]] | list[str]:
210		"""return the dataset as tokens according to the passed `maze_tokenizer`
211
212		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
213
214		if `join_tokens_individual_maze` is True, then the tokens of each maze are
215		joined with a space, and the result is a list of strings.
216		i.e.:
217
218			>>> dataset.as_tokens(join_tokens_individual_maze=False)
219			[["a", "b", "c"], ["d", "e", "f"]]
220			>>> dataset.as_tokens(join_tokens_individual_maze=True)
221			["a b c", "d e f"]
222		"""
223		output: list[list[str]] = [
224			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
225		]
226		if join_tokens_individual_maze:
227			return [" ".join(tokens) for tokens in output]
228		else:
229			return output
230
231	def __len__(self) -> int:
232		"""return the number of mazes in the dataset"""
233		return len(self.mazes)
234
235	def __eq__(self, other: object) -> bool:
236		"""compare two datasets"""
237		if not isinstance(other, MazeDataset):
238			raise NotImplementedError(
239				"can only compare with other MazeDataset objects",
240			)
241		# TODO: compare hashes of data instead of the data itself?
242		return self.cfg == other.cfg and self.mazes == other.mazes
243
244	def assert_equal(self, other: "MazeDataset") -> None:
245		"""assert that two datasets are equal"""
246		assert isinstance(other, MazeDataset)
247		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
248		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
249
250	@classmethod
251	def generate(
252		cls,
253		cfg: MazeDatasetConfig,
254		gen_parallel: bool = False,
255		pool_kwargs: dict | None = None,
256		verbose: bool = False,
257		# TODO: what to do when unexpected kwargs are passed?
258		**kwargs,  # noqa: ARG003
259	) -> "MazeDataset":
260		"""Generate a maze dataset given a config and some generation parameters"""
261		# Copy the config to avoid modifying the original
262		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
263			json.loads(json.dumps(cfg.serialize())),
264		)
265
266		if pool_kwargs is None:
267			pool_kwargs = dict()
268		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
269
270		solved_mazes: list[SolvedMaze | None]
271		# Configure tqdm for progress bar
272		tqdm_kwargs: dict = dict(
273			total=cfg_cpy.n_mazes,
274			unit="maze",
275			desc="generating & solving mazes",
276			disable=not verbose,
277		)
278		# TODO: don't use the global unless generating in parallel!
279		if gen_parallel:
280			with multiprocessing.Pool(
281				**pool_kwargs,
282				initializer=_maze_gen_init_worker,
283				initargs=(cfg_cpy,),
284			) as pool:
285				solved_mazes = list(
286					tqdm.tqdm(
287						pool.imap(_generate_maze_helper, maze_indexes),
288						**tqdm_kwargs,
289					),
290				)
291
292		else:
293			_maze_gen_init_worker(cfg_cpy)
294			solved_mazes = list(
295				tqdm.tqdm(
296					map(
297						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
298						# why does it think tolist() returns a string?
299						_generate_maze_helper,  # type: ignore[arg-type]
300						maze_indexes.tolist(),
301					),
302					**tqdm_kwargs,
303				),
304			)
305
306		# Filter out None values explicitly after ensuring all results are collected
307		solved_mazes_: list[SolvedMaze] = [
308			maze for maze in solved_mazes if maze is not None
309		]
310		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
311
312		# Update the config with the actual number of mazes
313		cfg_cpy.n_mazes = len(solved_mazes_)
314
315		dataset: MazeDataset = cls(
316			cfg=cfg_cpy,
317			mazes=solved_mazes_,
318		)
319
320		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
321
322		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
323
324		return dataset
325
326	@classmethod
327	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
328		"(not implemented yet!) download a maze dataset from the internet"
329		raise NotImplementedError("not implemented yet")
330
331	@classmethod
332	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
333		"""load from zanj/json"""
334		if data[_FORMAT_KEY] == "MazeDataset:minimal":
335			return cls._load_minimal(data)
336		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
337			return cls._load_minimal_soln_cat(data)
338		elif data[_FORMAT_KEY] == "MazeDataset":
339			if (
340				SERIALIZE_MINIMAL_THRESHOLD == -1
341			):  # Allow access to `_load_legacy` for profiling
342				return cls._load_legacy(data)
343			return cls._load_full(data)
344		else:
345			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
346			raise KeyError(
347				err_msg,
348			)
349
350	@classmethod
351	def _load_full(cls, data: JSONdict) -> "MazeDataset":
352		assert data[_FORMAT_KEY] == "MazeDataset"
353		return cls(
354			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
355			mazes=load_item_recursive(data["mazes"], tuple()),
356			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
357		)
358
359	@classmethod
360	def _load_minimal(cls, data: JSONdict) -> "MazeDataset":
361		assert data[_FORMAT_KEY] == "MazeDataset:minimal"
362		return cls(
363			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
364			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
365			mazes=[
366				SolvedMaze(
367					clist,
368					soln[:slen, ...],
369				)
370				for clist, slen, soln in zip(
371					load_item_recursive(data["maze_connection_lists"], tuple()),
372					load_item_recursive(data["maze_solution_lengths"], tuple()),
373					load_item_recursive(data["maze_solutions"], tuple()),
374					strict=False,
375					# load_item_recursive(data["maze_endpoints"], tuple()),
376				)
377			],
378		)
379
380	@classmethod
381	def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset":
382		assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat"
383
384		maze_solution_lengths = load_item_recursive(
385			data["maze_solution_lengths"],
386			tuple(),
387		)
388		maze_solutions_concat = load_item_recursive(
389			data["maze_solutions_concat"],
390			tuple(),
391		)
392		maze_solutions = np.split(
393			maze_solutions_concat,
394			np.cumsum(maze_solution_lengths)[:-1],
395			axis=0,
396		)
397
398		return cls(
399			cfg=load_item_recursive(data["cfg"], tuple()),
400			generation_metadata_collected=load_item_recursive(
401				data["generation_metadata_collected"],
402				tuple(),
403			),
404			mazes=[
405				SolvedMaze(
406					connection_list=clist,
407					solution=soln,
408				)
409				for clist, soln in zip(
410					load_item_recursive(data["maze_connection_lists"], tuple()),
411					# load_item_recursive(data["maze_endpoints"], tuple()),
412					maze_solutions,
413					strict=False,
414				)
415			],
416		)
417
418	@classmethod
419	def _load_legacy(cls, data: JSONdict) -> "MazeDataset":
420		"""Legacy `load` method from <0.5.2. Used exclusively for profiling comparison."""
421		assert data[_FORMAT_KEY] == "MazeDataset"
422		return cls(
423			**{
424				key: load_item_recursive(data[key], tuple())
425				for key in ["cfg", "mazes", "generation_metadata_collected"]
426			},
427		)
428
429	def serialize(self) -> JSONdict:
430		"""serialize to zanj/json"""
431		if (
432			SERIALIZE_MINIMAL_THRESHOLD is not None
433			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
434		):
435			return self._serialize_minimal()
436		return self._serialize_full()
437
438	def _serialize_full(self) -> JSONdict:
439		return {
440			_FORMAT_KEY: "MazeDataset",
441			"cfg": json_serialize(self.cfg),
442			"fname": self.cfg.to_fname(),
443			"mazes": json_serialize(self.mazes),
444			"generation_metadata_collected": json_serialize(
445				self.generation_metadata_collected,
446			),
447		}
448
449	def _serialize_minimal(self) -> JSONdict:
450		"alternate serialization where metadata is collected and mazes are stored in concatenated form"
451		filtered_meta: MazeDataset
452		if self.generation_metadata_collected is None:
453			filtered_meta = self.filter_by.collect_generation_meta()
454		else:
455			filtered_meta = self
456
457		max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes)
458		n_mazes: int = len(filtered_meta.mazes)
459		grid_n: int = filtered_meta.cfg.grid_n
460
461		maze_connection_lists: np.ndarray = np.empty(
462			(n_mazes, 2, grid_n, grid_n),
463			dtype=np.bool_,
464		)
465		# maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
466		maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32)
467		maze_solutions: np.ndarray = np.empty(
468			(n_mazes, max_solution_len, 2),
469			dtype=np.int8,
470		)
471
472		for idx, maze in enumerate(filtered_meta.mazes):
473			maze_connection_lists[idx] = maze.connection_list
474			# maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
475			maze_solution_lengths[idx] = maze.solution.shape[0]
476			maze_solutions[idx, : maze.solution.shape[0]] = maze.solution
477
478		return {
479			_FORMAT_KEY: "MazeDataset:minimal",
480			"cfg": json_serialize(filtered_meta.cfg),
481			"fname": filtered_meta.cfg.to_fname(),
482			"generation_metadata_collected": json_serialize(
483				filtered_meta.generation_metadata_collected,
484			),
485			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
486			# "maze_endpoints": maze_endpoints,
487			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
488			"maze_solutions": maze_solutions,  # type: ignore[dict-item]
489		}
490
491	def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict:
492		"alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form"
493		filtered_meta: MazeDataset
494		if self.generation_metadata_collected is None:
495			filtered_meta = self.filter_by.collect_generation_meta()
496		else:
497			filtered_meta = self
498
499		maze_solution_lengths: np.ndarray = np.array(
500			[m.solution.shape[0] for m in filtered_meta.mazes],
501			dtype=np.int32,
502		)
503		n_mazes: int = len(filtered_meta.mazes)
504		grid_n: int = filtered_meta.cfg.grid_n
505		total_solution_len: int = np.sum(maze_solution_lengths)
506
507		maze_connection_lists: np.ndarray = np.empty(
508			(n_mazes, 2, grid_n, grid_n),
509			dtype=np.bool_,
510		)
511		maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
512		maze_solutions_concat: np.ndarray = np.empty(
513			(total_solution_len, 2),
514			dtype=np.int8,
515		)
516
517		solutions_running_idx: int = 0
518		for idx, maze in enumerate(filtered_meta.mazes):
519			maze_connection_lists[idx] = maze.connection_list
520			maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
521			soln_len: int = maze.solution.shape[0]
522			maze_solution_lengths[idx] = soln_len
523			maze_solutions_concat[
524				solutions_running_idx : solutions_running_idx + soln_len
525			] = maze.solution
526			solutions_running_idx += soln_len
527
528		return {
529			_FORMAT_KEY: "MazeDataset:minimal_soln_cat",
530			"cfg": json_serialize(filtered_meta.cfg),
531			"fname": filtered_meta.cfg.to_fname(),
532			"generation_metadata_collected": json_serialize(
533				filtered_meta.generation_metadata_collected,
534			),
535			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
536			"maze_endpoints": maze_endpoints,  # type: ignore[dict-item]
537			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
538			"maze_solutions_concat": maze_solutions_concat,  # type: ignore[dict-item]
539		}
540
541	def update_self_config(self) -> None:
542		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
543		if self.cfg.n_mazes != len(self.mazes):
544			warnings.warn(
545				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
546			)
547			self.cfg.n_mazes = len(self.mazes)
548
549	def custom_maze_filter(
550		self,
551		method: typing.Callable[[SolvedMaze], bool],
552		**kwargs,
553	) -> "MazeDataset":
554		"""filter the dataset using a custom method"""
555		output: MazeDataset = MazeDataset(
556			cfg=copy.deepcopy(self.cfg),
557			mazes=[m for m in self.mazes if method(m, **kwargs)],
558		)
559		output.cfg.applied_filters.append(
560			{
561				"name": f"__custom__:{method.__name__}",
562				"kwargs": kwargs,
563			},
564		)
565		output.update_self_config()
566		return output

a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config

MazeDataset( cfg: maze_dataset.MazeDatasetConfig, mazes: Sequence[maze_dataset.SolvedMaze], generation_metadata_collected: dict | None = None)
116	def __init__(
117		self,
118		cfg: MazeDatasetConfig,
119		mazes: typing.Sequence[SolvedMaze],
120		generation_metadata_collected: dict | None = None,
121	) -> None:
122		"""initialize a maze dataset from a config and a list of solved mazes"""
123		super().__init__()
124		self.cfg: MazeDatasetConfig = cfg
125		self.mazes: list[SolvedMaze] = list(mazes)
126		self.generation_metadata_collected: dict | None = generation_metadata_collected

initialize a maze dataset from a config and a list of solved mazes

generation_metadata_collected: dict | None
@classmethod
def from_config( cls, cfg: maze_dataset.MazeDatasetConfig, do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: zanj.zanj.ZANJ | None = None, do_download: bool = True, local_base_path: pathlib.Path = PosixPath('data/maze_dataset'), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs) -> MazeDataset:
129	@classmethod
130	def from_config(  # type: ignore[override]
131		cls,
132		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
133		cfg: MazeDatasetConfig,  # type: ignore[override]
134		do_generate: bool = True,
135		load_local: bool = True,
136		save_local: bool = True,
137		zanj: ZANJ | None = None,
138		do_download: bool = True,
139		local_base_path: Path = Path("data/maze_dataset"),
140		except_on_config_mismatch: bool = True,
141		allow_generation_metadata_filter_mismatch: bool = True,
142		verbose: bool = False,
143		**kwargs,
144	) -> "MazeDataset":
145		"""create a maze dataset from a config
146
147		priority of loading:
148		1. load from local
149		2. download
150		3. generate
151
152		"""
153		return cast(
154			MazeDataset,
155			super().from_config(
156				cfg=cfg,
157				do_generate=do_generate,
158				load_local=load_local,
159				save_local=save_local,
160				zanj=zanj,
161				do_download=do_download,
162				local_base_path=local_base_path,
163				except_on_config_mismatch=except_on_config_mismatch,
164				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
165				verbose=verbose,
166				**kwargs,
167			),
168		)

create a maze dataset from a config

priority of loading:

  1. load from local
  2. download
  3. generate
def data_hash(self) -> int:
170	def data_hash(self) -> int:
171		"""return a hash of the data"""
172		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))

return a hash of the data

def as_tokens( self, maze_tokenizer, limit: int | None = None, join_tokens_individual_maze: bool = False) -> list[list[str]] | list[str]:
204	def as_tokens(
205		self,
206		maze_tokenizer,  # TODO: MazeTokenizer
207		limit: int | None = None,
208		join_tokens_individual_maze: bool = False,
209	) -> list[list[str]] | list[str]:
210		"""return the dataset as tokens according to the passed `maze_tokenizer`
211
212		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
213
214		if `join_tokens_individual_maze` is True, then the tokens of each maze are
215		joined with a space, and the result is a list of strings.
216		i.e.:
217
218			>>> dataset.as_tokens(join_tokens_individual_maze=False)
219			[["a", "b", "c"], ["d", "e", "f"]]
220			>>> dataset.as_tokens(join_tokens_individual_maze=True)
221			["a b c", "d e f"]
222		"""
223		output: list[list[str]] = [
224			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
225		]
226		if join_tokens_individual_maze:
227			return [" ".join(tokens) for tokens in output]
228		else:
229			return output

return the dataset as tokens according to the passed maze_tokenizer

the maze_tokenizer should be either a MazeTokenizer or a MazeTokenizerModular

if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:

    >>> dataset.as_tokens(join_tokens_individual_maze=False)
    [["a", "b", "c"], ["d", "e", "f"]]
    >>> dataset.as_tokens(join_tokens_individual_maze=True)
    ["a b c", "d e f"]
def assert_equal(self, other: MazeDataset) -> None:
244	def assert_equal(self, other: "MazeDataset") -> None:
245		"""assert that two datasets are equal"""
246		assert isinstance(other, MazeDataset)
247		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
248		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"

assert that two datasets are equal

@classmethod
def generate( cls, cfg: maze_dataset.MazeDatasetConfig, gen_parallel: bool = False, pool_kwargs: dict | None = None, verbose: bool = False, **kwargs) -> MazeDataset:
250	@classmethod
251	def generate(
252		cls,
253		cfg: MazeDatasetConfig,
254		gen_parallel: bool = False,
255		pool_kwargs: dict | None = None,
256		verbose: bool = False,
257		# TODO: what to do when unexpected kwargs are passed?
258		**kwargs,  # noqa: ARG003
259	) -> "MazeDataset":
260		"""Generate a maze dataset given a config and some generation parameters"""
261		# Copy the config to avoid modifying the original
262		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
263			json.loads(json.dumps(cfg.serialize())),
264		)
265
266		if pool_kwargs is None:
267			pool_kwargs = dict()
268		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
269
270		solved_mazes: list[SolvedMaze | None]
271		# Configure tqdm for progress bar
272		tqdm_kwargs: dict = dict(
273			total=cfg_cpy.n_mazes,
274			unit="maze",
275			desc="generating & solving mazes",
276			disable=not verbose,
277		)
278		# TODO: don't use the global unless generating in parallel!
279		if gen_parallel:
280			with multiprocessing.Pool(
281				**pool_kwargs,
282				initializer=_maze_gen_init_worker,
283				initargs=(cfg_cpy,),
284			) as pool:
285				solved_mazes = list(
286					tqdm.tqdm(
287						pool.imap(_generate_maze_helper, maze_indexes),
288						**tqdm_kwargs,
289					),
290				)
291
292		else:
293			_maze_gen_init_worker(cfg_cpy)
294			solved_mazes = list(
295				tqdm.tqdm(
296					map(
297						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
298						# why does it think tolist() returns a string?
299						_generate_maze_helper,  # type: ignore[arg-type]
300						maze_indexes.tolist(),
301					),
302					**tqdm_kwargs,
303				),
304			)
305
306		# Filter out None values explicitly after ensuring all results are collected
307		solved_mazes_: list[SolvedMaze] = [
308			maze for maze in solved_mazes if maze is not None
309		]
310		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
311
312		# Update the config with the actual number of mazes
313		cfg_cpy.n_mazes = len(solved_mazes_)
314
315		dataset: MazeDataset = cls(
316			cfg=cfg_cpy,
317			mazes=solved_mazes_,
318		)
319
320		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
321
322		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
323
324		return dataset

Generate a maze dataset given a config and some generation parameters

@classmethod
def download( cls, cfg: maze_dataset.MazeDatasetConfig, **kwargs) -> MazeDataset:
326	@classmethod
327	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
328		"(not implemented yet!) download a maze dataset from the internet"
329		raise NotImplementedError("not implemented yet")

(not implemented yet!) download a maze dataset from the internet

@classmethod
def load( cls: type[MazeDataset], data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> MazeDataset:
331	@classmethod
332	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
333		"""load from zanj/json"""
334		if data[_FORMAT_KEY] == "MazeDataset:minimal":
335			return cls._load_minimal(data)
336		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
337			return cls._load_minimal_soln_cat(data)
338		elif data[_FORMAT_KEY] == "MazeDataset":
339			if (
340				SERIALIZE_MINIMAL_THRESHOLD == -1
341			):  # Allow access to `_load_legacy` for profiling
342				return cls._load_legacy(data)
343			return cls._load_full(data)
344		else:
345			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
346			raise KeyError(
347				err_msg,
348			)

load from zanj/json

def serialize( self) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
429	def serialize(self) -> JSONdict:
430		"""serialize to zanj/json"""
431		if (
432			SERIALIZE_MINIMAL_THRESHOLD is not None
433			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
434		):
435			return self._serialize_minimal()
436		return self._serialize_full()

serialize to zanj/json

def update_self_config(self) -> None:
541	def update_self_config(self) -> None:
542		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
543		if self.cfg.n_mazes != len(self.mazes):
544			warnings.warn(
545				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
546			)
547			self.cfg.n_mazes = len(self.mazes)

update the config to match the current state of the dataset (number of mazes, such as after filtering)

def custom_maze_filter( self, method: Callable[[maze_dataset.SolvedMaze], bool], **kwargs) -> MazeDataset:
549	def custom_maze_filter(
550		self,
551		method: typing.Callable[[SolvedMaze], bool],
552		**kwargs,
553	) -> "MazeDataset":
554		"""filter the dataset using a custom method"""
555		output: MazeDataset = MazeDataset(
556			cfg=copy.deepcopy(self.cfg),
557			mazes=[m for m in self.mazes if method(m, **kwargs)],
558		)
559		output.cfg.applied_filters.append(
560			{
561				"name": f"__custom__:{method.__name__}",
562				"kwargs": kwargs,
563			},
564		)
565		output.update_self_config()
566		return output

filter the dataset using a custom method