maze_dataset.dataset
MazeDatasetConfig
s are used to create a MazeDataset
via MazeDataset.from_config(cfg)
"
When initializing mazes, further configuration options can be specified through the from_config()
factory method as necessary. Options include 1) whether to generate the dataset during runtime or load an existing dataset, 2) if and how to parallelize generation, and 3) where to store the generated dataset. Full documentation of configuration options is available in our repository [@maze-dataset-github]. Available maze generation algorithms are static methods of the LatticeMazeGenerators
class.
Furthermore, a dataset of mazes can be filtered to satisfy certain properties:
dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3)
Custom filters can be specified, and several filters are included:
path_length(min_length: int)
: shortest length from the origin to target should be at leastmin_length
.start_end_distance(min_distance: int)
: Manhattan distance between start and end should be at leastmin_distance
, ignoring walls.remove_duplicates(...)
: remove mazes which are similar to others in the dataset, measured via Hamming distance.remove_duplicates_fast()
: remove mazes which are exactly identical to others in the dataset.
All implemented maze generation algorithms are stochastic by nature. For reproducibility, the seed
parameter of MazeDatasetConfig
may be set. In practice, we do not find that exact duplicates of mazes are generated with any meaningful frequency,
even when generating large datasets.
1"""`MazeDatasetConfig`s are used to create a `MazeDataset` via `MazeDataset.from_config(cfg)`" 2 3When initializing mazes, further configuration options can be specified through the `from_config()` factory method as necessary. Options include 1) whether to generate the dataset during runtime or load an existing dataset, 2) if and how to parallelize generation, and 3) where to store the generated dataset. Full documentation of configuration options is available in our repository [@maze-dataset-github]. Available maze generation algorithms are static methods of the `LatticeMazeGenerators` class. 4 5Furthermore, a dataset of mazes can be filtered to satisfy certain properties: 6 7```python 8dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3) 9``` 10 11Custom filters can be specified, and several filters are included: 12 13- `path_length(min_length: int)`: shortest length from the origin to target should be at least `min_length`. 14- `start_end_distance(min_distance: int)`: Manhattan distance between start and end should be at least `min_distance`, ignoring walls. 15- `remove_duplicates(...)`: remove mazes which are similar to others in the dataset, measured via Hamming distance. 16- `remove_duplicates_fast()`: remove mazes which are exactly identical to others in the dataset. 17 18All implemented maze generation algorithms are stochastic by nature. For reproducibility, the `seed` parameter of `MazeDatasetConfig` may be set. In practice, we do not find that exact duplicates of mazes are generated with any meaningful frequency, 19even when generating large datasets. 20 21""" 22 23from maze_dataset.dataset.collected_dataset import ( 24 MazeDatasetCollection, 25 MazeDatasetCollectionConfig, 26) 27from maze_dataset.dataset.maze_dataset import MazeDataset 28from maze_dataset.dataset.maze_dataset_config import MazeDatasetConfig 29 30__all__ = [ 31 # submodules 32 "collected_dataset", 33 "configs", 34 "dataset", 35 "filters", 36 "maze_dataset_config", 37 "maze_dataset", 38 "rasterized", 39 "success_predict_math", 40 # dataset classes 41 "MazeDataset", 42 "MazeDatasetConfig", 43 "MazeDatasetCollection", 44 "MazeDatasetCollectionConfig", 45]
114class MazeDataset(GPTDataset[MazeDatasetConfig]): # noqa: PLW1641 115 """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`""" 116 117 def __init__( 118 self, 119 cfg: MazeDatasetConfig, 120 mazes: typing.Sequence[SolvedMaze], 121 generation_metadata_collected: dict | None = None, 122 ) -> None: 123 """initialize a maze dataset from a config and a list of solved mazes""" 124 super().__init__() 125 self.cfg: MazeDatasetConfig = cfg 126 self.mazes: list[SolvedMaze] = list(mazes) 127 self.generation_metadata_collected: dict | None = generation_metadata_collected 128 129 # TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset" [override] 130 @classmethod 131 def from_config( # type: ignore[override] 132 cls, 133 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 134 cfg: MazeDatasetConfig, # type: ignore[override] 135 do_generate: bool = True, 136 load_local: bool = True, 137 save_local: bool = True, 138 zanj: ZANJ | None = None, 139 do_download: bool = True, 140 local_base_path: Path = Path("data/maze_dataset"), 141 except_on_config_mismatch: bool = True, 142 allow_generation_metadata_filter_mismatch: bool = True, 143 verbose: bool = False, 144 **kwargs, 145 ) -> "MazeDataset": 146 """create a maze dataset from a config 147 148 priority of loading: 149 1. load from local 150 2. download 151 3. generate 152 153 """ 154 return cast( 155 "MazeDataset", 156 super().from_config( 157 cfg=cfg, 158 do_generate=do_generate, 159 load_local=load_local, 160 save_local=save_local, 161 zanj=zanj, 162 do_download=do_download, 163 local_base_path=local_base_path, 164 except_on_config_mismatch=except_on_config_mismatch, 165 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 166 verbose=verbose, 167 **kwargs, 168 ), 169 ) 170 171 def data_hash(self) -> int: 172 """return a hash of the data""" 173 return stable_hash(str(tuple([x.serialize() for x in self.mazes]))) 174 175 def __getitem__(self, i: int) -> SolvedMaze: 176 """get a maze by index""" 177 return self.mazes[i] 178 179 def __iter__(self) -> typing.Iterator[SolvedMaze]: 180 """iterate over the mazes""" 181 return iter(self.mazes) 182 183 def __deepcopy__(self, memo) -> "MazeDataset": # noqa: ANN001 184 """deepcopy the dataset 185 186 FIX: this isnt actually a deepcopy I think? 187 """ 188 return MazeDataset.load(self._serialize_full()) 189 190 # TYPING: get type hints on the tokenizer here 191 @overload 192 def as_tokens( 193 self, 194 maze_tokenizer, # noqa: ANN001 195 limit: int | None = None, 196 join_tokens_individual_maze: Literal[False] = False, 197 ) -> list[list[str]]: ... 198 @overload 199 def as_tokens( 200 self, 201 maze_tokenizer, # noqa: ANN001 202 limit: int | None = None, 203 join_tokens_individual_maze: Literal[True] = True, 204 ) -> list[str]: ... 205 def as_tokens( 206 self, 207 maze_tokenizer, # TODO: MazeTokenizer 208 limit: int | None = None, 209 join_tokens_individual_maze: bool = False, 210 ) -> list[list[str]] | list[str]: 211 """return the dataset as tokens according to the passed `maze_tokenizer` 212 213 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 214 215 if `join_tokens_individual_maze` is True, then the tokens of each maze are 216 joined with a space, and the result is a list of strings. 217 i.e.: 218 219 >>> dataset.as_tokens(join_tokens_individual_maze=False) 220 [["a", "b", "c"], ["d", "e", "f"]] 221 >>> dataset.as_tokens(join_tokens_individual_maze=True) 222 ["a b c", "d e f"] 223 """ 224 output: list[list[str]] = [ 225 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 226 ] 227 if join_tokens_individual_maze: 228 return [" ".join(tokens) for tokens in output] 229 else: 230 return output 231 232 def __len__(self) -> int: 233 """return the number of mazes in the dataset""" 234 return len(self.mazes) 235 236 def __eq__(self, other: object) -> bool: 237 """compare two datasets""" 238 if not isinstance(other, MazeDataset): 239 raise NotImplementedError( 240 "can only compare with other MazeDataset objects", 241 ) 242 # TODO: compare hashes of data instead of the data itself? 243 return self.cfg == other.cfg and self.mazes == other.mazes 244 245 def assert_equal(self, other: "MazeDataset") -> None: 246 """assert that two datasets are equal""" 247 assert isinstance(other, MazeDataset) 248 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 249 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }" 250 251 @classmethod 252 def generate( 253 cls, 254 cfg: MazeDatasetConfig, 255 gen_parallel: bool = False, 256 pool_kwargs: dict | None = None, 257 verbose: bool = False, 258 # TODO: what to do when unexpected kwargs are passed? 259 **kwargs, # noqa: ARG003 260 ) -> "MazeDataset": 261 """Generate a maze dataset given a config and some generation parameters""" 262 # Copy the config to avoid modifying the original 263 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 264 json.loads(json.dumps(cfg.serialize())), 265 ) 266 267 if pool_kwargs is None: 268 pool_kwargs = dict() 269 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 270 271 solved_mazes: list[SolvedMaze | None] 272 # Configure tqdm for progress bar 273 tqdm_kwargs: dict = dict( 274 total=cfg_cpy.n_mazes, 275 unit="maze", 276 desc="generating & solving mazes", 277 disable=not verbose, 278 ) 279 # TODO: don't use the global unless generating in parallel! 280 if gen_parallel: 281 with multiprocessing.Pool( 282 **pool_kwargs, 283 initializer=_maze_gen_init_worker, 284 initargs=(cfg_cpy,), 285 ) as pool: 286 solved_mazes = list( 287 tqdm.tqdm( 288 pool.imap(_generate_maze_helper, maze_indexes), 289 **tqdm_kwargs, 290 ), 291 ) 292 293 else: 294 _maze_gen_init_worker(cfg_cpy) 295 solved_mazes = list( 296 tqdm.tqdm( 297 map( 298 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 299 # why does it think tolist() returns a string? 300 _generate_maze_helper, # type: ignore[arg-type] 301 maze_indexes.tolist(), 302 ), 303 **tqdm_kwargs, 304 ), 305 ) 306 307 # Filter out None values explicitly after ensuring all results are collected 308 solved_mazes_: list[SolvedMaze] = [ 309 maze for maze in solved_mazes if maze is not None 310 ] 311 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 312 313 # Update the config with the actual number of mazes 314 cfg_cpy.n_mazes = len(solved_mazes_) 315 316 dataset: MazeDataset = cls( 317 cfg=cfg_cpy, 318 mazes=solved_mazes_, 319 ) 320 321 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 322 323 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 324 325 return dataset 326 327 @classmethod 328 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 329 "(not implemented yet!) download a maze dataset from the internet" 330 raise NotImplementedError("not implemented yet") 331 332 @classmethod 333 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 334 """load from zanj/json""" 335 if data[_FORMAT_KEY] == "MazeDataset:minimal": 336 return cls._load_minimal(data) 337 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 338 return cls._load_minimal_soln_cat(data) 339 elif data[_FORMAT_KEY] == "MazeDataset": 340 if ( 341 SERIALIZE_MINIMAL_THRESHOLD == -1 342 ): # Allow access to `_load_legacy` for profiling 343 return cls._load_legacy(data) 344 return cls._load_full(data) 345 else: 346 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 347 raise KeyError( 348 err_msg, 349 ) 350 351 @classmethod 352 def _load_full(cls, data: JSONdict) -> "MazeDataset": 353 assert data[_FORMAT_KEY] == "MazeDataset" 354 return cls( 355 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 356 mazes=load_item_recursive(data["mazes"], tuple()), 357 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 358 ) 359 360 @classmethod 361 def _load_minimal(cls, data: JSONdict) -> "MazeDataset": 362 assert data[_FORMAT_KEY] == "MazeDataset:minimal" 363 return cls( 364 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 365 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 366 mazes=[ 367 SolvedMaze( 368 clist, 369 soln[:slen, ...], 370 ) 371 for clist, slen, soln in zip( 372 load_item_recursive(data["maze_connection_lists"], tuple()), 373 load_item_recursive(data["maze_solution_lengths"], tuple()), 374 load_item_recursive(data["maze_solutions"], tuple()), 375 strict=False, 376 # load_item_recursive(data["maze_endpoints"], tuple()), 377 ) 378 ], 379 ) 380 381 @classmethod 382 def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset": 383 assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat" 384 385 maze_solution_lengths = load_item_recursive( 386 data["maze_solution_lengths"], 387 tuple(), 388 ) 389 maze_solutions_concat = load_item_recursive( 390 data["maze_solutions_concat"], 391 tuple(), 392 ) 393 maze_solutions = np.split( 394 maze_solutions_concat, 395 np.cumsum(maze_solution_lengths)[:-1], 396 axis=0, 397 ) 398 399 return cls( 400 cfg=load_item_recursive(data["cfg"], tuple()), 401 generation_metadata_collected=load_item_recursive( 402 data["generation_metadata_collected"], 403 tuple(), 404 ), 405 mazes=[ 406 SolvedMaze( 407 connection_list=clist, 408 solution=soln, 409 ) 410 for clist, soln in zip( 411 load_item_recursive(data["maze_connection_lists"], tuple()), 412 # load_item_recursive(data["maze_endpoints"], tuple()), 413 maze_solutions, 414 strict=False, 415 ) 416 ], 417 ) 418 419 @classmethod 420 def _load_legacy(cls, data: JSONdict) -> "MazeDataset": 421 """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison.""" 422 assert data[_FORMAT_KEY] == "MazeDataset" 423 return cls( 424 **{ 425 key: load_item_recursive(data[key], tuple()) 426 for key in ["cfg", "mazes", "generation_metadata_collected"] 427 }, 428 ) 429 430 def serialize(self) -> JSONdict: 431 """serialize to zanj/json""" 432 if ( 433 SERIALIZE_MINIMAL_THRESHOLD is not None 434 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 435 ): 436 return self._serialize_minimal() 437 return self._serialize_full() 438 439 def _serialize_full(self) -> JSONdict: 440 return { 441 _FORMAT_KEY: "MazeDataset", 442 "cfg": json_serialize(self.cfg), 443 "fname": self.cfg.to_fname(), 444 "mazes": json_serialize(self.mazes), 445 "generation_metadata_collected": json_serialize( 446 self.generation_metadata_collected, 447 ), 448 } 449 450 def _serialize_minimal(self) -> JSONdict: 451 "alternate serialization where metadata is collected and mazes are stored in concatenated form" 452 filtered_meta: MazeDataset 453 if self.generation_metadata_collected is None: 454 filtered_meta = self.filter_by.collect_generation_meta() 455 else: 456 filtered_meta = self 457 458 max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes) 459 n_mazes: int = len(filtered_meta.mazes) 460 grid_n: int = filtered_meta.cfg.grid_n 461 462 maze_connection_lists: np.ndarray = np.empty( 463 (n_mazes, 2, grid_n, grid_n), 464 dtype=np.bool_, 465 ) 466 # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 467 maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32) 468 maze_solutions: np.ndarray = np.empty( 469 (n_mazes, max_solution_len, 2), 470 dtype=np.int8, 471 ) 472 473 for idx, maze in enumerate(filtered_meta.mazes): 474 maze_connection_lists[idx] = maze.connection_list 475 # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 476 maze_solution_lengths[idx] = maze.solution.shape[0] 477 maze_solutions[idx, : maze.solution.shape[0]] = maze.solution 478 479 return { 480 _FORMAT_KEY: "MazeDataset:minimal", 481 "cfg": json_serialize(filtered_meta.cfg), 482 "fname": filtered_meta.cfg.to_fname(), 483 "generation_metadata_collected": json_serialize( 484 filtered_meta.generation_metadata_collected, 485 ), 486 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 487 # "maze_endpoints": maze_endpoints, 488 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 489 "maze_solutions": maze_solutions, # type: ignore[dict-item] 490 } 491 492 def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict: 493 "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form" 494 filtered_meta: MazeDataset 495 if self.generation_metadata_collected is None: 496 filtered_meta = self.filter_by.collect_generation_meta() 497 else: 498 filtered_meta = self 499 500 maze_solution_lengths: np.ndarray = np.array( 501 [m.solution.shape[0] for m in filtered_meta.mazes], 502 dtype=np.int32, 503 ) 504 n_mazes: int = len(filtered_meta.mazes) 505 grid_n: int = filtered_meta.cfg.grid_n 506 total_solution_len: int = np.sum(maze_solution_lengths) 507 508 maze_connection_lists: np.ndarray = np.empty( 509 (n_mazes, 2, grid_n, grid_n), 510 dtype=np.bool_, 511 ) 512 maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 513 maze_solutions_concat: np.ndarray = np.empty( 514 (total_solution_len, 2), 515 dtype=np.int8, 516 ) 517 518 solutions_running_idx: int = 0 519 for idx, maze in enumerate(filtered_meta.mazes): 520 maze_connection_lists[idx] = maze.connection_list 521 maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 522 soln_len: int = maze.solution.shape[0] 523 maze_solution_lengths[idx] = soln_len 524 maze_solutions_concat[ 525 solutions_running_idx : solutions_running_idx + soln_len 526 ] = maze.solution 527 solutions_running_idx += soln_len 528 529 return { 530 _FORMAT_KEY: "MazeDataset:minimal_soln_cat", 531 "cfg": json_serialize(filtered_meta.cfg), 532 "fname": filtered_meta.cfg.to_fname(), 533 "generation_metadata_collected": json_serialize( 534 filtered_meta.generation_metadata_collected, 535 ), 536 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 537 "maze_endpoints": maze_endpoints, # type: ignore[dict-item] 538 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 539 "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item] 540 } 541 542 def update_self_config(self) -> None: 543 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 544 if self.cfg.n_mazes != len(self.mazes): 545 warnings.warn( 546 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 547 ) 548 self.cfg.n_mazes = len(self.mazes) 549 550 def custom_maze_filter( 551 self, 552 method: typing.Callable[[SolvedMaze], bool], 553 **kwargs, 554 ) -> "MazeDataset": 555 """filter the dataset using a custom method""" 556 output: MazeDataset = MazeDataset( 557 cfg=copy.deepcopy(self.cfg), 558 mazes=[m for m in self.mazes if method(m, **kwargs)], 559 ) 560 output.cfg.applied_filters.append( 561 { 562 "name": f"__custom__:{method.__name__}", 563 "kwargs": kwargs, 564 }, 565 ) 566 output.update_self_config() 567 return output
a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config
117 def __init__( 118 self, 119 cfg: MazeDatasetConfig, 120 mazes: typing.Sequence[SolvedMaze], 121 generation_metadata_collected: dict | None = None, 122 ) -> None: 123 """initialize a maze dataset from a config and a list of solved mazes""" 124 super().__init__() 125 self.cfg: MazeDatasetConfig = cfg 126 self.mazes: list[SolvedMaze] = list(mazes) 127 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize a maze dataset from a config and a list of solved mazes
130 @classmethod 131 def from_config( # type: ignore[override] 132 cls, 133 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 134 cfg: MazeDatasetConfig, # type: ignore[override] 135 do_generate: bool = True, 136 load_local: bool = True, 137 save_local: bool = True, 138 zanj: ZANJ | None = None, 139 do_download: bool = True, 140 local_base_path: Path = Path("data/maze_dataset"), 141 except_on_config_mismatch: bool = True, 142 allow_generation_metadata_filter_mismatch: bool = True, 143 verbose: bool = False, 144 **kwargs, 145 ) -> "MazeDataset": 146 """create a maze dataset from a config 147 148 priority of loading: 149 1. load from local 150 2. download 151 3. generate 152 153 """ 154 return cast( 155 "MazeDataset", 156 super().from_config( 157 cfg=cfg, 158 do_generate=do_generate, 159 load_local=load_local, 160 save_local=save_local, 161 zanj=zanj, 162 do_download=do_download, 163 local_base_path=local_base_path, 164 except_on_config_mismatch=except_on_config_mismatch, 165 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 166 verbose=verbose, 167 **kwargs, 168 ), 169 )
create a maze dataset from a config
priority of loading:
- load from local
- download
- generate
171 def data_hash(self) -> int: 172 """return a hash of the data""" 173 return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
return a hash of the data
205 def as_tokens( 206 self, 207 maze_tokenizer, # TODO: MazeTokenizer 208 limit: int | None = None, 209 join_tokens_individual_maze: bool = False, 210 ) -> list[list[str]] | list[str]: 211 """return the dataset as tokens according to the passed `maze_tokenizer` 212 213 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 214 215 if `join_tokens_individual_maze` is True, then the tokens of each maze are 216 joined with a space, and the result is a list of strings. 217 i.e.: 218 219 >>> dataset.as_tokens(join_tokens_individual_maze=False) 220 [["a", "b", "c"], ["d", "e", "f"]] 221 >>> dataset.as_tokens(join_tokens_individual_maze=True) 222 ["a b c", "d e f"] 223 """ 224 output: list[list[str]] = [ 225 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 226 ] 227 if join_tokens_individual_maze: 228 return [" ".join(tokens) for tokens in output] 229 else: 230 return output
return the dataset as tokens according to the passed maze_tokenizer
the maze_tokenizer
should be either a MazeTokenizer
or a MazeTokenizerModular
if join_tokens_individual_maze
is True, then the tokens of each maze are
joined with a space, and the result is a list of strings.
i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
245 def assert_equal(self, other: "MazeDataset") -> None: 246 """assert that two datasets are equal""" 247 assert isinstance(other, MazeDataset) 248 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 249 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
assert that two datasets are equal
251 @classmethod 252 def generate( 253 cls, 254 cfg: MazeDatasetConfig, 255 gen_parallel: bool = False, 256 pool_kwargs: dict | None = None, 257 verbose: bool = False, 258 # TODO: what to do when unexpected kwargs are passed? 259 **kwargs, # noqa: ARG003 260 ) -> "MazeDataset": 261 """Generate a maze dataset given a config and some generation parameters""" 262 # Copy the config to avoid modifying the original 263 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 264 json.loads(json.dumps(cfg.serialize())), 265 ) 266 267 if pool_kwargs is None: 268 pool_kwargs = dict() 269 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 270 271 solved_mazes: list[SolvedMaze | None] 272 # Configure tqdm for progress bar 273 tqdm_kwargs: dict = dict( 274 total=cfg_cpy.n_mazes, 275 unit="maze", 276 desc="generating & solving mazes", 277 disable=not verbose, 278 ) 279 # TODO: don't use the global unless generating in parallel! 280 if gen_parallel: 281 with multiprocessing.Pool( 282 **pool_kwargs, 283 initializer=_maze_gen_init_worker, 284 initargs=(cfg_cpy,), 285 ) as pool: 286 solved_mazes = list( 287 tqdm.tqdm( 288 pool.imap(_generate_maze_helper, maze_indexes), 289 **tqdm_kwargs, 290 ), 291 ) 292 293 else: 294 _maze_gen_init_worker(cfg_cpy) 295 solved_mazes = list( 296 tqdm.tqdm( 297 map( 298 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 299 # why does it think tolist() returns a string? 300 _generate_maze_helper, # type: ignore[arg-type] 301 maze_indexes.tolist(), 302 ), 303 **tqdm_kwargs, 304 ), 305 ) 306 307 # Filter out None values explicitly after ensuring all results are collected 308 solved_mazes_: list[SolvedMaze] = [ 309 maze for maze in solved_mazes if maze is not None 310 ] 311 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 312 313 # Update the config with the actual number of mazes 314 cfg_cpy.n_mazes = len(solved_mazes_) 315 316 dataset: MazeDataset = cls( 317 cfg=cfg_cpy, 318 mazes=solved_mazes_, 319 ) 320 321 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 322 323 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 324 325 return dataset
Generate a maze dataset given a config and some generation parameters
327 @classmethod 328 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 329 "(not implemented yet!) download a maze dataset from the internet" 330 raise NotImplementedError("not implemented yet")
(not implemented yet!) download a maze dataset from the internet
332 @classmethod 333 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 334 """load from zanj/json""" 335 if data[_FORMAT_KEY] == "MazeDataset:minimal": 336 return cls._load_minimal(data) 337 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 338 return cls._load_minimal_soln_cat(data) 339 elif data[_FORMAT_KEY] == "MazeDataset": 340 if ( 341 SERIALIZE_MINIMAL_THRESHOLD == -1 342 ): # Allow access to `_load_legacy` for profiling 343 return cls._load_legacy(data) 344 return cls._load_full(data) 345 else: 346 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 347 raise KeyError( 348 err_msg, 349 )
load from zanj/json
430 def serialize(self) -> JSONdict: 431 """serialize to zanj/json""" 432 if ( 433 SERIALIZE_MINIMAL_THRESHOLD is not None 434 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 435 ): 436 return self._serialize_minimal() 437 return self._serialize_full()
serialize to zanj/json
542 def update_self_config(self) -> None: 543 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 544 if self.cfg.n_mazes != len(self.mazes): 545 warnings.warn( 546 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 547 ) 548 self.cfg.n_mazes = len(self.mazes)
update the config to match the current state of the dataset (number of mazes, such as after filtering)
550 def custom_maze_filter( 551 self, 552 method: typing.Callable[[SolvedMaze], bool], 553 **kwargs, 554 ) -> "MazeDataset": 555 """filter the dataset using a custom method""" 556 output: MazeDataset = MazeDataset( 557 cfg=copy.deepcopy(self.cfg), 558 mazes=[m for m in self.mazes if method(m, **kwargs)], 559 ) 560 output.cfg.applied_filters.append( 561 { 562 "name": f"__custom__:{method.__name__}", 563 "kwargs": kwargs, 564 }, 565 ) 566 output.update_self_config() 567 return output
filter the dataset using a custom method
Inherited Members
257@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 258class MazeDatasetConfig(MazeDatasetConfig_base): # type: ignore[misc] 259 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset 260 261 # Parameters: 262 - `name : str` 263 name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname` 264 - `grid_n : int` 265 grid size of the maze (number of rows/columns) 266 - `n_mazes : int` 267 number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate. 268 see `EndpointKwargsType` for more details. 269 - `maze_ctor : Callable` 270 maze generator function. This should be a function that takes a grid size and returns a maze. 271 This will usually be one of the functions in `LatticeMazeGenerators`. 272 - `maze_ctor_kwargs : dict` 273 keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using. 274 - `endpoint_kwargs : EndpointKwargsType` 275 keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info. 276 - `applied_filters : list[dict]` 277 list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, 278 but these are stored with the config in case you want to re-generate the dataset with the same filters. 279 280 """ 281 282 @property 283 def config_version(self) -> str: 284 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 285 return "1.0" 286 287 @property 288 def versions(self) -> dict: 289 """return the versions of the config and the maze_dataset""" 290 return dict( 291 config=self.config_version, 292 maze_dataset=importlib.metadata.version("maze_dataset"), 293 ) 294 295 def serialize(self) -> dict: 296 "serialize the MazeDatasetConfig with all fields and fname" 297 return { 298 **self._serialize_base( 299 applied_filters__skip__collect_generation_meta=False 300 ), 301 "fname": self.to_fname(), 302 "versions": self.versions, 303 } 304 305 def summary(self) -> dict: 306 """return a summary of the config""" 307 # do we run this to make sure it doesn't error? 308 super_summary: dict = super().summary() 309 assert super_summary 310 self_ser: dict = self.serialize() 311 return dict( 312 name=self.name, 313 fname=self.to_fname(), 314 sdc_hash=self.stable_hash_cfg(), 315 seed=self.seed, 316 seq_len_min=self.seq_len_min, 317 seq_len_max=self.seq_len_max, 318 applied_filters=self.applied_filters, 319 grid_n=self_ser["grid_n"], 320 n_mazes=self_ser["n_mazes"], 321 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 322 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 323 endpoint_kwargs=self_ser["endpoint_kwargs"], 324 ) 325 326 def _to_ps_array(self) -> _PercolationSuccessArray: 327 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 328 329 used in predicting the success rate 330 """ 331 try: 332 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 333 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 334 ) 335 assert "p" in self.maze_ctor_kwargs, ( 336 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 337 ) 338 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 339 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 340 ) 341 except AssertionError as e: 342 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 343 raise NoPercolationInConfigError( 344 err_msg, 345 ) from e 346 347 endpoints_unique_flag: int = int( 348 # we are pretty sure it will be an int or bool here 349 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 350 ) 351 352 # adjustment for bknutson0 353 if not ( 354 self.endpoint_kwargs.get("deadend_start", False) 355 and self.endpoint_kwargs.get("deadend_end", False) 356 ): 357 # we didnt train on this, but if either endpoint is not required to be in a dead end 358 # then requiring the endpoints to be unique does not really affect the success rate 359 # (except for very small percolation values, pure percolation generation) 360 endpoints_unique_flag = 0 361 362 return np.array( 363 [ 364 float(self.maze_ctor_kwargs["p"]), 365 float(self.grid_n), 366 float( 367 int( 368 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 369 or self.endpoint_kwargs.get("deadend_end", False), 370 ), 371 ), 372 float(endpoints_unique_flag), 373 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 374 ], 375 dtype=np.float64, 376 ) 377 378 @classmethod 379 def _from_ps_array( 380 cls, 381 arr: _PercolationSuccessArray, 382 name: str = "predict", 383 n_mazes: int = 100, 384 **kwargs, 385 ) -> "MazeDatasetConfig": 386 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 387 388 # Returns: 389 - `MazeDatasetConfig` 390 Config corresponding to `arr` 391 """ 392 return cls( 393 name=name, 394 grid_n=int(arr[1]), 395 n_mazes=n_mazes, 396 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 397 maze_ctor_kwargs={"p": float(arr[0])}, 398 endpoint_kwargs=dict( 399 deadend_start=bool(arr[2]), 400 deadend_end=bool(arr[2]), 401 endpoints_not_equal=bool(arr[3]), 402 except_on_no_valid_endpoint=False, 403 ), 404 **kwargs, 405 ) 406 407 def success_fraction_estimate( 408 self, 409 except_if_all_success_expected: bool = False, 410 ) -> float: 411 """Estimate the success fraction of this config. 412 413 only valid when the generator is a percolation generator, 414 and endpoints are enforced to be dead ends 415 416 more information on where this comes from can be found in 417 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 418 - `estimate_dataset_fractions.ipynb` 419 - `maze_dataset.benchmarks.sweep_fit` 420 421 # Parameters: 422 - `except_if_all_success_expected : bool` 423 if `True`, don't raise an error if the success fraction is below the threshold. 424 will always return `1.0` if the config is not expected to fail 425 426 # Returns: 427 - `float` 428 estimated success fraction 429 430 # Raises: 431 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 432 """ 433 try: 434 return cfg_success_predict_fn(self) 435 436 except NoPercolationInConfigError as e: 437 if except_if_all_success_expected: 438 raise e # noqa: TRY201 439 return 1.0 440 441 def success_fraction_compensate( 442 self, 443 safety_margin: float = 1.2, 444 except_if_all_success_expected: bool = False, 445 epsilon: float = 1e-2, 446 ) -> "MazeDatasetConfig": 447 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 448 449 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 450 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 451 452 more information on where this comes from can be found in 453 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 454 - `estimate_dataset_fractions.ipynb` 455 - `maze_dataset.benchmarks.sweep_fit` 456 457 # Parameters: 458 - `safety_margin : float` 459 safety margin to apply to the success fraction estimate 460 (defaults to `1.2`, or 20% more mazes than estimated) 461 - `except_if_all_success_expected : bool` 462 if `True`, don't raise an error if the success fraction is below the threshold. 463 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 464 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 465 since `safety_margin` is still applied. 466 (defaults to `False`) 467 - `epsilon : float` 468 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 469 (defaults to `1e-2`) 470 471 # Returns: 472 - `MazeDatasetConfig` 473 new config with adjusted `n_mazes` 474 475 # Raises: 476 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 477 """ 478 # compute and check the success fraction 479 success_fraction: float = self.success_fraction_estimate( 480 except_if_all_success_expected=except_if_all_success_expected, 481 ) 482 if success_fraction < epsilon: 483 err_msg: str = ( 484 f"{success_fraction = } is below the threshold of {epsilon = }" 485 ) 486 raise SuccessChanceTooSmallError( 487 err_msg, 488 ) 489 490 # compute the new number of mazes 491 n_mazes: int = self.n_mazes 492 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 493 494 # put it in a new config and return 495 cfg_dict: dict = self.serialize() 496 cfg_dict["n_mazes"] = new_n_mazes 497 return MazeDatasetConfig.load(cfg_dict)
config object which is passed to MazeDataset.from_config
to generate or load a dataset
Parameters:
name : str
name of the dataset -- this can be anything, but should be filesystem safe since we use it in thefname
grid_n : int
grid size of the maze (number of rows/columns)n_mazes : int
number of mazes to request. For some combinations ofendpoint_kwargs
andmaze_ctor
, not all mazes might successfully generate. seeEndpointKwargsType
for more details.maze_ctor : Callable
maze generator function. This should be a function that takes a grid size and returns a maze. This will usually be one of the functions inLatticeMazeGenerators
.maze_ctor_kwargs : dict
keyword arguments to pass to the maze generator function. Specific to themaze_ctor
you are using.endpoint_kwargs : EndpointKwargsType
keyword arguments passed toLatticeMaze.generate_random_path()
. seeEndpointKwargsType
for more info.applied_filters : list[dict]
list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, but these are stored with the config in case you want to re-generate the dataset with the same filters.
282 @property 283 def config_version(self) -> str: 284 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 285 return "1.0"
return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config
287 @property 288 def versions(self) -> dict: 289 """return the versions of the config and the maze_dataset""" 290 return dict( 291 config=self.config_version, 292 maze_dataset=importlib.metadata.version("maze_dataset"), 293 )
return the versions of the config and the maze_dataset
295 def serialize(self) -> dict: 296 "serialize the MazeDatasetConfig with all fields and fname" 297 return { 298 **self._serialize_base( 299 applied_filters__skip__collect_generation_meta=False 300 ), 301 "fname": self.to_fname(), 302 "versions": self.versions, 303 }
serialize the MazeDatasetConfig with all fields and fname
305 def summary(self) -> dict: 306 """return a summary of the config""" 307 # do we run this to make sure it doesn't error? 308 super_summary: dict = super().summary() 309 assert super_summary 310 self_ser: dict = self.serialize() 311 return dict( 312 name=self.name, 313 fname=self.to_fname(), 314 sdc_hash=self.stable_hash_cfg(), 315 seed=self.seed, 316 seq_len_min=self.seq_len_min, 317 seq_len_max=self.seq_len_max, 318 applied_filters=self.applied_filters, 319 grid_n=self_ser["grid_n"], 320 n_mazes=self_ser["n_mazes"], 321 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 322 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 323 endpoint_kwargs=self_ser["endpoint_kwargs"], 324 )
return a summary of the config
407 def success_fraction_estimate( 408 self, 409 except_if_all_success_expected: bool = False, 410 ) -> float: 411 """Estimate the success fraction of this config. 412 413 only valid when the generator is a percolation generator, 414 and endpoints are enforced to be dead ends 415 416 more information on where this comes from can be found in 417 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 418 - `estimate_dataset_fractions.ipynb` 419 - `maze_dataset.benchmarks.sweep_fit` 420 421 # Parameters: 422 - `except_if_all_success_expected : bool` 423 if `True`, don't raise an error if the success fraction is below the threshold. 424 will always return `1.0` if the config is not expected to fail 425 426 # Returns: 427 - `float` 428 estimated success fraction 429 430 # Raises: 431 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 432 """ 433 try: 434 return cfg_success_predict_fn(self) 435 436 except NoPercolationInConfigError as e: 437 if except_if_all_success_expected: 438 raise e # noqa: TRY201 439 return 1.0
Estimate the success fraction of this config.
only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends
more information on where this comes from can be found in
cfg_success_predict_fn()
frommaze_dataset.dataset.success_predict_math
estimate_dataset_fractions.ipynb
maze_dataset.benchmarks.sweep_fit
Parameters:
except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. will always return1.0
if the config is not expected to fail
Returns:
float
estimated success fraction
Raises:
NoPercolationInConfigError
: if the config is not expected to fail, andexcept_if_all_success_expected
isFalse
441 def success_fraction_compensate( 442 self, 443 safety_margin: float = 1.2, 444 except_if_all_success_expected: bool = False, 445 epsilon: float = 1e-2, 446 ) -> "MazeDatasetConfig": 447 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 448 449 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 450 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 451 452 more information on where this comes from can be found in 453 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 454 - `estimate_dataset_fractions.ipynb` 455 - `maze_dataset.benchmarks.sweep_fit` 456 457 # Parameters: 458 - `safety_margin : float` 459 safety margin to apply to the success fraction estimate 460 (defaults to `1.2`, or 20% more mazes than estimated) 461 - `except_if_all_success_expected : bool` 462 if `True`, don't raise an error if the success fraction is below the threshold. 463 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 464 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 465 since `safety_margin` is still applied. 466 (defaults to `False`) 467 - `epsilon : float` 468 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 469 (defaults to `1e-2`) 470 471 # Returns: 472 - `MazeDatasetConfig` 473 new config with adjusted `n_mazes` 474 475 # Raises: 476 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 477 """ 478 # compute and check the success fraction 479 success_fraction: float = self.success_fraction_estimate( 480 except_if_all_success_expected=except_if_all_success_expected, 481 ) 482 if success_fraction < epsilon: 483 err_msg: str = ( 484 f"{success_fraction = } is below the threshold of {epsilon = }" 485 ) 486 raise SuccessChanceTooSmallError( 487 err_msg, 488 ) 489 490 # compute the new number of mazes 491 n_mazes: int = self.n_mazes 492 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 493 494 # put it in a new config and return 495 cfg_dict: dict = self.serialize() 496 cfg_dict["n_mazes"] = new_n_mazes 497 return MazeDatasetConfig.load(cfg_dict)
return a new MazeDatasetConfig
like this one with n_mazes
adjusted to compensate for the success fraction
calls MazeDatasetConfig.success_fraction_estimate()
to get the success fraction, and then
computes the new number of mazes as n_mazes = n_mazes * safety_margin / success_fraction + 1
more information on where this comes from can be found in
cfg_success_predict_fn()
frommaze_dataset.dataset.success_predict_math
estimate_dataset_fractions.ipynb
maze_dataset.benchmarks.sweep_fit
Parameters:
safety_margin : float
safety margin to apply to the success fraction estimate (defaults to1.2
, or 20% more mazes than estimated)except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. this is passed toMazeDatasetConfig.success_fraction_estimate
. if your config isn't expected to fail, passing this might mean you generate more mazes than needed sincesafety_margin
is still applied. (defaults toFalse
)epsilon : float
raiseSuccessChanceTooSmallError
if the success fraction is below this threshold (defaults to1e-2
)
Returns:
MazeDatasetConfig
new config with adjustedn_mazes
Raises:
SuccessChanceTooSmallError
: if the computed success fraction is belowepsilon
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
The type of the None singleton.
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.dataset.maze_dataset_config.MazeDatasetConfig_base
- grid_n
- n_mazes
- maze_ctor
- maze_ctor_kwargs
- endpoint_kwargs
- grid_shape
- grid_shape_np
- max_grid_n
- stable_hash_cfg
- to_fname
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
84class MazeDatasetCollection(GPTDataset): 85 """a collection of maze datasets""" 86 87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected 106 107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets] 111 112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths))) 116 117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 ) 125 126 def __len__(self) -> int: 127 """return the total number of mazes in the collection""" 128 return sum(len(dataset) for dataset in self.maze_datasets) 129 130 def __getitem__(self, index: int) -> LatticeMaze: 131 "get a maze by index" 132 # find which dataset the index belongs to 133 # we add 1, since np.searchsorted returns the 134 # index of the last element that is strictly less than the target 135 # while we want the index of the last element less than or equal to the target 136 dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1)) 137 index_adjusted: int = index 138 if dataset_idx > 0: 139 # if the index is 0, `dataset_idx - 1` will be -1. 140 # We just want to use the base index 141 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1] 142 return self.maze_datasets[dataset_idx][index_adjusted] 143 144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets) 156 157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets) 169 170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 } 180 181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 ) 191 192 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow? 193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output 217 218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
a collection of maze datasets
87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize the dataset collection from a MazeDatasetCollectionConfig
and a list of MazeDataset
s
107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets]
return the lengths of each dataset in the collection
112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths)))
return the cumulative lengths of each dataset in the collection
117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 )
single list of all mazes in the collection
144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets)
generate a dataset collection from a config
157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets)
(not implemented!) download a dataset collection from a config
170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 }
serialize the dataset collection
181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 )
load the dataset collection from the representation created by serialize
193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output
return the dataset as tokens
if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
update the config to match the number of mazes, and update the underlying configs of each dataset
Inherited Members
31@serializable_dataclass(kw_only=True) 32class MazeDatasetCollectionConfig(GPTDatasetConfig): 33 """maze dataset collection configuration, including tokenizers and shuffle""" 34 35 # Attributes without a default cannot follow attributes with one [misc] 36 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( # type: ignore[misc] 37 serialization_fn=lambda configs: [config.serialize() for config in configs], 38 loading_fn=lambda data: [ 39 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"] 40 ], 41 ) 42 43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 ) 52 53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs) 57 58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs) 62 63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n) 67 68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32) 72 73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize())) 76 77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
maze dataset collection configuration, including tokenizers and shuffle
43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 )
return a summary of the config
53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs)
return the total number of mazes in the collection across all dataset
58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs)
return the maximum grid size of the mazes in the collection
63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n)
return the maximum grid shape of the mazes in the collection
68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32)
return the maximum grid shape of the mazes in the collection as a numpy array
73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize()))
return a stable hash of the config
77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
convert config to a filename
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
The type of the None singleton.
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
The type of the None singleton.
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict