maze_dataset.dataset.maze_dataset
MazeDatasetConfig
is where you decide what your dataset should look like, then pass it to MazeDataset.from_config
to generate or load the dataset.
1"""`MazeDatasetConfig` is where you decide what your dataset should look like, then pass it to `MazeDataset.from_config` to generate or load the dataset. 2 3see [demo_dataset notebook](../../notebooks/demo_dataset) 4 5""" 6 7import copy 8import json 9import multiprocessing 10import typing 11import warnings 12from pathlib import Path 13from typing import Literal, cast, overload 14 15import numpy as np 16import tqdm 17from jaxtyping import Int 18from muutils.json_serialize import ( 19 json_serialize, 20) 21from muutils.json_serialize.util import ( 22 _FORMAT_KEY, 23 JSONdict, 24) 25from muutils.misc import stable_hash 26from zanj import ZANJ 27from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler 28 29from maze_dataset.constants import CoordArray 30from maze_dataset.dataset.dataset import ( 31 GPTDataset, 32) 33from maze_dataset.dataset.maze_dataset_config import ( 34 SERIALIZE_MINIMAL_THRESHOLD, 35 EndpointKwargsType, 36 MazeDatasetConfig, 37) 38from maze_dataset.generation.seed import GLOBAL_SEED 39from maze_dataset.maze import LatticeMaze, SolvedMaze 40 41_GLOBAL_WORKER_CONFIG: MazeDatasetConfig 42 43 44def _generate_maze_helper(index: int) -> SolvedMaze | None: # noqa: ARG001 45 """Helper function for generating mazes in parallel. 46 47 > [!CAUTION] 48 > don't use this unless generating in parallel! 49 """ 50 global _GLOBAL_WORKER_CONFIG # noqa: PLW0602 51 # TODO: don't use this unless generating in parallel! 52 maze: LatticeMaze = _GLOBAL_WORKER_CONFIG.maze_ctor( 53 grid_shape=_GLOBAL_WORKER_CONFIG.grid_shape_np, 54 **_GLOBAL_WORKER_CONFIG.maze_ctor_kwargs, 55 ) 56 57 endpoint_kwargs: EndpointKwargsType = _GLOBAL_WORKER_CONFIG.endpoint_kwargs.copy() 58 59 # Generate the solution 60 # mypy doesnt realize EndpointKwargsType has only string keys: `Keywords must be strings [misc]` 61 # TYPING: error: No overload variant of "generate_random_path" of "LatticeMaze" matches argument type "dict[Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | list[tuple[int, int]] | None]" [call-overload] 62 solution: CoordArray | None = maze.generate_random_path(**endpoint_kwargs) # type: ignore[misc, call-overload] 63 64 # Validate the solution 65 if ( 66 solution is None 67 or len(solution) == 0 68 or not isinstance(solution, np.ndarray) 69 # magic value is fine here 70 or len(solution.shape) != 2 # noqa: PLR2004 71 ): 72 return None # Return None if the solution is invalid 73 74 return SolvedMaze.from_lattice_maze( 75 lattice_maze=maze, 76 solution=solution, 77 ) 78 79 80def _maze_gen_init_worker(config: MazeDatasetConfig) -> None: 81 """special worker helper 82 83 > [!CAUTION] 84 > this makes the generation depend both on whether parallelism is used, and on the number of processes. this is bad! 85 86 """ 87 # TODO: dont use globals here! 88 global _GLOBAL_WORKER_CONFIG # noqa: PLW0603 89 _GLOBAL_WORKER_CONFIG = config 90 91 process_id: tuple[int, ...] = multiprocessing.current_process()._identity 92 if len(process_id) == 0: 93 # no multiprocessing, seed was already set 94 pass 95 elif len(process_id) == 1: 96 # multiprocessing, adjust seed based on process id 97 # only set numpy seed, since we do not use other random gens 98 np.random.seed( 99 _GLOBAL_WORKER_CONFIG.seed 100 or GLOBAL_SEED # if the seed is None, use the global seed 101 + process_id[0] 102 ) 103 else: 104 err_msg = ( 105 f"unexpected process id: {process_id = }\n{multiprocessing.Process() = }" 106 ) 107 raise ValueError( 108 err_msg, 109 ) 110 111 112# TODO: we probably don't need to hash datasets, right? 113class MazeDataset(GPTDataset[MazeDatasetConfig]): # noqa: PLW1641 114 """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`""" 115 116 def __init__( 117 self, 118 cfg: MazeDatasetConfig, 119 mazes: typing.Sequence[SolvedMaze], 120 generation_metadata_collected: dict | None = None, 121 ) -> None: 122 """initialize a maze dataset from a config and a list of solved mazes""" 123 super().__init__() 124 self.cfg: MazeDatasetConfig = cfg 125 self.mazes: list[SolvedMaze] = list(mazes) 126 self.generation_metadata_collected: dict | None = generation_metadata_collected 127 128 # TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset" [override] 129 @classmethod 130 def from_config( # type: ignore[override] 131 cls, 132 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 133 cfg: MazeDatasetConfig, # type: ignore[override] 134 do_generate: bool = True, 135 load_local: bool = True, 136 save_local: bool = True, 137 zanj: ZANJ | None = None, 138 do_download: bool = True, 139 local_base_path: Path = Path("data/maze_dataset"), 140 except_on_config_mismatch: bool = True, 141 allow_generation_metadata_filter_mismatch: bool = True, 142 verbose: bool = False, 143 **kwargs, 144 ) -> "MazeDataset": 145 """create a maze dataset from a config 146 147 priority of loading: 148 1. load from local 149 2. download 150 3. generate 151 152 """ 153 return cast( 154 "MazeDataset", 155 super().from_config( 156 cfg=cfg, 157 do_generate=do_generate, 158 load_local=load_local, 159 save_local=save_local, 160 zanj=zanj, 161 do_download=do_download, 162 local_base_path=local_base_path, 163 except_on_config_mismatch=except_on_config_mismatch, 164 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 165 verbose=verbose, 166 **kwargs, 167 ), 168 ) 169 170 def data_hash(self) -> int: 171 """return a hash of the data""" 172 return stable_hash(str(tuple([x.serialize() for x in self.mazes]))) 173 174 def __getitem__(self, i: int) -> SolvedMaze: 175 """get a maze by index""" 176 return self.mazes[i] 177 178 def __iter__(self) -> typing.Iterator[SolvedMaze]: 179 """iterate over the mazes""" 180 return iter(self.mazes) 181 182 def __deepcopy__(self, memo) -> "MazeDataset": # noqa: ANN001 183 """deepcopy the dataset 184 185 FIX: this isnt actually a deepcopy I think? 186 """ 187 return MazeDataset.load(self._serialize_full()) 188 189 # TYPING: get type hints on the tokenizer here 190 @overload 191 def as_tokens( 192 self, 193 maze_tokenizer, # noqa: ANN001 194 limit: int | None = None, 195 join_tokens_individual_maze: Literal[False] = False, 196 ) -> list[list[str]]: ... 197 @overload 198 def as_tokens( 199 self, 200 maze_tokenizer, # noqa: ANN001 201 limit: int | None = None, 202 join_tokens_individual_maze: Literal[True] = True, 203 ) -> list[str]: ... 204 def as_tokens( 205 self, 206 maze_tokenizer, # TODO: MazeTokenizer 207 limit: int | None = None, 208 join_tokens_individual_maze: bool = False, 209 ) -> list[list[str]] | list[str]: 210 """return the dataset as tokens according to the passed `maze_tokenizer` 211 212 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 213 214 if `join_tokens_individual_maze` is True, then the tokens of each maze are 215 joined with a space, and the result is a list of strings. 216 i.e.: 217 218 >>> dataset.as_tokens(join_tokens_individual_maze=False) 219 [["a", "b", "c"], ["d", "e", "f"]] 220 >>> dataset.as_tokens(join_tokens_individual_maze=True) 221 ["a b c", "d e f"] 222 """ 223 output: list[list[str]] = [ 224 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 225 ] 226 if join_tokens_individual_maze: 227 return [" ".join(tokens) for tokens in output] 228 else: 229 return output 230 231 def __len__(self) -> int: 232 """return the number of mazes in the dataset""" 233 return len(self.mazes) 234 235 def __eq__(self, other: object) -> bool: 236 """compare two datasets""" 237 if not isinstance(other, MazeDataset): 238 raise NotImplementedError( 239 "can only compare with other MazeDataset objects", 240 ) 241 # TODO: compare hashes of data instead of the data itself? 242 return self.cfg == other.cfg and self.mazes == other.mazes 243 244 def assert_equal(self, other: "MazeDataset") -> None: 245 """assert that two datasets are equal""" 246 assert isinstance(other, MazeDataset) 247 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 248 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }" 249 250 @classmethod 251 def generate( 252 cls, 253 cfg: MazeDatasetConfig, 254 gen_parallel: bool = False, 255 pool_kwargs: dict | None = None, 256 verbose: bool = False, 257 # TODO: what to do when unexpected kwargs are passed? 258 **kwargs, # noqa: ARG003 259 ) -> "MazeDataset": 260 """Generate a maze dataset given a config and some generation parameters""" 261 # Copy the config to avoid modifying the original 262 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 263 json.loads(json.dumps(cfg.serialize())), 264 ) 265 266 if pool_kwargs is None: 267 pool_kwargs = dict() 268 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 269 270 solved_mazes: list[SolvedMaze | None] 271 # Configure tqdm for progress bar 272 tqdm_kwargs: dict = dict( 273 total=cfg_cpy.n_mazes, 274 unit="maze", 275 desc="generating & solving mazes", 276 disable=not verbose, 277 ) 278 # TODO: don't use the global unless generating in parallel! 279 if gen_parallel: 280 with multiprocessing.Pool( 281 **pool_kwargs, 282 initializer=_maze_gen_init_worker, 283 initargs=(cfg_cpy,), 284 ) as pool: 285 solved_mazes = list( 286 tqdm.tqdm( 287 pool.imap(_generate_maze_helper, maze_indexes), 288 **tqdm_kwargs, 289 ), 290 ) 291 292 else: 293 _maze_gen_init_worker(cfg_cpy) 294 solved_mazes = list( 295 tqdm.tqdm( 296 map( 297 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 298 # why does it think tolist() returns a string? 299 _generate_maze_helper, # type: ignore[arg-type] 300 maze_indexes.tolist(), 301 ), 302 **tqdm_kwargs, 303 ), 304 ) 305 306 # Filter out None values explicitly after ensuring all results are collected 307 solved_mazes_: list[SolvedMaze] = [ 308 maze for maze in solved_mazes if maze is not None 309 ] 310 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 311 312 # Update the config with the actual number of mazes 313 cfg_cpy.n_mazes = len(solved_mazes_) 314 315 dataset: MazeDataset = cls( 316 cfg=cfg_cpy, 317 mazes=solved_mazes_, 318 ) 319 320 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 321 322 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 323 324 return dataset 325 326 @classmethod 327 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 328 "(not implemented yet!) download a maze dataset from the internet" 329 raise NotImplementedError("not implemented yet") 330 331 @classmethod 332 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 333 """load from zanj/json""" 334 if data[_FORMAT_KEY] == "MazeDataset:minimal": 335 return cls._load_minimal(data) 336 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 337 return cls._load_minimal_soln_cat(data) 338 elif data[_FORMAT_KEY] == "MazeDataset": 339 if ( 340 SERIALIZE_MINIMAL_THRESHOLD == -1 341 ): # Allow access to `_load_legacy` for profiling 342 return cls._load_legacy(data) 343 return cls._load_full(data) 344 else: 345 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 346 raise KeyError( 347 err_msg, 348 ) 349 350 @classmethod 351 def _load_full(cls, data: JSONdict) -> "MazeDataset": 352 assert data[_FORMAT_KEY] == "MazeDataset" 353 return cls( 354 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 355 mazes=load_item_recursive(data["mazes"], tuple()), 356 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 357 ) 358 359 @classmethod 360 def _load_minimal(cls, data: JSONdict) -> "MazeDataset": 361 assert data[_FORMAT_KEY] == "MazeDataset:minimal" 362 return cls( 363 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 364 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 365 mazes=[ 366 SolvedMaze( 367 clist, 368 soln[:slen, ...], 369 ) 370 for clist, slen, soln in zip( 371 load_item_recursive(data["maze_connection_lists"], tuple()), 372 load_item_recursive(data["maze_solution_lengths"], tuple()), 373 load_item_recursive(data["maze_solutions"], tuple()), 374 strict=False, 375 # load_item_recursive(data["maze_endpoints"], tuple()), 376 ) 377 ], 378 ) 379 380 @classmethod 381 def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset": 382 assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat" 383 384 maze_solution_lengths = load_item_recursive( 385 data["maze_solution_lengths"], 386 tuple(), 387 ) 388 maze_solutions_concat = load_item_recursive( 389 data["maze_solutions_concat"], 390 tuple(), 391 ) 392 maze_solutions = np.split( 393 maze_solutions_concat, 394 np.cumsum(maze_solution_lengths)[:-1], 395 axis=0, 396 ) 397 398 return cls( 399 cfg=load_item_recursive(data["cfg"], tuple()), 400 generation_metadata_collected=load_item_recursive( 401 data["generation_metadata_collected"], 402 tuple(), 403 ), 404 mazes=[ 405 SolvedMaze( 406 connection_list=clist, 407 solution=soln, 408 ) 409 for clist, soln in zip( 410 load_item_recursive(data["maze_connection_lists"], tuple()), 411 # load_item_recursive(data["maze_endpoints"], tuple()), 412 maze_solutions, 413 strict=False, 414 ) 415 ], 416 ) 417 418 @classmethod 419 def _load_legacy(cls, data: JSONdict) -> "MazeDataset": 420 """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison.""" 421 assert data[_FORMAT_KEY] == "MazeDataset" 422 return cls( 423 **{ 424 key: load_item_recursive(data[key], tuple()) 425 for key in ["cfg", "mazes", "generation_metadata_collected"] 426 }, 427 ) 428 429 def serialize(self) -> JSONdict: 430 """serialize to zanj/json""" 431 if ( 432 SERIALIZE_MINIMAL_THRESHOLD is not None 433 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 434 ): 435 return self._serialize_minimal() 436 return self._serialize_full() 437 438 def _serialize_full(self) -> JSONdict: 439 return { 440 _FORMAT_KEY: "MazeDataset", 441 "cfg": json_serialize(self.cfg), 442 "fname": self.cfg.to_fname(), 443 "mazes": json_serialize(self.mazes), 444 "generation_metadata_collected": json_serialize( 445 self.generation_metadata_collected, 446 ), 447 } 448 449 def _serialize_minimal(self) -> JSONdict: 450 "alternate serialization where metadata is collected and mazes are stored in concatenated form" 451 filtered_meta: MazeDataset 452 if self.generation_metadata_collected is None: 453 filtered_meta = self.filter_by.collect_generation_meta() 454 else: 455 filtered_meta = self 456 457 max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes) 458 n_mazes: int = len(filtered_meta.mazes) 459 grid_n: int = filtered_meta.cfg.grid_n 460 461 maze_connection_lists: np.ndarray = np.empty( 462 (n_mazes, 2, grid_n, grid_n), 463 dtype=np.bool_, 464 ) 465 # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 466 maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32) 467 maze_solutions: np.ndarray = np.empty( 468 (n_mazes, max_solution_len, 2), 469 dtype=np.int8, 470 ) 471 472 for idx, maze in enumerate(filtered_meta.mazes): 473 maze_connection_lists[idx] = maze.connection_list 474 # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 475 maze_solution_lengths[idx] = maze.solution.shape[0] 476 maze_solutions[idx, : maze.solution.shape[0]] = maze.solution 477 478 return { 479 _FORMAT_KEY: "MazeDataset:minimal", 480 "cfg": json_serialize(filtered_meta.cfg), 481 "fname": filtered_meta.cfg.to_fname(), 482 "generation_metadata_collected": json_serialize( 483 filtered_meta.generation_metadata_collected, 484 ), 485 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 486 # "maze_endpoints": maze_endpoints, 487 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 488 "maze_solutions": maze_solutions, # type: ignore[dict-item] 489 } 490 491 def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict: 492 "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form" 493 filtered_meta: MazeDataset 494 if self.generation_metadata_collected is None: 495 filtered_meta = self.filter_by.collect_generation_meta() 496 else: 497 filtered_meta = self 498 499 maze_solution_lengths: np.ndarray = np.array( 500 [m.solution.shape[0] for m in filtered_meta.mazes], 501 dtype=np.int32, 502 ) 503 n_mazes: int = len(filtered_meta.mazes) 504 grid_n: int = filtered_meta.cfg.grid_n 505 total_solution_len: int = np.sum(maze_solution_lengths) 506 507 maze_connection_lists: np.ndarray = np.empty( 508 (n_mazes, 2, grid_n, grid_n), 509 dtype=np.bool_, 510 ) 511 maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 512 maze_solutions_concat: np.ndarray = np.empty( 513 (total_solution_len, 2), 514 dtype=np.int8, 515 ) 516 517 solutions_running_idx: int = 0 518 for idx, maze in enumerate(filtered_meta.mazes): 519 maze_connection_lists[idx] = maze.connection_list 520 maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 521 soln_len: int = maze.solution.shape[0] 522 maze_solution_lengths[idx] = soln_len 523 maze_solutions_concat[ 524 solutions_running_idx : solutions_running_idx + soln_len 525 ] = maze.solution 526 solutions_running_idx += soln_len 527 528 return { 529 _FORMAT_KEY: "MazeDataset:minimal_soln_cat", 530 "cfg": json_serialize(filtered_meta.cfg), 531 "fname": filtered_meta.cfg.to_fname(), 532 "generation_metadata_collected": json_serialize( 533 filtered_meta.generation_metadata_collected, 534 ), 535 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 536 "maze_endpoints": maze_endpoints, # type: ignore[dict-item] 537 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 538 "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item] 539 } 540 541 def update_self_config(self) -> None: 542 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 543 if self.cfg.n_mazes != len(self.mazes): 544 warnings.warn( 545 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 546 ) 547 self.cfg.n_mazes = len(self.mazes) 548 549 def custom_maze_filter( 550 self, 551 method: typing.Callable[[SolvedMaze], bool], 552 **kwargs, 553 ) -> "MazeDataset": 554 """filter the dataset using a custom method""" 555 output: MazeDataset = MazeDataset( 556 cfg=copy.deepcopy(self.cfg), 557 mazes=[m for m in self.mazes if method(m, **kwargs)], 558 ) 559 output.cfg.applied_filters.append( 560 { 561 "name": f"__custom__:{method.__name__}", 562 "kwargs": kwargs, 563 }, 564 ) 565 output.update_self_config() 566 return output 567 568 569MazeDatasetConfig._dataset_class = property( # type: ignore[method-assign, assignment] 570 lambda self: MazeDataset, # noqa: ARG005 571) 572 573# register things with zanj 574register_loader_handler( 575 LoaderHandler( 576 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] # noqa: ARG005 577 isinstance(json_item, typing.Mapping) 578 and _FORMAT_KEY in json_item 579 and json_item[_FORMAT_KEY].startswith("MazeDataset") 580 ), 581 load=lambda json_item, path=None, z=None: MazeDataset.load(json_item), # type: ignore[misc] # noqa: ARG005 582 uid="MazeDataset", 583 source_pckg="maze_dataset.generation.maze_dataset", 584 desc="MazeDataset", 585 ), 586) 587 588 589# TODO: the code below is for doing some smarter collecting and type checking. Probably will delete. 590""" 591collect either the type at the field, or the shape of the field if it is an array 592metadata_types: dict[str, set[type, tuple]] = dict() 593for maze in new_dataset: 594 for key, value in maze.generation_meta.items(): 595 if key not in metadata_types: 596 metadata_types[key] = set() 597 598 if isinstance(value, np.ndarray): 599 metadata_types[key].add(value.shape) 600 else: 601 metadata_types[key].add(type(value)) 602 603# figure out what to do for each field 604metadata_actions: dict[str, typing.Callable] = dict() 605for key, key_type in metadata_types.items(): 606 if all(isinstance(kt, tuple) for kt in key_type): 607 if all(kt == (2,) for kt in key_type): 608 # its all coords, do a statcounter on those coords 609 metadata_actions[key] = lambda vals: Counter(tuple(x) for x in vals) 610 elif all( 611 (len(kt) == 2) and (kt[1] == 2) 612 for kt in key_type 613 ): 614 # its a list of coords, do a statcounter on those coords 615 metadata_actions[key] = lambda vals: Counter( 616 tuple(x) for x in np.concatenate(vals) 617 ) 618 else: 619 # its a list of something else, do a counter on those 620 # TODO: throw except here? 621 metadata_actions[key] = Counter 622 623 elif all(kt in (bool, int, float) for kt in key_type): 624 # statcounter for numeric types 625 metadata_actions[key] = StatCounter 626 elif all(kt == str for kt in key_type): 627 # counter for string types 628 metadata_actions[key] = Counter 629 else: 630 # counter for everything else 631 # TODO: throw except here? 632 metadata_actions[key] = Counter 633"""
114class MazeDataset(GPTDataset[MazeDatasetConfig]): # noqa: PLW1641 115 """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`""" 116 117 def __init__( 118 self, 119 cfg: MazeDatasetConfig, 120 mazes: typing.Sequence[SolvedMaze], 121 generation_metadata_collected: dict | None = None, 122 ) -> None: 123 """initialize a maze dataset from a config and a list of solved mazes""" 124 super().__init__() 125 self.cfg: MazeDatasetConfig = cfg 126 self.mazes: list[SolvedMaze] = list(mazes) 127 self.generation_metadata_collected: dict | None = generation_metadata_collected 128 129 # TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset" [override] 130 @classmethod 131 def from_config( # type: ignore[override] 132 cls, 133 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 134 cfg: MazeDatasetConfig, # type: ignore[override] 135 do_generate: bool = True, 136 load_local: bool = True, 137 save_local: bool = True, 138 zanj: ZANJ | None = None, 139 do_download: bool = True, 140 local_base_path: Path = Path("data/maze_dataset"), 141 except_on_config_mismatch: bool = True, 142 allow_generation_metadata_filter_mismatch: bool = True, 143 verbose: bool = False, 144 **kwargs, 145 ) -> "MazeDataset": 146 """create a maze dataset from a config 147 148 priority of loading: 149 1. load from local 150 2. download 151 3. generate 152 153 """ 154 return cast( 155 "MazeDataset", 156 super().from_config( 157 cfg=cfg, 158 do_generate=do_generate, 159 load_local=load_local, 160 save_local=save_local, 161 zanj=zanj, 162 do_download=do_download, 163 local_base_path=local_base_path, 164 except_on_config_mismatch=except_on_config_mismatch, 165 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 166 verbose=verbose, 167 **kwargs, 168 ), 169 ) 170 171 def data_hash(self) -> int: 172 """return a hash of the data""" 173 return stable_hash(str(tuple([x.serialize() for x in self.mazes]))) 174 175 def __getitem__(self, i: int) -> SolvedMaze: 176 """get a maze by index""" 177 return self.mazes[i] 178 179 def __iter__(self) -> typing.Iterator[SolvedMaze]: 180 """iterate over the mazes""" 181 return iter(self.mazes) 182 183 def __deepcopy__(self, memo) -> "MazeDataset": # noqa: ANN001 184 """deepcopy the dataset 185 186 FIX: this isnt actually a deepcopy I think? 187 """ 188 return MazeDataset.load(self._serialize_full()) 189 190 # TYPING: get type hints on the tokenizer here 191 @overload 192 def as_tokens( 193 self, 194 maze_tokenizer, # noqa: ANN001 195 limit: int | None = None, 196 join_tokens_individual_maze: Literal[False] = False, 197 ) -> list[list[str]]: ... 198 @overload 199 def as_tokens( 200 self, 201 maze_tokenizer, # noqa: ANN001 202 limit: int | None = None, 203 join_tokens_individual_maze: Literal[True] = True, 204 ) -> list[str]: ... 205 def as_tokens( 206 self, 207 maze_tokenizer, # TODO: MazeTokenizer 208 limit: int | None = None, 209 join_tokens_individual_maze: bool = False, 210 ) -> list[list[str]] | list[str]: 211 """return the dataset as tokens according to the passed `maze_tokenizer` 212 213 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 214 215 if `join_tokens_individual_maze` is True, then the tokens of each maze are 216 joined with a space, and the result is a list of strings. 217 i.e.: 218 219 >>> dataset.as_tokens(join_tokens_individual_maze=False) 220 [["a", "b", "c"], ["d", "e", "f"]] 221 >>> dataset.as_tokens(join_tokens_individual_maze=True) 222 ["a b c", "d e f"] 223 """ 224 output: list[list[str]] = [ 225 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 226 ] 227 if join_tokens_individual_maze: 228 return [" ".join(tokens) for tokens in output] 229 else: 230 return output 231 232 def __len__(self) -> int: 233 """return the number of mazes in the dataset""" 234 return len(self.mazes) 235 236 def __eq__(self, other: object) -> bool: 237 """compare two datasets""" 238 if not isinstance(other, MazeDataset): 239 raise NotImplementedError( 240 "can only compare with other MazeDataset objects", 241 ) 242 # TODO: compare hashes of data instead of the data itself? 243 return self.cfg == other.cfg and self.mazes == other.mazes 244 245 def assert_equal(self, other: "MazeDataset") -> None: 246 """assert that two datasets are equal""" 247 assert isinstance(other, MazeDataset) 248 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 249 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }" 250 251 @classmethod 252 def generate( 253 cls, 254 cfg: MazeDatasetConfig, 255 gen_parallel: bool = False, 256 pool_kwargs: dict | None = None, 257 verbose: bool = False, 258 # TODO: what to do when unexpected kwargs are passed? 259 **kwargs, # noqa: ARG003 260 ) -> "MazeDataset": 261 """Generate a maze dataset given a config and some generation parameters""" 262 # Copy the config to avoid modifying the original 263 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 264 json.loads(json.dumps(cfg.serialize())), 265 ) 266 267 if pool_kwargs is None: 268 pool_kwargs = dict() 269 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 270 271 solved_mazes: list[SolvedMaze | None] 272 # Configure tqdm for progress bar 273 tqdm_kwargs: dict = dict( 274 total=cfg_cpy.n_mazes, 275 unit="maze", 276 desc="generating & solving mazes", 277 disable=not verbose, 278 ) 279 # TODO: don't use the global unless generating in parallel! 280 if gen_parallel: 281 with multiprocessing.Pool( 282 **pool_kwargs, 283 initializer=_maze_gen_init_worker, 284 initargs=(cfg_cpy,), 285 ) as pool: 286 solved_mazes = list( 287 tqdm.tqdm( 288 pool.imap(_generate_maze_helper, maze_indexes), 289 **tqdm_kwargs, 290 ), 291 ) 292 293 else: 294 _maze_gen_init_worker(cfg_cpy) 295 solved_mazes = list( 296 tqdm.tqdm( 297 map( 298 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 299 # why does it think tolist() returns a string? 300 _generate_maze_helper, # type: ignore[arg-type] 301 maze_indexes.tolist(), 302 ), 303 **tqdm_kwargs, 304 ), 305 ) 306 307 # Filter out None values explicitly after ensuring all results are collected 308 solved_mazes_: list[SolvedMaze] = [ 309 maze for maze in solved_mazes if maze is not None 310 ] 311 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 312 313 # Update the config with the actual number of mazes 314 cfg_cpy.n_mazes = len(solved_mazes_) 315 316 dataset: MazeDataset = cls( 317 cfg=cfg_cpy, 318 mazes=solved_mazes_, 319 ) 320 321 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 322 323 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 324 325 return dataset 326 327 @classmethod 328 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 329 "(not implemented yet!) download a maze dataset from the internet" 330 raise NotImplementedError("not implemented yet") 331 332 @classmethod 333 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 334 """load from zanj/json""" 335 if data[_FORMAT_KEY] == "MazeDataset:minimal": 336 return cls._load_minimal(data) 337 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 338 return cls._load_minimal_soln_cat(data) 339 elif data[_FORMAT_KEY] == "MazeDataset": 340 if ( 341 SERIALIZE_MINIMAL_THRESHOLD == -1 342 ): # Allow access to `_load_legacy` for profiling 343 return cls._load_legacy(data) 344 return cls._load_full(data) 345 else: 346 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 347 raise KeyError( 348 err_msg, 349 ) 350 351 @classmethod 352 def _load_full(cls, data: JSONdict) -> "MazeDataset": 353 assert data[_FORMAT_KEY] == "MazeDataset" 354 return cls( 355 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 356 mazes=load_item_recursive(data["mazes"], tuple()), 357 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 358 ) 359 360 @classmethod 361 def _load_minimal(cls, data: JSONdict) -> "MazeDataset": 362 assert data[_FORMAT_KEY] == "MazeDataset:minimal" 363 return cls( 364 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 365 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 366 mazes=[ 367 SolvedMaze( 368 clist, 369 soln[:slen, ...], 370 ) 371 for clist, slen, soln in zip( 372 load_item_recursive(data["maze_connection_lists"], tuple()), 373 load_item_recursive(data["maze_solution_lengths"], tuple()), 374 load_item_recursive(data["maze_solutions"], tuple()), 375 strict=False, 376 # load_item_recursive(data["maze_endpoints"], tuple()), 377 ) 378 ], 379 ) 380 381 @classmethod 382 def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset": 383 assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat" 384 385 maze_solution_lengths = load_item_recursive( 386 data["maze_solution_lengths"], 387 tuple(), 388 ) 389 maze_solutions_concat = load_item_recursive( 390 data["maze_solutions_concat"], 391 tuple(), 392 ) 393 maze_solutions = np.split( 394 maze_solutions_concat, 395 np.cumsum(maze_solution_lengths)[:-1], 396 axis=0, 397 ) 398 399 return cls( 400 cfg=load_item_recursive(data["cfg"], tuple()), 401 generation_metadata_collected=load_item_recursive( 402 data["generation_metadata_collected"], 403 tuple(), 404 ), 405 mazes=[ 406 SolvedMaze( 407 connection_list=clist, 408 solution=soln, 409 ) 410 for clist, soln in zip( 411 load_item_recursive(data["maze_connection_lists"], tuple()), 412 # load_item_recursive(data["maze_endpoints"], tuple()), 413 maze_solutions, 414 strict=False, 415 ) 416 ], 417 ) 418 419 @classmethod 420 def _load_legacy(cls, data: JSONdict) -> "MazeDataset": 421 """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison.""" 422 assert data[_FORMAT_KEY] == "MazeDataset" 423 return cls( 424 **{ 425 key: load_item_recursive(data[key], tuple()) 426 for key in ["cfg", "mazes", "generation_metadata_collected"] 427 }, 428 ) 429 430 def serialize(self) -> JSONdict: 431 """serialize to zanj/json""" 432 if ( 433 SERIALIZE_MINIMAL_THRESHOLD is not None 434 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 435 ): 436 return self._serialize_minimal() 437 return self._serialize_full() 438 439 def _serialize_full(self) -> JSONdict: 440 return { 441 _FORMAT_KEY: "MazeDataset", 442 "cfg": json_serialize(self.cfg), 443 "fname": self.cfg.to_fname(), 444 "mazes": json_serialize(self.mazes), 445 "generation_metadata_collected": json_serialize( 446 self.generation_metadata_collected, 447 ), 448 } 449 450 def _serialize_minimal(self) -> JSONdict: 451 "alternate serialization where metadata is collected and mazes are stored in concatenated form" 452 filtered_meta: MazeDataset 453 if self.generation_metadata_collected is None: 454 filtered_meta = self.filter_by.collect_generation_meta() 455 else: 456 filtered_meta = self 457 458 max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes) 459 n_mazes: int = len(filtered_meta.mazes) 460 grid_n: int = filtered_meta.cfg.grid_n 461 462 maze_connection_lists: np.ndarray = np.empty( 463 (n_mazes, 2, grid_n, grid_n), 464 dtype=np.bool_, 465 ) 466 # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 467 maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32) 468 maze_solutions: np.ndarray = np.empty( 469 (n_mazes, max_solution_len, 2), 470 dtype=np.int8, 471 ) 472 473 for idx, maze in enumerate(filtered_meta.mazes): 474 maze_connection_lists[idx] = maze.connection_list 475 # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 476 maze_solution_lengths[idx] = maze.solution.shape[0] 477 maze_solutions[idx, : maze.solution.shape[0]] = maze.solution 478 479 return { 480 _FORMAT_KEY: "MazeDataset:minimal", 481 "cfg": json_serialize(filtered_meta.cfg), 482 "fname": filtered_meta.cfg.to_fname(), 483 "generation_metadata_collected": json_serialize( 484 filtered_meta.generation_metadata_collected, 485 ), 486 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 487 # "maze_endpoints": maze_endpoints, 488 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 489 "maze_solutions": maze_solutions, # type: ignore[dict-item] 490 } 491 492 def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict: 493 "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form" 494 filtered_meta: MazeDataset 495 if self.generation_metadata_collected is None: 496 filtered_meta = self.filter_by.collect_generation_meta() 497 else: 498 filtered_meta = self 499 500 maze_solution_lengths: np.ndarray = np.array( 501 [m.solution.shape[0] for m in filtered_meta.mazes], 502 dtype=np.int32, 503 ) 504 n_mazes: int = len(filtered_meta.mazes) 505 grid_n: int = filtered_meta.cfg.grid_n 506 total_solution_len: int = np.sum(maze_solution_lengths) 507 508 maze_connection_lists: np.ndarray = np.empty( 509 (n_mazes, 2, grid_n, grid_n), 510 dtype=np.bool_, 511 ) 512 maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 513 maze_solutions_concat: np.ndarray = np.empty( 514 (total_solution_len, 2), 515 dtype=np.int8, 516 ) 517 518 solutions_running_idx: int = 0 519 for idx, maze in enumerate(filtered_meta.mazes): 520 maze_connection_lists[idx] = maze.connection_list 521 maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 522 soln_len: int = maze.solution.shape[0] 523 maze_solution_lengths[idx] = soln_len 524 maze_solutions_concat[ 525 solutions_running_idx : solutions_running_idx + soln_len 526 ] = maze.solution 527 solutions_running_idx += soln_len 528 529 return { 530 _FORMAT_KEY: "MazeDataset:minimal_soln_cat", 531 "cfg": json_serialize(filtered_meta.cfg), 532 "fname": filtered_meta.cfg.to_fname(), 533 "generation_metadata_collected": json_serialize( 534 filtered_meta.generation_metadata_collected, 535 ), 536 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 537 "maze_endpoints": maze_endpoints, # type: ignore[dict-item] 538 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 539 "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item] 540 } 541 542 def update_self_config(self) -> None: 543 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 544 if self.cfg.n_mazes != len(self.mazes): 545 warnings.warn( 546 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 547 ) 548 self.cfg.n_mazes = len(self.mazes) 549 550 def custom_maze_filter( 551 self, 552 method: typing.Callable[[SolvedMaze], bool], 553 **kwargs, 554 ) -> "MazeDataset": 555 """filter the dataset using a custom method""" 556 output: MazeDataset = MazeDataset( 557 cfg=copy.deepcopy(self.cfg), 558 mazes=[m for m in self.mazes if method(m, **kwargs)], 559 ) 560 output.cfg.applied_filters.append( 561 { 562 "name": f"__custom__:{method.__name__}", 563 "kwargs": kwargs, 564 }, 565 ) 566 output.update_self_config() 567 return output
a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config
117 def __init__( 118 self, 119 cfg: MazeDatasetConfig, 120 mazes: typing.Sequence[SolvedMaze], 121 generation_metadata_collected: dict | None = None, 122 ) -> None: 123 """initialize a maze dataset from a config and a list of solved mazes""" 124 super().__init__() 125 self.cfg: MazeDatasetConfig = cfg 126 self.mazes: list[SolvedMaze] = list(mazes) 127 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize a maze dataset from a config and a list of solved mazes
130 @classmethod 131 def from_config( # type: ignore[override] 132 cls, 133 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 134 cfg: MazeDatasetConfig, # type: ignore[override] 135 do_generate: bool = True, 136 load_local: bool = True, 137 save_local: bool = True, 138 zanj: ZANJ | None = None, 139 do_download: bool = True, 140 local_base_path: Path = Path("data/maze_dataset"), 141 except_on_config_mismatch: bool = True, 142 allow_generation_metadata_filter_mismatch: bool = True, 143 verbose: bool = False, 144 **kwargs, 145 ) -> "MazeDataset": 146 """create a maze dataset from a config 147 148 priority of loading: 149 1. load from local 150 2. download 151 3. generate 152 153 """ 154 return cast( 155 "MazeDataset", 156 super().from_config( 157 cfg=cfg, 158 do_generate=do_generate, 159 load_local=load_local, 160 save_local=save_local, 161 zanj=zanj, 162 do_download=do_download, 163 local_base_path=local_base_path, 164 except_on_config_mismatch=except_on_config_mismatch, 165 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 166 verbose=verbose, 167 **kwargs, 168 ), 169 )
create a maze dataset from a config
priority of loading:
- load from local
- download
- generate
171 def data_hash(self) -> int: 172 """return a hash of the data""" 173 return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
return a hash of the data
205 def as_tokens( 206 self, 207 maze_tokenizer, # TODO: MazeTokenizer 208 limit: int | None = None, 209 join_tokens_individual_maze: bool = False, 210 ) -> list[list[str]] | list[str]: 211 """return the dataset as tokens according to the passed `maze_tokenizer` 212 213 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 214 215 if `join_tokens_individual_maze` is True, then the tokens of each maze are 216 joined with a space, and the result is a list of strings. 217 i.e.: 218 219 >>> dataset.as_tokens(join_tokens_individual_maze=False) 220 [["a", "b", "c"], ["d", "e", "f"]] 221 >>> dataset.as_tokens(join_tokens_individual_maze=True) 222 ["a b c", "d e f"] 223 """ 224 output: list[list[str]] = [ 225 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 226 ] 227 if join_tokens_individual_maze: 228 return [" ".join(tokens) for tokens in output] 229 else: 230 return output
return the dataset as tokens according to the passed maze_tokenizer
the maze_tokenizer
should be either a MazeTokenizer
or a MazeTokenizerModular
if join_tokens_individual_maze
is True, then the tokens of each maze are
joined with a space, and the result is a list of strings.
i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
245 def assert_equal(self, other: "MazeDataset") -> None: 246 """assert that two datasets are equal""" 247 assert isinstance(other, MazeDataset) 248 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 249 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
assert that two datasets are equal
251 @classmethod 252 def generate( 253 cls, 254 cfg: MazeDatasetConfig, 255 gen_parallel: bool = False, 256 pool_kwargs: dict | None = None, 257 verbose: bool = False, 258 # TODO: what to do when unexpected kwargs are passed? 259 **kwargs, # noqa: ARG003 260 ) -> "MazeDataset": 261 """Generate a maze dataset given a config and some generation parameters""" 262 # Copy the config to avoid modifying the original 263 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 264 json.loads(json.dumps(cfg.serialize())), 265 ) 266 267 if pool_kwargs is None: 268 pool_kwargs = dict() 269 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 270 271 solved_mazes: list[SolvedMaze | None] 272 # Configure tqdm for progress bar 273 tqdm_kwargs: dict = dict( 274 total=cfg_cpy.n_mazes, 275 unit="maze", 276 desc="generating & solving mazes", 277 disable=not verbose, 278 ) 279 # TODO: don't use the global unless generating in parallel! 280 if gen_parallel: 281 with multiprocessing.Pool( 282 **pool_kwargs, 283 initializer=_maze_gen_init_worker, 284 initargs=(cfg_cpy,), 285 ) as pool: 286 solved_mazes = list( 287 tqdm.tqdm( 288 pool.imap(_generate_maze_helper, maze_indexes), 289 **tqdm_kwargs, 290 ), 291 ) 292 293 else: 294 _maze_gen_init_worker(cfg_cpy) 295 solved_mazes = list( 296 tqdm.tqdm( 297 map( 298 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 299 # why does it think tolist() returns a string? 300 _generate_maze_helper, # type: ignore[arg-type] 301 maze_indexes.tolist(), 302 ), 303 **tqdm_kwargs, 304 ), 305 ) 306 307 # Filter out None values explicitly after ensuring all results are collected 308 solved_mazes_: list[SolvedMaze] = [ 309 maze for maze in solved_mazes if maze is not None 310 ] 311 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 312 313 # Update the config with the actual number of mazes 314 cfg_cpy.n_mazes = len(solved_mazes_) 315 316 dataset: MazeDataset = cls( 317 cfg=cfg_cpy, 318 mazes=solved_mazes_, 319 ) 320 321 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 322 323 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 324 325 return dataset
Generate a maze dataset given a config and some generation parameters
327 @classmethod 328 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 329 "(not implemented yet!) download a maze dataset from the internet" 330 raise NotImplementedError("not implemented yet")
(not implemented yet!) download a maze dataset from the internet
332 @classmethod 333 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 334 """load from zanj/json""" 335 if data[_FORMAT_KEY] == "MazeDataset:minimal": 336 return cls._load_minimal(data) 337 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 338 return cls._load_minimal_soln_cat(data) 339 elif data[_FORMAT_KEY] == "MazeDataset": 340 if ( 341 SERIALIZE_MINIMAL_THRESHOLD == -1 342 ): # Allow access to `_load_legacy` for profiling 343 return cls._load_legacy(data) 344 return cls._load_full(data) 345 else: 346 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 347 raise KeyError( 348 err_msg, 349 )
load from zanj/json
430 def serialize(self) -> JSONdict: 431 """serialize to zanj/json""" 432 if ( 433 SERIALIZE_MINIMAL_THRESHOLD is not None 434 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 435 ): 436 return self._serialize_minimal() 437 return self._serialize_full()
serialize to zanj/json
542 def update_self_config(self) -> None: 543 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 544 if self.cfg.n_mazes != len(self.mazes): 545 warnings.warn( 546 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 547 ) 548 self.cfg.n_mazes = len(self.mazes)
update the config to match the current state of the dataset (number of mazes, such as after filtering)
550 def custom_maze_filter( 551 self, 552 method: typing.Callable[[SolvedMaze], bool], 553 **kwargs, 554 ) -> "MazeDataset": 555 """filter the dataset using a custom method""" 556 output: MazeDataset = MazeDataset( 557 cfg=copy.deepcopy(self.cfg), 558 mazes=[m for m in self.mazes if method(m, **kwargs)], 559 ) 560 output.cfg.applied_filters.append( 561 { 562 "name": f"__custom__:{method.__name__}", 563 "kwargs": kwargs, 564 }, 565 ) 566 output.update_self_config() 567 return output
filter the dataset using a custom method