docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.dataset

MazeDatasetConfigs are used to create a MazeDataset via MazeDataset.from_config(cfg)"

When initializing mazes, further configuration options can be specified through the from_config() factory method as necessary. Options include 1) whether to generate the dataset during runtime or load an existing dataset, 2) if and how to parallelize generation, and 3) where to store the generated dataset. Full documentation of configuration options is available in our repository [@maze-dataset-github]. Available maze generation algorithms are static methods of the LatticeMazeGenerators class.

Furthermore, a dataset of mazes can be filtered to satisfy certain properties:

dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3)

Custom filters can be specified, and several filters are included:

  • path_length(min_length: int): shortest length from the origin to target should be at least min_length.
  • start_end_distance(min_distance: int): Manhattan distance between start and end should be at least min_distance, ignoring walls.
  • remove_duplicates(...): remove mazes which are similar to others in the dataset, measured via Hamming distance.
  • remove_duplicates_fast(): remove mazes which are exactly identical to others in the dataset.

All implemented maze generation algorithms are stochastic by nature. For reproducibility, the seed parameter of MazeDatasetConfig may be set. In practice, we do not find that exact duplicates of mazes are generated with any meaningful frequency, even when generating large datasets.


 1"""`MazeDatasetConfig`s are used to create a `MazeDataset` via `MazeDataset.from_config(cfg)`"
 2
 3When initializing mazes, further configuration options can be specified through the `from_config()` factory method as necessary. Options include 1) whether to generate the dataset during runtime or load an existing dataset, 2) if and how to parallelize generation, and 3) where to store the generated dataset. Full documentation of configuration options is available in our repository [@maze-dataset-github]. Available maze generation algorithms are static methods of the `LatticeMazeGenerators` class.
 4
 5Furthermore, a dataset of mazes can be filtered to satisfy certain properties:
 6
 7```python
 8dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3)
 9```
10
11Custom filters can be specified, and several filters are included:
12
13- `path_length(min_length: int)`: shortest length from the origin to target should be at least `min_length`.
14- `start_end_distance(min_distance: int)`: Manhattan distance between start and end should be at least `min_distance`, ignoring walls.
15- `remove_duplicates(...)`: remove mazes which are similar to others in the dataset, measured via Hamming distance.
16- `remove_duplicates_fast()`: remove mazes which are exactly identical to others in the dataset.
17
18All implemented maze generation algorithms are stochastic by nature. For reproducibility, the `seed` parameter of `MazeDatasetConfig` may be set. In practice, we do not find that exact duplicates of mazes are generated with any meaningful frequency,
19even when generating large datasets.
20
21"""
22
23from maze_dataset.dataset.collected_dataset import (
24	MazeDatasetCollection,
25	MazeDatasetCollectionConfig,
26)
27from maze_dataset.dataset.maze_dataset import MazeDataset
28from maze_dataset.dataset.maze_dataset_config import MazeDatasetConfig
29
30__all__ = [
31	# submodules
32	"collected_dataset",
33	"configs",
34	"dataset",
35	"filters",
36	"maze_dataset_config",
37	"maze_dataset",
38	"rasterized",
39	"success_predict_math",
40	# dataset classes
41	"MazeDataset",
42	"MazeDatasetConfig",
43	"MazeDatasetCollection",
44	"MazeDatasetCollectionConfig",
45]

113class MazeDataset(GPTDataset[MazeDatasetConfig]):
114	"""a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`"""
115
116	def __init__(
117		self,
118		cfg: MazeDatasetConfig,
119		mazes: typing.Sequence[SolvedMaze],
120		generation_metadata_collected: dict | None = None,
121	) -> None:
122		"""initialize a maze dataset from a config and a list of solved mazes"""
123		super().__init__()
124		self.cfg: MazeDatasetConfig = cfg
125		self.mazes: list[SolvedMaze] = list(mazes)
126		self.generation_metadata_collected: dict | None = generation_metadata_collected
127
128	# TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset"  [override]
129	@classmethod
130	def from_config(  # type: ignore[override]
131		cls,
132		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
133		cfg: MazeDatasetConfig,  # type: ignore[override]
134		do_generate: bool = True,
135		load_local: bool = True,
136		save_local: bool = True,
137		zanj: ZANJ | None = None,
138		do_download: bool = True,
139		local_base_path: Path = Path("data/maze_dataset"),
140		except_on_config_mismatch: bool = True,
141		allow_generation_metadata_filter_mismatch: bool = True,
142		verbose: bool = False,
143		**kwargs,
144	) -> "MazeDataset":
145		"""create a maze dataset from a config
146
147		priority of loading:
148		1. load from local
149		2. download
150		3. generate
151
152		"""
153		return cast(
154			MazeDataset,
155			super().from_config(
156				cfg=cfg,
157				do_generate=do_generate,
158				load_local=load_local,
159				save_local=save_local,
160				zanj=zanj,
161				do_download=do_download,
162				local_base_path=local_base_path,
163				except_on_config_mismatch=except_on_config_mismatch,
164				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
165				verbose=verbose,
166				**kwargs,
167			),
168		)
169
170	def data_hash(self) -> int:
171		"""return a hash of the data"""
172		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
173
174	def __getitem__(self, i: int) -> SolvedMaze:
175		"""get a maze by index"""
176		return self.mazes[i]
177
178	def __iter__(self) -> typing.Iterator[SolvedMaze]:
179		"""iterate over the mazes"""
180		return iter(self.mazes)
181
182	def __deepcopy__(self, memo) -> "MazeDataset":  # noqa: ANN001
183		"""deepcopy the dataset
184
185		FIX: this isnt actually a deepcopy I think?
186		"""
187		return MazeDataset.load(self._serialize_full())
188
189	# TYPING: get type hints on the tokenizer here
190	@overload
191	def as_tokens(
192		self,
193		maze_tokenizer,  # noqa: ANN001
194		limit: int | None = None,
195		join_tokens_individual_maze: Literal[False] = False,
196	) -> list[list[str]]: ...
197	@overload
198	def as_tokens(
199		self,
200		maze_tokenizer,  # noqa: ANN001
201		limit: int | None = None,
202		join_tokens_individual_maze: Literal[True] = True,
203	) -> list[str]: ...
204	def as_tokens(
205		self,
206		maze_tokenizer,  # TODO: MazeTokenizer
207		limit: int | None = None,
208		join_tokens_individual_maze: bool = False,
209	) -> list[list[str]] | list[str]:
210		"""return the dataset as tokens according to the passed `maze_tokenizer`
211
212		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
213
214		if `join_tokens_individual_maze` is True, then the tokens of each maze are
215		joined with a space, and the result is a list of strings.
216		i.e.:
217
218			>>> dataset.as_tokens(join_tokens_individual_maze=False)
219			[["a", "b", "c"], ["d", "e", "f"]]
220			>>> dataset.as_tokens(join_tokens_individual_maze=True)
221			["a b c", "d e f"]
222		"""
223		output: list[list[str]] = [
224			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
225		]
226		if join_tokens_individual_maze:
227			return [" ".join(tokens) for tokens in output]
228		else:
229			return output
230
231	def __len__(self) -> int:
232		"""return the number of mazes in the dataset"""
233		return len(self.mazes)
234
235	def __eq__(self, other: object) -> bool:
236		"""compare two datasets"""
237		if not isinstance(other, MazeDataset):
238			raise NotImplementedError(
239				"can only compare with other MazeDataset objects",
240			)
241		# TODO: compare hashes of data instead of the data itself?
242		return self.cfg == other.cfg and self.mazes == other.mazes
243
244	def assert_equal(self, other: "MazeDataset") -> None:
245		"""assert that two datasets are equal"""
246		assert isinstance(other, MazeDataset)
247		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
248		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
249
250	@classmethod
251	def generate(
252		cls,
253		cfg: MazeDatasetConfig,
254		gen_parallel: bool = False,
255		pool_kwargs: dict | None = None,
256		verbose: bool = False,
257		# TODO: what to do when unexpected kwargs are passed?
258		**kwargs,  # noqa: ARG003
259	) -> "MazeDataset":
260		"""Generate a maze dataset given a config and some generation parameters"""
261		# Copy the config to avoid modifying the original
262		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
263			json.loads(json.dumps(cfg.serialize())),
264		)
265
266		if pool_kwargs is None:
267			pool_kwargs = dict()
268		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
269
270		solved_mazes: list[SolvedMaze | None]
271		# Configure tqdm for progress bar
272		tqdm_kwargs: dict = dict(
273			total=cfg_cpy.n_mazes,
274			unit="maze",
275			desc="generating & solving mazes",
276			disable=not verbose,
277		)
278		# TODO: don't use the global unless generating in parallel!
279		if gen_parallel:
280			with multiprocessing.Pool(
281				**pool_kwargs,
282				initializer=_maze_gen_init_worker,
283				initargs=(cfg_cpy,),
284			) as pool:
285				solved_mazes = list(
286					tqdm.tqdm(
287						pool.imap(_generate_maze_helper, maze_indexes),
288						**tqdm_kwargs,
289					),
290				)
291
292		else:
293			_maze_gen_init_worker(cfg_cpy)
294			solved_mazes = list(
295				tqdm.tqdm(
296					map(
297						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
298						# why does it think tolist() returns a string?
299						_generate_maze_helper,  # type: ignore[arg-type]
300						maze_indexes.tolist(),
301					),
302					**tqdm_kwargs,
303				),
304			)
305
306		# Filter out None values explicitly after ensuring all results are collected
307		solved_mazes_: list[SolvedMaze] = [
308			maze for maze in solved_mazes if maze is not None
309		]
310		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
311
312		# Update the config with the actual number of mazes
313		cfg_cpy.n_mazes = len(solved_mazes_)
314
315		dataset: MazeDataset = cls(
316			cfg=cfg_cpy,
317			mazes=solved_mazes_,
318		)
319
320		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
321
322		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
323
324		return dataset
325
326	@classmethod
327	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
328		"(not implemented yet!) download a maze dataset from the internet"
329		raise NotImplementedError("not implemented yet")
330
331	@classmethod
332	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
333		"""load from zanj/json"""
334		if data[_FORMAT_KEY] == "MazeDataset:minimal":
335			return cls._load_minimal(data)
336		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
337			return cls._load_minimal_soln_cat(data)
338		elif data[_FORMAT_KEY] == "MazeDataset":
339			if (
340				SERIALIZE_MINIMAL_THRESHOLD == -1
341			):  # Allow access to `_load_legacy` for profiling
342				return cls._load_legacy(data)
343			return cls._load_full(data)
344		else:
345			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
346			raise KeyError(
347				err_msg,
348			)
349
350	@classmethod
351	def _load_full(cls, data: JSONdict) -> "MazeDataset":
352		assert data[_FORMAT_KEY] == "MazeDataset"
353		return cls(
354			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
355			mazes=load_item_recursive(data["mazes"], tuple()),
356			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
357		)
358
359	@classmethod
360	def _load_minimal(cls, data: JSONdict) -> "MazeDataset":
361		assert data[_FORMAT_KEY] == "MazeDataset:minimal"
362		return cls(
363			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
364			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
365			mazes=[
366				SolvedMaze(
367					clist,
368					soln[:slen, ...],
369				)
370				for clist, slen, soln in zip(
371					load_item_recursive(data["maze_connection_lists"], tuple()),
372					load_item_recursive(data["maze_solution_lengths"], tuple()),
373					load_item_recursive(data["maze_solutions"], tuple()),
374					strict=False,
375					# load_item_recursive(data["maze_endpoints"], tuple()),
376				)
377			],
378		)
379
380	@classmethod
381	def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset":
382		assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat"
383
384		maze_solution_lengths = load_item_recursive(
385			data["maze_solution_lengths"],
386			tuple(),
387		)
388		maze_solutions_concat = load_item_recursive(
389			data["maze_solutions_concat"],
390			tuple(),
391		)
392		maze_solutions = np.split(
393			maze_solutions_concat,
394			np.cumsum(maze_solution_lengths)[:-1],
395			axis=0,
396		)
397
398		return cls(
399			cfg=load_item_recursive(data["cfg"], tuple()),
400			generation_metadata_collected=load_item_recursive(
401				data["generation_metadata_collected"],
402				tuple(),
403			),
404			mazes=[
405				SolvedMaze(
406					connection_list=clist,
407					solution=soln,
408				)
409				for clist, soln in zip(
410					load_item_recursive(data["maze_connection_lists"], tuple()),
411					# load_item_recursive(data["maze_endpoints"], tuple()),
412					maze_solutions,
413					strict=False,
414				)
415			],
416		)
417
418	@classmethod
419	def _load_legacy(cls, data: JSONdict) -> "MazeDataset":
420		"""Legacy `load` method from <0.5.2. Used exclusively for profiling comparison."""
421		assert data[_FORMAT_KEY] == "MazeDataset"
422		return cls(
423			**{
424				key: load_item_recursive(data[key], tuple())
425				for key in ["cfg", "mazes", "generation_metadata_collected"]
426			},
427		)
428
429	def serialize(self) -> JSONdict:
430		"""serialize to zanj/json"""
431		if (
432			SERIALIZE_MINIMAL_THRESHOLD is not None
433			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
434		):
435			return self._serialize_minimal()
436		return self._serialize_full()
437
438	def _serialize_full(self) -> JSONdict:
439		return {
440			_FORMAT_KEY: "MazeDataset",
441			"cfg": json_serialize(self.cfg),
442			"fname": self.cfg.to_fname(),
443			"mazes": json_serialize(self.mazes),
444			"generation_metadata_collected": json_serialize(
445				self.generation_metadata_collected,
446			),
447		}
448
449	def _serialize_minimal(self) -> JSONdict:
450		"alternate serialization where metadata is collected and mazes are stored in concatenated form"
451		filtered_meta: MazeDataset
452		if self.generation_metadata_collected is None:
453			filtered_meta = self.filter_by.collect_generation_meta()
454		else:
455			filtered_meta = self
456
457		max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes)
458		n_mazes: int = len(filtered_meta.mazes)
459		grid_n: int = filtered_meta.cfg.grid_n
460
461		maze_connection_lists: np.ndarray = np.empty(
462			(n_mazes, 2, grid_n, grid_n),
463			dtype=np.bool_,
464		)
465		# maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
466		maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32)
467		maze_solutions: np.ndarray = np.empty(
468			(n_mazes, max_solution_len, 2),
469			dtype=np.int8,
470		)
471
472		for idx, maze in enumerate(filtered_meta.mazes):
473			maze_connection_lists[idx] = maze.connection_list
474			# maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
475			maze_solution_lengths[idx] = maze.solution.shape[0]
476			maze_solutions[idx, : maze.solution.shape[0]] = maze.solution
477
478		return {
479			_FORMAT_KEY: "MazeDataset:minimal",
480			"cfg": json_serialize(filtered_meta.cfg),
481			"fname": filtered_meta.cfg.to_fname(),
482			"generation_metadata_collected": json_serialize(
483				filtered_meta.generation_metadata_collected,
484			),
485			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
486			# "maze_endpoints": maze_endpoints,
487			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
488			"maze_solutions": maze_solutions,  # type: ignore[dict-item]
489		}
490
491	def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict:
492		"alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form"
493		filtered_meta: MazeDataset
494		if self.generation_metadata_collected is None:
495			filtered_meta = self.filter_by.collect_generation_meta()
496		else:
497			filtered_meta = self
498
499		maze_solution_lengths: np.ndarray = np.array(
500			[m.solution.shape[0] for m in filtered_meta.mazes],
501			dtype=np.int32,
502		)
503		n_mazes: int = len(filtered_meta.mazes)
504		grid_n: int = filtered_meta.cfg.grid_n
505		total_solution_len: int = np.sum(maze_solution_lengths)
506
507		maze_connection_lists: np.ndarray = np.empty(
508			(n_mazes, 2, grid_n, grid_n),
509			dtype=np.bool_,
510		)
511		maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
512		maze_solutions_concat: np.ndarray = np.empty(
513			(total_solution_len, 2),
514			dtype=np.int8,
515		)
516
517		solutions_running_idx: int = 0
518		for idx, maze in enumerate(filtered_meta.mazes):
519			maze_connection_lists[idx] = maze.connection_list
520			maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
521			soln_len: int = maze.solution.shape[0]
522			maze_solution_lengths[idx] = soln_len
523			maze_solutions_concat[
524				solutions_running_idx : solutions_running_idx + soln_len
525			] = maze.solution
526			solutions_running_idx += soln_len
527
528		return {
529			_FORMAT_KEY: "MazeDataset:minimal_soln_cat",
530			"cfg": json_serialize(filtered_meta.cfg),
531			"fname": filtered_meta.cfg.to_fname(),
532			"generation_metadata_collected": json_serialize(
533				filtered_meta.generation_metadata_collected,
534			),
535			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
536			"maze_endpoints": maze_endpoints,  # type: ignore[dict-item]
537			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
538			"maze_solutions_concat": maze_solutions_concat,  # type: ignore[dict-item]
539		}
540
541	def update_self_config(self) -> None:
542		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
543		if self.cfg.n_mazes != len(self.mazes):
544			warnings.warn(
545				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
546			)
547			self.cfg.n_mazes = len(self.mazes)
548
549	def custom_maze_filter(
550		self,
551		method: typing.Callable[[SolvedMaze], bool],
552		**kwargs,
553	) -> "MazeDataset":
554		"""filter the dataset using a custom method"""
555		output: MazeDataset = MazeDataset(
556			cfg=copy.deepcopy(self.cfg),
557			mazes=[m for m in self.mazes if method(m, **kwargs)],
558		)
559		output.cfg.applied_filters.append(
560			{
561				"name": f"__custom__:{method.__name__}",
562				"kwargs": kwargs,
563			},
564		)
565		output.update_self_config()
566		return output

a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config

MazeDataset( cfg: MazeDatasetConfig, mazes: Sequence[maze_dataset.SolvedMaze], generation_metadata_collected: dict | None = None)
116	def __init__(
117		self,
118		cfg: MazeDatasetConfig,
119		mazes: typing.Sequence[SolvedMaze],
120		generation_metadata_collected: dict | None = None,
121	) -> None:
122		"""initialize a maze dataset from a config and a list of solved mazes"""
123		super().__init__()
124		self.cfg: MazeDatasetConfig = cfg
125		self.mazes: list[SolvedMaze] = list(mazes)
126		self.generation_metadata_collected: dict | None = generation_metadata_collected

initialize a maze dataset from a config and a list of solved mazes

generation_metadata_collected: dict | None
@classmethod
def from_config( cls, cfg: MazeDatasetConfig, do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: zanj.zanj.ZANJ | None = None, do_download: bool = True, local_base_path: pathlib.Path = PosixPath('data/maze_dataset'), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs) -> MazeDataset:
129	@classmethod
130	def from_config(  # type: ignore[override]
131		cls,
132		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
133		cfg: MazeDatasetConfig,  # type: ignore[override]
134		do_generate: bool = True,
135		load_local: bool = True,
136		save_local: bool = True,
137		zanj: ZANJ | None = None,
138		do_download: bool = True,
139		local_base_path: Path = Path("data/maze_dataset"),
140		except_on_config_mismatch: bool = True,
141		allow_generation_metadata_filter_mismatch: bool = True,
142		verbose: bool = False,
143		**kwargs,
144	) -> "MazeDataset":
145		"""create a maze dataset from a config
146
147		priority of loading:
148		1. load from local
149		2. download
150		3. generate
151
152		"""
153		return cast(
154			MazeDataset,
155			super().from_config(
156				cfg=cfg,
157				do_generate=do_generate,
158				load_local=load_local,
159				save_local=save_local,
160				zanj=zanj,
161				do_download=do_download,
162				local_base_path=local_base_path,
163				except_on_config_mismatch=except_on_config_mismatch,
164				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
165				verbose=verbose,
166				**kwargs,
167			),
168		)

create a maze dataset from a config

priority of loading:

  1. load from local
  2. download
  3. generate
def data_hash(self) -> int:
170	def data_hash(self) -> int:
171		"""return a hash of the data"""
172		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))

return a hash of the data

def as_tokens( self, maze_tokenizer, limit: int | None = None, join_tokens_individual_maze: bool = False) -> list[list[str]] | list[str]:
204	def as_tokens(
205		self,
206		maze_tokenizer,  # TODO: MazeTokenizer
207		limit: int | None = None,
208		join_tokens_individual_maze: bool = False,
209	) -> list[list[str]] | list[str]:
210		"""return the dataset as tokens according to the passed `maze_tokenizer`
211
212		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
213
214		if `join_tokens_individual_maze` is True, then the tokens of each maze are
215		joined with a space, and the result is a list of strings.
216		i.e.:
217
218			>>> dataset.as_tokens(join_tokens_individual_maze=False)
219			[["a", "b", "c"], ["d", "e", "f"]]
220			>>> dataset.as_tokens(join_tokens_individual_maze=True)
221			["a b c", "d e f"]
222		"""
223		output: list[list[str]] = [
224			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
225		]
226		if join_tokens_individual_maze:
227			return [" ".join(tokens) for tokens in output]
228		else:
229			return output

return the dataset as tokens according to the passed maze_tokenizer

the maze_tokenizer should be either a MazeTokenizer or a MazeTokenizerModular

if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:

    >>> dataset.as_tokens(join_tokens_individual_maze=False)
    [["a", "b", "c"], ["d", "e", "f"]]
    >>> dataset.as_tokens(join_tokens_individual_maze=True)
    ["a b c", "d e f"]
def assert_equal(self, other: MazeDataset) -> None:
244	def assert_equal(self, other: "MazeDataset") -> None:
245		"""assert that two datasets are equal"""
246		assert isinstance(other, MazeDataset)
247		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
248		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"

assert that two datasets are equal

@classmethod
def generate( cls, cfg: MazeDatasetConfig, gen_parallel: bool = False, pool_kwargs: dict | None = None, verbose: bool = False, **kwargs) -> MazeDataset:
250	@classmethod
251	def generate(
252		cls,
253		cfg: MazeDatasetConfig,
254		gen_parallel: bool = False,
255		pool_kwargs: dict | None = None,
256		verbose: bool = False,
257		# TODO: what to do when unexpected kwargs are passed?
258		**kwargs,  # noqa: ARG003
259	) -> "MazeDataset":
260		"""Generate a maze dataset given a config and some generation parameters"""
261		# Copy the config to avoid modifying the original
262		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
263			json.loads(json.dumps(cfg.serialize())),
264		)
265
266		if pool_kwargs is None:
267			pool_kwargs = dict()
268		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
269
270		solved_mazes: list[SolvedMaze | None]
271		# Configure tqdm for progress bar
272		tqdm_kwargs: dict = dict(
273			total=cfg_cpy.n_mazes,
274			unit="maze",
275			desc="generating & solving mazes",
276			disable=not verbose,
277		)
278		# TODO: don't use the global unless generating in parallel!
279		if gen_parallel:
280			with multiprocessing.Pool(
281				**pool_kwargs,
282				initializer=_maze_gen_init_worker,
283				initargs=(cfg_cpy,),
284			) as pool:
285				solved_mazes = list(
286					tqdm.tqdm(
287						pool.imap(_generate_maze_helper, maze_indexes),
288						**tqdm_kwargs,
289					),
290				)
291
292		else:
293			_maze_gen_init_worker(cfg_cpy)
294			solved_mazes = list(
295				tqdm.tqdm(
296					map(
297						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
298						# why does it think tolist() returns a string?
299						_generate_maze_helper,  # type: ignore[arg-type]
300						maze_indexes.tolist(),
301					),
302					**tqdm_kwargs,
303				),
304			)
305
306		# Filter out None values explicitly after ensuring all results are collected
307		solved_mazes_: list[SolvedMaze] = [
308			maze for maze in solved_mazes if maze is not None
309		]
310		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
311
312		# Update the config with the actual number of mazes
313		cfg_cpy.n_mazes = len(solved_mazes_)
314
315		dataset: MazeDataset = cls(
316			cfg=cfg_cpy,
317			mazes=solved_mazes_,
318		)
319
320		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
321
322		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
323
324		return dataset

Generate a maze dataset given a config and some generation parameters

@classmethod
def download( cls, cfg: MazeDatasetConfig, **kwargs) -> MazeDataset:
326	@classmethod
327	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
328		"(not implemented yet!) download a maze dataset from the internet"
329		raise NotImplementedError("not implemented yet")

(not implemented yet!) download a maze dataset from the internet

@classmethod
def load( cls: type[MazeDataset], data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> MazeDataset:
331	@classmethod
332	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
333		"""load from zanj/json"""
334		if data[_FORMAT_KEY] == "MazeDataset:minimal":
335			return cls._load_minimal(data)
336		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
337			return cls._load_minimal_soln_cat(data)
338		elif data[_FORMAT_KEY] == "MazeDataset":
339			if (
340				SERIALIZE_MINIMAL_THRESHOLD == -1
341			):  # Allow access to `_load_legacy` for profiling
342				return cls._load_legacy(data)
343			return cls._load_full(data)
344		else:
345			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
346			raise KeyError(
347				err_msg,
348			)

load from zanj/json

def serialize( self) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
429	def serialize(self) -> JSONdict:
430		"""serialize to zanj/json"""
431		if (
432			SERIALIZE_MINIMAL_THRESHOLD is not None
433			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
434		):
435			return self._serialize_minimal()
436		return self._serialize_full()

serialize to zanj/json

def update_self_config(self) -> None:
541	def update_self_config(self) -> None:
542		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
543		if self.cfg.n_mazes != len(self.mazes):
544			warnings.warn(
545				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
546			)
547			self.cfg.n_mazes = len(self.mazes)

update the config to match the current state of the dataset (number of mazes, such as after filtering)

def custom_maze_filter( self, method: Callable[[maze_dataset.SolvedMaze], bool], **kwargs) -> MazeDataset:
549	def custom_maze_filter(
550		self,
551		method: typing.Callable[[SolvedMaze], bool],
552		**kwargs,
553	) -> "MazeDataset":
554		"""filter the dataset using a custom method"""
555		output: MazeDataset = MazeDataset(
556			cfg=copy.deepcopy(self.cfg),
557			mazes=[m for m in self.mazes if method(m, **kwargs)],
558		)
559		output.cfg.applied_filters.append(
560			{
561				"name": f"__custom__:{method.__name__}",
562				"kwargs": kwargs,
563			},
564		)
565		output.update_self_config()
566		return output

filter the dataset using a custom method

@serializable_dataclass(kw_only=True, methods_no_override=['serialize'])
class MazeDatasetConfig(maze_dataset.dataset.maze_dataset_config.MazeDatasetConfig_base):
257@serializable_dataclass(kw_only=True, methods_no_override=["serialize"])
258class MazeDatasetConfig(MazeDatasetConfig_base):  # type: ignore[misc]
259	"""config object which is passed to `MazeDataset.from_config` to generate or load a dataset
260
261	# Parameters:
262	- `name : str`
263		name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname`
264	- `grid_n : int`
265		grid size of the maze (number of rows/columns)
266	- `n_mazes : int`
267		number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate.
268		see `EndpointKwargsType` for more details.
269	- `maze_ctor : Callable`
270		maze generator function. This should be a function that takes a grid size and returns a maze.
271		This will usually be one of the functions in `LatticeMazeGenerators`.
272	- `maze_ctor_kwargs : dict`
273		keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using.
274	- `endpoint_kwargs : EndpointKwargsType`
275		keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info.
276	- `applied_filters : list[dict]`
277		list of filters that have been applied to the dataset. We recommend applying filters to datasets directly,
278		but these are stored with the config in case you want to re-generate the dataset with the same filters.
279
280	"""
281
282	@property
283	def config_version(self) -> str:
284		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
285		return "1.0"
286
287	@property
288	def versions(self) -> dict:
289		"""return the versions of the config and the maze_dataset"""
290		return dict(
291			config=self.config_version,
292			maze_dataset=importlib.metadata.version("maze_dataset"),
293		)
294
295	def serialize(self) -> dict:
296		"serialize the MazeDatasetConfig with all fields and fname"
297		return {
298			**self._serialize_base(
299				applied_filters__skip__collect_generation_meta=False
300			),
301			"fname": self.to_fname(),
302			"versions": self.versions,
303		}
304
305	def summary(self) -> dict:
306		"""return a summary of the config"""
307		# do we run this to make sure it doesn't error?
308		super_summary: dict = super().summary()
309		assert super_summary
310		self_ser: dict = self.serialize()
311		return dict(
312			name=self.name,
313			fname=self.to_fname(),
314			sdc_hash=self.stable_hash_cfg(),
315			seed=self.seed,
316			seq_len_min=self.seq_len_min,
317			seq_len_max=self.seq_len_max,
318			applied_filters=self.applied_filters,
319			grid_n=self_ser["grid_n"],
320			n_mazes=self_ser["n_mazes"],
321			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
322			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
323			endpoint_kwargs=self_ser["endpoint_kwargs"],
324		)
325
326	def _to_ps_array(self) -> _PercolationSuccessArray:
327		"""Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector.
328
329		used in predicting the success rate
330		"""
331		try:
332			assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, (
333				f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }"
334			)
335			assert "p" in self.maze_ctor_kwargs, (
336				f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }"
337			)
338			assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), (
339				f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }"
340			)
341		except AssertionError as e:
342			err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }"
343			raise NoPercolationInConfigError(
344				err_msg,
345			) from e
346
347		endpoints_unique_flag: int = int(
348			# we are pretty sure it will be an int or bool here
349			self.endpoint_kwargs.get("endpoints_not_equal", True),  # type: ignore[arg-type]
350		)
351
352		# adjustment for bknutson0
353		if not (
354			self.endpoint_kwargs.get("deadend_start", False)
355			and self.endpoint_kwargs.get("deadend_end", False)
356		):
357			# we didnt train on this, but if either endpoint is not required to be in a dead end
358			# then  requiring the endpoints to be unique does not really affect the success rate
359			# (except for very small percolation values, pure percolation generation)
360			endpoints_unique_flag = 0
361
362		return np.array(
363			[
364				float(self.maze_ctor_kwargs["p"]),
365				float(self.grid_n),
366				float(
367					int(
368						self.endpoint_kwargs.get("deadend_start", False)  # type: ignore[arg-type]
369						or self.endpoint_kwargs.get("deadend_end", False),
370					),
371				),
372				float(endpoints_unique_flag),
373				float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)),
374			],
375			dtype=np.float64,
376		)
377
378	@classmethod
379	def _from_ps_array(
380		cls,
381		arr: _PercolationSuccessArray,
382		name: str = "predict",
383		n_mazes: int = 100,
384		**kwargs,
385	) -> "MazeDatasetConfig":
386		"""Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters.
387
388		# Returns:
389		- `MazeDatasetConfig`
390			Config corresponding to `arr`
391		"""
392		return cls(
393			name=name,
394			grid_n=int(arr[1]),
395			n_mazes=n_mazes,
396			maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]],
397			maze_ctor_kwargs={"p": float(arr[0])},
398			endpoint_kwargs=dict(
399				deadend_start=bool(arr[2]),
400				deadend_end=bool(arr[2]),
401				endpoints_not_equal=bool(arr[3]),
402				except_on_no_valid_endpoint=False,
403			),
404			**kwargs,
405		)
406
407	def success_fraction_estimate(
408		self,
409		except_if_all_success_expected: bool = False,
410	) -> float:
411		"""Estimate the success fraction of this config.
412
413		only valid when the generator is a percolation generator,
414		and endpoints are enforced to be dead ends
415
416		more information on where this comes from can be found in
417		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
418		- `estimate_dataset_fractions.ipynb`
419		- `maze_dataset.benchmarks.sweep_fit`
420
421		# Parameters:
422		- `except_if_all_success_expected : bool`
423			if `True`, don't raise an error if the success fraction is below the threshold.
424			will always return `1.0` if the config is not expected to fail
425
426		# Returns:
427		- `float`
428			estimated success fraction
429
430		# Raises:
431		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
432		"""
433		try:
434			return cfg_success_predict_fn(self)
435
436		except NoPercolationInConfigError as e:
437			if except_if_all_success_expected:
438				raise e  # noqa: TRY201
439			return 1.0
440
441	def success_fraction_compensate(
442		self,
443		safety_margin: float = 1.2,
444		except_if_all_success_expected: bool = False,
445		epsilon: float = 1e-2,
446	) -> "MazeDatasetConfig":
447		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
448
449		calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then
450		computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1`
451
452		more information on where this comes from can be found in
453		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
454		- `estimate_dataset_fractions.ipynb`
455		- `maze_dataset.benchmarks.sweep_fit`
456
457		# Parameters:
458		- `safety_margin : float`
459			safety margin to apply to the success fraction estimate
460			(defaults to `1.2`, or 20% more mazes than estimated)
461		- `except_if_all_success_expected : bool`
462			if `True`, don't raise an error if the success fraction is below the threshold.
463			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
464			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
465			since `safety_margin` is still applied.
466			(defaults to `False`)
467		- `epsilon : float`
468			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
469			(defaults to `1e-2`)
470
471		# Returns:
472		- `MazeDatasetConfig`
473			new config with adjusted `n_mazes`
474
475		# Raises:
476		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
477		"""
478		# compute and check the success fraction
479		success_fraction: float = self.success_fraction_estimate(
480			except_if_all_success_expected=except_if_all_success_expected,
481		)
482		if success_fraction < epsilon:
483			err_msg: str = (
484				f"{success_fraction = } is below the threshold of {epsilon = }"
485			)
486			raise SuccessChanceTooSmallError(
487				err_msg,
488			)
489
490		# compute the new number of mazes
491		n_mazes: int = self.n_mazes
492		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
493
494		# put it in a new config and return
495		cfg_dict: dict = self.serialize()
496		cfg_dict["n_mazes"] = new_n_mazes
497		return MazeDatasetConfig.load(cfg_dict)

config object which is passed to MazeDataset.from_config to generate or load a dataset

Parameters:

  • name : str name of the dataset -- this can be anything, but should be filesystem safe since we use it in the fname
  • grid_n : int grid size of the maze (number of rows/columns)
  • n_mazes : int number of mazes to request. For some combinations of endpoint_kwargs and maze_ctor, not all mazes might successfully generate. see EndpointKwargsType for more details.
  • maze_ctor : Callable maze generator function. This should be a function that takes a grid size and returns a maze. This will usually be one of the functions in LatticeMazeGenerators.
  • maze_ctor_kwargs : dict keyword arguments to pass to the maze generator function. Specific to the maze_ctor you are using.
  • endpoint_kwargs : EndpointKwargsType keyword arguments passed to LatticeMaze.generate_random_path(). see EndpointKwargsType for more info.
  • applied_filters : list[dict] list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, but these are stored with the config in case you want to re-generate the dataset with the same filters.
MazeDatasetConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, grid_n: int, n_mazes: int, maze_ctor: Callable = <function LatticeMazeGenerators.gen_dfs>, maze_ctor_kwargs: dict = <factory>, endpoint_kwargs: dict[typing.Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | None | list[tuple[int, int]]] = <factory>, _fname_loaded: str | None = None)
config_version: str
282	@property
283	def config_version(self) -> str:
284		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
285		return "1.0"

return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config

versions: dict
287	@property
288	def versions(self) -> dict:
289		"""return the versions of the config and the maze_dataset"""
290		return dict(
291			config=self.config_version,
292			maze_dataset=importlib.metadata.version("maze_dataset"),
293		)

return the versions of the config and the maze_dataset

def serialize(self) -> dict:
295	def serialize(self) -> dict:
296		"serialize the MazeDatasetConfig with all fields and fname"
297		return {
298			**self._serialize_base(
299				applied_filters__skip__collect_generation_meta=False
300			),
301			"fname": self.to_fname(),
302			"versions": self.versions,
303		}

serialize the MazeDatasetConfig with all fields and fname

def summary(self) -> dict:
305	def summary(self) -> dict:
306		"""return a summary of the config"""
307		# do we run this to make sure it doesn't error?
308		super_summary: dict = super().summary()
309		assert super_summary
310		self_ser: dict = self.serialize()
311		return dict(
312			name=self.name,
313			fname=self.to_fname(),
314			sdc_hash=self.stable_hash_cfg(),
315			seed=self.seed,
316			seq_len_min=self.seq_len_min,
317			seq_len_max=self.seq_len_max,
318			applied_filters=self.applied_filters,
319			grid_n=self_ser["grid_n"],
320			n_mazes=self_ser["n_mazes"],
321			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
322			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
323			endpoint_kwargs=self_ser["endpoint_kwargs"],
324		)

return a summary of the config

def success_fraction_estimate(self, except_if_all_success_expected: bool = False) -> float:
407	def success_fraction_estimate(
408		self,
409		except_if_all_success_expected: bool = False,
410	) -> float:
411		"""Estimate the success fraction of this config.
412
413		only valid when the generator is a percolation generator,
414		and endpoints are enforced to be dead ends
415
416		more information on where this comes from can be found in
417		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
418		- `estimate_dataset_fractions.ipynb`
419		- `maze_dataset.benchmarks.sweep_fit`
420
421		# Parameters:
422		- `except_if_all_success_expected : bool`
423			if `True`, don't raise an error if the success fraction is below the threshold.
424			will always return `1.0` if the config is not expected to fail
425
426		# Returns:
427		- `float`
428			estimated success fraction
429
430		# Raises:
431		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
432		"""
433		try:
434			return cfg_success_predict_fn(self)
435
436		except NoPercolationInConfigError as e:
437			if except_if_all_success_expected:
438				raise e  # noqa: TRY201
439			return 1.0

Estimate the success fraction of this config.

only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends

more information on where this comes from can be found in

Parameters:

  • except_if_all_success_expected : bool if True, don't raise an error if the success fraction is below the threshold. will always return 1.0 if the config is not expected to fail

Returns:

  • float estimated success fraction

Raises:

  • NoPercolationInConfigError : if the config is not expected to fail, and except_if_all_success_expected is False
def success_fraction_compensate( self, safety_margin: float = 1.2, except_if_all_success_expected: bool = False, epsilon: float = 0.01) -> MazeDatasetConfig:
441	def success_fraction_compensate(
442		self,
443		safety_margin: float = 1.2,
444		except_if_all_success_expected: bool = False,
445		epsilon: float = 1e-2,
446	) -> "MazeDatasetConfig":
447		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
448
449		calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then
450		computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1`
451
452		more information on where this comes from can be found in
453		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
454		- `estimate_dataset_fractions.ipynb`
455		- `maze_dataset.benchmarks.sweep_fit`
456
457		# Parameters:
458		- `safety_margin : float`
459			safety margin to apply to the success fraction estimate
460			(defaults to `1.2`, or 20% more mazes than estimated)
461		- `except_if_all_success_expected : bool`
462			if `True`, don't raise an error if the success fraction is below the threshold.
463			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
464			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
465			since `safety_margin` is still applied.
466			(defaults to `False`)
467		- `epsilon : float`
468			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
469			(defaults to `1e-2`)
470
471		# Returns:
472		- `MazeDatasetConfig`
473			new config with adjusted `n_mazes`
474
475		# Raises:
476		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
477		"""
478		# compute and check the success fraction
479		success_fraction: float = self.success_fraction_estimate(
480			except_if_all_success_expected=except_if_all_success_expected,
481		)
482		if success_fraction < epsilon:
483			err_msg: str = (
484				f"{success_fraction = } is below the threshold of {epsilon = }"
485			)
486			raise SuccessChanceTooSmallError(
487				err_msg,
488			)
489
490		# compute the new number of mazes
491		n_mazes: int = self.n_mazes
492		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
493
494		# put it in a new config and return
495		cfg_dict: dict = self.serialize()
496		cfg_dict["n_mazes"] = new_n_mazes
497		return MazeDatasetConfig.load(cfg_dict)

return a new MazeDatasetConfig like this one with n_mazes adjusted to compensate for the success fraction

calls MazeDatasetConfig.success_fraction_estimate() to get the success fraction, and then computes the new number of mazes as n_mazes = n_mazes * safety_margin / success_fraction + 1

more information on where this comes from can be found in

Parameters:

  • safety_margin : float safety margin to apply to the success fraction estimate (defaults to 1.2, or 20% more mazes than estimated)
  • except_if_all_success_expected : bool if True, don't raise an error if the success fraction is below the threshold. this is passed to MazeDatasetConfig.success_fraction_estimate. if your config isn't expected to fail, passing this might mean you generate more mazes than needed since safety_margin is still applied. (defaults to False)
  • epsilon : float raise SuccessChanceTooSmallError if the success fraction is below this threshold (defaults to 1e-2)

Returns:

Raises:

  • SuccessChanceTooSmallError : if the computed success fraction is below epsilon
@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

class MazeDatasetCollection(typing.Generic[~T_DatasetConfig]):
 84class MazeDatasetCollection(GPTDataset):
 85	"""a collection of maze datasets"""
 86
 87	def __init__(
 88		self,
 89		cfg: MazeDatasetCollectionConfig,
 90		maze_datasets: list[MazeDataset],
 91		generation_metadata_collected: dict | None = None,
 92	) -> None:
 93		"initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s"
 94		super().__init__()
 95		self.cfg: MazeDatasetCollectionConfig = cfg
 96		self.maze_datasets: list[MazeDataset] = list(maze_datasets)
 97		for c, ds in zip(
 98			self.cfg.maze_dataset_configs,
 99			self.maze_datasets,
100			strict=False,
101		):
102			assert c.name == ds.cfg.name
103			assert c == ds.cfg
104
105		self.generation_metadata_collected: dict | None = generation_metadata_collected
106
107	@property
108	def dataset_lengths(self) -> list[int]:
109		"""return the lengths of each dataset in the collection"""
110		return [len(dataset) for dataset in self.maze_datasets]
111
112	@property
113	def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
114		"""return the cumulative lengths of each dataset in the collection"""
115		return np.array(list(itertools.accumulate(self.dataset_lengths)))
116
117	@cached_property
118	def mazes(self) -> list[LatticeMaze]:
119		"single list of all mazes in the collection"
120		return list(
121			itertools.chain.from_iterable(
122				dataset.mazes for dataset in self.maze_datasets
123			),
124		)
125
126	def __len__(self) -> int:
127		"""return the total number of mazes in the collection"""
128		return sum(len(dataset) for dataset in self.maze_datasets)
129
130	def __getitem__(self, index: int) -> LatticeMaze:
131		"get a maze by index"
132		# find which dataset the index belongs to
133		# we add 1, since np.searchsorted returns the
134		# index of the last element that is strictly less than the target
135		# while we want the index of the last element less than or equal to the target
136		dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1))
137		index_adjusted: int = index
138		if dataset_idx > 0:
139			# if the index is 0, `dataset_idx - 1` will be -1.
140			# We just want to use the base index
141			index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1]
142		return self.maze_datasets[dataset_idx][index_adjusted]
143
144	@classmethod
145	def generate(
146		cls,
147		cfg: MazeDatasetCollectionConfig,
148		**kwargs,
149	) -> "MazeDatasetCollection":
150		"""generate a dataset collection from a config"""
151		datasets = [
152			MazeDataset.generate(config, **kwargs)
153			for config in cfg.maze_dataset_configs
154		]
155		return cls(cfg, datasets)
156
157	@classmethod
158	def download(
159		cls,
160		cfg: MazeDatasetCollectionConfig,
161		**kwargs,
162	) -> "MazeDatasetCollection":
163		"(not implemented!) download a dataset collection from a config"
164		datasets = [
165			MazeDataset.download(config, **kwargs)
166			for config in cfg.maze_dataset_configs
167		]
168		return cls(cfg, datasets)
169
170	def serialize(self) -> JSONdict:
171		"""serialize the dataset collection"""
172		return {
173			_FORMAT_KEY: "MazeDatasetCollection",
174			"cfg": self.cfg.serialize(),
175			"maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
176			"generation_metadata_collected": json_serialize(
177				self.generation_metadata_collected,
178			),
179		}
180
181	@classmethod
182	def load(cls, data: JSONdict) -> "MazeDatasetCollection":
183		"""load the dataset collection from the representation created by `serialize`"""
184		assert data[_FORMAT_KEY] == "MazeDatasetCollection"
185		return cls(
186			**{
187				key: load_item_recursive(data[key], tuple())
188				for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
189			},
190		)
191
192	# TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow?
193	def as_tokens(
194		self,
195		# TODO: MazeTokenizer
196		maze_tokenizer,  # noqa: ANN001
197		limit: int | None = None,
198		join_tokens_individual_maze: bool = False,
199	) -> list[list[str]] | list[str]:
200		"""return the dataset as tokens
201
202		if join_tokens_individual_maze is True, then the tokens of each maze are
203		joined with a space, and the result is a list of strings.
204		i.e.:
205		>>> dataset.as_tokens(join_tokens_individual_maze=False)
206		[["a", "b", "c"], ["d", "e", "f"]]
207		>>> dataset.as_tokens(join_tokens_individual_maze=True)
208		["a b c", "d e f"]
209		"""
210		output: list[list[str]] = [
211			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
212		]
213		if join_tokens_individual_maze:
214			return [" ".join(tokens) for tokens in output]
215		else:
216			return output
217
218	def update_self_config(self) -> None:
219		"update the config to match the number of mazes, and update the underlying configs of each dataset"
220		# TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
221		self.cfg.__dict__["n_mazes"] = len(self)
222		for dataset in self.maze_datasets:
223			dataset.update_self_config()
224
225		self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]

a collection of maze datasets

MazeDatasetCollection( cfg: MazeDatasetCollectionConfig, maze_datasets: list[MazeDataset], generation_metadata_collected: dict | None = None)
 87	def __init__(
 88		self,
 89		cfg: MazeDatasetCollectionConfig,
 90		maze_datasets: list[MazeDataset],
 91		generation_metadata_collected: dict | None = None,
 92	) -> None:
 93		"initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s"
 94		super().__init__()
 95		self.cfg: MazeDatasetCollectionConfig = cfg
 96		self.maze_datasets: list[MazeDataset] = list(maze_datasets)
 97		for c, ds in zip(
 98			self.cfg.maze_dataset_configs,
 99			self.maze_datasets,
100			strict=False,
101		):
102			assert c.name == ds.cfg.name
103			assert c == ds.cfg
104
105		self.generation_metadata_collected: dict | None = generation_metadata_collected

initialize the dataset collection from a MazeDatasetCollectionConfig and a list of MazeDatasets

maze_datasets: list[MazeDataset]
generation_metadata_collected: dict | None
dataset_lengths: list[int]
107	@property
108	def dataset_lengths(self) -> list[int]:
109		"""return the lengths of each dataset in the collection"""
110		return [len(dataset) for dataset in self.maze_datasets]

return the lengths of each dataset in the collection

dataset_cum_lengths: jaxtyping.Int[ndarray, 'indices']
112	@property
113	def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
114		"""return the cumulative lengths of each dataset in the collection"""
115		return np.array(list(itertools.accumulate(self.dataset_lengths)))

return the cumulative lengths of each dataset in the collection

117	@cached_property
118	def mazes(self) -> list[LatticeMaze]:
119		"single list of all mazes in the collection"
120		return list(
121			itertools.chain.from_iterable(
122				dataset.mazes for dataset in self.maze_datasets
123			),
124		)

single list of all mazes in the collection

@classmethod
def generate( cls, cfg: MazeDatasetCollectionConfig, **kwargs) -> MazeDatasetCollection:
144	@classmethod
145	def generate(
146		cls,
147		cfg: MazeDatasetCollectionConfig,
148		**kwargs,
149	) -> "MazeDatasetCollection":
150		"""generate a dataset collection from a config"""
151		datasets = [
152			MazeDataset.generate(config, **kwargs)
153			for config in cfg.maze_dataset_configs
154		]
155		return cls(cfg, datasets)

generate a dataset collection from a config

@classmethod
def download( cls, cfg: MazeDatasetCollectionConfig, **kwargs) -> MazeDatasetCollection:
157	@classmethod
158	def download(
159		cls,
160		cfg: MazeDatasetCollectionConfig,
161		**kwargs,
162	) -> "MazeDatasetCollection":
163		"(not implemented!) download a dataset collection from a config"
164		datasets = [
165			MazeDataset.download(config, **kwargs)
166			for config in cfg.maze_dataset_configs
167		]
168		return cls(cfg, datasets)

(not implemented!) download a dataset collection from a config

def serialize( self) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
170	def serialize(self) -> JSONdict:
171		"""serialize the dataset collection"""
172		return {
173			_FORMAT_KEY: "MazeDatasetCollection",
174			"cfg": self.cfg.serialize(),
175			"maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
176			"generation_metadata_collected": json_serialize(
177				self.generation_metadata_collected,
178			),
179		}

serialize the dataset collection

@classmethod
def load( cls, data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> MazeDatasetCollection:
181	@classmethod
182	def load(cls, data: JSONdict) -> "MazeDatasetCollection":
183		"""load the dataset collection from the representation created by `serialize`"""
184		assert data[_FORMAT_KEY] == "MazeDatasetCollection"
185		return cls(
186			**{
187				key: load_item_recursive(data[key], tuple())
188				for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
189			},
190		)

load the dataset collection from the representation created by serialize

def as_tokens( self, maze_tokenizer, limit: int | None = None, join_tokens_individual_maze: bool = False) -> list[list[str]] | list[str]:
193	def as_tokens(
194		self,
195		# TODO: MazeTokenizer
196		maze_tokenizer,  # noqa: ANN001
197		limit: int | None = None,
198		join_tokens_individual_maze: bool = False,
199	) -> list[list[str]] | list[str]:
200		"""return the dataset as tokens
201
202		if join_tokens_individual_maze is True, then the tokens of each maze are
203		joined with a space, and the result is a list of strings.
204		i.e.:
205		>>> dataset.as_tokens(join_tokens_individual_maze=False)
206		[["a", "b", "c"], ["d", "e", "f"]]
207		>>> dataset.as_tokens(join_tokens_individual_maze=True)
208		["a b c", "d e f"]
209		"""
210		output: list[list[str]] = [
211			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
212		]
213		if join_tokens_individual_maze:
214			return [" ".join(tokens) for tokens in output]
215		else:
216			return output

return the dataset as tokens

if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:

>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
def update_self_config(self) -> None:
218	def update_self_config(self) -> None:
219		"update the config to match the number of mazes, and update the underlying configs of each dataset"
220		# TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
221		self.cfg.__dict__["n_mazes"] = len(self)
222		for dataset in self.maze_datasets:
223			dataset.update_self_config()
224
225		self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]

update the config to match the number of mazes, and update the underlying configs of each dataset

@serializable_dataclass(kw_only=True)
class MazeDatasetCollectionConfig(maze_dataset.dataset.dataset.GPTDatasetConfig):
31@serializable_dataclass(kw_only=True)
32class MazeDatasetCollectionConfig(GPTDatasetConfig):
33	"""maze dataset collection configuration, including tokenizers and shuffle"""
34
35	# Attributes without a default cannot follow attributes with one  [misc]
36	maze_dataset_configs: list[MazeDatasetConfig] = serializable_field(  # type: ignore[misc]
37		serialization_fn=lambda configs: [config.serialize() for config in configs],
38		loading_fn=lambda data: [
39			MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"]
40		],
41	)
42
43	def summary(self) -> dict:
44		"""return a summary of the config"""
45		return dict(
46			n_mazes=self.n_mazes,
47			max_grid_n=self.max_grid_n,
48			max_grid_shape=self.max_grid_shape,
49			fname=self.to_fname(),
50			cfg_summaries=[c.summary() for c in self.maze_dataset_configs],
51		)
52
53	@property
54	def n_mazes(self) -> int:
55		"""return the total number of mazes in the collection across all dataset"""
56		return sum(config.n_mazes for config in self.maze_dataset_configs)
57
58	@property
59	def max_grid_n(self) -> int:
60		"""return the maximum grid size of the mazes in the collection"""
61		return max(config.grid_n for config in self.maze_dataset_configs)
62
63	@property
64	def max_grid_shape(self) -> CoordTup:
65		"""return the maximum grid shape of the mazes in the collection"""
66		return (self.max_grid_n, self.max_grid_n)
67
68	@property
69	def max_grid_shape_np(self) -> Coord:
70		"""return the maximum grid shape of the mazes in the collection as a numpy array"""
71		return np.array(self.max_grid_shape, dtype=np.int32)
72
73	def stable_hash_cfg(self) -> int:
74		"""return a stable hash of the config"""
75		return stable_hash(json.dumps(self.serialize()))
76
77	def to_fname(self) -> str:
78		"""convert config to a filename"""
79		return sanitize_fname(
80			f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}",
81		)

maze dataset collection configuration, including tokenizers and shuffle

MazeDatasetCollectionConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, maze_dataset_configs: list[MazeDatasetConfig])
maze_dataset_configs: list[MazeDatasetConfig]
def summary(self) -> dict:
43	def summary(self) -> dict:
44		"""return a summary of the config"""
45		return dict(
46			n_mazes=self.n_mazes,
47			max_grid_n=self.max_grid_n,
48			max_grid_shape=self.max_grid_shape,
49			fname=self.to_fname(),
50			cfg_summaries=[c.summary() for c in self.maze_dataset_configs],
51		)

return a summary of the config

n_mazes: int
53	@property
54	def n_mazes(self) -> int:
55		"""return the total number of mazes in the collection across all dataset"""
56		return sum(config.n_mazes for config in self.maze_dataset_configs)

return the total number of mazes in the collection across all dataset

max_grid_n: int
58	@property
59	def max_grid_n(self) -> int:
60		"""return the maximum grid size of the mazes in the collection"""
61		return max(config.grid_n for config in self.maze_dataset_configs)

return the maximum grid size of the mazes in the collection

max_grid_shape: tuple[int, int]
63	@property
64	def max_grid_shape(self) -> CoordTup:
65		"""return the maximum grid shape of the mazes in the collection"""
66		return (self.max_grid_n, self.max_grid_n)

return the maximum grid shape of the mazes in the collection

max_grid_shape_np: jaxtyping.Int8[ndarray, 'row_col=2']
68	@property
69	def max_grid_shape_np(self) -> Coord:
70		"""return the maximum grid shape of the mazes in the collection as a numpy array"""
71		return np.array(self.max_grid_shape, dtype=np.int32)

return the maximum grid shape of the mazes in the collection as a numpy array

def stable_hash_cfg(self) -> int:
73	def stable_hash_cfg(self) -> int:
74		"""return a stable hash of the config"""
75		return stable_hash(json.dumps(self.serialize()))

return a stable hash of the config

def to_fname(self) -> str:
77	def to_fname(self) -> str:
78		"""convert config to a filename"""
79		return sanitize_fname(
80			f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}",
81		)

convert config to a filename

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.dataset.dataset.GPTDatasetConfig
name
seq_len_min
seq_len_max
seed
applied_filters
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict