maze_dataset.dataset.maze_dataset_config
implements MazeDatasetConfig
which is used to generate or load a dataset
1"implements `MazeDatasetConfig` which is used to generate or load a dataset" 2 3import hashlib 4import importlib.metadata 5import json 6import typing 7import warnings 8from typing import Callable 9 10import numpy as np 11from jaxtyping import Float 12from muutils.json_serialize import ( 13 serializable_dataclass, 14 serializable_field, 15) 16from muutils.json_serialize.util import ( 17 safe_getsource, 18 string_as_lines, 19) 20from muutils.misc import sanitize_fname, shorten_numerical_to_str 21 22from maze_dataset.constants import Coord, CoordTup 23from maze_dataset.dataset.dataset import ( 24 GPTDatasetConfig, 25) 26from maze_dataset.dataset.success_predict_math import cfg_success_predict_fn 27from maze_dataset.generation.generators import _GENERATORS_PERCOLATED, GENERATORS_MAP 28 29SERIALIZE_MINIMAL_THRESHOLD: int | None = 100 30"""If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`. 31Setting to None means that `serialize_minimal` will never be used. 32Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only.""" 33 34MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5 35"length of the has, in characters, of the hash in the fname of a `MazeDatasetConfig`" 36 37_PercolationSuccessArray = Float[ 38 np.ndarray, 39 "p/grid_n/deadends/endpoints_not_equal/generator_func=5", 40] 41 42 43class NoPercolationInConfigError(ValueError): 44 """raised when trying to predict the success fraction of a config that doesn't have percolation""" 45 46 pass 47 48 49class SuccessChanceTooSmallError(ValueError): 50 """raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`""" 51 52 pass 53 54 55def set_serialize_minimal_threshold(threshold: int | None) -> None: 56 "get the global SERIALIZE_MINIMAL_THRESHOLD" 57 global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603 58 SERIALIZE_MINIMAL_THRESHOLD = threshold 59 60 61def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable: 62 "get the maze constructor from `GENERATORS_MAP`" 63 if isinstance(maze_ctor_serialized, dict): 64 # this is both the new and old version of the serialization 65 return GENERATORS_MAP[maze_ctor_serialized["__name__"]] 66 elif isinstance(maze_ctor_serialized, str): 67 # this is a version I switched to for a while but now we are switching back 68 warnings.warn( 69 "you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: " 70 "https://github.com/understanding-search/maze-dataset/issues/new", 71 ) 72 return GENERATORS_MAP[maze_ctor_serialized] 73 else: 74 err_msg: str = f"maze_ctor_serialized is of type {type(maze_ctor_serialized) = }, expected str or dict\n{maze_ctor_serialized = }" 75 raise TypeError(err_msg) 76 77 78EndpointKwargsType = dict[ 79 typing.Literal[ 80 "allowed_start", 81 "allowed_end", 82 "deadend_start", 83 "deadend_end", 84 "endpoints_not_equal", 85 "except_on_no_valid_endpoint", 86 ], 87 bool | None | list[tuple[int, int]], 88] 89"""type hint for `MazeDatasetConfig.endpoint_kwargs` 90 91- `except_on_no_valid_endpoint : bool` (default: `True`) 92 some of the conditions (dead ends if a maze is very open, no path between given start and end) can cause the maze generation to fail. 93 if `except_on_no_valid_endpoint` is `True`, then the maze generation will raise an error if it fails to generate a valid maze. 94 however, if `False`, then the maze generation will return a dataset with fewer mazes than requested. 95 If you are generating large datasets, consider using `MazeDatasetConfig.success_fraction_compensate()` 96 this uses a pysr-created function to roughly estimate the success fraction of the dataset. 97- `allowed_start : list[tuple[int, int]]` (default: `None`) 98 list of allowed starting position coordinates 99- `allowed_end : list[tuple[int, int]]` (default: `None`) 100 list of allowed ending position coordinates 101- `deadend_start : bool` (default: `False`) 102 if `True`, the starting position must be a dead end 103- `deadend_end : bool` (default: `False`) 104 if `True`, the ending position must be a dead end 105- `endpoints_not_equal : bool` (default: `True`) 106 if `True`, the starting and ending positions must be different 107 108 109 110""" 111 112 113def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType: 114 if data.get("endpoint_kwargs") is None: 115 return dict() 116 117 else: 118 return { 119 k: ( 120 # bools and Nones are fine 121 v 122 if (isinstance(v, bool) or v is None) 123 # assume its a CoordList 124 else [tuple(x) for x in v] # muutils/zanj saves tuples as lists 125 ) 126 for k, v in data["endpoint_kwargs"].items() 127 } 128 129 130# not private because we need this to show up in docs 131@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"]) 132class MazeDatasetConfig_base(GPTDatasetConfig): # noqa: N801 133 """base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here""" 134 135 # NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only 136 137 grid_n: int = serializable_field() # type: ignore[misc] 138 139 # not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters 140 n_mazes: int = serializable_field(compare=False) # type: ignore[misc] 141 142 maze_ctor: Callable = serializable_field( 143 default=GENERATORS_MAP["gen_dfs"], 144 serialization_fn=lambda gen_func: { 145 "__name__": gen_func.__name__, 146 "__module__": gen_func.__module__, 147 # NOTE: this was causing hashing issues on 3.13 vs older versions because somehow, 148 # the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY 149 # so we just uh. strip it all now. 150 # see: 151 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53 152 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53 153 # https://www.diffchecker.com/tqIMSevy/ 154 # update: we also need to filter for empty lines. B) 155 "__doc__": [ 156 line.strip() 157 for line in string_as_lines(gen_func.__doc__) 158 if line.strip() 159 ], 160 "source_code": safe_getsource(gen_func), 161 }, 162 loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]), 163 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 164 ) 165 166 maze_ctor_kwargs: dict = serializable_field( 167 default_factory=dict, 168 serialization_fn=lambda kwargs: kwargs, 169 loading_fn=lambda data: ( 170 dict() 171 if data.get("maze_ctor_kwargs", None) 172 is None # this should handle the backwards compatibility 173 else data["maze_ctor_kwargs"] 174 ), 175 ) 176 177 endpoint_kwargs: EndpointKwargsType = serializable_field( 178 default_factory=dict, 179 serialization_fn=lambda kwargs: kwargs, 180 loading_fn=_load_endpoint_kwargs, 181 assert_type=False, 182 ) 183 184 # NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*, 185 # and so we need to save an `None` here or this wont load the `fname` field on load 186 # this is a total mess, and very confusing, and entirely my fault 187 _fname_loaded: str | None = serializable_field( 188 default=None, 189 compare=False, 190 serialization_fn=lambda _: None, 191 loading_fn=lambda data: data.get("fname", None), 192 ) 193 194 @property 195 def grid_shape(self) -> CoordTup: 196 """return the shape of the grid as a tuple""" 197 return (self.grid_n, self.grid_n) 198 199 @property 200 def grid_shape_np(self) -> Coord: 201 """return the shape of the grid as a numpy array""" 202 return np.array(self.grid_shape) 203 204 @property 205 def max_grid_n(self) -> int: 206 """return the maximum of the grid shape""" 207 return max(self.grid_shape) 208 209 def _serialize_base( 210 self, applied_filters__skip__collect_generation_meta: bool = True 211 ) -> dict: 212 """serialize the base config for user in `stable_hash_cfg()` and `to_fname()` 213 214 - note that the _fname_loaded will always be `None` to avoid infinite recursion 215 - note that we **do not** by default include information about metadata collection here, 216 since otherwise loading a dataset that we minified by collecting the metadata would be impossible 217 but for comparing things, we do store it when serializing properly by setting 218 `applied_filters__skip__collect_generation_meta=False` 219 """ 220 serialized: dict = MazeDatasetConfig_base.serialize(self) 221 if applied_filters__skip__collect_generation_meta: 222 serialized["applied_filters"] = [ 223 x 224 for x in serialized["applied_filters"] 225 if x.get("name", None) != "collect_generation_meta" 226 ] 227 return serialized 228 229 def _stable_str_dump(self) -> str: 230 return json.dumps( 231 self._serialize_base(), 232 sort_keys=True, 233 indent=None, 234 ) 235 236 def stable_hash_cfg(self) -> int: 237 """return a stable hash of the config""" 238 return int.from_bytes( 239 hashlib.md5( # noqa: S324 240 bytes(self._stable_str_dump(), "ascii") 241 ).digest(), 242 "big", 243 ) 244 245 def to_fname(self) -> str: 246 """return a unique identifier (valid as a filename) for this config""" 247 n_mazes_str: str = shorten_numerical_to_str(self.n_mazes) 248 maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_") 249 hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH 250 return sanitize_fname( 251 f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}", 252 ) 253 254 255# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only 256@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 257class MazeDatasetConfig(MazeDatasetConfig_base): # type: ignore[misc] 258 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset 259 260 # Parameters: 261 - `name : str` 262 name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname` 263 - `grid_n : int` 264 grid size of the maze (number of rows/columns) 265 - `n_mazes : int` 266 number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate. 267 see `EndpointKwargsType` for more details. 268 - `maze_ctor : Callable` 269 maze generator function. This should be a function that takes a grid size and returns a maze. 270 This will usually be one of the functions in `LatticeMazeGenerators`. 271 - `maze_ctor_kwargs : dict` 272 keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using. 273 - `endpoint_kwargs : EndpointKwargsType` 274 keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info. 275 - `applied_filters : list[dict]` 276 list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, 277 but these are stored with the config in case you want to re-generate the dataset with the same filters. 278 279 """ 280 281 @property 282 def config_version(self) -> str: 283 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 284 return "1.0" 285 286 @property 287 def versions(self) -> dict: 288 """return the versions of the config and the maze_dataset""" 289 return dict( 290 config=self.config_version, 291 maze_dataset=importlib.metadata.version("maze_dataset"), 292 ) 293 294 def serialize(self) -> dict: 295 "serialize the MazeDatasetConfig with all fields and fname" 296 return { 297 **self._serialize_base( 298 applied_filters__skip__collect_generation_meta=False 299 ), 300 "fname": self.to_fname(), 301 "versions": self.versions, 302 } 303 304 def summary(self) -> dict: 305 """return a summary of the config""" 306 # do we run this to make sure it doesn't error? 307 super_summary: dict = super().summary() 308 assert super_summary 309 self_ser: dict = self.serialize() 310 return dict( 311 name=self.name, 312 fname=self.to_fname(), 313 sdc_hash=self.stable_hash_cfg(), 314 seed=self.seed, 315 seq_len_min=self.seq_len_min, 316 seq_len_max=self.seq_len_max, 317 applied_filters=self.applied_filters, 318 grid_n=self_ser["grid_n"], 319 n_mazes=self_ser["n_mazes"], 320 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 321 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 322 endpoint_kwargs=self_ser["endpoint_kwargs"], 323 ) 324 325 def _to_ps_array(self) -> _PercolationSuccessArray: 326 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 327 328 used in predicting the success rate 329 """ 330 try: 331 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 332 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 333 ) 334 assert "p" in self.maze_ctor_kwargs, ( 335 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 336 ) 337 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 338 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 339 ) 340 except AssertionError as e: 341 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 342 raise NoPercolationInConfigError( 343 err_msg, 344 ) from e 345 346 endpoints_unique_flag: int = int( 347 # we are pretty sure it will be an int or bool here 348 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 349 ) 350 351 # adjustment for bknutson0 352 if not ( 353 self.endpoint_kwargs.get("deadend_start", False) 354 and self.endpoint_kwargs.get("deadend_end", False) 355 ): 356 # we didnt train on this, but if either endpoint is not required to be in a dead end 357 # then requiring the endpoints to be unique does not really affect the success rate 358 # (except for very small percolation values, pure percolation generation) 359 endpoints_unique_flag = 0 360 361 return np.array( 362 [ 363 float(self.maze_ctor_kwargs["p"]), 364 float(self.grid_n), 365 float( 366 int( 367 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 368 or self.endpoint_kwargs.get("deadend_end", False), 369 ), 370 ), 371 float(endpoints_unique_flag), 372 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 373 ], 374 dtype=np.float64, 375 ) 376 377 @classmethod 378 def _from_ps_array( 379 cls, 380 arr: _PercolationSuccessArray, 381 name: str = "predict", 382 n_mazes: int = 100, 383 **kwargs, 384 ) -> "MazeDatasetConfig": 385 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 386 387 # Returns: 388 - `MazeDatasetConfig` 389 Config corresponding to `arr` 390 """ 391 return cls( 392 name=name, 393 grid_n=int(arr[1]), 394 n_mazes=n_mazes, 395 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 396 maze_ctor_kwargs={"p": float(arr[0])}, 397 endpoint_kwargs=dict( 398 deadend_start=bool(arr[2]), 399 deadend_end=bool(arr[2]), 400 endpoints_not_equal=bool(arr[3]), 401 except_on_no_valid_endpoint=False, 402 ), 403 **kwargs, 404 ) 405 406 def success_fraction_estimate( 407 self, 408 except_if_all_success_expected: bool = False, 409 ) -> float: 410 """Estimate the success fraction of this config. 411 412 only valid when the generator is a percolation generator, 413 and endpoints are enforced to be dead ends 414 415 more information on where this comes from can be found in 416 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 417 - `estimate_dataset_fractions.ipynb` 418 - `maze_dataset.benchmarks.sweep_fit` 419 420 # Parameters: 421 - `except_if_all_success_expected : bool` 422 if `True`, don't raise an error if the success fraction is below the threshold. 423 will always return `1.0` if the config is not expected to fail 424 425 # Returns: 426 - `float` 427 estimated success fraction 428 429 # Raises: 430 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 431 """ 432 try: 433 return cfg_success_predict_fn(self) 434 435 except NoPercolationInConfigError as e: 436 if except_if_all_success_expected: 437 raise e # noqa: TRY201 438 return 1.0 439 440 def success_fraction_compensate( 441 self, 442 safety_margin: float = 1.2, 443 except_if_all_success_expected: bool = False, 444 epsilon: float = 1e-2, 445 ) -> "MazeDatasetConfig": 446 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 447 448 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 449 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 450 451 more information on where this comes from can be found in 452 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 453 - `estimate_dataset_fractions.ipynb` 454 - `maze_dataset.benchmarks.sweep_fit` 455 456 # Parameters: 457 - `safety_margin : float` 458 safety margin to apply to the success fraction estimate 459 (defaults to `1.2`, or 20% more mazes than estimated) 460 - `except_if_all_success_expected : bool` 461 if `True`, don't raise an error if the success fraction is below the threshold. 462 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 463 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 464 since `safety_margin` is still applied. 465 (defaults to `False`) 466 - `epsilon : float` 467 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 468 (defaults to `1e-2`) 469 470 # Returns: 471 - `MazeDatasetConfig` 472 new config with adjusted `n_mazes` 473 474 # Raises: 475 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 476 """ 477 # compute and check the success fraction 478 success_fraction: float = self.success_fraction_estimate( 479 except_if_all_success_expected=except_if_all_success_expected, 480 ) 481 if success_fraction < epsilon: 482 err_msg: str = ( 483 f"{success_fraction = } is below the threshold of {epsilon = }" 484 ) 485 raise SuccessChanceTooSmallError( 486 err_msg, 487 ) 488 489 # compute the new number of mazes 490 n_mazes: int = self.n_mazes 491 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 492 493 # put it in a new config and return 494 cfg_dict: dict = self.serialize() 495 cfg_dict["n_mazes"] = new_n_mazes 496 return MazeDatasetConfig.load(cfg_dict)
If n_mazes>=SERIALIZE_MINIMAL_THRESHOLD
, then the MazeDataset will use serialize_minimal
.
Setting to None means that serialize_minimal
will never be used.
Set to -1 to make calls to read
use MazeDataset._load_legacy
. Used for profiling only.
length of the has, in characters, of the hash in the fname of a MazeDatasetConfig
44class NoPercolationInConfigError(ValueError): 45 """raised when trying to predict the success fraction of a config that doesn't have percolation""" 46 47 pass
raised when trying to predict the success fraction of a config that doesn't have percolation
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
50class SuccessChanceTooSmallError(ValueError): 51 """raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`""" 52 53 pass
raised when the success fraction is below the threshold in MazeDatasetConfig.success_fraction_compensate
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
56def set_serialize_minimal_threshold(threshold: int | None) -> None: 57 "get the global SERIALIZE_MINIMAL_THRESHOLD" 58 global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603 59 SERIALIZE_MINIMAL_THRESHOLD = threshold
get the global SERIALIZE_MINIMAL_THRESHOLD
type hint for MazeDatasetConfig.endpoint_kwargs
except_on_no_valid_endpoint : bool
(default:True
) some of the conditions (dead ends if a maze is very open, no path between given start and end) can cause the maze generation to fail. ifexcept_on_no_valid_endpoint
isTrue
, then the maze generation will raise an error if it fails to generate a valid maze. however, ifFalse
, then the maze generation will return a dataset with fewer mazes than requested. If you are generating large datasets, consider usingMazeDatasetConfig.success_fraction_compensate()
this uses a pysr-created function to roughly estimate the success fraction of the dataset.allowed_start : list[tuple[int, int]]
(default:None
) list of allowed starting position coordinatesallowed_end : list[tuple[int, int]]
(default:None
) list of allowed ending position coordinatesdeadend_start : bool
(default:False
) ifTrue
, the starting position must be a dead enddeadend_end : bool
(default:False
) ifTrue
, the ending position must be a dead endendpoints_not_equal : bool
(default:True
) ifTrue
, the starting and ending positions must be different
132@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"]) 133class MazeDatasetConfig_base(GPTDatasetConfig): # noqa: N801 134 """base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here""" 135 136 # NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only 137 138 grid_n: int = serializable_field() # type: ignore[misc] 139 140 # not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters 141 n_mazes: int = serializable_field(compare=False) # type: ignore[misc] 142 143 maze_ctor: Callable = serializable_field( 144 default=GENERATORS_MAP["gen_dfs"], 145 serialization_fn=lambda gen_func: { 146 "__name__": gen_func.__name__, 147 "__module__": gen_func.__module__, 148 # NOTE: this was causing hashing issues on 3.13 vs older versions because somehow, 149 # the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY 150 # so we just uh. strip it all now. 151 # see: 152 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53 153 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53 154 # https://www.diffchecker.com/tqIMSevy/ 155 # update: we also need to filter for empty lines. B) 156 "__doc__": [ 157 line.strip() 158 for line in string_as_lines(gen_func.__doc__) 159 if line.strip() 160 ], 161 "source_code": safe_getsource(gen_func), 162 }, 163 loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]), 164 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 165 ) 166 167 maze_ctor_kwargs: dict = serializable_field( 168 default_factory=dict, 169 serialization_fn=lambda kwargs: kwargs, 170 loading_fn=lambda data: ( 171 dict() 172 if data.get("maze_ctor_kwargs", None) 173 is None # this should handle the backwards compatibility 174 else data["maze_ctor_kwargs"] 175 ), 176 ) 177 178 endpoint_kwargs: EndpointKwargsType = serializable_field( 179 default_factory=dict, 180 serialization_fn=lambda kwargs: kwargs, 181 loading_fn=_load_endpoint_kwargs, 182 assert_type=False, 183 ) 184 185 # NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*, 186 # and so we need to save an `None` here or this wont load the `fname` field on load 187 # this is a total mess, and very confusing, and entirely my fault 188 _fname_loaded: str | None = serializable_field( 189 default=None, 190 compare=False, 191 serialization_fn=lambda _: None, 192 loading_fn=lambda data: data.get("fname", None), 193 ) 194 195 @property 196 def grid_shape(self) -> CoordTup: 197 """return the shape of the grid as a tuple""" 198 return (self.grid_n, self.grid_n) 199 200 @property 201 def grid_shape_np(self) -> Coord: 202 """return the shape of the grid as a numpy array""" 203 return np.array(self.grid_shape) 204 205 @property 206 def max_grid_n(self) -> int: 207 """return the maximum of the grid shape""" 208 return max(self.grid_shape) 209 210 def _serialize_base( 211 self, applied_filters__skip__collect_generation_meta: bool = True 212 ) -> dict: 213 """serialize the base config for user in `stable_hash_cfg()` and `to_fname()` 214 215 - note that the _fname_loaded will always be `None` to avoid infinite recursion 216 - note that we **do not** by default include information about metadata collection here, 217 since otherwise loading a dataset that we minified by collecting the metadata would be impossible 218 but for comparing things, we do store it when serializing properly by setting 219 `applied_filters__skip__collect_generation_meta=False` 220 """ 221 serialized: dict = MazeDatasetConfig_base.serialize(self) 222 if applied_filters__skip__collect_generation_meta: 223 serialized["applied_filters"] = [ 224 x 225 for x in serialized["applied_filters"] 226 if x.get("name", None) != "collect_generation_meta" 227 ] 228 return serialized 229 230 def _stable_str_dump(self) -> str: 231 return json.dumps( 232 self._serialize_base(), 233 sort_keys=True, 234 indent=None, 235 ) 236 237 def stable_hash_cfg(self) -> int: 238 """return a stable hash of the config""" 239 return int.from_bytes( 240 hashlib.md5( # noqa: S324 241 bytes(self._stable_str_dump(), "ascii") 242 ).digest(), 243 "big", 244 ) 245 246 def to_fname(self) -> str: 247 """return a unique identifier (valid as a filename) for this config""" 248 n_mazes_str: str = shorten_numerical_to_str(self.n_mazes) 249 maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_") 250 hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH 251 return sanitize_fname( 252 f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}", 253 )
base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here
61 @staticmethod 62 def gen_dfs( 63 grid_shape: Coord | CoordTup, 64 lattice_dim: int = 2, 65 accessible_cells: float | None = None, 66 max_tree_depth: float | None = None, 67 do_forks: bool = True, 68 randomized_stack: bool = False, 69 start_coord: Coord | None = None, 70 ) -> LatticeMaze: 71 """generate a lattice maze using depth first search, iterative 72 73 # Arguments 74 - `grid_shape: Coord`: the shape of the grid 75 - `lattice_dim: int`: the dimension of the lattice 76 (default: `2`) 77 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 78 (default: `None`) 79 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 80 (default: `None`) 81 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 82 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 83 84 # algorithm 85 1. Choose the initial cell, mark it as visited and push it to the stack 86 2. While the stack is not empty 87 1. Pop a cell from the stack and make it a current cell 88 2. If the current cell has any neighbours which have not been visited 89 1. Push the current cell to the stack 90 2. Choose one of the unvisited neighbours 91 3. Remove the wall between the current cell and the chosen cell 92 4. Mark the chosen cell as visited and push it to the stack 93 """ 94 # Default values if no constraints have been passed 95 grid_shape_: Coord = np.array(grid_shape) 96 n_total_cells: int = int(np.prod(grid_shape_)) 97 98 n_accessible_cells: int 99 if accessible_cells is None: 100 n_accessible_cells = n_total_cells 101 elif isinstance(accessible_cells, float): 102 assert accessible_cells <= 1, ( 103 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 104 ) 105 106 n_accessible_cells = int(accessible_cells * n_total_cells) 107 else: 108 assert isinstance(accessible_cells, int) 109 n_accessible_cells = accessible_cells 110 111 if max_tree_depth is None: 112 max_tree_depth = ( 113 2 * n_total_cells 114 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 115 elif isinstance(max_tree_depth, float): 116 assert max_tree_depth <= 1, ( 117 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 118 ) 119 120 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 121 122 # choose a random start coord 123 start_coord = _random_start_coord(grid_shape_, start_coord) 124 125 # initialize the maze with no connections 126 connection_list: ConnectionList = np.zeros( 127 (lattice_dim, grid_shape_[0], grid_shape_[1]), 128 dtype=np.bool_, 129 ) 130 131 # initialize the stack with the target coord 132 visited_cells: set[tuple[int, int]] = set() 133 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 134 stack: list[Coord] = [start_coord] 135 136 # initialize tree_depth_counter 137 current_tree_depth: int = 1 138 139 # loop until the stack is empty or n_connected_cells is reached 140 while stack and (len(visited_cells) < n_accessible_cells): 141 # get the current coord from the stack 142 current_coord: Coord 143 if randomized_stack: 144 current_coord = stack.pop(random.randint(0, len(stack) - 1)) 145 else: 146 current_coord = stack.pop() 147 148 # filter neighbors by being within grid bounds and being unvisited 149 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 150 (neighbor, delta) 151 for neighbor, delta in zip( 152 current_coord + NEIGHBORS_MASK, 153 NEIGHBORS_MASK, 154 strict=False, 155 ) 156 if ( 157 (tuple(neighbor) not in visited_cells) 158 and (0 <= neighbor[0] < grid_shape_[0]) 159 and (0 <= neighbor[1] < grid_shape_[1]) 160 ) 161 ] 162 163 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 164 if unvisited_neighbors_deltas and ( 165 current_tree_depth <= max_tree_depth / 2 166 ): 167 # if we want a maze without forks, simply don't add the current coord back to the stack 168 if do_forks and (len(unvisited_neighbors_deltas) > 1): 169 stack.append(current_coord) 170 171 # choose one of the unvisited neighbors 172 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) 173 174 # add connection 175 dim: int = int(np.argmax(np.abs(delta))) 176 # if positive, down/right from current coord 177 # if negative, up/left from current coord (down/right from neighbor) 178 clist_node: Coord = ( 179 current_coord if (delta.sum() > 0) else chosen_neighbor 180 ) 181 connection_list[dim, clist_node[0], clist_node[1]] = True 182 183 # add to visited cells and stack 184 visited_cells.add(tuple(chosen_neighbor)) 185 stack.append(chosen_neighbor) 186 187 # Update current tree depth 188 current_tree_depth += 1 189 else: 190 current_tree_depth -= 1 191 192 return LatticeMaze( 193 connection_list=connection_list, 194 generation_meta=dict( 195 func_name="gen_dfs", 196 grid_shape=grid_shape_, 197 start_coord=start_coord, 198 n_accessible_cells=int(n_accessible_cells), 199 max_tree_depth=int(max_tree_depth), 200 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 201 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 202 # treated as fully connected even when it is most certainly not, causing solving the maze to break 203 fully_connected=bool(len(visited_cells) == n_total_cells), 204 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 205 ), 206 )
generate a lattice maze using depth first search, iterative
Arguments
grid_shape: Coord
: the shape of the gridlattice_dim: int
: the dimension of the lattice (default:2
)accessible_cells: int | float |None
: the number of accessible cells in the maze. IfNone
, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default:None
)max_tree_depth: int | float | None
: the maximum depth of the tree. IfNone
, defaults to2 * accessible_cells
. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default:None
)do_forks: bool
: whether to allow forks in the maze. IfFalse
, the maze will be have no forks and will be a simple hallway.start_coord: Coord | None
: the starting coordinate of the generation algorithm. IfNone
, defaults to a random coordinate.
algorithm
- Choose the initial cell, mark it as visited and push it to the stack
- While the stack is not empty
- Pop a cell from the stack and make it a current cell
- If the current cell has any neighbours which have not been visited
- Push the current cell to the stack
- Choose one of the unvisited neighbours
- Remove the wall between the current cell and the chosen cell
- Mark the chosen cell as visited and push it to the stack
195 @property 196 def grid_shape(self) -> CoordTup: 197 """return the shape of the grid as a tuple""" 198 return (self.grid_n, self.grid_n)
return the shape of the grid as a tuple
200 @property 201 def grid_shape_np(self) -> Coord: 202 """return the shape of the grid as a numpy array""" 203 return np.array(self.grid_shape)
return the shape of the grid as a numpy array
205 @property 206 def max_grid_n(self) -> int: 207 """return the maximum of the grid shape""" 208 return max(self.grid_shape)
return the maximum of the grid shape
237 def stable_hash_cfg(self) -> int: 238 """return a stable hash of the config""" 239 return int.from_bytes( 240 hashlib.md5( # noqa: S324 241 bytes(self._stable_str_dump(), "ascii") 242 ).digest(), 243 "big", 244 )
return a stable hash of the config
246 def to_fname(self) -> str: 247 """return a unique identifier (valid as a filename) for this config""" 248 n_mazes_str: str = shorten_numerical_to_str(self.n_mazes) 249 maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_") 250 hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH 251 return sanitize_fname( 252 f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}", 253 )
return a unique identifier (valid as a filename) for this config
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.dataset.dataset.GPTDatasetConfig
- name
- seq_len_min
- seq_len_max
- seed
- applied_filters
- summary
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
257@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 258class MazeDatasetConfig(MazeDatasetConfig_base): # type: ignore[misc] 259 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset 260 261 # Parameters: 262 - `name : str` 263 name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname` 264 - `grid_n : int` 265 grid size of the maze (number of rows/columns) 266 - `n_mazes : int` 267 number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate. 268 see `EndpointKwargsType` for more details. 269 - `maze_ctor : Callable` 270 maze generator function. This should be a function that takes a grid size and returns a maze. 271 This will usually be one of the functions in `LatticeMazeGenerators`. 272 - `maze_ctor_kwargs : dict` 273 keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using. 274 - `endpoint_kwargs : EndpointKwargsType` 275 keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info. 276 - `applied_filters : list[dict]` 277 list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, 278 but these are stored with the config in case you want to re-generate the dataset with the same filters. 279 280 """ 281 282 @property 283 def config_version(self) -> str: 284 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 285 return "1.0" 286 287 @property 288 def versions(self) -> dict: 289 """return the versions of the config and the maze_dataset""" 290 return dict( 291 config=self.config_version, 292 maze_dataset=importlib.metadata.version("maze_dataset"), 293 ) 294 295 def serialize(self) -> dict: 296 "serialize the MazeDatasetConfig with all fields and fname" 297 return { 298 **self._serialize_base( 299 applied_filters__skip__collect_generation_meta=False 300 ), 301 "fname": self.to_fname(), 302 "versions": self.versions, 303 } 304 305 def summary(self) -> dict: 306 """return a summary of the config""" 307 # do we run this to make sure it doesn't error? 308 super_summary: dict = super().summary() 309 assert super_summary 310 self_ser: dict = self.serialize() 311 return dict( 312 name=self.name, 313 fname=self.to_fname(), 314 sdc_hash=self.stable_hash_cfg(), 315 seed=self.seed, 316 seq_len_min=self.seq_len_min, 317 seq_len_max=self.seq_len_max, 318 applied_filters=self.applied_filters, 319 grid_n=self_ser["grid_n"], 320 n_mazes=self_ser["n_mazes"], 321 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 322 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 323 endpoint_kwargs=self_ser["endpoint_kwargs"], 324 ) 325 326 def _to_ps_array(self) -> _PercolationSuccessArray: 327 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 328 329 used in predicting the success rate 330 """ 331 try: 332 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 333 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 334 ) 335 assert "p" in self.maze_ctor_kwargs, ( 336 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 337 ) 338 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 339 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 340 ) 341 except AssertionError as e: 342 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 343 raise NoPercolationInConfigError( 344 err_msg, 345 ) from e 346 347 endpoints_unique_flag: int = int( 348 # we are pretty sure it will be an int or bool here 349 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 350 ) 351 352 # adjustment for bknutson0 353 if not ( 354 self.endpoint_kwargs.get("deadend_start", False) 355 and self.endpoint_kwargs.get("deadend_end", False) 356 ): 357 # we didnt train on this, but if either endpoint is not required to be in a dead end 358 # then requiring the endpoints to be unique does not really affect the success rate 359 # (except for very small percolation values, pure percolation generation) 360 endpoints_unique_flag = 0 361 362 return np.array( 363 [ 364 float(self.maze_ctor_kwargs["p"]), 365 float(self.grid_n), 366 float( 367 int( 368 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 369 or self.endpoint_kwargs.get("deadend_end", False), 370 ), 371 ), 372 float(endpoints_unique_flag), 373 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 374 ], 375 dtype=np.float64, 376 ) 377 378 @classmethod 379 def _from_ps_array( 380 cls, 381 arr: _PercolationSuccessArray, 382 name: str = "predict", 383 n_mazes: int = 100, 384 **kwargs, 385 ) -> "MazeDatasetConfig": 386 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 387 388 # Returns: 389 - `MazeDatasetConfig` 390 Config corresponding to `arr` 391 """ 392 return cls( 393 name=name, 394 grid_n=int(arr[1]), 395 n_mazes=n_mazes, 396 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 397 maze_ctor_kwargs={"p": float(arr[0])}, 398 endpoint_kwargs=dict( 399 deadend_start=bool(arr[2]), 400 deadend_end=bool(arr[2]), 401 endpoints_not_equal=bool(arr[3]), 402 except_on_no_valid_endpoint=False, 403 ), 404 **kwargs, 405 ) 406 407 def success_fraction_estimate( 408 self, 409 except_if_all_success_expected: bool = False, 410 ) -> float: 411 """Estimate the success fraction of this config. 412 413 only valid when the generator is a percolation generator, 414 and endpoints are enforced to be dead ends 415 416 more information on where this comes from can be found in 417 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 418 - `estimate_dataset_fractions.ipynb` 419 - `maze_dataset.benchmarks.sweep_fit` 420 421 # Parameters: 422 - `except_if_all_success_expected : bool` 423 if `True`, don't raise an error if the success fraction is below the threshold. 424 will always return `1.0` if the config is not expected to fail 425 426 # Returns: 427 - `float` 428 estimated success fraction 429 430 # Raises: 431 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 432 """ 433 try: 434 return cfg_success_predict_fn(self) 435 436 except NoPercolationInConfigError as e: 437 if except_if_all_success_expected: 438 raise e # noqa: TRY201 439 return 1.0 440 441 def success_fraction_compensate( 442 self, 443 safety_margin: float = 1.2, 444 except_if_all_success_expected: bool = False, 445 epsilon: float = 1e-2, 446 ) -> "MazeDatasetConfig": 447 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 448 449 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 450 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 451 452 more information on where this comes from can be found in 453 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 454 - `estimate_dataset_fractions.ipynb` 455 - `maze_dataset.benchmarks.sweep_fit` 456 457 # Parameters: 458 - `safety_margin : float` 459 safety margin to apply to the success fraction estimate 460 (defaults to `1.2`, or 20% more mazes than estimated) 461 - `except_if_all_success_expected : bool` 462 if `True`, don't raise an error if the success fraction is below the threshold. 463 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 464 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 465 since `safety_margin` is still applied. 466 (defaults to `False`) 467 - `epsilon : float` 468 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 469 (defaults to `1e-2`) 470 471 # Returns: 472 - `MazeDatasetConfig` 473 new config with adjusted `n_mazes` 474 475 # Raises: 476 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 477 """ 478 # compute and check the success fraction 479 success_fraction: float = self.success_fraction_estimate( 480 except_if_all_success_expected=except_if_all_success_expected, 481 ) 482 if success_fraction < epsilon: 483 err_msg: str = ( 484 f"{success_fraction = } is below the threshold of {epsilon = }" 485 ) 486 raise SuccessChanceTooSmallError( 487 err_msg, 488 ) 489 490 # compute the new number of mazes 491 n_mazes: int = self.n_mazes 492 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 493 494 # put it in a new config and return 495 cfg_dict: dict = self.serialize() 496 cfg_dict["n_mazes"] = new_n_mazes 497 return MazeDatasetConfig.load(cfg_dict)
config object which is passed to MazeDataset.from_config
to generate or load a dataset
Parameters:
name : str
name of the dataset -- this can be anything, but should be filesystem safe since we use it in thefname
grid_n : int
grid size of the maze (number of rows/columns)n_mazes : int
number of mazes to request. For some combinations ofendpoint_kwargs
andmaze_ctor
, not all mazes might successfully generate. seeEndpointKwargsType
for more details.maze_ctor : Callable
maze generator function. This should be a function that takes a grid size and returns a maze. This will usually be one of the functions inLatticeMazeGenerators
.maze_ctor_kwargs : dict
keyword arguments to pass to the maze generator function. Specific to themaze_ctor
you are using.endpoint_kwargs : EndpointKwargsType
keyword arguments passed toLatticeMaze.generate_random_path()
. seeEndpointKwargsType
for more info.applied_filters : list[dict]
list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, but these are stored with the config in case you want to re-generate the dataset with the same filters.
282 @property 283 def config_version(self) -> str: 284 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 285 return "1.0"
return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config
287 @property 288 def versions(self) -> dict: 289 """return the versions of the config and the maze_dataset""" 290 return dict( 291 config=self.config_version, 292 maze_dataset=importlib.metadata.version("maze_dataset"), 293 )
return the versions of the config and the maze_dataset
295 def serialize(self) -> dict: 296 "serialize the MazeDatasetConfig with all fields and fname" 297 return { 298 **self._serialize_base( 299 applied_filters__skip__collect_generation_meta=False 300 ), 301 "fname": self.to_fname(), 302 "versions": self.versions, 303 }
serialize the MazeDatasetConfig with all fields and fname
305 def summary(self) -> dict: 306 """return a summary of the config""" 307 # do we run this to make sure it doesn't error? 308 super_summary: dict = super().summary() 309 assert super_summary 310 self_ser: dict = self.serialize() 311 return dict( 312 name=self.name, 313 fname=self.to_fname(), 314 sdc_hash=self.stable_hash_cfg(), 315 seed=self.seed, 316 seq_len_min=self.seq_len_min, 317 seq_len_max=self.seq_len_max, 318 applied_filters=self.applied_filters, 319 grid_n=self_ser["grid_n"], 320 n_mazes=self_ser["n_mazes"], 321 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 322 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 323 endpoint_kwargs=self_ser["endpoint_kwargs"], 324 )
return a summary of the config
407 def success_fraction_estimate( 408 self, 409 except_if_all_success_expected: bool = False, 410 ) -> float: 411 """Estimate the success fraction of this config. 412 413 only valid when the generator is a percolation generator, 414 and endpoints are enforced to be dead ends 415 416 more information on where this comes from can be found in 417 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 418 - `estimate_dataset_fractions.ipynb` 419 - `maze_dataset.benchmarks.sweep_fit` 420 421 # Parameters: 422 - `except_if_all_success_expected : bool` 423 if `True`, don't raise an error if the success fraction is below the threshold. 424 will always return `1.0` if the config is not expected to fail 425 426 # Returns: 427 - `float` 428 estimated success fraction 429 430 # Raises: 431 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 432 """ 433 try: 434 return cfg_success_predict_fn(self) 435 436 except NoPercolationInConfigError as e: 437 if except_if_all_success_expected: 438 raise e # noqa: TRY201 439 return 1.0
Estimate the success fraction of this config.
only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends
more information on where this comes from can be found in
cfg_success_predict_fn()
frommaze_dataset.dataset.success_predict_math
estimate_dataset_fractions.ipynb
maze_dataset.benchmarks.sweep_fit
Parameters:
except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. will always return1.0
if the config is not expected to fail
Returns:
float
estimated success fraction
Raises:
NoPercolationInConfigError
: if the config is not expected to fail, andexcept_if_all_success_expected
isFalse
441 def success_fraction_compensate( 442 self, 443 safety_margin: float = 1.2, 444 except_if_all_success_expected: bool = False, 445 epsilon: float = 1e-2, 446 ) -> "MazeDatasetConfig": 447 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 448 449 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 450 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 451 452 more information on where this comes from can be found in 453 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 454 - `estimate_dataset_fractions.ipynb` 455 - `maze_dataset.benchmarks.sweep_fit` 456 457 # Parameters: 458 - `safety_margin : float` 459 safety margin to apply to the success fraction estimate 460 (defaults to `1.2`, or 20% more mazes than estimated) 461 - `except_if_all_success_expected : bool` 462 if `True`, don't raise an error if the success fraction is below the threshold. 463 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 464 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 465 since `safety_margin` is still applied. 466 (defaults to `False`) 467 - `epsilon : float` 468 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 469 (defaults to `1e-2`) 470 471 # Returns: 472 - `MazeDatasetConfig` 473 new config with adjusted `n_mazes` 474 475 # Raises: 476 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 477 """ 478 # compute and check the success fraction 479 success_fraction: float = self.success_fraction_estimate( 480 except_if_all_success_expected=except_if_all_success_expected, 481 ) 482 if success_fraction < epsilon: 483 err_msg: str = ( 484 f"{success_fraction = } is below the threshold of {epsilon = }" 485 ) 486 raise SuccessChanceTooSmallError( 487 err_msg, 488 ) 489 490 # compute the new number of mazes 491 n_mazes: int = self.n_mazes 492 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 493 494 # put it in a new config and return 495 cfg_dict: dict = self.serialize() 496 cfg_dict["n_mazes"] = new_n_mazes 497 return MazeDatasetConfig.load(cfg_dict)
return a new MazeDatasetConfig
like this one with n_mazes
adjusted to compensate for the success fraction
calls MazeDatasetConfig.success_fraction_estimate()
to get the success fraction, and then
computes the new number of mazes as n_mazes = n_mazes * safety_margin / success_fraction + 1
more information on where this comes from can be found in
cfg_success_predict_fn()
frommaze_dataset.dataset.success_predict_math
estimate_dataset_fractions.ipynb
maze_dataset.benchmarks.sweep_fit
Parameters:
safety_margin : float
safety margin to apply to the success fraction estimate (defaults to1.2
, or 20% more mazes than estimated)except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. this is passed toMazeDatasetConfig.success_fraction_estimate
. if your config isn't expected to fail, passing this might mean you generate more mazes than needed sincesafety_margin
is still applied. (defaults toFalse
)epsilon : float
raiseSuccessChanceTooSmallError
if the success fraction is below this threshold (defaults to1e-2
)
Returns:
MazeDatasetConfig
new config with adjustedn_mazes
Raises:
SuccessChanceTooSmallError
: if the computed success fraction is belowepsilon
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- MazeDatasetConfig_base
- grid_n
- n_mazes
- maze_ctor
- maze_ctor_kwargs
- endpoint_kwargs
- grid_shape
- grid_shape_np
- max_grid_n
- stable_hash_cfg
- to_fname
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict