maze_dataset
maze-dataset
This package provides utilities for generation, filtering, solving, visualizing, and processing of mazes for training or evaluating ML systems. Primarily built for the maze-transformer interpretability project. You can find our paper on it here: http://arxiv.org/abs/2309.10498
This package includes a variety of maze generation algorithms, including randomized depth first search, Wilson's algorithm for uniform spanning trees, and percolation. Datasets can be filtered to select mazes of a certain length or complexity, remove duplicates, and satisfy custom properties. A variety of output formats for visualization and training ML models are provided.
![]() |
![]() |
![]() |
![]() |
You can view and search through a wide variety of example mazes here: understanding-search.github.io/maze-dataset/examples/maze_examples
Citing
If you use this code in your research, please cite our paper:
@misc{maze-dataset,
title={A Configurable Library for Generating and Manipulating Maze Datasets},
author={Michael Igorevich Ivanitskiy and Rusheb Shah and Alex F. Spies and Tilman Räuker and Dan Valentine and Can Rager and Lucia Quirke and Chris Mathwin and Guillaume Corlouer and Cecilia Diniz Behn and Samy Wu Fung},
year={2023},
eprint={2309.10498},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={http://arxiv.org/abs/2309.10498}
}
Installation
This package is available on PyPI, and can be installed via
pip install maze-dataset
Please note that due to an issue with the
rust-fst
package, some tokenization features are not available on macOS. Please see #57
Docs
The full hosted documentation is available at https://understanding-search.github.io/maze-dataset/.
Additionally, our notebooks serve as a good starting point for understanding the package.
Usage
Creating a dataset
To create a MazeDataset
, you first create a MazeDatasetConfig
:
from maze_dataset import MazeDataset, MazeDatasetConfig
from maze_dataset.generation import LatticeMazeGenerators
cfg: MazeDatasetConfig = MazeDatasetConfig(
name="test", # name is only for you to keep track of things
grid_n=5, # number of rows/columns in the lattice
n_mazes=4, # number of mazes to generate
maze_ctor=LatticeMazeGenerators.gen_dfs, # algorithm to generate the maze
maze_ctor_kwargs=dict(do_forks=False), # additional parameters to pass to the maze generation algorithm
)
and then pass this config to the MazeDataset.from_config
method:
dataset: MazeDataset = MazeDataset.from_config(cfg)
This method can search for whether a dataset with matching config hash already exists on your filesystem in the expected location, and load it if so. It can also generate a dataset on the fly if needed.
Conversions to useful formats
The elements of the dataset are SolvedMaze
objects:
>>> m = dataset[0]
>>> type(m)
SolvedMaze
Which can be converted to a variety of formats:
# visual representation as ascii art
print(m.as_ascii())
# RGB image, optionally without solution or endpoints, suitable for CNNs
import matplotlib.pyplot as plt
plt.imshow(m.as_pixels())
# text format for autoreregressive transformers
from maze_dataset.tokenization import MazeTokenizerModular, TokenizationMode, PromptSequencers
m.as_tokens(maze_tokenizer=MazeTokenizerModular(
prompt_sequencer=PromptSequencers.AOTP(), # many options here
))
# advanced visualization with many features
from maze_dataset.plotting import MazePlot
MazePlot(m).plot()
Development
We use this makefile template with slight modifications for our development workflow. This project uses uv for dependency and virtual environment management.
- clone with
git clone https://github.com/understanding-search/maze-dataset
- if you don't already have uv, install it. We only guarantee compatibility with
uv
newer than0.8.0
make dep
to install all dependenciesmake help
will print all available commandsmake test
will run basic tests to ensure the package is working- run just the unit tests with
make test-unit
- see all tests with explanations using
make help
ormake help | grep test
- run just the unit tests with
make format
will run ruff to format and check the code
Note: due to compatibility issues between the
rust_fst
package and Darwin/macOS systems, not all tests will pass on these systems. However,make test-unit
andmake test-notebooks-muutils
should still pass. Please see #57 for updates on resolving this problem.
Contributing
We welcome contributions! We use GitHub issues to track bugs and feature requests. If you have a bug fix or a new feature to contribute, please open a pull request. We are also happy to provide usage support and answer questions about the package via issues!
While we expect that the core interface of the package is stable, we are very open to adding new features. We're particularly excited about adding new maze generation algorithms and new output formats. Please feel free to both suggest new formats or algorithms, and to implement them and open PRs! For more info on how to add a new maze generation algorithm, see the documentation on generators.
We are also aware that like any piece of software, maze-dataset
is not without bugs. If something isn't working as expected, please open an issue and we will do our best to fix it. It helps us keep things tidy if you first search existing bug reports to see if your issue has already been reported.
1""".. include:: ../README.md""" 2 3from maze_dataset.constants import ( 4 SPECIAL_TOKENS, 5 VOCAB, 6 VOCAB_LIST, 7 VOCAB_TOKEN_TO_INDEX, 8 Connection, 9 ConnectionArray, 10 ConnectionList, 11 Coord, 12 CoordArray, 13 CoordList, 14 CoordTup, 15) 16from maze_dataset.dataset.collected_dataset import ( 17 MazeDatasetCollection, 18 MazeDatasetCollectionConfig, 19) 20from maze_dataset.dataset.filters import register_maze_filter 21from maze_dataset.dataset.maze_dataset import ( 22 MazeDataset, 23 MazeDatasetConfig, 24) 25from maze_dataset.dataset.maze_dataset_config import set_serialize_minimal_threshold 26from maze_dataset.generation.generators import LatticeMazeGenerators 27from maze_dataset.maze.lattice_maze import LatticeMaze, SolvedMaze, TargetedLatticeMaze 28 29__all__ = [ 30 # submodules (with sub-submodules) 31 "benchmark", 32 "dataset", 33 "generation", 34 "maze", 35 "plotting", 36 "tokenization", 37 # submodules 38 "constants", 39 "testing_utils", 40 "token_utils", 41 "utils", 42 # main 43 "SolvedMaze", 44 "MazeDatasetConfig", 45 "MazeDataset", 46 # dataset classes 47 "MazeDatasetCollection", 48 "MazeDatasetCollectionConfig", 49 # maze classes 50 "TargetedLatticeMaze", 51 "LatticeMaze", 52 # other 53 "set_serialize_minimal_threshold", 54 "register_maze_filter", 55 "LatticeMazeGenerators", 56 # types 57 "Coord", 58 "CoordTup", 59 "CoordList", 60 "CoordArray", 61 "Connection", 62 "ConnectionList", 63 "ConnectionArray", 64 # constants 65 "SPECIAL_TOKENS", 66 "VOCAB", 67 "VOCAB_LIST", 68 "VOCAB_TOKEN_TO_INDEX", 69]
1275@serializable_dataclass(frozen=True, kw_only=True) 1276class SolvedMaze(TargetedLatticeMaze): # type: ignore[misc] 1277 """Stores a maze and a solution""" 1278 1279 solution: CoordArray = serializable_field( # type: ignore[misc] 1280 assert_type=False, 1281 ) 1282 1283 def __init__( 1284 self, 1285 connection_list: ConnectionList, 1286 solution: CoordArray, 1287 generation_meta: dict | None = None, 1288 start_pos: Coord | None = None, 1289 end_pos: Coord | None = None, 1290 allow_invalid: bool = False, 1291 ) -> None: 1292 """Create a SolvedMaze from a connection list and a solution 1293 1294 > DOCS: better documentation for this init method 1295 """ 1296 # figure out the solution 1297 solution_valid: bool = False 1298 if solution is not None: 1299 solution = np.array(solution) 1300 # note that a path length of 1 here is valid, since the start and end pos could be the same 1301 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 1302 solution_valid = True 1303 1304 if not solution_valid and not allow_invalid: 1305 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" 1306 raise ValueError( 1307 err_msg, 1308 f"{connection_list = }", 1309 ) 1310 1311 # init the TargetedLatticeMaze 1312 super().__init__( 1313 connection_list=connection_list, 1314 generation_meta=generation_meta, 1315 # TODO: the argument type is stricter than the expected type but it still fails? 1316 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 1317 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 1318 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 1319 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 1320 ) 1321 1322 self.__dict__["solution"] = solution 1323 1324 # adjust the endpoints 1325 if not allow_invalid: 1326 if start_pos is not None: 1327 assert np.array_equal(np.array(start_pos), self.start_pos), ( 1328 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 1329 ) 1330 if end_pos is not None: 1331 assert np.array_equal(np.array(end_pos), self.end_pos), ( 1332 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 1333 ) 1334 # TODO: assert the path does not backtrack, walk through walls, etc? 1335 1336 def __eq__(self, other: object) -> bool: 1337 "check equality, calls parent class equality check" 1338 return super().__eq__(other) 1339 1340 def __hash__(self) -> int: 1341 "hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes" 1342 return hash((self.connection_list.tobytes(), self.solution.tobytes())) 1343 1344 def _get_solution_tokens(self) -> list[str | CoordTup]: 1345 return [ 1346 SPECIAL_TOKENS.PATH_START, 1347 *[tuple(c) for c in self.solution], 1348 SPECIAL_TOKENS.PATH_END, 1349 ] 1350 1351 def get_solution_tokens(self) -> list[str | CoordTup]: 1352 "(deprecated!) return the solution as a list of tokens" 1353 warnings.warn( 1354 "`LatticeMaze.get_solution_tokens` is deprecated.", 1355 TokenizerDeprecationWarning, 1356 ) 1357 return self._get_solution_tokens() 1358 1359 # for backwards compatibility 1360 @property 1361 def maze(self) -> LatticeMaze: 1362 "(deprecated!) return the maze without the solution" 1363 warnings.warn( 1364 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 1365 DeprecationWarning, 1366 ) 1367 return LatticeMaze(connection_list=self.connection_list) 1368 1369 # type ignore here since we're overriding a method with a different signature 1370 @classmethod 1371 def from_lattice_maze( # type: ignore[override] 1372 cls, 1373 lattice_maze: LatticeMaze, 1374 solution: list[CoordTup] | CoordArray, 1375 ) -> "SolvedMaze": 1376 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" 1377 return cls( 1378 connection_list=lattice_maze.connection_list, 1379 solution=np.array(solution), 1380 generation_meta=lattice_maze.generation_meta, 1381 ) 1382 1383 @classmethod 1384 def from_targeted_lattice_maze( 1385 cls, 1386 targeted_lattice_maze: TargetedLatticeMaze, 1387 solution: list[CoordTup] | CoordArray | None = None, 1388 ) -> "SolvedMaze": 1389 """solves the given targeted lattice maze and returns a SolvedMaze""" 1390 if solution is None: 1391 solution = targeted_lattice_maze.find_shortest_path( 1392 targeted_lattice_maze.start_pos, 1393 targeted_lattice_maze.end_pos, 1394 ) 1395 return cls( 1396 connection_list=targeted_lattice_maze.connection_list, 1397 solution=np.array(solution), 1398 generation_meta=targeted_lattice_maze.generation_meta, 1399 ) 1400 1401 def get_solution_forking_points( 1402 self, 1403 always_include_endpoints: bool = False, 1404 ) -> tuple[list[int], CoordArray]: 1405 """coordinates and their indicies from the solution where a fork is present 1406 1407 - if the start point is not a dead end, this counts as a fork 1408 - if the end point is not a dead end, this counts as a fork 1409 """ 1410 output_idxs: list[int] = list() 1411 output_coords: list[CoordTup] = list() 1412 1413 for idx, coord in enumerate(self.solution): 1414 # more than one choice for first coord, or more than 2 for any other 1415 # since the previous coord doesn't count as a choice 1416 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 1417 theshold: int = 1 if is_endpoint else 2 1418 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 1419 is_endpoint and always_include_endpoints 1420 ): 1421 output_idxs.append(idx) 1422 output_coords.append(coord) 1423 1424 return output_idxs, np.array(output_coords) 1425 1426 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 1427 """coordinates from the solution where there is only a single (non-backtracking) point to move to 1428 1429 returns the complement of `get_solution_forking_points` from the path 1430 """ 1431 forks_idxs, _ = self.get_solution_forking_points() 1432 # HACK: idk why type ignore here 1433 return ( # type: ignore[return-value] 1434 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 1435 np.delete(self.solution, forks_idxs, axis=0), 1436 )
Stores a maze and a solution
1283 def __init__( 1284 self, 1285 connection_list: ConnectionList, 1286 solution: CoordArray, 1287 generation_meta: dict | None = None, 1288 start_pos: Coord | None = None, 1289 end_pos: Coord | None = None, 1290 allow_invalid: bool = False, 1291 ) -> None: 1292 """Create a SolvedMaze from a connection list and a solution 1293 1294 > DOCS: better documentation for this init method 1295 """ 1296 # figure out the solution 1297 solution_valid: bool = False 1298 if solution is not None: 1299 solution = np.array(solution) 1300 # note that a path length of 1 here is valid, since the start and end pos could be the same 1301 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 1302 solution_valid = True 1303 1304 if not solution_valid and not allow_invalid: 1305 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" 1306 raise ValueError( 1307 err_msg, 1308 f"{connection_list = }", 1309 ) 1310 1311 # init the TargetedLatticeMaze 1312 super().__init__( 1313 connection_list=connection_list, 1314 generation_meta=generation_meta, 1315 # TODO: the argument type is stricter than the expected type but it still fails? 1316 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 1317 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 1318 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 1319 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 1320 ) 1321 1322 self.__dict__["solution"] = solution 1323 1324 # adjust the endpoints 1325 if not allow_invalid: 1326 if start_pos is not None: 1327 assert np.array_equal(np.array(start_pos), self.start_pos), ( 1328 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 1329 ) 1330 if end_pos is not None: 1331 assert np.array_equal(np.array(end_pos), self.end_pos), ( 1332 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 1333 ) 1334 # TODO: assert the path does not backtrack, walk through walls, etc?
Create a SolvedMaze from a connection list and a solution
DOCS: better documentation for this init method
1351 def get_solution_tokens(self) -> list[str | CoordTup]: 1352 "(deprecated!) return the solution as a list of tokens" 1353 warnings.warn( 1354 "`LatticeMaze.get_solution_tokens` is deprecated.", 1355 TokenizerDeprecationWarning, 1356 ) 1357 return self._get_solution_tokens()
(deprecated!) return the solution as a list of tokens
1360 @property 1361 def maze(self) -> LatticeMaze: 1362 "(deprecated!) return the maze without the solution" 1363 warnings.warn( 1364 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 1365 DeprecationWarning, 1366 ) 1367 return LatticeMaze(connection_list=self.connection_list)
(deprecated!) return the maze without the solution
1370 @classmethod 1371 def from_lattice_maze( # type: ignore[override] 1372 cls, 1373 lattice_maze: LatticeMaze, 1374 solution: list[CoordTup] | CoordArray, 1375 ) -> "SolvedMaze": 1376 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" 1377 return cls( 1378 connection_list=lattice_maze.connection_list, 1379 solution=np.array(solution), 1380 generation_meta=lattice_maze.generation_meta, 1381 )
get a SolvedMaze
from a LatticeMaze
by specifying a solution
1383 @classmethod 1384 def from_targeted_lattice_maze( 1385 cls, 1386 targeted_lattice_maze: TargetedLatticeMaze, 1387 solution: list[CoordTup] | CoordArray | None = None, 1388 ) -> "SolvedMaze": 1389 """solves the given targeted lattice maze and returns a SolvedMaze""" 1390 if solution is None: 1391 solution = targeted_lattice_maze.find_shortest_path( 1392 targeted_lattice_maze.start_pos, 1393 targeted_lattice_maze.end_pos, 1394 ) 1395 return cls( 1396 connection_list=targeted_lattice_maze.connection_list, 1397 solution=np.array(solution), 1398 generation_meta=targeted_lattice_maze.generation_meta, 1399 )
solves the given targeted lattice maze and returns a SolvedMaze
1401 def get_solution_forking_points( 1402 self, 1403 always_include_endpoints: bool = False, 1404 ) -> tuple[list[int], CoordArray]: 1405 """coordinates and their indicies from the solution where a fork is present 1406 1407 - if the start point is not a dead end, this counts as a fork 1408 - if the end point is not a dead end, this counts as a fork 1409 """ 1410 output_idxs: list[int] = list() 1411 output_coords: list[CoordTup] = list() 1412 1413 for idx, coord in enumerate(self.solution): 1414 # more than one choice for first coord, or more than 2 for any other 1415 # since the previous coord doesn't count as a choice 1416 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 1417 theshold: int = 1 if is_endpoint else 2 1418 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 1419 is_endpoint and always_include_endpoints 1420 ): 1421 output_idxs.append(idx) 1422 output_coords.append(coord) 1423 1424 return output_idxs, np.array(output_coords)
coordinates and their indicies from the solution where a fork is present
- if the start point is not a dead end, this counts as a fork
- if the end point is not a dead end, this counts as a fork
1426 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 1427 """coordinates from the solution where there is only a single (non-backtracking) point to move to 1428 1429 returns the complement of `get_solution_forking_points` from the path 1430 """ 1431 forks_idxs, _ = self.get_solution_forking_points() 1432 # HACK: idk why type ignore here 1433 return ( # type: ignore[return-value] 1434 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 1435 np.delete(self.solution, forks_idxs, axis=0), 1436 )
coordinates from the solution where there is only a single (non-backtracking) point to move to
returns the complement of get_solution_forking_points
from the path
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
The type of the None singleton.
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
The type of the None singleton.
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- LatticeMaze
- connection_list
- generation_meta
- lattice_dim
- grid_shape
- n_connections
- grid_n
- heuristic
- nodes_connected
- is_valid_path
- coord_degrees
- get_coord_neighbors
- gen_connected_component_from
- find_shortest_path
- get_nodes
- get_connected_component
- generate_random_path
- as_adj_list
- from_adj_list
- as_adj_list_tokens
- as_tokens
- from_tokens
- as_pixels
- from_pixels
- as_ascii
- from_ascii
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
257@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 258class MazeDatasetConfig(MazeDatasetConfig_base): # type: ignore[misc] 259 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset 260 261 # Parameters: 262 - `name : str` 263 name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname` 264 - `grid_n : int` 265 grid size of the maze (number of rows/columns) 266 - `n_mazes : int` 267 number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate. 268 see `EndpointKwargsType` for more details. 269 - `maze_ctor : Callable` 270 maze generator function. This should be a function that takes a grid size and returns a maze. 271 This will usually be one of the functions in `LatticeMazeGenerators`. 272 - `maze_ctor_kwargs : dict` 273 keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using. 274 - `endpoint_kwargs : EndpointKwargsType` 275 keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info. 276 - `applied_filters : list[dict]` 277 list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, 278 but these are stored with the config in case you want to re-generate the dataset with the same filters. 279 280 """ 281 282 @property 283 def config_version(self) -> str: 284 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 285 return "1.0" 286 287 @property 288 def versions(self) -> dict: 289 """return the versions of the config and the maze_dataset""" 290 return dict( 291 config=self.config_version, 292 maze_dataset=importlib.metadata.version("maze_dataset"), 293 ) 294 295 def serialize(self) -> dict: 296 "serialize the MazeDatasetConfig with all fields and fname" 297 return { 298 **self._serialize_base( 299 applied_filters__skip__collect_generation_meta=False 300 ), 301 "fname": self.to_fname(), 302 "versions": self.versions, 303 } 304 305 def summary(self) -> dict: 306 """return a summary of the config""" 307 # do we run this to make sure it doesn't error? 308 super_summary: dict = super().summary() 309 assert super_summary 310 self_ser: dict = self.serialize() 311 return dict( 312 name=self.name, 313 fname=self.to_fname(), 314 sdc_hash=self.stable_hash_cfg(), 315 seed=self.seed, 316 seq_len_min=self.seq_len_min, 317 seq_len_max=self.seq_len_max, 318 applied_filters=self.applied_filters, 319 grid_n=self_ser["grid_n"], 320 n_mazes=self_ser["n_mazes"], 321 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 322 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 323 endpoint_kwargs=self_ser["endpoint_kwargs"], 324 ) 325 326 def _to_ps_array(self) -> _PercolationSuccessArray: 327 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 328 329 used in predicting the success rate 330 """ 331 try: 332 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 333 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 334 ) 335 assert "p" in self.maze_ctor_kwargs, ( 336 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 337 ) 338 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 339 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 340 ) 341 except AssertionError as e: 342 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 343 raise NoPercolationInConfigError( 344 err_msg, 345 ) from e 346 347 endpoints_unique_flag: int = int( 348 # we are pretty sure it will be an int or bool here 349 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 350 ) 351 352 # adjustment for bknutson0 353 if not ( 354 self.endpoint_kwargs.get("deadend_start", False) 355 and self.endpoint_kwargs.get("deadend_end", False) 356 ): 357 # we didnt train on this, but if either endpoint is not required to be in a dead end 358 # then requiring the endpoints to be unique does not really affect the success rate 359 # (except for very small percolation values, pure percolation generation) 360 endpoints_unique_flag = 0 361 362 return np.array( 363 [ 364 float(self.maze_ctor_kwargs["p"]), 365 float(self.grid_n), 366 float( 367 int( 368 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 369 or self.endpoint_kwargs.get("deadend_end", False), 370 ), 371 ), 372 float(endpoints_unique_flag), 373 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 374 ], 375 dtype=np.float64, 376 ) 377 378 @classmethod 379 def _from_ps_array( 380 cls, 381 arr: _PercolationSuccessArray, 382 name: str = "predict", 383 n_mazes: int = 100, 384 **kwargs, 385 ) -> "MazeDatasetConfig": 386 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 387 388 # Returns: 389 - `MazeDatasetConfig` 390 Config corresponding to `arr` 391 """ 392 return cls( 393 name=name, 394 grid_n=int(arr[1]), 395 n_mazes=n_mazes, 396 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 397 maze_ctor_kwargs={"p": float(arr[0])}, 398 endpoint_kwargs=dict( 399 deadend_start=bool(arr[2]), 400 deadend_end=bool(arr[2]), 401 endpoints_not_equal=bool(arr[3]), 402 except_on_no_valid_endpoint=False, 403 ), 404 **kwargs, 405 ) 406 407 def success_fraction_estimate( 408 self, 409 except_if_all_success_expected: bool = False, 410 ) -> float: 411 """Estimate the success fraction of this config. 412 413 only valid when the generator is a percolation generator, 414 and endpoints are enforced to be dead ends 415 416 more information on where this comes from can be found in 417 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 418 - `estimate_dataset_fractions.ipynb` 419 - `maze_dataset.benchmarks.sweep_fit` 420 421 # Parameters: 422 - `except_if_all_success_expected : bool` 423 if `True`, don't raise an error if the success fraction is below the threshold. 424 will always return `1.0` if the config is not expected to fail 425 426 # Returns: 427 - `float` 428 estimated success fraction 429 430 # Raises: 431 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 432 """ 433 try: 434 return cfg_success_predict_fn(self) 435 436 except NoPercolationInConfigError as e: 437 if except_if_all_success_expected: 438 raise e # noqa: TRY201 439 return 1.0 440 441 def success_fraction_compensate( 442 self, 443 safety_margin: float = 1.2, 444 except_if_all_success_expected: bool = False, 445 epsilon: float = 1e-2, 446 ) -> "MazeDatasetConfig": 447 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 448 449 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 450 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 451 452 more information on where this comes from can be found in 453 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 454 - `estimate_dataset_fractions.ipynb` 455 - `maze_dataset.benchmarks.sweep_fit` 456 457 # Parameters: 458 - `safety_margin : float` 459 safety margin to apply to the success fraction estimate 460 (defaults to `1.2`, or 20% more mazes than estimated) 461 - `except_if_all_success_expected : bool` 462 if `True`, don't raise an error if the success fraction is below the threshold. 463 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 464 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 465 since `safety_margin` is still applied. 466 (defaults to `False`) 467 - `epsilon : float` 468 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 469 (defaults to `1e-2`) 470 471 # Returns: 472 - `MazeDatasetConfig` 473 new config with adjusted `n_mazes` 474 475 # Raises: 476 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 477 """ 478 # compute and check the success fraction 479 success_fraction: float = self.success_fraction_estimate( 480 except_if_all_success_expected=except_if_all_success_expected, 481 ) 482 if success_fraction < epsilon: 483 err_msg: str = ( 484 f"{success_fraction = } is below the threshold of {epsilon = }" 485 ) 486 raise SuccessChanceTooSmallError( 487 err_msg, 488 ) 489 490 # compute the new number of mazes 491 n_mazes: int = self.n_mazes 492 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 493 494 # put it in a new config and return 495 cfg_dict: dict = self.serialize() 496 cfg_dict["n_mazes"] = new_n_mazes 497 return MazeDatasetConfig.load(cfg_dict)
config object which is passed to MazeDataset.from_config
to generate or load a dataset
Parameters:
name : str
name of the dataset -- this can be anything, but should be filesystem safe since we use it in thefname
grid_n : int
grid size of the maze (number of rows/columns)n_mazes : int
number of mazes to request. For some combinations ofendpoint_kwargs
andmaze_ctor
, not all mazes might successfully generate. seeEndpointKwargsType
for more details.maze_ctor : Callable
maze generator function. This should be a function that takes a grid size and returns a maze. This will usually be one of the functions inLatticeMazeGenerators
.maze_ctor_kwargs : dict
keyword arguments to pass to the maze generator function. Specific to themaze_ctor
you are using.endpoint_kwargs : EndpointKwargsType
keyword arguments passed toLatticeMaze.generate_random_path()
. seeEndpointKwargsType
for more info.applied_filters : list[dict]
list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, but these are stored with the config in case you want to re-generate the dataset with the same filters.
282 @property 283 def config_version(self) -> str: 284 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 285 return "1.0"
return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config
287 @property 288 def versions(self) -> dict: 289 """return the versions of the config and the maze_dataset""" 290 return dict( 291 config=self.config_version, 292 maze_dataset=importlib.metadata.version("maze_dataset"), 293 )
return the versions of the config and the maze_dataset
295 def serialize(self) -> dict: 296 "serialize the MazeDatasetConfig with all fields and fname" 297 return { 298 **self._serialize_base( 299 applied_filters__skip__collect_generation_meta=False 300 ), 301 "fname": self.to_fname(), 302 "versions": self.versions, 303 }
serialize the MazeDatasetConfig with all fields and fname
305 def summary(self) -> dict: 306 """return a summary of the config""" 307 # do we run this to make sure it doesn't error? 308 super_summary: dict = super().summary() 309 assert super_summary 310 self_ser: dict = self.serialize() 311 return dict( 312 name=self.name, 313 fname=self.to_fname(), 314 sdc_hash=self.stable_hash_cfg(), 315 seed=self.seed, 316 seq_len_min=self.seq_len_min, 317 seq_len_max=self.seq_len_max, 318 applied_filters=self.applied_filters, 319 grid_n=self_ser["grid_n"], 320 n_mazes=self_ser["n_mazes"], 321 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 322 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 323 endpoint_kwargs=self_ser["endpoint_kwargs"], 324 )
return a summary of the config
407 def success_fraction_estimate( 408 self, 409 except_if_all_success_expected: bool = False, 410 ) -> float: 411 """Estimate the success fraction of this config. 412 413 only valid when the generator is a percolation generator, 414 and endpoints are enforced to be dead ends 415 416 more information on where this comes from can be found in 417 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 418 - `estimate_dataset_fractions.ipynb` 419 - `maze_dataset.benchmarks.sweep_fit` 420 421 # Parameters: 422 - `except_if_all_success_expected : bool` 423 if `True`, don't raise an error if the success fraction is below the threshold. 424 will always return `1.0` if the config is not expected to fail 425 426 # Returns: 427 - `float` 428 estimated success fraction 429 430 # Raises: 431 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 432 """ 433 try: 434 return cfg_success_predict_fn(self) 435 436 except NoPercolationInConfigError as e: 437 if except_if_all_success_expected: 438 raise e # noqa: TRY201 439 return 1.0
Estimate the success fraction of this config.
only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends
more information on where this comes from can be found in
cfg_success_predict_fn()
frommaze_dataset.dataset.success_predict_math
estimate_dataset_fractions.ipynb
maze_dataset.benchmarks.sweep_fit
Parameters:
except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. will always return1.0
if the config is not expected to fail
Returns:
float
estimated success fraction
Raises:
NoPercolationInConfigError
: if the config is not expected to fail, andexcept_if_all_success_expected
isFalse
441 def success_fraction_compensate( 442 self, 443 safety_margin: float = 1.2, 444 except_if_all_success_expected: bool = False, 445 epsilon: float = 1e-2, 446 ) -> "MazeDatasetConfig": 447 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 448 449 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 450 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 451 452 more information on where this comes from can be found in 453 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 454 - `estimate_dataset_fractions.ipynb` 455 - `maze_dataset.benchmarks.sweep_fit` 456 457 # Parameters: 458 - `safety_margin : float` 459 safety margin to apply to the success fraction estimate 460 (defaults to `1.2`, or 20% more mazes than estimated) 461 - `except_if_all_success_expected : bool` 462 if `True`, don't raise an error if the success fraction is below the threshold. 463 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 464 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 465 since `safety_margin` is still applied. 466 (defaults to `False`) 467 - `epsilon : float` 468 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 469 (defaults to `1e-2`) 470 471 # Returns: 472 - `MazeDatasetConfig` 473 new config with adjusted `n_mazes` 474 475 # Raises: 476 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 477 """ 478 # compute and check the success fraction 479 success_fraction: float = self.success_fraction_estimate( 480 except_if_all_success_expected=except_if_all_success_expected, 481 ) 482 if success_fraction < epsilon: 483 err_msg: str = ( 484 f"{success_fraction = } is below the threshold of {epsilon = }" 485 ) 486 raise SuccessChanceTooSmallError( 487 err_msg, 488 ) 489 490 # compute the new number of mazes 491 n_mazes: int = self.n_mazes 492 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 493 494 # put it in a new config and return 495 cfg_dict: dict = self.serialize() 496 cfg_dict["n_mazes"] = new_n_mazes 497 return MazeDatasetConfig.load(cfg_dict)
return a new MazeDatasetConfig
like this one with n_mazes
adjusted to compensate for the success fraction
calls MazeDatasetConfig.success_fraction_estimate()
to get the success fraction, and then
computes the new number of mazes as n_mazes = n_mazes * safety_margin / success_fraction + 1
more information on where this comes from can be found in
cfg_success_predict_fn()
frommaze_dataset.dataset.success_predict_math
estimate_dataset_fractions.ipynb
maze_dataset.benchmarks.sweep_fit
Parameters:
safety_margin : float
safety margin to apply to the success fraction estimate (defaults to1.2
, or 20% more mazes than estimated)except_if_all_success_expected : bool
ifTrue
, don't raise an error if the success fraction is below the threshold. this is passed toMazeDatasetConfig.success_fraction_estimate
. if your config isn't expected to fail, passing this might mean you generate more mazes than needed sincesafety_margin
is still applied. (defaults toFalse
)epsilon : float
raiseSuccessChanceTooSmallError
if the success fraction is below this threshold (defaults to1e-2
)
Returns:
MazeDatasetConfig
new config with adjustedn_mazes
Raises:
SuccessChanceTooSmallError
: if the computed success fraction is belowepsilon
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
The type of the None singleton.
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.dataset.maze_dataset_config.MazeDatasetConfig_base
- grid_n
- n_mazes
- maze_ctor
- maze_ctor_kwargs
- endpoint_kwargs
- grid_shape
- grid_shape_np
- max_grid_n
- stable_hash_cfg
- to_fname
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
114class MazeDataset(GPTDataset[MazeDatasetConfig]): # noqa: PLW1641 115 """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`""" 116 117 def __init__( 118 self, 119 cfg: MazeDatasetConfig, 120 mazes: typing.Sequence[SolvedMaze], 121 generation_metadata_collected: dict | None = None, 122 ) -> None: 123 """initialize a maze dataset from a config and a list of solved mazes""" 124 super().__init__() 125 self.cfg: MazeDatasetConfig = cfg 126 self.mazes: list[SolvedMaze] = list(mazes) 127 self.generation_metadata_collected: dict | None = generation_metadata_collected 128 129 # TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset" [override] 130 @classmethod 131 def from_config( # type: ignore[override] 132 cls, 133 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 134 cfg: MazeDatasetConfig, # type: ignore[override] 135 do_generate: bool = True, 136 load_local: bool = True, 137 save_local: bool = True, 138 zanj: ZANJ | None = None, 139 do_download: bool = True, 140 local_base_path: Path = Path("data/maze_dataset"), 141 except_on_config_mismatch: bool = True, 142 allow_generation_metadata_filter_mismatch: bool = True, 143 verbose: bool = False, 144 **kwargs, 145 ) -> "MazeDataset": 146 """create a maze dataset from a config 147 148 priority of loading: 149 1. load from local 150 2. download 151 3. generate 152 153 """ 154 return cast( 155 "MazeDataset", 156 super().from_config( 157 cfg=cfg, 158 do_generate=do_generate, 159 load_local=load_local, 160 save_local=save_local, 161 zanj=zanj, 162 do_download=do_download, 163 local_base_path=local_base_path, 164 except_on_config_mismatch=except_on_config_mismatch, 165 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 166 verbose=verbose, 167 **kwargs, 168 ), 169 ) 170 171 def data_hash(self) -> int: 172 """return a hash of the data""" 173 return stable_hash(str(tuple([x.serialize() for x in self.mazes]))) 174 175 def __getitem__(self, i: int) -> SolvedMaze: 176 """get a maze by index""" 177 return self.mazes[i] 178 179 def __iter__(self) -> typing.Iterator[SolvedMaze]: 180 """iterate over the mazes""" 181 return iter(self.mazes) 182 183 def __deepcopy__(self, memo) -> "MazeDataset": # noqa: ANN001 184 """deepcopy the dataset 185 186 FIX: this isnt actually a deepcopy I think? 187 """ 188 return MazeDataset.load(self._serialize_full()) 189 190 # TYPING: get type hints on the tokenizer here 191 @overload 192 def as_tokens( 193 self, 194 maze_tokenizer, # noqa: ANN001 195 limit: int | None = None, 196 join_tokens_individual_maze: Literal[False] = False, 197 ) -> list[list[str]]: ... 198 @overload 199 def as_tokens( 200 self, 201 maze_tokenizer, # noqa: ANN001 202 limit: int | None = None, 203 join_tokens_individual_maze: Literal[True] = True, 204 ) -> list[str]: ... 205 def as_tokens( 206 self, 207 maze_tokenizer, # TODO: MazeTokenizer 208 limit: int | None = None, 209 join_tokens_individual_maze: bool = False, 210 ) -> list[list[str]] | list[str]: 211 """return the dataset as tokens according to the passed `maze_tokenizer` 212 213 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 214 215 if `join_tokens_individual_maze` is True, then the tokens of each maze are 216 joined with a space, and the result is a list of strings. 217 i.e.: 218 219 >>> dataset.as_tokens(join_tokens_individual_maze=False) 220 [["a", "b", "c"], ["d", "e", "f"]] 221 >>> dataset.as_tokens(join_tokens_individual_maze=True) 222 ["a b c", "d e f"] 223 """ 224 output: list[list[str]] = [ 225 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 226 ] 227 if join_tokens_individual_maze: 228 return [" ".join(tokens) for tokens in output] 229 else: 230 return output 231 232 def __len__(self) -> int: 233 """return the number of mazes in the dataset""" 234 return len(self.mazes) 235 236 def __eq__(self, other: object) -> bool: 237 """compare two datasets""" 238 if not isinstance(other, MazeDataset): 239 raise NotImplementedError( 240 "can only compare with other MazeDataset objects", 241 ) 242 # TODO: compare hashes of data instead of the data itself? 243 return self.cfg == other.cfg and self.mazes == other.mazes 244 245 def assert_equal(self, other: "MazeDataset") -> None: 246 """assert that two datasets are equal""" 247 assert isinstance(other, MazeDataset) 248 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 249 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }" 250 251 @classmethod 252 def generate( 253 cls, 254 cfg: MazeDatasetConfig, 255 gen_parallel: bool = False, 256 pool_kwargs: dict | None = None, 257 verbose: bool = False, 258 # TODO: what to do when unexpected kwargs are passed? 259 **kwargs, # noqa: ARG003 260 ) -> "MazeDataset": 261 """Generate a maze dataset given a config and some generation parameters""" 262 # Copy the config to avoid modifying the original 263 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 264 json.loads(json.dumps(cfg.serialize())), 265 ) 266 267 if pool_kwargs is None: 268 pool_kwargs = dict() 269 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 270 271 solved_mazes: list[SolvedMaze | None] 272 # Configure tqdm for progress bar 273 tqdm_kwargs: dict = dict( 274 total=cfg_cpy.n_mazes, 275 unit="maze", 276 desc="generating & solving mazes", 277 disable=not verbose, 278 ) 279 # TODO: don't use the global unless generating in parallel! 280 if gen_parallel: 281 with multiprocessing.Pool( 282 **pool_kwargs, 283 initializer=_maze_gen_init_worker, 284 initargs=(cfg_cpy,), 285 ) as pool: 286 solved_mazes = list( 287 tqdm.tqdm( 288 pool.imap(_generate_maze_helper, maze_indexes), 289 **tqdm_kwargs, 290 ), 291 ) 292 293 else: 294 _maze_gen_init_worker(cfg_cpy) 295 solved_mazes = list( 296 tqdm.tqdm( 297 map( 298 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 299 # why does it think tolist() returns a string? 300 _generate_maze_helper, # type: ignore[arg-type] 301 maze_indexes.tolist(), 302 ), 303 **tqdm_kwargs, 304 ), 305 ) 306 307 # Filter out None values explicitly after ensuring all results are collected 308 solved_mazes_: list[SolvedMaze] = [ 309 maze for maze in solved_mazes if maze is not None 310 ] 311 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 312 313 # Update the config with the actual number of mazes 314 cfg_cpy.n_mazes = len(solved_mazes_) 315 316 dataset: MazeDataset = cls( 317 cfg=cfg_cpy, 318 mazes=solved_mazes_, 319 ) 320 321 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 322 323 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 324 325 return dataset 326 327 @classmethod 328 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 329 "(not implemented yet!) download a maze dataset from the internet" 330 raise NotImplementedError("not implemented yet") 331 332 @classmethod 333 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 334 """load from zanj/json""" 335 if data[_FORMAT_KEY] == "MazeDataset:minimal": 336 return cls._load_minimal(data) 337 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 338 return cls._load_minimal_soln_cat(data) 339 elif data[_FORMAT_KEY] == "MazeDataset": 340 if ( 341 SERIALIZE_MINIMAL_THRESHOLD == -1 342 ): # Allow access to `_load_legacy` for profiling 343 return cls._load_legacy(data) 344 return cls._load_full(data) 345 else: 346 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 347 raise KeyError( 348 err_msg, 349 ) 350 351 @classmethod 352 def _load_full(cls, data: JSONdict) -> "MazeDataset": 353 assert data[_FORMAT_KEY] == "MazeDataset" 354 return cls( 355 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 356 mazes=load_item_recursive(data["mazes"], tuple()), 357 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 358 ) 359 360 @classmethod 361 def _load_minimal(cls, data: JSONdict) -> "MazeDataset": 362 assert data[_FORMAT_KEY] == "MazeDataset:minimal" 363 return cls( 364 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 365 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 366 mazes=[ 367 SolvedMaze( 368 clist, 369 soln[:slen, ...], 370 ) 371 for clist, slen, soln in zip( 372 load_item_recursive(data["maze_connection_lists"], tuple()), 373 load_item_recursive(data["maze_solution_lengths"], tuple()), 374 load_item_recursive(data["maze_solutions"], tuple()), 375 strict=False, 376 # load_item_recursive(data["maze_endpoints"], tuple()), 377 ) 378 ], 379 ) 380 381 @classmethod 382 def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset": 383 assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat" 384 385 maze_solution_lengths = load_item_recursive( 386 data["maze_solution_lengths"], 387 tuple(), 388 ) 389 maze_solutions_concat = load_item_recursive( 390 data["maze_solutions_concat"], 391 tuple(), 392 ) 393 maze_solutions = np.split( 394 maze_solutions_concat, 395 np.cumsum(maze_solution_lengths)[:-1], 396 axis=0, 397 ) 398 399 return cls( 400 cfg=load_item_recursive(data["cfg"], tuple()), 401 generation_metadata_collected=load_item_recursive( 402 data["generation_metadata_collected"], 403 tuple(), 404 ), 405 mazes=[ 406 SolvedMaze( 407 connection_list=clist, 408 solution=soln, 409 ) 410 for clist, soln in zip( 411 load_item_recursive(data["maze_connection_lists"], tuple()), 412 # load_item_recursive(data["maze_endpoints"], tuple()), 413 maze_solutions, 414 strict=False, 415 ) 416 ], 417 ) 418 419 @classmethod 420 def _load_legacy(cls, data: JSONdict) -> "MazeDataset": 421 """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison.""" 422 assert data[_FORMAT_KEY] == "MazeDataset" 423 return cls( 424 **{ 425 key: load_item_recursive(data[key], tuple()) 426 for key in ["cfg", "mazes", "generation_metadata_collected"] 427 }, 428 ) 429 430 def serialize(self) -> JSONdict: 431 """serialize to zanj/json""" 432 if ( 433 SERIALIZE_MINIMAL_THRESHOLD is not None 434 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 435 ): 436 return self._serialize_minimal() 437 return self._serialize_full() 438 439 def _serialize_full(self) -> JSONdict: 440 return { 441 _FORMAT_KEY: "MazeDataset", 442 "cfg": json_serialize(self.cfg), 443 "fname": self.cfg.to_fname(), 444 "mazes": json_serialize(self.mazes), 445 "generation_metadata_collected": json_serialize( 446 self.generation_metadata_collected, 447 ), 448 } 449 450 def _serialize_minimal(self) -> JSONdict: 451 "alternate serialization where metadata is collected and mazes are stored in concatenated form" 452 filtered_meta: MazeDataset 453 if self.generation_metadata_collected is None: 454 filtered_meta = self.filter_by.collect_generation_meta() 455 else: 456 filtered_meta = self 457 458 max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes) 459 n_mazes: int = len(filtered_meta.mazes) 460 grid_n: int = filtered_meta.cfg.grid_n 461 462 maze_connection_lists: np.ndarray = np.empty( 463 (n_mazes, 2, grid_n, grid_n), 464 dtype=np.bool_, 465 ) 466 # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 467 maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32) 468 maze_solutions: np.ndarray = np.empty( 469 (n_mazes, max_solution_len, 2), 470 dtype=np.int8, 471 ) 472 473 for idx, maze in enumerate(filtered_meta.mazes): 474 maze_connection_lists[idx] = maze.connection_list 475 # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 476 maze_solution_lengths[idx] = maze.solution.shape[0] 477 maze_solutions[idx, : maze.solution.shape[0]] = maze.solution 478 479 return { 480 _FORMAT_KEY: "MazeDataset:minimal", 481 "cfg": json_serialize(filtered_meta.cfg), 482 "fname": filtered_meta.cfg.to_fname(), 483 "generation_metadata_collected": json_serialize( 484 filtered_meta.generation_metadata_collected, 485 ), 486 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 487 # "maze_endpoints": maze_endpoints, 488 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 489 "maze_solutions": maze_solutions, # type: ignore[dict-item] 490 } 491 492 def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict: 493 "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form" 494 filtered_meta: MazeDataset 495 if self.generation_metadata_collected is None: 496 filtered_meta = self.filter_by.collect_generation_meta() 497 else: 498 filtered_meta = self 499 500 maze_solution_lengths: np.ndarray = np.array( 501 [m.solution.shape[0] for m in filtered_meta.mazes], 502 dtype=np.int32, 503 ) 504 n_mazes: int = len(filtered_meta.mazes) 505 grid_n: int = filtered_meta.cfg.grid_n 506 total_solution_len: int = np.sum(maze_solution_lengths) 507 508 maze_connection_lists: np.ndarray = np.empty( 509 (n_mazes, 2, grid_n, grid_n), 510 dtype=np.bool_, 511 ) 512 maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 513 maze_solutions_concat: np.ndarray = np.empty( 514 (total_solution_len, 2), 515 dtype=np.int8, 516 ) 517 518 solutions_running_idx: int = 0 519 for idx, maze in enumerate(filtered_meta.mazes): 520 maze_connection_lists[idx] = maze.connection_list 521 maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 522 soln_len: int = maze.solution.shape[0] 523 maze_solution_lengths[idx] = soln_len 524 maze_solutions_concat[ 525 solutions_running_idx : solutions_running_idx + soln_len 526 ] = maze.solution 527 solutions_running_idx += soln_len 528 529 return { 530 _FORMAT_KEY: "MazeDataset:minimal_soln_cat", 531 "cfg": json_serialize(filtered_meta.cfg), 532 "fname": filtered_meta.cfg.to_fname(), 533 "generation_metadata_collected": json_serialize( 534 filtered_meta.generation_metadata_collected, 535 ), 536 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 537 "maze_endpoints": maze_endpoints, # type: ignore[dict-item] 538 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 539 "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item] 540 } 541 542 def update_self_config(self) -> None: 543 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 544 if self.cfg.n_mazes != len(self.mazes): 545 warnings.warn( 546 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 547 ) 548 self.cfg.n_mazes = len(self.mazes) 549 550 def custom_maze_filter( 551 self, 552 method: typing.Callable[[SolvedMaze], bool], 553 **kwargs, 554 ) -> "MazeDataset": 555 """filter the dataset using a custom method""" 556 output: MazeDataset = MazeDataset( 557 cfg=copy.deepcopy(self.cfg), 558 mazes=[m for m in self.mazes if method(m, **kwargs)], 559 ) 560 output.cfg.applied_filters.append( 561 { 562 "name": f"__custom__:{method.__name__}", 563 "kwargs": kwargs, 564 }, 565 ) 566 output.update_self_config() 567 return output
a maze dataset class. This is a collection of solved mazes, and should be initialized via MazeDataset.from_config
117 def __init__( 118 self, 119 cfg: MazeDatasetConfig, 120 mazes: typing.Sequence[SolvedMaze], 121 generation_metadata_collected: dict | None = None, 122 ) -> None: 123 """initialize a maze dataset from a config and a list of solved mazes""" 124 super().__init__() 125 self.cfg: MazeDatasetConfig = cfg 126 self.mazes: list[SolvedMaze] = list(mazes) 127 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize a maze dataset from a config and a list of solved mazes
130 @classmethod 131 def from_config( # type: ignore[override] 132 cls, 133 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 134 cfg: MazeDatasetConfig, # type: ignore[override] 135 do_generate: bool = True, 136 load_local: bool = True, 137 save_local: bool = True, 138 zanj: ZANJ | None = None, 139 do_download: bool = True, 140 local_base_path: Path = Path("data/maze_dataset"), 141 except_on_config_mismatch: bool = True, 142 allow_generation_metadata_filter_mismatch: bool = True, 143 verbose: bool = False, 144 **kwargs, 145 ) -> "MazeDataset": 146 """create a maze dataset from a config 147 148 priority of loading: 149 1. load from local 150 2. download 151 3. generate 152 153 """ 154 return cast( 155 "MazeDataset", 156 super().from_config( 157 cfg=cfg, 158 do_generate=do_generate, 159 load_local=load_local, 160 save_local=save_local, 161 zanj=zanj, 162 do_download=do_download, 163 local_base_path=local_base_path, 164 except_on_config_mismatch=except_on_config_mismatch, 165 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 166 verbose=verbose, 167 **kwargs, 168 ), 169 )
create a maze dataset from a config
priority of loading:
- load from local
- download
- generate
171 def data_hash(self) -> int: 172 """return a hash of the data""" 173 return stable_hash(str(tuple([x.serialize() for x in self.mazes])))
return a hash of the data
205 def as_tokens( 206 self, 207 maze_tokenizer, # TODO: MazeTokenizer 208 limit: int | None = None, 209 join_tokens_individual_maze: bool = False, 210 ) -> list[list[str]] | list[str]: 211 """return the dataset as tokens according to the passed `maze_tokenizer` 212 213 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 214 215 if `join_tokens_individual_maze` is True, then the tokens of each maze are 216 joined with a space, and the result is a list of strings. 217 i.e.: 218 219 >>> dataset.as_tokens(join_tokens_individual_maze=False) 220 [["a", "b", "c"], ["d", "e", "f"]] 221 >>> dataset.as_tokens(join_tokens_individual_maze=True) 222 ["a b c", "d e f"] 223 """ 224 output: list[list[str]] = [ 225 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 226 ] 227 if join_tokens_individual_maze: 228 return [" ".join(tokens) for tokens in output] 229 else: 230 return output
return the dataset as tokens according to the passed maze_tokenizer
the maze_tokenizer
should be either a MazeTokenizer
or a MazeTokenizerModular
if join_tokens_individual_maze
is True, then the tokens of each maze are
joined with a space, and the result is a list of strings.
i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
245 def assert_equal(self, other: "MazeDataset") -> None: 246 """assert that two datasets are equal""" 247 assert isinstance(other, MazeDataset) 248 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 249 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }"
assert that two datasets are equal
251 @classmethod 252 def generate( 253 cls, 254 cfg: MazeDatasetConfig, 255 gen_parallel: bool = False, 256 pool_kwargs: dict | None = None, 257 verbose: bool = False, 258 # TODO: what to do when unexpected kwargs are passed? 259 **kwargs, # noqa: ARG003 260 ) -> "MazeDataset": 261 """Generate a maze dataset given a config and some generation parameters""" 262 # Copy the config to avoid modifying the original 263 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 264 json.loads(json.dumps(cfg.serialize())), 265 ) 266 267 if pool_kwargs is None: 268 pool_kwargs = dict() 269 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 270 271 solved_mazes: list[SolvedMaze | None] 272 # Configure tqdm for progress bar 273 tqdm_kwargs: dict = dict( 274 total=cfg_cpy.n_mazes, 275 unit="maze", 276 desc="generating & solving mazes", 277 disable=not verbose, 278 ) 279 # TODO: don't use the global unless generating in parallel! 280 if gen_parallel: 281 with multiprocessing.Pool( 282 **pool_kwargs, 283 initializer=_maze_gen_init_worker, 284 initargs=(cfg_cpy,), 285 ) as pool: 286 solved_mazes = list( 287 tqdm.tqdm( 288 pool.imap(_generate_maze_helper, maze_indexes), 289 **tqdm_kwargs, 290 ), 291 ) 292 293 else: 294 _maze_gen_init_worker(cfg_cpy) 295 solved_mazes = list( 296 tqdm.tqdm( 297 map( 298 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 299 # why does it think tolist() returns a string? 300 _generate_maze_helper, # type: ignore[arg-type] 301 maze_indexes.tolist(), 302 ), 303 **tqdm_kwargs, 304 ), 305 ) 306 307 # Filter out None values explicitly after ensuring all results are collected 308 solved_mazes_: list[SolvedMaze] = [ 309 maze for maze in solved_mazes if maze is not None 310 ] 311 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 312 313 # Update the config with the actual number of mazes 314 cfg_cpy.n_mazes = len(solved_mazes_) 315 316 dataset: MazeDataset = cls( 317 cfg=cfg_cpy, 318 mazes=solved_mazes_, 319 ) 320 321 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 322 323 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 324 325 return dataset
Generate a maze dataset given a config and some generation parameters
327 @classmethod 328 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 329 "(not implemented yet!) download a maze dataset from the internet" 330 raise NotImplementedError("not implemented yet")
(not implemented yet!) download a maze dataset from the internet
332 @classmethod 333 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 334 """load from zanj/json""" 335 if data[_FORMAT_KEY] == "MazeDataset:minimal": 336 return cls._load_minimal(data) 337 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 338 return cls._load_minimal_soln_cat(data) 339 elif data[_FORMAT_KEY] == "MazeDataset": 340 if ( 341 SERIALIZE_MINIMAL_THRESHOLD == -1 342 ): # Allow access to `_load_legacy` for profiling 343 return cls._load_legacy(data) 344 return cls._load_full(data) 345 else: 346 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 347 raise KeyError( 348 err_msg, 349 )
load from zanj/json
430 def serialize(self) -> JSONdict: 431 """serialize to zanj/json""" 432 if ( 433 SERIALIZE_MINIMAL_THRESHOLD is not None 434 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 435 ): 436 return self._serialize_minimal() 437 return self._serialize_full()
serialize to zanj/json
542 def update_self_config(self) -> None: 543 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 544 if self.cfg.n_mazes != len(self.mazes): 545 warnings.warn( 546 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 547 ) 548 self.cfg.n_mazes = len(self.mazes)
update the config to match the current state of the dataset (number of mazes, such as after filtering)
550 def custom_maze_filter( 551 self, 552 method: typing.Callable[[SolvedMaze], bool], 553 **kwargs, 554 ) -> "MazeDataset": 555 """filter the dataset using a custom method""" 556 output: MazeDataset = MazeDataset( 557 cfg=copy.deepcopy(self.cfg), 558 mazes=[m for m in self.mazes if method(m, **kwargs)], 559 ) 560 output.cfg.applied_filters.append( 561 { 562 "name": f"__custom__:{method.__name__}", 563 "kwargs": kwargs, 564 }, 565 ) 566 output.update_self_config() 567 return output
filter the dataset using a custom method
Inherited Members
84class MazeDatasetCollection(GPTDataset): 85 """a collection of maze datasets""" 86 87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected 106 107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets] 111 112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths))) 116 117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 ) 125 126 def __len__(self) -> int: 127 """return the total number of mazes in the collection""" 128 return sum(len(dataset) for dataset in self.maze_datasets) 129 130 def __getitem__(self, index: int) -> LatticeMaze: 131 "get a maze by index" 132 # find which dataset the index belongs to 133 # we add 1, since np.searchsorted returns the 134 # index of the last element that is strictly less than the target 135 # while we want the index of the last element less than or equal to the target 136 dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1)) 137 index_adjusted: int = index 138 if dataset_idx > 0: 139 # if the index is 0, `dataset_idx - 1` will be -1. 140 # We just want to use the base index 141 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1] 142 return self.maze_datasets[dataset_idx][index_adjusted] 143 144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets) 156 157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets) 169 170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 } 180 181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 ) 191 192 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow? 193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output 217 218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
a collection of maze datasets
87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize the dataset collection from a MazeDatasetCollectionConfig
and a list of MazeDataset
s
107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets]
return the lengths of each dataset in the collection
112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths)))
return the cumulative lengths of each dataset in the collection
117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 )
single list of all mazes in the collection
144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets)
generate a dataset collection from a config
157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets)
(not implemented!) download a dataset collection from a config
170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 }
serialize the dataset collection
181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 )
load the dataset collection from the representation created by serialize
193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output
return the dataset as tokens
if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
update the config to match the number of mazes, and update the underlying configs of each dataset
Inherited Members
31@serializable_dataclass(kw_only=True) 32class MazeDatasetCollectionConfig(GPTDatasetConfig): 33 """maze dataset collection configuration, including tokenizers and shuffle""" 34 35 # Attributes without a default cannot follow attributes with one [misc] 36 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( # type: ignore[misc] 37 serialization_fn=lambda configs: [config.serialize() for config in configs], 38 loading_fn=lambda data: [ 39 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"] 40 ], 41 ) 42 43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 ) 52 53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs) 57 58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs) 62 63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n) 67 68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32) 72 73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize())) 76 77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
maze dataset collection configuration, including tokenizers and shuffle
43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 )
return a summary of the config
53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs)
return the total number of mazes in the collection across all dataset
58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs)
return the maximum grid size of the mazes in the collection
63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n)
return the maximum grid shape of the mazes in the collection
68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32)
return the maximum grid shape of the mazes in the collection as a numpy array
73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize()))
return a stable hash of the config
77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
convert config to a filename
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
The type of the None singleton.
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
The type of the None singleton.
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
1176@serializable_dataclass(frozen=True, kw_only=True) 1177class TargetedLatticeMaze(LatticeMaze): # type: ignore[misc] 1178 """A LatticeMaze with a start and end position""" 1179 1180 # this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos 1181 # type ignore here because even though its a kw-only dataclass, 1182 # mypy doesn't like that non-default arguments are after default arguments 1183 start_pos: Coord = serializable_field( # type: ignore[misc] 1184 assert_type=False, 1185 ) 1186 end_pos: Coord = serializable_field( # type: ignore[misc] 1187 assert_type=False, 1188 ) 1189 1190 def __post_init__(self) -> None: 1191 "post init converts start and end pos to numpy arrays, checks they exist and are in bounds" 1192 # make things numpy arrays (very jank to override frozen dataclass) 1193 self.__dict__["start_pos"] = np.array(self.start_pos) 1194 self.__dict__["end_pos"] = np.array(self.end_pos) 1195 assert self.start_pos is not None 1196 assert self.end_pos is not None 1197 # check that start and end are in bounds 1198 if ( 1199 self.start_pos[0] >= self.grid_shape[0] 1200 or self.start_pos[1] >= self.grid_shape[1] 1201 ): 1202 err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}" 1203 raise ValueError( 1204 err_msg, 1205 ) 1206 if ( 1207 self.end_pos[0] >= self.grid_shape[0] 1208 or self.end_pos[1] >= self.grid_shape[1] 1209 ): 1210 err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }" 1211 raise ValueError( 1212 err_msg, 1213 ) 1214 1215 def __eq__(self, other: object) -> bool: 1216 "check equality, calls parent class equality check" 1217 return super().__eq__(other) 1218 1219 def __hash__(self) -> int: 1220 "hash the `TargetedLatticeMaze` by hashing a tuple of the connection list and start and end positions as bytes" 1221 return hash( 1222 ( 1223 self.connection_list.tobytes(), 1224 self.start_pos.tobytes(), 1225 self.end_pos.tobytes(), 1226 ) 1227 ) 1228 1229 def _get_start_pos_tokens(self) -> list[str | CoordTup]: 1230 return [ 1231 SPECIAL_TOKENS.ORIGIN_START, 1232 tuple(self.start_pos), 1233 SPECIAL_TOKENS.ORIGIN_END, 1234 ] 1235 1236 def get_start_pos_tokens(self) -> list[str | CoordTup]: 1237 "(deprecated!) return the start position as a list of tokens" 1238 warnings.warn( 1239 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 1240 TokenizerDeprecationWarning, 1241 ) 1242 return self._get_start_pos_tokens() 1243 1244 def _get_end_pos_tokens(self) -> list[str | CoordTup]: 1245 return [ 1246 SPECIAL_TOKENS.TARGET_START, 1247 tuple(self.end_pos), 1248 SPECIAL_TOKENS.TARGET_END, 1249 ] 1250 1251 def get_end_pos_tokens(self) -> list[str | CoordTup]: 1252 "(deprecated!) return the end position as a list of tokens" 1253 warnings.warn( 1254 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 1255 TokenizerDeprecationWarning, 1256 ) 1257 return self._get_end_pos_tokens() 1258 1259 @classmethod 1260 def from_lattice_maze( 1261 cls, 1262 lattice_maze: LatticeMaze, 1263 start_pos: Coord | CoordTup, 1264 end_pos: Coord | CoordTup, 1265 ) -> "TargetedLatticeMaze": 1266 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" 1267 return cls( 1268 connection_list=lattice_maze.connection_list, 1269 start_pos=np.array(start_pos), 1270 end_pos=np.array(end_pos), 1271 generation_meta=lattice_maze.generation_meta, 1272 )
A LatticeMaze with a start and end position
1236 def get_start_pos_tokens(self) -> list[str | CoordTup]: 1237 "(deprecated!) return the start position as a list of tokens" 1238 warnings.warn( 1239 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 1240 TokenizerDeprecationWarning, 1241 ) 1242 return self._get_start_pos_tokens()
(deprecated!) return the start position as a list of tokens
1251 def get_end_pos_tokens(self) -> list[str | CoordTup]: 1252 "(deprecated!) return the end position as a list of tokens" 1253 warnings.warn( 1254 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 1255 TokenizerDeprecationWarning, 1256 ) 1257 return self._get_end_pos_tokens()
(deprecated!) return the end position as a list of tokens
1259 @classmethod 1260 def from_lattice_maze( 1261 cls, 1262 lattice_maze: LatticeMaze, 1263 start_pos: Coord | CoordTup, 1264 end_pos: Coord | CoordTup, 1265 ) -> "TargetedLatticeMaze": 1266 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" 1267 return cls( 1268 connection_list=lattice_maze.connection_list, 1269 start_pos=np.array(start_pos), 1270 end_pos=np.array(end_pos), 1271 generation_meta=lattice_maze.generation_meta, 1272 )
get a TargetedLatticeMaze
from a LatticeMaze
by specifying start and end positions
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
The type of the None singleton.
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
The type of the None singleton.
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- LatticeMaze
- connection_list
- generation_meta
- lattice_dim
- grid_shape
- n_connections
- grid_n
- heuristic
- nodes_connected
- is_valid_path
- coord_degrees
- get_coord_neighbors
- gen_connected_component_from
- find_shortest_path
- get_nodes
- get_connected_component
- generate_random_path
- as_adj_list
- from_adj_list
- as_adj_list_tokens
- as_tokens
- from_tokens
- as_pixels
- from_pixels
- as_ascii
- from_ascii
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
120@serializable_dataclass( 121 frozen=True, 122 kw_only=True, 123 properties_to_serialize=["lattice_dim", "generation_meta"], 124) 125class LatticeMaze(SerializableDataclass): 126 """lattice maze (nodes on a lattice, connections only to neighboring nodes) 127 128 Connection List represents which nodes (N) are connected in each direction. 129 130 First and second elements represent rightward and downward connections, 131 respectively. 132 133 Example: 134 Connection list: 135 [ 136 [ # down 137 [F T], 138 [F F] 139 ], 140 [ # right 141 [T F], 142 [T F] 143 ] 144 ] 145 146 Nodes with connections 147 N T N F 148 F T 149 N T N F 150 F F 151 152 Graph: 153 N - N 154 | 155 N - N 156 157 Note: the bottom row connections going down, and the 158 right-hand connections going right, will always be False. 159 160 """ 161 162 connection_list: ConnectionList 163 generation_meta: dict | None = serializable_field(default=None, compare=False) 164 165 lattice_dim = property(lambda self: self.connection_list.shape[0]) 166 grid_shape = property(lambda self: self.connection_list.shape[1:]) 167 n_connections = property(lambda self: self.connection_list.sum()) 168 169 @property 170 def grid_n(self) -> int: 171 "grid size as int, raises `AssertionError` if not square" 172 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 173 return self.grid_shape[0] 174 175 # ============================================================ 176 # basic methods 177 # ============================================================ 178 179 def __eq__(self, other: object) -> bool: 180 "equality check calls super" 181 return super().__eq__(other) 182 183 @staticmethod 184 def heuristic(a: CoordTup, b: CoordTup) -> float: 185 """return manhattan distance between two points""" 186 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1]) 187 188 def __hash__(self) -> int: 189 """hash the connection list by converting connection list to bytes""" 190 return hash(self.connection_list.tobytes()) 191 192 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 193 """returns whether two nodes are connected""" 194 delta: Coord = b - a 195 if np.abs(delta).sum() != 1: 196 # return false if not even adjacent 197 return False 198 else: 199 # test for wall 200 dim: int = int(np.argmax(np.abs(delta))) 201 clist_node: Coord = a if (delta.sum() > 0) else b 202 return self.connection_list[dim, clist_node[0], clist_node[1]] 203 204 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 205 """check if a path is valid""" 206 # check path is not empty 207 if len(path) == 0: 208 return empty_is_valid 209 210 # check all coords in bounds of maze 211 if not np.all((path >= 0) & (path < self.grid_shape)): 212 return False 213 214 # check all nodes connected 215 for i in range(len(path) - 1): 216 if not self.nodes_connected(path[i], path[i + 1]): 217 return False 218 return True 219 220 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 221 """Returns an array with the connectivity degree of each coord. 222 223 I.e., how many neighbors each coord has. 224 """ 225 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 226 self.connection_list.astype(np.int8) 227 ) 228 degrees: Int8[np.ndarray, "row col"] = np.sum( 229 int_conn, 230 axis=0, 231 ) # Connections to east and south 232 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 233 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 234 return degrees 235 236 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 237 """Returns an array of the neighboring, connected coords of `c`.""" 238 c = np.array(c) # type: ignore[assignment] 239 neighbors: list[Coord] = [ 240 neighbor 241 for neighbor in (c + NEIGHBORS_MASK) 242 if ( 243 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 244 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 245 and self.nodes_connected(c, neighbor) # connected 246 ) 247 ] 248 249 output: CoordArray = np.array(neighbors) 250 if len(neighbors) > 0: 251 assert output.shape == ( 252 len(neighbors), 253 2, 254 ), ( 255 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 256 ) 257 return output 258 259 def gen_connected_component_from(self, c: Coord) -> CoordArray: 260 """return the connected component from a given coordinate""" 261 # Stack for DFS 262 stack: list[Coord] = [c] 263 264 # Set to store visited nodes 265 visited: set[CoordTup] = set() 266 267 while stack: 268 current_node: Coord = stack.pop() 269 # this is fine since we know current_node is a coord and thus of length 2 270 visited.add(tuple(current_node)) # type: ignore[arg-type] 271 272 # Get the neighbors of the current node 273 neighbors = self.get_coord_neighbors(current_node) 274 275 # Iterate over neighbors 276 for neighbor in neighbors: 277 if tuple(neighbor) not in visited: 278 stack.append(neighbor) 279 280 return np.array(list(visited)) 281 282 def find_shortest_path( 283 self, 284 c_start: CoordTup | Coord, 285 c_end: CoordTup | Coord, 286 ) -> CoordArray: 287 """find the shortest path between two coordinates, using A*""" 288 c_start = tuple(c_start) # type: ignore[assignment] 289 c_end = tuple(c_end) # type: ignore[assignment] 290 291 g_score: dict[CoordTup, float] = ( 292 dict() 293 ) # cost of cheapest path to node from start currently known 294 f_score: dict[CoordTup, float] = { 295 c_start: 0.0, 296 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 297 298 # init 299 g_score[c_start] = 0.0 300 g_score[c_start] = self.heuristic(c_start, c_end) 301 302 closed_vtx: set[CoordTup] = set() # nodes already evaluated 303 # nodes to be evaluated 304 # we need a set of the tuples, dont place the ints in the set 305 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 306 source: dict[CoordTup, CoordTup] = ( 307 dict() 308 ) # node immediately preceding each node in the path (currently known shortest path) 309 310 while open_vtx: 311 # get lowest f_score node 312 # mypy cant tell that c is of length 2 313 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] 314 # f_current: float = f_score[c_current] 315 316 # check if goal is reached 317 if c_end == c_current: 318 path: list[CoordTup] = [c_current] 319 p_current: CoordTup = c_current 320 while p_current in source: 321 p_current = source[p_current] 322 path.append(p_current) 323 # ---------------------------------------------------------------------- 324 # this is the only return statement 325 return np.array(path[::-1]) 326 # ---------------------------------------------------------------------- 327 328 # close current node 329 closed_vtx.add(c_current) 330 open_vtx.remove(c_current) 331 332 # update g_score of neighbors 333 _np_neighbor: Coord 334 for _np_neighbor in self.get_coord_neighbors(c_current): 335 neighbor: CoordTup = tuple(_np_neighbor) 336 337 if neighbor in closed_vtx: 338 # already checked 339 continue 340 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 341 342 if neighbor not in open_vtx: 343 # found new vtx, so add 344 open_vtx.add(neighbor) 345 346 elif g_temp >= g_score[neighbor]: 347 # if already knew about this one, but current g_score is worse, skip 348 continue 349 350 # store g_score and source 351 source[neighbor] = c_current 352 g_score[neighbor] = g_temp 353 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 354 355 raise ValueError( 356 "A solution could not be found!", 357 f"{c_start = }, {c_end = }", 358 self.as_ascii(), 359 ) 360 361 def get_nodes(self) -> CoordArray: 362 """return a list of all nodes in the maze""" 363 rows: Int[np.ndarray, "x y"] 364 cols: Int[np.ndarray, "x y"] 365 rows, cols = np.meshgrid( 366 range(self.grid_shape[0]), 367 range(self.grid_shape[1]), 368 indexing="ij", 369 ) 370 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 371 return nodes 372 373 def get_connected_component(self) -> CoordArray: 374 """get the largest (and assumed only nonsingular) connected component of the maze 375 376 TODO: other connected components? 377 """ 378 if (self.generation_meta is None) or ( 379 self.generation_meta.get("fully_connected", False) 380 ): 381 # for fully connected case, pick any two positions 382 return self.get_nodes() 383 else: 384 # if metadata provided, use visited cells 385 visited_cells: set[CoordTup] | None = self.generation_meta.get( 386 "visited_cells", 387 None, 388 ) 389 if visited_cells is None: 390 # TODO: dynamically generate visited_cells? 391 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 392 raise ValueError( 393 err_msg, 394 ) 395 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 396 return visited_cells_np 397 398 @typing.overload 399 def generate_random_path( 400 self, 401 allowed_start: CoordList | None = None, 402 allowed_end: CoordList | None = None, 403 deadend_start: bool = False, 404 deadend_end: bool = False, 405 endpoints_not_equal: bool = False, 406 except_on_no_valid_endpoint: typing.Literal[True] = True, 407 ) -> CoordArray: ... 408 @typing.overload 409 def generate_random_path( 410 self, 411 allowed_start: CoordList | None = None, 412 allowed_end: CoordList | None = None, 413 deadend_start: bool = False, 414 deadend_end: bool = False, 415 endpoints_not_equal: bool = False, 416 except_on_no_valid_endpoint: typing.Literal[False] = False, 417 ) -> CoordArray | None: ... 418 def generate_random_path( # noqa: C901 419 self, 420 allowed_start: CoordList | None = None, 421 allowed_end: CoordList | None = None, 422 deadend_start: bool = False, 423 deadend_end: bool = False, 424 endpoints_not_equal: bool = False, 425 except_on_no_valid_endpoint: bool = True, 426 ) -> CoordArray | None: 427 """return a path between randomly chosen start and end nodes within the connected component 428 429 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 430 431 # Parameters: 432 - `allowed_start : CoordList | None` 433 a list of allowed start positions. If `None`, any position in the connected component is allowed 434 (defaults to `None`) 435 - `allowed_end : CoordList | None` 436 a list of allowed end positions. If `None`, any position in the connected component is allowed 437 (defaults to `None`) 438 - `deadend_start : bool` 439 whether to ***force*** the start position to be a deadend (defaults to `False`) 440 (defaults to `False`) 441 - `deadend_end : bool` 442 whether to ***force*** the end position to be a deadend (defaults to `False`) 443 (defaults to `False`) 444 - `endpoints_not_equal : bool` 445 whether to ensure tha the start and end point are not the same 446 (defaults to `False`) 447 - `except_on_no_valid_endpoint : bool` 448 whether to raise an error if no valid start or end positions are found 449 if this is `False`, the function might return `None` and this must be handled by the caller 450 (defaults to `True`) 451 452 # Returns: 453 - `CoordArray` 454 a path between the selected start and end positions 455 456 # Raises: 457 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 458 """ 459 # we can't create a "path" in a single-node maze 460 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 461 f"can't create path in single-node maze: {self.as_ascii()}" 462 ) 463 464 # get connected component 465 connected_component: CoordArray = self.get_connected_component() 466 467 # initialize start and end positions 468 positions: Int[np.int8, "2 2"] 469 470 # if no special conditions on start and end positions 471 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 472 None, 473 None, 474 False, 475 False, 476 ): 477 try: 478 positions = connected_component[ # type: ignore[assignment] 479 np.random.choice( 480 len(connected_component), 481 size=2, 482 replace=False, 483 ) 484 ] 485 except ValueError as e: 486 if except_on_no_valid_endpoint: 487 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" 488 raise NoValidEndpointException( 489 err_msg, 490 ) from e 491 return None 492 493 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 494 495 # handle special conditions 496 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 497 # copy connected component set 498 allowed_start_set: set[CoordTup] = connected_component_set.copy() 499 allowed_end_set: set[CoordTup] = connected_component_set.copy() 500 501 # filter by explicitly allowed start and end positions 502 # '# type: ignore[assignment]' here because the returned tuple can be of any length 503 if allowed_start is not None: 504 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 505 506 if allowed_end is not None: 507 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 508 509 # filter by forcing deadends 510 if deadend_start: 511 allowed_start_set = set( 512 filter( 513 lambda x: len(self.get_coord_neighbors(x)) == 1, 514 allowed_start_set, 515 ), 516 ) 517 518 if deadend_end: 519 allowed_end_set = set( 520 filter( 521 lambda x: len(self.get_coord_neighbors(x)) == 1, 522 allowed_end_set, 523 ), 524 ) 525 526 # check we have valid positions 527 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 528 if except_on_no_valid_endpoint: 529 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 530 raise NoValidEndpointException( 531 err_msg, 532 ) 533 return None 534 535 # randomly select start and end positions 536 try: 537 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 538 start_pos: CoordTup = tuple( # type: ignore[assignment] 539 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], 540 ) 541 if endpoints_not_equal: 542 # remove start position from end positions 543 allowed_end_set.discard(start_pos) 544 end_pos: CoordTup = tuple( # type: ignore[assignment] 545 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], 546 ) 547 except ValueError as e: 548 if except_on_no_valid_endpoint: 549 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 550 raise NoValidEndpointException( 551 err_msg, 552 ) from e 553 return None 554 555 return self.find_shortest_path(start_pos, end_pos) 556 557 # ============================================================ 558 # to and from adjacency list 559 # ============================================================ 560 def as_adj_list( 561 self, 562 shuffle_d0: bool = True, 563 shuffle_d1: bool = True, 564 ) -> Int8[np.ndarray, "conn start_end coord"]: 565 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" 566 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1) 567 568 @classmethod 569 def from_adj_list( 570 cls, 571 adj_list: Int8[np.ndarray, "conn start_end coord"], 572 ) -> "LatticeMaze": 573 """create a LatticeMaze from a list of connections 574 575 > [!NOTE] 576 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 577 """ 578 # this is where it would probably break for rectangular mazes 579 grid_n: int = adj_list.max() + 1 580 581 connection_list: ConnectionList = np.zeros( 582 (2, grid_n, grid_n), 583 dtype=np.bool_, 584 ) 585 586 for c_start, c_end in adj_list: 587 # check that exactly 1 coordinate matches 588 if (c_start == c_end).sum() != 1: 589 raise ValueError("invalid connection") 590 591 # get the direction 592 d: int = (c_start != c_end).argmax() 593 594 x: int 595 y: int 596 # pick whichever has the lesser value in the direction `d` 597 if c_start[d] < c_end[d]: 598 x, y = c_start 599 else: 600 x, y = c_end 601 602 connection_list[d, x, y] = True 603 604 return LatticeMaze( 605 connection_list=connection_list, 606 ) 607 608 def as_adj_list_tokens(self) -> list[str | CoordTup]: 609 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" 610 warnings.warn( 611 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 612 TokenizerDeprecationWarning, 613 ) 614 return [ 615 SPECIAL_TOKENS.ADJLIST_START, 616 *chain.from_iterable( # type: ignore[list-item] 617 [ 618 [ 619 tuple(c_s), 620 SPECIAL_TOKENS.CONNECTOR, 621 tuple(c_e), 622 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 623 ] 624 for c_s, c_e in self.as_adj_list() 625 ], 626 ), 627 SPECIAL_TOKENS.ADJLIST_END, 628 ] 629 630 def _as_adj_list_tokens(self) -> list[str | CoordTup]: 631 return [ 632 SPECIAL_TOKENS.ADJLIST_START, 633 *chain.from_iterable( # type: ignore[list-item] 634 [ 635 [ 636 tuple(c_s), 637 SPECIAL_TOKENS.CONNECTOR, 638 tuple(c_e), 639 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 640 ] 641 for c_s, c_e in self.as_adj_list() 642 ], 643 ), 644 SPECIAL_TOKENS.ADJLIST_END, 645 ] 646 647 def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]: 648 """turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples""" 649 output: list[CoordTup | str] = self._as_adj_list_tokens() 650 # if getattr(self, "start_pos", None) is not None: 651 if isinstance(self, TargetedLatticeMaze): 652 output += self._get_start_pos_tokens() 653 if isinstance(self, TargetedLatticeMaze): 654 output += self._get_end_pos_tokens() 655 if isinstance(self, SolvedMaze): 656 output += self._get_solution_tokens() 657 return output 658 659 def _as_tokens( 660 self, 661 maze_tokenizer: "MazeTokenizer | TokenizationMode", 662 ) -> list[str]: 663 # type ignores here fine since we check the instance 664 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 665 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 666 if ( 667 isinstance_by_type_name(maze_tokenizer, "MazeTokenizer") 668 and maze_tokenizer.is_AOTP() # type: ignore[union-attr] 669 ): 670 coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP() 671 coords_processed: list[str] = maze_tokenizer.coords_to_strings( # type: ignore[union-attr] 672 coords=coords_raw, 673 when_noncoord="include", 674 ) 675 return coords_processed 676 else: 677 err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}" 678 raise NotImplementedError(err_msg) 679 680 def as_tokens( 681 self, 682 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 683 ) -> list[str]: 684 """serialize maze and solution to tokens""" 685 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 686 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 687 else: 688 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type] 689 690 @classmethod 691 def _from_tokens_AOTP( 692 cls, 693 tokens: list[str], 694 maze_tokenizer: "MazeTokenizer | MazeTokenizerModular", 695 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 696 """create a LatticeMaze from a list of tokens""" 697 # figure out what input format 698 # ======================================== 699 if tokens[0] == SPECIAL_TOKENS.ADJLIST_START: 700 adj_list_tokens = get_adj_list_tokens(tokens) 701 else: 702 # If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens 703 adj_list_tokens = tokens 704 warnings.warn( 705 "Assuming input is just adjacency list tokens, no special tokens found", 706 ) 707 708 # process edges for adjacency list 709 # ======================================== 710 edges: list[list[str]] = list_split( 711 adj_list_tokens, 712 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 713 ) 714 715 coordinates: list[tuple[CoordTup, CoordTup]] = list() 716 for e in edges: 717 # skip last endline 718 if len(e) != 0: 719 # convert to coords, split start and end 720 e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords( 721 e, 722 when_noncoord="include", 723 ) 724 # this assertion depends on the tokenizer having exactly one token for the connector 725 # which is also why we "include" above 726 # the connector token is discarded below 727 assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }" # noqa: PLR2004 728 assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, ( 729 f"invalid edge: {e = } {e_coords = }" 730 ) 731 e_coords_first: CoordTup = e_coords[0] # type: ignore[assignment] 732 e_coords_last: CoordTup = e_coords[-1] # type: ignore[assignment] 733 coordinates.append((e_coords_first, e_coords_last)) 734 735 assert all(len(c) == DIM_2 for c in coordinates), ( 736 f"invalid coordinates: {coordinates = }" 737 ) 738 adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates) 739 assert tuple(adj_list.shape) == ( 740 len(coordinates), 741 2, 742 2, 743 ), f"invalid adj_list: {adj_list.shape = } {coordinates = }" 744 745 output_maze: LatticeMaze = cls.from_adj_list(adj_list) 746 747 # add start and end positions 748 # ======================================== 749 is_targeted: bool = False 750 if all( 751 x in tokens 752 for x in ( 753 SPECIAL_TOKENS.ORIGIN_START, 754 SPECIAL_TOKENS.ORIGIN_END, 755 SPECIAL_TOKENS.TARGET_START, 756 SPECIAL_TOKENS.TARGET_END, 757 ) 758 ): 759 start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 760 get_origin_tokens(tokens), 761 when_noncoord="error", 762 ) 763 end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 764 get_target_tokens(tokens), 765 when_noncoord="error", 766 ) 767 assert len(start_pos_list) == 1, ( 768 f"invalid start_pos_list: {start_pos_list = }" 769 ) 770 assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }" 771 772 start_pos: CoordTup = start_pos_list[0] 773 end_pos: CoordTup = end_pos_list[0] 774 775 output_maze = TargetedLatticeMaze.from_lattice_maze( 776 lattice_maze=output_maze, 777 start_pos=start_pos, 778 end_pos=end_pos, 779 ) 780 781 is_targeted = True 782 783 if all( 784 x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END) 785 ): 786 assert is_targeted, "maze must be targeted to have a solution" 787 solution: list[CoordTup] = maze_tokenizer.strings_to_coords( 788 get_path_tokens(tokens, trim_end=True), 789 when_noncoord="error", 790 ) 791 output_maze = SolvedMaze.from_targeted_lattice_maze( 792 # HACK: I think this is fine, but im not sure 793 targeted_lattice_maze=output_maze, # type: ignore[arg-type] 794 solution=solution, 795 ) 796 797 return output_maze 798 799 # TODO: any way to get return type hinting working for this? 800 @classmethod 801 def from_tokens( 802 cls, 803 tokens: list[str], 804 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 805 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 806 """Constructs a maze from a tokenization. 807 808 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 809 """ 810 # HACK: type ignores here fine since we check the instance 811 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 812 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 813 if ( 814 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 815 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 816 ): 817 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 818 raise NotImplementedError( 819 err_msg, 820 ) 821 822 if isinstance(tokens, str): 823 tokens = tokens.split() 824 825 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 826 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 827 else: 828 raise NotImplementedError("only AOTP tokenization is supported") 829 830 # ============================================================ 831 # to and from pixels 832 # ============================================================ 833 def _as_pixels_bw(self) -> BinaryPixelGrid: 834 assert self.lattice_dim == DIM_2, "only 2D mazes are supported" 835 # Create an empty pixel grid with walls 836 pixel_grid: Int[np.ndarray, "x y"] = np.full( 837 (self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1), 838 False, 839 dtype=np.bool_, 840 ) 841 842 # Set white nodes 843 pixel_grid[1::2, 1::2] = True 844 845 # Set white connections (downward) 846 for i, row in enumerate(self.connection_list[0]): 847 for j, connected in enumerate(row): 848 if connected: 849 pixel_grid[i * 2 + 2, j * 2 + 1] = True 850 851 # Set white connections (rightward) 852 for i, row in enumerate(self.connection_list[1]): 853 for j, connected in enumerate(row): 854 if connected: 855 pixel_grid[i * 2 + 1, j * 2 + 2] = True 856 857 return pixel_grid 858 859 def as_pixels( 860 self, 861 show_endpoints: bool = True, 862 show_solution: bool = True, 863 ) -> PixelGrid: 864 """convert the maze to a pixel grid 865 866 - useful as a simpler way of plotting the maze than the more complex `MazePlot` 867 - the same underlying representation as `as_ascii` but as an image 868 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data 869 """ 870 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 871 # but solution, start_pos, end_pos not always defined 872 # but its fine since we explicitly check the type 873 if show_solution and not show_endpoints: 874 raise ValueError("show_solution=True requires show_endpoints=True") 875 # convert original bool pixel grid to RGB 876 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 877 pixel_grid: PixelGrid = np.full( 878 (*pixel_grid_bw.shape, 3), 879 PixelColors.WALL, 880 dtype=np.uint8, 881 ) 882 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 883 884 if self.__class__ == LatticeMaze: 885 return pixel_grid 886 887 # set endpoints for TargetedLatticeMaze 888 if self.__class__ == TargetedLatticeMaze: 889 if show_endpoints: 890 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 891 PixelColors.START 892 ) 893 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 894 PixelColors.END 895 ) 896 return pixel_grid 897 898 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 899 if show_solution: 900 for coord in self.solution: # type: ignore[attr-defined] 901 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 902 903 # set pixels between coords 904 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 905 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 906 # check they are adjacent using norm 907 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 908 f"Coords {coord} and {next_coord} are not adjacent" 909 ) 910 # set pixel between them 911 pixel_grid[ 912 coord[0] * 2 + 1 + next_coord[0] - coord[0], 913 coord[1] * 2 + 1 + next_coord[1] - coord[1], 914 ] = PixelColors.PATH 915 916 # set endpoints (again, since path would overwrite them) 917 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 918 PixelColors.START 919 ) 920 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 921 PixelColors.END 922 ) 923 924 return pixel_grid 925 926 @classmethod 927 def _from_pixel_grid_bw( 928 cls, 929 pixel_grid: BinaryPixelGrid, 930 ) -> tuple[ConnectionList, tuple[int, int]]: 931 grid_shape: tuple[int, int] = ( 932 pixel_grid.shape[0] // 2, 933 pixel_grid.shape[1] // 2, 934 ) 935 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_) 936 937 # Extract downward connections 938 connection_list[0] = pixel_grid[2::2, 1::2] 939 940 # Extract rightward connections 941 connection_list[1] = pixel_grid[1::2, 2::2] 942 943 return connection_list, grid_shape 944 945 @classmethod 946 def _from_pixel_grid_with_positions( 947 cls, 948 pixel_grid: PixelGrid | BinaryPixelGrid, 949 marked_positions: dict[str, RGB], 950 ) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]: 951 # Convert RGB pixel grid to Bool pixel grid 952 # error: Incompatible types in assignment (expression has type 953 # "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]", 954 # variable has type "ndarray[Any, Any]") [assignment] 955 pixel_grid_bw: BinaryPixelGrid = ~np.all( # type: ignore[assignment] 956 pixel_grid == PixelColors.WALL, 957 axis=-1, 958 ) 959 connection_list: ConnectionList 960 grid_shape: tuple[int, int] 961 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw) 962 963 # Find any marked positions 964 out_positions: dict[str, CoordArray] = dict() 965 for key, color in marked_positions.items(): 966 pos_temp: Int[np.ndarray, "x y"] = np.argwhere( 967 np.all(pixel_grid == color, axis=-1), 968 ) 969 pos_save: list[CoordTup] = list() 970 for pos in pos_temp: 971 # if it is a coordinate and not connection (transform position, %2==1) 972 if pos[0] % 2 == 1 and pos[1] % 2 == 1: 973 pos_save.append((pos[0] // 2, pos[1] // 2)) 974 975 out_positions[key] = np.array(pos_save) 976 977 return connection_list, grid_shape, out_positions 978 979 @classmethod 980 def from_pixels( 981 cls, 982 pixel_grid: PixelGrid, 983 ) -> "LatticeMaze": 984 """create a LatticeMaze from a pixel grid. reverse of `as_pixels` 985 986 # Raises: 987 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` 988 """ 989 connection_list: ConnectionList 990 grid_shape: tuple[int, int] 991 992 # if a binary pixel grid, return regular LatticeMaze 993 if len(pixel_grid.shape) == 2: # noqa: PLR2004 994 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 995 return LatticeMaze(connection_list=connection_list) 996 997 # otherwise, detect and check it's valid 998 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 999 if cls not in cls_detected.__mro__: 1000 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" 1001 raise ValueError( 1002 err_msg, 1003 ) 1004 1005 ( 1006 connection_list, 1007 grid_shape, 1008 marked_pos, 1009 ) = cls._from_pixel_grid_with_positions( 1010 pixel_grid=pixel_grid, 1011 marked_positions=dict( 1012 start=PixelColors.START, 1013 end=PixelColors.END, 1014 solution=PixelColors.PATH, 1015 ), 1016 ) 1017 # if we wanted a LatticeMaze, return it 1018 if cls == LatticeMaze: 1019 return LatticeMaze(connection_list=connection_list) 1020 1021 # otherwise, keep going 1022 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 1023 1024 # start and end pos 1025 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 1026 assert start_pos_arr.shape == ( 1027 1, 1028 2, 1029 ), ( 1030 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1031 ) 1032 assert end_pos_arr.shape == ( 1033 1, 1034 2, 1035 ), ( 1036 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1037 ) 1038 1039 start_pos: Coord = start_pos_arr[0] 1040 end_pos: Coord = end_pos_arr[0] 1041 1042 # return a TargetedLatticeMaze if that's what we wanted 1043 if cls == TargetedLatticeMaze: 1044 return TargetedLatticeMaze( 1045 connection_list=connection_list, 1046 start_pos=start_pos, 1047 end_pos=end_pos, 1048 ) 1049 1050 # raw solution, only contains path elements and not start or end 1051 solution_raw: CoordArray = marked_pos["solution"] 1052 if len(solution_raw.shape) == 2: # noqa: PLR2004 1053 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 1054 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 1055 ) 1056 elif solution_raw.shape == (0,): 1057 # the solution and end should be immediately adjacent 1058 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 1059 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 1060 ) 1061 1062 # order the solution, by creating a list from the start to the end 1063 # add end pos, since we will iterate over all these starting from the start pos 1064 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 1065 tuple(end_pos), 1066 ] 1067 # solution starts with start point 1068 solution: list[CoordTup] = [tuple(start_pos)] 1069 while solution[-1] != tuple(end_pos): 1070 # use `get_coord_neighbors` to find connected neighbors 1071 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 1072 # TODO: make this less ugly 1073 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 1074 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 1075 ) 1076 # neighbors = neighbors[:, [1, 0]] 1077 # filter out neighbors that are not in the raw solution 1078 neighbors_filtered: CoordArray = np.array( 1079 [ 1080 coord 1081 for coord in neighbors 1082 if ( 1083 tuple(coord) in solution_raw_list 1084 and tuple(coord) not in solution 1085 ) 1086 ], 1087 ) 1088 # assert only one element is left, and then add it to the solution 1089 assert neighbors_filtered.shape == ( 1090 1, 1091 2, 1092 ), ( 1093 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 1094 ) 1095 solution.append(tuple(neighbors_filtered[0])) 1096 1097 # assert the solution is complete 1098 assert solution[0] == tuple(start_pos), ( 1099 f"solution {solution} does not start at start_pos {start_pos}" 1100 ) 1101 assert solution[-1] == tuple(end_pos), ( 1102 f"solution {solution} does not end at end_pos {end_pos}" 1103 ) 1104 1105 return cls( 1106 connection_list=np.array(connection_list), 1107 solution=np.array(solution), # type: ignore[call-arg] 1108 ) 1109 1110 # ============================================================ 1111 # to and from ASCII 1112 # ============================================================ 1113 def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]: 1114 # Get the pixel grid using to_pixels(). 1115 pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw() 1116 1117 # Replace pixel values with ASCII characters. 1118 ascii_grid: Shaped[np.ndarray, "x y"] = np.full( 1119 pixel_grid.shape, 1120 AsciiChars.WALL, 1121 dtype=str, 1122 ) 1123 ascii_grid[pixel_grid == True] = AsciiChars.OPEN # noqa: E712 1124 1125 return ascii_grid 1126 1127 def as_ascii( 1128 self, 1129 show_endpoints: bool = True, 1130 show_solution: bool = True, 1131 ) -> str: 1132 """return an ASCII grid of the maze 1133 1134 useful for debugging in the terminal, or as it's own format 1135 1136 can be reversed with `LatticeMaze.from_ascii()` 1137 """ 1138 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 1139 pixel_grid: PixelGrid = self.as_pixels( 1140 show_endpoints=show_endpoints, 1141 show_solution=show_solution, 1142 ) 1143 1144 chars_replace: tuple = tuple() 1145 if show_endpoints: 1146 chars_replace += (AsciiChars.START, AsciiChars.END) 1147 if show_solution: 1148 chars_replace += (AsciiChars.PATH,) 1149 1150 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1151 if ascii_char in chars_replace: 1152 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 1153 1154 return "\n".join("".join(row) for row in ascii_grid) 1155 1156 @classmethod 1157 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 1158 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" 1159 lines: list[str] = ascii_str.strip().split("\n") 1160 lines = [line.strip() for line in lines] 1161 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 1162 [list(line) for line in lines], 1163 dtype=str, 1164 ) 1165 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 1166 1167 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1168 pixel_grid[ascii_grid == ascii_char] = pixel_color 1169 1170 return cls.from_pixels(pixel_grid)
lattice maze (nodes on a lattice, connections only to neighboring nodes)
Connection List represents which nodes (N) are connected in each direction.
First and second elements represent rightward and downward connections, respectively.
Example: Connection list: [ [ # down [F T], [F F] ], [ # right [T F], [T F] ] ]
Nodes with connections
N T N F
F T
N T N F
F F
Graph:
N - N
|
N - N
Note: the bottom row connections going down, and the right-hand connections going right, will always be False.
169 @property 170 def grid_n(self) -> int: 171 "grid size as int, raises `AssertionError` if not square" 172 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 173 return self.grid_shape[0]
grid size as int, raises AssertionError
if not square
183 @staticmethod 184 def heuristic(a: CoordTup, b: CoordTup) -> float: 185 """return manhattan distance between two points""" 186 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])
return manhattan distance between two points
192 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 193 """returns whether two nodes are connected""" 194 delta: Coord = b - a 195 if np.abs(delta).sum() != 1: 196 # return false if not even adjacent 197 return False 198 else: 199 # test for wall 200 dim: int = int(np.argmax(np.abs(delta))) 201 clist_node: Coord = a if (delta.sum() > 0) else b 202 return self.connection_list[dim, clist_node[0], clist_node[1]]
returns whether two nodes are connected
204 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 205 """check if a path is valid""" 206 # check path is not empty 207 if len(path) == 0: 208 return empty_is_valid 209 210 # check all coords in bounds of maze 211 if not np.all((path >= 0) & (path < self.grid_shape)): 212 return False 213 214 # check all nodes connected 215 for i in range(len(path) - 1): 216 if not self.nodes_connected(path[i], path[i + 1]): 217 return False 218 return True
check if a path is valid
220 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 221 """Returns an array with the connectivity degree of each coord. 222 223 I.e., how many neighbors each coord has. 224 """ 225 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 226 self.connection_list.astype(np.int8) 227 ) 228 degrees: Int8[np.ndarray, "row col"] = np.sum( 229 int_conn, 230 axis=0, 231 ) # Connections to east and south 232 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 233 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 234 return degrees
Returns an array with the connectivity degree of each coord.
I.e., how many neighbors each coord has.
236 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 237 """Returns an array of the neighboring, connected coords of `c`.""" 238 c = np.array(c) # type: ignore[assignment] 239 neighbors: list[Coord] = [ 240 neighbor 241 for neighbor in (c + NEIGHBORS_MASK) 242 if ( 243 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 244 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 245 and self.nodes_connected(c, neighbor) # connected 246 ) 247 ] 248 249 output: CoordArray = np.array(neighbors) 250 if len(neighbors) > 0: 251 assert output.shape == ( 252 len(neighbors), 253 2, 254 ), ( 255 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 256 ) 257 return output
Returns an array of the neighboring, connected coords of c
.
259 def gen_connected_component_from(self, c: Coord) -> CoordArray: 260 """return the connected component from a given coordinate""" 261 # Stack for DFS 262 stack: list[Coord] = [c] 263 264 # Set to store visited nodes 265 visited: set[CoordTup] = set() 266 267 while stack: 268 current_node: Coord = stack.pop() 269 # this is fine since we know current_node is a coord and thus of length 2 270 visited.add(tuple(current_node)) # type: ignore[arg-type] 271 272 # Get the neighbors of the current node 273 neighbors = self.get_coord_neighbors(current_node) 274 275 # Iterate over neighbors 276 for neighbor in neighbors: 277 if tuple(neighbor) not in visited: 278 stack.append(neighbor) 279 280 return np.array(list(visited))
return the connected component from a given coordinate
282 def find_shortest_path( 283 self, 284 c_start: CoordTup | Coord, 285 c_end: CoordTup | Coord, 286 ) -> CoordArray: 287 """find the shortest path between two coordinates, using A*""" 288 c_start = tuple(c_start) # type: ignore[assignment] 289 c_end = tuple(c_end) # type: ignore[assignment] 290 291 g_score: dict[CoordTup, float] = ( 292 dict() 293 ) # cost of cheapest path to node from start currently known 294 f_score: dict[CoordTup, float] = { 295 c_start: 0.0, 296 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 297 298 # init 299 g_score[c_start] = 0.0 300 g_score[c_start] = self.heuristic(c_start, c_end) 301 302 closed_vtx: set[CoordTup] = set() # nodes already evaluated 303 # nodes to be evaluated 304 # we need a set of the tuples, dont place the ints in the set 305 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 306 source: dict[CoordTup, CoordTup] = ( 307 dict() 308 ) # node immediately preceding each node in the path (currently known shortest path) 309 310 while open_vtx: 311 # get lowest f_score node 312 # mypy cant tell that c is of length 2 313 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] 314 # f_current: float = f_score[c_current] 315 316 # check if goal is reached 317 if c_end == c_current: 318 path: list[CoordTup] = [c_current] 319 p_current: CoordTup = c_current 320 while p_current in source: 321 p_current = source[p_current] 322 path.append(p_current) 323 # ---------------------------------------------------------------------- 324 # this is the only return statement 325 return np.array(path[::-1]) 326 # ---------------------------------------------------------------------- 327 328 # close current node 329 closed_vtx.add(c_current) 330 open_vtx.remove(c_current) 331 332 # update g_score of neighbors 333 _np_neighbor: Coord 334 for _np_neighbor in self.get_coord_neighbors(c_current): 335 neighbor: CoordTup = tuple(_np_neighbor) 336 337 if neighbor in closed_vtx: 338 # already checked 339 continue 340 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 341 342 if neighbor not in open_vtx: 343 # found new vtx, so add 344 open_vtx.add(neighbor) 345 346 elif g_temp >= g_score[neighbor]: 347 # if already knew about this one, but current g_score is worse, skip 348 continue 349 350 # store g_score and source 351 source[neighbor] = c_current 352 g_score[neighbor] = g_temp 353 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 354 355 raise ValueError( 356 "A solution could not be found!", 357 f"{c_start = }, {c_end = }", 358 self.as_ascii(), 359 )
find the shortest path between two coordinates, using A*
361 def get_nodes(self) -> CoordArray: 362 """return a list of all nodes in the maze""" 363 rows: Int[np.ndarray, "x y"] 364 cols: Int[np.ndarray, "x y"] 365 rows, cols = np.meshgrid( 366 range(self.grid_shape[0]), 367 range(self.grid_shape[1]), 368 indexing="ij", 369 ) 370 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 371 return nodes
return a list of all nodes in the maze
373 def get_connected_component(self) -> CoordArray: 374 """get the largest (and assumed only nonsingular) connected component of the maze 375 376 TODO: other connected components? 377 """ 378 if (self.generation_meta is None) or ( 379 self.generation_meta.get("fully_connected", False) 380 ): 381 # for fully connected case, pick any two positions 382 return self.get_nodes() 383 else: 384 # if metadata provided, use visited cells 385 visited_cells: set[CoordTup] | None = self.generation_meta.get( 386 "visited_cells", 387 None, 388 ) 389 if visited_cells is None: 390 # TODO: dynamically generate visited_cells? 391 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 392 raise ValueError( 393 err_msg, 394 ) 395 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 396 return visited_cells_np
get the largest (and assumed only nonsingular) connected component of the maze
TODO: other connected components?
418 def generate_random_path( # noqa: C901 419 self, 420 allowed_start: CoordList | None = None, 421 allowed_end: CoordList | None = None, 422 deadend_start: bool = False, 423 deadend_end: bool = False, 424 endpoints_not_equal: bool = False, 425 except_on_no_valid_endpoint: bool = True, 426 ) -> CoordArray | None: 427 """return a path between randomly chosen start and end nodes within the connected component 428 429 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 430 431 # Parameters: 432 - `allowed_start : CoordList | None` 433 a list of allowed start positions. If `None`, any position in the connected component is allowed 434 (defaults to `None`) 435 - `allowed_end : CoordList | None` 436 a list of allowed end positions. If `None`, any position in the connected component is allowed 437 (defaults to `None`) 438 - `deadend_start : bool` 439 whether to ***force*** the start position to be a deadend (defaults to `False`) 440 (defaults to `False`) 441 - `deadend_end : bool` 442 whether to ***force*** the end position to be a deadend (defaults to `False`) 443 (defaults to `False`) 444 - `endpoints_not_equal : bool` 445 whether to ensure tha the start and end point are not the same 446 (defaults to `False`) 447 - `except_on_no_valid_endpoint : bool` 448 whether to raise an error if no valid start or end positions are found 449 if this is `False`, the function might return `None` and this must be handled by the caller 450 (defaults to `True`) 451 452 # Returns: 453 - `CoordArray` 454 a path between the selected start and end positions 455 456 # Raises: 457 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 458 """ 459 # we can't create a "path" in a single-node maze 460 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 461 f"can't create path in single-node maze: {self.as_ascii()}" 462 ) 463 464 # get connected component 465 connected_component: CoordArray = self.get_connected_component() 466 467 # initialize start and end positions 468 positions: Int[np.int8, "2 2"] 469 470 # if no special conditions on start and end positions 471 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 472 None, 473 None, 474 False, 475 False, 476 ): 477 try: 478 positions = connected_component[ # type: ignore[assignment] 479 np.random.choice( 480 len(connected_component), 481 size=2, 482 replace=False, 483 ) 484 ] 485 except ValueError as e: 486 if except_on_no_valid_endpoint: 487 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" 488 raise NoValidEndpointException( 489 err_msg, 490 ) from e 491 return None 492 493 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 494 495 # handle special conditions 496 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 497 # copy connected component set 498 allowed_start_set: set[CoordTup] = connected_component_set.copy() 499 allowed_end_set: set[CoordTup] = connected_component_set.copy() 500 501 # filter by explicitly allowed start and end positions 502 # '# type: ignore[assignment]' here because the returned tuple can be of any length 503 if allowed_start is not None: 504 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 505 506 if allowed_end is not None: 507 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 508 509 # filter by forcing deadends 510 if deadend_start: 511 allowed_start_set = set( 512 filter( 513 lambda x: len(self.get_coord_neighbors(x)) == 1, 514 allowed_start_set, 515 ), 516 ) 517 518 if deadend_end: 519 allowed_end_set = set( 520 filter( 521 lambda x: len(self.get_coord_neighbors(x)) == 1, 522 allowed_end_set, 523 ), 524 ) 525 526 # check we have valid positions 527 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 528 if except_on_no_valid_endpoint: 529 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 530 raise NoValidEndpointException( 531 err_msg, 532 ) 533 return None 534 535 # randomly select start and end positions 536 try: 537 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 538 start_pos: CoordTup = tuple( # type: ignore[assignment] 539 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], 540 ) 541 if endpoints_not_equal: 542 # remove start position from end positions 543 allowed_end_set.discard(start_pos) 544 end_pos: CoordTup = tuple( # type: ignore[assignment] 545 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], 546 ) 547 except ValueError as e: 548 if except_on_no_valid_endpoint: 549 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 550 raise NoValidEndpointException( 551 err_msg, 552 ) from e 553 return None 554 555 return self.find_shortest_path(start_pos, end_pos)
return a path between randomly chosen start and end nodes within the connected component
Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
Parameters:
allowed_start : CoordList | None
a list of allowed start positions. IfNone
, any position in the connected component is allowed (defaults toNone
)allowed_end : CoordList | None
a list of allowed end positions. IfNone
, any position in the connected component is allowed (defaults toNone
)deadend_start : bool
whether to force the start position to be a deadend (defaults toFalse
) (defaults toFalse
)deadend_end : bool
whether to force the end position to be a deadend (defaults toFalse
) (defaults toFalse
)endpoints_not_equal : bool
whether to ensure tha the start and end point are not the same (defaults toFalse
)except_on_no_valid_endpoint : bool
whether to raise an error if no valid start or end positions are found if this isFalse
, the function might returnNone
and this must be handled by the caller (defaults toTrue
)
Returns:
CoordArray
a path between the selected start and end positions
Raises:
NoValidEndpointException
: if no valid start or end positions are found, andexcept_on_no_valid_endpoint
isTrue
560 def as_adj_list( 561 self, 562 shuffle_d0: bool = True, 563 shuffle_d1: bool = True, 564 ) -> Int8[np.ndarray, "conn start_end coord"]: 565 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" 566 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)
return the maze as an adjacency list, wraps maze_dataset.token_utils.connection_list_to_adj_list
568 @classmethod 569 def from_adj_list( 570 cls, 571 adj_list: Int8[np.ndarray, "conn start_end coord"], 572 ) -> "LatticeMaze": 573 """create a LatticeMaze from a list of connections 574 575 > [!NOTE] 576 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 577 """ 578 # this is where it would probably break for rectangular mazes 579 grid_n: int = adj_list.max() + 1 580 581 connection_list: ConnectionList = np.zeros( 582 (2, grid_n, grid_n), 583 dtype=np.bool_, 584 ) 585 586 for c_start, c_end in adj_list: 587 # check that exactly 1 coordinate matches 588 if (c_start == c_end).sum() != 1: 589 raise ValueError("invalid connection") 590 591 # get the direction 592 d: int = (c_start != c_end).argmax() 593 594 x: int 595 y: int 596 # pick whichever has the lesser value in the direction `d` 597 if c_start[d] < c_end[d]: 598 x, y = c_start 599 else: 600 x, y = c_end 601 602 connection_list[d, x, y] = True 603 604 return LatticeMaze( 605 connection_list=connection_list, 606 )
create a LatticeMaze from a list of connections
This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
608 def as_adj_list_tokens(self) -> list[str | CoordTup]: 609 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" 610 warnings.warn( 611 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 612 TokenizerDeprecationWarning, 613 ) 614 return [ 615 SPECIAL_TOKENS.ADJLIST_START, 616 *chain.from_iterable( # type: ignore[list-item] 617 [ 618 [ 619 tuple(c_s), 620 SPECIAL_TOKENS.CONNECTOR, 621 tuple(c_e), 622 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 623 ] 624 for c_s, c_e in self.as_adj_list() 625 ], 626 ), 627 SPECIAL_TOKENS.ADJLIST_END, 628 ]
(deprecated!) turn the maze into adjacency list tokens, use MazeTokenizerModular
instead
680 def as_tokens( 681 self, 682 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 683 ) -> list[str]: 684 """serialize maze and solution to tokens""" 685 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 686 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 687 else: 688 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type]
serialize maze and solution to tokens
800 @classmethod 801 def from_tokens( 802 cls, 803 tokens: list[str], 804 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 805 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 806 """Constructs a maze from a tokenization. 807 808 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 809 """ 810 # HACK: type ignores here fine since we check the instance 811 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 812 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 813 if ( 814 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 815 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 816 ): 817 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 818 raise NotImplementedError( 819 err_msg, 820 ) 821 822 if isinstance(tokens, str): 823 tokens = tokens.split() 824 825 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 826 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 827 else: 828 raise NotImplementedError("only AOTP tokenization is supported")
Constructs a maze from a tokenization.
Only legacy tokenizers and their MazeTokenizerModular
analogs are supported.
859 def as_pixels( 860 self, 861 show_endpoints: bool = True, 862 show_solution: bool = True, 863 ) -> PixelGrid: 864 """convert the maze to a pixel grid 865 866 - useful as a simpler way of plotting the maze than the more complex `MazePlot` 867 - the same underlying representation as `as_ascii` but as an image 868 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data 869 """ 870 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 871 # but solution, start_pos, end_pos not always defined 872 # but its fine since we explicitly check the type 873 if show_solution and not show_endpoints: 874 raise ValueError("show_solution=True requires show_endpoints=True") 875 # convert original bool pixel grid to RGB 876 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 877 pixel_grid: PixelGrid = np.full( 878 (*pixel_grid_bw.shape, 3), 879 PixelColors.WALL, 880 dtype=np.uint8, 881 ) 882 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 883 884 if self.__class__ == LatticeMaze: 885 return pixel_grid 886 887 # set endpoints for TargetedLatticeMaze 888 if self.__class__ == TargetedLatticeMaze: 889 if show_endpoints: 890 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 891 PixelColors.START 892 ) 893 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 894 PixelColors.END 895 ) 896 return pixel_grid 897 898 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 899 if show_solution: 900 for coord in self.solution: # type: ignore[attr-defined] 901 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 902 903 # set pixels between coords 904 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 905 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 906 # check they are adjacent using norm 907 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 908 f"Coords {coord} and {next_coord} are not adjacent" 909 ) 910 # set pixel between them 911 pixel_grid[ 912 coord[0] * 2 + 1 + next_coord[0] - coord[0], 913 coord[1] * 2 + 1 + next_coord[1] - coord[1], 914 ] = PixelColors.PATH 915 916 # set endpoints (again, since path would overwrite them) 917 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 918 PixelColors.START 919 ) 920 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 921 PixelColors.END 922 ) 923 924 return pixel_grid
convert the maze to a pixel grid
- useful as a simpler way of plotting the maze than the more complex
MazePlot
- the same underlying representation as
as_ascii
but as an image - used in
RasterizedMazeDataset
, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data
979 @classmethod 980 def from_pixels( 981 cls, 982 pixel_grid: PixelGrid, 983 ) -> "LatticeMaze": 984 """create a LatticeMaze from a pixel grid. reverse of `as_pixels` 985 986 # Raises: 987 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` 988 """ 989 connection_list: ConnectionList 990 grid_shape: tuple[int, int] 991 992 # if a binary pixel grid, return regular LatticeMaze 993 if len(pixel_grid.shape) == 2: # noqa: PLR2004 994 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 995 return LatticeMaze(connection_list=connection_list) 996 997 # otherwise, detect and check it's valid 998 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 999 if cls not in cls_detected.__mro__: 1000 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" 1001 raise ValueError( 1002 err_msg, 1003 ) 1004 1005 ( 1006 connection_list, 1007 grid_shape, 1008 marked_pos, 1009 ) = cls._from_pixel_grid_with_positions( 1010 pixel_grid=pixel_grid, 1011 marked_positions=dict( 1012 start=PixelColors.START, 1013 end=PixelColors.END, 1014 solution=PixelColors.PATH, 1015 ), 1016 ) 1017 # if we wanted a LatticeMaze, return it 1018 if cls == LatticeMaze: 1019 return LatticeMaze(connection_list=connection_list) 1020 1021 # otherwise, keep going 1022 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 1023 1024 # start and end pos 1025 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 1026 assert start_pos_arr.shape == ( 1027 1, 1028 2, 1029 ), ( 1030 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1031 ) 1032 assert end_pos_arr.shape == ( 1033 1, 1034 2, 1035 ), ( 1036 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1037 ) 1038 1039 start_pos: Coord = start_pos_arr[0] 1040 end_pos: Coord = end_pos_arr[0] 1041 1042 # return a TargetedLatticeMaze if that's what we wanted 1043 if cls == TargetedLatticeMaze: 1044 return TargetedLatticeMaze( 1045 connection_list=connection_list, 1046 start_pos=start_pos, 1047 end_pos=end_pos, 1048 ) 1049 1050 # raw solution, only contains path elements and not start or end 1051 solution_raw: CoordArray = marked_pos["solution"] 1052 if len(solution_raw.shape) == 2: # noqa: PLR2004 1053 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 1054 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 1055 ) 1056 elif solution_raw.shape == (0,): 1057 # the solution and end should be immediately adjacent 1058 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 1059 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 1060 ) 1061 1062 # order the solution, by creating a list from the start to the end 1063 # add end pos, since we will iterate over all these starting from the start pos 1064 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 1065 tuple(end_pos), 1066 ] 1067 # solution starts with start point 1068 solution: list[CoordTup] = [tuple(start_pos)] 1069 while solution[-1] != tuple(end_pos): 1070 # use `get_coord_neighbors` to find connected neighbors 1071 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 1072 # TODO: make this less ugly 1073 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 1074 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 1075 ) 1076 # neighbors = neighbors[:, [1, 0]] 1077 # filter out neighbors that are not in the raw solution 1078 neighbors_filtered: CoordArray = np.array( 1079 [ 1080 coord 1081 for coord in neighbors 1082 if ( 1083 tuple(coord) in solution_raw_list 1084 and tuple(coord) not in solution 1085 ) 1086 ], 1087 ) 1088 # assert only one element is left, and then add it to the solution 1089 assert neighbors_filtered.shape == ( 1090 1, 1091 2, 1092 ), ( 1093 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 1094 ) 1095 solution.append(tuple(neighbors_filtered[0])) 1096 1097 # assert the solution is complete 1098 assert solution[0] == tuple(start_pos), ( 1099 f"solution {solution} does not start at start_pos {start_pos}" 1100 ) 1101 assert solution[-1] == tuple(end_pos), ( 1102 f"solution {solution} does not end at end_pos {end_pos}" 1103 ) 1104 1105 return cls( 1106 connection_list=np.array(connection_list), 1107 solution=np.array(solution), # type: ignore[call-arg] 1108 )
create a LatticeMaze from a pixel grid. reverse of as_pixels
Raises:
ValueError
: if the pixel grid cannot be cast to aLatticeMaze
-- it's probably aTargetedLatticeMaze
orSolvedMaze
1127 def as_ascii( 1128 self, 1129 show_endpoints: bool = True, 1130 show_solution: bool = True, 1131 ) -> str: 1132 """return an ASCII grid of the maze 1133 1134 useful for debugging in the terminal, or as it's own format 1135 1136 can be reversed with `LatticeMaze.from_ascii()` 1137 """ 1138 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 1139 pixel_grid: PixelGrid = self.as_pixels( 1140 show_endpoints=show_endpoints, 1141 show_solution=show_solution, 1142 ) 1143 1144 chars_replace: tuple = tuple() 1145 if show_endpoints: 1146 chars_replace += (AsciiChars.START, AsciiChars.END) 1147 if show_solution: 1148 chars_replace += (AsciiChars.PATH,) 1149 1150 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1151 if ascii_char in chars_replace: 1152 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 1153 1154 return "\n".join("".join(row) for row in ascii_grid)
return an ASCII grid of the maze
useful for debugging in the terminal, or as it's own format
can be reversed with LatticeMaze.from_ascii()
1156 @classmethod 1157 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 1158 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" 1159 lines: list[str] = ascii_str.strip().split("\n") 1160 lines = [line.strip() for line in lines] 1161 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 1162 [list(line) for line in lines], 1163 dtype=str, 1164 ) 1165 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 1166 1167 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1168 pixel_grid[ascii_grid == ascii_char] = pixel_color 1169 1170 return cls.from_pixels(pixel_grid)
get a LatticeMaze
from an ASCII representation (reverses LaticeMaze.as_ascii
)
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
The type of the None singleton.
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
The type of the None singleton.
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
56def set_serialize_minimal_threshold(threshold: int | None) -> None: 57 "get the global SERIALIZE_MINIMAL_THRESHOLD" 58 global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603 59 SERIALIZE_MINIMAL_THRESHOLD = threshold
get the global SERIALIZE_MINIMAL_THRESHOLD
21def register_maze_filter( 22 method: typing.Callable[[SolvedMaze, typing.Any], bool], 23) -> DatasetFilterFunc: 24 """register a maze filter, casting it to operate over the whole list of mazes 25 26 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 27 28 this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays 29 """ 30 31 @functools.wraps(method) 32 def wrapper(dataset: MazeDataset, *args, **kwargs) -> MazeDataset: 33 # copy and filter 34 new_dataset: MazeDataset = copy.deepcopy( 35 MazeDataset( 36 cfg=dataset.cfg, 37 mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)], 38 ), 39 ) 40 # update the config 41 new_dataset.cfg.applied_filters.append( 42 dict(name=method.__name__, args=args, kwargs=kwargs), 43 ) 44 new_dataset.update_self_config() 45 return new_dataset 46 47 return wrapper
register a maze filter, casting it to operate over the whole list of mazes
method should be a staticmethod of a namespace class registered with register_filter_namespace_for_dataset
this is a more restricted version of register_dataset_filter
that removes the need for boilerplate for operating over the arrays
54class LatticeMazeGenerators: 55 """namespace for lattice maze generation algorithms 56 57 examples of generated mazes can be found here: 58 https://understanding-search.github.io/maze-dataset/examples/maze_examples.html 59 """ 60 61 @staticmethod 62 def gen_dfs( 63 grid_shape: Coord | CoordTup, 64 *, 65 lattice_dim: int = 2, 66 accessible_cells: float | None = None, 67 max_tree_depth: float | None = None, 68 do_forks: bool = True, 69 randomized_stack: bool = False, 70 start_coord: Coord | None = None, 71 ) -> LatticeMaze: 72 """generate a lattice maze using depth first search, iterative 73 74 # Arguments 75 - `grid_shape: Coord`: the shape of the grid 76 - `lattice_dim: int`: the dimension of the lattice 77 (default: `2`) 78 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 79 (default: `None`) 80 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 81 (default: `None`) 82 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 83 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 84 85 # algorithm 86 1. Choose the initial cell, mark it as visited and push it to the stack 87 2. While the stack is not empty 88 1. Pop a cell from the stack and make it a current cell 89 2. If the current cell has any neighbours which have not been visited 90 1. Push the current cell to the stack 91 2. Choose one of the unvisited neighbours 92 3. Remove the wall between the current cell and the chosen cell 93 4. Mark the chosen cell as visited and push it to the stack 94 """ 95 # Default values if no constraints have been passed 96 grid_shape_: Coord = np.array(grid_shape) 97 n_total_cells: int = int(np.prod(grid_shape_)) 98 99 n_accessible_cells: int 100 if accessible_cells is None: 101 n_accessible_cells = n_total_cells 102 elif isinstance(accessible_cells, float): 103 assert accessible_cells <= 1, ( 104 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 105 ) 106 107 n_accessible_cells = int(accessible_cells * n_total_cells) 108 else: 109 assert isinstance(accessible_cells, int) 110 n_accessible_cells = accessible_cells 111 112 if max_tree_depth is None: 113 max_tree_depth = ( 114 2 * n_total_cells 115 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 116 elif isinstance(max_tree_depth, float): 117 assert max_tree_depth <= 1, ( 118 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 119 ) 120 121 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 122 123 # choose a random start coord 124 start_coord = _random_start_coord(grid_shape_, start_coord) 125 126 # initialize the maze with no connections 127 connection_list: ConnectionList = np.zeros( 128 (lattice_dim, grid_shape_[0], grid_shape_[1]), 129 dtype=np.bool_, 130 ) 131 132 # initialize the stack with the target coord 133 visited_cells: set[tuple[int, int]] = set() 134 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 135 stack: list[Coord] = [start_coord] 136 137 # initialize tree_depth_counter 138 current_tree_depth: int = 1 139 140 # loop until the stack is empty or n_connected_cells is reached 141 while stack and (len(visited_cells) < n_accessible_cells): 142 # get the current coord from the stack 143 current_coord: Coord 144 if randomized_stack: 145 current_coord = stack.pop(random.randint(0, len(stack) - 1)) 146 else: 147 current_coord = stack.pop() 148 149 # filter neighbors by being within grid bounds and being unvisited 150 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 151 (neighbor, delta) 152 for neighbor, delta in zip( 153 current_coord + NEIGHBORS_MASK, 154 NEIGHBORS_MASK, 155 strict=False, 156 ) 157 if ( 158 (tuple(neighbor) not in visited_cells) 159 and (0 <= neighbor[0] < grid_shape_[0]) 160 and (0 <= neighbor[1] < grid_shape_[1]) 161 ) 162 ] 163 164 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 165 if unvisited_neighbors_deltas and ( 166 current_tree_depth <= max_tree_depth / 2 167 ): 168 # if we want a maze without forks, simply don't add the current coord back to the stack 169 if do_forks and (len(unvisited_neighbors_deltas) > 1): 170 stack.append(current_coord) 171 172 # choose one of the unvisited neighbors 173 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) 174 175 # add connection 176 dim: int = int(np.argmax(np.abs(delta))) 177 # if positive, down/right from current coord 178 # if negative, up/left from current coord (down/right from neighbor) 179 clist_node: Coord = ( 180 current_coord if (delta.sum() > 0) else chosen_neighbor 181 ) 182 connection_list[dim, clist_node[0], clist_node[1]] = True 183 184 # add to visited cells and stack 185 visited_cells.add(tuple(chosen_neighbor)) 186 stack.append(chosen_neighbor) 187 188 # Update current tree depth 189 current_tree_depth += 1 190 else: 191 current_tree_depth -= 1 192 193 return LatticeMaze( 194 connection_list=connection_list, 195 generation_meta=dict( 196 func_name="gen_dfs", 197 grid_shape=grid_shape_, 198 start_coord=start_coord, 199 n_accessible_cells=int(n_accessible_cells), 200 max_tree_depth=int(max_tree_depth), 201 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 202 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 203 # treated as fully connected even when it is most certainly not, causing solving the maze to break 204 fully_connected=bool(len(visited_cells) == n_total_cells), 205 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 206 ), 207 ) 208 209 @staticmethod 210 def gen_prim( 211 grid_shape: Coord | CoordTup, 212 lattice_dim: int = 2, 213 accessible_cells: float | None = None, 214 max_tree_depth: float | None = None, 215 do_forks: bool = True, 216 start_coord: Coord | None = None, 217 ) -> LatticeMaze: 218 "(broken!) generate a lattice maze using Prim's algorithm" 219 warnings.warn( 220 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 221 ) 222 return LatticeMazeGenerators.gen_dfs( 223 grid_shape=grid_shape, 224 lattice_dim=lattice_dim, 225 accessible_cells=accessible_cells, 226 max_tree_depth=max_tree_depth, 227 do_forks=do_forks, 228 start_coord=start_coord, 229 randomized_stack=True, 230 ) 231 232 @staticmethod 233 def gen_wilson( 234 grid_shape: Coord | CoordTup, 235 **kwargs, 236 ) -> LatticeMaze: 237 """Generate a lattice maze using Wilson's algorithm. 238 239 # Algorithm 240 Wilson's algorithm generates an unbiased (random) maze 241 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 242 acyclic and all cells are part of a unique connected space. 243 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 244 """ 245 assert not kwargs, ( 246 f"gen_wilson does not take any additional arguments, got {kwargs = }" 247 ) 248 249 grid_shape_: Coord = np.array(grid_shape) 250 251 # Initialize grid and visited cells 252 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 253 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 254 255 # Choose a random cell and mark it as visited 256 start_coord: Coord = _random_start_coord(grid_shape_, None) 257 visited[start_coord[0], start_coord[1]] = True 258 del start_coord 259 260 while not visited.all(): 261 # Perform loop-erased random walk from another random cell 262 263 # Choose walk_start only from unvisited cells 264 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 265 walk_start: Coord = unvisited_coords[ 266 np.random.choice(unvisited_coords.shape[0]) 267 ] 268 269 # Perform the random walk 270 path: list[Coord] = [walk_start] 271 current: Coord = walk_start 272 273 # exit the loop once the current path hits a visited cell 274 while not visited[current[0], current[1]]: 275 # find a valid neighbor (one always exists on a lattice) 276 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 277 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 278 279 # Check for loop 280 loop_exit: int | None = None 281 for i, p in enumerate(path): 282 if np.array_equal(next_cell, p): 283 loop_exit = i 284 break 285 286 # erase the loop, or continue the walk 287 if loop_exit is not None: 288 # this removes everything after and including the loop start 289 path = path[: loop_exit + 1] 290 # reset current cell to end of path 291 current = path[-1] 292 else: 293 path.append(next_cell) 294 current = next_cell 295 296 # Add the path to the maze 297 for i in range(len(path) - 1): 298 c_1: Coord = path[i] 299 c_2: Coord = path[i + 1] 300 301 # find the dimension of the connection 302 delta: Coord = c_2 - c_1 303 dim: int = int(np.argmax(np.abs(delta))) 304 305 # if positive, down/right from current coord 306 # if negative, up/left from current coord (down/right from neighbor) 307 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 308 connection_list[dim, clist_node[0], clist_node[1]] = True 309 visited[c_1[0], c_1[1]] = True 310 # we dont add c_2 because the last c_2 will have already been visited 311 312 return LatticeMaze( 313 connection_list=connection_list, 314 generation_meta=dict( 315 func_name="gen_wilson", 316 grid_shape=grid_shape_, 317 fully_connected=True, 318 ), 319 ) 320 321 @staticmethod 322 def gen_percolation( 323 grid_shape: Coord | CoordTup, 324 p: float = 0.4, 325 lattice_dim: int = 2, 326 start_coord: Coord | None = None, 327 ) -> LatticeMaze: 328 """generate a lattice maze using simple percolation 329 330 note that p in the range (0.4, 0.7) gives the most interesting mazes 331 332 # Arguments 333 - `grid_shape: Coord`: the shape of the grid 334 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 335 - `p: float`: the probability of a cell being accessible (default: `0.5`) 336 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 337 """ 338 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 339 grid_shape_: Coord = np.array(grid_shape) 340 341 start_coord = _random_start_coord(grid_shape_, start_coord) 342 343 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 344 345 connection_list = _fill_edges_with_walls(connection_list) 346 347 output: LatticeMaze = LatticeMaze( 348 connection_list=connection_list, 349 generation_meta=dict( 350 func_name="gen_percolation", 351 grid_shape=grid_shape_, 352 percolation_p=p, 353 start_coord=start_coord, 354 ), 355 ) 356 357 # generation_meta is sometimes None, but not here since we just made it a dict above 358 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 359 start_coord, 360 ) 361 362 return output 363 364 @staticmethod 365 def gen_dfs_percolation( 366 grid_shape: Coord | CoordTup, 367 p: float = 0.4, 368 lattice_dim: int = 2, 369 accessible_cells: int | None = None, 370 max_tree_depth: int | None = None, 371 start_coord: Coord | None = None, 372 ) -> LatticeMaze: 373 """dfs and then percolation (adds cycles)""" 374 grid_shape_: Coord = np.array(grid_shape) 375 start_coord = _random_start_coord(grid_shape_, start_coord) 376 377 # generate initial maze via dfs 378 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 379 grid_shape=grid_shape_, 380 lattice_dim=lattice_dim, 381 accessible_cells=accessible_cells, 382 max_tree_depth=max_tree_depth, 383 start_coord=start_coord, 384 ) 385 386 # percolate 387 connection_list_perc: np.ndarray = ( 388 np.random.rand(*maze.connection_list.shape) < p 389 ) 390 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 391 392 maze.__dict__["connection_list"] = np.logical_or( 393 maze.connection_list, 394 connection_list_perc, 395 ) 396 397 # generation_meta is sometimes None, but not here since we just made it a dict above 398 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 399 maze.generation_meta["percolation_p"] = p # type: ignore[index] 400 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 401 start_coord, 402 ) 403 404 return maze 405 406 @staticmethod 407 def gen_kruskal( 408 grid_shape: "Coord | CoordTup", 409 lattice_dim: int = 2, 410 start_coord: "Coord | None" = None, 411 ) -> "LatticeMaze": 412 """Generate a maze using Kruskal's algorithm. 413 414 This function generates a random spanning tree over a grid using Kruskal's algorithm. 415 Each cell is treated as a node, and all valid adjacent edges are listed and processed 416 in random order. An edge is added (i.e. its passage carved) only if it connects two cells 417 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) 418 without cycles. 419 420 https://en.wikipedia.org/wiki/Kruskal's_algorithm 421 422 # Parameters: 423 - `grid_shape : Coord | CoordTup` 424 The shape of the maze grid (for example, `(n_rows, n_cols)`). 425 - `lattice_dim : int` 426 The lattice dimension (default is `2`). 427 - `start_coord : Coord | None` 428 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen. 429 - `**kwargs` 430 Additional keyword arguments (currently unused). 431 432 # Returns: 433 - `LatticeMaze` 434 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm. 435 436 # Usage: 437 ```python 438 maze = gen_kruskal((10, 10)) 439 ``` 440 """ 441 assert lattice_dim == 2, ( # noqa: PLR2004 442 "Kruskal's algorithm is only implemented for 2D lattices." 443 ) 444 # Convert grid_shape to a tuple of ints 445 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 446 n_rows, n_cols = grid_shape_ 447 448 # Initialize union-find data structure. 449 parent: dict[tuple[int, int], tuple[int, int]] = {} 450 451 def find(cell: tuple[int, int]) -> tuple[int, int]: 452 while parent[cell] != cell: 453 parent[cell] = parent[parent[cell]] 454 cell = parent[cell] 455 return cell 456 457 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None: 458 root1 = find(cell1) 459 root2 = find(cell2) 460 parent[root2] = root1 461 462 # Initialize each cell as its own set. 463 for i in range(n_rows): 464 for j in range(n_cols): 465 parent[(i, j)] = (i, j) 466 467 # List all possible edges. 468 # For vertical edges (i.e. connecting a cell to its right neighbor): 469 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = [] 470 for i in range(n_rows): 471 for j in range(n_cols - 1): 472 edges.append(((i, j), (i, j + 1), 1)) 473 # For horizontal edges (i.e. connecting a cell to its bottom neighbor): 474 for i in range(n_rows - 1): 475 for j in range(n_cols): 476 edges.append(((i, j), (i + 1, j), 0)) 477 478 # Shuffle the list of edges. 479 import random 480 481 random.shuffle(edges) 482 483 # Initialize connection_list with no connections. 484 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)). 485 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)). 486 import numpy as np 487 488 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 489 490 # Process each edge; if it connects two different trees, union them and carve the passage. 491 for cell1, cell2, direction in edges: 492 if find(cell1) != find(cell2): 493 union(cell1, cell2) 494 if direction == 0: 495 # Horizontal edge: connection is stored in connection_list[0] at cell1. 496 connection_list[0, cell1[0], cell1[1]] = True 497 else: 498 # Vertical edge: connection is stored in connection_list[1] at cell1. 499 connection_list[1, cell1[0], cell1[1]] = True 500 501 if start_coord is None: 502 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 503 504 generation_meta: dict = dict( 505 func_name="gen_kruskal", 506 grid_shape=grid_shape_, 507 start_coord=start_coord, 508 algorithm="kruskal", 509 fully_connected=True, 510 ) 511 return LatticeMaze( 512 connection_list=connection_list, generation_meta=generation_meta 513 ) 514 515 @staticmethod 516 def gen_recursive_division( 517 grid_shape: "Coord | CoordTup", 518 lattice_dim: int = 2, 519 start_coord: "Coord | None" = None, 520 ) -> "LatticeMaze": 521 """Generate a maze using the recursive division algorithm. 522 523 This function generates a maze by recursively dividing the grid with walls and carving a single 524 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent 525 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. 526 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells. 527 528 # Parameters: 529 - `grid_shape : Coord | CoordTup` 530 The shape of the maze grid (e.g., `(n_rows, n_cols)`). 531 - `lattice_dim : int` 532 The lattice dimension (default is `2`). 533 - `start_coord : Coord | None` 534 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen. 535 - `**kwargs` 536 Additional keyword arguments (currently unused). 537 538 # Returns: 539 - `LatticeMaze` 540 A maze represented by a connection list, generated using recursive division. 541 542 # Usage: 543 ```python 544 maze = gen_recursive_division((10, 10)) 545 ``` 546 """ 547 assert lattice_dim == 2, ( # noqa: PLR2004 548 "Recursive division algorithm is only implemented for 2D lattices." 549 ) 550 # Convert grid_shape to a tuple of ints. 551 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 552 n_rows, n_cols = grid_shape_ 553 554 # Initialize connection_list as a fully connected grid. 555 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True. 556 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True. 557 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 558 connection_list[0, : n_rows - 1, :] = True 559 connection_list[1, :, : n_cols - 1] = True 560 561 def divide(x: int, y: int, width: int, height: int) -> None: 562 """Recursively divide the region starting at (x, y) with the given width and height. 563 564 Removes connections along the chosen division line except for one randomly chosen gap. 565 """ 566 if width < 2 or height < 2: # noqa: PLR2004 567 return 568 569 if width > height: 570 # Vertical division. 571 wall_col = random.randint(x + 1, x + width - 1) 572 gap_row = random.randint(y, y + height - 1) 573 for row in range(y, y + height): 574 if row == gap_row: 575 continue 576 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col). 577 if wall_col - 1 < n_cols - 1: 578 connection_list[1, row, wall_col - 1] = False 579 # Recurse on the left and right subregions. 580 divide(x, y, wall_col - x, height) 581 divide(wall_col, y, x + width - wall_col, height) 582 else: 583 # Horizontal division. 584 wall_row = random.randint(y + 1, y + height - 1) 585 gap_col = random.randint(x, x + width - 1) 586 for col in range(x, x + width): 587 if col == gap_col: 588 continue 589 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col). 590 if wall_row - 1 < n_rows - 1: 591 connection_list[0, wall_row - 1, col] = False 592 # Recurse on the top and bottom subregions. 593 divide(x, y, width, wall_row - y) 594 divide(x, wall_row, width, y + height - wall_row) 595 596 # Begin the division on the full grid. 597 divide(0, 0, n_cols, n_rows) 598 599 if start_coord is None: 600 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 601 602 generation_meta: dict = dict( 603 func_name="gen_recursive_division", 604 grid_shape=grid_shape_, 605 start_coord=start_coord, 606 algorithm="recursive_division", 607 fully_connected=True, 608 ) 609 return LatticeMaze( 610 connection_list=connection_list, generation_meta=generation_meta 611 )
namespace for lattice maze generation algorithms
examples of generated mazes can be found here: https://understanding-search.github.io/maze-dataset/examples/maze_examples.html
61 @staticmethod 62 def gen_dfs( 63 grid_shape: Coord | CoordTup, 64 *, 65 lattice_dim: int = 2, 66 accessible_cells: float | None = None, 67 max_tree_depth: float | None = None, 68 do_forks: bool = True, 69 randomized_stack: bool = False, 70 start_coord: Coord | None = None, 71 ) -> LatticeMaze: 72 """generate a lattice maze using depth first search, iterative 73 74 # Arguments 75 - `grid_shape: Coord`: the shape of the grid 76 - `lattice_dim: int`: the dimension of the lattice 77 (default: `2`) 78 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 79 (default: `None`) 80 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 81 (default: `None`) 82 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 83 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 84 85 # algorithm 86 1. Choose the initial cell, mark it as visited and push it to the stack 87 2. While the stack is not empty 88 1. Pop a cell from the stack and make it a current cell 89 2. If the current cell has any neighbours which have not been visited 90 1. Push the current cell to the stack 91 2. Choose one of the unvisited neighbours 92 3. Remove the wall between the current cell and the chosen cell 93 4. Mark the chosen cell as visited and push it to the stack 94 """ 95 # Default values if no constraints have been passed 96 grid_shape_: Coord = np.array(grid_shape) 97 n_total_cells: int = int(np.prod(grid_shape_)) 98 99 n_accessible_cells: int 100 if accessible_cells is None: 101 n_accessible_cells = n_total_cells 102 elif isinstance(accessible_cells, float): 103 assert accessible_cells <= 1, ( 104 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 105 ) 106 107 n_accessible_cells = int(accessible_cells * n_total_cells) 108 else: 109 assert isinstance(accessible_cells, int) 110 n_accessible_cells = accessible_cells 111 112 if max_tree_depth is None: 113 max_tree_depth = ( 114 2 * n_total_cells 115 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 116 elif isinstance(max_tree_depth, float): 117 assert max_tree_depth <= 1, ( 118 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 119 ) 120 121 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 122 123 # choose a random start coord 124 start_coord = _random_start_coord(grid_shape_, start_coord) 125 126 # initialize the maze with no connections 127 connection_list: ConnectionList = np.zeros( 128 (lattice_dim, grid_shape_[0], grid_shape_[1]), 129 dtype=np.bool_, 130 ) 131 132 # initialize the stack with the target coord 133 visited_cells: set[tuple[int, int]] = set() 134 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 135 stack: list[Coord] = [start_coord] 136 137 # initialize tree_depth_counter 138 current_tree_depth: int = 1 139 140 # loop until the stack is empty or n_connected_cells is reached 141 while stack and (len(visited_cells) < n_accessible_cells): 142 # get the current coord from the stack 143 current_coord: Coord 144 if randomized_stack: 145 current_coord = stack.pop(random.randint(0, len(stack) - 1)) 146 else: 147 current_coord = stack.pop() 148 149 # filter neighbors by being within grid bounds and being unvisited 150 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 151 (neighbor, delta) 152 for neighbor, delta in zip( 153 current_coord + NEIGHBORS_MASK, 154 NEIGHBORS_MASK, 155 strict=False, 156 ) 157 if ( 158 (tuple(neighbor) not in visited_cells) 159 and (0 <= neighbor[0] < grid_shape_[0]) 160 and (0 <= neighbor[1] < grid_shape_[1]) 161 ) 162 ] 163 164 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 165 if unvisited_neighbors_deltas and ( 166 current_tree_depth <= max_tree_depth / 2 167 ): 168 # if we want a maze without forks, simply don't add the current coord back to the stack 169 if do_forks and (len(unvisited_neighbors_deltas) > 1): 170 stack.append(current_coord) 171 172 # choose one of the unvisited neighbors 173 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) 174 175 # add connection 176 dim: int = int(np.argmax(np.abs(delta))) 177 # if positive, down/right from current coord 178 # if negative, up/left from current coord (down/right from neighbor) 179 clist_node: Coord = ( 180 current_coord if (delta.sum() > 0) else chosen_neighbor 181 ) 182 connection_list[dim, clist_node[0], clist_node[1]] = True 183 184 # add to visited cells and stack 185 visited_cells.add(tuple(chosen_neighbor)) 186 stack.append(chosen_neighbor) 187 188 # Update current tree depth 189 current_tree_depth += 1 190 else: 191 current_tree_depth -= 1 192 193 return LatticeMaze( 194 connection_list=connection_list, 195 generation_meta=dict( 196 func_name="gen_dfs", 197 grid_shape=grid_shape_, 198 start_coord=start_coord, 199 n_accessible_cells=int(n_accessible_cells), 200 max_tree_depth=int(max_tree_depth), 201 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 202 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 203 # treated as fully connected even when it is most certainly not, causing solving the maze to break 204 fully_connected=bool(len(visited_cells) == n_total_cells), 205 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 206 ), 207 )
generate a lattice maze using depth first search, iterative
Arguments
grid_shape: Coord
: the shape of the gridlattice_dim: int
: the dimension of the lattice (default:2
)accessible_cells: int | float |None
: the number of accessible cells in the maze. IfNone
, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default:None
)max_tree_depth: int | float | None
: the maximum depth of the tree. IfNone
, defaults to2 * accessible_cells
. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default:None
)do_forks: bool
: whether to allow forks in the maze. IfFalse
, the maze will be have no forks and will be a simple hallway.start_coord: Coord | None
: the starting coordinate of the generation algorithm. IfNone
, defaults to a random coordinate.
algorithm
- Choose the initial cell, mark it as visited and push it to the stack
- While the stack is not empty
- Pop a cell from the stack and make it a current cell
- If the current cell has any neighbours which have not been visited
- Push the current cell to the stack
- Choose one of the unvisited neighbours
- Remove the wall between the current cell and the chosen cell
- Mark the chosen cell as visited and push it to the stack
209 @staticmethod 210 def gen_prim( 211 grid_shape: Coord | CoordTup, 212 lattice_dim: int = 2, 213 accessible_cells: float | None = None, 214 max_tree_depth: float | None = None, 215 do_forks: bool = True, 216 start_coord: Coord | None = None, 217 ) -> LatticeMaze: 218 "(broken!) generate a lattice maze using Prim's algorithm" 219 warnings.warn( 220 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 221 ) 222 return LatticeMazeGenerators.gen_dfs( 223 grid_shape=grid_shape, 224 lattice_dim=lattice_dim, 225 accessible_cells=accessible_cells, 226 max_tree_depth=max_tree_depth, 227 do_forks=do_forks, 228 start_coord=start_coord, 229 randomized_stack=True, 230 )
(broken!) generate a lattice maze using Prim's algorithm
232 @staticmethod 233 def gen_wilson( 234 grid_shape: Coord | CoordTup, 235 **kwargs, 236 ) -> LatticeMaze: 237 """Generate a lattice maze using Wilson's algorithm. 238 239 # Algorithm 240 Wilson's algorithm generates an unbiased (random) maze 241 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 242 acyclic and all cells are part of a unique connected space. 243 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 244 """ 245 assert not kwargs, ( 246 f"gen_wilson does not take any additional arguments, got {kwargs = }" 247 ) 248 249 grid_shape_: Coord = np.array(grid_shape) 250 251 # Initialize grid and visited cells 252 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 253 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 254 255 # Choose a random cell and mark it as visited 256 start_coord: Coord = _random_start_coord(grid_shape_, None) 257 visited[start_coord[0], start_coord[1]] = True 258 del start_coord 259 260 while not visited.all(): 261 # Perform loop-erased random walk from another random cell 262 263 # Choose walk_start only from unvisited cells 264 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 265 walk_start: Coord = unvisited_coords[ 266 np.random.choice(unvisited_coords.shape[0]) 267 ] 268 269 # Perform the random walk 270 path: list[Coord] = [walk_start] 271 current: Coord = walk_start 272 273 # exit the loop once the current path hits a visited cell 274 while not visited[current[0], current[1]]: 275 # find a valid neighbor (one always exists on a lattice) 276 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 277 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 278 279 # Check for loop 280 loop_exit: int | None = None 281 for i, p in enumerate(path): 282 if np.array_equal(next_cell, p): 283 loop_exit = i 284 break 285 286 # erase the loop, or continue the walk 287 if loop_exit is not None: 288 # this removes everything after and including the loop start 289 path = path[: loop_exit + 1] 290 # reset current cell to end of path 291 current = path[-1] 292 else: 293 path.append(next_cell) 294 current = next_cell 295 296 # Add the path to the maze 297 for i in range(len(path) - 1): 298 c_1: Coord = path[i] 299 c_2: Coord = path[i + 1] 300 301 # find the dimension of the connection 302 delta: Coord = c_2 - c_1 303 dim: int = int(np.argmax(np.abs(delta))) 304 305 # if positive, down/right from current coord 306 # if negative, up/left from current coord (down/right from neighbor) 307 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 308 connection_list[dim, clist_node[0], clist_node[1]] = True 309 visited[c_1[0], c_1[1]] = True 310 # we dont add c_2 because the last c_2 will have already been visited 311 312 return LatticeMaze( 313 connection_list=connection_list, 314 generation_meta=dict( 315 func_name="gen_wilson", 316 grid_shape=grid_shape_, 317 fully_connected=True, 318 ), 319 )
Generate a lattice maze using Wilson's algorithm.
Algorithm
Wilson's algorithm generates an unbiased (random) maze sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is acyclic and all cells are part of a unique connected space. https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
321 @staticmethod 322 def gen_percolation( 323 grid_shape: Coord | CoordTup, 324 p: float = 0.4, 325 lattice_dim: int = 2, 326 start_coord: Coord | None = None, 327 ) -> LatticeMaze: 328 """generate a lattice maze using simple percolation 329 330 note that p in the range (0.4, 0.7) gives the most interesting mazes 331 332 # Arguments 333 - `grid_shape: Coord`: the shape of the grid 334 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 335 - `p: float`: the probability of a cell being accessible (default: `0.5`) 336 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 337 """ 338 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 339 grid_shape_: Coord = np.array(grid_shape) 340 341 start_coord = _random_start_coord(grid_shape_, start_coord) 342 343 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 344 345 connection_list = _fill_edges_with_walls(connection_list) 346 347 output: LatticeMaze = LatticeMaze( 348 connection_list=connection_list, 349 generation_meta=dict( 350 func_name="gen_percolation", 351 grid_shape=grid_shape_, 352 percolation_p=p, 353 start_coord=start_coord, 354 ), 355 ) 356 357 # generation_meta is sometimes None, but not here since we just made it a dict above 358 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 359 start_coord, 360 ) 361 362 return output
generate a lattice maze using simple percolation
note that p in the range (0.4, 0.7) gives the most interesting mazes
Arguments
grid_shape: Coord
: the shape of the gridlattice_dim: int
: the dimension of the lattice (default:2
)p: float
: the probability of a cell being accessible (default:0.5
)start_coord: Coord | None
: the starting coordinate for the connected component (default:None
will give a random start)
364 @staticmethod 365 def gen_dfs_percolation( 366 grid_shape: Coord | CoordTup, 367 p: float = 0.4, 368 lattice_dim: int = 2, 369 accessible_cells: int | None = None, 370 max_tree_depth: int | None = None, 371 start_coord: Coord | None = None, 372 ) -> LatticeMaze: 373 """dfs and then percolation (adds cycles)""" 374 grid_shape_: Coord = np.array(grid_shape) 375 start_coord = _random_start_coord(grid_shape_, start_coord) 376 377 # generate initial maze via dfs 378 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 379 grid_shape=grid_shape_, 380 lattice_dim=lattice_dim, 381 accessible_cells=accessible_cells, 382 max_tree_depth=max_tree_depth, 383 start_coord=start_coord, 384 ) 385 386 # percolate 387 connection_list_perc: np.ndarray = ( 388 np.random.rand(*maze.connection_list.shape) < p 389 ) 390 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 391 392 maze.__dict__["connection_list"] = np.logical_or( 393 maze.connection_list, 394 connection_list_perc, 395 ) 396 397 # generation_meta is sometimes None, but not here since we just made it a dict above 398 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 399 maze.generation_meta["percolation_p"] = p # type: ignore[index] 400 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 401 start_coord, 402 ) 403 404 return maze
dfs and then percolation (adds cycles)
406 @staticmethod 407 def gen_kruskal( 408 grid_shape: "Coord | CoordTup", 409 lattice_dim: int = 2, 410 start_coord: "Coord | None" = None, 411 ) -> "LatticeMaze": 412 """Generate a maze using Kruskal's algorithm. 413 414 This function generates a random spanning tree over a grid using Kruskal's algorithm. 415 Each cell is treated as a node, and all valid adjacent edges are listed and processed 416 in random order. An edge is added (i.e. its passage carved) only if it connects two cells 417 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) 418 without cycles. 419 420 https://en.wikipedia.org/wiki/Kruskal's_algorithm 421 422 # Parameters: 423 - `grid_shape : Coord | CoordTup` 424 The shape of the maze grid (for example, `(n_rows, n_cols)`). 425 - `lattice_dim : int` 426 The lattice dimension (default is `2`). 427 - `start_coord : Coord | None` 428 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen. 429 - `**kwargs` 430 Additional keyword arguments (currently unused). 431 432 # Returns: 433 - `LatticeMaze` 434 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm. 435 436 # Usage: 437 ```python 438 maze = gen_kruskal((10, 10)) 439 ``` 440 """ 441 assert lattice_dim == 2, ( # noqa: PLR2004 442 "Kruskal's algorithm is only implemented for 2D lattices." 443 ) 444 # Convert grid_shape to a tuple of ints 445 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 446 n_rows, n_cols = grid_shape_ 447 448 # Initialize union-find data structure. 449 parent: dict[tuple[int, int], tuple[int, int]] = {} 450 451 def find(cell: tuple[int, int]) -> tuple[int, int]: 452 while parent[cell] != cell: 453 parent[cell] = parent[parent[cell]] 454 cell = parent[cell] 455 return cell 456 457 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None: 458 root1 = find(cell1) 459 root2 = find(cell2) 460 parent[root2] = root1 461 462 # Initialize each cell as its own set. 463 for i in range(n_rows): 464 for j in range(n_cols): 465 parent[(i, j)] = (i, j) 466 467 # List all possible edges. 468 # For vertical edges (i.e. connecting a cell to its right neighbor): 469 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = [] 470 for i in range(n_rows): 471 for j in range(n_cols - 1): 472 edges.append(((i, j), (i, j + 1), 1)) 473 # For horizontal edges (i.e. connecting a cell to its bottom neighbor): 474 for i in range(n_rows - 1): 475 for j in range(n_cols): 476 edges.append(((i, j), (i + 1, j), 0)) 477 478 # Shuffle the list of edges. 479 import random 480 481 random.shuffle(edges) 482 483 # Initialize connection_list with no connections. 484 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)). 485 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)). 486 import numpy as np 487 488 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 489 490 # Process each edge; if it connects two different trees, union them and carve the passage. 491 for cell1, cell2, direction in edges: 492 if find(cell1) != find(cell2): 493 union(cell1, cell2) 494 if direction == 0: 495 # Horizontal edge: connection is stored in connection_list[0] at cell1. 496 connection_list[0, cell1[0], cell1[1]] = True 497 else: 498 # Vertical edge: connection is stored in connection_list[1] at cell1. 499 connection_list[1, cell1[0], cell1[1]] = True 500 501 if start_coord is None: 502 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 503 504 generation_meta: dict = dict( 505 func_name="gen_kruskal", 506 grid_shape=grid_shape_, 507 start_coord=start_coord, 508 algorithm="kruskal", 509 fully_connected=True, 510 ) 511 return LatticeMaze( 512 connection_list=connection_list, generation_meta=generation_meta 513 )
Generate a maze using Kruskal's algorithm.
This function generates a random spanning tree over a grid using Kruskal's algorithm. Each cell is treated as a node, and all valid adjacent edges are listed and processed in random order. An edge is added (i.e. its passage carved) only if it connects two cells that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) without cycles.
https://en.wikipedia.org/wiki/Kruskal's_algorithm
Parameters:
grid_shape : Coord | CoordTup
The shape of the maze grid (for example,(n_rows, n_cols)
).lattice_dim : int
The lattice dimension (default is2
).start_coord : Coord | None
Optionally, specify a starting coordinate. IfNone
, a random coordinate will be chosen.**kwargs
Additional keyword arguments (currently unused).
Returns:
LatticeMaze
A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
Usage:
maze = gen_kruskal((10, 10))
515 @staticmethod 516 def gen_recursive_division( 517 grid_shape: "Coord | CoordTup", 518 lattice_dim: int = 2, 519 start_coord: "Coord | None" = None, 520 ) -> "LatticeMaze": 521 """Generate a maze using the recursive division algorithm. 522 523 This function generates a maze by recursively dividing the grid with walls and carving a single 524 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent 525 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. 526 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells. 527 528 # Parameters: 529 - `grid_shape : Coord | CoordTup` 530 The shape of the maze grid (e.g., `(n_rows, n_cols)`). 531 - `lattice_dim : int` 532 The lattice dimension (default is `2`). 533 - `start_coord : Coord | None` 534 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen. 535 - `**kwargs` 536 Additional keyword arguments (currently unused). 537 538 # Returns: 539 - `LatticeMaze` 540 A maze represented by a connection list, generated using recursive division. 541 542 # Usage: 543 ```python 544 maze = gen_recursive_division((10, 10)) 545 ``` 546 """ 547 assert lattice_dim == 2, ( # noqa: PLR2004 548 "Recursive division algorithm is only implemented for 2D lattices." 549 ) 550 # Convert grid_shape to a tuple of ints. 551 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 552 n_rows, n_cols = grid_shape_ 553 554 # Initialize connection_list as a fully connected grid. 555 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True. 556 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True. 557 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 558 connection_list[0, : n_rows - 1, :] = True 559 connection_list[1, :, : n_cols - 1] = True 560 561 def divide(x: int, y: int, width: int, height: int) -> None: 562 """Recursively divide the region starting at (x, y) with the given width and height. 563 564 Removes connections along the chosen division line except for one randomly chosen gap. 565 """ 566 if width < 2 or height < 2: # noqa: PLR2004 567 return 568 569 if width > height: 570 # Vertical division. 571 wall_col = random.randint(x + 1, x + width - 1) 572 gap_row = random.randint(y, y + height - 1) 573 for row in range(y, y + height): 574 if row == gap_row: 575 continue 576 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col). 577 if wall_col - 1 < n_cols - 1: 578 connection_list[1, row, wall_col - 1] = False 579 # Recurse on the left and right subregions. 580 divide(x, y, wall_col - x, height) 581 divide(wall_col, y, x + width - wall_col, height) 582 else: 583 # Horizontal division. 584 wall_row = random.randint(y + 1, y + height - 1) 585 gap_col = random.randint(x, x + width - 1) 586 for col in range(x, x + width): 587 if col == gap_col: 588 continue 589 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col). 590 if wall_row - 1 < n_rows - 1: 591 connection_list[0, wall_row - 1, col] = False 592 # Recurse on the top and bottom subregions. 593 divide(x, y, width, wall_row - y) 594 divide(x, wall_row, width, y + height - wall_row) 595 596 # Begin the division on the full grid. 597 divide(0, 0, n_cols, n_rows) 598 599 if start_coord is None: 600 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 601 602 generation_meta: dict = dict( 603 func_name="gen_recursive_division", 604 grid_shape=grid_shape_, 605 start_coord=start_coord, 606 algorithm="recursive_division", 607 fully_connected=True, 608 ) 609 return LatticeMaze( 610 connection_list=connection_list, generation_meta=generation_meta 611 )
Generate a maze using the recursive division algorithm.
This function generates a maze by recursively dividing the grid with walls and carving a single passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
Parameters:
grid_shape : Coord | CoordTup
The shape of the maze grid (e.g.,(n_rows, n_cols)
).lattice_dim : int
The lattice dimension (default is2
).start_coord : Coord | None
Optionally, specify a starting coordinate. IfNone
, a random coordinate is chosen.**kwargs
Additional keyword arguments (currently unused).
Returns:
LatticeMaze
A maze represented by a connection list, generated using recursive division.
Usage:
maze = gen_recursive_division((10, 10))