maze_dataset.dataset
MazeDatasetConfig
s are used to create a MazeDataset
via MazeDataset.from_config(cfg)
"
When initializing mazes, further configuration options can be specified through the from_config()
factory method as necessary. Options include 1) whether to generate the dataset during runtime or load an existing dataset, 2) if and how to parallelize generation, and 3) where to store the generated dataset. Full documentation of configuration options is available in our repository [@maze-dataset-github]. Available maze generation algorithms are static methods of the LatticeMazeGenerators
class.
Furthermore, a dataset of mazes can be filtered to satisfy certain properties:
dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3)
Custom filters can be specified, and several filters are included:
path_length(min_length: int)
: shortest length from the origin to target should be at leastmin_length
.start_end_distance(min_distance: int)
: Manhattan distance between start and end should be at leastmin_distance
, ignoring walls.remove_duplicates(...)
: remove mazes which are similar to others in the dataset, measured via Hamming distance.remove_duplicates_fast()
: remove mazes which are exactly identical to others in the dataset.
All implemented maze generation algorithms are stochastic by nature. For reproducibility, the seed
parameter of MazeDatasetConfig
may be set. In practice, we do not find that exact duplicates of mazes are generated with any meaningful frequency,
even when generating large datasets.
1"""`MazeDatasetConfig`s are used to create a `MazeDataset` via `MazeDataset.from_config(cfg)`" 2 3When initializing mazes, further configuration options can be specified through the `from_config()` factory method as necessary. Options include 1) whether to generate the dataset during runtime or load an existing dataset, 2) if and how to parallelize generation, and 3) where to store the generated dataset. Full documentation of configuration options is available in our repository [@maze-dataset-github]. Available maze generation algorithms are static methods of the `LatticeMazeGenerators` class. 4 5Furthermore, a dataset of mazes can be filtered to satisfy certain properties: 6 7```python 8dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3) 9``` 10 11Custom filters can be specified, and several filters are included: 12 13- `path_length(min_length: int)`: shortest length from the origin to target should be at least `min_length`. 14- `start_end_distance(min_distance: int)`: Manhattan distance between start and end should be at least `min_distance`, ignoring walls. 15- `remove_duplicates(...)`: remove mazes which are similar to others in the dataset, measured via Hamming distance. 16- `remove_duplicates_fast()`: remove mazes which are exactly identical to others in the dataset. 17 18All implemented maze generation algorithms are stochastic by nature. For reproducibility, the `seed` parameter of `MazeDatasetConfig` may be set. In practice, we do not find that exact duplicates of mazes are generated with any meaningful frequency, 19even when generating large datasets. 20 21""" 22 23from maze_dataset.dataset.collected_dataset import ( 24 MazeDatasetCollection, 25 MazeDatasetCollectionConfig, 26) 27from maze_dataset.dataset.maze_dataset import MazeDataset 28from maze_dataset.dataset.maze_dataset_config import MazeDatasetConfig 29 30__all__ = [ 31 # submodules 32 "collected_dataset", 33 "configs", 34 "dataset", 35 "filters", 36 "maze_dataset_config", 37 "maze_dataset", 38 "rasterized", 39 "success_predict_math", 40 # dataset classes 41 "MazeDataset", 42 "MazeDatasetConfig", 43 "MazeDatasetCollection", 44 "MazeDatasetCollectionConfig", 45]
113class MazeDataset(GPTDataset[MazeDatasetConfig]): 114 """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`""" 115 116 def __init__( 117 self, 118 cfg: MazeDatasetConfig, 119 mazes: typing.Sequence[SolvedMaze], 120 generation_metadata_collected: dict | None = None, 121 ) -> None: 122 """initialize a maze dataset from a config and a list of solved mazes""" 123 super().__init__() 124 self.cfg: MazeDatasetConfig = cfg 125 self.mazes: list[SolvedMaze] = list(mazes) 126 self.generation_metadata_collected: dict | None = generation_metadata_collected 127 128 # TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset" [override] 129 @classmethod 130 def from_config( # type: ignore[override] 131 cls, 132 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 133 cfg: MazeDatasetConfig, # type: ignore[override] 134 do_generate: bool = True, 135 load_local: bool = True, 136 save_local: bool = True, 137 zanj: ZANJ | None = None, 138 do_download: bool = True, 139 local_base_path: Path = Path("data/maze_dataset"), 140 except_on_config_mismatch: bool = True, 141 allow_generation_metadata_filter_mismatch: bool = True, 142 verbose: bool = False, 143 **kwargs, 144 ) -> "MazeDataset": 145 """create a maze dataset from a config 146 147 priority of loading: 148 1. load from local 149 2. download 150 3. generate 151 152 """ 153 return cast( 154 MazeDataset, 155 super().from_config( 156 cfg=cfg, 157 do_generate=do_generate, 158 load_local=load_local, 159 save_local=save_local, 160 zanj=zanj, 161 do_download=do_download, 162 local_base_path=local_base_path, 163 except_on_config_mismatch=except_on_config_mismatch, 164 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 165 verbose=verbose, 166 **kwargs, 167 ), 168 ) 169 170 def data_hash(self) -> int: 171 """return a hash of the data""" 172 return stable_hash(str(tuple([x.serialize() for x in self.mazes]))) 173 174 def __getitem__(self, i: int) -> SolvedMaze: 175 """get a maze by index""" 176 return self.mazes[i] 177 178 def __iter__(self) -> typing.Iterator[SolvedMaze]: 179 """iterate over the mazes""" 180 return iter(self.mazes) 181 182 def __deepcopy__(self, memo) -> "MazeDataset": # noqa: ANN001 183 """deepcopy the dataset 184 185 FIX: this isnt actually a deepcopy I think? 186 """ 187 return MazeDataset.load(self._serialize_full()) 188 189 # TYPING: get type hints on the tokenizer here 190 @overload 191 def as_tokens( 192 self, 193 maze_tokenizer, # noqa: ANN001 194 limit: int | None = None, 195 join_tokens_individual_maze: Literal[False] = False, 196 ) -> list[list[str]]: ... 197 @overload 198 def as_tokens( 199 self, 200 maze_tokenizer, # noqa: ANN001 201 limit: int | None = None, 202 join_tokens_individual_maze: Literal[True] = True, 203 ) -> list[str]: ... 204 def as_tokens( 205 self, 206 maze_tokenizer, # TODO: MazeTokenizer 207 limit: int | None = None, 208 join_tokens_individual_maze: bool = False, 209 ) -> list[list[str]] | list[str]: 210 """return the dataset as tokens according to the passed `maze_tokenizer` 211 212 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 213 214 if `join_tokens_individual_maze` is True, then the tokens of each maze are 215 joined with a space, and the result is a list of strings. 216 i.e.: 217 218 >>> dataset.as_tokens(join_tokens_individual_maze=False) 219 [["a", "b", "c"], ["d", "e", "f"]] 220 >>> dataset.as_tokens(join_tokens_individual_maze=True) 221 ["a b c", "d e f"] 222 """ 223 output: list[list[str]] = [ 224 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 225 ] 226 if join_tokens_individual_maze: 227 return [" ".join(tokens) for tokens in output] 228 else: 229 return output 230 231 def __len__(self) -> int: 232 """return the number of mazes in the dataset""" 233 return len(self.mazes) 234 235 def __eq__(self, other: object) -> bool: 236 """compare two datasets""" 237 if not isinstance(other, MazeDataset): 238 raise NotImplementedError( 239 "can only compare with other MazeDataset objects", 240 ) 241 # TODO: compare hashes of data instead of the data itself? 242 return self.cfg == other.cfg and self.mazes == other.mazes 243 244 def assert_equal(self, other: "MazeDataset") -> None: 245 """assert that two datasets are equal""" 246 assert isinstance(other, MazeDataset) 247 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 248 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }" 249 250 @classmethod 251 def generate( 252 cls, 253 cfg: MazeDatasetConfig, 254 gen_parallel: bool = False, 255 pool_kwargs: dict | None = None, 256 verbose: bool = False, 257 # TODO: what to do when unexpected kwargs are passed? 258 **kwargs, # noqa: ARG003 259 ) -> "MazeDataset": 260 """Generate a maze dataset given a config and some generation parameters""" 261 # Copy the config to avoid modifying the original 262 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 263 json.loads(json.dumps(cfg.serialize())), 264 ) 265 266 if pool_kwargs is None: 267 pool_kwargs = dict() 268 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 269 270 solved_mazes: list[SolvedMaze | None] 271 # Configure tqdm for progress bar 272 tqdm_kwargs: dict = dict( 273 total=cfg_cpy.n_mazes, 274 unit="maze", 275 desc="generating & solving mazes", 276 disable=not verbose, 277 ) 278 # TODO: don't use the global unless generating in parallel! 279 if gen_parallel: 280 with multiprocessing.Pool( 281 **pool_kwargs, 282 initializer=_maze_gen_init_worker, 283 initargs=(cfg_cpy,), 284 ) as pool: 285 solved_mazes = list( 286 tqdm.tqdm( 287 pool.imap(_generate_maze_helper, maze_indexes), 288 **tqdm_kwargs, 289 ), 290 ) 291 292 else: 293 _maze_gen_init_worker(cfg_cpy) 294 solved_mazes = list( 295 tqdm.tqdm( 296 map( 297 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 298 # why does it think tolist() returns a string? 299 _generate_maze_helper, # type: ignore[arg-type] 300 maze_indexes.tolist(), 301 ), 302 **tqdm_kwargs, 303 ), 304 ) 305 306 # Filter out None values explicitly after ensuring all results are collected 307 solved_mazes_: list[SolvedMaze] = [ 308 maze for maze in solved_mazes if maze is not None 309 ] 310 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 311 312 # Update the config with the actual number of mazes 313 cfg_cpy.n_mazes = len(solved_mazes_) 314 315 dataset: MazeDataset = cls( 316 cfg=cfg_cpy, 317 mazes=solved_mazes_, 318 ) 319 320 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 321 322 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 323 324 return dataset 325 326 @classmethod 327 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 328 "(not implemented yet!) download a maze dataset from the internet" 329 raise NotImplementedError("not implemented yet") 330 331 @classmethod 332 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 333 """load from zanj/json""" 334 if data[_FORMAT_KEY] == "MazeDataset:minimal": 335 return cls._load_minimal(data) 336 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 337 return cls._load_minimal_soln_cat(data) 338 elif data[_FORMAT_KEY] == "MazeDataset": 339 if ( 340 SERIALIZE_MINIMAL_THRESHOLD == -1 341 ): # Allow access to `_load_legacy` for profiling 342 return cls._load_legacy(data) 343 return cls._load_full(data) 344 else: 345 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 346 raise KeyError( 347 err_msg, 348 ) 349 350 @classmethod 351 def _load_full(cls, data: JSONdict) -> "MazeDataset": 352 assert data[_FORMAT_KEY] == "MazeDataset" 353 return cls( 354 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 355 mazes=load_item_recursive(data["mazes"], tuple()), 356 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 357 ) 358 359 @classmethod 360 def _load_minimal(cls, data: JSONdict) -> "MazeDataset": 361 assert data[_FORMAT_KEY] == "MazeDataset:minimal" 362 return cls( 363 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 364 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 365 mazes=[ 366 SolvedMaze( 367 clist, 368 soln[:slen, ...], 369 ) 370 for clist, slen, soln in zip( 371 load_item_recursive(data["maze_connection_lists"], tuple()), 372 load_item_recursive(data["maze_solution_lengths"], tuple()), 373 load_item_recursive(data["maze_solutions"], tuple()), 374 strict=False, 375 # load_item_recursive(data["maze_endpoints"], tuple()), 376 ) 377 ], 378 ) 379 380 @classmethod 381 def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset": 382 assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat" 383 384 maze_solution_lengths = load_item_recursive( 385 data["maze_solution_lengths"], 386 tuple(), 387 ) 388 maze_solutions_concat = load_item_recursive( 389 data["maze_solutions_concat"], 390 tuple(), 391 ) 392 maze_solutions = np.split( 393 maze_solutions_concat, 394 np.cumsum(maze_solution_lengths)[:-1], 395 axis=0, 396 ) 397 398 return cls( 399 cfg=load_item_recursive(data["cfg"], tuple()), 400 generation_metadata_collected=load_item_recursive( 401 data["generation_metadata_collected"], 402 tuple(), 403 ), 404 mazes=[ 405 SolvedMaze( 406 connection_list=clist, 407 solution=soln, 408 ) 409 for clist, soln in zip( 410 load_item_recursive(data["maze_connection_lists"], tuple()), 411 # load_item_recursive(data["maze_endpoints"], tuple()), 412 maze_solutions, 413 strict=False, 414 ) 415 ], 416 ) 417 418 @classmethod 419 def _load_legacy(cls, data: JSONdict) -> "MazeDataset": 420 """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison.""" 421 assert data[_FORMAT_KEY] == "MazeDataset" 422 return cls( 423 **{ 424 key: load_item_recursive(data[key], tuple()) 425 for key in ["cfg", "mazes", "generation_metadata_collected"] 426 }, 427 ) 428 429 def serialize(self) -> JSONdict: 430 """serialize to zanj/json""" 431 if ( 432 SERIALIZE_MINIMAL_THRESHOLD is not None 433 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 434 ): 435 return self._serialize_minimal() 436 return self._serialize_full() 437 438 def _serialize_full(self) -> JSONdict: 439 return { 440 _FORMAT_KEY: "MazeDataset", 441 "cfg": json_serialize(self.cfg), 442 "fname": self.cfg.to_fname(), 443 "mazes": json_serialize(self.mazes), 444 "generation_metadata_collected": json_serialize( 445 self.generation_metadata_collected, 446 ), 447 } 448 449 def _serialize_minimal(self) -> JSONdict: 450 "alternate serialization where metadata is collected and mazes are stored in concatenated form" 451 filtered_meta: MazeDataset 452 if self.generation_metadata_collected is None: 453 filtered_meta = self.filter_by.collect_generation_meta() 454 else: 455 filtered_meta = self 456 457 max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes) 458 n_mazes: int = len(filtered_meta.mazes) 459 grid_n: int = filtered_meta.cfg.grid_n 460 461 maze_connection_lists: np.ndarray = np.empty( 462 (n_mazes, 2, grid_n, grid_n), 463 dtype=np.bool_, 464 ) 465 # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 466 maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32) 467 maze_solutions: np.ndarray = np.empty( 468 (n_mazes, max_solution_len, 2), 469 dtype=np.int8, 470 ) 471 472 for idx, maze in enumerate(filtered_meta.mazes): 473 maze_connection_lists[idx] = maze.connection_list 474 # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 475 maze_solution_lengths[idx] = maze.solution.shape[0] 476 maze_solutions[idx, : maze.solution.shape[0]] = maze.solution 477 478 return { 479 _FORMAT_KEY: "MazeDataset:minimal", 480 "cfg": json_serialize(filtered_meta.cfg), 481 "fname": filtered_meta.cfg.to_fname(), 482 "generation_metadata_collected": json_serialize( 483 filtered_meta.generation_metadata_collected, 484 ), 485 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 486 # "maze_endpoints": maze_endpoints, 487 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 488 "maze_solutions": maze_solutions, # type: ignore[dict-item] 489 } 490 491 def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict: 492 "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form" 493 filtered_meta: MazeDataset 494 if self.generation_metadata_collected is None: 495 filtered_meta = self.filter_by.collect_generation_meta() 496 else: 497 filtered_meta = self 498 499 maze_solution_lengths: np.ndarray = np.array( 500 [m.solution.shape[0] for m in filtered_meta.mazes], 501 dtype=np.int32, 502 ) 503 n_mazes: int = len(filtered_meta.mazes) 504 grid_n: int = filtered_meta.cfg.grid_n 505 total_solution_len: int = np.sum(maze_solution_lengths) 506 507 maze_connection_lists: np.ndarray = np.empty( 508 (n_mazes, 2, grid_n, grid_n), 509 dtype=np.bool_, 510 ) 511 maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 512 maze_solutions_concat: np.ndarray = np.empty( 513 (total_solution_len, 2), 514 dtype=np.int8, 515 ) 516 517 solutions_running_idx: int = 0 518 for idx, maze in enumerate(filtered_meta.mazes): 519 maze_connection_lists[idx] = maze.connection_list 520 maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 521 soln_len: int = maze.solution.shape[0] 522 maze_solution_lengths[idx] = soln_len 523 maze_solutions_concat[ 524 solutions_running_idx : solutions_running_idx + soln_len 525 ] = maze.solution 526 solutions_running_idx += soln_len 527 528 return { 529 _FORMAT_KEY: "MazeDataset:minimal_soln_cat", 530 "cfg": json_serialize(filtered_meta.cfg), 531 "fname": filtered_meta.cfg.to_fname(), 532 "generation_metadata_collected": json_serialize( 533 filtered_meta.generation_metadata_collected, 534 ), 535 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 536 "maze_endpoints": maze_endpoints, # type: ignore[dict-item] 537 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 538 "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item] 539 } 540 541 def update_self_config(self) -> None: 542 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 543 if self.cfg.n_mazes != len(self.mazes): 544 warnings.warn( 545 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 546 ) 547 self.cfg.n_mazes = len(self.mazes) 548 549 def custom_maze_filter( 550 self, 551 method: typing.Callable[[SolvedMaze], bool], 552 **kwargs, 553 ) -> "MazeDataset": 554 """filter the dataset using a custom method""" 555 output: MazeDataset = MazeDataset( 556 cfg=copy.deepcopy(self.cfg), 557 mazes=[m for m in self.mazes if method(m, **kwargs)], 558 ) 559 output.cfg.applied_filters.append( 560 { 561 "name": f"__custom__:{method.__name__}", 562 "kwargs": kwargs, 563 }, 564 ) 565 output.update_self_config() 566 return output
a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config
116 def __init__( 117 self, 118 cfg: MazeDatasetConfig, 119 mazes: typing.Sequence[SolvedMaze], 120 generation_metadata_collected: dict | None = None, 121 ) -> None: 122 """initialize a maze dataset from a config and a list of solved mazes""" 123 super().__init__() 124 self.cfg: MazeDatasetConfig = cfg 125 self.mazes: list[SolvedMaze] = list(mazes) 126 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize a maze dataset from a config and a list of solved mazes
129 @classmethod 130 def from_config( # type: ignore[override] 131 cls, 132 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 133 cfg: MazeDatasetConfig, # type: ignore[override] 134 do_generate: bool = True, 135 load_local: bool = True, 136 save_local: bool = True, 137 zanj: ZANJ | None = None, 138 do_download: bool = True, 139 local_base_path: Path = Path("data/maze_dataset"), 140 except_on_config_mismatch: bool = True, 141 allow_generation_metadata_filter_mismatch: bool = True, 142 verbose: bool = False, 143 **kwargs, 144 ) -> "MazeDataset": 145 """create a maze dataset from a config 146 147 priority of loading: 148 1. load from local 149 2. download 150 3. generate 151 152 """ 153 return cast( 154 MazeDataset, 155 super().from_config( 156 cfg=cfg, 157 do_generate=do_generate, 158 load_local=load_local, 159 save_local=save_local, 160 zanj=zanj, 161 do_download=do_download, 162 local_base_path=local_base_path, 163 except_on_config_mismatch=except_on_config_mismatch, 164 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 165 verbose=verbose, 166 **kwargs, 167 ), 168 )
create a maze dataset from a config
priority of loading:
- load from local
- download
- generate
170 def data_hash(self) -> int: 171 """return a hash of the data""" 172 return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
return a hash of the data
204 def as_tokens( 205 self, 206 maze_tokenizer, # TODO: MazeTokenizer 207 limit: int | None = None, 208 join_tokens_individual_maze: bool = False, 209 ) -> list[list[str]] | list[str]: 210 """return the dataset as tokens according to the passed `maze_tokenizer` 211 212 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 213 214 if `join_tokens_individual_maze` is True, then the tokens of each maze are 215 joined with a space, and the result is a list of strings. 216 i.e.: 217 218 >>> dataset.as_tokens(join_tokens_individual_maze=False) 219 [["a", "b", "c"], ["d", "e", "f"]] 220 >>> dataset.as_tokens(join_tokens_individual_maze=True) 221 ["a b c", "d e f"] 222 """ 223 output: list[list[str]] = [ 224 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 225 ] 226 if join_tokens_individual_maze: 227 return [" ".join(tokens) for tokens in output] 228 else: 229 return output
return the dataset as tokens according to the passed maze_tokenizer
the maze_tokenizer
should be either a MazeTokenizer
or a MazeTokenizerModular
if join_tokens_individual_maze
is True, then the tokens of each maze are
joined with a space, and the result is a list of strings.
i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
244 def assert_equal(self, other: "MazeDataset") -> None: 245 """assert that two datasets are equal""" 246 assert isinstance(other, MazeDataset) 247 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 248 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
assert that two datasets are equal
250 @classmethod 251 def generate( 252 cls, 253 cfg: MazeDatasetConfig, 254 gen_parallel: bool = False, 255 pool_kwargs: dict | None = None, 256 verbose: bool = False, 257 # TODO: what to do when unexpected kwargs are passed? 258 **kwargs, # noqa: ARG003 259 ) -> "MazeDataset": 260 """Generate a maze dataset given a config and some generation parameters""" 261 # Copy the config to avoid modifying the original 262 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 263 json.loads(json.dumps(cfg.serialize())), 264 ) 265 266 if pool_kwargs is None: 267 pool_kwargs = dict() 268 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 269 270 solved_mazes: list[SolvedMaze | None] 271 # Configure tqdm for progress bar 272 tqdm_kwargs: dict = dict( 273 total=cfg_cpy.n_mazes, 274 unit="maze", 275 desc="generating & solving mazes", 276 disable=not verbose, 277 ) 278 # TODO: don't use the global unless generating in parallel! 279 if gen_parallel: 280 with multiprocessing.Pool( 281 **pool_kwargs, 282 initializer=_maze_gen_init_worker, 283 initargs=(cfg_cpy,), 284 ) as pool: 285 solved_mazes = list( 286 tqdm.tqdm( 287 pool.imap(_generate_maze_helper, maze_indexes), 288 **tqdm_kwargs, 289 ), 290 ) 291 292 else: 293 _maze_gen_init_worker(cfg_cpy) 294 solved_mazes = list( 295 tqdm.tqdm( 296 map( 297 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 298 # why does it think tolist() returns a string? 299 _generate_maze_helper, # type: ignore[arg-type] 300 maze_indexes.tolist(), 301 ), 302 **tqdm_kwargs, 303 ), 304 ) 305 306 # Filter out None values explicitly after ensuring all results are collected 307 solved_mazes_: list[SolvedMaze] = [ 308 maze for maze in solved_mazes if maze is not None 309 ] 310 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 311 312 # Update the config with the actual number of mazes 313 cfg_cpy.n_mazes = len(solved_mazes_) 314 315 dataset: MazeDataset = cls( 316 cfg=cfg_cpy, 317 mazes=solved_mazes_, 318 ) 319 320 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 321 322 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 323 324 return dataset
Generate a maze dataset given a config and some generation parameters
326 @classmethod 327 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 328 "(not implemented yet!) download a maze dataset from the internet" 329 raise NotImplementedError("not implemented yet")
(not implemented yet!) download a maze dataset from the internet
331 @classmethod 332 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 333 """load from zanj/json""" 334 if data[_FORMAT_KEY] == "MazeDataset:minimal": 335 return cls._load_minimal(data) 336 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 337 return cls._load_minimal_soln_cat(data) 338 elif data[_FORMAT_KEY] == "MazeDataset": 339 if ( 340 SERIALIZE_MINIMAL_THRESHOLD == -1 341 ): # Allow access to `_load_legacy` for profiling 342 return cls._load_legacy(data) 343 return cls._load_full(data) 344 else: 345 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 346 raise KeyError( 347 err_msg, 348 )
load from zanj/json
429 def serialize(self) -> JSONdict: 430 """serialize to zanj/json""" 431 if ( 432 SERIALIZE_MINIMAL_THRESHOLD is not None 433 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 434 ): 435 return self._serialize_minimal() 436 return self._serialize_full()
serialize to zanj/json
541 def update_self_config(self) -> None: 542 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 543 if self.cfg.n_mazes != len(self.mazes): 544 warnings.warn( 545 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 546 ) 547 self.cfg.n_mazes = len(self.mazes)
update the config to match the current state of the dataset (number of mazes, such as after filtering)
549 def custom_maze_filter( 550 self, 551 method: typing.Callable[[SolvedMaze], bool], 552 **kwargs, 553 ) -> "MazeDataset": 554 """filter the dataset using a custom method""" 555 output: MazeDataset = MazeDataset( 556 cfg=copy.deepcopy(self.cfg), 557 mazes=[m for m in self.mazes if method(m, **kwargs)], 558 ) 559 output.cfg.applied_filters.append( 560 { 561 "name": f"__custom__:{method.__name__}", 562 "kwargs": kwargs, 563 }, 564 ) 565 output.update_self_config() 566 return output
filter the dataset using a custom method
Inherited Members
257@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 258class MazeDatasetConfig(MazeDatasetConfig_base): # type: ignore[misc] 259 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset 260 261 # Parameters: 262 - `name : str` 263 name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname` 264 - `grid_n : int` 265 grid size of the maze (number of rows/columns) 266 - `n_mazes : int` 267 number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate. 268 see `EndpointKwargsType` for more details. 269 - `maze_ctor : Callable` 270 maze generator function. This should be a function that takes a grid size and returns a maze. 271 This will usually be one of the functions in `LatticeMazeGenerators`. 272 - `maze_ctor_kwargs : dict` 273 keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using. 274 - `endpoint_kwargs : EndpointKwargsType` 275 keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info. 276 - `applied_filters : list[dict]` 277 list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, 278 but these are stored with the config in case you want to re-generate the dataset with the same filters. 279 280 """ 281 282 @property 283 def config_version(self) -> str: 284 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 285 return "1.0" 286 287 @property 288 def versions(self) -> dict: 289 """return the versions of the config and the maze_dataset""" 290 return dict( 291 config=self.config_version, 292 maze_dataset=importlib.metadata.version("maze_dataset"), 293 ) 294 295 def serialize(self) -> dict: 296 "serialize the MazeDatasetConfig with all fields and fname" 297 return { 298 **self._serialize_base( 299 applied_filters__skip__collect_generation_meta=False 300 ), 301 "fname": self.to_fname(), 302 "versions": self.versions, 303 } 304 305 def summary(self) -> dict: 306 """return a summary of the config""" 307 # do we run this to make sure it doesn't error? 308 super_summary: dict = super().summary() 309 assert super_summary 310 self_ser: dict = self.serialize() 311 return dict( 312 name=self.name, 313 fname=self.to_fname(), 314 sdc_hash=self.stable_hash_cfg(), 315 seed=self.seed, 316 seq_len_min=self.seq_len_min, 317 seq_len_max=self.seq_len_max, 318 applied_filters=self.applied_filters, 319 grid_n=self_ser["grid_n"], 320 n_mazes=self_ser["n_mazes"], 321 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 322 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 323 endpoint_kwargs=self_ser["endpoint_kwargs"], 324 ) 325 326 def _to_ps_array(self) -> _PercolationSuccessArray: 327 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 328 329 used in predicting the success rate 330 """ 331 try: 332 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 333 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 334 ) 335 assert "p" in self.maze_ctor_kwargs, ( 336 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 337 ) 338 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 339 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 340 ) 341 except AssertionError as e: 342 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 343 raise NoPercolationInConfigError( 344 err_msg, 345 ) from e 346 347 endpoints_unique_flag: int = int( 348 # we are pretty sure it will be an int or bool here 349 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 350 ) 351 352 # adjustment for bknutson0 353 if not ( 354 self.endpoint_kwargs.get("deadend_start", False) 355 and self.endpoint_kwargs.get("deadend_end", False) 356 ): 357 # we didnt train on this, but if either endpoint is not required to be in a dead end 358 # then requiring the endpoints to be unique does not really affect the success rate 359 # (except for very small percolation values, pure percolation generation) 360 endpoints_unique_flag = 0 361 362 return np.array( 363 [ 364 float(self.maze_ctor_kwargs["p"]), 365 float(self.grid_n), 366 float( 367 int( 368 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 369 or self.endpoint_kwargs.get("deadend_end", False), 370 ), 371 ), 372 float(endpoints_unique_flag), 373 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 374 ], 375 dtype=np.float64, 376 ) 377 378 @classmethod 379 def _from_ps_array( 380 cls, 381 arr: _PercolationSuccessArray, 382 name: str = "predict", 383 n_mazes: int = 100, 384 **kwargs, 385 ) -> "MazeDatasetConfig": 386 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 387 388 # Returns: 389 - `MazeDatasetConfig` 390 Config corresponding to `arr` 391 """ 392 return cls( 393 name=name, 394 grid_n=int(arr[1]), 395 n_mazes=n_mazes, 396 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 397 maze_ctor_kwargs={"p": float(arr[0])}, 398 endpoint_kwargs=dict( 399 deadend_start=bool(arr[2]), 400 deadend_end=bool(arr[2]), 401 endpoints_not_equal=bool(arr[3]), 402 except_on_no_valid_endpoint=False, 403 ), 404 **kwargs, 405 ) 406 407 def success_fraction_estimate( 408 self, 409 except_if_all_success_expected: bool = False, 410 ) -> float: 411 """Estimate the success fraction of this config. 412 413 only valid when the generator is a percolation generator, 414 and endpoints are enforced to be dead ends 415 416 more information on where this comes from can be found in 417 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 418 - `estimate_dataset_fractions.ipynb` 419 - `maze_dataset.benchmarks.sweep_fit` 420 421 # Parameters: 422 - `except_if_all_success_expected : bool` 423 if `True`, don't raise an error if the success fraction is below the threshold. 424 will always return `1.0` if the config is not expected to fail 425 426 # Returns: 427 - `float` 428 estimated success fraction 429 430 # Raises: 431 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 432 """ 433 try: 434 return cfg_success_predict_fn(self) 435 436 except NoPercolationInConfigError as e: 437 if except_if_all_success_expected: 438 raise e # noqa: TRY201 439 return 1.0 440 441 def success_fraction_compensate( 442 self, 443 safety_margin: float = 1.2, 444 except_if_all_success_expected: bool = False, 445 epsilon: float = 1e-2, 446 ) -> "MazeDatasetConfig": 447 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 448 449 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 450 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 451 452 more information on where this comes from can be found in 453 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 454 - `estimate_dataset_fractions.ipynb` 455 - `maze_dataset.benchmarks.sweep_fit` 456 457 # Parameters: 458 - `safety_margin : float` 459 safety margin to apply to the success fraction estimate 460 (defaults to `1.2`, or 20% more mazes than estimated) 461 - `except_if_all_success_expected : bool` 462 if `True`, don't raise an error if the success fraction is below the threshold. 463 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 464 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 465 since `safety_margin` is still applied. 466 (defaults to `False`) 467 - `epsilon : float` 468 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 469 (defaults to `1e-2`) 470 471 # Returns: 472 - `MazeDatasetConfig` 473 new config with adjusted `n_mazes` 474 475 # Raises: 476 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 477 """ 478 # compute and check the success fraction 479 success_fraction: float = self.success_fraction_estimate( 480 except_if_all_success_expected=except_if_all_success_expected, 481 ) 482 if success_fraction < epsilon: 483 err_msg: str = ( 484 f"{success_fraction = } is below the threshold of {epsilon = }" 485 ) 486 raise SuccessChanceTooSmallError( 487 err_msg, 488 ) 489 490 # compute the new number of mazes 491 n_mazes: int = self.n_mazes 492 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 493 494 # put it in a new config and return 495 cfg_dict: dict = self.serialize() 496 cfg_dict["n_mazes"] = new_n_mazes 497 return MazeDatasetConfig.load(cfg_dict)
config object which is passed to MazeDataset.from_config
to generate or load a dataset
Parameters:
name : str
name of the dataset -- this can be anything, but should be filesystem safe since we use it in thefname
grid_n : int
grid size of the maze (number of rows/columns)n_mazes : int
number of mazes to request. For some combinations ofendpoint_kwargs
andmaze_ctor
, not all mazes might successfully generate. seeEndpointKwargsType
for more details.maze_ctor : Callable
maze generator function. This should be a function that takes a grid size and returns a maze. This will usually be one of the functions inLatticeMazeGenerators
.maze_ctor_kwargs : dict
keyword arguments to pass to the maze generator function. Specific to themaze_ctor
you are using.endpoint_kwargs : EndpointKwargsType
keyword arguments passed toLatticeMaze.generate_random_path()
. seeEndpointKwargsType
for more info.applied_filters : list[dict]
list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, but these are stored with the config in case you want to re-generate the dataset with the same filters.
282 @property 283 def config_version(self) -> str: 284 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 285 return "1.0"
return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config
287 @property 288 def versions(self) -> dict: 289 """return the versions of the config and the maze_dataset""" 290 return dict( 291 config=self.config_version, 292 maze_dataset=importlib.metadata.version("maze_dataset"), 293 )
return the versions of the config and the maze_dataset
295 def serialize(self) -> dict: 296 "serialize the MazeDatasetConfig with all fields and fname" 297 return { 298 **self._serialize_base( 299 applied_filters__skip__collect_generation_meta=False 300 ), 301 "fname": self.to_fname(), 302 "versions": self.versions, 303 }
serialize the MazeDatasetConfig with all fields and fname
305 def summary(self) -> dict: 306 """return a summary of the config""" 307 # do we run this to make sure it doesn't error? 308 super_summary: dict = super().summary() 309 assert super_summary 310 self_ser: dict = self.serialize() 311 return dict( 312 name=self.name, 313 fname=self.to_fname(), 314 sdc_hash=self.stable_hash_cfg(), 315 seed=self.seed, 316 seq_len_min=self.seq_len_min, 317 seq_len_max=self.seq_len_max, 318 applied_filters=self.applied_filters, 319 grid_n=self_ser["grid_n"], 320 n_mazes=self_ser["n_mazes"], 321 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 322 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 323 endpoint_kwargs=self_ser["endpoint_kwargs"], 324 )
return a summary of the config
407 def success_fraction_estimate( 408 self, 409 except_if_all_success_expected: bool = False, 410 ) -> float: 411 """Estimate the success fraction of this config. 412 413 only valid when the generator is a percolation generator, 414 and endpoints are enforced to be dead ends 415 416 more information on where this comes from can be found in 417 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 418 - `estimate_dataset_fractions.ipynb` 419 - `maze_dataset.benchmarks.sweep_fit` 420 421 # Parameters: 422 - `except_if_all_success_expected : bool` 423 if `True`, don't raise an error if the success fraction is below the threshold. 424 will always return `1.0` if the config is not expected to fail 425 426 # Returns: 427 - `float` 428 estimated success fraction 429 430 # Raises: 431 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 432 """ 433 try: 434 return cfg_success_predict_fn(self) 435 436 except NoPercolationInConfigError as e: 437 if except_if_all_success_expected: 438 raise e # noqa: TRY201 439 return 1.0
Estimate the success fraction of this config.
only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends
more information on where this comes from can be found in
cfg_success_predict_fn()
frommaze_dataset.dataset.success_predict_math
estimate_dataset_fractions.ipynb
maze_dataset.benchmarks.sweep_fit
Parameters:
except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. will always return1.0
if the config is not expected to fail
Returns:
float
estimated success fraction
Raises:
NoPercolationInConfigError
: if the config is not expected to fail, andexcept_if_all_success_expected
isFalse
441 def success_fraction_compensate( 442 self, 443 safety_margin: float = 1.2, 444 except_if_all_success_expected: bool = False, 445 epsilon: float = 1e-2, 446 ) -> "MazeDatasetConfig": 447 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 448 449 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 450 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 451 452 more information on where this comes from can be found in 453 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 454 - `estimate_dataset_fractions.ipynb` 455 - `maze_dataset.benchmarks.sweep_fit` 456 457 # Parameters: 458 - `safety_margin : float` 459 safety margin to apply to the success fraction estimate 460 (defaults to `1.2`, or 20% more mazes than estimated) 461 - `except_if_all_success_expected : bool` 462 if `True`, don't raise an error if the success fraction is below the threshold. 463 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 464 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 465 since `safety_margin` is still applied. 466 (defaults to `False`) 467 - `epsilon : float` 468 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 469 (defaults to `1e-2`) 470 471 # Returns: 472 - `MazeDatasetConfig` 473 new config with adjusted `n_mazes` 474 475 # Raises: 476 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 477 """ 478 # compute and check the success fraction 479 success_fraction: float = self.success_fraction_estimate( 480 except_if_all_success_expected=except_if_all_success_expected, 481 ) 482 if success_fraction < epsilon: 483 err_msg: str = ( 484 f"{success_fraction = } is below the threshold of {epsilon = }" 485 ) 486 raise SuccessChanceTooSmallError( 487 err_msg, 488 ) 489 490 # compute the new number of mazes 491 n_mazes: int = self.n_mazes 492 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 493 494 # put it in a new config and return 495 cfg_dict: dict = self.serialize() 496 cfg_dict["n_mazes"] = new_n_mazes 497 return MazeDatasetConfig.load(cfg_dict)
return a new MazeDatasetConfig
like this one with n_mazes
adjusted to compensate for the success fraction
calls MazeDatasetConfig.success_fraction_estimate()
to get the success fraction, and then
computes the new number of mazes as n_mazes = n_mazes * safety_margin / success_fraction + 1
more information on where this comes from can be found in
cfg_success_predict_fn()
frommaze_dataset.dataset.success_predict_math
estimate_dataset_fractions.ipynb
maze_dataset.benchmarks.sweep_fit
Parameters:
safety_margin : float
safety margin to apply to the success fraction estimate (defaults to1.2
, or 20% more mazes than estimated)except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. this is passed toMazeDatasetConfig.success_fraction_estimate
. if your config isn't expected to fail, passing this might mean you generate more mazes than needed sincesafety_margin
is still applied. (defaults toFalse
)epsilon : float
raiseSuccessChanceTooSmallError
if the success fraction is below this threshold (defaults to1e-2
)
Returns:
MazeDatasetConfig
new config with adjustedn_mazes
Raises:
SuccessChanceTooSmallError
: if the computed success fraction is belowepsilon
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.dataset.maze_dataset_config.MazeDatasetConfig_base
- grid_n
- n_mazes
- maze_ctor
- maze_ctor_kwargs
- endpoint_kwargs
- grid_shape
- grid_shape_np
- max_grid_n
- stable_hash_cfg
- to_fname
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
84class MazeDatasetCollection(GPTDataset): 85 """a collection of maze datasets""" 86 87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected 106 107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets] 111 112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths))) 116 117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 ) 125 126 def __len__(self) -> int: 127 """return the total number of mazes in the collection""" 128 return sum(len(dataset) for dataset in self.maze_datasets) 129 130 def __getitem__(self, index: int) -> LatticeMaze: 131 "get a maze by index" 132 # find which dataset the index belongs to 133 # we add 1, since np.searchsorted returns the 134 # index of the last element that is strictly less than the target 135 # while we want the index of the last element less than or equal to the target 136 dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1)) 137 index_adjusted: int = index 138 if dataset_idx > 0: 139 # if the index is 0, `dataset_idx - 1` will be -1. 140 # We just want to use the base index 141 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1] 142 return self.maze_datasets[dataset_idx][index_adjusted] 143 144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets) 156 157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets) 169 170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 } 180 181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 ) 191 192 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow? 193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output 217 218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
a collection of maze datasets
87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize the dataset collection from a MazeDatasetCollectionConfig
and a list of MazeDataset
s
107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets]
return the lengths of each dataset in the collection
112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths)))
return the cumulative lengths of each dataset in the collection
117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 )
single list of all mazes in the collection
144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets)
generate a dataset collection from a config
157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets)
(not implemented!) download a dataset collection from a config
170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 }
serialize the dataset collection
181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 )
load the dataset collection from the representation created by serialize
193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output
return the dataset as tokens
if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
update the config to match the number of mazes, and update the underlying configs of each dataset
Inherited Members
31@serializable_dataclass(kw_only=True) 32class MazeDatasetCollectionConfig(GPTDatasetConfig): 33 """maze dataset collection configuration, including tokenizers and shuffle""" 34 35 # Attributes without a default cannot follow attributes with one [misc] 36 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( # type: ignore[misc] 37 serialization_fn=lambda configs: [config.serialize() for config in configs], 38 loading_fn=lambda data: [ 39 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"] 40 ], 41 ) 42 43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 ) 52 53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs) 57 58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs) 62 63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n) 67 68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32) 72 73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize())) 76 77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
maze dataset collection configuration, including tokenizers and shuffle
43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 )
return a summary of the config
53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs)
return the total number of mazes in the collection across all dataset
58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs)
return the maximum grid size of the mazes in the collection
63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n)
return the maximum grid shape of the mazes in the collection
68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32)
return the maximum grid shape of the mazes in the collection as a numpy array
73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize()))
return a stable hash of the config
77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
convert config to a filename
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict