Coverage for maze_dataset/dataset/maze_dataset_config.py: 24%
118 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
1"implements `MazeDatasetConfig` which is used to generate or load a dataset"
3import hashlib
4import importlib.metadata
5import json
6import typing
7import warnings
8from typing import Callable
10import numpy as np
11from jaxtyping import Float
12from muutils.json_serialize import (
13 serializable_dataclass,
14 serializable_field,
15)
16from muutils.json_serialize.util import (
17 safe_getsource,
18 string_as_lines,
19)
20from muutils.misc import sanitize_fname, shorten_numerical_to_str
22from maze_dataset.constants import Coord, CoordTup
23from maze_dataset.dataset.dataset import (
24 GPTDatasetConfig,
25)
26from maze_dataset.dataset.success_predict_math import cfg_success_predict_fn
27from maze_dataset.generation.generators import _GENERATORS_PERCOLATED, GENERATORS_MAP
29SERIALIZE_MINIMAL_THRESHOLD: int | None = 100
30"""If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`.
31Setting to None means that `serialize_minimal` will never be used.
32Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only."""
34MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5
35"length of the has, in characters, of the hash in the fname of a `MazeDatasetConfig`"
37_PercolationSuccessArray = Float[
38 np.ndarray,
39 "p/grid_n/deadends/endpoints_not_equal/generator_func=5",
40]
43class NoPercolationInConfigError(ValueError):
44 """raised when trying to predict the success fraction of a config that doesn't have percolation"""
46 pass
49class SuccessChanceTooSmallError(ValueError):
50 """raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`"""
52 pass
55def set_serialize_minimal_threshold(threshold: int | None) -> None:
56 "get the global SERIALIZE_MINIMAL_THRESHOLD"
57 global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603
58 SERIALIZE_MINIMAL_THRESHOLD = threshold
61def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable:
62 "get the maze constructor from `GENERATORS_MAP`"
63 if isinstance(maze_ctor_serialized, dict):
64 # this is both the new and old version of the serialization
65 return GENERATORS_MAP[maze_ctor_serialized["__name__"]]
66 elif isinstance(maze_ctor_serialized, str):
67 # this is a version I switched to for a while but now we are switching back
68 warnings.warn(
69 "you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: "
70 "https://github.com/understanding-search/maze-dataset/issues/new",
71 )
72 return GENERATORS_MAP[maze_ctor_serialized]
73 else:
74 err_msg: str = f"maze_ctor_serialized is of type {type(maze_ctor_serialized) = }, expected str or dict\n{maze_ctor_serialized = }"
75 raise TypeError(err_msg)
78EndpointKwargsType = dict[
79 typing.Literal[
80 "allowed_start",
81 "allowed_end",
82 "deadend_start",
83 "deadend_end",
84 "endpoints_not_equal",
85 "except_on_no_valid_endpoint",
86 ],
87 bool | None | list[tuple[int, int]],
88]
89"""type hint for `MazeDatasetConfig.endpoint_kwargs`
91- `except_on_no_valid_endpoint : bool` (default: `True`)
92 some of the conditions (dead ends if a maze is very open, no path between given start and end) can cause the maze generation to fail.
93 if `except_on_no_valid_endpoint` is `True`, then the maze generation will raise an error if it fails to generate a valid maze.
94 however, if `False`, then the maze generation will return a dataset with fewer mazes than requested.
95 If you are generating large datasets, consider using `MazeDatasetConfig.success_fraction_compensate()`
96 this uses a pysr-created function to roughly estimate the success fraction of the dataset.
97- `allowed_start : list[tuple[int, int]]` (default: `None`)
98 list of allowed starting position coordinates
99- `allowed_end : list[tuple[int, int]]` (default: `None`)
100 list of allowed ending position coordinates
101- `deadend_start : bool` (default: `False`)
102 if `True`, the starting position must be a dead end
103- `deadend_end : bool` (default: `False`)
104 if `True`, the ending position must be a dead end
105- `endpoints_not_equal : bool` (default: `True`)
106 if `True`, the starting and ending positions must be different
110"""
113def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType:
114 if data.get("endpoint_kwargs") is None:
115 return dict()
117 else:
118 return {
119 k: (
120 # bools and Nones are fine
121 v
122 if (isinstance(v, bool) or v is None)
123 # assume its a CoordList
124 else [tuple(x) for x in v] # muutils/zanj saves tuples as lists
125 )
126 for k, v in data["endpoint_kwargs"].items()
127 }
130# not private because we need this to show up in docs
131@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"])
132class MazeDatasetConfig_base(GPTDatasetConfig): # noqa: N801
133 """base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here"""
135 # NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only
137 grid_n: int = serializable_field() # type: ignore[misc]
139 # not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters
140 n_mazes: int = serializable_field(compare=False) # type: ignore[misc]
142 maze_ctor: Callable = serializable_field(
143 default=GENERATORS_MAP["gen_dfs"],
144 serialization_fn=lambda gen_func: {
145 "__name__": gen_func.__name__,
146 "__module__": gen_func.__module__,
147 # NOTE: this was causing hashing issues on 3.13 vs older versions because somehow,
148 # the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY
149 # so we just uh. strip it all now.
150 # see:
151 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53
152 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53
153 # https://www.diffchecker.com/tqIMSevy/
154 # update: we also need to filter for empty lines. B)
155 "__doc__": [
156 line.strip()
157 for line in string_as_lines(gen_func.__doc__)
158 if line.strip()
159 ],
160 "source_code": safe_getsource(gen_func),
161 },
162 loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]),
163 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures
164 )
166 maze_ctor_kwargs: dict = serializable_field(
167 default_factory=dict,
168 serialization_fn=lambda kwargs: kwargs,
169 loading_fn=lambda data: (
170 dict()
171 if data.get("maze_ctor_kwargs", None)
172 is None # this should handle the backwards compatibility
173 else data["maze_ctor_kwargs"]
174 ),
175 )
177 endpoint_kwargs: EndpointKwargsType = serializable_field(
178 default_factory=dict,
179 serialization_fn=lambda kwargs: kwargs,
180 loading_fn=_load_endpoint_kwargs,
181 assert_type=False,
182 )
184 # NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*,
185 # and so we need to save an `None` here or this wont load the `fname` field on load
186 # this is a total mess, and very confusing, and entirely my fault
187 _fname_loaded: str | None = serializable_field(
188 default=None,
189 compare=False,
190 serialization_fn=lambda _: None,
191 loading_fn=lambda data: data.get("fname", None),
192 )
194 @property
195 def grid_shape(self) -> CoordTup:
196 """return the shape of the grid as a tuple"""
197 return (self.grid_n, self.grid_n)
199 @property
200 def grid_shape_np(self) -> Coord:
201 """return the shape of the grid as a numpy array"""
202 return np.array(self.grid_shape)
204 @property
205 def max_grid_n(self) -> int:
206 """return the maximum of the grid shape"""
207 return max(self.grid_shape)
209 def _serialize_base(
210 self, applied_filters__skip__collect_generation_meta: bool = True
211 ) -> dict:
212 """serialize the base config for user in `stable_hash_cfg()` and `to_fname()`
214 - note that the _fname_loaded will always be `None` to avoid infinite recursion
215 - note that we **do not** by default include information about metadata collection here,
216 since otherwise loading a dataset that we minified by collecting the metadata would be impossible
217 but for comparing things, we do store it when serializing properly by setting
218 `applied_filters__skip__collect_generation_meta=False`
219 """
220 serialized: dict = MazeDatasetConfig_base.serialize(self)
221 if applied_filters__skip__collect_generation_meta:
222 serialized["applied_filters"] = [
223 x
224 for x in serialized["applied_filters"]
225 if x.get("name", None) != "collect_generation_meta"
226 ]
227 return serialized
229 def _stable_str_dump(self) -> str:
230 return json.dumps(
231 self._serialize_base(),
232 sort_keys=True,
233 indent=None,
234 )
236 def stable_hash_cfg(self) -> int:
237 """return a stable hash of the config"""
238 return int.from_bytes(
239 hashlib.md5( # noqa: S324
240 bytes(self._stable_str_dump(), "ascii")
241 ).digest(),
242 "big",
243 )
245 def to_fname(self) -> str:
246 """return a unique identifier (valid as a filename) for this config"""
247 n_mazes_str: str = shorten_numerical_to_str(self.n_mazes)
248 maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_")
249 hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH
250 return sanitize_fname(
251 f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}",
252 )
255# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only
256@serializable_dataclass(kw_only=True, methods_no_override=["serialize"])
257class MazeDatasetConfig(MazeDatasetConfig_base): # type: ignore[misc]
258 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset
260 # Parameters:
261 - `name : str`
262 name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname`
263 - `grid_n : int`
264 grid size of the maze (number of rows/columns)
265 - `n_mazes : int`
266 number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate.
267 see `EndpointKwargsType` for more details.
268 - `maze_ctor : Callable`
269 maze generator function. This should be a function that takes a grid size and returns a maze.
270 This will usually be one of the functions in `LatticeMazeGenerators`.
271 - `maze_ctor_kwargs : dict`
272 keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using.
273 - `endpoint_kwargs : EndpointKwargsType`
274 keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info.
275 - `applied_filters : list[dict]`
276 list of filters that have been applied to the dataset. We recommend applying filters to datasets directly,
277 but these are stored with the config in case you want to re-generate the dataset with the same filters.
279 """
281 @property
282 def config_version(self) -> str:
283 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config"""
284 return "1.0"
286 @property
287 def versions(self) -> dict:
288 """return the versions of the config and the maze_dataset"""
289 return dict(
290 config=self.config_version,
291 maze_dataset=importlib.metadata.version("maze_dataset"),
292 )
294 def serialize(self) -> dict:
295 "serialize the MazeDatasetConfig with all fields and fname"
296 return {
297 **self._serialize_base(
298 applied_filters__skip__collect_generation_meta=False
299 ),
300 "fname": self.to_fname(),
301 "versions": self.versions,
302 }
304 def summary(self) -> dict:
305 """return a summary of the config"""
306 # do we run this to make sure it doesn't error?
307 super_summary: dict = super().summary()
308 assert super_summary
309 self_ser: dict = self.serialize()
310 return dict(
311 name=self.name,
312 fname=self.to_fname(),
313 sdc_hash=self.stable_hash_cfg(),
314 seed=self.seed,
315 seq_len_min=self.seq_len_min,
316 seq_len_max=self.seq_len_max,
317 applied_filters=self.applied_filters,
318 grid_n=self_ser["grid_n"],
319 n_mazes=self_ser["n_mazes"],
320 maze_ctor_name=self_ser["maze_ctor"]["__name__"],
321 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"],
322 endpoint_kwargs=self_ser["endpoint_kwargs"],
323 )
325 def _to_ps_array(self) -> _PercolationSuccessArray:
326 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector.
328 used in predicting the success rate
329 """
330 try:
331 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, (
332 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }"
333 )
334 assert "p" in self.maze_ctor_kwargs, (
335 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }"
336 )
337 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), (
338 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }"
339 )
340 except AssertionError as e:
341 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }"
342 raise NoPercolationInConfigError(
343 err_msg,
344 ) from e
346 endpoints_unique_flag: int = int(
347 # we are pretty sure it will be an int or bool here
348 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type]
349 )
351 # adjustment for bknutson0
352 if not (
353 self.endpoint_kwargs.get("deadend_start", False)
354 and self.endpoint_kwargs.get("deadend_end", False)
355 ):
356 # we didnt train on this, but if either endpoint is not required to be in a dead end
357 # then requiring the endpoints to be unique does not really affect the success rate
358 # (except for very small percolation values, pure percolation generation)
359 endpoints_unique_flag = 0
361 return np.array(
362 [
363 float(self.maze_ctor_kwargs["p"]),
364 float(self.grid_n),
365 float(
366 int(
367 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type]
368 or self.endpoint_kwargs.get("deadend_end", False),
369 ),
370 ),
371 float(endpoints_unique_flag),
372 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)),
373 ],
374 dtype=np.float64,
375 )
377 @classmethod
378 def _from_ps_array(
379 cls,
380 arr: _PercolationSuccessArray,
381 name: str = "predict",
382 n_mazes: int = 100,
383 **kwargs,
384 ) -> "MazeDatasetConfig":
385 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters.
387 # Returns:
388 - `MazeDatasetConfig`
389 Config corresponding to `arr`
390 """
391 return cls(
392 name=name,
393 grid_n=int(arr[1]),
394 n_mazes=n_mazes,
395 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]],
396 maze_ctor_kwargs={"p": float(arr[0])},
397 endpoint_kwargs=dict(
398 deadend_start=bool(arr[2]),
399 deadend_end=bool(arr[2]),
400 endpoints_not_equal=bool(arr[3]),
401 except_on_no_valid_endpoint=False,
402 ),
403 **kwargs,
404 )
406 def success_fraction_estimate(
407 self,
408 except_if_all_success_expected: bool = False,
409 ) -> float:
410 """Estimate the success fraction of this config.
412 only valid when the generator is a percolation generator,
413 and endpoints are enforced to be dead ends
415 more information on where this comes from can be found in
416 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
417 - `estimate_dataset_fractions.ipynb`
418 - `maze_dataset.benchmarks.sweep_fit`
420 # Parameters:
421 - `except_if_all_success_expected : bool`
422 if `True`, don't raise an error if the success fraction is below the threshold.
423 will always return `1.0` if the config is not expected to fail
425 # Returns:
426 - `float`
427 estimated success fraction
429 # Raises:
430 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False`
431 """
432 try:
433 return cfg_success_predict_fn(self)
435 except NoPercolationInConfigError as e:
436 if except_if_all_success_expected:
437 raise e # noqa: TRY201
438 return 1.0
440 def success_fraction_compensate(
441 self,
442 safety_margin: float = 1.2,
443 except_if_all_success_expected: bool = False,
444 epsilon: float = 1e-2,
445 ) -> "MazeDatasetConfig":
446 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction
448 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then
449 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1`
451 more information on where this comes from can be found in
452 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math`
453 - `estimate_dataset_fractions.ipynb`
454 - `maze_dataset.benchmarks.sweep_fit`
456 # Parameters:
457 - `safety_margin : float`
458 safety margin to apply to the success fraction estimate
459 (defaults to `1.2`, or 20% more mazes than estimated)
460 - `except_if_all_success_expected : bool`
461 if `True`, don't raise an error if the success fraction is below the threshold.
462 this is passed to `MazeDatasetConfig.success_fraction_estimate`.
463 if your config isn't expected to fail, passing this might mean you generate more mazes than needed
464 since `safety_margin` is still applied.
465 (defaults to `False`)
466 - `epsilon : float`
467 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold
468 (defaults to `1e-2`)
470 # Returns:
471 - `MazeDatasetConfig`
472 new config with adjusted `n_mazes`
474 # Raises:
475 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon`
476 """
477 # compute and check the success fraction
478 success_fraction: float = self.success_fraction_estimate(
479 except_if_all_success_expected=except_if_all_success_expected,
480 )
481 if success_fraction < epsilon:
482 err_msg: str = (
483 f"{success_fraction = } is below the threshold of {epsilon = }"
484 )
485 raise SuccessChanceTooSmallError(
486 err_msg,
487 )
489 # compute the new number of mazes
490 n_mazes: int = self.n_mazes
491 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1
493 # put it in a new config and return
494 cfg_dict: dict = self.serialize()
495 cfg_dict["n_mazes"] = new_n_mazes
496 return MazeDatasetConfig.load(cfg_dict)