docs for maze-dataset v1.4.0
View Source on GitHub

maze_dataset.dataset.maze_dataset

MazeDatasetConfig is where you decide what your dataset should look like, then pass it to MazeDataset.from_config to generate or load the dataset.

see demo_dataset notebook


  1"""`MazeDatasetConfig` is where you decide what your dataset should look like, then pass it to `MazeDataset.from_config` to generate or load the dataset.
  2
  3see [demo_dataset notebook](../../notebooks/demo_dataset)
  4
  5"""
  6
  7import copy
  8import json
  9import multiprocessing
 10import typing
 11import warnings
 12from pathlib import Path
 13from typing import Literal, cast, overload
 14
 15import numpy as np
 16import tqdm
 17from jaxtyping import Int
 18from muutils.json_serialize import (
 19	json_serialize,
 20)
 21from muutils.json_serialize.util import (
 22	_FORMAT_KEY,
 23	JSONdict,
 24)
 25from muutils.misc import stable_hash
 26from zanj import ZANJ
 27from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler
 28
 29from maze_dataset.constants import CoordArray
 30from maze_dataset.dataset.dataset import (
 31	GPTDataset,
 32)
 33from maze_dataset.dataset.maze_dataset_config import (
 34	SERIALIZE_MINIMAL_THRESHOLD,
 35	EndpointKwargsType,
 36	MazeDatasetConfig,
 37)
 38from maze_dataset.generation.seed import GLOBAL_SEED
 39from maze_dataset.maze import LatticeMaze, SolvedMaze
 40
 41_GLOBAL_WORKER_CONFIG: MazeDatasetConfig
 42
 43
 44def _generate_maze_helper(index: int) -> SolvedMaze | None:  # noqa: ARG001
 45	"""Helper function for generating mazes in parallel.
 46
 47	> [!CAUTION]
 48	> don't use this unless generating in parallel!
 49	"""
 50	global _GLOBAL_WORKER_CONFIG  # noqa: PLW0602
 51	# TODO: don't use this unless generating in parallel!
 52	maze: LatticeMaze = _GLOBAL_WORKER_CONFIG.maze_ctor(
 53		grid_shape=_GLOBAL_WORKER_CONFIG.grid_shape_np,
 54		**_GLOBAL_WORKER_CONFIG.maze_ctor_kwargs,
 55	)
 56
 57	endpoint_kwargs: EndpointKwargsType = _GLOBAL_WORKER_CONFIG.endpoint_kwargs.copy()
 58
 59	# Generate the solution
 60	# mypy doesnt realize EndpointKwargsType has only string keys: `Keywords must be strings  [misc]`
 61	# TYPING: error: No overload variant of "generate_random_path" of "LatticeMaze" matches argument type "dict[Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | list[tuple[int, int]] | None]"  [call-overload]
 62	solution: CoordArray | None = maze.generate_random_path(**endpoint_kwargs)  # type: ignore[misc, call-overload]
 63
 64	# Validate the solution
 65	if (
 66		solution is None
 67		or len(solution) == 0
 68		or not isinstance(solution, np.ndarray)
 69		# magic value is fine here
 70		or len(solution.shape) != 2  # noqa: PLR2004
 71	):
 72		return None  # Return None if the solution is invalid
 73
 74	return SolvedMaze.from_lattice_maze(
 75		lattice_maze=maze,
 76		solution=solution,
 77	)
 78
 79
 80def _maze_gen_init_worker(config: MazeDatasetConfig) -> None:
 81	"""special worker helper
 82
 83	> [!CAUTION]
 84	> this makes the generation depend both on whether parallelism is used, and on the number of processes. this is bad!
 85
 86	"""
 87	# TODO: dont use globals here!
 88	global _GLOBAL_WORKER_CONFIG  # noqa: PLW0603
 89	_GLOBAL_WORKER_CONFIG = config
 90
 91	process_id: tuple[int, ...] = multiprocessing.current_process()._identity
 92	if len(process_id) == 0:
 93		# no multiprocessing, seed was already set
 94		pass
 95	elif len(process_id) == 1:
 96		# multiprocessing, adjust seed based on process id
 97		# only set numpy seed, since we do not use other random gens
 98		np.random.seed(
 99			_GLOBAL_WORKER_CONFIG.seed
100			or GLOBAL_SEED  # if the seed is None, use the global seed
101			+ process_id[0]
102		)
103	else:
104		err_msg = (
105			f"unexpected process id: {process_id = }\n{multiprocessing.Process() = }"
106		)
107		raise ValueError(
108			err_msg,
109		)
110
111
112# TODO: we probably don't need to hash datasets, right?
113class MazeDataset(GPTDataset[MazeDatasetConfig]):  # noqa: PLW1641
114	"""a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`"""
115
116	def __init__(
117		self,
118		cfg: MazeDatasetConfig,
119		mazes: typing.Sequence[SolvedMaze],
120		generation_metadata_collected: dict | None = None,
121	) -> None:
122		"""initialize a maze dataset from a config and a list of solved mazes"""
123		super().__init__()
124		self.cfg: MazeDatasetConfig = cfg
125		self.mazes: list[SolvedMaze] = list(mazes)
126		self.generation_metadata_collected: dict | None = generation_metadata_collected
127
128	# TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset"  [override]
129	@classmethod
130	def from_config(  # type: ignore[override]
131		cls,
132		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
133		cfg: MazeDatasetConfig,  # type: ignore[override]
134		do_generate: bool = True,
135		load_local: bool = True,
136		save_local: bool = True,
137		zanj: ZANJ | None = None,
138		do_download: bool = True,
139		local_base_path: Path = Path("data/maze_dataset"),
140		except_on_config_mismatch: bool = True,
141		allow_generation_metadata_filter_mismatch: bool = True,
142		verbose: bool = False,
143		**kwargs,
144	) -> "MazeDataset":
145		"""create a maze dataset from a config
146
147		priority of loading:
148		1. load from local
149		2. download
150		3. generate
151
152		"""
153		return cast(
154			"MazeDataset",
155			super().from_config(
156				cfg=cfg,
157				do_generate=do_generate,
158				load_local=load_local,
159				save_local=save_local,
160				zanj=zanj,
161				do_download=do_download,
162				local_base_path=local_base_path,
163				except_on_config_mismatch=except_on_config_mismatch,
164				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
165				verbose=verbose,
166				**kwargs,
167			),
168		)
169
170	def data_hash(self) -> int:
171		"""return a hash of the data"""
172		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
173
174	def __getitem__(self, i: int) -> SolvedMaze:
175		"""get a maze by index"""
176		return self.mazes[i]
177
178	def __iter__(self) -> typing.Iterator[SolvedMaze]:
179		"""iterate over the mazes"""
180		return iter(self.mazes)
181
182	def __deepcopy__(self, memo) -> "MazeDataset":  # noqa: ANN001
183		"""deepcopy the dataset
184
185		FIX: this isnt actually a deepcopy I think?
186		"""
187		return MazeDataset.load(self._serialize_full())
188
189	# TYPING: get type hints on the tokenizer here
190	@overload
191	def as_tokens(
192		self,
193		maze_tokenizer,  # noqa: ANN001
194		limit: int | None = None,
195		join_tokens_individual_maze: Literal[False] = False,
196	) -> list[list[str]]: ...
197	@overload
198	def as_tokens(
199		self,
200		maze_tokenizer,  # noqa: ANN001
201		limit: int | None = None,
202		join_tokens_individual_maze: Literal[True] = True,
203	) -> list[str]: ...
204	def as_tokens(
205		self,
206		maze_tokenizer,  # TODO: MazeTokenizer
207		limit: int | None = None,
208		join_tokens_individual_maze: bool = False,
209	) -> list[list[str]] | list[str]:
210		"""return the dataset as tokens according to the passed `maze_tokenizer`
211
212		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
213
214		if `join_tokens_individual_maze` is True, then the tokens of each maze are
215		joined with a space, and the result is a list of strings.
216		i.e.:
217
218			>>> dataset.as_tokens(join_tokens_individual_maze=False)
219			[["a", "b", "c"], ["d", "e", "f"]]
220			>>> dataset.as_tokens(join_tokens_individual_maze=True)
221			["a b c", "d e f"]
222		"""
223		output: list[list[str]] = [
224			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
225		]
226		if join_tokens_individual_maze:
227			return [" ".join(tokens) for tokens in output]
228		else:
229			return output
230
231	def __len__(self) -> int:
232		"""return the number of mazes in the dataset"""
233		return len(self.mazes)
234
235	def __eq__(self, other: object) -> bool:
236		"""compare two datasets"""
237		if not isinstance(other, MazeDataset):
238			raise NotImplementedError(
239				"can only compare with other MazeDataset objects",
240			)
241		# TODO: compare hashes of data instead of the data itself?
242		return self.cfg == other.cfg and self.mazes == other.mazes
243
244	def assert_equal(self, other: "MazeDataset") -> None:
245		"""assert that two datasets are equal"""
246		assert isinstance(other, MazeDataset)
247		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
248		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
249
250	@classmethod
251	def generate(
252		cls,
253		cfg: MazeDatasetConfig,
254		gen_parallel: bool = False,
255		pool_kwargs: dict | None = None,
256		verbose: bool = False,
257		# TODO: what to do when unexpected kwargs are passed?
258		**kwargs,  # noqa: ARG003
259	) -> "MazeDataset":
260		"""Generate a maze dataset given a config and some generation parameters"""
261		# Copy the config to avoid modifying the original
262		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
263			json.loads(json.dumps(cfg.serialize())),
264		)
265
266		if pool_kwargs is None:
267			pool_kwargs = dict()
268		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
269
270		solved_mazes: list[SolvedMaze | None]
271		# Configure tqdm for progress bar
272		tqdm_kwargs: dict = dict(
273			total=cfg_cpy.n_mazes,
274			unit="maze",
275			desc="generating & solving mazes",
276			disable=not verbose,
277		)
278		# TODO: don't use the global unless generating in parallel!
279		if gen_parallel:
280			with multiprocessing.Pool(
281				**pool_kwargs,
282				initializer=_maze_gen_init_worker,
283				initargs=(cfg_cpy,),
284			) as pool:
285				solved_mazes = list(
286					tqdm.tqdm(
287						pool.imap(_generate_maze_helper, maze_indexes),
288						**tqdm_kwargs,
289					),
290				)
291
292		else:
293			_maze_gen_init_worker(cfg_cpy)
294			solved_mazes = list(
295				tqdm.tqdm(
296					map(
297						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
298						# why does it think tolist() returns a string?
299						_generate_maze_helper,  # type: ignore[arg-type]
300						maze_indexes.tolist(),
301					),
302					**tqdm_kwargs,
303				),
304			)
305
306		# Filter out None values explicitly after ensuring all results are collected
307		solved_mazes_: list[SolvedMaze] = [
308			maze for maze in solved_mazes if maze is not None
309		]
310		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
311
312		# Update the config with the actual number of mazes
313		cfg_cpy.n_mazes = len(solved_mazes_)
314
315		dataset: MazeDataset = cls(
316			cfg=cfg_cpy,
317			mazes=solved_mazes_,
318		)
319
320		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
321
322		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
323
324		return dataset
325
326	@classmethod
327	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
328		"(not implemented yet!) download a maze dataset from the internet"
329		raise NotImplementedError("not implemented yet")
330
331	@classmethod
332	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
333		"""load from zanj/json"""
334		if data[_FORMAT_KEY] == "MazeDataset:minimal":
335			return cls._load_minimal(data)
336		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
337			return cls._load_minimal_soln_cat(data)
338		elif data[_FORMAT_KEY] == "MazeDataset":
339			if (
340				SERIALIZE_MINIMAL_THRESHOLD == -1
341			):  # Allow access to `_load_legacy` for profiling
342				return cls._load_legacy(data)
343			return cls._load_full(data)
344		else:
345			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
346			raise KeyError(
347				err_msg,
348			)
349
350	@classmethod
351	def _load_full(cls, data: JSONdict) -> "MazeDataset":
352		assert data[_FORMAT_KEY] == "MazeDataset"
353		return cls(
354			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
355			mazes=load_item_recursive(data["mazes"], tuple()),
356			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
357		)
358
359	@classmethod
360	def _load_minimal(cls, data: JSONdict) -> "MazeDataset":
361		assert data[_FORMAT_KEY] == "MazeDataset:minimal"
362		return cls(
363			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
364			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
365			mazes=[
366				SolvedMaze(
367					clist,
368					soln[:slen, ...],
369				)
370				for clist, slen, soln in zip(
371					load_item_recursive(data["maze_connection_lists"], tuple()),
372					load_item_recursive(data["maze_solution_lengths"], tuple()),
373					load_item_recursive(data["maze_solutions"], tuple()),
374					strict=False,
375					# load_item_recursive(data["maze_endpoints"], tuple()),
376				)
377			],
378		)
379
380	@classmethod
381	def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset":
382		assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat"
383
384		maze_solution_lengths = load_item_recursive(
385			data["maze_solution_lengths"],
386			tuple(),
387		)
388		maze_solutions_concat = load_item_recursive(
389			data["maze_solutions_concat"],
390			tuple(),
391		)
392		maze_solutions = np.split(
393			maze_solutions_concat,
394			np.cumsum(maze_solution_lengths)[:-1],
395			axis=0,
396		)
397
398		return cls(
399			cfg=load_item_recursive(data["cfg"], tuple()),
400			generation_metadata_collected=load_item_recursive(
401				data["generation_metadata_collected"],
402				tuple(),
403			),
404			mazes=[
405				SolvedMaze(
406					connection_list=clist,
407					solution=soln,
408				)
409				for clist, soln in zip(
410					load_item_recursive(data["maze_connection_lists"], tuple()),
411					# load_item_recursive(data["maze_endpoints"], tuple()),
412					maze_solutions,
413					strict=False,
414				)
415			],
416		)
417
418	@classmethod
419	def _load_legacy(cls, data: JSONdict) -> "MazeDataset":
420		"""Legacy `load` method from <0.5.2. Used exclusively for profiling comparison."""
421		assert data[_FORMAT_KEY] == "MazeDataset"
422		return cls(
423			**{
424				key: load_item_recursive(data[key], tuple())
425				for key in ["cfg", "mazes", "generation_metadata_collected"]
426			},
427		)
428
429	def serialize(self) -> JSONdict:
430		"""serialize to zanj/json"""
431		if (
432			SERIALIZE_MINIMAL_THRESHOLD is not None
433			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
434		):
435			return self._serialize_minimal()
436		return self._serialize_full()
437
438	def _serialize_full(self) -> JSONdict:
439		return {
440			_FORMAT_KEY: "MazeDataset",
441			"cfg": json_serialize(self.cfg),
442			"fname": self.cfg.to_fname(),
443			"mazes": json_serialize(self.mazes),
444			"generation_metadata_collected": json_serialize(
445				self.generation_metadata_collected,
446			),
447		}
448
449	def _serialize_minimal(self) -> JSONdict:
450		"alternate serialization where metadata is collected and mazes are stored in concatenated form"
451		filtered_meta: MazeDataset
452		if self.generation_metadata_collected is None:
453			filtered_meta = self.filter_by.collect_generation_meta()
454		else:
455			filtered_meta = self
456
457		max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes)
458		n_mazes: int = len(filtered_meta.mazes)
459		grid_n: int = filtered_meta.cfg.grid_n
460
461		maze_connection_lists: np.ndarray = np.empty(
462			(n_mazes, 2, grid_n, grid_n),
463			dtype=np.bool_,
464		)
465		# maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
466		maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32)
467		maze_solutions: np.ndarray = np.empty(
468			(n_mazes, max_solution_len, 2),
469			dtype=np.int8,
470		)
471
472		for idx, maze in enumerate(filtered_meta.mazes):
473			maze_connection_lists[idx] = maze.connection_list
474			# maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
475			maze_solution_lengths[idx] = maze.solution.shape[0]
476			maze_solutions[idx, : maze.solution.shape[0]] = maze.solution
477
478		return {
479			_FORMAT_KEY: "MazeDataset:minimal",
480			"cfg": json_serialize(filtered_meta.cfg),
481			"fname": filtered_meta.cfg.to_fname(),
482			"generation_metadata_collected": json_serialize(
483				filtered_meta.generation_metadata_collected,
484			),
485			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
486			# "maze_endpoints": maze_endpoints,
487			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
488			"maze_solutions": maze_solutions,  # type: ignore[dict-item]
489		}
490
491	def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict:
492		"alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form"
493		filtered_meta: MazeDataset
494		if self.generation_metadata_collected is None:
495			filtered_meta = self.filter_by.collect_generation_meta()
496		else:
497			filtered_meta = self
498
499		maze_solution_lengths: np.ndarray = np.array(
500			[m.solution.shape[0] for m in filtered_meta.mazes],
501			dtype=np.int32,
502		)
503		n_mazes: int = len(filtered_meta.mazes)
504		grid_n: int = filtered_meta.cfg.grid_n
505		total_solution_len: int = np.sum(maze_solution_lengths)
506
507		maze_connection_lists: np.ndarray = np.empty(
508			(n_mazes, 2, grid_n, grid_n),
509			dtype=np.bool_,
510		)
511		maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
512		maze_solutions_concat: np.ndarray = np.empty(
513			(total_solution_len, 2),
514			dtype=np.int8,
515		)
516
517		solutions_running_idx: int = 0
518		for idx, maze in enumerate(filtered_meta.mazes):
519			maze_connection_lists[idx] = maze.connection_list
520			maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
521			soln_len: int = maze.solution.shape[0]
522			maze_solution_lengths[idx] = soln_len
523			maze_solutions_concat[
524				solutions_running_idx : solutions_running_idx + soln_len
525			] = maze.solution
526			solutions_running_idx += soln_len
527
528		return {
529			_FORMAT_KEY: "MazeDataset:minimal_soln_cat",
530			"cfg": json_serialize(filtered_meta.cfg),
531			"fname": filtered_meta.cfg.to_fname(),
532			"generation_metadata_collected": json_serialize(
533				filtered_meta.generation_metadata_collected,
534			),
535			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
536			"maze_endpoints": maze_endpoints,  # type: ignore[dict-item]
537			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
538			"maze_solutions_concat": maze_solutions_concat,  # type: ignore[dict-item]
539		}
540
541	def update_self_config(self) -> None:
542		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
543		if self.cfg.n_mazes != len(self.mazes):
544			warnings.warn(
545				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
546			)
547			self.cfg.n_mazes = len(self.mazes)
548
549	def custom_maze_filter(
550		self,
551		method: typing.Callable[[SolvedMaze], bool],
552		**kwargs,
553	) -> "MazeDataset":
554		"""filter the dataset using a custom method"""
555		output: MazeDataset = MazeDataset(
556			cfg=copy.deepcopy(self.cfg),
557			mazes=[m for m in self.mazes if method(m, **kwargs)],
558		)
559		output.cfg.applied_filters.append(
560			{
561				"name": f"__custom__:{method.__name__}",
562				"kwargs": kwargs,
563			},
564		)
565		output.update_self_config()
566		return output
567
568
569MazeDatasetConfig._dataset_class = property(  # type: ignore[method-assign, assignment]
570	lambda self: MazeDataset,  # noqa: ARG005
571)
572
573# register things with zanj
574register_loader_handler(
575	LoaderHandler(
576		check=lambda json_item, path=None, z=None: (  # type: ignore[misc] # noqa: ARG005
577			isinstance(json_item, typing.Mapping)
578			and _FORMAT_KEY in json_item
579			and json_item[_FORMAT_KEY].startswith("MazeDataset")
580		),
581		load=lambda json_item, path=None, z=None: MazeDataset.load(json_item),  # type: ignore[misc] # noqa: ARG005
582		uid="MazeDataset",
583		source_pckg="maze_dataset.generation.maze_dataset",
584		desc="MazeDataset",
585	),
586)
587
588
589# TODO: the code below is for doing some smarter collecting and type checking. Probably will delete.
590"""
591collect either the type at the field, or the shape of the field if it is an array
592metadata_types: dict[str, set[type, tuple]] = dict()
593for maze in new_dataset:
594	for key, value in maze.generation_meta.items():
595		if key not in metadata_types:
596			metadata_types[key] = set()
597
598		if isinstance(value, np.ndarray):
599			metadata_types[key].add(value.shape)
600		else:
601			metadata_types[key].add(type(value))
602
603# figure out what to do for each field
604metadata_actions: dict[str, typing.Callable] = dict()
605for key, key_type in metadata_types.items():
606	if all(isinstance(kt, tuple) for kt in key_type):
607		if all(kt == (2,) for kt in key_type):
608			# its all coords, do a statcounter on those coords
609			metadata_actions[key] = lambda vals: Counter(tuple(x) for x in vals)
610		elif all(
611			(len(kt) == 2) and (kt[1] == 2)
612			for kt in key_type
613		):
614			# its a list of coords, do a statcounter on those coords
615			metadata_actions[key] = lambda vals: Counter(
616				tuple(x) for x in np.concatenate(vals)
617			)
618		else:
619			# its a list of something else, do a counter on those
620			# TODO: throw except here?
621			metadata_actions[key] = Counter
622
623	elif all(kt in (bool, int, float) for kt in key_type):
624		# statcounter for numeric types
625		metadata_actions[key] = StatCounter
626	elif all(kt == str for kt in key_type):
627		# counter for string types
628		metadata_actions[key] = Counter
629	else:
630		# counter for everything else
631		# TODO: throw except here?
632		metadata_actions[key] = Counter
633"""

114class MazeDataset(GPTDataset[MazeDatasetConfig]):  # noqa: PLW1641
115	"""a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`"""
116
117	def __init__(
118		self,
119		cfg: MazeDatasetConfig,
120		mazes: typing.Sequence[SolvedMaze],
121		generation_metadata_collected: dict | None = None,
122	) -> None:
123		"""initialize a maze dataset from a config and a list of solved mazes"""
124		super().__init__()
125		self.cfg: MazeDatasetConfig = cfg
126		self.mazes: list[SolvedMaze] = list(mazes)
127		self.generation_metadata_collected: dict | None = generation_metadata_collected
128
129	# TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset"  [override]
130	@classmethod
131	def from_config(  # type: ignore[override]
132		cls,
133		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
134		cfg: MazeDatasetConfig,  # type: ignore[override]
135		do_generate: bool = True,
136		load_local: bool = True,
137		save_local: bool = True,
138		zanj: ZANJ | None = None,
139		do_download: bool = True,
140		local_base_path: Path = Path("data/maze_dataset"),
141		except_on_config_mismatch: bool = True,
142		allow_generation_metadata_filter_mismatch: bool = True,
143		verbose: bool = False,
144		**kwargs,
145	) -> "MazeDataset":
146		"""create a maze dataset from a config
147
148		priority of loading:
149		1. load from local
150		2. download
151		3. generate
152
153		"""
154		return cast(
155			"MazeDataset",
156			super().from_config(
157				cfg=cfg,
158				do_generate=do_generate,
159				load_local=load_local,
160				save_local=save_local,
161				zanj=zanj,
162				do_download=do_download,
163				local_base_path=local_base_path,
164				except_on_config_mismatch=except_on_config_mismatch,
165				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
166				verbose=verbose,
167				**kwargs,
168			),
169		)
170
171	def data_hash(self) -> int:
172		"""return a hash of the data"""
173		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
174
175	def __getitem__(self, i: int) -> SolvedMaze:
176		"""get a maze by index"""
177		return self.mazes[i]
178
179	def __iter__(self) -> typing.Iterator[SolvedMaze]:
180		"""iterate over the mazes"""
181		return iter(self.mazes)
182
183	def __deepcopy__(self, memo) -> "MazeDataset":  # noqa: ANN001
184		"""deepcopy the dataset
185
186		FIX: this isnt actually a deepcopy I think?
187		"""
188		return MazeDataset.load(self._serialize_full())
189
190	# TYPING: get type hints on the tokenizer here
191	@overload
192	def as_tokens(
193		self,
194		maze_tokenizer,  # noqa: ANN001
195		limit: int | None = None,
196		join_tokens_individual_maze: Literal[False] = False,
197	) -> list[list[str]]: ...
198	@overload
199	def as_tokens(
200		self,
201		maze_tokenizer,  # noqa: ANN001
202		limit: int | None = None,
203		join_tokens_individual_maze: Literal[True] = True,
204	) -> list[str]: ...
205	def as_tokens(
206		self,
207		maze_tokenizer,  # TODO: MazeTokenizer
208		limit: int | None = None,
209		join_tokens_individual_maze: bool = False,
210	) -> list[list[str]] | list[str]:
211		"""return the dataset as tokens according to the passed `maze_tokenizer`
212
213		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
214
215		if `join_tokens_individual_maze` is True, then the tokens of each maze are
216		joined with a space, and the result is a list of strings.
217		i.e.:
218
219			>>> dataset.as_tokens(join_tokens_individual_maze=False)
220			[["a", "b", "c"], ["d", "e", "f"]]
221			>>> dataset.as_tokens(join_tokens_individual_maze=True)
222			["a b c", "d e f"]
223		"""
224		output: list[list[str]] = [
225			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
226		]
227		if join_tokens_individual_maze:
228			return [" ".join(tokens) for tokens in output]
229		else:
230			return output
231
232	def __len__(self) -> int:
233		"""return the number of mazes in the dataset"""
234		return len(self.mazes)
235
236	def __eq__(self, other: object) -> bool:
237		"""compare two datasets"""
238		if not isinstance(other, MazeDataset):
239			raise NotImplementedError(
240				"can only compare with other MazeDataset objects",
241			)
242		# TODO: compare hashes of data instead of the data itself?
243		return self.cfg == other.cfg and self.mazes == other.mazes
244
245	def assert_equal(self, other: "MazeDataset") -> None:
246		"""assert that two datasets are equal"""
247		assert isinstance(other, MazeDataset)
248		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
249		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
250
251	@classmethod
252	def generate(
253		cls,
254		cfg: MazeDatasetConfig,
255		gen_parallel: bool = False,
256		pool_kwargs: dict | None = None,
257		verbose: bool = False,
258		# TODO: what to do when unexpected kwargs are passed?
259		**kwargs,  # noqa: ARG003
260	) -> "MazeDataset":
261		"""Generate a maze dataset given a config and some generation parameters"""
262		# Copy the config to avoid modifying the original
263		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
264			json.loads(json.dumps(cfg.serialize())),
265		)
266
267		if pool_kwargs is None:
268			pool_kwargs = dict()
269		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
270
271		solved_mazes: list[SolvedMaze | None]
272		# Configure tqdm for progress bar
273		tqdm_kwargs: dict = dict(
274			total=cfg_cpy.n_mazes,
275			unit="maze",
276			desc="generating & solving mazes",
277			disable=not verbose,
278		)
279		# TODO: don't use the global unless generating in parallel!
280		if gen_parallel:
281			with multiprocessing.Pool(
282				**pool_kwargs,
283				initializer=_maze_gen_init_worker,
284				initargs=(cfg_cpy,),
285			) as pool:
286				solved_mazes = list(
287					tqdm.tqdm(
288						pool.imap(_generate_maze_helper, maze_indexes),
289						**tqdm_kwargs,
290					),
291				)
292
293		else:
294			_maze_gen_init_worker(cfg_cpy)
295			solved_mazes = list(
296				tqdm.tqdm(
297					map(
298						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
299						# why does it think tolist() returns a string?
300						_generate_maze_helper,  # type: ignore[arg-type]
301						maze_indexes.tolist(),
302					),
303					**tqdm_kwargs,
304				),
305			)
306
307		# Filter out None values explicitly after ensuring all results are collected
308		solved_mazes_: list[SolvedMaze] = [
309			maze for maze in solved_mazes if maze is not None
310		]
311		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
312
313		# Update the config with the actual number of mazes
314		cfg_cpy.n_mazes = len(solved_mazes_)
315
316		dataset: MazeDataset = cls(
317			cfg=cfg_cpy,
318			mazes=solved_mazes_,
319		)
320
321		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
322
323		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
324
325		return dataset
326
327	@classmethod
328	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
329		"(not implemented yet!) download a maze dataset from the internet"
330		raise NotImplementedError("not implemented yet")
331
332	@classmethod
333	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
334		"""load from zanj/json"""
335		if data[_FORMAT_KEY] == "MazeDataset:minimal":
336			return cls._load_minimal(data)
337		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
338			return cls._load_minimal_soln_cat(data)
339		elif data[_FORMAT_KEY] == "MazeDataset":
340			if (
341				SERIALIZE_MINIMAL_THRESHOLD == -1
342			):  # Allow access to `_load_legacy` for profiling
343				return cls._load_legacy(data)
344			return cls._load_full(data)
345		else:
346			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
347			raise KeyError(
348				err_msg,
349			)
350
351	@classmethod
352	def _load_full(cls, data: JSONdict) -> "MazeDataset":
353		assert data[_FORMAT_KEY] == "MazeDataset"
354		return cls(
355			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
356			mazes=load_item_recursive(data["mazes"], tuple()),
357			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
358		)
359
360	@classmethod
361	def _load_minimal(cls, data: JSONdict) -> "MazeDataset":
362		assert data[_FORMAT_KEY] == "MazeDataset:minimal"
363		return cls(
364			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
365			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
366			mazes=[
367				SolvedMaze(
368					clist,
369					soln[:slen, ...],
370				)
371				for clist, slen, soln in zip(
372					load_item_recursive(data["maze_connection_lists"], tuple()),
373					load_item_recursive(data["maze_solution_lengths"], tuple()),
374					load_item_recursive(data["maze_solutions"], tuple()),
375					strict=False,
376					# load_item_recursive(data["maze_endpoints"], tuple()),
377				)
378			],
379		)
380
381	@classmethod
382	def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset":
383		assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat"
384
385		maze_solution_lengths = load_item_recursive(
386			data["maze_solution_lengths"],
387			tuple(),
388		)
389		maze_solutions_concat = load_item_recursive(
390			data["maze_solutions_concat"],
391			tuple(),
392		)
393		maze_solutions = np.split(
394			maze_solutions_concat,
395			np.cumsum(maze_solution_lengths)[:-1],
396			axis=0,
397		)
398
399		return cls(
400			cfg=load_item_recursive(data["cfg"], tuple()),
401			generation_metadata_collected=load_item_recursive(
402				data["generation_metadata_collected"],
403				tuple(),
404			),
405			mazes=[
406				SolvedMaze(
407					connection_list=clist,
408					solution=soln,
409				)
410				for clist, soln in zip(
411					load_item_recursive(data["maze_connection_lists"], tuple()),
412					# load_item_recursive(data["maze_endpoints"], tuple()),
413					maze_solutions,
414					strict=False,
415				)
416			],
417		)
418
419	@classmethod
420	def _load_legacy(cls, data: JSONdict) -> "MazeDataset":
421		"""Legacy `load` method from <0.5.2. Used exclusively for profiling comparison."""
422		assert data[_FORMAT_KEY] == "MazeDataset"
423		return cls(
424			**{
425				key: load_item_recursive(data[key], tuple())
426				for key in ["cfg", "mazes", "generation_metadata_collected"]
427			},
428		)
429
430	def serialize(self) -> JSONdict:
431		"""serialize to zanj/json"""
432		if (
433			SERIALIZE_MINIMAL_THRESHOLD is not None
434			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
435		):
436			return self._serialize_minimal()
437		return self._serialize_full()
438
439	def _serialize_full(self) -> JSONdict:
440		return {
441			_FORMAT_KEY: "MazeDataset",
442			"cfg": json_serialize(self.cfg),
443			"fname": self.cfg.to_fname(),
444			"mazes": json_serialize(self.mazes),
445			"generation_metadata_collected": json_serialize(
446				self.generation_metadata_collected,
447			),
448		}
449
450	def _serialize_minimal(self) -> JSONdict:
451		"alternate serialization where metadata is collected and mazes are stored in concatenated form"
452		filtered_meta: MazeDataset
453		if self.generation_metadata_collected is None:
454			filtered_meta = self.filter_by.collect_generation_meta()
455		else:
456			filtered_meta = self
457
458		max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes)
459		n_mazes: int = len(filtered_meta.mazes)
460		grid_n: int = filtered_meta.cfg.grid_n
461
462		maze_connection_lists: np.ndarray = np.empty(
463			(n_mazes, 2, grid_n, grid_n),
464			dtype=np.bool_,
465		)
466		# maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
467		maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32)
468		maze_solutions: np.ndarray = np.empty(
469			(n_mazes, max_solution_len, 2),
470			dtype=np.int8,
471		)
472
473		for idx, maze in enumerate(filtered_meta.mazes):
474			maze_connection_lists[idx] = maze.connection_list
475			# maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
476			maze_solution_lengths[idx] = maze.solution.shape[0]
477			maze_solutions[idx, : maze.solution.shape[0]] = maze.solution
478
479		return {
480			_FORMAT_KEY: "MazeDataset:minimal",
481			"cfg": json_serialize(filtered_meta.cfg),
482			"fname": filtered_meta.cfg.to_fname(),
483			"generation_metadata_collected": json_serialize(
484				filtered_meta.generation_metadata_collected,
485			),
486			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
487			# "maze_endpoints": maze_endpoints,
488			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
489			"maze_solutions": maze_solutions,  # type: ignore[dict-item]
490		}
491
492	def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict:
493		"alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form"
494		filtered_meta: MazeDataset
495		if self.generation_metadata_collected is None:
496			filtered_meta = self.filter_by.collect_generation_meta()
497		else:
498			filtered_meta = self
499
500		maze_solution_lengths: np.ndarray = np.array(
501			[m.solution.shape[0] for m in filtered_meta.mazes],
502			dtype=np.int32,
503		)
504		n_mazes: int = len(filtered_meta.mazes)
505		grid_n: int = filtered_meta.cfg.grid_n
506		total_solution_len: int = np.sum(maze_solution_lengths)
507
508		maze_connection_lists: np.ndarray = np.empty(
509			(n_mazes, 2, grid_n, grid_n),
510			dtype=np.bool_,
511		)
512		maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
513		maze_solutions_concat: np.ndarray = np.empty(
514			(total_solution_len, 2),
515			dtype=np.int8,
516		)
517
518		solutions_running_idx: int = 0
519		for idx, maze in enumerate(filtered_meta.mazes):
520			maze_connection_lists[idx] = maze.connection_list
521			maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
522			soln_len: int = maze.solution.shape[0]
523			maze_solution_lengths[idx] = soln_len
524			maze_solutions_concat[
525				solutions_running_idx : solutions_running_idx + soln_len
526			] = maze.solution
527			solutions_running_idx += soln_len
528
529		return {
530			_FORMAT_KEY: "MazeDataset:minimal_soln_cat",
531			"cfg": json_serialize(filtered_meta.cfg),
532			"fname": filtered_meta.cfg.to_fname(),
533			"generation_metadata_collected": json_serialize(
534				filtered_meta.generation_metadata_collected,
535			),
536			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
537			"maze_endpoints": maze_endpoints,  # type: ignore[dict-item]
538			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
539			"maze_solutions_concat": maze_solutions_concat,  # type: ignore[dict-item]
540		}
541
542	def update_self_config(self) -> None:
543		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
544		if self.cfg.n_mazes != len(self.mazes):
545			warnings.warn(
546				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
547			)
548			self.cfg.n_mazes = len(self.mazes)
549
550	def custom_maze_filter(
551		self,
552		method: typing.Callable[[SolvedMaze], bool],
553		**kwargs,
554	) -> "MazeDataset":
555		"""filter the dataset using a custom method"""
556		output: MazeDataset = MazeDataset(
557			cfg=copy.deepcopy(self.cfg),
558			mazes=[m for m in self.mazes if method(m, **kwargs)],
559		)
560		output.cfg.applied_filters.append(
561			{
562				"name": f"__custom__:{method.__name__}",
563				"kwargs": kwargs,
564			},
565		)
566		output.update_self_config()
567		return output

a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config

MazeDataset( cfg: maze_dataset.MazeDatasetConfig, mazes: Sequence[maze_dataset.SolvedMaze], generation_metadata_collected: dict | None = None)
117	def __init__(
118		self,
119		cfg: MazeDatasetConfig,
120		mazes: typing.Sequence[SolvedMaze],
121		generation_metadata_collected: dict | None = None,
122	) -> None:
123		"""initialize a maze dataset from a config and a list of solved mazes"""
124		super().__init__()
125		self.cfg: MazeDatasetConfig = cfg
126		self.mazes: list[SolvedMaze] = list(mazes)
127		self.generation_metadata_collected: dict | None = generation_metadata_collected

initialize a maze dataset from a config and a list of solved mazes

generation_metadata_collected: dict | None
@classmethod
def from_config( cls, cfg: maze_dataset.MazeDatasetConfig, do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: zanj.zanj.ZANJ | None = None, do_download: bool = True, local_base_path: pathlib._local.Path = PosixPath('data/maze_dataset'), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs) -> MazeDataset:
130	@classmethod
131	def from_config(  # type: ignore[override]
132		cls,
133		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
134		cfg: MazeDatasetConfig,  # type: ignore[override]
135		do_generate: bool = True,
136		load_local: bool = True,
137		save_local: bool = True,
138		zanj: ZANJ | None = None,
139		do_download: bool = True,
140		local_base_path: Path = Path("data/maze_dataset"),
141		except_on_config_mismatch: bool = True,
142		allow_generation_metadata_filter_mismatch: bool = True,
143		verbose: bool = False,
144		**kwargs,
145	) -> "MazeDataset":
146		"""create a maze dataset from a config
147
148		priority of loading:
149		1. load from local
150		2. download
151		3. generate
152
153		"""
154		return cast(
155			"MazeDataset",
156			super().from_config(
157				cfg=cfg,
158				do_generate=do_generate,
159				load_local=load_local,
160				save_local=save_local,
161				zanj=zanj,
162				do_download=do_download,
163				local_base_path=local_base_path,
164				except_on_config_mismatch=except_on_config_mismatch,
165				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
166				verbose=verbose,
167				**kwargs,
168			),
169		)

create a maze dataset from a config

priority of loading:

  1. load from local
  2. download
  3. generate
def data_hash(self) -> int:
171	def data_hash(self) -> int:
172		"""return a hash of the data"""
173		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))

return a hash of the data

def as_tokens( self, maze_tokenizer, limit: int | None = None, join_tokens_individual_maze: bool = False) -> list[list[str]] | list[str]:
205	def as_tokens(
206		self,
207		maze_tokenizer,  # TODO: MazeTokenizer
208		limit: int | None = None,
209		join_tokens_individual_maze: bool = False,
210	) -> list[list[str]] | list[str]:
211		"""return the dataset as tokens according to the passed `maze_tokenizer`
212
213		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
214
215		if `join_tokens_individual_maze` is True, then the tokens of each maze are
216		joined with a space, and the result is a list of strings.
217		i.e.:
218
219			>>> dataset.as_tokens(join_tokens_individual_maze=False)
220			[["a", "b", "c"], ["d", "e", "f"]]
221			>>> dataset.as_tokens(join_tokens_individual_maze=True)
222			["a b c", "d e f"]
223		"""
224		output: list[list[str]] = [
225			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
226		]
227		if join_tokens_individual_maze:
228			return [" ".join(tokens) for tokens in output]
229		else:
230			return output

return the dataset as tokens according to the passed maze_tokenizer

the maze_tokenizer should be either a MazeTokenizer or a MazeTokenizerModular

if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:

    >>> dataset.as_tokens(join_tokens_individual_maze=False)
    [["a", "b", "c"], ["d", "e", "f"]]
    >>> dataset.as_tokens(join_tokens_individual_maze=True)
    ["a b c", "d e f"]
def assert_equal(self, other: MazeDataset) -> None:
245	def assert_equal(self, other: "MazeDataset") -> None:
246		"""assert that two datasets are equal"""
247		assert isinstance(other, MazeDataset)
248		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
249		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"

assert that two datasets are equal

@classmethod
def generate( cls, cfg: maze_dataset.MazeDatasetConfig, gen_parallel: bool = False, pool_kwargs: dict | None = None, verbose: bool = False, **kwargs) -> MazeDataset:
251	@classmethod
252	def generate(
253		cls,
254		cfg: MazeDatasetConfig,
255		gen_parallel: bool = False,
256		pool_kwargs: dict | None = None,
257		verbose: bool = False,
258		# TODO: what to do when unexpected kwargs are passed?
259		**kwargs,  # noqa: ARG003
260	) -> "MazeDataset":
261		"""Generate a maze dataset given a config and some generation parameters"""
262		# Copy the config to avoid modifying the original
263		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
264			json.loads(json.dumps(cfg.serialize())),
265		)
266
267		if pool_kwargs is None:
268			pool_kwargs = dict()
269		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
270
271		solved_mazes: list[SolvedMaze | None]
272		# Configure tqdm for progress bar
273		tqdm_kwargs: dict = dict(
274			total=cfg_cpy.n_mazes,
275			unit="maze",
276			desc="generating & solving mazes",
277			disable=not verbose,
278		)
279		# TODO: don't use the global unless generating in parallel!
280		if gen_parallel:
281			with multiprocessing.Pool(
282				**pool_kwargs,
283				initializer=_maze_gen_init_worker,
284				initargs=(cfg_cpy,),
285			) as pool:
286				solved_mazes = list(
287					tqdm.tqdm(
288						pool.imap(_generate_maze_helper, maze_indexes),
289						**tqdm_kwargs,
290					),
291				)
292
293		else:
294			_maze_gen_init_worker(cfg_cpy)
295			solved_mazes = list(
296				tqdm.tqdm(
297					map(
298						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
299						# why does it think tolist() returns a string?
300						_generate_maze_helper,  # type: ignore[arg-type]
301						maze_indexes.tolist(),
302					),
303					**tqdm_kwargs,
304				),
305			)
306
307		# Filter out None values explicitly after ensuring all results are collected
308		solved_mazes_: list[SolvedMaze] = [
309			maze for maze in solved_mazes if maze is not None
310		]
311		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
312
313		# Update the config with the actual number of mazes
314		cfg_cpy.n_mazes = len(solved_mazes_)
315
316		dataset: MazeDataset = cls(
317			cfg=cfg_cpy,
318			mazes=solved_mazes_,
319		)
320
321		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
322
323		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
324
325		return dataset

Generate a maze dataset given a config and some generation parameters

@classmethod
def download( cls, cfg: maze_dataset.MazeDatasetConfig, **kwargs) -> MazeDataset:
327	@classmethod
328	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
329		"(not implemented yet!) download a maze dataset from the internet"
330		raise NotImplementedError("not implemented yet")

(not implemented yet!) download a maze dataset from the internet

@classmethod
def load( cls: type[MazeDataset], data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> MazeDataset:
332	@classmethod
333	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
334		"""load from zanj/json"""
335		if data[_FORMAT_KEY] == "MazeDataset:minimal":
336			return cls._load_minimal(data)
337		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
338			return cls._load_minimal_soln_cat(data)
339		elif data[_FORMAT_KEY] == "MazeDataset":
340			if (
341				SERIALIZE_MINIMAL_THRESHOLD == -1
342			):  # Allow access to `_load_legacy` for profiling
343				return cls._load_legacy(data)
344			return cls._load_full(data)
345		else:
346			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
347			raise KeyError(
348				err_msg,
349			)

load from zanj/json

def serialize( self) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
430	def serialize(self) -> JSONdict:
431		"""serialize to zanj/json"""
432		if (
433			SERIALIZE_MINIMAL_THRESHOLD is not None
434			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
435		):
436			return self._serialize_minimal()
437		return self._serialize_full()

serialize to zanj/json

def update_self_config(self) -> None:
542	def update_self_config(self) -> None:
543		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
544		if self.cfg.n_mazes != len(self.mazes):
545			warnings.warn(
546				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
547			)
548			self.cfg.n_mazes = len(self.mazes)

update the config to match the current state of the dataset (number of mazes, such as after filtering)

def custom_maze_filter( self, method: Callable[[maze_dataset.SolvedMaze], bool], **kwargs) -> MazeDataset:
550	def custom_maze_filter(
551		self,
552		method: typing.Callable[[SolvedMaze], bool],
553		**kwargs,
554	) -> "MazeDataset":
555		"""filter the dataset using a custom method"""
556		output: MazeDataset = MazeDataset(
557			cfg=copy.deepcopy(self.cfg),
558			mazes=[m for m in self.mazes if method(m, **kwargs)],
559		)
560		output.cfg.applied_filters.append(
561			{
562				"name": f"__custom__:{method.__name__}",
563				"kwargs": kwargs,
564			},
565		)
566		output.update_self_config()
567		return output

filter the dataset using a custom method