maze_dataset
maze-dataset
This package provides utilities for generation, filtering, solving, visualizing, and processing of mazes for training or evaluating ML systems. Primarily built for the maze-transformer interpretability project. You can find our paper on it here: http://arxiv.org/abs/2309.10498
This package includes a variety of maze generation algorithms, including randomized depth first search, Wilson's algorithm for uniform spanning trees, and percolation. Datasets can be filtered to select mazes of a certain length or complexity, remove duplicates, and satisfy custom properties. A variety of output formats for visualization and training ML models are provided.
![]() |
![]() |
![]() |
![]() |
You can view and search through a wide variety of example mazes here: understanding-search.github.io/maze-dataset/examples/maze_examples
Citing
If you use this code in your research, please cite our paper:
@misc{maze-dataset,
title={A Configurable Library for Generating and Manipulating Maze Datasets},
author={Michael Igorevich Ivanitskiy and Rusheb Shah and Alex F. Spies and Tilman Räuker and Dan Valentine and Can Rager and Lucia Quirke and Chris Mathwin and Guillaume Corlouer and Cecilia Diniz Behn and Samy Wu Fung},
year={2023},
eprint={2309.10498},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={http://arxiv.org/abs/2309.10498}
}
Installation
This package is available on PyPI, and can be installed via
pip install maze-dataset
Docs
The full hosted documentation is available at https://understanding-search.github.io/maze-dataset/.
Additionally, our notebooks serve as a good starting point for understanding the package.
Usage
Creating a dataset
To create a MazeDataset
, which inherits from torch.utils.data.Dataset
, you first create a MazeDatasetConfig
:
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.generation import LatticeMazeGenerators
cfg: MazeDatasetConfig = MazeDatasetConfig(
name="test", # name is only for you to keep track of things
grid_n=5, # number of rows/columns in the lattice
n_mazes=4, # number of mazes to generate
maze_ctor=LatticeMazeGenerators.gen_dfs, # algorithm to generate the maze
maze_ctor_kwargs=dict(do_forks=False), # additional parameters to pass to the maze generation algorithm
)
and then pass this config to the MazeDataset.from_config
method:
dataset: MazeDataset = MazeDataset.from_config(cfg)
This method can search for whether a dataset with matching config hash already exists on your filesystem in the expected location, and load it if so. It can also generate a dataset on the fly if needed.
Conversions to useful formats
The elements of the dataset are SolvedMaze
objects:
>>> m = dataset[0]
>>> type(m)
SolvedMaze
Which can be converted to a variety of formats:
# visual representation as ascii art
m.as_ascii()
# RGB image, optionally without solution or endpoints, suitable for CNNs
m.as_pixels()
# text format for autoreregressive transformers
from maze_dataset.tokenization import MazeTokenizerModular, TokenizationMode
m.as_tokens(maze_tokenizer=MazeTokenizerModular(
tokenization_mode=TokenizationMode.AOTP_UT_rasterized, max_grid_size=100,
))
# advanced visualization with many features
from maze_dataset.plotting import MazePlot
MazePlot(maze).plot()
Development
we use this makefile template with slight modifications for our development workflow.
- clone with
git clone https://github.com/understanding-search/maze-dataset
make dep
to install all dependenciesmake help
will print all available commandsmake test
will run basic tests to ensure the package is workingmake format
will run ruff to format and check the code
1""".. include:: ../README.md""" 2 3from maze_dataset.constants import ( 4 SPECIAL_TOKENS, 5 VOCAB, 6 VOCAB_LIST, 7 VOCAB_TOKEN_TO_INDEX, 8 Connection, 9 ConnectionArray, 10 ConnectionList, 11 Coord, 12 CoordArray, 13 CoordList, 14 CoordTup, 15) 16from maze_dataset.dataset.collected_dataset import ( 17 MazeDatasetCollection, 18 MazeDatasetCollectionConfig, 19) 20from maze_dataset.dataset.filters import register_maze_filter 21from maze_dataset.dataset.maze_dataset import ( 22 MazeDataset, 23 MazeDatasetConfig, 24) 25from maze_dataset.dataset.maze_dataset_config import set_serialize_minimal_threshold 26from maze_dataset.generation.generators import LatticeMazeGenerators 27from maze_dataset.maze.lattice_maze import LatticeMaze, SolvedMaze, TargetedLatticeMaze 28 29__all__ = [ 30 # submodules (with sub-submodules) 31 "benchmark", 32 "dataset", 33 "generation", 34 "maze", 35 "plotting", 36 "tokenization", 37 # submodules 38 "constants", 39 "testing_utils", 40 "token_utils", 41 "utils", 42 # main 43 "SolvedMaze", 44 "MazeDatasetConfig", 45 "MazeDataset", 46 # dataset classes 47 "MazeDatasetCollection", 48 "MazeDatasetCollectionConfig", 49 # maze classes 50 "TargetedLatticeMaze", 51 "LatticeMaze", 52 # other 53 "set_serialize_minimal_threshold", 54 "register_maze_filter", 55 "LatticeMazeGenerators", 56 # types 57 "Coord", 58 "CoordTup", 59 "CoordList", 60 "CoordArray", 61 "Connection", 62 "ConnectionList", 63 "ConnectionArray", 64 # constants 65 "SPECIAL_TOKENS", 66 "VOCAB", 67 "VOCAB_LIST", 68 "VOCAB_TOKEN_TO_INDEX", 69]
1265@serializable_dataclass(frozen=True, kw_only=True) 1266class SolvedMaze(TargetedLatticeMaze): # type: ignore[misc] 1267 """Stores a maze and a solution""" 1268 1269 solution: CoordArray = serializable_field( # type: ignore[misc] 1270 assert_type=False, 1271 ) 1272 1273 def __init__( 1274 self, 1275 connection_list: ConnectionList, 1276 solution: CoordArray, 1277 generation_meta: dict | None = None, 1278 start_pos: Coord | None = None, 1279 end_pos: Coord | None = None, 1280 allow_invalid: bool = False, 1281 ) -> None: 1282 """Create a SolvedMaze from a connection list and a solution 1283 1284 > DOCS: better documentation for this init method 1285 """ 1286 # figure out the solution 1287 solution_valid: bool = False 1288 if solution is not None: 1289 solution = np.array(solution) 1290 # note that a path length of 1 here is valid, since the start and end pos could be the same 1291 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 1292 solution_valid = True 1293 1294 if not solution_valid and not allow_invalid: 1295 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" 1296 raise ValueError( 1297 err_msg, 1298 f"{connection_list = }", 1299 ) 1300 1301 # init the TargetedLatticeMaze 1302 super().__init__( 1303 connection_list=connection_list, 1304 generation_meta=generation_meta, 1305 # TODO: the argument type is stricter than the expected type but it still fails? 1306 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 1307 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 1308 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 1309 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 1310 ) 1311 1312 self.__dict__["solution"] = solution 1313 1314 # adjust the endpoints 1315 if not allow_invalid: 1316 if start_pos is not None: 1317 assert np.array_equal(np.array(start_pos), self.start_pos), ( 1318 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 1319 ) 1320 if end_pos is not None: 1321 assert np.array_equal(np.array(end_pos), self.end_pos), ( 1322 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 1323 ) 1324 # TODO: assert the path does not backtrack, walk through walls, etc? 1325 1326 def __eq__(self, other: object) -> bool: 1327 "check equality, calls parent class equality check" 1328 return super().__eq__(other) 1329 1330 def __hash__(self) -> int: 1331 "hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes" 1332 return hash((self.connection_list.tobytes(), self.solution.tobytes())) 1333 1334 def _get_solution_tokens(self) -> list[str | CoordTup]: 1335 return [ 1336 SPECIAL_TOKENS.PATH_START, 1337 *[tuple(c) for c in self.solution], 1338 SPECIAL_TOKENS.PATH_END, 1339 ] 1340 1341 def get_solution_tokens(self) -> list[str | CoordTup]: 1342 "(deprecated!) return the solution as a list of tokens" 1343 warnings.warn( 1344 "`LatticeMaze.get_solution_tokens` is deprecated.", 1345 TokenizerDeprecationWarning, 1346 ) 1347 return self._get_solution_tokens() 1348 1349 # for backwards compatibility 1350 @property 1351 def maze(self) -> LatticeMaze: 1352 "(deprecated!) return the maze without the solution" 1353 warnings.warn( 1354 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 1355 DeprecationWarning, 1356 ) 1357 return LatticeMaze(connection_list=self.connection_list) 1358 1359 # type ignore here since we're overriding a method with a different signature 1360 @classmethod 1361 def from_lattice_maze( # type: ignore[override] 1362 cls, 1363 lattice_maze: LatticeMaze, 1364 solution: list[CoordTup] | CoordArray, 1365 ) -> "SolvedMaze": 1366 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" 1367 return cls( 1368 connection_list=lattice_maze.connection_list, 1369 solution=np.array(solution), 1370 generation_meta=lattice_maze.generation_meta, 1371 ) 1372 1373 @classmethod 1374 def from_targeted_lattice_maze( 1375 cls, 1376 targeted_lattice_maze: TargetedLatticeMaze, 1377 solution: list[CoordTup] | CoordArray | None = None, 1378 ) -> "SolvedMaze": 1379 """solves the given targeted lattice maze and returns a SolvedMaze""" 1380 if solution is None: 1381 solution = targeted_lattice_maze.find_shortest_path( 1382 targeted_lattice_maze.start_pos, 1383 targeted_lattice_maze.end_pos, 1384 ) 1385 return cls( 1386 connection_list=targeted_lattice_maze.connection_list, 1387 solution=np.array(solution), 1388 generation_meta=targeted_lattice_maze.generation_meta, 1389 ) 1390 1391 def get_solution_forking_points( 1392 self, 1393 always_include_endpoints: bool = False, 1394 ) -> tuple[list[int], CoordArray]: 1395 """coordinates and their indicies from the solution where a fork is present 1396 1397 - if the start point is not a dead end, this counts as a fork 1398 - if the end point is not a dead end, this counts as a fork 1399 """ 1400 output_idxs: list[int] = list() 1401 output_coords: list[CoordTup] = list() 1402 1403 for idx, coord in enumerate(self.solution): 1404 # more than one choice for first coord, or more than 2 for any other 1405 # since the previous coord doesn't count as a choice 1406 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 1407 theshold: int = 1 if is_endpoint else 2 1408 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 1409 is_endpoint and always_include_endpoints 1410 ): 1411 output_idxs.append(idx) 1412 output_coords.append(coord) 1413 1414 return output_idxs, np.array(output_coords) 1415 1416 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 1417 """coordinates from the solution where there is only a single (non-backtracking) point to move to 1418 1419 returns the complement of `get_solution_forking_points` from the path 1420 """ 1421 forks_idxs, _ = self.get_solution_forking_points() 1422 # HACK: idk why type ignore here 1423 return ( # type: ignore[return-value] 1424 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 1425 np.delete(self.solution, forks_idxs, axis=0), 1426 )
Stores a maze and a solution
1273 def __init__( 1274 self, 1275 connection_list: ConnectionList, 1276 solution: CoordArray, 1277 generation_meta: dict | None = None, 1278 start_pos: Coord | None = None, 1279 end_pos: Coord | None = None, 1280 allow_invalid: bool = False, 1281 ) -> None: 1282 """Create a SolvedMaze from a connection list and a solution 1283 1284 > DOCS: better documentation for this init method 1285 """ 1286 # figure out the solution 1287 solution_valid: bool = False 1288 if solution is not None: 1289 solution = np.array(solution) 1290 # note that a path length of 1 here is valid, since the start and end pos could be the same 1291 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 1292 solution_valid = True 1293 1294 if not solution_valid and not allow_invalid: 1295 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" 1296 raise ValueError( 1297 err_msg, 1298 f"{connection_list = }", 1299 ) 1300 1301 # init the TargetedLatticeMaze 1302 super().__init__( 1303 connection_list=connection_list, 1304 generation_meta=generation_meta, 1305 # TODO: the argument type is stricter than the expected type but it still fails? 1306 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 1307 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 1308 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 1309 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 1310 ) 1311 1312 self.__dict__["solution"] = solution 1313 1314 # adjust the endpoints 1315 if not allow_invalid: 1316 if start_pos is not None: 1317 assert np.array_equal(np.array(start_pos), self.start_pos), ( 1318 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 1319 ) 1320 if end_pos is not None: 1321 assert np.array_equal(np.array(end_pos), self.end_pos), ( 1322 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 1323 ) 1324 # TODO: assert the path does not backtrack, walk through walls, etc?
Create a SolvedMaze from a connection list and a solution
DOCS: better documentation for this init method
1341 def get_solution_tokens(self) -> list[str | CoordTup]: 1342 "(deprecated!) return the solution as a list of tokens" 1343 warnings.warn( 1344 "`LatticeMaze.get_solution_tokens` is deprecated.", 1345 TokenizerDeprecationWarning, 1346 ) 1347 return self._get_solution_tokens()
(deprecated!) return the solution as a list of tokens
1350 @property 1351 def maze(self) -> LatticeMaze: 1352 "(deprecated!) return the maze without the solution" 1353 warnings.warn( 1354 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 1355 DeprecationWarning, 1356 ) 1357 return LatticeMaze(connection_list=self.connection_list)
(deprecated!) return the maze without the solution
1360 @classmethod 1361 def from_lattice_maze( # type: ignore[override] 1362 cls, 1363 lattice_maze: LatticeMaze, 1364 solution: list[CoordTup] | CoordArray, 1365 ) -> "SolvedMaze": 1366 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" 1367 return cls( 1368 connection_list=lattice_maze.connection_list, 1369 solution=np.array(solution), 1370 generation_meta=lattice_maze.generation_meta, 1371 )
get a SolvedMaze
from a LatticeMaze
by specifying a solution
1373 @classmethod 1374 def from_targeted_lattice_maze( 1375 cls, 1376 targeted_lattice_maze: TargetedLatticeMaze, 1377 solution: list[CoordTup] | CoordArray | None = None, 1378 ) -> "SolvedMaze": 1379 """solves the given targeted lattice maze and returns a SolvedMaze""" 1380 if solution is None: 1381 solution = targeted_lattice_maze.find_shortest_path( 1382 targeted_lattice_maze.start_pos, 1383 targeted_lattice_maze.end_pos, 1384 ) 1385 return cls( 1386 connection_list=targeted_lattice_maze.connection_list, 1387 solution=np.array(solution), 1388 generation_meta=targeted_lattice_maze.generation_meta, 1389 )
solves the given targeted lattice maze and returns a SolvedMaze
1391 def get_solution_forking_points( 1392 self, 1393 always_include_endpoints: bool = False, 1394 ) -> tuple[list[int], CoordArray]: 1395 """coordinates and their indicies from the solution where a fork is present 1396 1397 - if the start point is not a dead end, this counts as a fork 1398 - if the end point is not a dead end, this counts as a fork 1399 """ 1400 output_idxs: list[int] = list() 1401 output_coords: list[CoordTup] = list() 1402 1403 for idx, coord in enumerate(self.solution): 1404 # more than one choice for first coord, or more than 2 for any other 1405 # since the previous coord doesn't count as a choice 1406 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 1407 theshold: int = 1 if is_endpoint else 2 1408 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 1409 is_endpoint and always_include_endpoints 1410 ): 1411 output_idxs.append(idx) 1412 output_coords.append(coord) 1413 1414 return output_idxs, np.array(output_coords)
coordinates and their indicies from the solution where a fork is present
- if the start point is not a dead end, this counts as a fork
- if the end point is not a dead end, this counts as a fork
1416 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 1417 """coordinates from the solution where there is only a single (non-backtracking) point to move to 1418 1419 returns the complement of `get_solution_forking_points` from the path 1420 """ 1421 forks_idxs, _ = self.get_solution_forking_points() 1422 # HACK: idk why type ignore here 1423 return ( # type: ignore[return-value] 1424 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 1425 np.delete(self.solution, forks_idxs, axis=0), 1426 )
coordinates from the solution where there is only a single (non-backtracking) point to move to
returns the complement of get_solution_forking_points
from the path
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- LatticeMaze
- connection_list
- generation_meta
- lattice_dim
- grid_shape
- n_connections
- grid_n
- heuristic
- nodes_connected
- is_valid_path
- coord_degrees
- get_coord_neighbors
- gen_connected_component_from
- find_shortest_path
- get_nodes
- get_connected_component
- generate_random_path
- as_adj_list
- from_adj_list
- as_adj_list_tokens
- as_tokens
- from_tokens
- as_pixels
- from_pixels
- as_ascii
- from_ascii
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
257@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 258class MazeDatasetConfig(MazeDatasetConfig_base): # type: ignore[misc] 259 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset 260 261 # Parameters: 262 - `name : str` 263 name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname` 264 - `grid_n : int` 265 grid size of the maze (number of rows/columns) 266 - `n_mazes : int` 267 number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate. 268 see `EndpointKwargsType` for more details. 269 - `maze_ctor : Callable` 270 maze generator function. This should be a function that takes a grid size and returns a maze. 271 This will usually be one of the functions in `LatticeMazeGenerators`. 272 - `maze_ctor_kwargs : dict` 273 keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using. 274 - `endpoint_kwargs : EndpointKwargsType` 275 keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info. 276 - `applied_filters : list[dict]` 277 list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, 278 but these are stored with the config in case you want to re-generate the dataset with the same filters. 279 280 """ 281 282 @property 283 def config_version(self) -> str: 284 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 285 return "1.0" 286 287 @property 288 def versions(self) -> dict: 289 """return the versions of the config and the maze_dataset""" 290 return dict( 291 config=self.config_version, 292 maze_dataset=importlib.metadata.version("maze_dataset"), 293 ) 294 295 def serialize(self) -> dict: 296 "serialize the MazeDatasetConfig with all fields and fname" 297 return { 298 **self._serialize_base( 299 applied_filters__skip__collect_generation_meta=False 300 ), 301 "fname": self.to_fname(), 302 "versions": self.versions, 303 } 304 305 def summary(self) -> dict: 306 """return a summary of the config""" 307 # do we run this to make sure it doesn't error? 308 super_summary: dict = super().summary() 309 assert super_summary 310 self_ser: dict = self.serialize() 311 return dict( 312 name=self.name, 313 fname=self.to_fname(), 314 sdc_hash=self.stable_hash_cfg(), 315 seed=self.seed, 316 seq_len_min=self.seq_len_min, 317 seq_len_max=self.seq_len_max, 318 applied_filters=self.applied_filters, 319 grid_n=self_ser["grid_n"], 320 n_mazes=self_ser["n_mazes"], 321 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 322 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 323 endpoint_kwargs=self_ser["endpoint_kwargs"], 324 ) 325 326 def _to_ps_array(self) -> _PercolationSuccessArray: 327 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 328 329 used in predicting the success rate 330 """ 331 try: 332 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 333 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 334 ) 335 assert "p" in self.maze_ctor_kwargs, ( 336 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 337 ) 338 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 339 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 340 ) 341 except AssertionError as e: 342 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 343 raise NoPercolationInConfigError( 344 err_msg, 345 ) from e 346 347 endpoints_unique_flag: int = int( 348 # we are pretty sure it will be an int or bool here 349 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 350 ) 351 352 # adjustment for bknutson0 353 if not ( 354 self.endpoint_kwargs.get("deadend_start", False) 355 and self.endpoint_kwargs.get("deadend_end", False) 356 ): 357 # we didnt train on this, but if either endpoint is not required to be in a dead end 358 # then requiring the endpoints to be unique does not really affect the success rate 359 # (except for very small percolation values, pure percolation generation) 360 endpoints_unique_flag = 0 361 362 return np.array( 363 [ 364 float(self.maze_ctor_kwargs["p"]), 365 float(self.grid_n), 366 float( 367 int( 368 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 369 or self.endpoint_kwargs.get("deadend_end", False), 370 ), 371 ), 372 float(endpoints_unique_flag), 373 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 374 ], 375 dtype=np.float64, 376 ) 377 378 @classmethod 379 def _from_ps_array( 380 cls, 381 arr: _PercolationSuccessArray, 382 name: str = "predict", 383 n_mazes: int = 100, 384 **kwargs, 385 ) -> "MazeDatasetConfig": 386 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 387 388 # Returns: 389 - `MazeDatasetConfig` 390 Config corresponding to `arr` 391 """ 392 return cls( 393 name=name, 394 grid_n=int(arr[1]), 395 n_mazes=n_mazes, 396 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 397 maze_ctor_kwargs={"p": float(arr[0])}, 398 endpoint_kwargs=dict( 399 deadend_start=bool(arr[2]), 400 deadend_end=bool(arr[2]), 401 endpoints_not_equal=bool(arr[3]), 402 except_on_no_valid_endpoint=False, 403 ), 404 **kwargs, 405 ) 406 407 def success_fraction_estimate( 408 self, 409 except_if_all_success_expected: bool = False, 410 ) -> float: 411 """Estimate the success fraction of this config. 412 413 only valid when the generator is a percolation generator, 414 and endpoints are enforced to be dead ends 415 416 more information on where this comes from can be found in 417 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 418 - `estimate_dataset_fractions.ipynb` 419 - `maze_dataset.benchmarks.sweep_fit` 420 421 # Parameters: 422 - `except_if_all_success_expected : bool` 423 if `True`, don't raise an error if the success fraction is below the threshold. 424 will always return `1.0` if the config is not expected to fail 425 426 # Returns: 427 - `float` 428 estimated success fraction 429 430 # Raises: 431 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 432 """ 433 try: 434 return cfg_success_predict_fn(self) 435 436 except NoPercolationInConfigError as e: 437 if except_if_all_success_expected: 438 raise e # noqa: TRY201 439 return 1.0 440 441 def success_fraction_compensate( 442 self, 443 safety_margin: float = 1.2, 444 except_if_all_success_expected: bool = False, 445 epsilon: float = 1e-2, 446 ) -> "MazeDatasetConfig": 447 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 448 449 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 450 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 451 452 more information on where this comes from can be found in 453 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 454 - `estimate_dataset_fractions.ipynb` 455 - `maze_dataset.benchmarks.sweep_fit` 456 457 # Parameters: 458 - `safety_margin : float` 459 safety margin to apply to the success fraction estimate 460 (defaults to `1.2`, or 20% more mazes than estimated) 461 - `except_if_all_success_expected : bool` 462 if `True`, don't raise an error if the success fraction is below the threshold. 463 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 464 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 465 since `safety_margin` is still applied. 466 (defaults to `False`) 467 - `epsilon : float` 468 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 469 (defaults to `1e-2`) 470 471 # Returns: 472 - `MazeDatasetConfig` 473 new config with adjusted `n_mazes` 474 475 # Raises: 476 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 477 """ 478 # compute and check the success fraction 479 success_fraction: float = self.success_fraction_estimate( 480 except_if_all_success_expected=except_if_all_success_expected, 481 ) 482 if success_fraction < epsilon: 483 err_msg: str = ( 484 f"{success_fraction = } is below the threshold of {epsilon = }" 485 ) 486 raise SuccessChanceTooSmallError( 487 err_msg, 488 ) 489 490 # compute the new number of mazes 491 n_mazes: int = self.n_mazes 492 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 493 494 # put it in a new config and return 495 cfg_dict: dict = self.serialize() 496 cfg_dict["n_mazes"] = new_n_mazes 497 return MazeDatasetConfig.load(cfg_dict)
config object which is passed to MazeDataset.from_config
to generate or load a dataset
Parameters:
name : str
name of the dataset -- this can be anything, but should be filesystem safe since we use it in thefname
grid_n : int
grid size of the maze (number of rows/columns)n_mazes : int
number of mazes to request. For some combinations ofendpoint_kwargs
andmaze_ctor
, not all mazes might successfully generate. seeEndpointKwargsType
for more details.maze_ctor : Callable
maze generator function. This should be a function that takes a grid size and returns a maze. This will usually be one of the functions inLatticeMazeGenerators
.maze_ctor_kwargs : dict
keyword arguments to pass to the maze generator function. Specific to themaze_ctor
you are using.endpoint_kwargs : EndpointKwargsType
keyword arguments passed toLatticeMaze.generate_random_path()
. seeEndpointKwargsType
for more info.applied_filters : list[dict]
list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, but these are stored with the config in case you want to re-generate the dataset with the same filters.
282 @property 283 def config_version(self) -> str: 284 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 285 return "1.0"
return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config
287 @property 288 def versions(self) -> dict: 289 """return the versions of the config and the maze_dataset""" 290 return dict( 291 config=self.config_version, 292 maze_dataset=importlib.metadata.version("maze_dataset"), 293 )
return the versions of the config and the maze_dataset
295 def serialize(self) -> dict: 296 "serialize the MazeDatasetConfig with all fields and fname" 297 return { 298 **self._serialize_base( 299 applied_filters__skip__collect_generation_meta=False 300 ), 301 "fname": self.to_fname(), 302 "versions": self.versions, 303 }
serialize the MazeDatasetConfig with all fields and fname
305 def summary(self) -> dict: 306 """return a summary of the config""" 307 # do we run this to make sure it doesn't error? 308 super_summary: dict = super().summary() 309 assert super_summary 310 self_ser: dict = self.serialize() 311 return dict( 312 name=self.name, 313 fname=self.to_fname(), 314 sdc_hash=self.stable_hash_cfg(), 315 seed=self.seed, 316 seq_len_min=self.seq_len_min, 317 seq_len_max=self.seq_len_max, 318 applied_filters=self.applied_filters, 319 grid_n=self_ser["grid_n"], 320 n_mazes=self_ser["n_mazes"], 321 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 322 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 323 endpoint_kwargs=self_ser["endpoint_kwargs"], 324 )
return a summary of the config
407 def success_fraction_estimate( 408 self, 409 except_if_all_success_expected: bool = False, 410 ) -> float: 411 """Estimate the success fraction of this config. 412 413 only valid when the generator is a percolation generator, 414 and endpoints are enforced to be dead ends 415 416 more information on where this comes from can be found in 417 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 418 - `estimate_dataset_fractions.ipynb` 419 - `maze_dataset.benchmarks.sweep_fit` 420 421 # Parameters: 422 - `except_if_all_success_expected : bool` 423 if `True`, don't raise an error if the success fraction is below the threshold. 424 will always return `1.0` if the config is not expected to fail 425 426 # Returns: 427 - `float` 428 estimated success fraction 429 430 # Raises: 431 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 432 """ 433 try: 434 return cfg_success_predict_fn(self) 435 436 except NoPercolationInConfigError as e: 437 if except_if_all_success_expected: 438 raise e # noqa: TRY201 439 return 1.0
Estimate the success fraction of this config.
only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends
more information on where this comes from can be found in
cfg_success_predict_fn()
frommaze_dataset.dataset.success_predict_math
estimate_dataset_fractions.ipynb
maze_dataset.benchmarks.sweep_fit
Parameters:
except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. will always return1.0
if the config is not expected to fail
Returns:
float
estimated success fraction
Raises:
NoPercolationInConfigError
: if the config is not expected to fail, andexcept_if_all_success_expected
isFalse
441 def success_fraction_compensate( 442 self, 443 safety_margin: float = 1.2, 444 except_if_all_success_expected: bool = False, 445 epsilon: float = 1e-2, 446 ) -> "MazeDatasetConfig": 447 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 448 449 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 450 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 451 452 more information on where this comes from can be found in 453 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 454 - `estimate_dataset_fractions.ipynb` 455 - `maze_dataset.benchmarks.sweep_fit` 456 457 # Parameters: 458 - `safety_margin : float` 459 safety margin to apply to the success fraction estimate 460 (defaults to `1.2`, or 20% more mazes than estimated) 461 - `except_if_all_success_expected : bool` 462 if `True`, don't raise an error if the success fraction is below the threshold. 463 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 464 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 465 since `safety_margin` is still applied. 466 (defaults to `False`) 467 - `epsilon : float` 468 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 469 (defaults to `1e-2`) 470 471 # Returns: 472 - `MazeDatasetConfig` 473 new config with adjusted `n_mazes` 474 475 # Raises: 476 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 477 """ 478 # compute and check the success fraction 479 success_fraction: float = self.success_fraction_estimate( 480 except_if_all_success_expected=except_if_all_success_expected, 481 ) 482 if success_fraction < epsilon: 483 err_msg: str = ( 484 f"{success_fraction = } is below the threshold of {epsilon = }" 485 ) 486 raise SuccessChanceTooSmallError( 487 err_msg, 488 ) 489 490 # compute the new number of mazes 491 n_mazes: int = self.n_mazes 492 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 493 494 # put it in a new config and return 495 cfg_dict: dict = self.serialize() 496 cfg_dict["n_mazes"] = new_n_mazes 497 return MazeDatasetConfig.load(cfg_dict)
return a new MazeDatasetConfig
like this one with n_mazes
adjusted to compensate for the success fraction
calls MazeDatasetConfig.success_fraction_estimate()
to get the success fraction, and then
computes the new number of mazes as n_mazes = n_mazes * safety_margin / success_fraction + 1
more information on where this comes from can be found in
cfg_success_predict_fn()
frommaze_dataset.dataset.success_predict_math
estimate_dataset_fractions.ipynb
maze_dataset.benchmarks.sweep_fit
Parameters:
safety_margin : float
safety margin to apply to the success fraction estimate (defaults to1.2
, or 20% more mazes than estimated)except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. this is passed toMazeDatasetConfig.success_fraction_estimate
. if your config isn't expected to fail, passing this might mean you generate more mazes than needed sincesafety_margin
is still applied. (defaults toFalse
)epsilon : float
raiseSuccessChanceTooSmallError
if the success fraction is below this threshold (defaults to1e-2
)
Returns:
MazeDatasetConfig
new config with adjustedn_mazes
Raises:
SuccessChanceTooSmallError
: if the computed success fraction is belowepsilon
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.dataset.maze_dataset_config.MazeDatasetConfig_base
- grid_n
- n_mazes
- maze_ctor
- maze_ctor_kwargs
- endpoint_kwargs
- grid_shape
- grid_shape_np
- max_grid_n
- stable_hash_cfg
- to_fname
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
113class MazeDataset(GPTDataset[MazeDatasetConfig]): 114 """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`""" 115 116 def __init__( 117 self, 118 cfg: MazeDatasetConfig, 119 mazes: typing.Sequence[SolvedMaze], 120 generation_metadata_collected: dict | None = None, 121 ) -> None: 122 """initialize a maze dataset from a config and a list of solved mazes""" 123 super().__init__() 124 self.cfg: MazeDatasetConfig = cfg 125 self.mazes: list[SolvedMaze] = list(mazes) 126 self.generation_metadata_collected: dict | None = generation_metadata_collected 127 128 # TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset" [override] 129 @classmethod 130 def from_config( # type: ignore[override] 131 cls, 132 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 133 cfg: MazeDatasetConfig, # type: ignore[override] 134 do_generate: bool = True, 135 load_local: bool = True, 136 save_local: bool = True, 137 zanj: ZANJ | None = None, 138 do_download: bool = True, 139 local_base_path: Path = Path("data/maze_dataset"), 140 except_on_config_mismatch: bool = True, 141 allow_generation_metadata_filter_mismatch: bool = True, 142 verbose: bool = False, 143 **kwargs, 144 ) -> "MazeDataset": 145 """create a maze dataset from a config 146 147 priority of loading: 148 1. load from local 149 2. download 150 3. generate 151 152 """ 153 return cast( 154 MazeDataset, 155 super().from_config( 156 cfg=cfg, 157 do_generate=do_generate, 158 load_local=load_local, 159 save_local=save_local, 160 zanj=zanj, 161 do_download=do_download, 162 local_base_path=local_base_path, 163 except_on_config_mismatch=except_on_config_mismatch, 164 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 165 verbose=verbose, 166 **kwargs, 167 ), 168 ) 169 170 def data_hash(self) -> int: 171 """return a hash of the data""" 172 return stable_hash(str(tuple([x.serialize() for x in self.mazes]))) 173 174 def __getitem__(self, i: int) -> SolvedMaze: 175 """get a maze by index""" 176 return self.mazes[i] 177 178 def __iter__(self) -> typing.Iterator[SolvedMaze]: 179 """iterate over the mazes""" 180 return iter(self.mazes) 181 182 def __deepcopy__(self, memo) -> "MazeDataset": # noqa: ANN001 183 """deepcopy the dataset 184 185 FIX: this isnt actually a deepcopy I think? 186 """ 187 return MazeDataset.load(self._serialize_full()) 188 189 # TYPING: get type hints on the tokenizer here 190 @overload 191 def as_tokens( 192 self, 193 maze_tokenizer, # noqa: ANN001 194 limit: int | None = None, 195 join_tokens_individual_maze: Literal[False] = False, 196 ) -> list[list[str]]: ... 197 @overload 198 def as_tokens( 199 self, 200 maze_tokenizer, # noqa: ANN001 201 limit: int | None = None, 202 join_tokens_individual_maze: Literal[True] = True, 203 ) -> list[str]: ... 204 def as_tokens( 205 self, 206 maze_tokenizer, # TODO: MazeTokenizer 207 limit: int | None = None, 208 join_tokens_individual_maze: bool = False, 209 ) -> list[list[str]] | list[str]: 210 """return the dataset as tokens according to the passed `maze_tokenizer` 211 212 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 213 214 if `join_tokens_individual_maze` is True, then the tokens of each maze are 215 joined with a space, and the result is a list of strings. 216 i.e.: 217 218 >>> dataset.as_tokens(join_tokens_individual_maze=False) 219 [["a", "b", "c"], ["d", "e", "f"]] 220 >>> dataset.as_tokens(join_tokens_individual_maze=True) 221 ["a b c", "d e f"] 222 """ 223 output: list[list[str]] = [ 224 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 225 ] 226 if join_tokens_individual_maze: 227 return [" ".join(tokens) for tokens in output] 228 else: 229 return output 230 231 def __len__(self) -> int: 232 """return the number of mazes in the dataset""" 233 return len(self.mazes) 234 235 def __eq__(self, other: object) -> bool: 236 """compare two datasets""" 237 if not isinstance(other, MazeDataset): 238 raise NotImplementedError( 239 "can only compare with other MazeDataset objects", 240 ) 241 # TODO: compare hashes of data instead of the data itself? 242 return self.cfg == other.cfg and self.mazes == other.mazes 243 244 def assert_equal(self, other: "MazeDataset") -> None: 245 """assert that two datasets are equal""" 246 assert isinstance(other, MazeDataset) 247 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 248 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }" 249 250 @classmethod 251 def generate( 252 cls, 253 cfg: MazeDatasetConfig, 254 gen_parallel: bool = False, 255 pool_kwargs: dict | None = None, 256 verbose: bool = False, 257 # TODO: what to do when unexpected kwargs are passed? 258 **kwargs, # noqa: ARG003 259 ) -> "MazeDataset": 260 """Generate a maze dataset given a config and some generation parameters""" 261 # Copy the config to avoid modifying the original 262 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 263 json.loads(json.dumps(cfg.serialize())), 264 ) 265 266 if pool_kwargs is None: 267 pool_kwargs = dict() 268 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 269 270 solved_mazes: list[SolvedMaze | None] 271 # Configure tqdm for progress bar 272 tqdm_kwargs: dict = dict( 273 total=cfg_cpy.n_mazes, 274 unit="maze", 275 desc="generating & solving mazes", 276 disable=not verbose, 277 ) 278 # TODO: don't use the global unless generating in parallel! 279 if gen_parallel: 280 with multiprocessing.Pool( 281 **pool_kwargs, 282 initializer=_maze_gen_init_worker, 283 initargs=(cfg_cpy,), 284 ) as pool: 285 solved_mazes = list( 286 tqdm.tqdm( 287 pool.imap(_generate_maze_helper, maze_indexes), 288 **tqdm_kwargs, 289 ), 290 ) 291 292 else: 293 _maze_gen_init_worker(cfg_cpy) 294 solved_mazes = list( 295 tqdm.tqdm( 296 map( 297 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 298 # why does it think tolist() returns a string? 299 _generate_maze_helper, # type: ignore[arg-type] 300 maze_indexes.tolist(), 301 ), 302 **tqdm_kwargs, 303 ), 304 ) 305 306 # Filter out None values explicitly after ensuring all results are collected 307 solved_mazes_: list[SolvedMaze] = [ 308 maze for maze in solved_mazes if maze is not None 309 ] 310 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 311 312 # Update the config with the actual number of mazes 313 cfg_cpy.n_mazes = len(solved_mazes_) 314 315 dataset: MazeDataset = cls( 316 cfg=cfg_cpy, 317 mazes=solved_mazes_, 318 ) 319 320 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 321 322 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 323 324 return dataset 325 326 @classmethod 327 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 328 "(not implemented yet!) download a maze dataset from the internet" 329 raise NotImplementedError("not implemented yet") 330 331 @classmethod 332 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 333 """load from zanj/json""" 334 if data[_FORMAT_KEY] == "MazeDataset:minimal": 335 return cls._load_minimal(data) 336 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 337 return cls._load_minimal_soln_cat(data) 338 elif data[_FORMAT_KEY] == "MazeDataset": 339 if ( 340 SERIALIZE_MINIMAL_THRESHOLD == -1 341 ): # Allow access to `_load_legacy` for profiling 342 return cls._load_legacy(data) 343 return cls._load_full(data) 344 else: 345 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 346 raise KeyError( 347 err_msg, 348 ) 349 350 @classmethod 351 def _load_full(cls, data: JSONdict) -> "MazeDataset": 352 assert data[_FORMAT_KEY] == "MazeDataset" 353 return cls( 354 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 355 mazes=load_item_recursive(data["mazes"], tuple()), 356 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 357 ) 358 359 @classmethod 360 def _load_minimal(cls, data: JSONdict) -> "MazeDataset": 361 assert data[_FORMAT_KEY] == "MazeDataset:minimal" 362 return cls( 363 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 364 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 365 mazes=[ 366 SolvedMaze( 367 clist, 368 soln[:slen, ...], 369 ) 370 for clist, slen, soln in zip( 371 load_item_recursive(data["maze_connection_lists"], tuple()), 372 load_item_recursive(data["maze_solution_lengths"], tuple()), 373 load_item_recursive(data["maze_solutions"], tuple()), 374 strict=False, 375 # load_item_recursive(data["maze_endpoints"], tuple()), 376 ) 377 ], 378 ) 379 380 @classmethod 381 def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset": 382 assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat" 383 384 maze_solution_lengths = load_item_recursive( 385 data["maze_solution_lengths"], 386 tuple(), 387 ) 388 maze_solutions_concat = load_item_recursive( 389 data["maze_solutions_concat"], 390 tuple(), 391 ) 392 maze_solutions = np.split( 393 maze_solutions_concat, 394 np.cumsum(maze_solution_lengths)[:-1], 395 axis=0, 396 ) 397 398 return cls( 399 cfg=load_item_recursive(data["cfg"], tuple()), 400 generation_metadata_collected=load_item_recursive( 401 data["generation_metadata_collected"], 402 tuple(), 403 ), 404 mazes=[ 405 SolvedMaze( 406 connection_list=clist, 407 solution=soln, 408 ) 409 for clist, soln in zip( 410 load_item_recursive(data["maze_connection_lists"], tuple()), 411 # load_item_recursive(data["maze_endpoints"], tuple()), 412 maze_solutions, 413 strict=False, 414 ) 415 ], 416 ) 417 418 @classmethod 419 def _load_legacy(cls, data: JSONdict) -> "MazeDataset": 420 """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison.""" 421 assert data[_FORMAT_KEY] == "MazeDataset" 422 return cls( 423 **{ 424 key: load_item_recursive(data[key], tuple()) 425 for key in ["cfg", "mazes", "generation_metadata_collected"] 426 }, 427 ) 428 429 def serialize(self) -> JSONdict: 430 """serialize to zanj/json""" 431 if ( 432 SERIALIZE_MINIMAL_THRESHOLD is not None 433 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 434 ): 435 return self._serialize_minimal() 436 return self._serialize_full() 437 438 def _serialize_full(self) -> JSONdict: 439 return { 440 _FORMAT_KEY: "MazeDataset", 441 "cfg": json_serialize(self.cfg), 442 "fname": self.cfg.to_fname(), 443 "mazes": json_serialize(self.mazes), 444 "generation_metadata_collected": json_serialize( 445 self.generation_metadata_collected, 446 ), 447 } 448 449 def _serialize_minimal(self) -> JSONdict: 450 "alternate serialization where metadata is collected and mazes are stored in concatenated form" 451 filtered_meta: MazeDataset 452 if self.generation_metadata_collected is None: 453 filtered_meta = self.filter_by.collect_generation_meta() 454 else: 455 filtered_meta = self 456 457 max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes) 458 n_mazes: int = len(filtered_meta.mazes) 459 grid_n: int = filtered_meta.cfg.grid_n 460 461 maze_connection_lists: np.ndarray = np.empty( 462 (n_mazes, 2, grid_n, grid_n), 463 dtype=np.bool_, 464 ) 465 # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 466 maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32) 467 maze_solutions: np.ndarray = np.empty( 468 (n_mazes, max_solution_len, 2), 469 dtype=np.int8, 470 ) 471 472 for idx, maze in enumerate(filtered_meta.mazes): 473 maze_connection_lists[idx] = maze.connection_list 474 # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 475 maze_solution_lengths[idx] = maze.solution.shape[0] 476 maze_solutions[idx, : maze.solution.shape[0]] = maze.solution 477 478 return { 479 _FORMAT_KEY: "MazeDataset:minimal", 480 "cfg": json_serialize(filtered_meta.cfg), 481 "fname": filtered_meta.cfg.to_fname(), 482 "generation_metadata_collected": json_serialize( 483 filtered_meta.generation_metadata_collected, 484 ), 485 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 486 # "maze_endpoints": maze_endpoints, 487 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 488 "maze_solutions": maze_solutions, # type: ignore[dict-item] 489 } 490 491 def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict: 492 "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form" 493 filtered_meta: MazeDataset 494 if self.generation_metadata_collected is None: 495 filtered_meta = self.filter_by.collect_generation_meta() 496 else: 497 filtered_meta = self 498 499 maze_solution_lengths: np.ndarray = np.array( 500 [m.solution.shape[0] for m in filtered_meta.mazes], 501 dtype=np.int32, 502 ) 503 n_mazes: int = len(filtered_meta.mazes) 504 grid_n: int = filtered_meta.cfg.grid_n 505 total_solution_len: int = np.sum(maze_solution_lengths) 506 507 maze_connection_lists: np.ndarray = np.empty( 508 (n_mazes, 2, grid_n, grid_n), 509 dtype=np.bool_, 510 ) 511 maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 512 maze_solutions_concat: np.ndarray = np.empty( 513 (total_solution_len, 2), 514 dtype=np.int8, 515 ) 516 517 solutions_running_idx: int = 0 518 for idx, maze in enumerate(filtered_meta.mazes): 519 maze_connection_lists[idx] = maze.connection_list 520 maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 521 soln_len: int = maze.solution.shape[0] 522 maze_solution_lengths[idx] = soln_len 523 maze_solutions_concat[ 524 solutions_running_idx : solutions_running_idx + soln_len 525 ] = maze.solution 526 solutions_running_idx += soln_len 527 528 return { 529 _FORMAT_KEY: "MazeDataset:minimal_soln_cat", 530 "cfg": json_serialize(filtered_meta.cfg), 531 "fname": filtered_meta.cfg.to_fname(), 532 "generation_metadata_collected": json_serialize( 533 filtered_meta.generation_metadata_collected, 534 ), 535 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 536 "maze_endpoints": maze_endpoints, # type: ignore[dict-item] 537 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 538 "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item] 539 } 540 541 def update_self_config(self) -> None: 542 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 543 if self.cfg.n_mazes != len(self.mazes): 544 warnings.warn( 545 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 546 ) 547 self.cfg.n_mazes = len(self.mazes) 548 549 def custom_maze_filter( 550 self, 551 method: typing.Callable[[SolvedMaze], bool], 552 **kwargs, 553 ) -> "MazeDataset": 554 """filter the dataset using a custom method""" 555 output: MazeDataset = MazeDataset( 556 cfg=copy.deepcopy(self.cfg), 557 mazes=[m for m in self.mazes if method(m, **kwargs)], 558 ) 559 output.cfg.applied_filters.append( 560 { 561 "name": f"__custom__:{method.__name__}", 562 "kwargs": kwargs, 563 }, 564 ) 565 output.update_self_config() 566 return output
a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config
116 def __init__( 117 self, 118 cfg: MazeDatasetConfig, 119 mazes: typing.Sequence[SolvedMaze], 120 generation_metadata_collected: dict | None = None, 121 ) -> None: 122 """initialize a maze dataset from a config and a list of solved mazes""" 123 super().__init__() 124 self.cfg: MazeDatasetConfig = cfg 125 self.mazes: list[SolvedMaze] = list(mazes) 126 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize a maze dataset from a config and a list of solved mazes
129 @classmethod 130 def from_config( # type: ignore[override] 131 cls, 132 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 133 cfg: MazeDatasetConfig, # type: ignore[override] 134 do_generate: bool = True, 135 load_local: bool = True, 136 save_local: bool = True, 137 zanj: ZANJ | None = None, 138 do_download: bool = True, 139 local_base_path: Path = Path("data/maze_dataset"), 140 except_on_config_mismatch: bool = True, 141 allow_generation_metadata_filter_mismatch: bool = True, 142 verbose: bool = False, 143 **kwargs, 144 ) -> "MazeDataset": 145 """create a maze dataset from a config 146 147 priority of loading: 148 1. load from local 149 2. download 150 3. generate 151 152 """ 153 return cast( 154 MazeDataset, 155 super().from_config( 156 cfg=cfg, 157 do_generate=do_generate, 158 load_local=load_local, 159 save_local=save_local, 160 zanj=zanj, 161 do_download=do_download, 162 local_base_path=local_base_path, 163 except_on_config_mismatch=except_on_config_mismatch, 164 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 165 verbose=verbose, 166 **kwargs, 167 ), 168 )
create a maze dataset from a config
priority of loading:
- load from local
- download
- generate
170 def data_hash(self) -> int: 171 """return a hash of the data""" 172 return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
return a hash of the data
204 def as_tokens( 205 self, 206 maze_tokenizer, # TODO: MazeTokenizer 207 limit: int | None = None, 208 join_tokens_individual_maze: bool = False, 209 ) -> list[list[str]] | list[str]: 210 """return the dataset as tokens according to the passed `maze_tokenizer` 211 212 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 213 214 if `join_tokens_individual_maze` is True, then the tokens of each maze are 215 joined with a space, and the result is a list of strings. 216 i.e.: 217 218 >>> dataset.as_tokens(join_tokens_individual_maze=False) 219 [["a", "b", "c"], ["d", "e", "f"]] 220 >>> dataset.as_tokens(join_tokens_individual_maze=True) 221 ["a b c", "d e f"] 222 """ 223 output: list[list[str]] = [ 224 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 225 ] 226 if join_tokens_individual_maze: 227 return [" ".join(tokens) for tokens in output] 228 else: 229 return output
return the dataset as tokens according to the passed maze_tokenizer
the maze_tokenizer
should be either a MazeTokenizer
or a MazeTokenizerModular
if join_tokens_individual_maze
is True, then the tokens of each maze are
joined with a space, and the result is a list of strings.
i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
244 def assert_equal(self, other: "MazeDataset") -> None: 245 """assert that two datasets are equal""" 246 assert isinstance(other, MazeDataset) 247 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 248 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
assert that two datasets are equal
250 @classmethod 251 def generate( 252 cls, 253 cfg: MazeDatasetConfig, 254 gen_parallel: bool = False, 255 pool_kwargs: dict | None = None, 256 verbose: bool = False, 257 # TODO: what to do when unexpected kwargs are passed? 258 **kwargs, # noqa: ARG003 259 ) -> "MazeDataset": 260 """Generate a maze dataset given a config and some generation parameters""" 261 # Copy the config to avoid modifying the original 262 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 263 json.loads(json.dumps(cfg.serialize())), 264 ) 265 266 if pool_kwargs is None: 267 pool_kwargs = dict() 268 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 269 270 solved_mazes: list[SolvedMaze | None] 271 # Configure tqdm for progress bar 272 tqdm_kwargs: dict = dict( 273 total=cfg_cpy.n_mazes, 274 unit="maze", 275 desc="generating & solving mazes", 276 disable=not verbose, 277 ) 278 # TODO: don't use the global unless generating in parallel! 279 if gen_parallel: 280 with multiprocessing.Pool( 281 **pool_kwargs, 282 initializer=_maze_gen_init_worker, 283 initargs=(cfg_cpy,), 284 ) as pool: 285 solved_mazes = list( 286 tqdm.tqdm( 287 pool.imap(_generate_maze_helper, maze_indexes), 288 **tqdm_kwargs, 289 ), 290 ) 291 292 else: 293 _maze_gen_init_worker(cfg_cpy) 294 solved_mazes = list( 295 tqdm.tqdm( 296 map( 297 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 298 # why does it think tolist() returns a string? 299 _generate_maze_helper, # type: ignore[arg-type] 300 maze_indexes.tolist(), 301 ), 302 **tqdm_kwargs, 303 ), 304 ) 305 306 # Filter out None values explicitly after ensuring all results are collected 307 solved_mazes_: list[SolvedMaze] = [ 308 maze for maze in solved_mazes if maze is not None 309 ] 310 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 311 312 # Update the config with the actual number of mazes 313 cfg_cpy.n_mazes = len(solved_mazes_) 314 315 dataset: MazeDataset = cls( 316 cfg=cfg_cpy, 317 mazes=solved_mazes_, 318 ) 319 320 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 321 322 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 323 324 return dataset
Generate a maze dataset given a config and some generation parameters
326 @classmethod 327 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 328 "(not implemented yet!) download a maze dataset from the internet" 329 raise NotImplementedError("not implemented yet")
(not implemented yet!) download a maze dataset from the internet
331 @classmethod 332 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 333 """load from zanj/json""" 334 if data[_FORMAT_KEY] == "MazeDataset:minimal": 335 return cls._load_minimal(data) 336 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 337 return cls._load_minimal_soln_cat(data) 338 elif data[_FORMAT_KEY] == "MazeDataset": 339 if ( 340 SERIALIZE_MINIMAL_THRESHOLD == -1 341 ): # Allow access to `_load_legacy` for profiling 342 return cls._load_legacy(data) 343 return cls._load_full(data) 344 else: 345 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 346 raise KeyError( 347 err_msg, 348 )
load from zanj/json
429 def serialize(self) -> JSONdict: 430 """serialize to zanj/json""" 431 if ( 432 SERIALIZE_MINIMAL_THRESHOLD is not None 433 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 434 ): 435 return self._serialize_minimal() 436 return self._serialize_full()
serialize to zanj/json
541 def update_self_config(self) -> None: 542 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 543 if self.cfg.n_mazes != len(self.mazes): 544 warnings.warn( 545 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 546 ) 547 self.cfg.n_mazes = len(self.mazes)
update the config to match the current state of the dataset (number of mazes, such as after filtering)
549 def custom_maze_filter( 550 self, 551 method: typing.Callable[[SolvedMaze], bool], 552 **kwargs, 553 ) -> "MazeDataset": 554 """filter the dataset using a custom method""" 555 output: MazeDataset = MazeDataset( 556 cfg=copy.deepcopy(self.cfg), 557 mazes=[m for m in self.mazes if method(m, **kwargs)], 558 ) 559 output.cfg.applied_filters.append( 560 { 561 "name": f"__custom__:{method.__name__}", 562 "kwargs": kwargs, 563 }, 564 ) 565 output.update_self_config() 566 return output
filter the dataset using a custom method
Inherited Members
84class MazeDatasetCollection(GPTDataset): 85 """a collection of maze datasets""" 86 87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected 106 107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets] 111 112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths))) 116 117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 ) 125 126 def __len__(self) -> int: 127 """return the total number of mazes in the collection""" 128 return sum(len(dataset) for dataset in self.maze_datasets) 129 130 def __getitem__(self, index: int) -> LatticeMaze: 131 "get a maze by index" 132 # find which dataset the index belongs to 133 # we add 1, since np.searchsorted returns the 134 # index of the last element that is strictly less than the target 135 # while we want the index of the last element less than or equal to the target 136 dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1)) 137 index_adjusted: int = index 138 if dataset_idx > 0: 139 # if the index is 0, `dataset_idx - 1` will be -1. 140 # We just want to use the base index 141 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1] 142 return self.maze_datasets[dataset_idx][index_adjusted] 143 144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets) 156 157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets) 169 170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 } 180 181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 ) 191 192 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow? 193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output 217 218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
a collection of maze datasets
87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize the dataset collection from a MazeDatasetCollectionConfig
and a list of MazeDataset
s
107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets]
return the lengths of each dataset in the collection
112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths)))
return the cumulative lengths of each dataset in the collection
117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 )
single list of all mazes in the collection
144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets)
generate a dataset collection from a config
157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets)
(not implemented!) download a dataset collection from a config
170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 }
serialize the dataset collection
181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 )
load the dataset collection from the representation created by serialize
193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output
return the dataset as tokens
if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
update the config to match the number of mazes, and update the underlying configs of each dataset
Inherited Members
31@serializable_dataclass(kw_only=True) 32class MazeDatasetCollectionConfig(GPTDatasetConfig): 33 """maze dataset collection configuration, including tokenizers and shuffle""" 34 35 # Attributes without a default cannot follow attributes with one [misc] 36 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( # type: ignore[misc] 37 serialization_fn=lambda configs: [config.serialize() for config in configs], 38 loading_fn=lambda data: [ 39 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"] 40 ], 41 ) 42 43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 ) 52 53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs) 57 58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs) 62 63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n) 67 68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32) 72 73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize())) 76 77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
maze dataset collection configuration, including tokenizers and shuffle
43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 )
return a summary of the config
53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs)
return the total number of mazes in the collection across all dataset
58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs)
return the maximum grid size of the mazes in the collection
63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n)
return the maximum grid shape of the mazes in the collection
68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32)
return the maximum grid shape of the mazes in the collection as a numpy array
73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize()))
return a stable hash of the config
77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
convert config to a filename
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
1176@serializable_dataclass(frozen=True, kw_only=True) 1177class TargetedLatticeMaze(LatticeMaze): # type: ignore[misc] 1178 """A LatticeMaze with a start and end position""" 1179 1180 # this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos 1181 # type ignore here because even though its a kw-only dataclass, 1182 # mypy doesn't like that non-default arguments are after default arguments 1183 start_pos: Coord = serializable_field( # type: ignore[misc] 1184 assert_type=False, 1185 ) 1186 end_pos: Coord = serializable_field( # type: ignore[misc] 1187 assert_type=False, 1188 ) 1189 1190 def __post_init__(self) -> None: 1191 "post init converts start and end pos to numpy arrays, checks they exist and are in bounds" 1192 # make things numpy arrays (very jank to override frozen dataclass) 1193 self.__dict__["start_pos"] = np.array(self.start_pos) 1194 self.__dict__["end_pos"] = np.array(self.end_pos) 1195 assert self.start_pos is not None 1196 assert self.end_pos is not None 1197 # check that start and end are in bounds 1198 if ( 1199 self.start_pos[0] >= self.grid_shape[0] 1200 or self.start_pos[1] >= self.grid_shape[1] 1201 ): 1202 err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}" 1203 raise ValueError( 1204 err_msg, 1205 ) 1206 if ( 1207 self.end_pos[0] >= self.grid_shape[0] 1208 or self.end_pos[1] >= self.grid_shape[1] 1209 ): 1210 err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }" 1211 raise ValueError( 1212 err_msg, 1213 ) 1214 1215 def __eq__(self, other: object) -> bool: 1216 "check equality, calls parent class equality check" 1217 return super().__eq__(other) 1218 1219 def _get_start_pos_tokens(self) -> list[str | CoordTup]: 1220 return [ 1221 SPECIAL_TOKENS.ORIGIN_START, 1222 tuple(self.start_pos), 1223 SPECIAL_TOKENS.ORIGIN_END, 1224 ] 1225 1226 def get_start_pos_tokens(self) -> list[str | CoordTup]: 1227 "(deprecated!) return the start position as a list of tokens" 1228 warnings.warn( 1229 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 1230 TokenizerDeprecationWarning, 1231 ) 1232 return self._get_start_pos_tokens() 1233 1234 def _get_end_pos_tokens(self) -> list[str | CoordTup]: 1235 return [ 1236 SPECIAL_TOKENS.TARGET_START, 1237 tuple(self.end_pos), 1238 SPECIAL_TOKENS.TARGET_END, 1239 ] 1240 1241 def get_end_pos_tokens(self) -> list[str | CoordTup]: 1242 "(deprecated!) return the end position as a list of tokens" 1243 warnings.warn( 1244 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 1245 TokenizerDeprecationWarning, 1246 ) 1247 return self._get_end_pos_tokens() 1248 1249 @classmethod 1250 def from_lattice_maze( 1251 cls, 1252 lattice_maze: LatticeMaze, 1253 start_pos: Coord | CoordTup, 1254 end_pos: Coord | CoordTup, 1255 ) -> "TargetedLatticeMaze": 1256 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" 1257 return cls( 1258 connection_list=lattice_maze.connection_list, 1259 start_pos=np.array(start_pos), 1260 end_pos=np.array(end_pos), 1261 generation_meta=lattice_maze.generation_meta, 1262 )
A LatticeMaze with a start and end position
1226 def get_start_pos_tokens(self) -> list[str | CoordTup]: 1227 "(deprecated!) return the start position as a list of tokens" 1228 warnings.warn( 1229 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 1230 TokenizerDeprecationWarning, 1231 ) 1232 return self._get_start_pos_tokens()
(deprecated!) return the start position as a list of tokens
1241 def get_end_pos_tokens(self) -> list[str | CoordTup]: 1242 "(deprecated!) return the end position as a list of tokens" 1243 warnings.warn( 1244 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 1245 TokenizerDeprecationWarning, 1246 ) 1247 return self._get_end_pos_tokens()
(deprecated!) return the end position as a list of tokens
1249 @classmethod 1250 def from_lattice_maze( 1251 cls, 1252 lattice_maze: LatticeMaze, 1253 start_pos: Coord | CoordTup, 1254 end_pos: Coord | CoordTup, 1255 ) -> "TargetedLatticeMaze": 1256 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" 1257 return cls( 1258 connection_list=lattice_maze.connection_list, 1259 start_pos=np.array(start_pos), 1260 end_pos=np.array(end_pos), 1261 generation_meta=lattice_maze.generation_meta, 1262 )
get a TargetedLatticeMaze
from a LatticeMaze
by specifying start and end positions
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- LatticeMaze
- connection_list
- generation_meta
- lattice_dim
- grid_shape
- n_connections
- grid_n
- heuristic
- nodes_connected
- is_valid_path
- coord_degrees
- get_coord_neighbors
- gen_connected_component_from
- find_shortest_path
- get_nodes
- get_connected_component
- generate_random_path
- as_adj_list
- from_adj_list
- as_adj_list_tokens
- as_tokens
- from_tokens
- as_pixels
- from_pixels
- as_ascii
- from_ascii
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
120@serializable_dataclass( 121 frozen=True, 122 kw_only=True, 123 properties_to_serialize=["lattice_dim", "generation_meta"], 124) 125class LatticeMaze(SerializableDataclass): 126 """lattice maze (nodes on a lattice, connections only to neighboring nodes) 127 128 Connection List represents which nodes (N) are connected in each direction. 129 130 First and second elements represent rightward and downward connections, 131 respectively. 132 133 Example: 134 Connection list: 135 [ 136 [ # down 137 [F T], 138 [F F] 139 ], 140 [ # right 141 [T F], 142 [T F] 143 ] 144 ] 145 146 Nodes with connections 147 N T N F 148 F T 149 N T N F 150 F F 151 152 Graph: 153 N - N 154 | 155 N - N 156 157 Note: the bottom row connections going down, and the 158 right-hand connections going right, will always be False. 159 160 """ 161 162 connection_list: ConnectionList 163 generation_meta: dict | None = serializable_field(default=None, compare=False) 164 165 lattice_dim = property(lambda self: self.connection_list.shape[0]) 166 grid_shape = property(lambda self: self.connection_list.shape[1:]) 167 n_connections = property(lambda self: self.connection_list.sum()) 168 169 @property 170 def grid_n(self) -> int: 171 "grid size as int, raises `AssertionError` if not square" 172 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 173 return self.grid_shape[0] 174 175 # ============================================================ 176 # basic methods 177 # ============================================================ 178 179 def __eq__(self, other: object) -> bool: 180 "equality check calls super" 181 return super().__eq__(other) 182 183 @staticmethod 184 def heuristic(a: CoordTup, b: CoordTup) -> float: 185 """return manhattan distance between two points""" 186 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1]) 187 188 def __hash__(self) -> int: 189 """hash the connection list by converting connection list to bytes""" 190 return hash(self.connection_list.tobytes()) 191 192 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 193 """returns whether two nodes are connected""" 194 delta: Coord = b - a 195 if np.abs(delta).sum() != 1: 196 # return false if not even adjacent 197 return False 198 else: 199 # test for wall 200 dim: int = int(np.argmax(np.abs(delta))) 201 clist_node: Coord = a if (delta.sum() > 0) else b 202 return self.connection_list[dim, clist_node[0], clist_node[1]] 203 204 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 205 """check if a path is valid""" 206 # check path is not empty 207 if len(path) == 0: 208 return empty_is_valid 209 210 # check all coords in bounds of maze 211 if not np.all((path >= 0) & (path < self.grid_shape)): 212 return False 213 214 # check all nodes connected 215 for i in range(len(path) - 1): 216 if not self.nodes_connected(path[i], path[i + 1]): 217 return False 218 return True 219 220 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 221 """Returns an array with the connectivity degree of each coord. 222 223 I.e., how many neighbors each coord has. 224 """ 225 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 226 self.connection_list.astype(np.int8) 227 ) 228 degrees: Int8[np.ndarray, "row col"] = np.sum( 229 int_conn, 230 axis=0, 231 ) # Connections to east and south 232 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 233 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 234 return degrees 235 236 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 237 """Returns an array of the neighboring, connected coords of `c`.""" 238 c = np.array(c) # type: ignore[assignment] 239 neighbors: list[Coord] = [ 240 neighbor 241 for neighbor in (c + NEIGHBORS_MASK) 242 if ( 243 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 244 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 245 and self.nodes_connected(c, neighbor) # connected 246 ) 247 ] 248 249 output: CoordArray = np.array(neighbors) 250 if len(neighbors) > 0: 251 assert output.shape == ( 252 len(neighbors), 253 2, 254 ), ( 255 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 256 ) 257 return output 258 259 def gen_connected_component_from(self, c: Coord) -> CoordArray: 260 """return the connected component from a given coordinate""" 261 # Stack for DFS 262 stack: list[Coord] = [c] 263 264 # Set to store visited nodes 265 visited: set[CoordTup] = set() 266 267 while stack: 268 current_node: Coord = stack.pop() 269 # this is fine since we know current_node is a coord and thus of length 2 270 visited.add(tuple(current_node)) # type: ignore[arg-type] 271 272 # Get the neighbors of the current node 273 neighbors = self.get_coord_neighbors(current_node) 274 275 # Iterate over neighbors 276 for neighbor in neighbors: 277 if tuple(neighbor) not in visited: 278 stack.append(neighbor) 279 280 return np.array(list(visited)) 281 282 def find_shortest_path( 283 self, 284 c_start: CoordTup | Coord, 285 c_end: CoordTup | Coord, 286 ) -> CoordArray: 287 """find the shortest path between two coordinates, using A*""" 288 c_start = tuple(c_start) # type: ignore[assignment] 289 c_end = tuple(c_end) # type: ignore[assignment] 290 291 g_score: dict[CoordTup, float] = ( 292 dict() 293 ) # cost of cheapest path to node from start currently known 294 f_score: dict[CoordTup, float] = { 295 c_start: 0.0, 296 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 297 298 # init 299 g_score[c_start] = 0.0 300 g_score[c_start] = self.heuristic(c_start, c_end) 301 302 closed_vtx: set[CoordTup] = set() # nodes already evaluated 303 # nodes to be evaluated 304 # we need a set of the tuples, dont place the ints in the set 305 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 306 source: dict[CoordTup, CoordTup] = ( 307 dict() 308 ) # node immediately preceding each node in the path (currently known shortest path) 309 310 while open_vtx: 311 # get lowest f_score node 312 # mypy cant tell that c is of length 2 313 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] 314 # f_current: float = f_score[c_current] 315 316 # check if goal is reached 317 if c_end == c_current: 318 path: list[CoordTup] = [c_current] 319 p_current: CoordTup = c_current 320 while p_current in source: 321 p_current = source[p_current] 322 path.append(p_current) 323 # ---------------------------------------------------------------------- 324 # this is the only return statement 325 return np.array(path[::-1]) 326 # ---------------------------------------------------------------------- 327 328 # close current node 329 closed_vtx.add(c_current) 330 open_vtx.remove(c_current) 331 332 # update g_score of neighbors 333 _np_neighbor: Coord 334 for _np_neighbor in self.get_coord_neighbors(c_current): 335 neighbor: CoordTup = tuple(_np_neighbor) 336 337 if neighbor in closed_vtx: 338 # already checked 339 continue 340 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 341 342 if neighbor not in open_vtx: 343 # found new vtx, so add 344 open_vtx.add(neighbor) 345 346 elif g_temp >= g_score[neighbor]: 347 # if already knew about this one, but current g_score is worse, skip 348 continue 349 350 # store g_score and source 351 source[neighbor] = c_current 352 g_score[neighbor] = g_temp 353 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 354 355 raise ValueError( 356 "A solution could not be found!", 357 f"{c_start = }, {c_end = }", 358 self.as_ascii(), 359 ) 360 361 def get_nodes(self) -> CoordArray: 362 """return a list of all nodes in the maze""" 363 rows: Int[np.ndarray, "x y"] 364 cols: Int[np.ndarray, "x y"] 365 rows, cols = np.meshgrid( 366 range(self.grid_shape[0]), 367 range(self.grid_shape[1]), 368 indexing="ij", 369 ) 370 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 371 return nodes 372 373 def get_connected_component(self) -> CoordArray: 374 """get the largest (and assumed only nonsingular) connected component of the maze 375 376 TODO: other connected components? 377 """ 378 if (self.generation_meta is None) or ( 379 self.generation_meta.get("fully_connected", False) 380 ): 381 # for fully connected case, pick any two positions 382 return self.get_nodes() 383 else: 384 # if metadata provided, use visited cells 385 visited_cells: set[CoordTup] | None = self.generation_meta.get( 386 "visited_cells", 387 None, 388 ) 389 if visited_cells is None: 390 # TODO: dynamically generate visited_cells? 391 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 392 raise ValueError( 393 err_msg, 394 ) 395 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 396 return visited_cells_np 397 398 @typing.overload 399 def generate_random_path( 400 self, 401 allowed_start: CoordList | None = None, 402 allowed_end: CoordList | None = None, 403 deadend_start: bool = False, 404 deadend_end: bool = False, 405 endpoints_not_equal: bool = False, 406 except_on_no_valid_endpoint: typing.Literal[True] = True, 407 ) -> CoordArray: ... 408 @typing.overload 409 def generate_random_path( 410 self, 411 allowed_start: CoordList | None = None, 412 allowed_end: CoordList | None = None, 413 deadend_start: bool = False, 414 deadend_end: bool = False, 415 endpoints_not_equal: bool = False, 416 except_on_no_valid_endpoint: typing.Literal[False] = False, 417 ) -> typing.Optional[CoordArray]: ... 418 def generate_random_path( # noqa: C901 419 self, 420 allowed_start: CoordList | None = None, 421 allowed_end: CoordList | None = None, 422 deadend_start: bool = False, 423 deadend_end: bool = False, 424 endpoints_not_equal: bool = False, 425 except_on_no_valid_endpoint: bool = True, 426 ) -> typing.Optional[CoordArray]: 427 """return a path between randomly chosen start and end nodes within the connected component 428 429 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 430 431 # Parameters: 432 - `allowed_start : CoordList | None` 433 a list of allowed start positions. If `None`, any position in the connected component is allowed 434 (defaults to `None`) 435 - `allowed_end : CoordList | None` 436 a list of allowed end positions. If `None`, any position in the connected component is allowed 437 (defaults to `None`) 438 - `deadend_start : bool` 439 whether to ***force*** the start position to be a deadend (defaults to `False`) 440 (defaults to `False`) 441 - `deadend_end : bool` 442 whether to ***force*** the end position to be a deadend (defaults to `False`) 443 (defaults to `False`) 444 - `endpoints_not_equal : bool` 445 whether to ensure tha the start and end point are not the same 446 (defaults to `False`) 447 - `except_on_no_valid_endpoint : bool` 448 whether to raise an error if no valid start or end positions are found 449 if this is `False`, the function might return `None` and this must be handled by the caller 450 (defaults to `True`) 451 452 # Returns: 453 - `CoordArray` 454 a path between the selected start and end positions 455 456 # Raises: 457 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 458 """ 459 # we can't create a "path" in a single-node maze 460 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 461 f"can't create path in single-node maze: {self.as_ascii()}" 462 ) 463 464 # get connected component 465 connected_component: CoordArray = self.get_connected_component() 466 467 # initialize start and end positions 468 positions: Int[np.int8, "2 2"] 469 470 # if no special conditions on start and end positions 471 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 472 None, 473 None, 474 False, 475 False, 476 ): 477 try: 478 positions = connected_component[ # type: ignore[assignment] 479 np.random.choice( 480 len(connected_component), 481 size=2, 482 replace=False, 483 ) 484 ] 485 except ValueError as e: 486 if except_on_no_valid_endpoint: 487 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" 488 raise NoValidEndpointException( 489 err_msg, 490 ) from e 491 return None 492 493 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 494 495 # handle special conditions 496 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 497 # copy connected component set 498 allowed_start_set: set[CoordTup] = connected_component_set.copy() 499 allowed_end_set: set[CoordTup] = connected_component_set.copy() 500 501 # filter by explicitly allowed start and end positions 502 # '# type: ignore[assignment]' here because the returned tuple can be of any length 503 if allowed_start is not None: 504 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 505 506 if allowed_end is not None: 507 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 508 509 # filter by forcing deadends 510 if deadend_start: 511 allowed_start_set = set( 512 filter( 513 lambda x: len(self.get_coord_neighbors(x)) == 1, 514 allowed_start_set, 515 ), 516 ) 517 518 if deadend_end: 519 allowed_end_set = set( 520 filter( 521 lambda x: len(self.get_coord_neighbors(x)) == 1, 522 allowed_end_set, 523 ), 524 ) 525 526 # check we have valid positions 527 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 528 if except_on_no_valid_endpoint: 529 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 530 raise NoValidEndpointException( 531 err_msg, 532 ) 533 return None 534 535 # randomly select start and end positions 536 try: 537 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 538 start_pos: CoordTup = tuple( # type: ignore[assignment] 539 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], 540 ) 541 if endpoints_not_equal: 542 # remove start position from end positions 543 allowed_end_set.discard(start_pos) 544 end_pos: CoordTup = tuple( # type: ignore[assignment] 545 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], 546 ) 547 except ValueError as e: 548 if except_on_no_valid_endpoint: 549 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 550 raise NoValidEndpointException( 551 err_msg, 552 ) from e 553 return None 554 555 return self.find_shortest_path(start_pos, end_pos) 556 557 # ============================================================ 558 # to and from adjacency list 559 # ============================================================ 560 def as_adj_list( 561 self, 562 shuffle_d0: bool = True, 563 shuffle_d1: bool = True, 564 ) -> Int8[np.ndarray, "conn start_end coord"]: 565 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" 566 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1) 567 568 @classmethod 569 def from_adj_list( 570 cls, 571 adj_list: Int8[np.ndarray, "conn start_end coord"], 572 ) -> "LatticeMaze": 573 """create a LatticeMaze from a list of connections 574 575 > [!NOTE] 576 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 577 """ 578 # this is where it would probably break for rectangular mazes 579 grid_n: int = adj_list.max() + 1 580 581 connection_list: ConnectionList = np.zeros( 582 (2, grid_n, grid_n), 583 dtype=np.bool_, 584 ) 585 586 for c_start, c_end in adj_list: 587 # check that exactly 1 coordinate matches 588 if (c_start == c_end).sum() != 1: 589 raise ValueError("invalid connection") 590 591 # get the direction 592 d: int = (c_start != c_end).argmax() 593 594 x: int 595 y: int 596 # pick whichever has the lesser value in the direction `d` 597 if c_start[d] < c_end[d]: 598 x, y = c_start 599 else: 600 x, y = c_end 601 602 connection_list[d, x, y] = True 603 604 return LatticeMaze( 605 connection_list=connection_list, 606 ) 607 608 def as_adj_list_tokens(self) -> list[str | CoordTup]: 609 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" 610 warnings.warn( 611 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 612 TokenizerDeprecationWarning, 613 ) 614 return [ 615 SPECIAL_TOKENS.ADJLIST_START, 616 *chain.from_iterable( # type: ignore[list-item] 617 [ 618 [ 619 tuple(c_s), 620 SPECIAL_TOKENS.CONNECTOR, 621 tuple(c_e), 622 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 623 ] 624 for c_s, c_e in self.as_adj_list() 625 ], 626 ), 627 SPECIAL_TOKENS.ADJLIST_END, 628 ] 629 630 def _as_adj_list_tokens(self) -> list[str | CoordTup]: 631 return [ 632 SPECIAL_TOKENS.ADJLIST_START, 633 *chain.from_iterable( # type: ignore[list-item] 634 [ 635 [ 636 tuple(c_s), 637 SPECIAL_TOKENS.CONNECTOR, 638 tuple(c_e), 639 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 640 ] 641 for c_s, c_e in self.as_adj_list() 642 ], 643 ), 644 SPECIAL_TOKENS.ADJLIST_END, 645 ] 646 647 def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]: 648 """turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples""" 649 output: list[CoordTup | str] = self._as_adj_list_tokens() 650 # if getattr(self, "start_pos", None) is not None: 651 if isinstance(self, TargetedLatticeMaze): 652 output += self._get_start_pos_tokens() 653 if isinstance(self, TargetedLatticeMaze): 654 output += self._get_end_pos_tokens() 655 if isinstance(self, SolvedMaze): 656 output += self._get_solution_tokens() 657 return output 658 659 def _as_tokens( 660 self, 661 maze_tokenizer: "MazeTokenizer | TokenizationMode", 662 ) -> list[str]: 663 # type ignores here fine since we check the instance 664 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 665 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 666 if ( 667 isinstance_by_type_name(maze_tokenizer, "MazeTokenizer") 668 and maze_tokenizer.is_AOTP() # type: ignore[union-attr] 669 ): 670 coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP() 671 coords_processed: list[str] = maze_tokenizer.coords_to_strings( # type: ignore[union-attr] 672 coords=coords_raw, 673 when_noncoord="include", 674 ) 675 return coords_processed 676 else: 677 err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}" 678 raise NotImplementedError(err_msg) 679 680 def as_tokens( 681 self, 682 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 683 ) -> list[str]: 684 """serialize maze and solution to tokens""" 685 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 686 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 687 else: 688 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type] 689 690 @classmethod 691 def _from_tokens_AOTP( 692 cls, 693 tokens: list[str], 694 maze_tokenizer: "MazeTokenizer | MazeTokenizerModular", 695 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 696 """create a LatticeMaze from a list of tokens""" 697 # figure out what input format 698 # ======================================== 699 if tokens[0] == SPECIAL_TOKENS.ADJLIST_START: 700 adj_list_tokens = get_adj_list_tokens(tokens) 701 else: 702 # If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens 703 adj_list_tokens = tokens 704 warnings.warn( 705 "Assuming input is just adjacency list tokens, no special tokens found", 706 ) 707 708 # process edges for adjacency list 709 # ======================================== 710 edges: list[list[str]] = list_split( 711 adj_list_tokens, 712 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 713 ) 714 715 coordinates: list[tuple[CoordTup, CoordTup]] = list() 716 for e in edges: 717 # skip last endline 718 if len(e) != 0: 719 # convert to coords, split start and end 720 e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords( 721 e, 722 when_noncoord="include", 723 ) 724 # this assertion depends on the tokenizer having exactly one token for the connector 725 # which is also why we "include" above 726 # the connector token is discarded below 727 assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }" # noqa: PLR2004 728 assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, ( 729 f"invalid edge: {e = } {e_coords = }" 730 ) 731 e_coords_first: CoordTup = e_coords[0] # type: ignore[assignment] 732 e_coords_last: CoordTup = e_coords[-1] # type: ignore[assignment] 733 coordinates.append((e_coords_first, e_coords_last)) 734 735 assert all(len(c) == DIM_2 for c in coordinates), ( 736 f"invalid coordinates: {coordinates = }" 737 ) 738 adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates) 739 assert tuple(adj_list.shape) == ( 740 len(coordinates), 741 2, 742 2, 743 ), f"invalid adj_list: {adj_list.shape = } {coordinates = }" 744 745 output_maze: LatticeMaze = cls.from_adj_list(adj_list) 746 747 # add start and end positions 748 # ======================================== 749 is_targeted: bool = False 750 if all( 751 x in tokens 752 for x in ( 753 SPECIAL_TOKENS.ORIGIN_START, 754 SPECIAL_TOKENS.ORIGIN_END, 755 SPECIAL_TOKENS.TARGET_START, 756 SPECIAL_TOKENS.TARGET_END, 757 ) 758 ): 759 start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 760 get_origin_tokens(tokens), 761 when_noncoord="error", 762 ) 763 end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 764 get_target_tokens(tokens), 765 when_noncoord="error", 766 ) 767 assert len(start_pos_list) == 1, ( 768 f"invalid start_pos_list: {start_pos_list = }" 769 ) 770 assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }" 771 772 start_pos: CoordTup = start_pos_list[0] 773 end_pos: CoordTup = end_pos_list[0] 774 775 output_maze = TargetedLatticeMaze.from_lattice_maze( 776 lattice_maze=output_maze, 777 start_pos=start_pos, 778 end_pos=end_pos, 779 ) 780 781 is_targeted = True 782 783 if all( 784 x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END) 785 ): 786 assert is_targeted, "maze must be targeted to have a solution" 787 solution: list[CoordTup] = maze_tokenizer.strings_to_coords( 788 get_path_tokens(tokens, trim_end=True), 789 when_noncoord="error", 790 ) 791 output_maze = SolvedMaze.from_targeted_lattice_maze( 792 # HACK: I think this is fine, but im not sure 793 targeted_lattice_maze=output_maze, # type: ignore[arg-type] 794 solution=solution, 795 ) 796 797 return output_maze 798 799 # TODO: any way to get return type hinting working for this? 800 @classmethod 801 def from_tokens( 802 cls, 803 tokens: list[str], 804 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 805 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 806 """Constructs a maze from a tokenization. 807 808 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 809 """ 810 # HACK: type ignores here fine since we check the instance 811 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 812 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 813 if ( 814 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 815 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 816 ): 817 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 818 raise NotImplementedError( 819 err_msg, 820 ) 821 822 if isinstance(tokens, str): 823 tokens = tokens.split() 824 825 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 826 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 827 else: 828 raise NotImplementedError("only AOTP tokenization is supported") 829 830 # ============================================================ 831 # to and from pixels 832 # ============================================================ 833 def _as_pixels_bw(self) -> BinaryPixelGrid: 834 assert self.lattice_dim == DIM_2, "only 2D mazes are supported" 835 # Create an empty pixel grid with walls 836 pixel_grid: Int[np.ndarray, "x y"] = np.full( 837 (self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1), 838 False, 839 dtype=np.bool_, 840 ) 841 842 # Set white nodes 843 pixel_grid[1::2, 1::2] = True 844 845 # Set white connections (downward) 846 for i, row in enumerate(self.connection_list[0]): 847 for j, connected in enumerate(row): 848 if connected: 849 pixel_grid[i * 2 + 2, j * 2 + 1] = True 850 851 # Set white connections (rightward) 852 for i, row in enumerate(self.connection_list[1]): 853 for j, connected in enumerate(row): 854 if connected: 855 pixel_grid[i * 2 + 1, j * 2 + 2] = True 856 857 return pixel_grid 858 859 def as_pixels( 860 self, 861 show_endpoints: bool = True, 862 show_solution: bool = True, 863 ) -> PixelGrid: 864 """convert the maze to a pixel grid 865 866 - useful as a simpler way of plotting the maze than the more complex `MazePlot` 867 - the same underlying representation as `as_ascii` but as an image 868 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data 869 """ 870 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 871 # but solution, start_pos, end_pos not always defined 872 # but its fine since we explicitly check the type 873 if show_solution and not show_endpoints: 874 raise ValueError("show_solution=True requires show_endpoints=True") 875 # convert original bool pixel grid to RGB 876 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 877 pixel_grid: PixelGrid = np.full( 878 (*pixel_grid_bw.shape, 3), 879 PixelColors.WALL, 880 dtype=np.uint8, 881 ) 882 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 883 884 if self.__class__ == LatticeMaze: 885 return pixel_grid 886 887 # set endpoints for TargetedLatticeMaze 888 if self.__class__ == TargetedLatticeMaze: 889 if show_endpoints: 890 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 891 PixelColors.START 892 ) 893 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 894 PixelColors.END 895 ) 896 return pixel_grid 897 898 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 899 if show_solution: 900 for coord in self.solution: # type: ignore[attr-defined] 901 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 902 903 # set pixels between coords 904 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 905 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 906 # check they are adjacent using norm 907 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 908 f"Coords {coord} and {next_coord} are not adjacent" 909 ) 910 # set pixel between them 911 pixel_grid[ 912 coord[0] * 2 + 1 + next_coord[0] - coord[0], 913 coord[1] * 2 + 1 + next_coord[1] - coord[1], 914 ] = PixelColors.PATH 915 916 # set endpoints (again, since path would overwrite them) 917 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 918 PixelColors.START 919 ) 920 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 921 PixelColors.END 922 ) 923 924 return pixel_grid 925 926 @classmethod 927 def _from_pixel_grid_bw( 928 cls, 929 pixel_grid: BinaryPixelGrid, 930 ) -> tuple[ConnectionList, tuple[int, int]]: 931 grid_shape: tuple[int, int] = ( 932 pixel_grid.shape[0] // 2, 933 pixel_grid.shape[1] // 2, 934 ) 935 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_) 936 937 # Extract downward connections 938 connection_list[0] = pixel_grid[2::2, 1::2] 939 940 # Extract rightward connections 941 connection_list[1] = pixel_grid[1::2, 2::2] 942 943 return connection_list, grid_shape 944 945 @classmethod 946 def _from_pixel_grid_with_positions( 947 cls, 948 pixel_grid: PixelGrid | BinaryPixelGrid, 949 marked_positions: dict[str, RGB], 950 ) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]: 951 # Convert RGB pixel grid to Bool pixel grid 952 # error: Incompatible types in assignment (expression has type 953 # "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]", 954 # variable has type "ndarray[Any, Any]") [assignment] 955 pixel_grid_bw: BinaryPixelGrid = ~np.all( # type: ignore[assignment] 956 pixel_grid == PixelColors.WALL, 957 axis=-1, 958 ) 959 connection_list: ConnectionList 960 grid_shape: tuple[int, int] 961 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw) 962 963 # Find any marked positions 964 out_positions: dict[str, CoordArray] = dict() 965 for key, color in marked_positions.items(): 966 pos_temp: Int[np.ndarray, "x y"] = np.argwhere( 967 np.all(pixel_grid == color, axis=-1), 968 ) 969 pos_save: list[CoordTup] = list() 970 for pos in pos_temp: 971 # if it is a coordinate and not connection (transform position, %2==1) 972 if pos[0] % 2 == 1 and pos[1] % 2 == 1: 973 pos_save.append((pos[0] // 2, pos[1] // 2)) 974 975 out_positions[key] = np.array(pos_save) 976 977 return connection_list, grid_shape, out_positions 978 979 @classmethod 980 def from_pixels( 981 cls, 982 pixel_grid: PixelGrid, 983 ) -> "LatticeMaze": 984 """create a LatticeMaze from a pixel grid. reverse of `as_pixels` 985 986 # Raises: 987 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` 988 """ 989 connection_list: ConnectionList 990 grid_shape: tuple[int, int] 991 992 # if a binary pixel grid, return regular LatticeMaze 993 if len(pixel_grid.shape) == 2: # noqa: PLR2004 994 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 995 return LatticeMaze(connection_list=connection_list) 996 997 # otherwise, detect and check it's valid 998 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 999 if cls not in cls_detected.__mro__: 1000 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" 1001 raise ValueError( 1002 err_msg, 1003 ) 1004 1005 ( 1006 connection_list, 1007 grid_shape, 1008 marked_pos, 1009 ) = cls._from_pixel_grid_with_positions( 1010 pixel_grid=pixel_grid, 1011 marked_positions=dict( 1012 start=PixelColors.START, 1013 end=PixelColors.END, 1014 solution=PixelColors.PATH, 1015 ), 1016 ) 1017 # if we wanted a LatticeMaze, return it 1018 if cls == LatticeMaze: 1019 return LatticeMaze(connection_list=connection_list) 1020 1021 # otherwise, keep going 1022 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 1023 1024 # start and end pos 1025 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 1026 assert start_pos_arr.shape == ( 1027 1, 1028 2, 1029 ), ( 1030 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1031 ) 1032 assert end_pos_arr.shape == ( 1033 1, 1034 2, 1035 ), ( 1036 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1037 ) 1038 1039 start_pos: Coord = start_pos_arr[0] 1040 end_pos: Coord = end_pos_arr[0] 1041 1042 # return a TargetedLatticeMaze if that's what we wanted 1043 if cls == TargetedLatticeMaze: 1044 return TargetedLatticeMaze( 1045 connection_list=connection_list, 1046 start_pos=start_pos, 1047 end_pos=end_pos, 1048 ) 1049 1050 # raw solution, only contains path elements and not start or end 1051 solution_raw: CoordArray = marked_pos["solution"] 1052 if len(solution_raw.shape) == 2: # noqa: PLR2004 1053 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 1054 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 1055 ) 1056 elif solution_raw.shape == (0,): 1057 # the solution and end should be immediately adjacent 1058 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 1059 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 1060 ) 1061 1062 # order the solution, by creating a list from the start to the end 1063 # add end pos, since we will iterate over all these starting from the start pos 1064 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 1065 tuple(end_pos), 1066 ] 1067 # solution starts with start point 1068 solution: list[CoordTup] = [tuple(start_pos)] 1069 while solution[-1] != tuple(end_pos): 1070 # use `get_coord_neighbors` to find connected neighbors 1071 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 1072 # TODO: make this less ugly 1073 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 1074 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 1075 ) 1076 # neighbors = neighbors[:, [1, 0]] 1077 # filter out neighbors that are not in the raw solution 1078 neighbors_filtered: CoordArray = np.array( 1079 [ 1080 coord 1081 for coord in neighbors 1082 if ( 1083 tuple(coord) in solution_raw_list 1084 and tuple(coord) not in solution 1085 ) 1086 ], 1087 ) 1088 # assert only one element is left, and then add it to the solution 1089 assert neighbors_filtered.shape == ( 1090 1, 1091 2, 1092 ), ( 1093 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 1094 ) 1095 solution.append(tuple(neighbors_filtered[0])) 1096 1097 # assert the solution is complete 1098 assert solution[0] == tuple(start_pos), ( 1099 f"solution {solution} does not start at start_pos {start_pos}" 1100 ) 1101 assert solution[-1] == tuple(end_pos), ( 1102 f"solution {solution} does not end at end_pos {end_pos}" 1103 ) 1104 1105 return cls( 1106 connection_list=np.array(connection_list), 1107 solution=np.array(solution), # type: ignore[call-arg] 1108 ) 1109 1110 # ============================================================ 1111 # to and from ASCII 1112 # ============================================================ 1113 def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]: 1114 # Get the pixel grid using to_pixels(). 1115 pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw() 1116 1117 # Replace pixel values with ASCII characters. 1118 ascii_grid: Shaped[np.ndarray, "x y"] = np.full( 1119 pixel_grid.shape, 1120 AsciiChars.WALL, 1121 dtype=str, 1122 ) 1123 ascii_grid[pixel_grid == True] = AsciiChars.OPEN # noqa: E712 1124 1125 return ascii_grid 1126 1127 def as_ascii( 1128 self, 1129 show_endpoints: bool = True, 1130 show_solution: bool = True, 1131 ) -> str: 1132 """return an ASCII grid of the maze 1133 1134 useful for debugging in the terminal, or as it's own format 1135 1136 can be reversed with `LatticeMaze.from_ascii()` 1137 """ 1138 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 1139 pixel_grid: PixelGrid = self.as_pixels( 1140 show_endpoints=show_endpoints, 1141 show_solution=show_solution, 1142 ) 1143 1144 chars_replace: tuple = tuple() 1145 if show_endpoints: 1146 chars_replace += (AsciiChars.START, AsciiChars.END) 1147 if show_solution: 1148 chars_replace += (AsciiChars.PATH,) 1149 1150 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1151 if ascii_char in chars_replace: 1152 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 1153 1154 return "\n".join("".join(row) for row in ascii_grid) 1155 1156 @classmethod 1157 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 1158 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" 1159 lines: list[str] = ascii_str.strip().split("\n") 1160 lines = [line.strip() for line in lines] 1161 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 1162 [list(line) for line in lines], 1163 dtype=str, 1164 ) 1165 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 1166 1167 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1168 pixel_grid[ascii_grid == ascii_char] = pixel_color 1169 1170 return cls.from_pixels(pixel_grid)
lattice maze (nodes on a lattice, connections only to neighboring nodes)
Connection List represents which nodes (N) are connected in each direction.
First and second elements represent rightward and downward connections, respectively.
Example: Connection list: [ [ # down [F T], [F F] ], [ # right [T F], [T F] ] ]
Nodes with connections
N T N F
F T
N T N F
F F
Graph:
N - N
|
N - N
Note: the bottom row connections going down, and the right-hand connections going right, will always be False.
169 @property 170 def grid_n(self) -> int: 171 "grid size as int, raises `AssertionError` if not square" 172 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 173 return self.grid_shape[0]
grid size as int, raises AssertionError
if not square
183 @staticmethod 184 def heuristic(a: CoordTup, b: CoordTup) -> float: 185 """return manhattan distance between two points""" 186 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])
return manhattan distance between two points
192 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 193 """returns whether two nodes are connected""" 194 delta: Coord = b - a 195 if np.abs(delta).sum() != 1: 196 # return false if not even adjacent 197 return False 198 else: 199 # test for wall 200 dim: int = int(np.argmax(np.abs(delta))) 201 clist_node: Coord = a if (delta.sum() > 0) else b 202 return self.connection_list[dim, clist_node[0], clist_node[1]]
returns whether two nodes are connected
204 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 205 """check if a path is valid""" 206 # check path is not empty 207 if len(path) == 0: 208 return empty_is_valid 209 210 # check all coords in bounds of maze 211 if not np.all((path >= 0) & (path < self.grid_shape)): 212 return False 213 214 # check all nodes connected 215 for i in range(len(path) - 1): 216 if not self.nodes_connected(path[i], path[i + 1]): 217 return False 218 return True
check if a path is valid
220 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 221 """Returns an array with the connectivity degree of each coord. 222 223 I.e., how many neighbors each coord has. 224 """ 225 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 226 self.connection_list.astype(np.int8) 227 ) 228 degrees: Int8[np.ndarray, "row col"] = np.sum( 229 int_conn, 230 axis=0, 231 ) # Connections to east and south 232 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 233 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 234 return degrees
Returns an array with the connectivity degree of each coord.
I.e., how many neighbors each coord has.
236 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 237 """Returns an array of the neighboring, connected coords of `c`.""" 238 c = np.array(c) # type: ignore[assignment] 239 neighbors: list[Coord] = [ 240 neighbor 241 for neighbor in (c + NEIGHBORS_MASK) 242 if ( 243 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 244 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 245 and self.nodes_connected(c, neighbor) # connected 246 ) 247 ] 248 249 output: CoordArray = np.array(neighbors) 250 if len(neighbors) > 0: 251 assert output.shape == ( 252 len(neighbors), 253 2, 254 ), ( 255 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 256 ) 257 return output
Returns an array of the neighboring, connected coords of c
.
259 def gen_connected_component_from(self, c: Coord) -> CoordArray: 260 """return the connected component from a given coordinate""" 261 # Stack for DFS 262 stack: list[Coord] = [c] 263 264 # Set to store visited nodes 265 visited: set[CoordTup] = set() 266 267 while stack: 268 current_node: Coord = stack.pop() 269 # this is fine since we know current_node is a coord and thus of length 2 270 visited.add(tuple(current_node)) # type: ignore[arg-type] 271 272 # Get the neighbors of the current node 273 neighbors = self.get_coord_neighbors(current_node) 274 275 # Iterate over neighbors 276 for neighbor in neighbors: 277 if tuple(neighbor) not in visited: 278 stack.append(neighbor) 279 280 return np.array(list(visited))
return the connected component from a given coordinate
282 def find_shortest_path( 283 self, 284 c_start: CoordTup | Coord, 285 c_end: CoordTup | Coord, 286 ) -> CoordArray: 287 """find the shortest path between two coordinates, using A*""" 288 c_start = tuple(c_start) # type: ignore[assignment] 289 c_end = tuple(c_end) # type: ignore[assignment] 290 291 g_score: dict[CoordTup, float] = ( 292 dict() 293 ) # cost of cheapest path to node from start currently known 294 f_score: dict[CoordTup, float] = { 295 c_start: 0.0, 296 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 297 298 # init 299 g_score[c_start] = 0.0 300 g_score[c_start] = self.heuristic(c_start, c_end) 301 302 closed_vtx: set[CoordTup] = set() # nodes already evaluated 303 # nodes to be evaluated 304 # we need a set of the tuples, dont place the ints in the set 305 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 306 source: dict[CoordTup, CoordTup] = ( 307 dict() 308 ) # node immediately preceding each node in the path (currently known shortest path) 309 310 while open_vtx: 311 # get lowest f_score node 312 # mypy cant tell that c is of length 2 313 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] 314 # f_current: float = f_score[c_current] 315 316 # check if goal is reached 317 if c_end == c_current: 318 path: list[CoordTup] = [c_current] 319 p_current: CoordTup = c_current 320 while p_current in source: 321 p_current = source[p_current] 322 path.append(p_current) 323 # ---------------------------------------------------------------------- 324 # this is the only return statement 325 return np.array(path[::-1]) 326 # ---------------------------------------------------------------------- 327 328 # close current node 329 closed_vtx.add(c_current) 330 open_vtx.remove(c_current) 331 332 # update g_score of neighbors 333 _np_neighbor: Coord 334 for _np_neighbor in self.get_coord_neighbors(c_current): 335 neighbor: CoordTup = tuple(_np_neighbor) 336 337 if neighbor in closed_vtx: 338 # already checked 339 continue 340 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 341 342 if neighbor not in open_vtx: 343 # found new vtx, so add 344 open_vtx.add(neighbor) 345 346 elif g_temp >= g_score[neighbor]: 347 # if already knew about this one, but current g_score is worse, skip 348 continue 349 350 # store g_score and source 351 source[neighbor] = c_current 352 g_score[neighbor] = g_temp 353 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 354 355 raise ValueError( 356 "A solution could not be found!", 357 f"{c_start = }, {c_end = }", 358 self.as_ascii(), 359 )
find the shortest path between two coordinates, using A*
361 def get_nodes(self) -> CoordArray: 362 """return a list of all nodes in the maze""" 363 rows: Int[np.ndarray, "x y"] 364 cols: Int[np.ndarray, "x y"] 365 rows, cols = np.meshgrid( 366 range(self.grid_shape[0]), 367 range(self.grid_shape[1]), 368 indexing="ij", 369 ) 370 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 371 return nodes
return a list of all nodes in the maze
373 def get_connected_component(self) -> CoordArray: 374 """get the largest (and assumed only nonsingular) connected component of the maze 375 376 TODO: other connected components? 377 """ 378 if (self.generation_meta is None) or ( 379 self.generation_meta.get("fully_connected", False) 380 ): 381 # for fully connected case, pick any two positions 382 return self.get_nodes() 383 else: 384 # if metadata provided, use visited cells 385 visited_cells: set[CoordTup] | None = self.generation_meta.get( 386 "visited_cells", 387 None, 388 ) 389 if visited_cells is None: 390 # TODO: dynamically generate visited_cells? 391 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 392 raise ValueError( 393 err_msg, 394 ) 395 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 396 return visited_cells_np
get the largest (and assumed only nonsingular) connected component of the maze
TODO: other connected components?
418 def generate_random_path( # noqa: C901 419 self, 420 allowed_start: CoordList | None = None, 421 allowed_end: CoordList | None = None, 422 deadend_start: bool = False, 423 deadend_end: bool = False, 424 endpoints_not_equal: bool = False, 425 except_on_no_valid_endpoint: bool = True, 426 ) -> typing.Optional[CoordArray]: 427 """return a path between randomly chosen start and end nodes within the connected component 428 429 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 430 431 # Parameters: 432 - `allowed_start : CoordList | None` 433 a list of allowed start positions. If `None`, any position in the connected component is allowed 434 (defaults to `None`) 435 - `allowed_end : CoordList | None` 436 a list of allowed end positions. If `None`, any position in the connected component is allowed 437 (defaults to `None`) 438 - `deadend_start : bool` 439 whether to ***force*** the start position to be a deadend (defaults to `False`) 440 (defaults to `False`) 441 - `deadend_end : bool` 442 whether to ***force*** the end position to be a deadend (defaults to `False`) 443 (defaults to `False`) 444 - `endpoints_not_equal : bool` 445 whether to ensure tha the start and end point are not the same 446 (defaults to `False`) 447 - `except_on_no_valid_endpoint : bool` 448 whether to raise an error if no valid start or end positions are found 449 if this is `False`, the function might return `None` and this must be handled by the caller 450 (defaults to `True`) 451 452 # Returns: 453 - `CoordArray` 454 a path between the selected start and end positions 455 456 # Raises: 457 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 458 """ 459 # we can't create a "path" in a single-node maze 460 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 461 f"can't create path in single-node maze: {self.as_ascii()}" 462 ) 463 464 # get connected component 465 connected_component: CoordArray = self.get_connected_component() 466 467 # initialize start and end positions 468 positions: Int[np.int8, "2 2"] 469 470 # if no special conditions on start and end positions 471 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 472 None, 473 None, 474 False, 475 False, 476 ): 477 try: 478 positions = connected_component[ # type: ignore[assignment] 479 np.random.choice( 480 len(connected_component), 481 size=2, 482 replace=False, 483 ) 484 ] 485 except ValueError as e: 486 if except_on_no_valid_endpoint: 487 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" 488 raise NoValidEndpointException( 489 err_msg, 490 ) from e 491 return None 492 493 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 494 495 # handle special conditions 496 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 497 # copy connected component set 498 allowed_start_set: set[CoordTup] = connected_component_set.copy() 499 allowed_end_set: set[CoordTup] = connected_component_set.copy() 500 501 # filter by explicitly allowed start and end positions 502 # '# type: ignore[assignment]' here because the returned tuple can be of any length 503 if allowed_start is not None: 504 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 505 506 if allowed_end is not None: 507 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 508 509 # filter by forcing deadends 510 if deadend_start: 511 allowed_start_set = set( 512 filter( 513 lambda x: len(self.get_coord_neighbors(x)) == 1, 514 allowed_start_set, 515 ), 516 ) 517 518 if deadend_end: 519 allowed_end_set = set( 520 filter( 521 lambda x: len(self.get_coord_neighbors(x)) == 1, 522 allowed_end_set, 523 ), 524 ) 525 526 # check we have valid positions 527 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 528 if except_on_no_valid_endpoint: 529 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 530 raise NoValidEndpointException( 531 err_msg, 532 ) 533 return None 534 535 # randomly select start and end positions 536 try: 537 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 538 start_pos: CoordTup = tuple( # type: ignore[assignment] 539 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], 540 ) 541 if endpoints_not_equal: 542 # remove start position from end positions 543 allowed_end_set.discard(start_pos) 544 end_pos: CoordTup = tuple( # type: ignore[assignment] 545 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], 546 ) 547 except ValueError as e: 548 if except_on_no_valid_endpoint: 549 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 550 raise NoValidEndpointException( 551 err_msg, 552 ) from e 553 return None 554 555 return self.find_shortest_path(start_pos, end_pos)
return a path between randomly chosen start and end nodes within the connected component
Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
Parameters:
allowed_start : CoordList | None
a list of allowed start positions. IfNone
, any position in the connected component is allowed (defaults toNone
)allowed_end : CoordList | None
a list of allowed end positions. IfNone
, any position in the connected component is allowed (defaults toNone
)deadend_start : bool
whether to force the start position to be a deadend (defaults toFalse
) (defaults toFalse
)deadend_end : bool
whether to force the end position to be a deadend (defaults toFalse
) (defaults toFalse
)endpoints_not_equal : bool
whether to ensure tha the start and end point are not the same (defaults toFalse
)except_on_no_valid_endpoint : bool
whether to raise an error if no valid start or end positions are found if this isFalse
, the function might returnNone
and this must be handled by the caller (defaults toTrue
)
Returns:
CoordArray
a path between the selected start and end positions
Raises:
NoValidEndpointException
: if no valid start or end positions are found, andexcept_on_no_valid_endpoint
isTrue
560 def as_adj_list( 561 self, 562 shuffle_d0: bool = True, 563 shuffle_d1: bool = True, 564 ) -> Int8[np.ndarray, "conn start_end coord"]: 565 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" 566 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)
return the maze as an adjacency list, wraps maze_dataset.token_utils.connection_list_to_adj_list
568 @classmethod 569 def from_adj_list( 570 cls, 571 adj_list: Int8[np.ndarray, "conn start_end coord"], 572 ) -> "LatticeMaze": 573 """create a LatticeMaze from a list of connections 574 575 > [!NOTE] 576 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 577 """ 578 # this is where it would probably break for rectangular mazes 579 grid_n: int = adj_list.max() + 1 580 581 connection_list: ConnectionList = np.zeros( 582 (2, grid_n, grid_n), 583 dtype=np.bool_, 584 ) 585 586 for c_start, c_end in adj_list: 587 # check that exactly 1 coordinate matches 588 if (c_start == c_end).sum() != 1: 589 raise ValueError("invalid connection") 590 591 # get the direction 592 d: int = (c_start != c_end).argmax() 593 594 x: int 595 y: int 596 # pick whichever has the lesser value in the direction `d` 597 if c_start[d] < c_end[d]: 598 x, y = c_start 599 else: 600 x, y = c_end 601 602 connection_list[d, x, y] = True 603 604 return LatticeMaze( 605 connection_list=connection_list, 606 )
create a LatticeMaze from a list of connections
This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
608 def as_adj_list_tokens(self) -> list[str | CoordTup]: 609 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" 610 warnings.warn( 611 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 612 TokenizerDeprecationWarning, 613 ) 614 return [ 615 SPECIAL_TOKENS.ADJLIST_START, 616 *chain.from_iterable( # type: ignore[list-item] 617 [ 618 [ 619 tuple(c_s), 620 SPECIAL_TOKENS.CONNECTOR, 621 tuple(c_e), 622 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 623 ] 624 for c_s, c_e in self.as_adj_list() 625 ], 626 ), 627 SPECIAL_TOKENS.ADJLIST_END, 628 ]
(deprecated!) turn the maze into adjacency list tokens, use MazeTokenizerModular
instead
680 def as_tokens( 681 self, 682 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 683 ) -> list[str]: 684 """serialize maze and solution to tokens""" 685 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 686 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 687 else: 688 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type]
serialize maze and solution to tokens
800 @classmethod 801 def from_tokens( 802 cls, 803 tokens: list[str], 804 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 805 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 806 """Constructs a maze from a tokenization. 807 808 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 809 """ 810 # HACK: type ignores here fine since we check the instance 811 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 812 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 813 if ( 814 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 815 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 816 ): 817 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 818 raise NotImplementedError( 819 err_msg, 820 ) 821 822 if isinstance(tokens, str): 823 tokens = tokens.split() 824 825 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 826 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 827 else: 828 raise NotImplementedError("only AOTP tokenization is supported")
Constructs a maze from a tokenization.
Only legacy tokenizers and their MazeTokenizerModular
analogs are supported.
859 def as_pixels( 860 self, 861 show_endpoints: bool = True, 862 show_solution: bool = True, 863 ) -> PixelGrid: 864 """convert the maze to a pixel grid 865 866 - useful as a simpler way of plotting the maze than the more complex `MazePlot` 867 - the same underlying representation as `as_ascii` but as an image 868 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data 869 """ 870 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 871 # but solution, start_pos, end_pos not always defined 872 # but its fine since we explicitly check the type 873 if show_solution and not show_endpoints: 874 raise ValueError("show_solution=True requires show_endpoints=True") 875 # convert original bool pixel grid to RGB 876 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 877 pixel_grid: PixelGrid = np.full( 878 (*pixel_grid_bw.shape, 3), 879 PixelColors.WALL, 880 dtype=np.uint8, 881 ) 882 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 883 884 if self.__class__ == LatticeMaze: 885 return pixel_grid 886 887 # set endpoints for TargetedLatticeMaze 888 if self.__class__ == TargetedLatticeMaze: 889 if show_endpoints: 890 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 891 PixelColors.START 892 ) 893 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 894 PixelColors.END 895 ) 896 return pixel_grid 897 898 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 899 if show_solution: 900 for coord in self.solution: # type: ignore[attr-defined] 901 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 902 903 # set pixels between coords 904 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 905 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 906 # check they are adjacent using norm 907 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 908 f"Coords {coord} and {next_coord} are not adjacent" 909 ) 910 # set pixel between them 911 pixel_grid[ 912 coord[0] * 2 + 1 + next_coord[0] - coord[0], 913 coord[1] * 2 + 1 + next_coord[1] - coord[1], 914 ] = PixelColors.PATH 915 916 # set endpoints (again, since path would overwrite them) 917 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 918 PixelColors.START 919 ) 920 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 921 PixelColors.END 922 ) 923 924 return pixel_grid
convert the maze to a pixel grid
- useful as a simpler way of plotting the maze than the more complex
MazePlot
- the same underlying representation as
as_ascii
but as an image - used in
RasterizedMazeDataset
, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data
979 @classmethod 980 def from_pixels( 981 cls, 982 pixel_grid: PixelGrid, 983 ) -> "LatticeMaze": 984 """create a LatticeMaze from a pixel grid. reverse of `as_pixels` 985 986 # Raises: 987 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` 988 """ 989 connection_list: ConnectionList 990 grid_shape: tuple[int, int] 991 992 # if a binary pixel grid, return regular LatticeMaze 993 if len(pixel_grid.shape) == 2: # noqa: PLR2004 994 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 995 return LatticeMaze(connection_list=connection_list) 996 997 # otherwise, detect and check it's valid 998 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 999 if cls not in cls_detected.__mro__: 1000 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" 1001 raise ValueError( 1002 err_msg, 1003 ) 1004 1005 ( 1006 connection_list, 1007 grid_shape, 1008 marked_pos, 1009 ) = cls._from_pixel_grid_with_positions( 1010 pixel_grid=pixel_grid, 1011 marked_positions=dict( 1012 start=PixelColors.START, 1013 end=PixelColors.END, 1014 solution=PixelColors.PATH, 1015 ), 1016 ) 1017 # if we wanted a LatticeMaze, return it 1018 if cls == LatticeMaze: 1019 return LatticeMaze(connection_list=connection_list) 1020 1021 # otherwise, keep going 1022 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 1023 1024 # start and end pos 1025 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 1026 assert start_pos_arr.shape == ( 1027 1, 1028 2, 1029 ), ( 1030 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1031 ) 1032 assert end_pos_arr.shape == ( 1033 1, 1034 2, 1035 ), ( 1036 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1037 ) 1038 1039 start_pos: Coord = start_pos_arr[0] 1040 end_pos: Coord = end_pos_arr[0] 1041 1042 # return a TargetedLatticeMaze if that's what we wanted 1043 if cls == TargetedLatticeMaze: 1044 return TargetedLatticeMaze( 1045 connection_list=connection_list, 1046 start_pos=start_pos, 1047 end_pos=end_pos, 1048 ) 1049 1050 # raw solution, only contains path elements and not start or end 1051 solution_raw: CoordArray = marked_pos["solution"] 1052 if len(solution_raw.shape) == 2: # noqa: PLR2004 1053 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 1054 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 1055 ) 1056 elif solution_raw.shape == (0,): 1057 # the solution and end should be immediately adjacent 1058 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 1059 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 1060 ) 1061 1062 # order the solution, by creating a list from the start to the end 1063 # add end pos, since we will iterate over all these starting from the start pos 1064 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 1065 tuple(end_pos), 1066 ] 1067 # solution starts with start point 1068 solution: list[CoordTup] = [tuple(start_pos)] 1069 while solution[-1] != tuple(end_pos): 1070 # use `get_coord_neighbors` to find connected neighbors 1071 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 1072 # TODO: make this less ugly 1073 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 1074 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 1075 ) 1076 # neighbors = neighbors[:, [1, 0]] 1077 # filter out neighbors that are not in the raw solution 1078 neighbors_filtered: CoordArray = np.array( 1079 [ 1080 coord 1081 for coord in neighbors 1082 if ( 1083 tuple(coord) in solution_raw_list 1084 and tuple(coord) not in solution 1085 ) 1086 ], 1087 ) 1088 # assert only one element is left, and then add it to the solution 1089 assert neighbors_filtered.shape == ( 1090 1, 1091 2, 1092 ), ( 1093 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 1094 ) 1095 solution.append(tuple(neighbors_filtered[0])) 1096 1097 # assert the solution is complete 1098 assert solution[0] == tuple(start_pos), ( 1099 f"solution {solution} does not start at start_pos {start_pos}" 1100 ) 1101 assert solution[-1] == tuple(end_pos), ( 1102 f"solution {solution} does not end at end_pos {end_pos}" 1103 ) 1104 1105 return cls( 1106 connection_list=np.array(connection_list), 1107 solution=np.array(solution), # type: ignore[call-arg] 1108 )
create a LatticeMaze from a pixel grid. reverse of as_pixels
Raises:
ValueError
: if the pixel grid cannot be cast to aLatticeMaze
-- it's probably aTargetedLatticeMaze
orSolvedMaze
1127 def as_ascii( 1128 self, 1129 show_endpoints: bool = True, 1130 show_solution: bool = True, 1131 ) -> str: 1132 """return an ASCII grid of the maze 1133 1134 useful for debugging in the terminal, or as it's own format 1135 1136 can be reversed with `LatticeMaze.from_ascii()` 1137 """ 1138 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 1139 pixel_grid: PixelGrid = self.as_pixels( 1140 show_endpoints=show_endpoints, 1141 show_solution=show_solution, 1142 ) 1143 1144 chars_replace: tuple = tuple() 1145 if show_endpoints: 1146 chars_replace += (AsciiChars.START, AsciiChars.END) 1147 if show_solution: 1148 chars_replace += (AsciiChars.PATH,) 1149 1150 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1151 if ascii_char in chars_replace: 1152 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 1153 1154 return "\n".join("".join(row) for row in ascii_grid)
return an ASCII grid of the maze
useful for debugging in the terminal, or as it's own format
can be reversed with LatticeMaze.from_ascii()
1156 @classmethod 1157 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 1158 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" 1159 lines: list[str] = ascii_str.strip().split("\n") 1160 lines = [line.strip() for line in lines] 1161 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 1162 [list(line) for line in lines], 1163 dtype=str, 1164 ) 1165 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 1166 1167 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1168 pixel_grid[ascii_grid == ascii_char] = pixel_color 1169 1170 return cls.from_pixels(pixel_grid)
get a LatticeMaze
from an ASCII representation (reverses LaticeMaze.as_ascii
)
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
56def set_serialize_minimal_threshold(threshold: int | None) -> None: 57 "get the global SERIALIZE_MINIMAL_THRESHOLD" 58 global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603 59 SERIALIZE_MINIMAL_THRESHOLD = threshold
get the global SERIALIZE_MINIMAL_THRESHOLD
21def register_maze_filter( 22 method: typing.Callable[[SolvedMaze, typing.Any], bool], 23) -> DatasetFilterFunc: 24 """register a maze filter, casting it to operate over the whole list of mazes 25 26 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 27 28 this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays 29 """ 30 31 @functools.wraps(method) 32 def wrapper(dataset: MazeDataset, *args, **kwargs) -> MazeDataset: 33 # copy and filter 34 new_dataset: MazeDataset = copy.deepcopy( 35 MazeDataset( 36 cfg=dataset.cfg, 37 mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)], 38 ), 39 ) 40 # update the config 41 new_dataset.cfg.applied_filters.append( 42 dict(name=method.__name__, args=args, kwargs=kwargs), 43 ) 44 new_dataset.update_self_config() 45 return new_dataset 46 47 return wrapper
register a maze filter, casting it to operate over the whole list of mazes
method should be a staticmethod of a namespace class registered with register_filter_namespace_for_dataset
this is a more restricted version of register_dataset_filter
that removes the need for boilerplate for operating over the arrays
54class LatticeMazeGenerators: 55 """namespace for lattice maze generation algorithms 56 57 examples of generated mazes can be found here: 58 https://understanding-search.github.io/maze-dataset/examples/maze_examples.html 59 """ 60 61 @staticmethod 62 def gen_dfs( 63 grid_shape: Coord | CoordTup, 64 lattice_dim: int = 2, 65 accessible_cells: float | None = None, 66 max_tree_depth: float | None = None, 67 do_forks: bool = True, 68 randomized_stack: bool = False, 69 start_coord: Coord | None = None, 70 ) -> LatticeMaze: 71 """generate a lattice maze using depth first search, iterative 72 73 # Arguments 74 - `grid_shape: Coord`: the shape of the grid 75 - `lattice_dim: int`: the dimension of the lattice 76 (default: `2`) 77 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 78 (default: `None`) 79 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 80 (default: `None`) 81 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 82 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 83 84 # algorithm 85 1. Choose the initial cell, mark it as visited and push it to the stack 86 2. While the stack is not empty 87 1. Pop a cell from the stack and make it a current cell 88 2. If the current cell has any neighbours which have not been visited 89 1. Push the current cell to the stack 90 2. Choose one of the unvisited neighbours 91 3. Remove the wall between the current cell and the chosen cell 92 4. Mark the chosen cell as visited and push it to the stack 93 """ 94 # Default values if no constraints have been passed 95 grid_shape_: Coord = np.array(grid_shape) 96 n_total_cells: int = int(np.prod(grid_shape_)) 97 98 n_accessible_cells: int 99 if accessible_cells is None: 100 n_accessible_cells = n_total_cells 101 elif isinstance(accessible_cells, float): 102 assert accessible_cells <= 1, ( 103 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 104 ) 105 106 n_accessible_cells = int(accessible_cells * n_total_cells) 107 else: 108 assert isinstance(accessible_cells, int) 109 n_accessible_cells = accessible_cells 110 111 if max_tree_depth is None: 112 max_tree_depth = ( 113 2 * n_total_cells 114 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 115 elif isinstance(max_tree_depth, float): 116 assert max_tree_depth <= 1, ( 117 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 118 ) 119 120 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 121 122 # choose a random start coord 123 start_coord = _random_start_coord(grid_shape_, start_coord) 124 125 # initialize the maze with no connections 126 connection_list: ConnectionList = np.zeros( 127 (lattice_dim, grid_shape_[0], grid_shape_[1]), 128 dtype=np.bool_, 129 ) 130 131 # initialize the stack with the target coord 132 visited_cells: set[tuple[int, int]] = set() 133 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 134 stack: list[Coord] = [start_coord] 135 136 # initialize tree_depth_counter 137 current_tree_depth: int = 1 138 139 # loop until the stack is empty or n_connected_cells is reached 140 while stack and (len(visited_cells) < n_accessible_cells): 141 # get the current coord from the stack 142 current_coord: Coord 143 if randomized_stack: 144 current_coord = stack.pop(random.randint(0, len(stack) - 1)) 145 else: 146 current_coord = stack.pop() 147 148 # filter neighbors by being within grid bounds and being unvisited 149 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 150 (neighbor, delta) 151 for neighbor, delta in zip( 152 current_coord + NEIGHBORS_MASK, 153 NEIGHBORS_MASK, 154 strict=False, 155 ) 156 if ( 157 (tuple(neighbor) not in visited_cells) 158 and (0 <= neighbor[0] < grid_shape_[0]) 159 and (0 <= neighbor[1] < grid_shape_[1]) 160 ) 161 ] 162 163 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 164 if unvisited_neighbors_deltas and ( 165 current_tree_depth <= max_tree_depth / 2 166 ): 167 # if we want a maze without forks, simply don't add the current coord back to the stack 168 if do_forks and (len(unvisited_neighbors_deltas) > 1): 169 stack.append(current_coord) 170 171 # choose one of the unvisited neighbors 172 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) 173 174 # add connection 175 dim: int = int(np.argmax(np.abs(delta))) 176 # if positive, down/right from current coord 177 # if negative, up/left from current coord (down/right from neighbor) 178 clist_node: Coord = ( 179 current_coord if (delta.sum() > 0) else chosen_neighbor 180 ) 181 connection_list[dim, clist_node[0], clist_node[1]] = True 182 183 # add to visited cells and stack 184 visited_cells.add(tuple(chosen_neighbor)) 185 stack.append(chosen_neighbor) 186 187 # Update current tree depth 188 current_tree_depth += 1 189 else: 190 current_tree_depth -= 1 191 192 return LatticeMaze( 193 connection_list=connection_list, 194 generation_meta=dict( 195 func_name="gen_dfs", 196 grid_shape=grid_shape_, 197 start_coord=start_coord, 198 n_accessible_cells=int(n_accessible_cells), 199 max_tree_depth=int(max_tree_depth), 200 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 201 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 202 # treated as fully connected even when it is most certainly not, causing solving the maze to break 203 fully_connected=bool(len(visited_cells) == n_total_cells), 204 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 205 ), 206 ) 207 208 @staticmethod 209 def gen_prim( 210 grid_shape: Coord | CoordTup, 211 lattice_dim: int = 2, 212 accessible_cells: float | None = None, 213 max_tree_depth: float | None = None, 214 do_forks: bool = True, 215 start_coord: Coord | None = None, 216 ) -> LatticeMaze: 217 "(broken!) generate a lattice maze using Prim's algorithm" 218 warnings.warn( 219 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 220 ) 221 return LatticeMazeGenerators.gen_dfs( 222 grid_shape=grid_shape, 223 lattice_dim=lattice_dim, 224 accessible_cells=accessible_cells, 225 max_tree_depth=max_tree_depth, 226 do_forks=do_forks, 227 start_coord=start_coord, 228 randomized_stack=True, 229 ) 230 231 @staticmethod 232 def gen_wilson( 233 grid_shape: Coord | CoordTup, 234 **kwargs, 235 ) -> LatticeMaze: 236 """Generate a lattice maze using Wilson's algorithm. 237 238 # Algorithm 239 Wilson's algorithm generates an unbiased (random) maze 240 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 241 acyclic and all cells are part of a unique connected space. 242 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 243 """ 244 assert not kwargs, ( 245 f"gen_wilson does not take any additional arguments, got {kwargs = }" 246 ) 247 248 grid_shape_: Coord = np.array(grid_shape) 249 250 # Initialize grid and visited cells 251 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 252 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 253 254 # Choose a random cell and mark it as visited 255 start_coord: Coord = _random_start_coord(grid_shape_, None) 256 visited[start_coord[0], start_coord[1]] = True 257 del start_coord 258 259 while not visited.all(): 260 # Perform loop-erased random walk from another random cell 261 262 # Choose walk_start only from unvisited cells 263 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 264 walk_start: Coord = unvisited_coords[ 265 np.random.choice(unvisited_coords.shape[0]) 266 ] 267 268 # Perform the random walk 269 path: list[Coord] = [walk_start] 270 current: Coord = walk_start 271 272 # exit the loop once the current path hits a visited cell 273 while not visited[current[0], current[1]]: 274 # find a valid neighbor (one always exists on a lattice) 275 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 276 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 277 278 # Check for loop 279 loop_exit: int | None = None 280 for i, p in enumerate(path): 281 if np.array_equal(next_cell, p): 282 loop_exit = i 283 break 284 285 # erase the loop, or continue the walk 286 if loop_exit is not None: 287 # this removes everything after and including the loop start 288 path = path[: loop_exit + 1] 289 # reset current cell to end of path 290 current = path[-1] 291 else: 292 path.append(next_cell) 293 current = next_cell 294 295 # Add the path to the maze 296 for i in range(len(path) - 1): 297 c_1: Coord = path[i] 298 c_2: Coord = path[i + 1] 299 300 # find the dimension of the connection 301 delta: Coord = c_2 - c_1 302 dim: int = int(np.argmax(np.abs(delta))) 303 304 # if positive, down/right from current coord 305 # if negative, up/left from current coord (down/right from neighbor) 306 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 307 connection_list[dim, clist_node[0], clist_node[1]] = True 308 visited[c_1[0], c_1[1]] = True 309 # we dont add c_2 because the last c_2 will have already been visited 310 311 return LatticeMaze( 312 connection_list=connection_list, 313 generation_meta=dict( 314 func_name="gen_wilson", 315 grid_shape=grid_shape_, 316 fully_connected=True, 317 ), 318 ) 319 320 @staticmethod 321 def gen_percolation( 322 grid_shape: Coord | CoordTup, 323 p: float = 0.4, 324 lattice_dim: int = 2, 325 start_coord: Coord | None = None, 326 ) -> LatticeMaze: 327 """generate a lattice maze using simple percolation 328 329 note that p in the range (0.4, 0.7) gives the most interesting mazes 330 331 # Arguments 332 - `grid_shape: Coord`: the shape of the grid 333 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 334 - `p: float`: the probability of a cell being accessible (default: `0.5`) 335 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 336 """ 337 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 338 grid_shape_: Coord = np.array(grid_shape) 339 340 start_coord = _random_start_coord(grid_shape_, start_coord) 341 342 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 343 344 connection_list = _fill_edges_with_walls(connection_list) 345 346 output: LatticeMaze = LatticeMaze( 347 connection_list=connection_list, 348 generation_meta=dict( 349 func_name="gen_percolation", 350 grid_shape=grid_shape_, 351 percolation_p=p, 352 start_coord=start_coord, 353 ), 354 ) 355 356 # generation_meta is sometimes None, but not here since we just made it a dict above 357 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 358 start_coord, 359 ) 360 361 return output 362 363 @staticmethod 364 def gen_dfs_percolation( 365 grid_shape: Coord | CoordTup, 366 p: float = 0.4, 367 lattice_dim: int = 2, 368 accessible_cells: int | None = None, 369 max_tree_depth: int | None = None, 370 start_coord: Coord | None = None, 371 ) -> LatticeMaze: 372 """dfs and then percolation (adds cycles)""" 373 grid_shape_: Coord = np.array(grid_shape) 374 start_coord = _random_start_coord(grid_shape_, start_coord) 375 376 # generate initial maze via dfs 377 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 378 grid_shape=grid_shape_, 379 lattice_dim=lattice_dim, 380 accessible_cells=accessible_cells, 381 max_tree_depth=max_tree_depth, 382 start_coord=start_coord, 383 ) 384 385 # percolate 386 connection_list_perc: np.ndarray = ( 387 np.random.rand(*maze.connection_list.shape) < p 388 ) 389 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 390 391 maze.__dict__["connection_list"] = np.logical_or( 392 maze.connection_list, 393 connection_list_perc, 394 ) 395 396 # generation_meta is sometimes None, but not here since we just made it a dict above 397 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 398 maze.generation_meta["percolation_p"] = p # type: ignore[index] 399 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 400 start_coord, 401 ) 402 403 return maze 404 405 @staticmethod 406 def gen_kruskal( 407 grid_shape: "Coord | CoordTup", 408 lattice_dim: int = 2, 409 start_coord: "Coord | None" = None, 410 ) -> "LatticeMaze": 411 """Generate a maze using Kruskal's algorithm. 412 413 This function generates a random spanning tree over a grid using Kruskal's algorithm. 414 Each cell is treated as a node, and all valid adjacent edges are listed and processed 415 in random order. An edge is added (i.e. its passage carved) only if it connects two cells 416 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) 417 without cycles. 418 419 https://en.wikipedia.org/wiki/Kruskal's_algorithm 420 421 # Parameters: 422 - `grid_shape : Coord | CoordTup` 423 The shape of the maze grid (for example, `(n_rows, n_cols)`). 424 - `lattice_dim : int` 425 The lattice dimension (default is `2`). 426 - `start_coord : Coord | None` 427 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen. 428 - `**kwargs` 429 Additional keyword arguments (currently unused). 430 431 # Returns: 432 - `LatticeMaze` 433 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm. 434 435 # Usage: 436 ```python 437 maze = gen_kruskal((10, 10)) 438 ``` 439 """ 440 assert lattice_dim == 2, ( # noqa: PLR2004 441 "Kruskal's algorithm is only implemented for 2D lattices." 442 ) 443 # Convert grid_shape to a tuple of ints 444 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 445 n_rows, n_cols = grid_shape_ 446 447 # Initialize union-find data structure. 448 parent: dict[tuple[int, int], tuple[int, int]] = {} 449 450 def find(cell: tuple[int, int]) -> tuple[int, int]: 451 while parent[cell] != cell: 452 parent[cell] = parent[parent[cell]] 453 cell = parent[cell] 454 return cell 455 456 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None: 457 root1 = find(cell1) 458 root2 = find(cell2) 459 parent[root2] = root1 460 461 # Initialize each cell as its own set. 462 for i in range(n_rows): 463 for j in range(n_cols): 464 parent[(i, j)] = (i, j) 465 466 # List all possible edges. 467 # For vertical edges (i.e. connecting a cell to its right neighbor): 468 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = [] 469 for i in range(n_rows): 470 for j in range(n_cols - 1): 471 edges.append(((i, j), (i, j + 1), 1)) 472 # For horizontal edges (i.e. connecting a cell to its bottom neighbor): 473 for i in range(n_rows - 1): 474 for j in range(n_cols): 475 edges.append(((i, j), (i + 1, j), 0)) 476 477 # Shuffle the list of edges. 478 import random 479 480 random.shuffle(edges) 481 482 # Initialize connection_list with no connections. 483 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)). 484 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)). 485 import numpy as np 486 487 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 488 489 # Process each edge; if it connects two different trees, union them and carve the passage. 490 for cell1, cell2, direction in edges: 491 if find(cell1) != find(cell2): 492 union(cell1, cell2) 493 if direction == 0: 494 # Horizontal edge: connection is stored in connection_list[0] at cell1. 495 connection_list[0, cell1[0], cell1[1]] = True 496 else: 497 # Vertical edge: connection is stored in connection_list[1] at cell1. 498 connection_list[1, cell1[0], cell1[1]] = True 499 500 if start_coord is None: 501 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 502 503 generation_meta: dict = dict( 504 func_name="gen_kruskal", 505 grid_shape=grid_shape_, 506 start_coord=start_coord, 507 algorithm="kruskal", 508 fully_connected=True, 509 ) 510 return LatticeMaze( 511 connection_list=connection_list, generation_meta=generation_meta 512 ) 513 514 @staticmethod 515 def gen_recursive_division( 516 grid_shape: "Coord | CoordTup", 517 lattice_dim: int = 2, 518 start_coord: "Coord | None" = None, 519 ) -> "LatticeMaze": 520 """Generate a maze using the recursive division algorithm. 521 522 This function generates a maze by recursively dividing the grid with walls and carving a single 523 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent 524 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. 525 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells. 526 527 # Parameters: 528 - `grid_shape : Coord | CoordTup` 529 The shape of the maze grid (e.g., `(n_rows, n_cols)`). 530 - `lattice_dim : int` 531 The lattice dimension (default is `2`). 532 - `start_coord : Coord | None` 533 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen. 534 - `**kwargs` 535 Additional keyword arguments (currently unused). 536 537 # Returns: 538 - `LatticeMaze` 539 A maze represented by a connection list, generated using recursive division. 540 541 # Usage: 542 ```python 543 maze = gen_recursive_division((10, 10)) 544 ``` 545 """ 546 assert lattice_dim == 2, ( # noqa: PLR2004 547 "Recursive division algorithm is only implemented for 2D lattices." 548 ) 549 # Convert grid_shape to a tuple of ints. 550 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 551 n_rows, n_cols = grid_shape_ 552 553 # Initialize connection_list as a fully connected grid. 554 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True. 555 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True. 556 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 557 connection_list[0, : n_rows - 1, :] = True 558 connection_list[1, :, : n_cols - 1] = True 559 560 def divide(x: int, y: int, width: int, height: int) -> None: 561 """Recursively divide the region starting at (x, y) with the given width and height. 562 563 Removes connections along the chosen division line except for one randomly chosen gap. 564 """ 565 if width < 2 or height < 2: # noqa: PLR2004 566 return 567 568 if width > height: 569 # Vertical division. 570 wall_col = random.randint(x + 1, x + width - 1) 571 gap_row = random.randint(y, y + height - 1) 572 for row in range(y, y + height): 573 if row == gap_row: 574 continue 575 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col). 576 if wall_col - 1 < n_cols - 1: 577 connection_list[1, row, wall_col - 1] = False 578 # Recurse on the left and right subregions. 579 divide(x, y, wall_col - x, height) 580 divide(wall_col, y, x + width - wall_col, height) 581 else: 582 # Horizontal division. 583 wall_row = random.randint(y + 1, y + height - 1) 584 gap_col = random.randint(x, x + width - 1) 585 for col in range(x, x + width): 586 if col == gap_col: 587 continue 588 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col). 589 if wall_row - 1 < n_rows - 1: 590 connection_list[0, wall_row - 1, col] = False 591 # Recurse on the top and bottom subregions. 592 divide(x, y, width, wall_row - y) 593 divide(x, wall_row, width, y + height - wall_row) 594 595 # Begin the division on the full grid. 596 divide(0, 0, n_cols, n_rows) 597 598 if start_coord is None: 599 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 600 601 generation_meta: dict = dict( 602 func_name="gen_recursive_division", 603 grid_shape=grid_shape_, 604 start_coord=start_coord, 605 algorithm="recursive_division", 606 fully_connected=True, 607 ) 608 return LatticeMaze( 609 connection_list=connection_list, generation_meta=generation_meta 610 )
namespace for lattice maze generation algorithms
examples of generated mazes can be found here: https://understanding-search.github.io/maze-dataset/examples/maze_examples.html
61 @staticmethod 62 def gen_dfs( 63 grid_shape: Coord | CoordTup, 64 lattice_dim: int = 2, 65 accessible_cells: float | None = None, 66 max_tree_depth: float | None = None, 67 do_forks: bool = True, 68 randomized_stack: bool = False, 69 start_coord: Coord | None = None, 70 ) -> LatticeMaze: 71 """generate a lattice maze using depth first search, iterative 72 73 # Arguments 74 - `grid_shape: Coord`: the shape of the grid 75 - `lattice_dim: int`: the dimension of the lattice 76 (default: `2`) 77 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 78 (default: `None`) 79 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 80 (default: `None`) 81 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 82 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 83 84 # algorithm 85 1. Choose the initial cell, mark it as visited and push it to the stack 86 2. While the stack is not empty 87 1. Pop a cell from the stack and make it a current cell 88 2. If the current cell has any neighbours which have not been visited 89 1. Push the current cell to the stack 90 2. Choose one of the unvisited neighbours 91 3. Remove the wall between the current cell and the chosen cell 92 4. Mark the chosen cell as visited and push it to the stack 93 """ 94 # Default values if no constraints have been passed 95 grid_shape_: Coord = np.array(grid_shape) 96 n_total_cells: int = int(np.prod(grid_shape_)) 97 98 n_accessible_cells: int 99 if accessible_cells is None: 100 n_accessible_cells = n_total_cells 101 elif isinstance(accessible_cells, float): 102 assert accessible_cells <= 1, ( 103 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 104 ) 105 106 n_accessible_cells = int(accessible_cells * n_total_cells) 107 else: 108 assert isinstance(accessible_cells, int) 109 n_accessible_cells = accessible_cells 110 111 if max_tree_depth is None: 112 max_tree_depth = ( 113 2 * n_total_cells 114 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 115 elif isinstance(max_tree_depth, float): 116 assert max_tree_depth <= 1, ( 117 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 118 ) 119 120 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 121 122 # choose a random start coord 123 start_coord = _random_start_coord(grid_shape_, start_coord) 124 125 # initialize the maze with no connections 126 connection_list: ConnectionList = np.zeros( 127 (lattice_dim, grid_shape_[0], grid_shape_[1]), 128 dtype=np.bool_, 129 ) 130 131 # initialize the stack with the target coord 132 visited_cells: set[tuple[int, int]] = set() 133 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 134 stack: list[Coord] = [start_coord] 135 136 # initialize tree_depth_counter 137 current_tree_depth: int = 1 138 139 # loop until the stack is empty or n_connected_cells is reached 140 while stack and (len(visited_cells) < n_accessible_cells): 141 # get the current coord from the stack 142 current_coord: Coord 143 if randomized_stack: 144 current_coord = stack.pop(random.randint(0, len(stack) - 1)) 145 else: 146 current_coord = stack.pop() 147 148 # filter neighbors by being within grid bounds and being unvisited 149 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 150 (neighbor, delta) 151 for neighbor, delta in zip( 152 current_coord + NEIGHBORS_MASK, 153 NEIGHBORS_MASK, 154 strict=False, 155 ) 156 if ( 157 (tuple(neighbor) not in visited_cells) 158 and (0 <= neighbor[0] < grid_shape_[0]) 159 and (0 <= neighbor[1] < grid_shape_[1]) 160 ) 161 ] 162 163 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 164 if unvisited_neighbors_deltas and ( 165 current_tree_depth <= max_tree_depth / 2 166 ): 167 # if we want a maze without forks, simply don't add the current coord back to the stack 168 if do_forks and (len(unvisited_neighbors_deltas) > 1): 169 stack.append(current_coord) 170 171 # choose one of the unvisited neighbors 172 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) 173 174 # add connection 175 dim: int = int(np.argmax(np.abs(delta))) 176 # if positive, down/right from current coord 177 # if negative, up/left from current coord (down/right from neighbor) 178 clist_node: Coord = ( 179 current_coord if (delta.sum() > 0) else chosen_neighbor 180 ) 181 connection_list[dim, clist_node[0], clist_node[1]] = True 182 183 # add to visited cells and stack 184 visited_cells.add(tuple(chosen_neighbor)) 185 stack.append(chosen_neighbor) 186 187 # Update current tree depth 188 current_tree_depth += 1 189 else: 190 current_tree_depth -= 1 191 192 return LatticeMaze( 193 connection_list=connection_list, 194 generation_meta=dict( 195 func_name="gen_dfs", 196 grid_shape=grid_shape_, 197 start_coord=start_coord, 198 n_accessible_cells=int(n_accessible_cells), 199 max_tree_depth=int(max_tree_depth), 200 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 201 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 202 # treated as fully connected even when it is most certainly not, causing solving the maze to break 203 fully_connected=bool(len(visited_cells) == n_total_cells), 204 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 205 ), 206 )
generate a lattice maze using depth first search, iterative
Arguments
grid_shape: Coord
: the shape of the gridlattice_dim: int
: the dimension of the lattice (default:2
)accessible_cells: int | float |None
: the number of accessible cells in the maze. IfNone
, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default:None
)max_tree_depth: int | float | None
: the maximum depth of the tree. IfNone
, defaults to2 * accessible_cells
. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default:None
)do_forks: bool
: whether to allow forks in the maze. IfFalse
, the maze will be have no forks and will be a simple hallway.start_coord: Coord | None
: the starting coordinate of the generation algorithm. IfNone
, defaults to a random coordinate.
algorithm
- Choose the initial cell, mark it as visited and push it to the stack
- While the stack is not empty
- Pop a cell from the stack and make it a current cell
- If the current cell has any neighbours which have not been visited
- Push the current cell to the stack
- Choose one of the unvisited neighbours
- Remove the wall between the current cell and the chosen cell
- Mark the chosen cell as visited and push it to the stack
208 @staticmethod 209 def gen_prim( 210 grid_shape: Coord | CoordTup, 211 lattice_dim: int = 2, 212 accessible_cells: float | None = None, 213 max_tree_depth: float | None = None, 214 do_forks: bool = True, 215 start_coord: Coord | None = None, 216 ) -> LatticeMaze: 217 "(broken!) generate a lattice maze using Prim's algorithm" 218 warnings.warn( 219 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 220 ) 221 return LatticeMazeGenerators.gen_dfs( 222 grid_shape=grid_shape, 223 lattice_dim=lattice_dim, 224 accessible_cells=accessible_cells, 225 max_tree_depth=max_tree_depth, 226 do_forks=do_forks, 227 start_coord=start_coord, 228 randomized_stack=True, 229 )
(broken!) generate a lattice maze using Prim's algorithm
231 @staticmethod 232 def gen_wilson( 233 grid_shape: Coord | CoordTup, 234 **kwargs, 235 ) -> LatticeMaze: 236 """Generate a lattice maze using Wilson's algorithm. 237 238 # Algorithm 239 Wilson's algorithm generates an unbiased (random) maze 240 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 241 acyclic and all cells are part of a unique connected space. 242 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 243 """ 244 assert not kwargs, ( 245 f"gen_wilson does not take any additional arguments, got {kwargs = }" 246 ) 247 248 grid_shape_: Coord = np.array(grid_shape) 249 250 # Initialize grid and visited cells 251 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 252 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 253 254 # Choose a random cell and mark it as visited 255 start_coord: Coord = _random_start_coord(grid_shape_, None) 256 visited[start_coord[0], start_coord[1]] = True 257 del start_coord 258 259 while not visited.all(): 260 # Perform loop-erased random walk from another random cell 261 262 # Choose walk_start only from unvisited cells 263 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 264 walk_start: Coord = unvisited_coords[ 265 np.random.choice(unvisited_coords.shape[0]) 266 ] 267 268 # Perform the random walk 269 path: list[Coord] = [walk_start] 270 current: Coord = walk_start 271 272 # exit the loop once the current path hits a visited cell 273 while not visited[current[0], current[1]]: 274 # find a valid neighbor (one always exists on a lattice) 275 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 276 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 277 278 # Check for loop 279 loop_exit: int | None = None 280 for i, p in enumerate(path): 281 if np.array_equal(next_cell, p): 282 loop_exit = i 283 break 284 285 # erase the loop, or continue the walk 286 if loop_exit is not None: 287 # this removes everything after and including the loop start 288 path = path[: loop_exit + 1] 289 # reset current cell to end of path 290 current = path[-1] 291 else: 292 path.append(next_cell) 293 current = next_cell 294 295 # Add the path to the maze 296 for i in range(len(path) - 1): 297 c_1: Coord = path[i] 298 c_2: Coord = path[i + 1] 299 300 # find the dimension of the connection 301 delta: Coord = c_2 - c_1 302 dim: int = int(np.argmax(np.abs(delta))) 303 304 # if positive, down/right from current coord 305 # if negative, up/left from current coord (down/right from neighbor) 306 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 307 connection_list[dim, clist_node[0], clist_node[1]] = True 308 visited[c_1[0], c_1[1]] = True 309 # we dont add c_2 because the last c_2 will have already been visited 310 311 return LatticeMaze( 312 connection_list=connection_list, 313 generation_meta=dict( 314 func_name="gen_wilson", 315 grid_shape=grid_shape_, 316 fully_connected=True, 317 ), 318 )
Generate a lattice maze using Wilson's algorithm.
Algorithm
Wilson's algorithm generates an unbiased (random) maze sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is acyclic and all cells are part of a unique connected space. https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
320 @staticmethod 321 def gen_percolation( 322 grid_shape: Coord | CoordTup, 323 p: float = 0.4, 324 lattice_dim: int = 2, 325 start_coord: Coord | None = None, 326 ) -> LatticeMaze: 327 """generate a lattice maze using simple percolation 328 329 note that p in the range (0.4, 0.7) gives the most interesting mazes 330 331 # Arguments 332 - `grid_shape: Coord`: the shape of the grid 333 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 334 - `p: float`: the probability of a cell being accessible (default: `0.5`) 335 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 336 """ 337 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 338 grid_shape_: Coord = np.array(grid_shape) 339 340 start_coord = _random_start_coord(grid_shape_, start_coord) 341 342 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 343 344 connection_list = _fill_edges_with_walls(connection_list) 345 346 output: LatticeMaze = LatticeMaze( 347 connection_list=connection_list, 348 generation_meta=dict( 349 func_name="gen_percolation", 350 grid_shape=grid_shape_, 351 percolation_p=p, 352 start_coord=start_coord, 353 ), 354 ) 355 356 # generation_meta is sometimes None, but not here since we just made it a dict above 357 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 358 start_coord, 359 ) 360 361 return output
generate a lattice maze using simple percolation
note that p in the range (0.4, 0.7) gives the most interesting mazes
Arguments
grid_shape: Coord
: the shape of the gridlattice_dim: int
: the dimension of the lattice (default:2
)p: float
: the probability of a cell being accessible (default:0.5
)start_coord: Coord | None
: the starting coordinate for the connected component (default:None
will give a random start)
363 @staticmethod 364 def gen_dfs_percolation( 365 grid_shape: Coord | CoordTup, 366 p: float = 0.4, 367 lattice_dim: int = 2, 368 accessible_cells: int | None = None, 369 max_tree_depth: int | None = None, 370 start_coord: Coord | None = None, 371 ) -> LatticeMaze: 372 """dfs and then percolation (adds cycles)""" 373 grid_shape_: Coord = np.array(grid_shape) 374 start_coord = _random_start_coord(grid_shape_, start_coord) 375 376 # generate initial maze via dfs 377 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 378 grid_shape=grid_shape_, 379 lattice_dim=lattice_dim, 380 accessible_cells=accessible_cells, 381 max_tree_depth=max_tree_depth, 382 start_coord=start_coord, 383 ) 384 385 # percolate 386 connection_list_perc: np.ndarray = ( 387 np.random.rand(*maze.connection_list.shape) < p 388 ) 389 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 390 391 maze.__dict__["connection_list"] = np.logical_or( 392 maze.connection_list, 393 connection_list_perc, 394 ) 395 396 # generation_meta is sometimes None, but not here since we just made it a dict above 397 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 398 maze.generation_meta["percolation_p"] = p # type: ignore[index] 399 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 400 start_coord, 401 ) 402 403 return maze
dfs and then percolation (adds cycles)
405 @staticmethod 406 def gen_kruskal( 407 grid_shape: "Coord | CoordTup", 408 lattice_dim: int = 2, 409 start_coord: "Coord | None" = None, 410 ) -> "LatticeMaze": 411 """Generate a maze using Kruskal's algorithm. 412 413 This function generates a random spanning tree over a grid using Kruskal's algorithm. 414 Each cell is treated as a node, and all valid adjacent edges are listed and processed 415 in random order. An edge is added (i.e. its passage carved) only if it connects two cells 416 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) 417 without cycles. 418 419 https://en.wikipedia.org/wiki/Kruskal's_algorithm 420 421 # Parameters: 422 - `grid_shape : Coord | CoordTup` 423 The shape of the maze grid (for example, `(n_rows, n_cols)`). 424 - `lattice_dim : int` 425 The lattice dimension (default is `2`). 426 - `start_coord : Coord | None` 427 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen. 428 - `**kwargs` 429 Additional keyword arguments (currently unused). 430 431 # Returns: 432 - `LatticeMaze` 433 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm. 434 435 # Usage: 436 ```python 437 maze = gen_kruskal((10, 10)) 438 ``` 439 """ 440 assert lattice_dim == 2, ( # noqa: PLR2004 441 "Kruskal's algorithm is only implemented for 2D lattices." 442 ) 443 # Convert grid_shape to a tuple of ints 444 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 445 n_rows, n_cols = grid_shape_ 446 447 # Initialize union-find data structure. 448 parent: dict[tuple[int, int], tuple[int, int]] = {} 449 450 def find(cell: tuple[int, int]) -> tuple[int, int]: 451 while parent[cell] != cell: 452 parent[cell] = parent[parent[cell]] 453 cell = parent[cell] 454 return cell 455 456 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None: 457 root1 = find(cell1) 458 root2 = find(cell2) 459 parent[root2] = root1 460 461 # Initialize each cell as its own set. 462 for i in range(n_rows): 463 for j in range(n_cols): 464 parent[(i, j)] = (i, j) 465 466 # List all possible edges. 467 # For vertical edges (i.e. connecting a cell to its right neighbor): 468 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = [] 469 for i in range(n_rows): 470 for j in range(n_cols - 1): 471 edges.append(((i, j), (i, j + 1), 1)) 472 # For horizontal edges (i.e. connecting a cell to its bottom neighbor): 473 for i in range(n_rows - 1): 474 for j in range(n_cols): 475 edges.append(((i, j), (i + 1, j), 0)) 476 477 # Shuffle the list of edges. 478 import random 479 480 random.shuffle(edges) 481 482 # Initialize connection_list with no connections. 483 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)). 484 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)). 485 import numpy as np 486 487 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 488 489 # Process each edge; if it connects two different trees, union them and carve the passage. 490 for cell1, cell2, direction in edges: 491 if find(cell1) != find(cell2): 492 union(cell1, cell2) 493 if direction == 0: 494 # Horizontal edge: connection is stored in connection_list[0] at cell1. 495 connection_list[0, cell1[0], cell1[1]] = True 496 else: 497 # Vertical edge: connection is stored in connection_list[1] at cell1. 498 connection_list[1, cell1[0], cell1[1]] = True 499 500 if start_coord is None: 501 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 502 503 generation_meta: dict = dict( 504 func_name="gen_kruskal", 505 grid_shape=grid_shape_, 506 start_coord=start_coord, 507 algorithm="kruskal", 508 fully_connected=True, 509 ) 510 return LatticeMaze( 511 connection_list=connection_list, generation_meta=generation_meta 512 )
Generate a maze using Kruskal's algorithm.
This function generates a random spanning tree over a grid using Kruskal's algorithm. Each cell is treated as a node, and all valid adjacent edges are listed and processed in random order. An edge is added (i.e. its passage carved) only if it connects two cells that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) without cycles.
https://en.wikipedia.org/wiki/Kruskal's_algorithm
Parameters:
grid_shape : Coord | CoordTup
The shape of the maze grid (for example,(n_rows, n_cols)
).lattice_dim : int
The lattice dimension (default is2
).start_coord : Coord | None
Optionally, specify a starting coordinate. IfNone
, a random coordinate will be chosen.**kwargs
Additional keyword arguments (currently unused).
Returns:
LatticeMaze
A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
Usage:
maze = gen_kruskal((10, 10))
514 @staticmethod 515 def gen_recursive_division( 516 grid_shape: "Coord | CoordTup", 517 lattice_dim: int = 2, 518 start_coord: "Coord | None" = None, 519 ) -> "LatticeMaze": 520 """Generate a maze using the recursive division algorithm. 521 522 This function generates a maze by recursively dividing the grid with walls and carving a single 523 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent 524 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. 525 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells. 526 527 # Parameters: 528 - `grid_shape : Coord | CoordTup` 529 The shape of the maze grid (e.g., `(n_rows, n_cols)`). 530 - `lattice_dim : int` 531 The lattice dimension (default is `2`). 532 - `start_coord : Coord | None` 533 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen. 534 - `**kwargs` 535 Additional keyword arguments (currently unused). 536 537 # Returns: 538 - `LatticeMaze` 539 A maze represented by a connection list, generated using recursive division. 540 541 # Usage: 542 ```python 543 maze = gen_recursive_division((10, 10)) 544 ``` 545 """ 546 assert lattice_dim == 2, ( # noqa: PLR2004 547 "Recursive division algorithm is only implemented for 2D lattices." 548 ) 549 # Convert grid_shape to a tuple of ints. 550 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 551 n_rows, n_cols = grid_shape_ 552 553 # Initialize connection_list as a fully connected grid. 554 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True. 555 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True. 556 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 557 connection_list[0, : n_rows - 1, :] = True 558 connection_list[1, :, : n_cols - 1] = True 559 560 def divide(x: int, y: int, width: int, height: int) -> None: 561 """Recursively divide the region starting at (x, y) with the given width and height. 562 563 Removes connections along the chosen division line except for one randomly chosen gap. 564 """ 565 if width < 2 or height < 2: # noqa: PLR2004 566 return 567 568 if width > height: 569 # Vertical division. 570 wall_col = random.randint(x + 1, x + width - 1) 571 gap_row = random.randint(y, y + height - 1) 572 for row in range(y, y + height): 573 if row == gap_row: 574 continue 575 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col). 576 if wall_col - 1 < n_cols - 1: 577 connection_list[1, row, wall_col - 1] = False 578 # Recurse on the left and right subregions. 579 divide(x, y, wall_col - x, height) 580 divide(wall_col, y, x + width - wall_col, height) 581 else: 582 # Horizontal division. 583 wall_row = random.randint(y + 1, y + height - 1) 584 gap_col = random.randint(x, x + width - 1) 585 for col in range(x, x + width): 586 if col == gap_col: 587 continue 588 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col). 589 if wall_row - 1 < n_rows - 1: 590 connection_list[0, wall_row - 1, col] = False 591 # Recurse on the top and bottom subregions. 592 divide(x, y, width, wall_row - y) 593 divide(x, wall_row, width, y + height - wall_row) 594 595 # Begin the division on the full grid. 596 divide(0, 0, n_cols, n_rows) 597 598 if start_coord is None: 599 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 600 601 generation_meta: dict = dict( 602 func_name="gen_recursive_division", 603 grid_shape=grid_shape_, 604 start_coord=start_coord, 605 algorithm="recursive_division", 606 fully_connected=True, 607 ) 608 return LatticeMaze( 609 connection_list=connection_list, generation_meta=generation_meta 610 )
Generate a maze using the recursive division algorithm.
This function generates a maze by recursively dividing the grid with walls and carving a single passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
Parameters:
grid_shape : Coord | CoordTup
The shape of the maze grid (e.g.,(n_rows, n_cols)
).lattice_dim : int
The lattice dimension (default is2
).start_coord : Coord | None
Optionally, specify a starting coordinate. IfNone
, a random coordinate is chosen.**kwargs
Additional keyword arguments (currently unused).
Returns:
LatticeMaze
A maze represented by a connection list, generated using recursive division.
Usage:
maze = gen_recursive_division((10, 10))