docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.dataset.maze_dataset_config

implements MazeDatasetConfig which is used to generate or load a dataset


  1"implements `MazeDatasetConfig` which is used to generate or load a dataset"
  2
  3import hashlib
  4import importlib.metadata
  5import json
  6import typing
  7import warnings
  8from typing import Callable
  9
 10import numpy as np
 11from jaxtyping import Float
 12from muutils.json_serialize import (
 13	serializable_dataclass,
 14	serializable_field,
 15)
 16from muutils.json_serialize.util import (
 17	safe_getsource,
 18	string_as_lines,
 19)
 20from muutils.misc import sanitize_fname, shorten_numerical_to_str
 21
 22from maze_dataset.constants import Coord, CoordTup
 23from maze_dataset.dataset.dataset import (
 24	GPTDatasetConfig,
 25)
 26from maze_dataset.dataset.success_predict_math import cfg_success_predict_fn
 27from maze_dataset.generation.generators import _GENERATORS_PERCOLATED, GENERATORS_MAP
 28
 29SERIALIZE_MINIMAL_THRESHOLD: int | None = 100
 30"""If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`.
 31Setting to None means that `serialize_minimal` will never be used.
 32Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only."""
 33
 34MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5
 35"length of the has, in characters, of the hash in the fname of a `MazeDatasetConfig`"
 36
 37_PercolationSuccessArray = Float[
 38	np.ndarray,
 39	"p/grid_n/deadends/endpoints_not_equal/generator_func=5",
 40]
 41
 42
 43class NoPercolationInConfigError(ValueError):
 44	"""raised when trying to predict the success fraction of a config that doesn't have percolation"""
 45
 46	pass
 47
 48
 49class SuccessChanceTooSmallError(ValueError):
 50	"""raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`"""
 51
 52	pass
 53
 54
 55def set_serialize_minimal_threshold(threshold: int | None) -> None:
 56	"get the global SERIALIZE_MINIMAL_THRESHOLD"
 57	global SERIALIZE_MINIMAL_THRESHOLD  # noqa: PLW0603
 58	SERIALIZE_MINIMAL_THRESHOLD = threshold
 59
 60
 61def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable:
 62	"get the maze constructor from `GENERATORS_MAP`"
 63	if isinstance(maze_ctor_serialized, dict):
 64		# this is both the new and old version of the serialization
 65		return GENERATORS_MAP[maze_ctor_serialized["__name__"]]
 66	elif isinstance(maze_ctor_serialized, str):
 67		# this is a version I switched to for a while but now we are switching back
 68		warnings.warn(
 69			"you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: "
 70			"https://github.com/understanding-search/maze-dataset/issues/new",
 71		)
 72		return GENERATORS_MAP[maze_ctor_serialized]
 73	else:
 74		err_msg: str = f"maze_ctor_serialized is of type {type(maze_ctor_serialized) = }, expected str or dict\n{maze_ctor_serialized = }"
 75		raise TypeError(err_msg)
 76
 77
 78EndpointKwargsType = dict[
 79	typing.Literal[
 80		"allowed_start",
 81		"allowed_end",
 82		"deadend_start",
 83		"deadend_end",
 84		"endpoints_not_equal",
 85		"except_on_no_valid_endpoint",
 86	],
 87	bool | None | list[tuple[int, int]],
 88]
 89"""type hint for `MazeDatasetConfig.endpoint_kwargs`
 90
 91- `except_on_no_valid_endpoint : bool` (default: `True`)
 92	some of the conditions (dead ends if a maze is very open, no path between given start and end) can cause the maze generation to fail.
 93	if `except_on_no_valid_endpoint` is `True`, then the maze generation will raise an error if it fails to generate a valid maze.
 94	however, if `False`, then the maze generation will return a dataset with fewer mazes than requested.
 95	If you are generating large datasets, consider using `MazeDatasetConfig.success_fraction_compensate()`
 96	this uses a pysr-created function to roughly estimate the success fraction of the dataset.
 97- `allowed_start : list[tuple[int, int]]` (default: `None`)
 98	list of allowed starting position coordinates
 99- `allowed_end : list[tuple[int, int]]` (default: `None`)
100	list of allowed ending position coordinates
101- `deadend_start : bool` (default: `False`)
102	if `True`, the starting position must be a dead end
103- `deadend_end : bool` (default: `False`)
104	if `True`, the ending position must be a dead end
105- `endpoints_not_equal : bool` (default: `True`)
106	if `True`, the starting and ending positions must be different
107
108
109
110"""
111
112
113def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType:
114	if data.get("endpoint_kwargs") is None:
115		return dict()
116
117	else:
118		return {
119			k: (
120				# bools and Nones are fine
121				v
122				if (isinstance(v, bool) or v is None)
123				# assume its a CoordList
124				else [tuple(x) for x in v]  # muutils/zanj saves tuples as lists
125			)
126			for k, v in data["endpoint_kwargs"].items()
127		}
128
129
130# not private because we need this to show up in docs
131@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"])
132class MazeDatasetConfig_base(GPTDatasetConfig):  # noqa: N801
133	"""base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here"""
134
135	# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only
136
137	grid_n: int = serializable_field()  # type: ignore[misc]
138
139	# not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters
140	n_mazes: int = serializable_field(compare=False)  # type: ignore[misc]
141
142	maze_ctor: Callable = serializable_field(
143		default=GENERATORS_MAP["gen_dfs"],
144		serialization_fn=lambda gen_func: {
145			"__name__": gen_func.__name__,
146			"__module__": gen_func.__module__,
147			# NOTE: this was causing hashing issues on 3.13 vs older versions because somehow,
148			# the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY
149			# so we just uh. strip it all now.
150			# see:
151			# https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53
152			# https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53
153			# https://www.diffchecker.com/tqIMSevy/
154			# update: we also need to filter for empty lines. B)
155			"__doc__": [
156				line.strip()
157				for line in string_as_lines(gen_func.__doc__)
158				if line.strip()
159			],
160			"source_code": safe_getsource(gen_func),
161		},
162		loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]),
163		assert_type=False,  # TODO: check the type here once muutils supports checking Callable signatures
164	)
165
166	maze_ctor_kwargs: dict = serializable_field(
167		default_factory=dict,
168		serialization_fn=lambda kwargs: kwargs,
169		loading_fn=lambda data: (
170			dict()
171			if data.get("maze_ctor_kwargs", None)
172			is None  # this should handle the backwards compatibility
173			else data["maze_ctor_kwargs"]
174		),
175	)
176
177	endpoint_kwargs: EndpointKwargsType = serializable_field(
178		default_factory=dict,
179		serialization_fn=lambda kwargs: kwargs,
180		loading_fn=_load_endpoint_kwargs,
181		assert_type=False,
182	)
183
184	# NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*,
185	# and so we need to save an `None` here or this wont load the `fname` field on load
186	# this is a total mess, and very confusing, and entirely my fault
187	_fname_loaded: str | None = serializable_field(
188		default=None,
189		compare=False,
190		serialization_fn=lambda _: None,
191		loading_fn=lambda data: data.get("fname", None),
192	)
193
194	@property
195	def grid_shape(self) -> CoordTup:
196		"""return the shape of the grid as a tuple"""
197		return (self.grid_n, self.grid_n)
198
199	@property
200	def grid_shape_np(self) -> Coord:
201		"""return the shape of the grid as a numpy array"""
202		return np.array(self.grid_shape)
203
204	@property
205	def max_grid_n(self) -> int:
206		"""return the maximum of the grid shape"""
207		return max(self.grid_shape)
208
209	def _serialize_base(
210		self, applied_filters__skip__collect_generation_meta: bool = True
211	) -> dict:
212		"""serialize the base config for user in `stable_hash_cfg()` and `to_fname()`
213
214		- note that the _fname_loaded will always be `None` to avoid infinite recursion
215		- note that we **do not** by default include information about metadata collection here,
216		since otherwise loading a dataset that we minified by collecting the metadata would be impossible
217		but for comparing things, we do store it when serializing properly by setting
218		`applied_filters__skip__collect_generation_meta=False`
219		"""
220		serialized: dict = MazeDatasetConfig_base.serialize(self)
221		if applied_filters__skip__collect_generation_meta:
222			serialized["applied_filters"] = [
223				x
224				for x in serialized["applied_filters"]
225				if x.get("name", None) != "collect_generation_meta"
226			]
227		return serialized
228
229	def _stable_str_dump(self) -> str:
230		return json.dumps(
231			self._serialize_base(),
232			sort_keys=True,
233			indent=None,
234		)
235
236	def stable_hash_cfg(self) -> int:
237		"""return a stable hash of the config"""
238		return int.from_bytes(
239			hashlib.md5(  # noqa: S324
240				bytes(self._stable_str_dump(), "ascii")
241			).digest(),
242			"big",
243		)
244
245	def to_fname(self) -> str:
246		"""return a unique identifier (valid as a filename) for this config"""
247		n_mazes_str: str = shorten_numerical_to_str(self.n_mazes)
248		maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_")
249		hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH
250		return sanitize_fname(
251			f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}",
252		)
253
254
255# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only
256@serializable_dataclass(kw_only=True, methods_no_override=["serialize"])
257class MazeDatasetConfig(MazeDatasetConfig_base):  # type: ignore[misc]
258	"""config object which is passed to `MazeDataset.from_config` to generate or load a dataset
259
260	# Parameters:
261	- `name : str`
262		name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname`
263	- `grid_n : int`
264		grid size of the maze (number of rows/columns)
265	- `n_mazes : int`
266		number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate.
267		see `EndpointKwargsType` for more details.
268	- `maze_ctor : Callable`
269		maze generator function. This should be a function that takes a grid size and returns a maze.
270		This will usually be one of the functions in `LatticeMazeGenerators`.
271	- `maze_ctor_kwargs : dict`
272		keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using.
273	- `endpoint_kwargs : EndpointKwargsType`
274		keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info.
275	- `applied_filters : list[dict]`
276		list of filters that have been applied to the dataset. We recommend applying filters to datasets directly,
277		but these are stored with the config in case you want to re-generate the dataset with the same filters.
278
279	"""
280
281	@property
282	def config_version(self) -> str:
283		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
284		return "1.0"
285
286	@property
287	def versions(self) -> dict:
288		"""return the versions of the config and the maze_dataset"""
289		return dict(
290			config=self.config_version,
291			maze_dataset=importlib.metadata.version("maze_dataset"),
292		)
293
294	def serialize(self) -> dict:
295		"serialize the MazeDatasetConfig with all fields and fname"
296		return {
297			**self._serialize_base(
298				applied_filters__skip__collect_generation_meta=False
299			),
300			"fname": self.to_fname(),
301			"versions": self.versions,
302		}
303
304	def summary(self) -> dict:
305		"""return a summary of the config"""
306		# do we run this to make sure it doesn't error?
307		super_summary: dict = super().summary()
308		assert super_summary
309		self_ser: dict = self.serialize()
310		return dict(
311			name=self.name,
312			fname=self.to_fname(),
313			sdc_hash=self.stable_hash_cfg(),
314			seed=self.seed,
315			seq_len_min=self.seq_len_min,
316			seq_len_max=self.seq_len_max,
317			applied_filters=self.applied_filters,
318			grid_n=self_ser["grid_n"],
319			n_mazes=self_ser["n_mazes"],
320			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
321			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
322			endpoint_kwargs=self_ser["endpoint_kwargs"],
323		)
324
325	def _to_ps_array(self) -> _PercolationSuccessArray:
326		"""Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector.
327
328		used in predicting the success rate
329		"""
330		try:
331			assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, (
332				f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }"
333			)
334			assert "p" in self.maze_ctor_kwargs, (
335				f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }"
336			)
337			assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), (
338				f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }"
339			)
340		except AssertionError as e:
341			err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }"
342			raise NoPercolationInConfigError(
343				err_msg,
344			) from e
345
346		endpoints_unique_flag: int = int(
347			# we are pretty sure it will be an int or bool here
348			self.endpoint_kwargs.get("endpoints_not_equal", True),  # type: ignore[arg-type]
349		)
350
351		# adjustment for bknutson0
352		if not (
353			self.endpoint_kwargs.get("deadend_start", False)
354			and self.endpoint_kwargs.get("deadend_end", False)
355		):
356			# we didnt train on this, but if either endpoint is not required to be in a dead end
357			# then  requiring the endpoints to be unique does not really affect the success rate
358			# (except for very small percolation values, pure percolation generation)
359			endpoints_unique_flag = 0
360
361		return np.array(
362			[
363				float(self.maze_ctor_kwargs["p"]),
364				float(self.grid_n),
365				float(
366					int(
367						self.endpoint_kwargs.get("deadend_start", False)  # type: ignore[arg-type]
368						or self.endpoint_kwargs.get("deadend_end", False),
369					),
370				),
371				float(endpoints_unique_flag),
372				float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)),
373			],
374			dtype=np.float64,
375		)
376
377	@classmethod
378	def _from_ps_array(
379		cls,
380		arr: _PercolationSuccessArray,
381		name: str = "predict",
382		n_mazes: int = 100,
383		**kwargs,
384	) -> "MazeDatasetConfig":
385		"""Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters.
386
387		# Returns:
388		- `MazeDatasetConfig`
389			Config corresponding to `arr`
390		"""
391		return cls(
392			name=name,
393			grid_n=int(arr[1]),
394			n_mazes=n_mazes,
395			maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]],
396			maze_ctor_kwargs={"p": float(arr[0])},
397			endpoint_kwargs=dict(
398				deadend_start=bool(arr[2]),
399				deadend_end=bool(arr[2]),
400				endpoints_not_equal=bool(arr[3]),
401				except_on_no_valid_endpoint=False,
402			),
403			**kwargs,
404		)
405
406	def success_fraction_estimate(
407		self,
408		except_if_all_success_expected: bool = False,
409	) -> float:
410		"""Estimate the success fraction of this config.
411
412		only valid when the generator is a percolation generator,
413		and endpoints are enforced to be dead ends
414
415		more information on where this comes from can be found in
416		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
417		- `estimate_dataset_fractions.ipynb`
418		- `maze_dataset.benchmarks.sweep_fit`
419
420		# Parameters:
421		- `except_if_all_success_expected : bool`
422			if `True`, don't raise an error if the success fraction is below the threshold.
423			will always return `1.0` if the config is not expected to fail
424
425		# Returns:
426		- `float`
427			estimated success fraction
428
429		# Raises:
430		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
431		"""
432		try:
433			return cfg_success_predict_fn(self)
434
435		except NoPercolationInConfigError as e:
436			if except_if_all_success_expected:
437				raise e  # noqa: TRY201
438			return 1.0
439
440	def success_fraction_compensate(
441		self,
442		safety_margin: float = 1.2,
443		except_if_all_success_expected: bool = False,
444		epsilon: float = 1e-2,
445	) -> "MazeDatasetConfig":
446		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
447
448		calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then
449		computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1`
450
451		more information on where this comes from can be found in
452		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
453		- `estimate_dataset_fractions.ipynb`
454		- `maze_dataset.benchmarks.sweep_fit`
455
456		# Parameters:
457		- `safety_margin : float`
458			safety margin to apply to the success fraction estimate
459			(defaults to `1.2`, or 20% more mazes than estimated)
460		- `except_if_all_success_expected : bool`
461			if `True`, don't raise an error if the success fraction is below the threshold.
462			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
463			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
464			since `safety_margin` is still applied.
465			(defaults to `False`)
466		- `epsilon : float`
467			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
468			(defaults to `1e-2`)
469
470		# Returns:
471		- `MazeDatasetConfig`
472			new config with adjusted `n_mazes`
473
474		# Raises:
475		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
476		"""
477		# compute and check the success fraction
478		success_fraction: float = self.success_fraction_estimate(
479			except_if_all_success_expected=except_if_all_success_expected,
480		)
481		if success_fraction < epsilon:
482			err_msg: str = (
483				f"{success_fraction = } is below the threshold of {epsilon = }"
484			)
485			raise SuccessChanceTooSmallError(
486				err_msg,
487			)
488
489		# compute the new number of mazes
490		n_mazes: int = self.n_mazes
491		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
492
493		# put it in a new config and return
494		cfg_dict: dict = self.serialize()
495		cfg_dict["n_mazes"] = new_n_mazes
496		return MazeDatasetConfig.load(cfg_dict)

SERIALIZE_MINIMAL_THRESHOLD: int | None = 100

If n_mazes>=SERIALIZE_MINIMAL_THRESHOLD, then the MazeDataset will use serialize_minimal. Setting to None means that serialize_minimal will never be used. Set to -1 to make calls to read use MazeDataset._load_legacy. Used for profiling only.

MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5

length of the has, in characters, of the hash in the fname of a MazeDatasetConfig

class NoPercolationInConfigError(builtins.ValueError):
44class NoPercolationInConfigError(ValueError):
45	"""raised when trying to predict the success fraction of a config that doesn't have percolation"""
46
47	pass

raised when trying to predict the success fraction of a config that doesn't have percolation

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class SuccessChanceTooSmallError(builtins.ValueError):
50class SuccessChanceTooSmallError(ValueError):
51	"""raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`"""
52
53	pass

raised when the success fraction is below the threshold in MazeDatasetConfig.success_fraction_compensate

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
def set_serialize_minimal_threshold(threshold: int | None) -> None:
56def set_serialize_minimal_threshold(threshold: int | None) -> None:
57	"get the global SERIALIZE_MINIMAL_THRESHOLD"
58	global SERIALIZE_MINIMAL_THRESHOLD  # noqa: PLW0603
59	SERIALIZE_MINIMAL_THRESHOLD = threshold

get the global SERIALIZE_MINIMAL_THRESHOLD

EndpointKwargsType = dict[typing.Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | None | list[tuple[int, int]]]

type hint for MazeDatasetConfig.endpoint_kwargs

  • except_on_no_valid_endpoint : bool (default: True) some of the conditions (dead ends if a maze is very open, no path between given start and end) can cause the maze generation to fail. if except_on_no_valid_endpoint is True, then the maze generation will raise an error if it fails to generate a valid maze. however, if False, then the maze generation will return a dataset with fewer mazes than requested. If you are generating large datasets, consider using MazeDatasetConfig.success_fraction_compensate() this uses a pysr-created function to roughly estimate the success fraction of the dataset.
  • allowed_start : list[tuple[int, int]] (default: None) list of allowed starting position coordinates
  • allowed_end : list[tuple[int, int]] (default: None) list of allowed ending position coordinates
  • deadend_start : bool (default: False) if True, the starting position must be a dead end
  • deadend_end : bool (default: False) if True, the ending position must be a dead end
  • endpoints_not_equal : bool (default: True) if True, the starting and ending positions must be different
@serializable_dataclass(kw_only=True, properties_to_serialize=['grid_shape'])
class MazeDatasetConfig_base(maze_dataset.dataset.dataset.GPTDatasetConfig):
132@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"])
133class MazeDatasetConfig_base(GPTDatasetConfig):  # noqa: N801
134	"""base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here"""
135
136	# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only
137
138	grid_n: int = serializable_field()  # type: ignore[misc]
139
140	# not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters
141	n_mazes: int = serializable_field(compare=False)  # type: ignore[misc]
142
143	maze_ctor: Callable = serializable_field(
144		default=GENERATORS_MAP["gen_dfs"],
145		serialization_fn=lambda gen_func: {
146			"__name__": gen_func.__name__,
147			"__module__": gen_func.__module__,
148			# NOTE: this was causing hashing issues on 3.13 vs older versions because somehow,
149			# the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY
150			# so we just uh. strip it all now.
151			# see:
152			# https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53
153			# https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53
154			# https://www.diffchecker.com/tqIMSevy/
155			# update: we also need to filter for empty lines. B)
156			"__doc__": [
157				line.strip()
158				for line in string_as_lines(gen_func.__doc__)
159				if line.strip()
160			],
161			"source_code": safe_getsource(gen_func),
162		},
163		loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]),
164		assert_type=False,  # TODO: check the type here once muutils supports checking Callable signatures
165	)
166
167	maze_ctor_kwargs: dict = serializable_field(
168		default_factory=dict,
169		serialization_fn=lambda kwargs: kwargs,
170		loading_fn=lambda data: (
171			dict()
172			if data.get("maze_ctor_kwargs", None)
173			is None  # this should handle the backwards compatibility
174			else data["maze_ctor_kwargs"]
175		),
176	)
177
178	endpoint_kwargs: EndpointKwargsType = serializable_field(
179		default_factory=dict,
180		serialization_fn=lambda kwargs: kwargs,
181		loading_fn=_load_endpoint_kwargs,
182		assert_type=False,
183	)
184
185	# NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*,
186	# and so we need to save an `None` here or this wont load the `fname` field on load
187	# this is a total mess, and very confusing, and entirely my fault
188	_fname_loaded: str | None = serializable_field(
189		default=None,
190		compare=False,
191		serialization_fn=lambda _: None,
192		loading_fn=lambda data: data.get("fname", None),
193	)
194
195	@property
196	def grid_shape(self) -> CoordTup:
197		"""return the shape of the grid as a tuple"""
198		return (self.grid_n, self.grid_n)
199
200	@property
201	def grid_shape_np(self) -> Coord:
202		"""return the shape of the grid as a numpy array"""
203		return np.array(self.grid_shape)
204
205	@property
206	def max_grid_n(self) -> int:
207		"""return the maximum of the grid shape"""
208		return max(self.grid_shape)
209
210	def _serialize_base(
211		self, applied_filters__skip__collect_generation_meta: bool = True
212	) -> dict:
213		"""serialize the base config for user in `stable_hash_cfg()` and `to_fname()`
214
215		- note that the _fname_loaded will always be `None` to avoid infinite recursion
216		- note that we **do not** by default include information about metadata collection here,
217		since otherwise loading a dataset that we minified by collecting the metadata would be impossible
218		but for comparing things, we do store it when serializing properly by setting
219		`applied_filters__skip__collect_generation_meta=False`
220		"""
221		serialized: dict = MazeDatasetConfig_base.serialize(self)
222		if applied_filters__skip__collect_generation_meta:
223			serialized["applied_filters"] = [
224				x
225				for x in serialized["applied_filters"]
226				if x.get("name", None) != "collect_generation_meta"
227			]
228		return serialized
229
230	def _stable_str_dump(self) -> str:
231		return json.dumps(
232			self._serialize_base(),
233			sort_keys=True,
234			indent=None,
235		)
236
237	def stable_hash_cfg(self) -> int:
238		"""return a stable hash of the config"""
239		return int.from_bytes(
240			hashlib.md5(  # noqa: S324
241				bytes(self._stable_str_dump(), "ascii")
242			).digest(),
243			"big",
244		)
245
246	def to_fname(self) -> str:
247		"""return a unique identifier (valid as a filename) for this config"""
248		n_mazes_str: str = shorten_numerical_to_str(self.n_mazes)
249		maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_")
250		hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH
251		return sanitize_fname(
252			f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}",
253		)

base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here

MazeDatasetConfig_base( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, grid_n: int, n_mazes: int, maze_ctor: Callable = <function LatticeMazeGenerators.gen_dfs>, maze_ctor_kwargs: dict = <factory>, endpoint_kwargs: dict[typing.Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | None | list[tuple[int, int]]] = <factory>, _fname_loaded: str | None = None)
grid_n: int
n_mazes: int
@staticmethod
def maze_ctor( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, randomized_stack: bool = False, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
 61	@staticmethod
 62	def gen_dfs(
 63		grid_shape: Coord | CoordTup,
 64		lattice_dim: int = 2,
 65		accessible_cells: float | None = None,
 66		max_tree_depth: float | None = None,
 67		do_forks: bool = True,
 68		randomized_stack: bool = False,
 69		start_coord: Coord | None = None,
 70	) -> LatticeMaze:
 71		"""generate a lattice maze using depth first search, iterative
 72
 73		# Arguments
 74		- `grid_shape: Coord`: the shape of the grid
 75		- `lattice_dim: int`: the dimension of the lattice
 76			(default: `2`)
 77		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 78			(default: `None`)
 79		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 80			(default: `None`)
 81		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 82		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 83
 84		# algorithm
 85		1. Choose the initial cell, mark it as visited and push it to the stack
 86		2. While the stack is not empty
 87			1. Pop a cell from the stack and make it a current cell
 88			2. If the current cell has any neighbours which have not been visited
 89				1. Push the current cell to the stack
 90				2. Choose one of the unvisited neighbours
 91				3. Remove the wall between the current cell and the chosen cell
 92				4. Mark the chosen cell as visited and push it to the stack
 93		"""
 94		# Default values if no constraints have been passed
 95		grid_shape_: Coord = np.array(grid_shape)
 96		n_total_cells: int = int(np.prod(grid_shape_))
 97
 98		n_accessible_cells: int
 99		if accessible_cells is None:
100			n_accessible_cells = n_total_cells
101		elif isinstance(accessible_cells, float):
102			assert accessible_cells <= 1, (
103				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
104			)
105
106			n_accessible_cells = int(accessible_cells * n_total_cells)
107		else:
108			assert isinstance(accessible_cells, int)
109			n_accessible_cells = accessible_cells
110
111		if max_tree_depth is None:
112			max_tree_depth = (
113				2 * n_total_cells
114			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
115		elif isinstance(max_tree_depth, float):
116			assert max_tree_depth <= 1, (
117				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
118			)
119
120			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
121
122		# choose a random start coord
123		start_coord = _random_start_coord(grid_shape_, start_coord)
124
125		# initialize the maze with no connections
126		connection_list: ConnectionList = np.zeros(
127			(lattice_dim, grid_shape_[0], grid_shape_[1]),
128			dtype=np.bool_,
129		)
130
131		# initialize the stack with the target coord
132		visited_cells: set[tuple[int, int]] = set()
133		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
134		stack: list[Coord] = [start_coord]
135
136		# initialize tree_depth_counter
137		current_tree_depth: int = 1
138
139		# loop until the stack is empty or n_connected_cells is reached
140		while stack and (len(visited_cells) < n_accessible_cells):
141			# get the current coord from the stack
142			current_coord: Coord
143			if randomized_stack:
144				current_coord = stack.pop(random.randint(0, len(stack) - 1))
145			else:
146				current_coord = stack.pop()
147
148			# filter neighbors by being within grid bounds and being unvisited
149			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
150				(neighbor, delta)
151				for neighbor, delta in zip(
152					current_coord + NEIGHBORS_MASK,
153					NEIGHBORS_MASK,
154					strict=False,
155				)
156				if (
157					(tuple(neighbor) not in visited_cells)
158					and (0 <= neighbor[0] < grid_shape_[0])
159					and (0 <= neighbor[1] < grid_shape_[1])
160				)
161			]
162
163			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
164			if unvisited_neighbors_deltas and (
165				current_tree_depth <= max_tree_depth / 2
166			):
167				# if we want a maze without forks, simply don't add the current coord back to the stack
168				if do_forks and (len(unvisited_neighbors_deltas) > 1):
169					stack.append(current_coord)
170
171				# choose one of the unvisited neighbors
172				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
173
174				# add connection
175				dim: int = int(np.argmax(np.abs(delta)))
176				# if positive, down/right from current coord
177				# if negative, up/left from current coord (down/right from neighbor)
178				clist_node: Coord = (
179					current_coord if (delta.sum() > 0) else chosen_neighbor
180				)
181				connection_list[dim, clist_node[0], clist_node[1]] = True
182
183				# add to visited cells and stack
184				visited_cells.add(tuple(chosen_neighbor))
185				stack.append(chosen_neighbor)
186
187				# Update current tree depth
188				current_tree_depth += 1
189			else:
190				current_tree_depth -= 1
191
192		return LatticeMaze(
193			connection_list=connection_list,
194			generation_meta=dict(
195				func_name="gen_dfs",
196				grid_shape=grid_shape_,
197				start_coord=start_coord,
198				n_accessible_cells=int(n_accessible_cells),
199				max_tree_depth=int(max_tree_depth),
200				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
201				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
202				# treated as fully connected even when it is most certainly not, causing solving the maze to break
203				fully_connected=bool(len(visited_cells) == n_total_cells),
204				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
205			),
206		)

generate a lattice maze using depth first search, iterative

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • accessible_cells: int | float |None: the number of accessible cells in the maze. If None, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default: None)
  • max_tree_depth: int | float | None: the maximum depth of the tree. If None, defaults to 2 * accessible_cells. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default: None)
  • do_forks: bool: whether to allow forks in the maze. If False, the maze will be have no forks and will be a simple hallway.
  • start_coord: Coord | None: the starting coordinate of the generation algorithm. If None, defaults to a random coordinate.

algorithm

  1. Choose the initial cell, mark it as visited and push it to the stack
  2. While the stack is not empty
    1. Pop a cell from the stack and make it a current cell
    2. If the current cell has any neighbours which have not been visited
      1. Push the current cell to the stack
      2. Choose one of the unvisited neighbours
      3. Remove the wall between the current cell and the chosen cell
      4. Mark the chosen cell as visited and push it to the stack
maze_ctor_kwargs: dict
endpoint_kwargs: dict[typing.Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | None | list[tuple[int, int]]]
grid_shape: tuple[int, int]
195	@property
196	def grid_shape(self) -> CoordTup:
197		"""return the shape of the grid as a tuple"""
198		return (self.grid_n, self.grid_n)

return the shape of the grid as a tuple

grid_shape_np: jaxtyping.Int8[ndarray, 'row_col=2']
200	@property
201	def grid_shape_np(self) -> Coord:
202		"""return the shape of the grid as a numpy array"""
203		return np.array(self.grid_shape)

return the shape of the grid as a numpy array

max_grid_n: int
205	@property
206	def max_grid_n(self) -> int:
207		"""return the maximum of the grid shape"""
208		return max(self.grid_shape)

return the maximum of the grid shape

def stable_hash_cfg(self) -> int:
237	def stable_hash_cfg(self) -> int:
238		"""return a stable hash of the config"""
239		return int.from_bytes(
240			hashlib.md5(  # noqa: S324
241				bytes(self._stable_str_dump(), "ascii")
242			).digest(),
243			"big",
244		)

return a stable hash of the config

def to_fname(self) -> str:
246	def to_fname(self) -> str:
247		"""return a unique identifier (valid as a filename) for this config"""
248		n_mazes_str: str = shorten_numerical_to_str(self.n_mazes)
249		maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_")
250		hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH
251		return sanitize_fname(
252			f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}",
253		)

return a unique identifier (valid as a filename) for this config

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.dataset.dataset.GPTDatasetConfig
name
seq_len_min
seq_len_max
seed
applied_filters
summary
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(kw_only=True, methods_no_override=['serialize'])
class MazeDatasetConfig(MazeDatasetConfig_base):
257@serializable_dataclass(kw_only=True, methods_no_override=["serialize"])
258class MazeDatasetConfig(MazeDatasetConfig_base):  # type: ignore[misc]
259	"""config object which is passed to `MazeDataset.from_config` to generate or load a dataset
260
261	# Parameters:
262	- `name : str`
263		name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname`
264	- `grid_n : int`
265		grid size of the maze (number of rows/columns)
266	- `n_mazes : int`
267		number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate.
268		see `EndpointKwargsType` for more details.
269	- `maze_ctor : Callable`
270		maze generator function. This should be a function that takes a grid size and returns a maze.
271		This will usually be one of the functions in `LatticeMazeGenerators`.
272	- `maze_ctor_kwargs : dict`
273		keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using.
274	- `endpoint_kwargs : EndpointKwargsType`
275		keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info.
276	- `applied_filters : list[dict]`
277		list of filters that have been applied to the dataset. We recommend applying filters to datasets directly,
278		but these are stored with the config in case you want to re-generate the dataset with the same filters.
279
280	"""
281
282	@property
283	def config_version(self) -> str:
284		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
285		return "1.0"
286
287	@property
288	def versions(self) -> dict:
289		"""return the versions of the config and the maze_dataset"""
290		return dict(
291			config=self.config_version,
292			maze_dataset=importlib.metadata.version("maze_dataset"),
293		)
294
295	def serialize(self) -> dict:
296		"serialize the MazeDatasetConfig with all fields and fname"
297		return {
298			**self._serialize_base(
299				applied_filters__skip__collect_generation_meta=False
300			),
301			"fname": self.to_fname(),
302			"versions": self.versions,
303		}
304
305	def summary(self) -> dict:
306		"""return a summary of the config"""
307		# do we run this to make sure it doesn't error?
308		super_summary: dict = super().summary()
309		assert super_summary
310		self_ser: dict = self.serialize()
311		return dict(
312			name=self.name,
313			fname=self.to_fname(),
314			sdc_hash=self.stable_hash_cfg(),
315			seed=self.seed,
316			seq_len_min=self.seq_len_min,
317			seq_len_max=self.seq_len_max,
318			applied_filters=self.applied_filters,
319			grid_n=self_ser["grid_n"],
320			n_mazes=self_ser["n_mazes"],
321			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
322			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
323			endpoint_kwargs=self_ser["endpoint_kwargs"],
324		)
325
326	def _to_ps_array(self) -> _PercolationSuccessArray:
327		"""Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector.
328
329		used in predicting the success rate
330		"""
331		try:
332			assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, (
333				f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }"
334			)
335			assert "p" in self.maze_ctor_kwargs, (
336				f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }"
337			)
338			assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), (
339				f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }"
340			)
341		except AssertionError as e:
342			err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }"
343			raise NoPercolationInConfigError(
344				err_msg,
345			) from e
346
347		endpoints_unique_flag: int = int(
348			# we are pretty sure it will be an int or bool here
349			self.endpoint_kwargs.get("endpoints_not_equal", True),  # type: ignore[arg-type]
350		)
351
352		# adjustment for bknutson0
353		if not (
354			self.endpoint_kwargs.get("deadend_start", False)
355			and self.endpoint_kwargs.get("deadend_end", False)
356		):
357			# we didnt train on this, but if either endpoint is not required to be in a dead end
358			# then  requiring the endpoints to be unique does not really affect the success rate
359			# (except for very small percolation values, pure percolation generation)
360			endpoints_unique_flag = 0
361
362		return np.array(
363			[
364				float(self.maze_ctor_kwargs["p"]),
365				float(self.grid_n),
366				float(
367					int(
368						self.endpoint_kwargs.get("deadend_start", False)  # type: ignore[arg-type]
369						or self.endpoint_kwargs.get("deadend_end", False),
370					),
371				),
372				float(endpoints_unique_flag),
373				float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)),
374			],
375			dtype=np.float64,
376		)
377
378	@classmethod
379	def _from_ps_array(
380		cls,
381		arr: _PercolationSuccessArray,
382		name: str = "predict",
383		n_mazes: int = 100,
384		**kwargs,
385	) -> "MazeDatasetConfig":
386		"""Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters.
387
388		# Returns:
389		- `MazeDatasetConfig`
390			Config corresponding to `arr`
391		"""
392		return cls(
393			name=name,
394			grid_n=int(arr[1]),
395			n_mazes=n_mazes,
396			maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]],
397			maze_ctor_kwargs={"p": float(arr[0])},
398			endpoint_kwargs=dict(
399				deadend_start=bool(arr[2]),
400				deadend_end=bool(arr[2]),
401				endpoints_not_equal=bool(arr[3]),
402				except_on_no_valid_endpoint=False,
403			),
404			**kwargs,
405		)
406
407	def success_fraction_estimate(
408		self,
409		except_if_all_success_expected: bool = False,
410	) -> float:
411		"""Estimate the success fraction of this config.
412
413		only valid when the generator is a percolation generator,
414		and endpoints are enforced to be dead ends
415
416		more information on where this comes from can be found in
417		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
418		- `estimate_dataset_fractions.ipynb`
419		- `maze_dataset.benchmarks.sweep_fit`
420
421		# Parameters:
422		- `except_if_all_success_expected : bool`
423			if `True`, don't raise an error if the success fraction is below the threshold.
424			will always return `1.0` if the config is not expected to fail
425
426		# Returns:
427		- `float`
428			estimated success fraction
429
430		# Raises:
431		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
432		"""
433		try:
434			return cfg_success_predict_fn(self)
435
436		except NoPercolationInConfigError as e:
437			if except_if_all_success_expected:
438				raise e  # noqa: TRY201
439			return 1.0
440
441	def success_fraction_compensate(
442		self,
443		safety_margin: float = 1.2,
444		except_if_all_success_expected: bool = False,
445		epsilon: float = 1e-2,
446	) -> "MazeDatasetConfig":
447		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
448
449		calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then
450		computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1`
451
452		more information on where this comes from can be found in
453		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
454		- `estimate_dataset_fractions.ipynb`
455		- `maze_dataset.benchmarks.sweep_fit`
456
457		# Parameters:
458		- `safety_margin : float`
459			safety margin to apply to the success fraction estimate
460			(defaults to `1.2`, or 20% more mazes than estimated)
461		- `except_if_all_success_expected : bool`
462			if `True`, don't raise an error if the success fraction is below the threshold.
463			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
464			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
465			since `safety_margin` is still applied.
466			(defaults to `False`)
467		- `epsilon : float`
468			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
469			(defaults to `1e-2`)
470
471		# Returns:
472		- `MazeDatasetConfig`
473			new config with adjusted `n_mazes`
474
475		# Raises:
476		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
477		"""
478		# compute and check the success fraction
479		success_fraction: float = self.success_fraction_estimate(
480			except_if_all_success_expected=except_if_all_success_expected,
481		)
482		if success_fraction < epsilon:
483			err_msg: str = (
484				f"{success_fraction = } is below the threshold of {epsilon = }"
485			)
486			raise SuccessChanceTooSmallError(
487				err_msg,
488			)
489
490		# compute the new number of mazes
491		n_mazes: int = self.n_mazes
492		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
493
494		# put it in a new config and return
495		cfg_dict: dict = self.serialize()
496		cfg_dict["n_mazes"] = new_n_mazes
497		return MazeDatasetConfig.load(cfg_dict)

config object which is passed to MazeDataset.from_config to generate or load a dataset

Parameters:

  • name : str name of the dataset -- this can be anything, but should be filesystem safe since we use it in the fname
  • grid_n : int grid size of the maze (number of rows/columns)
  • n_mazes : int number of mazes to request. For some combinations of endpoint_kwargs and maze_ctor, not all mazes might successfully generate. see EndpointKwargsType for more details.
  • maze_ctor : Callable maze generator function. This should be a function that takes a grid size and returns a maze. This will usually be one of the functions in LatticeMazeGenerators.
  • maze_ctor_kwargs : dict keyword arguments to pass to the maze generator function. Specific to the maze_ctor you are using.
  • endpoint_kwargs : EndpointKwargsType keyword arguments passed to LatticeMaze.generate_random_path(). see EndpointKwargsType for more info.
  • applied_filters : list[dict] list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, but these are stored with the config in case you want to re-generate the dataset with the same filters.
MazeDatasetConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, grid_n: int, n_mazes: int, maze_ctor: Callable = <function LatticeMazeGenerators.gen_dfs>, maze_ctor_kwargs: dict = <factory>, endpoint_kwargs: dict[typing.Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | None | list[tuple[int, int]]] = <factory>, _fname_loaded: str | None = None)
config_version: str
282	@property
283	def config_version(self) -> str:
284		"""return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
285		return "1.0"

return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config

versions: dict
287	@property
288	def versions(self) -> dict:
289		"""return the versions of the config and the maze_dataset"""
290		return dict(
291			config=self.config_version,
292			maze_dataset=importlib.metadata.version("maze_dataset"),
293		)

return the versions of the config and the maze_dataset

def serialize(self) -> dict:
295	def serialize(self) -> dict:
296		"serialize the MazeDatasetConfig with all fields and fname"
297		return {
298			**self._serialize_base(
299				applied_filters__skip__collect_generation_meta=False
300			),
301			"fname": self.to_fname(),
302			"versions": self.versions,
303		}

serialize the MazeDatasetConfig with all fields and fname

def summary(self) -> dict:
305	def summary(self) -> dict:
306		"""return a summary of the config"""
307		# do we run this to make sure it doesn't error?
308		super_summary: dict = super().summary()
309		assert super_summary
310		self_ser: dict = self.serialize()
311		return dict(
312			name=self.name,
313			fname=self.to_fname(),
314			sdc_hash=self.stable_hash_cfg(),
315			seed=self.seed,
316			seq_len_min=self.seq_len_min,
317			seq_len_max=self.seq_len_max,
318			applied_filters=self.applied_filters,
319			grid_n=self_ser["grid_n"],
320			n_mazes=self_ser["n_mazes"],
321			maze_ctor_name=self_ser["maze_ctor"]["__name__"],
322			maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
323			endpoint_kwargs=self_ser["endpoint_kwargs"],
324		)

return a summary of the config

def success_fraction_estimate(self, except_if_all_success_expected: bool = False) -> float:
407	def success_fraction_estimate(
408		self,
409		except_if_all_success_expected: bool = False,
410	) -> float:
411		"""Estimate the success fraction of this config.
412
413		only valid when the generator is a percolation generator,
414		and endpoints are enforced to be dead ends
415
416		more information on where this comes from can be found in
417		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
418		- `estimate_dataset_fractions.ipynb`
419		- `maze_dataset.benchmarks.sweep_fit`
420
421		# Parameters:
422		- `except_if_all_success_expected : bool`
423			if `True`, don't raise an error if the success fraction is below the threshold.
424			will always return `1.0` if the config is not expected to fail
425
426		# Returns:
427		- `float`
428			estimated success fraction
429
430		# Raises:
431		- `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
432		"""
433		try:
434			return cfg_success_predict_fn(self)
435
436		except NoPercolationInConfigError as e:
437			if except_if_all_success_expected:
438				raise e  # noqa: TRY201
439			return 1.0

Estimate the success fraction of this config.

only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends

more information on where this comes from can be found in

Parameters:

  • except_if_all_success_expected : bool if True, don't raise an error if the success fraction is below the threshold. will always return 1.0 if the config is not expected to fail

Returns:

  • float estimated success fraction

Raises:

def success_fraction_compensate( self, safety_margin: float = 1.2, except_if_all_success_expected: bool = False, epsilon: float = 0.01) -> MazeDatasetConfig:
441	def success_fraction_compensate(
442		self,
443		safety_margin: float = 1.2,
444		except_if_all_success_expected: bool = False,
445		epsilon: float = 1e-2,
446	) -> "MazeDatasetConfig":
447		"""return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
448
449		calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then
450		computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1`
451
452		more information on where this comes from can be found in
453		- `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
454		- `estimate_dataset_fractions.ipynb`
455		- `maze_dataset.benchmarks.sweep_fit`
456
457		# Parameters:
458		- `safety_margin : float`
459			safety margin to apply to the success fraction estimate
460			(defaults to `1.2`, or 20% more mazes than estimated)
461		- `except_if_all_success_expected : bool`
462			if `True`, don't raise an error if the success fraction is below the threshold.
463			this is passed to `MazeDatasetConfig.success_fraction_estimate`.
464			if your config isn't expected to fail, passing this might mean you generate more mazes than needed
465			since `safety_margin` is still applied.
466			(defaults to `False`)
467		- `epsilon : float`
468			raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
469			(defaults to `1e-2`)
470
471		# Returns:
472		- `MazeDatasetConfig`
473			new config with adjusted `n_mazes`
474
475		# Raises:
476		- `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
477		"""
478		# compute and check the success fraction
479		success_fraction: float = self.success_fraction_estimate(
480			except_if_all_success_expected=except_if_all_success_expected,
481		)
482		if success_fraction < epsilon:
483			err_msg: str = (
484				f"{success_fraction = } is below the threshold of {epsilon = }"
485			)
486			raise SuccessChanceTooSmallError(
487				err_msg,
488			)
489
490		# compute the new number of mazes
491		n_mazes: int = self.n_mazes
492		new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
493
494		# put it in a new config and return
495		cfg_dict: dict = self.serialize()
496		cfg_dict["n_mazes"] = new_n_mazes
497		return MazeDatasetConfig.load(cfg_dict)

return a new MazeDatasetConfig like this one with n_mazes adjusted to compensate for the success fraction

calls MazeDatasetConfig.success_fraction_estimate() to get the success fraction, and then computes the new number of mazes as n_mazes = n_mazes * safety_margin / success_fraction + 1

more information on where this comes from can be found in

Parameters:

  • safety_margin : float safety margin to apply to the success fraction estimate (defaults to 1.2, or 20% more mazes than estimated)
  • except_if_all_success_expected : bool if True, don't raise an error if the success fraction is below the threshold. this is passed to MazeDatasetConfig.success_fraction_estimate. if your config isn't expected to fail, passing this might mean you generate more mazes than needed since safety_margin is still applied. (defaults to False)
  • epsilon : float raise SuccessChanceTooSmallError if the success fraction is below this threshold (defaults to 1e-2)

Returns:

Raises:

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
MazeDatasetConfig_base
grid_n
n_mazes
maze_ctor
maze_ctor_kwargs
endpoint_kwargs
grid_shape
grid_shape_np
max_grid_n
stable_hash_cfg
to_fname
maze_dataset.dataset.dataset.GPTDatasetConfig
name
seq_len_min
seq_len_max
seed
applied_filters
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict