docs for maze-dataset v1.4.0
View Source on GitHub

maze_dataset.dataset

MazeDatasetConfigs are used to create a MazeDataset via MazeDataset.from_config(cfg)"

When initializing mazes, further configuration options can be specified through the from_config() factory method as necessary. Options include 1) whether to generate the dataset during runtime or load an existing dataset, 2) if and how to parallelize generation, and 3) where to store the generated dataset. Full documentation of configuration options is available in our repository [@maze-dataset-github]. Available maze generation algorithms are static methods of the LatticeMazeGenerators class.

Furthermore, a dataset of mazes can be filtered to satisfy certain properties:

dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3)

Custom filters can be specified, and several filters are included:

  • path_length(min_length: int): shortest length from the origin to target should be at least min_length.
  • start_end_distance(min_distance: int): Manhattan distance between start and end should be at least min_distance, ignoring walls.
  • remove_duplicates(...): remove mazes which are similar to others in the dataset, measured via Hamming distance.
  • remove_duplicates_fast(): remove mazes which are exactly identical to others in the dataset.

All implemented maze generation algorithms are stochastic by nature. For reproducibility, the seed parameter of MazeDatasetConfig may be set. In practice, we do not find that exact duplicates of mazes are generated with any meaningful frequency, even when generating large datasets.


 1"""`MazeDatasetConfig`s are used to create a `MazeDataset` via `MazeDataset.from_config(cfg)`"
 2
 3When initializing mazes, further configuration options can be specified through the `from_config()` factory method as necessary. Options include 1) whether to generate the dataset during runtime or load an existing dataset, 2) if and how to parallelize generation, and 3) where to store the generated dataset. Full documentation of configuration options is available in our repository [@maze-dataset-github]. Available maze generation algorithms are static methods of the `LatticeMazeGenerators` class.
 4
 5Furthermore, a dataset of mazes can be filtered to satisfy certain properties:
 6
 7```python
 8dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3)
 9```
10
11Custom filters can be specified, and several filters are included:
12
13- `path_length(min_length: int)`: shortest length from the origin to target should be at least `min_length`.
14- `start_end_distance(min_distance: int)`: Manhattan distance between start and end should be at least `min_distance`, ignoring walls.
15- `remove_duplicates(...)`: remove mazes which are similar to others in the dataset, measured via Hamming distance.
16- `remove_duplicates_fast()`: remove mazes which are exactly identical to others in the dataset.
17
18All implemented maze generation algorithms are stochastic by nature. For reproducibility, the `seed` parameter of `MazeDatasetConfig` may be set. In practice, we do not find that exact duplicates of mazes are generated with any meaningful frequency,
19even when generating large datasets.
20
21"""
22
23from maze_dataset.dataset.collected_dataset import (
24	MazeDatasetCollection,
25	MazeDatasetCollectionConfig,
26)
27from maze_dataset.dataset.maze_dataset import MazeDataset
28from maze_dataset.dataset.maze_dataset_config import MazeDatasetConfig
29
30__all__ = [
31	# submodules
32	"collected_dataset",
33	"configs",
34	"dataset",
35	"filters",
36	"maze_dataset_config",
37	"maze_dataset",
38	"rasterized",
39	"success_predict_math",
40	# dataset classes
41	"MazeDataset",
42	"MazeDatasetConfig",
43	"MazeDatasetCollection",
44	"MazeDatasetCollectionConfig",
45]

114class MazeDataset(GPTDataset[MazeDatasetConfig]):  # noqa: PLW1641
115	"""a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`"""
116
117	def __init__(
118		self,
119		cfg: MazeDatasetConfig,
120		mazes: typing.Sequence[SolvedMaze],
121		generation_metadata_collected: dict | None = None,
122	) -> None:
123		"""initialize a maze dataset from a config and a list of solved mazes"""
124		super().__init__()
125		self.cfg: MazeDatasetConfig = cfg
126		self.mazes: list[SolvedMaze] = list(mazes)
127		self.generation_metadata_collected: dict | None = generation_metadata_collected
128
129	# TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset"  [override]
130	@classmethod
131	def from_config(  # type: ignore[override]
132		cls,
133		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
134		cfg: MazeDatasetConfig,  # type: ignore[override]
135		do_generate: bool = True,
136		load_local: bool = True,
137		save_local: bool = True,
138		zanj: ZANJ | None = None,
139		do_download: bool = True,
140		local_base_path: Path = Path("data/maze_dataset"),
141		except_on_config_mismatch: bool = True,
142		allow_generation_metadata_filter_mismatch: bool = True,
143		verbose: bool = False,
144		**kwargs,
145	) -> "MazeDataset":
146		"""create a maze dataset from a config
147
148		priority of loading:
149		1. load from local
150		2. download
151		3. generate
152
153		"""
154		return cast(
155			"MazeDataset",
156			super().from_config(
157				cfg=cfg,
158				do_generate=do_generate,
159				load_local=load_local,
160				save_local=save_local,
161				zanj=zanj,
162				do_download=do_download,
163				local_base_path=local_base_path,
164				except_on_config_mismatch=except_on_config_mismatch,
165				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
166				verbose=verbose,
167				**kwargs,
168			),
169		)
170
171	def data_hash(self) -> int:
172		"""return a hash of the data"""
173		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
174
175	def __getitem__(self, i: int) -> SolvedMaze:
176		"""get a maze by index"""
177		return self.mazes[i]
178
179	def __iter__(self) -> typing.Iterator[SolvedMaze]:
180		"""iterate over the mazes"""
181		return iter(self.mazes)
182
183	def __deepcopy__(self, memo) -> "MazeDataset":  # noqa: ANN001
184		"""deepcopy the dataset
185
186		FIX: this isnt actually a deepcopy I think?
187		"""
188		return MazeDataset.load(self._serialize_full())
189
190	# TYPING: get type hints on the tokenizer here
191	@overload
192	def as_tokens(
193		self,
194		maze_tokenizer,  # noqa: ANN001
195		limit: int | None = None,
196		join_tokens_individual_maze: Literal[False] = False,
197	) -> list[list[str]]: ...
198	@overload
199	def as_tokens(
200		self,
201		maze_tokenizer,  # noqa: ANN001
202		limit: int | None = None,
203		join_tokens_individual_maze: Literal[True] = True,
204	) -> list[str]: ...
205	def as_tokens(
206		self,
207		maze_tokenizer,  # TODO: MazeTokenizer
208		limit: int | None = None,
209		join_tokens_individual_maze: bool = False,
210	) -> list[list[str]] | list[str]:
211		"""return the dataset as tokens according to the passed `maze_tokenizer`
212
213		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
214
215		if `join_tokens_individual_maze` is True, then the tokens of each maze are
216		joined with a space, and the result is a list of strings.
217		i.e.:
218
219			>>> dataset.as_tokens(join_tokens_individual_maze=False)
220			[["a", "b", "c"], ["d", "e", "f"]]
221			>>> dataset.as_tokens(join_tokens_individual_maze=True)
222			["a b c", "d e f"]
223		"""
224		output: list[list[str]] = [
225			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
226		]
227		if join_tokens_individual_maze:
228			return [" ".join(tokens) for tokens in output]
229		else:
230			return output
231
232	def __len__(self) -> int:
233		"""return the number of mazes in the dataset"""
234		return len(self.mazes)
235
236	def __eq__(self, other: object) -> bool:
237		"""compare two datasets"""
238		if not isinstance(other, MazeDataset):
239			raise NotImplementedError(
240				"can only compare with other MazeDataset objects",
241			)
242		# TODO: compare hashes of data instead of the data itself?
243		return self.cfg == other.cfg and self.mazes == other.mazes
244
245	def assert_equal(self, other: "MazeDataset") -> None:
246		"""assert that two datasets are equal"""
247		assert isinstance(other, MazeDataset)
248		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
249		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
250
251	@classmethod
252	def generate(
253		cls,
254		cfg: MazeDatasetConfig,
255		gen_parallel: bool = False,
256		pool_kwargs: dict | None = None,
257		verbose: bool = False,
258		# TODO: what to do when unexpected kwargs are passed?
259		**kwargs,  # noqa: ARG003
260	) -> "MazeDataset":
261		"""Generate a maze dataset given a config and some generation parameters"""
262		# Copy the config to avoid modifying the original
263		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
264			json.loads(json.dumps(cfg.serialize())),
265		)
266
267		if pool_kwargs is None:
268			pool_kwargs = dict()
269		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
270
271		solved_mazes: list[SolvedMaze | None]
272		# Configure tqdm for progress bar
273		tqdm_kwargs: dict = dict(
274			total=cfg_cpy.n_mazes,
275			unit="maze",
276			desc="generating & solving mazes",
277			disable=not verbose,
278		)
279		# TODO: don't use the global unless generating in parallel!
280		if gen_parallel:
281			with multiprocessing.Pool(
282				**pool_kwargs,
283				initializer=_maze_gen_init_worker,
284				initargs=(cfg_cpy,),
285			) as pool:
286				solved_mazes = list(
287					tqdm.tqdm(
288						pool.imap(_generate_maze_helper, maze_indexes),
289						**tqdm_kwargs,
290					),
291				)
292
293		else:
294			_maze_gen_init_worker(cfg_cpy)
295			solved_mazes = list(
296				tqdm.tqdm(
297					map(
298						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
299						# why does it think tolist() returns a string?
300						_generate_maze_helper,  # type: ignore[arg-type]
301						maze_indexes.tolist(),
302					),
303					**tqdm_kwargs,
304				),
305			)
306
307		# Filter out None values explicitly after ensuring all results are collected
308		solved_mazes_: list[SolvedMaze] = [
309			maze for maze in solved_mazes if maze is not None
310		]
311		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
312
313		# Update the config with the actual number of mazes
314		cfg_cpy.n_mazes = len(solved_mazes_)
315
316		dataset: MazeDataset = cls(
317			cfg=cfg_cpy,
318			mazes=solved_mazes_,
319		)
320
321		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
322
323		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
324
325		return dataset
326
327	@classmethod
328	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
329		"(not implemented yet!) download a maze dataset from the internet"
330		raise NotImplementedError("not implemented yet")
331
332	@classmethod
333	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
334		"""load from zanj/json"""
335		if data[_FORMAT_KEY] == "MazeDataset:minimal":
336			return cls._load_minimal(data)
337		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
338			return cls._load_minimal_soln_cat(data)
339		elif data[_FORMAT_KEY] == "MazeDataset":
340			if (
341				SERIALIZE_MINIMAL_THRESHOLD == -1
342			):  # Allow access to `_load_legacy` for profiling
343				return cls._load_legacy(data)
344			return cls._load_full(data)
345		else:
346			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
347			raise KeyError(
348				err_msg,
349			)
350
351	@classmethod
352	def _load_full(cls, data: JSONdict) -> "MazeDataset":
353		assert data[_FORMAT_KEY] == "MazeDataset"
354		return cls(
355			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
356			mazes=load_item_recursive(data["mazes"], tuple()),
357			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
358		)
359
360	@classmethod
361	def _load_minimal(cls, data: JSONdict) -> "MazeDataset":
362		assert data[_FORMAT_KEY] == "MazeDataset:minimal"
363		return cls(
364			cfg=MazeDatasetConfig.load(data["cfg"]),  # type: ignore[arg-type]
365			generation_metadata_collected=data["generation_metadata_collected"],  # type: ignore[arg-type]
366			mazes=[
367				SolvedMaze(
368					clist,
369					soln[:slen, ...],
370				)
371				for clist, slen, soln in zip(
372					load_item_recursive(data["maze_connection_lists"], tuple()),
373					load_item_recursive(data["maze_solution_lengths"], tuple()),
374					load_item_recursive(data["maze_solutions"], tuple()),
375					strict=False,
376					# load_item_recursive(data["maze_endpoints"], tuple()),
377				)
378			],
379		)
380
381	@classmethod
382	def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset":
383		assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat"
384
385		maze_solution_lengths = load_item_recursive(
386			data["maze_solution_lengths"],
387			tuple(),
388		)
389		maze_solutions_concat = load_item_recursive(
390			data["maze_solutions_concat"],
391			tuple(),
392		)
393		maze_solutions = np.split(
394			maze_solutions_concat,
395			np.cumsum(maze_solution_lengths)[:-1],
396			axis=0,
397		)
398
399		return cls(
400			cfg=load_item_recursive(data["cfg"], tuple()),
401			generation_metadata_collected=load_item_recursive(
402				data["generation_metadata_collected"],
403				tuple(),
404			),
405			mazes=[
406				SolvedMaze(
407					connection_list=clist,
408					solution=soln,
409				)
410				for clist, soln in zip(
411					load_item_recursive(data["maze_connection_lists"], tuple()),
412					# load_item_recursive(data["maze_endpoints"], tuple()),
413					maze_solutions,
414					strict=False,
415				)
416			],
417		)
418
419	@classmethod
420	def _load_legacy(cls, data: JSONdict) -> "MazeDataset":
421		"""Legacy `load` method from <0.5.2. Used exclusively for profiling comparison."""
422		assert data[_FORMAT_KEY] == "MazeDataset"
423		return cls(
424			**{
425				key: load_item_recursive(data[key], tuple())
426				for key in ["cfg", "mazes", "generation_metadata_collected"]
427			},
428		)
429
430	def serialize(self) -> JSONdict:
431		"""serialize to zanj/json"""
432		if (
433			SERIALIZE_MINIMAL_THRESHOLD is not None
434			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
435		):
436			return self._serialize_minimal()
437		return self._serialize_full()
438
439	def _serialize_full(self) -> JSONdict:
440		return {
441			_FORMAT_KEY: "MazeDataset",
442			"cfg": json_serialize(self.cfg),
443			"fname": self.cfg.to_fname(),
444			"mazes": json_serialize(self.mazes),
445			"generation_metadata_collected": json_serialize(
446				self.generation_metadata_collected,
447			),
448		}
449
450	def _serialize_minimal(self) -> JSONdict:
451		"alternate serialization where metadata is collected and mazes are stored in concatenated form"
452		filtered_meta: MazeDataset
453		if self.generation_metadata_collected is None:
454			filtered_meta = self.filter_by.collect_generation_meta()
455		else:
456			filtered_meta = self
457
458		max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes)
459		n_mazes: int = len(filtered_meta.mazes)
460		grid_n: int = filtered_meta.cfg.grid_n
461
462		maze_connection_lists: np.ndarray = np.empty(
463			(n_mazes, 2, grid_n, grid_n),
464			dtype=np.bool_,
465		)
466		# maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
467		maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32)
468		maze_solutions: np.ndarray = np.empty(
469			(n_mazes, max_solution_len, 2),
470			dtype=np.int8,
471		)
472
473		for idx, maze in enumerate(filtered_meta.mazes):
474			maze_connection_lists[idx] = maze.connection_list
475			# maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
476			maze_solution_lengths[idx] = maze.solution.shape[0]
477			maze_solutions[idx, : maze.solution.shape[0]] = maze.solution
478
479		return {
480			_FORMAT_KEY: "MazeDataset:minimal",
481			"cfg": json_serialize(filtered_meta.cfg),
482			"fname": filtered_meta.cfg.to_fname(),
483			"generation_metadata_collected": json_serialize(
484				filtered_meta.generation_metadata_collected,
485			),
486			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
487			# "maze_endpoints": maze_endpoints,
488			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
489			"maze_solutions": maze_solutions,  # type: ignore[dict-item]
490		}
491
492	def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict:
493		"alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form"
494		filtered_meta: MazeDataset
495		if self.generation_metadata_collected is None:
496			filtered_meta = self.filter_by.collect_generation_meta()
497		else:
498			filtered_meta = self
499
500		maze_solution_lengths: np.ndarray = np.array(
501			[m.solution.shape[0] for m in filtered_meta.mazes],
502			dtype=np.int32,
503		)
504		n_mazes: int = len(filtered_meta.mazes)
505		grid_n: int = filtered_meta.cfg.grid_n
506		total_solution_len: int = np.sum(maze_solution_lengths)
507
508		maze_connection_lists: np.ndarray = np.empty(
509			(n_mazes, 2, grid_n, grid_n),
510			dtype=np.bool_,
511		)
512		maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8)
513		maze_solutions_concat: np.ndarray = np.empty(
514			(total_solution_len, 2),
515			dtype=np.int8,
516		)
517
518		solutions_running_idx: int = 0
519		for idx, maze in enumerate(filtered_meta.mazes):
520			maze_connection_lists[idx] = maze.connection_list
521			maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos])
522			soln_len: int = maze.solution.shape[0]
523			maze_solution_lengths[idx] = soln_len
524			maze_solutions_concat[
525				solutions_running_idx : solutions_running_idx + soln_len
526			] = maze.solution
527			solutions_running_idx += soln_len
528
529		return {
530			_FORMAT_KEY: "MazeDataset:minimal_soln_cat",
531			"cfg": json_serialize(filtered_meta.cfg),
532			"fname": filtered_meta.cfg.to_fname(),
533			"generation_metadata_collected": json_serialize(
534				filtered_meta.generation_metadata_collected,
535			),
536			"maze_connection_lists": maze_connection_lists,  # type: ignore[dict-item]
537			"maze_endpoints": maze_endpoints,  # type: ignore[dict-item]
538			"maze_solution_lengths": maze_solution_lengths,  # type: ignore[dict-item]
539			"maze_solutions_concat": maze_solutions_concat,  # type: ignore[dict-item]
540		}
541
542	def update_self_config(self) -> None:
543		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
544		if self.cfg.n_mazes != len(self.mazes):
545			warnings.warn(
546				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
547			)
548			self.cfg.n_mazes = len(self.mazes)
549
550	def custom_maze_filter(
551		self,
552		method: typing.Callable[[SolvedMaze], bool],
553		**kwargs,
554	) -> "MazeDataset":
555		"""filter the dataset using a custom method"""
556		output: MazeDataset = MazeDataset(
557			cfg=copy.deepcopy(self.cfg),
558			mazes=[m for m in self.mazes if method(m, **kwargs)],
559		)
560		output.cfg.applied_filters.append(
561			{
562				"name": f"__custom__:{method.__name__}",
563				"kwargs": kwargs,
564			},
565		)
566		output.update_self_config()
567		return output

a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config

MazeDataset( cfg: MazeDatasetConfig, mazes: Sequence[maze_dataset.SolvedMaze], generation_metadata_collected: dict | None = None)
117	def __init__(
118		self,
119		cfg: MazeDatasetConfig,
120		mazes: typing.Sequence[SolvedMaze],
121		generation_metadata_collected: dict | None = None,
122	) -> None:
123		"""initialize a maze dataset from a config and a list of solved mazes"""
124		super().__init__()
125		self.cfg: MazeDatasetConfig = cfg
126		self.mazes: list[SolvedMaze] = list(mazes)
127		self.generation_metadata_collected: dict | None = generation_metadata_collected

initialize a maze dataset from a config and a list of solved mazes

generation_metadata_collected: dict | None
@classmethod
def from_config( cls, cfg: MazeDatasetConfig, do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: zanj.zanj.ZANJ | None = None, do_download: bool = True, local_base_path: pathlib._local.Path = PosixPath('data/maze_dataset'), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs) -> MazeDataset:
130	@classmethod
131	def from_config(  # type: ignore[override]
132		cls,
133		# TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig"  [override]
134		cfg: MazeDatasetConfig,  # type: ignore[override]
135		do_generate: bool = True,
136		load_local: bool = True,
137		save_local: bool = True,
138		zanj: ZANJ | None = None,
139		do_download: bool = True,
140		local_base_path: Path = Path("data/maze_dataset"),
141		except_on_config_mismatch: bool = True,
142		allow_generation_metadata_filter_mismatch: bool = True,
143		verbose: bool = False,
144		**kwargs,
145	) -> "MazeDataset":
146		"""create a maze dataset from a config
147
148		priority of loading:
149		1. load from local
150		2. download
151		3. generate
152
153		"""
154		return cast(
155			"MazeDataset",
156			super().from_config(
157				cfg=cfg,
158				do_generate=do_generate,
159				load_local=load_local,
160				save_local=save_local,
161				zanj=zanj,
162				do_download=do_download,
163				local_base_path=local_base_path,
164				except_on_config_mismatch=except_on_config_mismatch,
165				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
166				verbose=verbose,
167				**kwargs,
168			),
169		)

create a maze dataset from a config

priority of loading:

  1. load from local
  2. download
  3. generate
def data_hash(self) -> int:
171	def data_hash(self) -> int:
172		"""return a hash of the data"""
173		return stable_hash(str(tuple([x.serialize() for x in self.mazes])))

return a hash of the data

def as_tokens( self, maze_tokenizer, limit: int | None = None, join_tokens_individual_maze: bool = False) -> list[list[str]] | list[str]:
205	def as_tokens(
206		self,
207		maze_tokenizer,  # TODO: MazeTokenizer
208		limit: int | None = None,
209		join_tokens_individual_maze: bool = False,
210	) -> list[list[str]] | list[str]:
211		"""return the dataset as tokens according to the passed `maze_tokenizer`
212
213		the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular`
214
215		if `join_tokens_individual_maze` is True, then the tokens of each maze are
216		joined with a space, and the result is a list of strings.
217		i.e.:
218
219			>>> dataset.as_tokens(join_tokens_individual_maze=False)
220			[["a", "b", "c"], ["d", "e", "f"]]
221			>>> dataset.as_tokens(join_tokens_individual_maze=True)
222			["a b c", "d e f"]
223		"""
224		output: list[list[str]] = [
225			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
226		]
227		if join_tokens_individual_maze:
228			return [" ".join(tokens) for tokens in output]
229		else:
230			return output

return the dataset as tokens according to the passed maze_tokenizer

the maze_tokenizer should be either a MazeTokenizer or a MazeTokenizerModular

if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:

    >>> dataset.as_tokens(join_tokens_individual_maze=False)
    [["a", "b", "c"], ["d", "e", "f"]]
    >>> dataset.as_tokens(join_tokens_individual_maze=True)
    ["a b c", "d e f"]
def assert_equal(self, other: MazeDataset) -> None:
245	def assert_equal(self, other: "MazeDataset") -> None:
246		"""assert that two datasets are equal"""
247		assert isinstance(other, MazeDataset)
248		assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }"
249		assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"

assert that two datasets are equal

@classmethod
def generate( cls, cfg: MazeDatasetConfig, gen_parallel: bool = False, pool_kwargs: dict | None = None, verbose: bool = False, **kwargs) -> MazeDataset:
251	@classmethod
252	def generate(
253		cls,
254		cfg: MazeDatasetConfig,
255		gen_parallel: bool = False,
256		pool_kwargs: dict | None = None,
257		verbose: bool = False,
258		# TODO: what to do when unexpected kwargs are passed?
259		**kwargs,  # noqa: ARG003
260	) -> "MazeDataset":
261		"""Generate a maze dataset given a config and some generation parameters"""
262		# Copy the config to avoid modifying the original
263		cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load(
264			json.loads(json.dumps(cfg.serialize())),
265		)
266
267		if pool_kwargs is None:
268			pool_kwargs = dict()
269		maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes)  # type: ignore[assignment]
270
271		solved_mazes: list[SolvedMaze | None]
272		# Configure tqdm for progress bar
273		tqdm_kwargs: dict = dict(
274			total=cfg_cpy.n_mazes,
275			unit="maze",
276			desc="generating & solving mazes",
277			disable=not verbose,
278		)
279		# TODO: don't use the global unless generating in parallel!
280		if gen_parallel:
281			with multiprocessing.Pool(
282				**pool_kwargs,
283				initializer=_maze_gen_init_worker,
284				initargs=(cfg_cpy,),
285			) as pool:
286				solved_mazes = list(
287					tqdm.tqdm(
288						pool.imap(_generate_maze_helper, maze_indexes),
289						**tqdm_kwargs,
290					),
291				)
292
293		else:
294			_maze_gen_init_worker(cfg_cpy)
295			solved_mazes = list(
296				tqdm.tqdm(
297					map(
298						# TYPING:  error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]"  [arg-type]
299						# why does it think tolist() returns a string?
300						_generate_maze_helper,  # type: ignore[arg-type]
301						maze_indexes.tolist(),
302					),
303					**tqdm_kwargs,
304				),
305			)
306
307		# Filter out None values explicitly after ensuring all results are collected
308		solved_mazes_: list[SolvedMaze] = [
309			maze for maze in solved_mazes if maze is not None
310		]
311		# solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes))
312
313		# Update the config with the actual number of mazes
314		cfg_cpy.n_mazes = len(solved_mazes_)
315
316		dataset: MazeDataset = cls(
317			cfg=cfg_cpy,
318			mazes=solved_mazes_,
319		)
320
321		dataset.update_self_config()  # Call `update_self_config()` to ensure the dataset's config reflects changes
322
323		np.random.seed(cfg_cpy.seed)  # Reset the seed to the value in the config copy
324
325		return dataset

Generate a maze dataset given a config and some generation parameters

@classmethod
def download( cls, cfg: MazeDatasetConfig, **kwargs) -> MazeDataset:
327	@classmethod
328	def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset":
329		"(not implemented yet!) download a maze dataset from the internet"
330		raise NotImplementedError("not implemented yet")

(not implemented yet!) download a maze dataset from the internet

@classmethod
def load( cls: type[MazeDataset], data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> MazeDataset:
332	@classmethod
333	def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset":
334		"""load from zanj/json"""
335		if data[_FORMAT_KEY] == "MazeDataset:minimal":
336			return cls._load_minimal(data)
337		elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat":
338			return cls._load_minimal_soln_cat(data)
339		elif data[_FORMAT_KEY] == "MazeDataset":
340			if (
341				SERIALIZE_MINIMAL_THRESHOLD == -1
342			):  # Allow access to `_load_legacy` for profiling
343				return cls._load_legacy(data)
344			return cls._load_full(data)
345		else:
346			err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })"
347			raise KeyError(
348				err_msg,
349			)

load from zanj/json

def serialize( self) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
430	def serialize(self) -> JSONdict:
431		"""serialize to zanj/json"""
432		if (
433			SERIALIZE_MINIMAL_THRESHOLD is not None
434			and len(self) >= SERIALIZE_MINIMAL_THRESHOLD
435		):
436			return self._serialize_minimal()
437		return self._serialize_full()

serialize to zanj/json

def update_self_config(self) -> None:
542	def update_self_config(self) -> None:
543		"""update the config to match the current state of the dataset (number of mazes, such as after filtering)"""
544		if self.cfg.n_mazes != len(self.mazes):
545			warnings.warn(
546				f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}",
547			)
548			self.cfg.n_mazes = len(self.mazes)

update the config to match the current state of the dataset (number of mazes, such as after filtering)

def custom_maze_filter( self, method: Callable[[maze_dataset.SolvedMaze], bool], **kwargs) -> MazeDataset:
550	def custom_maze_filter(
551		self,
552		method: typing.Callable[[SolvedMaze], bool],
553		**kwargs,
554	) -> "MazeDataset":
555		"""filter the dataset using a custom method"""
556		output: MazeDataset = MazeDataset(
557			cfg=copy.deepcopy(self.cfg),
558			mazes=[m for m in self.mazes if method(m, **kwargs)],
559		)
560		output.cfg.applied_filters.append(
561			{
562				"name": f"__custom__:{method.__name__}",
563				"kwargs": kwargs,
564			},
565		)
566		output.update_self_config()
567		return output

filter the dataset using a custom method

@serializable_dataclass(kw_only=True, methods_no_override=['serialize'])
class MazeDatasetConfig(maze_dataset.dataset.maze_dataset_config.MazeDatasetConfig_base):
257@serializable_dataclass(kw_only=True, methods_no_override=["serialize"])
258class MazeDatasetConfig(MazeDatasetConfig_base):  # type: ignore[misc]
259	"""config object which is passed to `MazeDataset.from_config` to generate or load a dataset
260
261	# Parameters:
262	- `name : str`
263		name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname`
264	- `grid_n : int`
265		grid size of the maze (number of rows/columns)
266	- `n_mazes : int`
267		number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate.
268		see `EndpointKwargsType` for more details.
269	- `maze_ctor : Callable`
270		maze generator function. This should be a function that takes a grid size and returns a maze.
271		This will usually be one of the functions in `LatticeMazeGenerators`.
272	- `maze_ctor_kwargs : dict`
273		keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using.
274	- `endpoint_kwargs : EndpointKwargsType`
275		keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info.
276	- `applied_filters : list[dict]`
277		list of filters that have been applied to the dataset. We recommend applying filters to datasets directly,
278		but these are stored with the config in case you want to re-generate the dataset with the same filters.
279
280	"""
281
282	@property
283	def config_version(self) -> str:
284		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
285		return "1.0"
286
287	@property
288	def versions(self) -> dict:
289		"""return the versions of the config and the maze_dataset"""
290		return dict(
291			config=self.config_version,
292			maze_dataset=importlib.metadata.version("maze_dataset"),
293		)
294
295	def serialize(self) -> dict:
296		"serialize the MazeDatasetConfig with all fields and fname"
297		return {
298			**self._serialize_base(
299				applied_filters__skip__collect_generation_meta=False
300			),
301			"fname": self.to_fname(),
302			"versions": self.versions,
303		}
304
305	def summary(self) -> dict:
306		"""return a summary of the config"""
307		# do we run this to make sure it doesn't error?
308		super_summary: dict = super().summary()
309		assert super_summary
310		self_ser: dict = self.serialize()
311		return dict(
312			name=self.name,
313			fname=self.to_fname(),
314			sdc_hash=self.stable_hash_cfg(),
315			seed=self.seed,
316			seq_len_min=self.seq_len_min,
317			seq_len_max=self.seq_len_max,
318			applied_filters=self.applied_filters,
319			grid_n=self_ser["grid_n"],
320			n_mazes=self_ser["n_mazes"],
321			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
322			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
323			endpoint_kwargs=self_ser["endpoint_kwargs"],
324		)
325
326	def _to_ps_array(self) -> _PercolationSuccessArray:
327		"""Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector.
328
329		used in predicting the success rate
330		"""
331		try:
332			assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, (
333				f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }"
334			)
335			assert "p" in self.maze_ctor_kwargs, (
336				f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }"
337			)
338			assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), (
339				f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }"
340			)
341		except AssertionError as e:
342			err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }"
343			raise NoPercolationInConfigError(
344				err_msg,
345			) from e
346
347		endpoints_unique_flag: int = int(
348			# we are pretty sure it will be an int or bool here
349			self.endpoint_kwargs.get("endpoints_not_equal", True),  # type: ignore[arg-type]
350		)
351
352		# adjustment for bknutson0
353		if not (
354			self.endpoint_kwargs.get("deadend_start", False)
355			and self.endpoint_kwargs.get("deadend_end", False)
356		):
357			# we didnt train on this, but if either endpoint is not required to be in a dead end
358			# then  requiring the endpoints to be unique does not really affect the success rate
359			# (except for very small percolation values, pure percolation generation)
360			endpoints_unique_flag = 0
361
362		return np.array(
363			[
364				float(self.maze_ctor_kwargs["p"]),
365				float(self.grid_n),
366				float(
367					int(
368						self.endpoint_kwargs.get("deadend_start", False)  # type: ignore[arg-type]
369						or self.endpoint_kwargs.get("deadend_end", False),
370					),
371				),
372				float(endpoints_unique_flag),
373				float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)),
374			],
375			dtype=np.float64,
376		)
377
378	@classmethod
379	def _from_ps_array(
380		cls,
381		arr: _PercolationSuccessArray,
382		name: str = "predict",
383		n_mazes: int = 100,
384		**kwargs,
385	) -> "MazeDatasetConfig":
386		"""Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters.
387
388		# Returns:
389		- `MazeDatasetConfig`
390			Config corresponding to `arr`
391		"""
392		return cls(
393			name=name,
394			grid_n=int(arr[1]),
395			n_mazes=n_mazes,
396			maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]],
397			maze_ctor_kwargs={"p": float(arr[0])},
398			endpoint_kwargs=dict(
399				deadend_start=bool(arr[2]),
400				deadend_end=bool(arr[2]),
401				endpoints_not_equal=bool(arr[3]),
402				except_on_no_valid_endpoint=False,
403			),
404			**kwargs,
405		)
406
407	def success_fraction_estimate(
408		self,
409		except_if_all_success_expected: bool = False,
410	) -> float:
411		"""Estimate the success fraction of this config.
412
413		only valid when the generator is a percolation generator,
414		and endpoints are enforced to be dead ends
415
416		more information on where this comes from can be found in
417		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
418		- `estimate_dataset_fractions.ipynb`
419		- `maze_dataset.benchmarks.sweep_fit`
420
421		# Parameters:
422		- `except_if_all_success_expected : bool`
423			if `True`, don't raise an error if the success fraction is below the threshold.
424			will always return `1.0` if the config is not expected to fail
425
426		# Returns:
427		- `float`
428			estimated success fraction
429
430		# Raises:
431		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
432		"""
433		try:
434			return cfg_success_predict_fn(self)
435
436		except NoPercolationInConfigError as e:
437			if except_if_all_success_expected:
438				raise e  # noqa: TRY201
439			return 1.0
440
441	def success_fraction_compensate(
442		self,
443		safety_margin: float = 1.2,
444		except_if_all_success_expected: bool = False,
445		epsilon: float = 1e-2,
446	) -> "MazeDatasetConfig":
447		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
448
449		calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then
450		computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1`
451
452		more information on where this comes from can be found in
453		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
454		- `estimate_dataset_fractions.ipynb`
455		- `maze_dataset.benchmarks.sweep_fit`
456
457		# Parameters:
458		- `safety_margin : float`
459			safety margin to apply to the success fraction estimate
460			(defaults to `1.2`, or 20% more mazes than estimated)
461		- `except_if_all_success_expected : bool`
462			if `True`, don't raise an error if the success fraction is below the threshold.
463			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
464			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
465			since `safety_margin` is still applied.
466			(defaults to `False`)
467		- `epsilon : float`
468			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
469			(defaults to `1e-2`)
470
471		# Returns:
472		- `MazeDatasetConfig`
473			new config with adjusted `n_mazes`
474
475		# Raises:
476		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
477		"""
478		# compute and check the success fraction
479		success_fraction: float = self.success_fraction_estimate(
480			except_if_all_success_expected=except_if_all_success_expected,
481		)
482		if success_fraction < epsilon:
483			err_msg: str = (
484				f"{success_fraction = } is below the threshold of {epsilon = }"
485			)
486			raise SuccessChanceTooSmallError(
487				err_msg,
488			)
489
490		# compute the new number of mazes
491		n_mazes: int = self.n_mazes
492		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
493
494		# put it in a new config and return
495		cfg_dict: dict = self.serialize()
496		cfg_dict["n_mazes"] = new_n_mazes
497		return MazeDatasetConfig.load(cfg_dict)

config object which is passed to MazeDataset.from_config to generate or load a dataset

Parameters:

  • name : str name of the dataset -- this can be anything, but should be filesystem safe since we use it in the fname
  • grid_n : int grid size of the maze (number of rows/columns)
  • n_mazes : int number of mazes to request. For some combinations of endpoint_kwargs and maze_ctor, not all mazes might successfully generate. see EndpointKwargsType for more details.
  • maze_ctor : Callable maze generator function. This should be a function that takes a grid size and returns a maze. This will usually be one of the functions in LatticeMazeGenerators.
  • maze_ctor_kwargs : dict keyword arguments to pass to the maze generator function. Specific to the maze_ctor you are using.
  • endpoint_kwargs : EndpointKwargsType keyword arguments passed to LatticeMaze.generate_random_path(). see EndpointKwargsType for more info.
  • applied_filters : list[dict] list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, but these are stored with the config in case you want to re-generate the dataset with the same filters.
MazeDatasetConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, grid_n: int, n_mazes: int, maze_ctor: Callable = <function LatticeMazeGenerators.gen_dfs>, maze_ctor_kwargs: dict = <factory>, endpoint_kwargs: dict[typing.Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | None | list[tuple[int, int]]] = <factory>, _fname_loaded: str | None = None)
config_version: str
282	@property
283	def config_version(self) -> str:
284		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
285		return "1.0"

return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config

versions: dict
287	@property
288	def versions(self) -> dict:
289		"""return the versions of the config and the maze_dataset"""
290		return dict(
291			config=self.config_version,
292			maze_dataset=importlib.metadata.version("maze_dataset"),
293		)

return the versions of the config and the maze_dataset

def serialize(self) -> dict:
295	def serialize(self) -> dict:
296		"serialize the MazeDatasetConfig with all fields and fname"
297		return {
298			**self._serialize_base(
299				applied_filters__skip__collect_generation_meta=False
300			),
301			"fname": self.to_fname(),
302			"versions": self.versions,
303		}

serialize the MazeDatasetConfig with all fields and fname

def summary(self) -> dict:
305	def summary(self) -> dict:
306		"""return a summary of the config"""
307		# do we run this to make sure it doesn't error?
308		super_summary: dict = super().summary()
309		assert super_summary
310		self_ser: dict = self.serialize()
311		return dict(
312			name=self.name,
313			fname=self.to_fname(),
314			sdc_hash=self.stable_hash_cfg(),
315			seed=self.seed,
316			seq_len_min=self.seq_len_min,
317			seq_len_max=self.seq_len_max,
318			applied_filters=self.applied_filters,
319			grid_n=self_ser["grid_n"],
320			n_mazes=self_ser["n_mazes"],
321			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
322			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
323			endpoint_kwargs=self_ser["endpoint_kwargs"],
324		)

return a summary of the config

def success_fraction_estimate(self, except_if_all_success_expected: bool = False) -> float:
407	def success_fraction_estimate(
408		self,
409		except_if_all_success_expected: bool = False,
410	) -> float:
411		"""Estimate the success fraction of this config.
412
413		only valid when the generator is a percolation generator,
414		and endpoints are enforced to be dead ends
415
416		more information on where this comes from can be found in
417		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
418		- `estimate_dataset_fractions.ipynb`
419		- `maze_dataset.benchmarks.sweep_fit`
420
421		# Parameters:
422		- `except_if_all_success_expected : bool`
423			if `True`, don't raise an error if the success fraction is below the threshold.
424			will always return `1.0` if the config is not expected to fail
425
426		# Returns:
427		- `float`
428			estimated success fraction
429
430		# Raises:
431		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
432		"""
433		try:
434			return cfg_success_predict_fn(self)
435
436		except NoPercolationInConfigError as e:
437			if except_if_all_success_expected:
438				raise e  # noqa: TRY201
439			return 1.0

Estimate the success fraction of this config.

only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends

more information on where this comes from can be found in

Parameters:

  • except_if_all_success_expected : bool if True, don't raise an error if the success fraction is below the threshold. will always return 1.0 if the config is not expected to fail

Returns:

  • float estimated success fraction

Raises:

  • NoPercolationInConfigError : if the config is not expected to fail, and except_if_all_success_expected is False
def success_fraction_compensate( self, safety_margin: float = 1.2, except_if_all_success_expected: bool = False, epsilon: float = 0.01) -> MazeDatasetConfig:
441	def success_fraction_compensate(
442		self,
443		safety_margin: float = 1.2,
444		except_if_all_success_expected: bool = False,
445		epsilon: float = 1e-2,
446	) -> "MazeDatasetConfig":
447		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
448
449		calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then
450		computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1`
451
452		more information on where this comes from can be found in
453		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
454		- `estimate_dataset_fractions.ipynb`
455		- `maze_dataset.benchmarks.sweep_fit`
456
457		# Parameters:
458		- `safety_margin : float`
459			safety margin to apply to the success fraction estimate
460			(defaults to `1.2`, or 20% more mazes than estimated)
461		- `except_if_all_success_expected : bool`
462			if `True`, don't raise an error if the success fraction is below the threshold.
463			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
464			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
465			since `safety_margin` is still applied.
466			(defaults to `False`)
467		- `epsilon : float`
468			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
469			(defaults to `1e-2`)
470
471		# Returns:
472		- `MazeDatasetConfig`
473			new config with adjusted `n_mazes`
474
475		# Raises:
476		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
477		"""
478		# compute and check the success fraction
479		success_fraction: float = self.success_fraction_estimate(
480			except_if_all_success_expected=except_if_all_success_expected,
481		)
482		if success_fraction < epsilon:
483			err_msg: str = (
484				f"{success_fraction = } is below the threshold of {epsilon = }"
485			)
486			raise SuccessChanceTooSmallError(
487				err_msg,
488			)
489
490		# compute the new number of mazes
491		n_mazes: int = self.n_mazes
492		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
493
494		# put it in a new config and return
495		cfg_dict: dict = self.serialize()
496		cfg_dict["n_mazes"] = new_n_mazes
497		return MazeDatasetConfig.load(cfg_dict)

return a new MazeDatasetConfig like this one with n_mazes adjusted to compensate for the success fraction

calls MazeDatasetConfig.success_fraction_estimate() to get the success fraction, and then computes the new number of mazes as n_mazes = n_mazes * safety_margin / success_fraction + 1

more information on where this comes from can be found in

Parameters:

  • safety_margin : float safety margin to apply to the success fraction estimate (defaults to 1.2, or 20% more mazes than estimated)
  • except_if_all_success_expected : bool if True, don't raise an error if the success fraction is below the threshold. this is passed to MazeDatasetConfig.success_fraction_estimate. if your config isn't expected to fail, passing this might mean you generate more mazes than needed since safety_margin is still applied. (defaults to False)
  • epsilon : float raise SuccessChanceTooSmallError if the success fraction is below this threshold (defaults to 1e-2)

Returns:

Raises:

  • SuccessChanceTooSmallError : if the computed success fraction is below epsilon
@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

The type of the None singleton.

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

class MazeDatasetCollection(typing.Generic[~T_DatasetConfig]):
 84class MazeDatasetCollection(GPTDataset):
 85	"""a collection of maze datasets"""
 86
 87	def __init__(
 88		self,
 89		cfg: MazeDatasetCollectionConfig,
 90		maze_datasets: list[MazeDataset],
 91		generation_metadata_collected: dict | None = None,
 92	) -> None:
 93		"initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s"
 94		super().__init__()
 95		self.cfg: MazeDatasetCollectionConfig = cfg
 96		self.maze_datasets: list[MazeDataset] = list(maze_datasets)
 97		for c, ds in zip(
 98			self.cfg.maze_dataset_configs,
 99			self.maze_datasets,
100			strict=False,
101		):
102			assert c.name == ds.cfg.name
103			assert c == ds.cfg
104
105		self.generation_metadata_collected: dict | None = generation_metadata_collected
106
107	@property
108	def dataset_lengths(self) -> list[int]:
109		"""return the lengths of each dataset in the collection"""
110		return [len(dataset) for dataset in self.maze_datasets]
111
112	@property
113	def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
114		"""return the cumulative lengths of each dataset in the collection"""
115		return np.array(list(itertools.accumulate(self.dataset_lengths)))
116
117	@cached_property
118	def mazes(self) -> list[LatticeMaze]:
119		"single list of all mazes in the collection"
120		return list(
121			itertools.chain.from_iterable(
122				dataset.mazes for dataset in self.maze_datasets
123			),
124		)
125
126	def __len__(self) -> int:
127		"""return the total number of mazes in the collection"""
128		return sum(len(dataset) for dataset in self.maze_datasets)
129
130	def __getitem__(self, index: int) -> LatticeMaze:
131		"get a maze by index"
132		# find which dataset the index belongs to
133		# we add 1, since np.searchsorted returns the
134		# index of the last element that is strictly less than the target
135		# while we want the index of the last element less than or equal to the target
136		dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1))
137		index_adjusted: int = index
138		if dataset_idx > 0:
139			# if the index is 0, `dataset_idx - 1` will be -1.
140			# We just want to use the base index
141			index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1]
142		return self.maze_datasets[dataset_idx][index_adjusted]
143
144	@classmethod
145	def generate(
146		cls,
147		cfg: MazeDatasetCollectionConfig,
148		**kwargs,
149	) -> "MazeDatasetCollection":
150		"""generate a dataset collection from a config"""
151		datasets = [
152			MazeDataset.generate(config, **kwargs)
153			for config in cfg.maze_dataset_configs
154		]
155		return cls(cfg, datasets)
156
157	@classmethod
158	def download(
159		cls,
160		cfg: MazeDatasetCollectionConfig,
161		**kwargs,
162	) -> "MazeDatasetCollection":
163		"(not implemented!) download a dataset collection from a config"
164		datasets = [
165			MazeDataset.download(config, **kwargs)
166			for config in cfg.maze_dataset_configs
167		]
168		return cls(cfg, datasets)
169
170	def serialize(self) -> JSONdict:
171		"""serialize the dataset collection"""
172		return {
173			_FORMAT_KEY: "MazeDatasetCollection",
174			"cfg": self.cfg.serialize(),
175			"maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
176			"generation_metadata_collected": json_serialize(
177				self.generation_metadata_collected,
178			),
179		}
180
181	@classmethod
182	def load(cls, data: JSONdict) -> "MazeDatasetCollection":
183		"""load the dataset collection from the representation created by `serialize`"""
184		assert data[_FORMAT_KEY] == "MazeDatasetCollection"
185		return cls(
186			**{
187				key: load_item_recursive(data[key], tuple())
188				for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
189			},
190		)
191
192	# TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow?
193	def as_tokens(
194		self,
195		# TODO: MazeTokenizer
196		maze_tokenizer,  # noqa: ANN001
197		limit: int | None = None,
198		join_tokens_individual_maze: bool = False,
199	) -> list[list[str]] | list[str]:
200		"""return the dataset as tokens
201
202		if join_tokens_individual_maze is True, then the tokens of each maze are
203		joined with a space, and the result is a list of strings.
204		i.e.:
205		>>> dataset.as_tokens(join_tokens_individual_maze=False)
206		[["a", "b", "c"], ["d", "e", "f"]]
207		>>> dataset.as_tokens(join_tokens_individual_maze=True)
208		["a b c", "d e f"]
209		"""
210		output: list[list[str]] = [
211			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
212		]
213		if join_tokens_individual_maze:
214			return [" ".join(tokens) for tokens in output]
215		else:
216			return output
217
218	def update_self_config(self) -> None:
219		"update the config to match the number of mazes, and update the underlying configs of each dataset"
220		# TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
221		self.cfg.__dict__["n_mazes"] = len(self)
222		for dataset in self.maze_datasets:
223			dataset.update_self_config()
224
225		self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]

a collection of maze datasets

MazeDatasetCollection( cfg: MazeDatasetCollectionConfig, maze_datasets: list[MazeDataset], generation_metadata_collected: dict | None = None)
 87	def __init__(
 88		self,
 89		cfg: MazeDatasetCollectionConfig,
 90		maze_datasets: list[MazeDataset],
 91		generation_metadata_collected: dict | None = None,
 92	) -> None:
 93		"initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s"
 94		super().__init__()
 95		self.cfg: MazeDatasetCollectionConfig = cfg
 96		self.maze_datasets: list[MazeDataset] = list(maze_datasets)
 97		for c, ds in zip(
 98			self.cfg.maze_dataset_configs,
 99			self.maze_datasets,
100			strict=False,
101		):
102			assert c.name == ds.cfg.name
103			assert c == ds.cfg
104
105		self.generation_metadata_collected: dict | None = generation_metadata_collected

initialize the dataset collection from a MazeDatasetCollectionConfig and a list of MazeDatasets

maze_datasets: list[MazeDataset]
generation_metadata_collected: dict | None
dataset_lengths: list[int]
107	@property
108	def dataset_lengths(self) -> list[int]:
109		"""return the lengths of each dataset in the collection"""
110		return [len(dataset) for dataset in self.maze_datasets]

return the lengths of each dataset in the collection

dataset_cum_lengths: jaxtyping.Int[ndarray, 'indices']
112	@property
113	def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
114		"""return the cumulative lengths of each dataset in the collection"""
115		return np.array(list(itertools.accumulate(self.dataset_lengths)))

return the cumulative lengths of each dataset in the collection

117	@cached_property
118	def mazes(self) -> list[LatticeMaze]:
119		"single list of all mazes in the collection"
120		return list(
121			itertools.chain.from_iterable(
122				dataset.mazes for dataset in self.maze_datasets
123			),
124		)

single list of all mazes in the collection

@classmethod
def generate( cls, cfg: MazeDatasetCollectionConfig, **kwargs) -> MazeDatasetCollection:
144	@classmethod
145	def generate(
146		cls,
147		cfg: MazeDatasetCollectionConfig,
148		**kwargs,
149	) -> "MazeDatasetCollection":
150		"""generate a dataset collection from a config"""
151		datasets = [
152			MazeDataset.generate(config, **kwargs)
153			for config in cfg.maze_dataset_configs
154		]
155		return cls(cfg, datasets)

generate a dataset collection from a config

@classmethod
def download( cls, cfg: MazeDatasetCollectionConfig, **kwargs) -> MazeDatasetCollection:
157	@classmethod
158	def download(
159		cls,
160		cfg: MazeDatasetCollectionConfig,
161		**kwargs,
162	) -> "MazeDatasetCollection":
163		"(not implemented!) download a dataset collection from a config"
164		datasets = [
165			MazeDataset.download(config, **kwargs)
166			for config in cfg.maze_dataset_configs
167		]
168		return cls(cfg, datasets)

(not implemented!) download a dataset collection from a config

def serialize( self) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
170	def serialize(self) -> JSONdict:
171		"""serialize the dataset collection"""
172		return {
173			_FORMAT_KEY: "MazeDatasetCollection",
174			"cfg": self.cfg.serialize(),
175			"maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
176			"generation_metadata_collected": json_serialize(
177				self.generation_metadata_collected,
178			),
179		}

serialize the dataset collection

@classmethod
def load( cls, data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> MazeDatasetCollection:
181	@classmethod
182	def load(cls, data: JSONdict) -> "MazeDatasetCollection":
183		"""load the dataset collection from the representation created by `serialize`"""
184		assert data[_FORMAT_KEY] == "MazeDatasetCollection"
185		return cls(
186			**{
187				key: load_item_recursive(data[key], tuple())
188				for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
189			},
190		)

load the dataset collection from the representation created by serialize

def as_tokens( self, maze_tokenizer, limit: int | None = None, join_tokens_individual_maze: bool = False) -> list[list[str]] | list[str]:
193	def as_tokens(
194		self,
195		# TODO: MazeTokenizer
196		maze_tokenizer,  # noqa: ANN001
197		limit: int | None = None,
198		join_tokens_individual_maze: bool = False,
199	) -> list[list[str]] | list[str]:
200		"""return the dataset as tokens
201
202		if join_tokens_individual_maze is True, then the tokens of each maze are
203		joined with a space, and the result is a list of strings.
204		i.e.:
205		>>> dataset.as_tokens(join_tokens_individual_maze=False)
206		[["a", "b", "c"], ["d", "e", "f"]]
207		>>> dataset.as_tokens(join_tokens_individual_maze=True)
208		["a b c", "d e f"]
209		"""
210		output: list[list[str]] = [
211			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
212		]
213		if join_tokens_individual_maze:
214			return [" ".join(tokens) for tokens in output]
215		else:
216			return output

return the dataset as tokens

if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:

>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
def update_self_config(self) -> None:
218	def update_self_config(self) -> None:
219		"update the config to match the number of mazes, and update the underlying configs of each dataset"
220		# TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
221		self.cfg.__dict__["n_mazes"] = len(self)
222		for dataset in self.maze_datasets:
223			dataset.update_self_config()
224
225		self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]

update the config to match the number of mazes, and update the underlying configs of each dataset

@serializable_dataclass(kw_only=True)
class MazeDatasetCollectionConfig(maze_dataset.dataset.dataset.GPTDatasetConfig):
31@serializable_dataclass(kw_only=True)
32class MazeDatasetCollectionConfig(GPTDatasetConfig):
33	"""maze dataset collection configuration, including tokenizers and shuffle"""
34
35	# Attributes without a default cannot follow attributes with one  [misc]
36	maze_dataset_configs: list[MazeDatasetConfig] = serializable_field(  # type: ignore[misc]
37		serialization_fn=lambda configs: [config.serialize() for config in configs],
38		loading_fn=lambda data: [
39			MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"]
40		],
41	)
42
43	def summary(self) -> dict:
44		"""return a summary of the config"""
45		return dict(
46			n_mazes=self.n_mazes,
47			max_grid_n=self.max_grid_n,
48			max_grid_shape=self.max_grid_shape,
49			fname=self.to_fname(),
50			cfg_summaries=[c.summary() for c in self.maze_dataset_configs],
51		)
52
53	@property
54	def n_mazes(self) -> int:
55		"""return the total number of mazes in the collection across all dataset"""
56		return sum(config.n_mazes for config in self.maze_dataset_configs)
57
58	@property
59	def max_grid_n(self) -> int:
60		"""return the maximum grid size of the mazes in the collection"""
61		return max(config.grid_n for config in self.maze_dataset_configs)
62
63	@property
64	def max_grid_shape(self) -> CoordTup:
65		"""return the maximum grid shape of the mazes in the collection"""
66		return (self.max_grid_n, self.max_grid_n)
67
68	@property
69	def max_grid_shape_np(self) -> Coord:
70		"""return the maximum grid shape of the mazes in the collection as a numpy array"""
71		return np.array(self.max_grid_shape, dtype=np.int32)
72
73	def stable_hash_cfg(self) -> int:
74		"""return a stable hash of the config"""
75		return stable_hash(json.dumps(self.serialize()))
76
77	def to_fname(self) -> str:
78		"""convert config to a filename"""
79		return sanitize_fname(
80			f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}",
81		)

maze dataset collection configuration, including tokenizers and shuffle

MazeDatasetCollectionConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, maze_dataset_configs: list[MazeDatasetConfig])
maze_dataset_configs: list[MazeDatasetConfig]
def summary(self) -> dict:
43	def summary(self) -> dict:
44		"""return a summary of the config"""
45		return dict(
46			n_mazes=self.n_mazes,
47			max_grid_n=self.max_grid_n,
48			max_grid_shape=self.max_grid_shape,
49			fname=self.to_fname(),
50			cfg_summaries=[c.summary() for c in self.maze_dataset_configs],
51		)

return a summary of the config

n_mazes: int
53	@property
54	def n_mazes(self) -> int:
55		"""return the total number of mazes in the collection across all dataset"""
56		return sum(config.n_mazes for config in self.maze_dataset_configs)

return the total number of mazes in the collection across all dataset

max_grid_n: int
58	@property
59	def max_grid_n(self) -> int:
60		"""return the maximum grid size of the mazes in the collection"""
61		return max(config.grid_n for config in self.maze_dataset_configs)

return the maximum grid size of the mazes in the collection

max_grid_shape: tuple[int, int]
63	@property
64	def max_grid_shape(self) -> CoordTup:
65		"""return the maximum grid shape of the mazes in the collection"""
66		return (self.max_grid_n, self.max_grid_n)

return the maximum grid shape of the mazes in the collection

max_grid_shape_np: jaxtyping.Int8[ndarray, 'row_col=2']
68	@property
69	def max_grid_shape_np(self) -> Coord:
70		"""return the maximum grid shape of the mazes in the collection as a numpy array"""
71		return np.array(self.max_grid_shape, dtype=np.int32)

return the maximum grid shape of the mazes in the collection as a numpy array

def stable_hash_cfg(self) -> int:
73	def stable_hash_cfg(self) -> int:
74		"""return a stable hash of the config"""
75		return stable_hash(json.dumps(self.serialize()))

return a stable hash of the config

def to_fname(self) -> str:
77	def to_fname(self) -> str:
78		"""convert config to a filename"""
79		return sanitize_fname(
80			f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}",
81		)

convert config to a filename

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

The type of the None singleton.

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

The type of the None singleton.

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.dataset.dataset.GPTDatasetConfig
name
seq_len_min
seq_len_max
seed
applied_filters
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict