maze_dataset.maze
LatticeMaze
and the classes like SolvedMaze
that inherit from it, along with a variety of helper functions"
This package utilizes a simple, efficient representation of mazes. Using an adjacency list to represent mazes would lead to a poor lookup time of whether any given connection exists, whilst using a dense adjacency matrix would waste memory by failing to exploit the structure (e.g., only 4 of the diagonals would be filled in).
Instead, we describe mazes with the following simple representation: for a $d$-dimensional lattice with $r$ rows and $c$ columns, we initialize a boolean array $A = {0, 1}^{d \times r \times c}$, which we refer to in the code as a connection_list
. The value at $A[0,i,j]$ determines whether a downward connection exists from node $[i,j]$ to $[i+1, j]$. Likewise, the value at $A[1,i,j]$ determines whether a rightwards connection to $[i, j+1]$ exists. Thus, we avoid duplication of data about the existence of connections, at the cost of requiring additional care with indexing when looking for a connection upwards or to the left. Note that this setup allows for a periodic lattice.
To produce solutions to mazes, two points are selected uniformly at random without replacement from the connected component of the maze, and the $A^*$ algorithm is applied to find the shortest path between them.
Parallelization is implemented via the multiprocessing
module in the Python standard library, and parallel generation can be controlled via keyword arguments to the MazeDataset.from_config()
function.
1r"""`LatticeMaze` and the classes like `SolvedMaze` that inherit from it, along with a variety of helper functions" 2 3This package utilizes a simple, efficient representation of mazes. Using an adjacency list to represent mazes would lead to a poor lookup time of whether any given connection exists, whilst using a dense adjacency matrix would waste memory by failing to exploit the structure (e.g., only 4 of the diagonals would be filled in). 4Instead, we describe mazes with the following simple representation: for a $d$-dimensional lattice with $r$ rows and $c$ columns, we initialize a boolean array $A = \{0, 1\}^{d \times r \times c}$, which we refer to in the code as a `connection_list`. The value at $A[0,i,j]$ determines whether a downward connection exists from node $[i,j]$ to $[i+1, j]$. Likewise, the value at $A[1,i,j]$ determines whether a rightwards connection to $[i, j+1]$ exists. Thus, we avoid duplication of data about the existence of connections, at the cost of requiring additional care with indexing when looking for a connection upwards or to the left. Note that this setup allows for a periodic lattice. 5 6To produce solutions to mazes, two points are selected uniformly at random without replacement from the connected component of the maze, and the $A^*$ algorithm is applied to find the shortest path between them. 7 8Parallelization is implemented via the `multiprocessing` module in the Python standard library, and parallel generation can be controlled via keyword arguments to the `MazeDataset.from_config()` function. 9""" 10 11from maze_dataset.maze.lattice_maze import ( 12 AsciiChars, 13 ConnectionList, 14 Coord, 15 CoordArray, 16 LatticeMaze, 17 PixelColors, 18 SolvedMaze, 19 TargetedLatticeMaze, 20) 21 22__all__ = [ 23 # submodules 24 "lattice_maze", 25 # imports 26 "SolvedMaze", 27 "TargetedLatticeMaze", 28 "LatticeMaze", 29 "ConnectionList", 30 "AsciiChars", 31 "Coord", 32 "CoordArray", 33 "PixelColors", 34]
1265@serializable_dataclass(frozen=True, kw_only=True) 1266class SolvedMaze(TargetedLatticeMaze): # type: ignore[misc] 1267 """Stores a maze and a solution""" 1268 1269 solution: CoordArray = serializable_field( # type: ignore[misc] 1270 assert_type=False, 1271 ) 1272 1273 def __init__( 1274 self, 1275 connection_list: ConnectionList, 1276 solution: CoordArray, 1277 generation_meta: dict | None = None, 1278 start_pos: Coord | None = None, 1279 end_pos: Coord | None = None, 1280 allow_invalid: bool = False, 1281 ) -> None: 1282 """Create a SolvedMaze from a connection list and a solution 1283 1284 > DOCS: better documentation for this init method 1285 """ 1286 # figure out the solution 1287 solution_valid: bool = False 1288 if solution is not None: 1289 solution = np.array(solution) 1290 # note that a path length of 1 here is valid, since the start and end pos could be the same 1291 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 1292 solution_valid = True 1293 1294 if not solution_valid and not allow_invalid: 1295 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" 1296 raise ValueError( 1297 err_msg, 1298 f"{connection_list = }", 1299 ) 1300 1301 # init the TargetedLatticeMaze 1302 super().__init__( 1303 connection_list=connection_list, 1304 generation_meta=generation_meta, 1305 # TODO: the argument type is stricter than the expected type but it still fails? 1306 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 1307 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 1308 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 1309 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 1310 ) 1311 1312 self.__dict__["solution"] = solution 1313 1314 # adjust the endpoints 1315 if not allow_invalid: 1316 if start_pos is not None: 1317 assert np.array_equal(np.array(start_pos), self.start_pos), ( 1318 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 1319 ) 1320 if end_pos is not None: 1321 assert np.array_equal(np.array(end_pos), self.end_pos), ( 1322 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 1323 ) 1324 # TODO: assert the path does not backtrack, walk through walls, etc? 1325 1326 def __eq__(self, other: object) -> bool: 1327 "check equality, calls parent class equality check" 1328 return super().__eq__(other) 1329 1330 def __hash__(self) -> int: 1331 "hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes" 1332 return hash((self.connection_list.tobytes(), self.solution.tobytes())) 1333 1334 def _get_solution_tokens(self) -> list[str | CoordTup]: 1335 return [ 1336 SPECIAL_TOKENS.PATH_START, 1337 *[tuple(c) for c in self.solution], 1338 SPECIAL_TOKENS.PATH_END, 1339 ] 1340 1341 def get_solution_tokens(self) -> list[str | CoordTup]: 1342 "(deprecated!) return the solution as a list of tokens" 1343 warnings.warn( 1344 "`LatticeMaze.get_solution_tokens` is deprecated.", 1345 TokenizerDeprecationWarning, 1346 ) 1347 return self._get_solution_tokens() 1348 1349 # for backwards compatibility 1350 @property 1351 def maze(self) -> LatticeMaze: 1352 "(deprecated!) return the maze without the solution" 1353 warnings.warn( 1354 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 1355 DeprecationWarning, 1356 ) 1357 return LatticeMaze(connection_list=self.connection_list) 1358 1359 # type ignore here since we're overriding a method with a different signature 1360 @classmethod 1361 def from_lattice_maze( # type: ignore[override] 1362 cls, 1363 lattice_maze: LatticeMaze, 1364 solution: list[CoordTup] | CoordArray, 1365 ) -> "SolvedMaze": 1366 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" 1367 return cls( 1368 connection_list=lattice_maze.connection_list, 1369 solution=np.array(solution), 1370 generation_meta=lattice_maze.generation_meta, 1371 ) 1372 1373 @classmethod 1374 def from_targeted_lattice_maze( 1375 cls, 1376 targeted_lattice_maze: TargetedLatticeMaze, 1377 solution: list[CoordTup] | CoordArray | None = None, 1378 ) -> "SolvedMaze": 1379 """solves the given targeted lattice maze and returns a SolvedMaze""" 1380 if solution is None: 1381 solution = targeted_lattice_maze.find_shortest_path( 1382 targeted_lattice_maze.start_pos, 1383 targeted_lattice_maze.end_pos, 1384 ) 1385 return cls( 1386 connection_list=targeted_lattice_maze.connection_list, 1387 solution=np.array(solution), 1388 generation_meta=targeted_lattice_maze.generation_meta, 1389 ) 1390 1391 def get_solution_forking_points( 1392 self, 1393 always_include_endpoints: bool = False, 1394 ) -> tuple[list[int], CoordArray]: 1395 """coordinates and their indicies from the solution where a fork is present 1396 1397 - if the start point is not a dead end, this counts as a fork 1398 - if the end point is not a dead end, this counts as a fork 1399 """ 1400 output_idxs: list[int] = list() 1401 output_coords: list[CoordTup] = list() 1402 1403 for idx, coord in enumerate(self.solution): 1404 # more than one choice for first coord, or more than 2 for any other 1405 # since the previous coord doesn't count as a choice 1406 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 1407 theshold: int = 1 if is_endpoint else 2 1408 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 1409 is_endpoint and always_include_endpoints 1410 ): 1411 output_idxs.append(idx) 1412 output_coords.append(coord) 1413 1414 return output_idxs, np.array(output_coords) 1415 1416 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 1417 """coordinates from the solution where there is only a single (non-backtracking) point to move to 1418 1419 returns the complement of `get_solution_forking_points` from the path 1420 """ 1421 forks_idxs, _ = self.get_solution_forking_points() 1422 # HACK: idk why type ignore here 1423 return ( # type: ignore[return-value] 1424 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 1425 np.delete(self.solution, forks_idxs, axis=0), 1426 )
Stores a maze and a solution
1273 def __init__( 1274 self, 1275 connection_list: ConnectionList, 1276 solution: CoordArray, 1277 generation_meta: dict | None = None, 1278 start_pos: Coord | None = None, 1279 end_pos: Coord | None = None, 1280 allow_invalid: bool = False, 1281 ) -> None: 1282 """Create a SolvedMaze from a connection list and a solution 1283 1284 > DOCS: better documentation for this init method 1285 """ 1286 # figure out the solution 1287 solution_valid: bool = False 1288 if solution is not None: 1289 solution = np.array(solution) 1290 # note that a path length of 1 here is valid, since the start and end pos could be the same 1291 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 1292 solution_valid = True 1293 1294 if not solution_valid and not allow_invalid: 1295 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" 1296 raise ValueError( 1297 err_msg, 1298 f"{connection_list = }", 1299 ) 1300 1301 # init the TargetedLatticeMaze 1302 super().__init__( 1303 connection_list=connection_list, 1304 generation_meta=generation_meta, 1305 # TODO: the argument type is stricter than the expected type but it still fails? 1306 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 1307 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 1308 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 1309 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 1310 ) 1311 1312 self.__dict__["solution"] = solution 1313 1314 # adjust the endpoints 1315 if not allow_invalid: 1316 if start_pos is not None: 1317 assert np.array_equal(np.array(start_pos), self.start_pos), ( 1318 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 1319 ) 1320 if end_pos is not None: 1321 assert np.array_equal(np.array(end_pos), self.end_pos), ( 1322 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 1323 ) 1324 # TODO: assert the path does not backtrack, walk through walls, etc?
Create a SolvedMaze from a connection list and a solution
DOCS: better documentation for this init method
1341 def get_solution_tokens(self) -> list[str | CoordTup]: 1342 "(deprecated!) return the solution as a list of tokens" 1343 warnings.warn( 1344 "`LatticeMaze.get_solution_tokens` is deprecated.", 1345 TokenizerDeprecationWarning, 1346 ) 1347 return self._get_solution_tokens()
(deprecated!) return the solution as a list of tokens
1350 @property 1351 def maze(self) -> LatticeMaze: 1352 "(deprecated!) return the maze without the solution" 1353 warnings.warn( 1354 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 1355 DeprecationWarning, 1356 ) 1357 return LatticeMaze(connection_list=self.connection_list)
(deprecated!) return the maze without the solution
1360 @classmethod 1361 def from_lattice_maze( # type: ignore[override] 1362 cls, 1363 lattice_maze: LatticeMaze, 1364 solution: list[CoordTup] | CoordArray, 1365 ) -> "SolvedMaze": 1366 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" 1367 return cls( 1368 connection_list=lattice_maze.connection_list, 1369 solution=np.array(solution), 1370 generation_meta=lattice_maze.generation_meta, 1371 )
get a SolvedMaze
from a LatticeMaze
by specifying a solution
1373 @classmethod 1374 def from_targeted_lattice_maze( 1375 cls, 1376 targeted_lattice_maze: TargetedLatticeMaze, 1377 solution: list[CoordTup] | CoordArray | None = None, 1378 ) -> "SolvedMaze": 1379 """solves the given targeted lattice maze and returns a SolvedMaze""" 1380 if solution is None: 1381 solution = targeted_lattice_maze.find_shortest_path( 1382 targeted_lattice_maze.start_pos, 1383 targeted_lattice_maze.end_pos, 1384 ) 1385 return cls( 1386 connection_list=targeted_lattice_maze.connection_list, 1387 solution=np.array(solution), 1388 generation_meta=targeted_lattice_maze.generation_meta, 1389 )
solves the given targeted lattice maze and returns a SolvedMaze
1391 def get_solution_forking_points( 1392 self, 1393 always_include_endpoints: bool = False, 1394 ) -> tuple[list[int], CoordArray]: 1395 """coordinates and their indicies from the solution where a fork is present 1396 1397 - if the start point is not a dead end, this counts as a fork 1398 - if the end point is not a dead end, this counts as a fork 1399 """ 1400 output_idxs: list[int] = list() 1401 output_coords: list[CoordTup] = list() 1402 1403 for idx, coord in enumerate(self.solution): 1404 # more than one choice for first coord, or more than 2 for any other 1405 # since the previous coord doesn't count as a choice 1406 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 1407 theshold: int = 1 if is_endpoint else 2 1408 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 1409 is_endpoint and always_include_endpoints 1410 ): 1411 output_idxs.append(idx) 1412 output_coords.append(coord) 1413 1414 return output_idxs, np.array(output_coords)
coordinates and their indicies from the solution where a fork is present
- if the start point is not a dead end, this counts as a fork
- if the end point is not a dead end, this counts as a fork
1416 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 1417 """coordinates from the solution where there is only a single (non-backtracking) point to move to 1418 1419 returns the complement of `get_solution_forking_points` from the path 1420 """ 1421 forks_idxs, _ = self.get_solution_forking_points() 1422 # HACK: idk why type ignore here 1423 return ( # type: ignore[return-value] 1424 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 1425 np.delete(self.solution, forks_idxs, axis=0), 1426 )
coordinates from the solution where there is only a single (non-backtracking) point to move to
returns the complement of get_solution_forking_points
from the path
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- LatticeMaze
- connection_list
- generation_meta
- lattice_dim
- grid_shape
- n_connections
- grid_n
- heuristic
- nodes_connected
- is_valid_path
- coord_degrees
- get_coord_neighbors
- gen_connected_component_from
- find_shortest_path
- get_nodes
- get_connected_component
- generate_random_path
- as_adj_list
- from_adj_list
- as_adj_list_tokens
- as_tokens
- from_tokens
- as_pixels
- from_pixels
- as_ascii
- from_ascii
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
1176@serializable_dataclass(frozen=True, kw_only=True) 1177class TargetedLatticeMaze(LatticeMaze): # type: ignore[misc] 1178 """A LatticeMaze with a start and end position""" 1179 1180 # this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos 1181 # type ignore here because even though its a kw-only dataclass, 1182 # mypy doesn't like that non-default arguments are after default arguments 1183 start_pos: Coord = serializable_field( # type: ignore[misc] 1184 assert_type=False, 1185 ) 1186 end_pos: Coord = serializable_field( # type: ignore[misc] 1187 assert_type=False, 1188 ) 1189 1190 def __post_init__(self) -> None: 1191 "post init converts start and end pos to numpy arrays, checks they exist and are in bounds" 1192 # make things numpy arrays (very jank to override frozen dataclass) 1193 self.__dict__["start_pos"] = np.array(self.start_pos) 1194 self.__dict__["end_pos"] = np.array(self.end_pos) 1195 assert self.start_pos is not None 1196 assert self.end_pos is not None 1197 # check that start and end are in bounds 1198 if ( 1199 self.start_pos[0] >= self.grid_shape[0] 1200 or self.start_pos[1] >= self.grid_shape[1] 1201 ): 1202 err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}" 1203 raise ValueError( 1204 err_msg, 1205 ) 1206 if ( 1207 self.end_pos[0] >= self.grid_shape[0] 1208 or self.end_pos[1] >= self.grid_shape[1] 1209 ): 1210 err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }" 1211 raise ValueError( 1212 err_msg, 1213 ) 1214 1215 def __eq__(self, other: object) -> bool: 1216 "check equality, calls parent class equality check" 1217 return super().__eq__(other) 1218 1219 def _get_start_pos_tokens(self) -> list[str | CoordTup]: 1220 return [ 1221 SPECIAL_TOKENS.ORIGIN_START, 1222 tuple(self.start_pos), 1223 SPECIAL_TOKENS.ORIGIN_END, 1224 ] 1225 1226 def get_start_pos_tokens(self) -> list[str | CoordTup]: 1227 "(deprecated!) return the start position as a list of tokens" 1228 warnings.warn( 1229 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 1230 TokenizerDeprecationWarning, 1231 ) 1232 return self._get_start_pos_tokens() 1233 1234 def _get_end_pos_tokens(self) -> list[str | CoordTup]: 1235 return [ 1236 SPECIAL_TOKENS.TARGET_START, 1237 tuple(self.end_pos), 1238 SPECIAL_TOKENS.TARGET_END, 1239 ] 1240 1241 def get_end_pos_tokens(self) -> list[str | CoordTup]: 1242 "(deprecated!) return the end position as a list of tokens" 1243 warnings.warn( 1244 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 1245 TokenizerDeprecationWarning, 1246 ) 1247 return self._get_end_pos_tokens() 1248 1249 @classmethod 1250 def from_lattice_maze( 1251 cls, 1252 lattice_maze: LatticeMaze, 1253 start_pos: Coord | CoordTup, 1254 end_pos: Coord | CoordTup, 1255 ) -> "TargetedLatticeMaze": 1256 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" 1257 return cls( 1258 connection_list=lattice_maze.connection_list, 1259 start_pos=np.array(start_pos), 1260 end_pos=np.array(end_pos), 1261 generation_meta=lattice_maze.generation_meta, 1262 )
A LatticeMaze with a start and end position
1226 def get_start_pos_tokens(self) -> list[str | CoordTup]: 1227 "(deprecated!) return the start position as a list of tokens" 1228 warnings.warn( 1229 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 1230 TokenizerDeprecationWarning, 1231 ) 1232 return self._get_start_pos_tokens()
(deprecated!) return the start position as a list of tokens
1241 def get_end_pos_tokens(self) -> list[str | CoordTup]: 1242 "(deprecated!) return the end position as a list of tokens" 1243 warnings.warn( 1244 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 1245 TokenizerDeprecationWarning, 1246 ) 1247 return self._get_end_pos_tokens()
(deprecated!) return the end position as a list of tokens
1249 @classmethod 1250 def from_lattice_maze( 1251 cls, 1252 lattice_maze: LatticeMaze, 1253 start_pos: Coord | CoordTup, 1254 end_pos: Coord | CoordTup, 1255 ) -> "TargetedLatticeMaze": 1256 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" 1257 return cls( 1258 connection_list=lattice_maze.connection_list, 1259 start_pos=np.array(start_pos), 1260 end_pos=np.array(end_pos), 1261 generation_meta=lattice_maze.generation_meta, 1262 )
get a TargetedLatticeMaze
from a LatticeMaze
by specifying start and end positions
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- LatticeMaze
- connection_list
- generation_meta
- lattice_dim
- grid_shape
- n_connections
- grid_n
- heuristic
- nodes_connected
- is_valid_path
- coord_degrees
- get_coord_neighbors
- gen_connected_component_from
- find_shortest_path
- get_nodes
- get_connected_component
- generate_random_path
- as_adj_list
- from_adj_list
- as_adj_list_tokens
- as_tokens
- from_tokens
- as_pixels
- from_pixels
- as_ascii
- from_ascii
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
120@serializable_dataclass( 121 frozen=True, 122 kw_only=True, 123 properties_to_serialize=["lattice_dim", "generation_meta"], 124) 125class LatticeMaze(SerializableDataclass): 126 """lattice maze (nodes on a lattice, connections only to neighboring nodes) 127 128 Connection List represents which nodes (N) are connected in each direction. 129 130 First and second elements represent rightward and downward connections, 131 respectively. 132 133 Example: 134 Connection list: 135 [ 136 [ # down 137 [F T], 138 [F F] 139 ], 140 [ # right 141 [T F], 142 [T F] 143 ] 144 ] 145 146 Nodes with connections 147 N T N F 148 F T 149 N T N F 150 F F 151 152 Graph: 153 N - N 154 | 155 N - N 156 157 Note: the bottom row connections going down, and the 158 right-hand connections going right, will always be False. 159 160 """ 161 162 connection_list: ConnectionList 163 generation_meta: dict | None = serializable_field(default=None, compare=False) 164 165 lattice_dim = property(lambda self: self.connection_list.shape[0]) 166 grid_shape = property(lambda self: self.connection_list.shape[1:]) 167 n_connections = property(lambda self: self.connection_list.sum()) 168 169 @property 170 def grid_n(self) -> int: 171 "grid size as int, raises `AssertionError` if not square" 172 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 173 return self.grid_shape[0] 174 175 # ============================================================ 176 # basic methods 177 # ============================================================ 178 179 def __eq__(self, other: object) -> bool: 180 "equality check calls super" 181 return super().__eq__(other) 182 183 @staticmethod 184 def heuristic(a: CoordTup, b: CoordTup) -> float: 185 """return manhattan distance between two points""" 186 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1]) 187 188 def __hash__(self) -> int: 189 """hash the connection list by converting connection list to bytes""" 190 return hash(self.connection_list.tobytes()) 191 192 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 193 """returns whether two nodes are connected""" 194 delta: Coord = b - a 195 if np.abs(delta).sum() != 1: 196 # return false if not even adjacent 197 return False 198 else: 199 # test for wall 200 dim: int = int(np.argmax(np.abs(delta))) 201 clist_node: Coord = a if (delta.sum() > 0) else b 202 return self.connection_list[dim, clist_node[0], clist_node[1]] 203 204 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 205 """check if a path is valid""" 206 # check path is not empty 207 if len(path) == 0: 208 return empty_is_valid 209 210 # check all coords in bounds of maze 211 if not np.all((path >= 0) & (path < self.grid_shape)): 212 return False 213 214 # check all nodes connected 215 for i in range(len(path) - 1): 216 if not self.nodes_connected(path[i], path[i + 1]): 217 return False 218 return True 219 220 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 221 """Returns an array with the connectivity degree of each coord. 222 223 I.e., how many neighbors each coord has. 224 """ 225 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 226 self.connection_list.astype(np.int8) 227 ) 228 degrees: Int8[np.ndarray, "row col"] = np.sum( 229 int_conn, 230 axis=0, 231 ) # Connections to east and south 232 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 233 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 234 return degrees 235 236 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 237 """Returns an array of the neighboring, connected coords of `c`.""" 238 c = np.array(c) # type: ignore[assignment] 239 neighbors: list[Coord] = [ 240 neighbor 241 for neighbor in (c + NEIGHBORS_MASK) 242 if ( 243 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 244 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 245 and self.nodes_connected(c, neighbor) # connected 246 ) 247 ] 248 249 output: CoordArray = np.array(neighbors) 250 if len(neighbors) > 0: 251 assert output.shape == ( 252 len(neighbors), 253 2, 254 ), ( 255 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 256 ) 257 return output 258 259 def gen_connected_component_from(self, c: Coord) -> CoordArray: 260 """return the connected component from a given coordinate""" 261 # Stack for DFS 262 stack: list[Coord] = [c] 263 264 # Set to store visited nodes 265 visited: set[CoordTup] = set() 266 267 while stack: 268 current_node: Coord = stack.pop() 269 # this is fine since we know current_node is a coord and thus of length 2 270 visited.add(tuple(current_node)) # type: ignore[arg-type] 271 272 # Get the neighbors of the current node 273 neighbors = self.get_coord_neighbors(current_node) 274 275 # Iterate over neighbors 276 for neighbor in neighbors: 277 if tuple(neighbor) not in visited: 278 stack.append(neighbor) 279 280 return np.array(list(visited)) 281 282 def find_shortest_path( 283 self, 284 c_start: CoordTup | Coord, 285 c_end: CoordTup | Coord, 286 ) -> CoordArray: 287 """find the shortest path between two coordinates, using A*""" 288 c_start = tuple(c_start) # type: ignore[assignment] 289 c_end = tuple(c_end) # type: ignore[assignment] 290 291 g_score: dict[CoordTup, float] = ( 292 dict() 293 ) # cost of cheapest path to node from start currently known 294 f_score: dict[CoordTup, float] = { 295 c_start: 0.0, 296 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 297 298 # init 299 g_score[c_start] = 0.0 300 g_score[c_start] = self.heuristic(c_start, c_end) 301 302 closed_vtx: set[CoordTup] = set() # nodes already evaluated 303 # nodes to be evaluated 304 # we need a set of the tuples, dont place the ints in the set 305 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 306 source: dict[CoordTup, CoordTup] = ( 307 dict() 308 ) # node immediately preceding each node in the path (currently known shortest path) 309 310 while open_vtx: 311 # get lowest f_score node 312 # mypy cant tell that c is of length 2 313 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] 314 # f_current: float = f_score[c_current] 315 316 # check if goal is reached 317 if c_end == c_current: 318 path: list[CoordTup] = [c_current] 319 p_current: CoordTup = c_current 320 while p_current in source: 321 p_current = source[p_current] 322 path.append(p_current) 323 # ---------------------------------------------------------------------- 324 # this is the only return statement 325 return np.array(path[::-1]) 326 # ---------------------------------------------------------------------- 327 328 # close current node 329 closed_vtx.add(c_current) 330 open_vtx.remove(c_current) 331 332 # update g_score of neighbors 333 _np_neighbor: Coord 334 for _np_neighbor in self.get_coord_neighbors(c_current): 335 neighbor: CoordTup = tuple(_np_neighbor) 336 337 if neighbor in closed_vtx: 338 # already checked 339 continue 340 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 341 342 if neighbor not in open_vtx: 343 # found new vtx, so add 344 open_vtx.add(neighbor) 345 346 elif g_temp >= g_score[neighbor]: 347 # if already knew about this one, but current g_score is worse, skip 348 continue 349 350 # store g_score and source 351 source[neighbor] = c_current 352 g_score[neighbor] = g_temp 353 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 354 355 raise ValueError( 356 "A solution could not be found!", 357 f"{c_start = }, {c_end = }", 358 self.as_ascii(), 359 ) 360 361 def get_nodes(self) -> CoordArray: 362 """return a list of all nodes in the maze""" 363 rows: Int[np.ndarray, "x y"] 364 cols: Int[np.ndarray, "x y"] 365 rows, cols = np.meshgrid( 366 range(self.grid_shape[0]), 367 range(self.grid_shape[1]), 368 indexing="ij", 369 ) 370 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 371 return nodes 372 373 def get_connected_component(self) -> CoordArray: 374 """get the largest (and assumed only nonsingular) connected component of the maze 375 376 TODO: other connected components? 377 """ 378 if (self.generation_meta is None) or ( 379 self.generation_meta.get("fully_connected", False) 380 ): 381 # for fully connected case, pick any two positions 382 return self.get_nodes() 383 else: 384 # if metadata provided, use visited cells 385 visited_cells: set[CoordTup] | None = self.generation_meta.get( 386 "visited_cells", 387 None, 388 ) 389 if visited_cells is None: 390 # TODO: dynamically generate visited_cells? 391 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 392 raise ValueError( 393 err_msg, 394 ) 395 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 396 return visited_cells_np 397 398 @typing.overload 399 def generate_random_path( 400 self, 401 allowed_start: CoordList | None = None, 402 allowed_end: CoordList | None = None, 403 deadend_start: bool = False, 404 deadend_end: bool = False, 405 endpoints_not_equal: bool = False, 406 except_on_no_valid_endpoint: typing.Literal[True] = True, 407 ) -> CoordArray: ... 408 @typing.overload 409 def generate_random_path( 410 self, 411 allowed_start: CoordList | None = None, 412 allowed_end: CoordList | None = None, 413 deadend_start: bool = False, 414 deadend_end: bool = False, 415 endpoints_not_equal: bool = False, 416 except_on_no_valid_endpoint: typing.Literal[False] = False, 417 ) -> typing.Optional[CoordArray]: ... 418 def generate_random_path( # noqa: C901 419 self, 420 allowed_start: CoordList | None = None, 421 allowed_end: CoordList | None = None, 422 deadend_start: bool = False, 423 deadend_end: bool = False, 424 endpoints_not_equal: bool = False, 425 except_on_no_valid_endpoint: bool = True, 426 ) -> typing.Optional[CoordArray]: 427 """return a path between randomly chosen start and end nodes within the connected component 428 429 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 430 431 # Parameters: 432 - `allowed_start : CoordList | None` 433 a list of allowed start positions. If `None`, any position in the connected component is allowed 434 (defaults to `None`) 435 - `allowed_end : CoordList | None` 436 a list of allowed end positions. If `None`, any position in the connected component is allowed 437 (defaults to `None`) 438 - `deadend_start : bool` 439 whether to ***force*** the start position to be a deadend (defaults to `False`) 440 (defaults to `False`) 441 - `deadend_end : bool` 442 whether to ***force*** the end position to be a deadend (defaults to `False`) 443 (defaults to `False`) 444 - `endpoints_not_equal : bool` 445 whether to ensure tha the start and end point are not the same 446 (defaults to `False`) 447 - `except_on_no_valid_endpoint : bool` 448 whether to raise an error if no valid start or end positions are found 449 if this is `False`, the function might return `None` and this must be handled by the caller 450 (defaults to `True`) 451 452 # Returns: 453 - `CoordArray` 454 a path between the selected start and end positions 455 456 # Raises: 457 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 458 """ 459 # we can't create a "path" in a single-node maze 460 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 461 f"can't create path in single-node maze: {self.as_ascii()}" 462 ) 463 464 # get connected component 465 connected_component: CoordArray = self.get_connected_component() 466 467 # initialize start and end positions 468 positions: Int[np.int8, "2 2"] 469 470 # if no special conditions on start and end positions 471 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 472 None, 473 None, 474 False, 475 False, 476 ): 477 try: 478 positions = connected_component[ # type: ignore[assignment] 479 np.random.choice( 480 len(connected_component), 481 size=2, 482 replace=False, 483 ) 484 ] 485 except ValueError as e: 486 if except_on_no_valid_endpoint: 487 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" 488 raise NoValidEndpointException( 489 err_msg, 490 ) from e 491 return None 492 493 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 494 495 # handle special conditions 496 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 497 # copy connected component set 498 allowed_start_set: set[CoordTup] = connected_component_set.copy() 499 allowed_end_set: set[CoordTup] = connected_component_set.copy() 500 501 # filter by explicitly allowed start and end positions 502 # '# type: ignore[assignment]' here because the returned tuple can be of any length 503 if allowed_start is not None: 504 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 505 506 if allowed_end is not None: 507 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 508 509 # filter by forcing deadends 510 if deadend_start: 511 allowed_start_set = set( 512 filter( 513 lambda x: len(self.get_coord_neighbors(x)) == 1, 514 allowed_start_set, 515 ), 516 ) 517 518 if deadend_end: 519 allowed_end_set = set( 520 filter( 521 lambda x: len(self.get_coord_neighbors(x)) == 1, 522 allowed_end_set, 523 ), 524 ) 525 526 # check we have valid positions 527 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 528 if except_on_no_valid_endpoint: 529 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 530 raise NoValidEndpointException( 531 err_msg, 532 ) 533 return None 534 535 # randomly select start and end positions 536 try: 537 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 538 start_pos: CoordTup = tuple( # type: ignore[assignment] 539 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], 540 ) 541 if endpoints_not_equal: 542 # remove start position from end positions 543 allowed_end_set.discard(start_pos) 544 end_pos: CoordTup = tuple( # type: ignore[assignment] 545 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], 546 ) 547 except ValueError as e: 548 if except_on_no_valid_endpoint: 549 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 550 raise NoValidEndpointException( 551 err_msg, 552 ) from e 553 return None 554 555 return self.find_shortest_path(start_pos, end_pos) 556 557 # ============================================================ 558 # to and from adjacency list 559 # ============================================================ 560 def as_adj_list( 561 self, 562 shuffle_d0: bool = True, 563 shuffle_d1: bool = True, 564 ) -> Int8[np.ndarray, "conn start_end coord"]: 565 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" 566 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1) 567 568 @classmethod 569 def from_adj_list( 570 cls, 571 adj_list: Int8[np.ndarray, "conn start_end coord"], 572 ) -> "LatticeMaze": 573 """create a LatticeMaze from a list of connections 574 575 > [!NOTE] 576 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 577 """ 578 # this is where it would probably break for rectangular mazes 579 grid_n: int = adj_list.max() + 1 580 581 connection_list: ConnectionList = np.zeros( 582 (2, grid_n, grid_n), 583 dtype=np.bool_, 584 ) 585 586 for c_start, c_end in adj_list: 587 # check that exactly 1 coordinate matches 588 if (c_start == c_end).sum() != 1: 589 raise ValueError("invalid connection") 590 591 # get the direction 592 d: int = (c_start != c_end).argmax() 593 594 x: int 595 y: int 596 # pick whichever has the lesser value in the direction `d` 597 if c_start[d] < c_end[d]: 598 x, y = c_start 599 else: 600 x, y = c_end 601 602 connection_list[d, x, y] = True 603 604 return LatticeMaze( 605 connection_list=connection_list, 606 ) 607 608 def as_adj_list_tokens(self) -> list[str | CoordTup]: 609 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" 610 warnings.warn( 611 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 612 TokenizerDeprecationWarning, 613 ) 614 return [ 615 SPECIAL_TOKENS.ADJLIST_START, 616 *chain.from_iterable( # type: ignore[list-item] 617 [ 618 [ 619 tuple(c_s), 620 SPECIAL_TOKENS.CONNECTOR, 621 tuple(c_e), 622 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 623 ] 624 for c_s, c_e in self.as_adj_list() 625 ], 626 ), 627 SPECIAL_TOKENS.ADJLIST_END, 628 ] 629 630 def _as_adj_list_tokens(self) -> list[str | CoordTup]: 631 return [ 632 SPECIAL_TOKENS.ADJLIST_START, 633 *chain.from_iterable( # type: ignore[list-item] 634 [ 635 [ 636 tuple(c_s), 637 SPECIAL_TOKENS.CONNECTOR, 638 tuple(c_e), 639 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 640 ] 641 for c_s, c_e in self.as_adj_list() 642 ], 643 ), 644 SPECIAL_TOKENS.ADJLIST_END, 645 ] 646 647 def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]: 648 """turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples""" 649 output: list[CoordTup | str] = self._as_adj_list_tokens() 650 # if getattr(self, "start_pos", None) is not None: 651 if isinstance(self, TargetedLatticeMaze): 652 output += self._get_start_pos_tokens() 653 if isinstance(self, TargetedLatticeMaze): 654 output += self._get_end_pos_tokens() 655 if isinstance(self, SolvedMaze): 656 output += self._get_solution_tokens() 657 return output 658 659 def _as_tokens( 660 self, 661 maze_tokenizer: "MazeTokenizer | TokenizationMode", 662 ) -> list[str]: 663 # type ignores here fine since we check the instance 664 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 665 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 666 if ( 667 isinstance_by_type_name(maze_tokenizer, "MazeTokenizer") 668 and maze_tokenizer.is_AOTP() # type: ignore[union-attr] 669 ): 670 coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP() 671 coords_processed: list[str] = maze_tokenizer.coords_to_strings( # type: ignore[union-attr] 672 coords=coords_raw, 673 when_noncoord="include", 674 ) 675 return coords_processed 676 else: 677 err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}" 678 raise NotImplementedError(err_msg) 679 680 def as_tokens( 681 self, 682 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 683 ) -> list[str]: 684 """serialize maze and solution to tokens""" 685 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 686 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 687 else: 688 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type] 689 690 @classmethod 691 def _from_tokens_AOTP( 692 cls, 693 tokens: list[str], 694 maze_tokenizer: "MazeTokenizer | MazeTokenizerModular", 695 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 696 """create a LatticeMaze from a list of tokens""" 697 # figure out what input format 698 # ======================================== 699 if tokens[0] == SPECIAL_TOKENS.ADJLIST_START: 700 adj_list_tokens = get_adj_list_tokens(tokens) 701 else: 702 # If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens 703 adj_list_tokens = tokens 704 warnings.warn( 705 "Assuming input is just adjacency list tokens, no special tokens found", 706 ) 707 708 # process edges for adjacency list 709 # ======================================== 710 edges: list[list[str]] = list_split( 711 adj_list_tokens, 712 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 713 ) 714 715 coordinates: list[tuple[CoordTup, CoordTup]] = list() 716 for e in edges: 717 # skip last endline 718 if len(e) != 0: 719 # convert to coords, split start and end 720 e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords( 721 e, 722 when_noncoord="include", 723 ) 724 # this assertion depends on the tokenizer having exactly one token for the connector 725 # which is also why we "include" above 726 # the connector token is discarded below 727 assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }" # noqa: PLR2004 728 assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, ( 729 f"invalid edge: {e = } {e_coords = }" 730 ) 731 e_coords_first: CoordTup = e_coords[0] # type: ignore[assignment] 732 e_coords_last: CoordTup = e_coords[-1] # type: ignore[assignment] 733 coordinates.append((e_coords_first, e_coords_last)) 734 735 assert all(len(c) == DIM_2 for c in coordinates), ( 736 f"invalid coordinates: {coordinates = }" 737 ) 738 adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates) 739 assert tuple(adj_list.shape) == ( 740 len(coordinates), 741 2, 742 2, 743 ), f"invalid adj_list: {adj_list.shape = } {coordinates = }" 744 745 output_maze: LatticeMaze = cls.from_adj_list(adj_list) 746 747 # add start and end positions 748 # ======================================== 749 is_targeted: bool = False 750 if all( 751 x in tokens 752 for x in ( 753 SPECIAL_TOKENS.ORIGIN_START, 754 SPECIAL_TOKENS.ORIGIN_END, 755 SPECIAL_TOKENS.TARGET_START, 756 SPECIAL_TOKENS.TARGET_END, 757 ) 758 ): 759 start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 760 get_origin_tokens(tokens), 761 when_noncoord="error", 762 ) 763 end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 764 get_target_tokens(tokens), 765 when_noncoord="error", 766 ) 767 assert len(start_pos_list) == 1, ( 768 f"invalid start_pos_list: {start_pos_list = }" 769 ) 770 assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }" 771 772 start_pos: CoordTup = start_pos_list[0] 773 end_pos: CoordTup = end_pos_list[0] 774 775 output_maze = TargetedLatticeMaze.from_lattice_maze( 776 lattice_maze=output_maze, 777 start_pos=start_pos, 778 end_pos=end_pos, 779 ) 780 781 is_targeted = True 782 783 if all( 784 x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END) 785 ): 786 assert is_targeted, "maze must be targeted to have a solution" 787 solution: list[CoordTup] = maze_tokenizer.strings_to_coords( 788 get_path_tokens(tokens, trim_end=True), 789 when_noncoord="error", 790 ) 791 output_maze = SolvedMaze.from_targeted_lattice_maze( 792 # HACK: I think this is fine, but im not sure 793 targeted_lattice_maze=output_maze, # type: ignore[arg-type] 794 solution=solution, 795 ) 796 797 return output_maze 798 799 # TODO: any way to get return type hinting working for this? 800 @classmethod 801 def from_tokens( 802 cls, 803 tokens: list[str], 804 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 805 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 806 """Constructs a maze from a tokenization. 807 808 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 809 """ 810 # HACK: type ignores here fine since we check the instance 811 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 812 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 813 if ( 814 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 815 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 816 ): 817 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 818 raise NotImplementedError( 819 err_msg, 820 ) 821 822 if isinstance(tokens, str): 823 tokens = tokens.split() 824 825 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 826 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 827 else: 828 raise NotImplementedError("only AOTP tokenization is supported") 829 830 # ============================================================ 831 # to and from pixels 832 # ============================================================ 833 def _as_pixels_bw(self) -> BinaryPixelGrid: 834 assert self.lattice_dim == DIM_2, "only 2D mazes are supported" 835 # Create an empty pixel grid with walls 836 pixel_grid: Int[np.ndarray, "x y"] = np.full( 837 (self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1), 838 False, 839 dtype=np.bool_, 840 ) 841 842 # Set white nodes 843 pixel_grid[1::2, 1::2] = True 844 845 # Set white connections (downward) 846 for i, row in enumerate(self.connection_list[0]): 847 for j, connected in enumerate(row): 848 if connected: 849 pixel_grid[i * 2 + 2, j * 2 + 1] = True 850 851 # Set white connections (rightward) 852 for i, row in enumerate(self.connection_list[1]): 853 for j, connected in enumerate(row): 854 if connected: 855 pixel_grid[i * 2 + 1, j * 2 + 2] = True 856 857 return pixel_grid 858 859 def as_pixels( 860 self, 861 show_endpoints: bool = True, 862 show_solution: bool = True, 863 ) -> PixelGrid: 864 """convert the maze to a pixel grid 865 866 - useful as a simpler way of plotting the maze than the more complex `MazePlot` 867 - the same underlying representation as `as_ascii` but as an image 868 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data 869 """ 870 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 871 # but solution, start_pos, end_pos not always defined 872 # but its fine since we explicitly check the type 873 if show_solution and not show_endpoints: 874 raise ValueError("show_solution=True requires show_endpoints=True") 875 # convert original bool pixel grid to RGB 876 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 877 pixel_grid: PixelGrid = np.full( 878 (*pixel_grid_bw.shape, 3), 879 PixelColors.WALL, 880 dtype=np.uint8, 881 ) 882 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 883 884 if self.__class__ == LatticeMaze: 885 return pixel_grid 886 887 # set endpoints for TargetedLatticeMaze 888 if self.__class__ == TargetedLatticeMaze: 889 if show_endpoints: 890 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 891 PixelColors.START 892 ) 893 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 894 PixelColors.END 895 ) 896 return pixel_grid 897 898 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 899 if show_solution: 900 for coord in self.solution: # type: ignore[attr-defined] 901 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 902 903 # set pixels between coords 904 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 905 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 906 # check they are adjacent using norm 907 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 908 f"Coords {coord} and {next_coord} are not adjacent" 909 ) 910 # set pixel between them 911 pixel_grid[ 912 coord[0] * 2 + 1 + next_coord[0] - coord[0], 913 coord[1] * 2 + 1 + next_coord[1] - coord[1], 914 ] = PixelColors.PATH 915 916 # set endpoints (again, since path would overwrite them) 917 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 918 PixelColors.START 919 ) 920 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 921 PixelColors.END 922 ) 923 924 return pixel_grid 925 926 @classmethod 927 def _from_pixel_grid_bw( 928 cls, 929 pixel_grid: BinaryPixelGrid, 930 ) -> tuple[ConnectionList, tuple[int, int]]: 931 grid_shape: tuple[int, int] = ( 932 pixel_grid.shape[0] // 2, 933 pixel_grid.shape[1] // 2, 934 ) 935 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_) 936 937 # Extract downward connections 938 connection_list[0] = pixel_grid[2::2, 1::2] 939 940 # Extract rightward connections 941 connection_list[1] = pixel_grid[1::2, 2::2] 942 943 return connection_list, grid_shape 944 945 @classmethod 946 def _from_pixel_grid_with_positions( 947 cls, 948 pixel_grid: PixelGrid | BinaryPixelGrid, 949 marked_positions: dict[str, RGB], 950 ) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]: 951 # Convert RGB pixel grid to Bool pixel grid 952 # error: Incompatible types in assignment (expression has type 953 # "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]", 954 # variable has type "ndarray[Any, Any]") [assignment] 955 pixel_grid_bw: BinaryPixelGrid = ~np.all( # type: ignore[assignment] 956 pixel_grid == PixelColors.WALL, 957 axis=-1, 958 ) 959 connection_list: ConnectionList 960 grid_shape: tuple[int, int] 961 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw) 962 963 # Find any marked positions 964 out_positions: dict[str, CoordArray] = dict() 965 for key, color in marked_positions.items(): 966 pos_temp: Int[np.ndarray, "x y"] = np.argwhere( 967 np.all(pixel_grid == color, axis=-1), 968 ) 969 pos_save: list[CoordTup] = list() 970 for pos in pos_temp: 971 # if it is a coordinate and not connection (transform position, %2==1) 972 if pos[0] % 2 == 1 and pos[1] % 2 == 1: 973 pos_save.append((pos[0] // 2, pos[1] // 2)) 974 975 out_positions[key] = np.array(pos_save) 976 977 return connection_list, grid_shape, out_positions 978 979 @classmethod 980 def from_pixels( 981 cls, 982 pixel_grid: PixelGrid, 983 ) -> "LatticeMaze": 984 """create a LatticeMaze from a pixel grid. reverse of `as_pixels` 985 986 # Raises: 987 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` 988 """ 989 connection_list: ConnectionList 990 grid_shape: tuple[int, int] 991 992 # if a binary pixel grid, return regular LatticeMaze 993 if len(pixel_grid.shape) == 2: # noqa: PLR2004 994 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 995 return LatticeMaze(connection_list=connection_list) 996 997 # otherwise, detect and check it's valid 998 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 999 if cls not in cls_detected.__mro__: 1000 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" 1001 raise ValueError( 1002 err_msg, 1003 ) 1004 1005 ( 1006 connection_list, 1007 grid_shape, 1008 marked_pos, 1009 ) = cls._from_pixel_grid_with_positions( 1010 pixel_grid=pixel_grid, 1011 marked_positions=dict( 1012 start=PixelColors.START, 1013 end=PixelColors.END, 1014 solution=PixelColors.PATH, 1015 ), 1016 ) 1017 # if we wanted a LatticeMaze, return it 1018 if cls == LatticeMaze: 1019 return LatticeMaze(connection_list=connection_list) 1020 1021 # otherwise, keep going 1022 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 1023 1024 # start and end pos 1025 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 1026 assert start_pos_arr.shape == ( 1027 1, 1028 2, 1029 ), ( 1030 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1031 ) 1032 assert end_pos_arr.shape == ( 1033 1, 1034 2, 1035 ), ( 1036 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1037 ) 1038 1039 start_pos: Coord = start_pos_arr[0] 1040 end_pos: Coord = end_pos_arr[0] 1041 1042 # return a TargetedLatticeMaze if that's what we wanted 1043 if cls == TargetedLatticeMaze: 1044 return TargetedLatticeMaze( 1045 connection_list=connection_list, 1046 start_pos=start_pos, 1047 end_pos=end_pos, 1048 ) 1049 1050 # raw solution, only contains path elements and not start or end 1051 solution_raw: CoordArray = marked_pos["solution"] 1052 if len(solution_raw.shape) == 2: # noqa: PLR2004 1053 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 1054 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 1055 ) 1056 elif solution_raw.shape == (0,): 1057 # the solution and end should be immediately adjacent 1058 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 1059 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 1060 ) 1061 1062 # order the solution, by creating a list from the start to the end 1063 # add end pos, since we will iterate over all these starting from the start pos 1064 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 1065 tuple(end_pos), 1066 ] 1067 # solution starts with start point 1068 solution: list[CoordTup] = [tuple(start_pos)] 1069 while solution[-1] != tuple(end_pos): 1070 # use `get_coord_neighbors` to find connected neighbors 1071 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 1072 # TODO: make this less ugly 1073 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 1074 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 1075 ) 1076 # neighbors = neighbors[:, [1, 0]] 1077 # filter out neighbors that are not in the raw solution 1078 neighbors_filtered: CoordArray = np.array( 1079 [ 1080 coord 1081 for coord in neighbors 1082 if ( 1083 tuple(coord) in solution_raw_list 1084 and tuple(coord) not in solution 1085 ) 1086 ], 1087 ) 1088 # assert only one element is left, and then add it to the solution 1089 assert neighbors_filtered.shape == ( 1090 1, 1091 2, 1092 ), ( 1093 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 1094 ) 1095 solution.append(tuple(neighbors_filtered[0])) 1096 1097 # assert the solution is complete 1098 assert solution[0] == tuple(start_pos), ( 1099 f"solution {solution} does not start at start_pos {start_pos}" 1100 ) 1101 assert solution[-1] == tuple(end_pos), ( 1102 f"solution {solution} does not end at end_pos {end_pos}" 1103 ) 1104 1105 return cls( 1106 connection_list=np.array(connection_list), 1107 solution=np.array(solution), # type: ignore[call-arg] 1108 ) 1109 1110 # ============================================================ 1111 # to and from ASCII 1112 # ============================================================ 1113 def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]: 1114 # Get the pixel grid using to_pixels(). 1115 pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw() 1116 1117 # Replace pixel values with ASCII characters. 1118 ascii_grid: Shaped[np.ndarray, "x y"] = np.full( 1119 pixel_grid.shape, 1120 AsciiChars.WALL, 1121 dtype=str, 1122 ) 1123 ascii_grid[pixel_grid == True] = AsciiChars.OPEN # noqa: E712 1124 1125 return ascii_grid 1126 1127 def as_ascii( 1128 self, 1129 show_endpoints: bool = True, 1130 show_solution: bool = True, 1131 ) -> str: 1132 """return an ASCII grid of the maze 1133 1134 useful for debugging in the terminal, or as it's own format 1135 1136 can be reversed with `LatticeMaze.from_ascii()` 1137 """ 1138 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 1139 pixel_grid: PixelGrid = self.as_pixels( 1140 show_endpoints=show_endpoints, 1141 show_solution=show_solution, 1142 ) 1143 1144 chars_replace: tuple = tuple() 1145 if show_endpoints: 1146 chars_replace += (AsciiChars.START, AsciiChars.END) 1147 if show_solution: 1148 chars_replace += (AsciiChars.PATH,) 1149 1150 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1151 if ascii_char in chars_replace: 1152 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 1153 1154 return "\n".join("".join(row) for row in ascii_grid) 1155 1156 @classmethod 1157 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 1158 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" 1159 lines: list[str] = ascii_str.strip().split("\n") 1160 lines = [line.strip() for line in lines] 1161 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 1162 [list(line) for line in lines], 1163 dtype=str, 1164 ) 1165 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 1166 1167 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1168 pixel_grid[ascii_grid == ascii_char] = pixel_color 1169 1170 return cls.from_pixels(pixel_grid)
lattice maze (nodes on a lattice, connections only to neighboring nodes)
Connection List represents which nodes (N) are connected in each direction.
First and second elements represent rightward and downward connections, respectively.
Example: Connection list: [ [ # down [F T], [F F] ], [ # right [T F], [T F] ] ]
Nodes with connections
N T N F
F T
N T N F
F F
Graph:
N - N
|
N - N
Note: the bottom row connections going down, and the right-hand connections going right, will always be False.
169 @property 170 def grid_n(self) -> int: 171 "grid size as int, raises `AssertionError` if not square" 172 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 173 return self.grid_shape[0]
grid size as int, raises AssertionError
if not square
183 @staticmethod 184 def heuristic(a: CoordTup, b: CoordTup) -> float: 185 """return manhattan distance between two points""" 186 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])
return manhattan distance between two points
192 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 193 """returns whether two nodes are connected""" 194 delta: Coord = b - a 195 if np.abs(delta).sum() != 1: 196 # return false if not even adjacent 197 return False 198 else: 199 # test for wall 200 dim: int = int(np.argmax(np.abs(delta))) 201 clist_node: Coord = a if (delta.sum() > 0) else b 202 return self.connection_list[dim, clist_node[0], clist_node[1]]
returns whether two nodes are connected
204 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 205 """check if a path is valid""" 206 # check path is not empty 207 if len(path) == 0: 208 return empty_is_valid 209 210 # check all coords in bounds of maze 211 if not np.all((path >= 0) & (path < self.grid_shape)): 212 return False 213 214 # check all nodes connected 215 for i in range(len(path) - 1): 216 if not self.nodes_connected(path[i], path[i + 1]): 217 return False 218 return True
check if a path is valid
220 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 221 """Returns an array with the connectivity degree of each coord. 222 223 I.e., how many neighbors each coord has. 224 """ 225 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 226 self.connection_list.astype(np.int8) 227 ) 228 degrees: Int8[np.ndarray, "row col"] = np.sum( 229 int_conn, 230 axis=0, 231 ) # Connections to east and south 232 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 233 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 234 return degrees
Returns an array with the connectivity degree of each coord.
I.e., how many neighbors each coord has.
236 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 237 """Returns an array of the neighboring, connected coords of `c`.""" 238 c = np.array(c) # type: ignore[assignment] 239 neighbors: list[Coord] = [ 240 neighbor 241 for neighbor in (c + NEIGHBORS_MASK) 242 if ( 243 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 244 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 245 and self.nodes_connected(c, neighbor) # connected 246 ) 247 ] 248 249 output: CoordArray = np.array(neighbors) 250 if len(neighbors) > 0: 251 assert output.shape == ( 252 len(neighbors), 253 2, 254 ), ( 255 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 256 ) 257 return output
Returns an array of the neighboring, connected coords of c
.
259 def gen_connected_component_from(self, c: Coord) -> CoordArray: 260 """return the connected component from a given coordinate""" 261 # Stack for DFS 262 stack: list[Coord] = [c] 263 264 # Set to store visited nodes 265 visited: set[CoordTup] = set() 266 267 while stack: 268 current_node: Coord = stack.pop() 269 # this is fine since we know current_node is a coord and thus of length 2 270 visited.add(tuple(current_node)) # type: ignore[arg-type] 271 272 # Get the neighbors of the current node 273 neighbors = self.get_coord_neighbors(current_node) 274 275 # Iterate over neighbors 276 for neighbor in neighbors: 277 if tuple(neighbor) not in visited: 278 stack.append(neighbor) 279 280 return np.array(list(visited))
return the connected component from a given coordinate
282 def find_shortest_path( 283 self, 284 c_start: CoordTup | Coord, 285 c_end: CoordTup | Coord, 286 ) -> CoordArray: 287 """find the shortest path between two coordinates, using A*""" 288 c_start = tuple(c_start) # type: ignore[assignment] 289 c_end = tuple(c_end) # type: ignore[assignment] 290 291 g_score: dict[CoordTup, float] = ( 292 dict() 293 ) # cost of cheapest path to node from start currently known 294 f_score: dict[CoordTup, float] = { 295 c_start: 0.0, 296 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 297 298 # init 299 g_score[c_start] = 0.0 300 g_score[c_start] = self.heuristic(c_start, c_end) 301 302 closed_vtx: set[CoordTup] = set() # nodes already evaluated 303 # nodes to be evaluated 304 # we need a set of the tuples, dont place the ints in the set 305 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 306 source: dict[CoordTup, CoordTup] = ( 307 dict() 308 ) # node immediately preceding each node in the path (currently known shortest path) 309 310 while open_vtx: 311 # get lowest f_score node 312 # mypy cant tell that c is of length 2 313 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] 314 # f_current: float = f_score[c_current] 315 316 # check if goal is reached 317 if c_end == c_current: 318 path: list[CoordTup] = [c_current] 319 p_current: CoordTup = c_current 320 while p_current in source: 321 p_current = source[p_current] 322 path.append(p_current) 323 # ---------------------------------------------------------------------- 324 # this is the only return statement 325 return np.array(path[::-1]) 326 # ---------------------------------------------------------------------- 327 328 # close current node 329 closed_vtx.add(c_current) 330 open_vtx.remove(c_current) 331 332 # update g_score of neighbors 333 _np_neighbor: Coord 334 for _np_neighbor in self.get_coord_neighbors(c_current): 335 neighbor: CoordTup = tuple(_np_neighbor) 336 337 if neighbor in closed_vtx: 338 # already checked 339 continue 340 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 341 342 if neighbor not in open_vtx: 343 # found new vtx, so add 344 open_vtx.add(neighbor) 345 346 elif g_temp >= g_score[neighbor]: 347 # if already knew about this one, but current g_score is worse, skip 348 continue 349 350 # store g_score and source 351 source[neighbor] = c_current 352 g_score[neighbor] = g_temp 353 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 354 355 raise ValueError( 356 "A solution could not be found!", 357 f"{c_start = }, {c_end = }", 358 self.as_ascii(), 359 )
find the shortest path between two coordinates, using A*
361 def get_nodes(self) -> CoordArray: 362 """return a list of all nodes in the maze""" 363 rows: Int[np.ndarray, "x y"] 364 cols: Int[np.ndarray, "x y"] 365 rows, cols = np.meshgrid( 366 range(self.grid_shape[0]), 367 range(self.grid_shape[1]), 368 indexing="ij", 369 ) 370 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 371 return nodes
return a list of all nodes in the maze
373 def get_connected_component(self) -> CoordArray: 374 """get the largest (and assumed only nonsingular) connected component of the maze 375 376 TODO: other connected components? 377 """ 378 if (self.generation_meta is None) or ( 379 self.generation_meta.get("fully_connected", False) 380 ): 381 # for fully connected case, pick any two positions 382 return self.get_nodes() 383 else: 384 # if metadata provided, use visited cells 385 visited_cells: set[CoordTup] | None = self.generation_meta.get( 386 "visited_cells", 387 None, 388 ) 389 if visited_cells is None: 390 # TODO: dynamically generate visited_cells? 391 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 392 raise ValueError( 393 err_msg, 394 ) 395 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 396 return visited_cells_np
get the largest (and assumed only nonsingular) connected component of the maze
TODO: other connected components?
418 def generate_random_path( # noqa: C901 419 self, 420 allowed_start: CoordList | None = None, 421 allowed_end: CoordList | None = None, 422 deadend_start: bool = False, 423 deadend_end: bool = False, 424 endpoints_not_equal: bool = False, 425 except_on_no_valid_endpoint: bool = True, 426 ) -> typing.Optional[CoordArray]: 427 """return a path between randomly chosen start and end nodes within the connected component 428 429 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 430 431 # Parameters: 432 - `allowed_start : CoordList | None` 433 a list of allowed start positions. If `None`, any position in the connected component is allowed 434 (defaults to `None`) 435 - `allowed_end : CoordList | None` 436 a list of allowed end positions. If `None`, any position in the connected component is allowed 437 (defaults to `None`) 438 - `deadend_start : bool` 439 whether to ***force*** the start position to be a deadend (defaults to `False`) 440 (defaults to `False`) 441 - `deadend_end : bool` 442 whether to ***force*** the end position to be a deadend (defaults to `False`) 443 (defaults to `False`) 444 - `endpoints_not_equal : bool` 445 whether to ensure tha the start and end point are not the same 446 (defaults to `False`) 447 - `except_on_no_valid_endpoint : bool` 448 whether to raise an error if no valid start or end positions are found 449 if this is `False`, the function might return `None` and this must be handled by the caller 450 (defaults to `True`) 451 452 # Returns: 453 - `CoordArray` 454 a path between the selected start and end positions 455 456 # Raises: 457 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 458 """ 459 # we can't create a "path" in a single-node maze 460 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 461 f"can't create path in single-node maze: {self.as_ascii()}" 462 ) 463 464 # get connected component 465 connected_component: CoordArray = self.get_connected_component() 466 467 # initialize start and end positions 468 positions: Int[np.int8, "2 2"] 469 470 # if no special conditions on start and end positions 471 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 472 None, 473 None, 474 False, 475 False, 476 ): 477 try: 478 positions = connected_component[ # type: ignore[assignment] 479 np.random.choice( 480 len(connected_component), 481 size=2, 482 replace=False, 483 ) 484 ] 485 except ValueError as e: 486 if except_on_no_valid_endpoint: 487 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" 488 raise NoValidEndpointException( 489 err_msg, 490 ) from e 491 return None 492 493 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 494 495 # handle special conditions 496 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 497 # copy connected component set 498 allowed_start_set: set[CoordTup] = connected_component_set.copy() 499 allowed_end_set: set[CoordTup] = connected_component_set.copy() 500 501 # filter by explicitly allowed start and end positions 502 # '# type: ignore[assignment]' here because the returned tuple can be of any length 503 if allowed_start is not None: 504 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 505 506 if allowed_end is not None: 507 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 508 509 # filter by forcing deadends 510 if deadend_start: 511 allowed_start_set = set( 512 filter( 513 lambda x: len(self.get_coord_neighbors(x)) == 1, 514 allowed_start_set, 515 ), 516 ) 517 518 if deadend_end: 519 allowed_end_set = set( 520 filter( 521 lambda x: len(self.get_coord_neighbors(x)) == 1, 522 allowed_end_set, 523 ), 524 ) 525 526 # check we have valid positions 527 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 528 if except_on_no_valid_endpoint: 529 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 530 raise NoValidEndpointException( 531 err_msg, 532 ) 533 return None 534 535 # randomly select start and end positions 536 try: 537 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 538 start_pos: CoordTup = tuple( # type: ignore[assignment] 539 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], 540 ) 541 if endpoints_not_equal: 542 # remove start position from end positions 543 allowed_end_set.discard(start_pos) 544 end_pos: CoordTup = tuple( # type: ignore[assignment] 545 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], 546 ) 547 except ValueError as e: 548 if except_on_no_valid_endpoint: 549 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 550 raise NoValidEndpointException( 551 err_msg, 552 ) from e 553 return None 554 555 return self.find_shortest_path(start_pos, end_pos)
return a path between randomly chosen start and end nodes within the connected component
Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
Parameters:
allowed_start : CoordList | None
a list of allowed start positions. IfNone
, any position in the connected component is allowed (defaults toNone
)allowed_end : CoordList | None
a list of allowed end positions. IfNone
, any position in the connected component is allowed (defaults toNone
)deadend_start : bool
whether to force the start position to be a deadend (defaults toFalse
) (defaults toFalse
)deadend_end : bool
whether to force the end position to be a deadend (defaults toFalse
) (defaults toFalse
)endpoints_not_equal : bool
whether to ensure tha the start and end point are not the same (defaults toFalse
)except_on_no_valid_endpoint : bool
whether to raise an error if no valid start or end positions are found if this isFalse
, the function might returnNone
and this must be handled by the caller (defaults toTrue
)
Returns:
CoordArray
a path between the selected start and end positions
Raises:
NoValidEndpointException
: if no valid start or end positions are found, andexcept_on_no_valid_endpoint
isTrue
560 def as_adj_list( 561 self, 562 shuffle_d0: bool = True, 563 shuffle_d1: bool = True, 564 ) -> Int8[np.ndarray, "conn start_end coord"]: 565 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" 566 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)
return the maze as an adjacency list, wraps maze_dataset.token_utils.connection_list_to_adj_list
568 @classmethod 569 def from_adj_list( 570 cls, 571 adj_list: Int8[np.ndarray, "conn start_end coord"], 572 ) -> "LatticeMaze": 573 """create a LatticeMaze from a list of connections 574 575 > [!NOTE] 576 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 577 """ 578 # this is where it would probably break for rectangular mazes 579 grid_n: int = adj_list.max() + 1 580 581 connection_list: ConnectionList = np.zeros( 582 (2, grid_n, grid_n), 583 dtype=np.bool_, 584 ) 585 586 for c_start, c_end in adj_list: 587 # check that exactly 1 coordinate matches 588 if (c_start == c_end).sum() != 1: 589 raise ValueError("invalid connection") 590 591 # get the direction 592 d: int = (c_start != c_end).argmax() 593 594 x: int 595 y: int 596 # pick whichever has the lesser value in the direction `d` 597 if c_start[d] < c_end[d]: 598 x, y = c_start 599 else: 600 x, y = c_end 601 602 connection_list[d, x, y] = True 603 604 return LatticeMaze( 605 connection_list=connection_list, 606 )
create a LatticeMaze from a list of connections
This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
608 def as_adj_list_tokens(self) -> list[str | CoordTup]: 609 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" 610 warnings.warn( 611 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 612 TokenizerDeprecationWarning, 613 ) 614 return [ 615 SPECIAL_TOKENS.ADJLIST_START, 616 *chain.from_iterable( # type: ignore[list-item] 617 [ 618 [ 619 tuple(c_s), 620 SPECIAL_TOKENS.CONNECTOR, 621 tuple(c_e), 622 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 623 ] 624 for c_s, c_e in self.as_adj_list() 625 ], 626 ), 627 SPECIAL_TOKENS.ADJLIST_END, 628 ]
(deprecated!) turn the maze into adjacency list tokens, use MazeTokenizerModular
instead
680 def as_tokens( 681 self, 682 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 683 ) -> list[str]: 684 """serialize maze and solution to tokens""" 685 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 686 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 687 else: 688 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type]
serialize maze and solution to tokens
800 @classmethod 801 def from_tokens( 802 cls, 803 tokens: list[str], 804 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 805 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 806 """Constructs a maze from a tokenization. 807 808 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 809 """ 810 # HACK: type ignores here fine since we check the instance 811 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 812 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 813 if ( 814 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 815 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 816 ): 817 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 818 raise NotImplementedError( 819 err_msg, 820 ) 821 822 if isinstance(tokens, str): 823 tokens = tokens.split() 824 825 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 826 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 827 else: 828 raise NotImplementedError("only AOTP tokenization is supported")
Constructs a maze from a tokenization.
Only legacy tokenizers and their MazeTokenizerModular
analogs are supported.
859 def as_pixels( 860 self, 861 show_endpoints: bool = True, 862 show_solution: bool = True, 863 ) -> PixelGrid: 864 """convert the maze to a pixel grid 865 866 - useful as a simpler way of plotting the maze than the more complex `MazePlot` 867 - the same underlying representation as `as_ascii` but as an image 868 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data 869 """ 870 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 871 # but solution, start_pos, end_pos not always defined 872 # but its fine since we explicitly check the type 873 if show_solution and not show_endpoints: 874 raise ValueError("show_solution=True requires show_endpoints=True") 875 # convert original bool pixel grid to RGB 876 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 877 pixel_grid: PixelGrid = np.full( 878 (*pixel_grid_bw.shape, 3), 879 PixelColors.WALL, 880 dtype=np.uint8, 881 ) 882 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 883 884 if self.__class__ == LatticeMaze: 885 return pixel_grid 886 887 # set endpoints for TargetedLatticeMaze 888 if self.__class__ == TargetedLatticeMaze: 889 if show_endpoints: 890 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 891 PixelColors.START 892 ) 893 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 894 PixelColors.END 895 ) 896 return pixel_grid 897 898 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 899 if show_solution: 900 for coord in self.solution: # type: ignore[attr-defined] 901 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 902 903 # set pixels between coords 904 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 905 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 906 # check they are adjacent using norm 907 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 908 f"Coords {coord} and {next_coord} are not adjacent" 909 ) 910 # set pixel between them 911 pixel_grid[ 912 coord[0] * 2 + 1 + next_coord[0] - coord[0], 913 coord[1] * 2 + 1 + next_coord[1] - coord[1], 914 ] = PixelColors.PATH 915 916 # set endpoints (again, since path would overwrite them) 917 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 918 PixelColors.START 919 ) 920 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 921 PixelColors.END 922 ) 923 924 return pixel_grid
convert the maze to a pixel grid
- useful as a simpler way of plotting the maze than the more complex
MazePlot
- the same underlying representation as
as_ascii
but as an image - used in
RasterizedMazeDataset
, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data
979 @classmethod 980 def from_pixels( 981 cls, 982 pixel_grid: PixelGrid, 983 ) -> "LatticeMaze": 984 """create a LatticeMaze from a pixel grid. reverse of `as_pixels` 985 986 # Raises: 987 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` 988 """ 989 connection_list: ConnectionList 990 grid_shape: tuple[int, int] 991 992 # if a binary pixel grid, return regular LatticeMaze 993 if len(pixel_grid.shape) == 2: # noqa: PLR2004 994 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 995 return LatticeMaze(connection_list=connection_list) 996 997 # otherwise, detect and check it's valid 998 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 999 if cls not in cls_detected.__mro__: 1000 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" 1001 raise ValueError( 1002 err_msg, 1003 ) 1004 1005 ( 1006 connection_list, 1007 grid_shape, 1008 marked_pos, 1009 ) = cls._from_pixel_grid_with_positions( 1010 pixel_grid=pixel_grid, 1011 marked_positions=dict( 1012 start=PixelColors.START, 1013 end=PixelColors.END, 1014 solution=PixelColors.PATH, 1015 ), 1016 ) 1017 # if we wanted a LatticeMaze, return it 1018 if cls == LatticeMaze: 1019 return LatticeMaze(connection_list=connection_list) 1020 1021 # otherwise, keep going 1022 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 1023 1024 # start and end pos 1025 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 1026 assert start_pos_arr.shape == ( 1027 1, 1028 2, 1029 ), ( 1030 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1031 ) 1032 assert end_pos_arr.shape == ( 1033 1, 1034 2, 1035 ), ( 1036 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1037 ) 1038 1039 start_pos: Coord = start_pos_arr[0] 1040 end_pos: Coord = end_pos_arr[0] 1041 1042 # return a TargetedLatticeMaze if that's what we wanted 1043 if cls == TargetedLatticeMaze: 1044 return TargetedLatticeMaze( 1045 connection_list=connection_list, 1046 start_pos=start_pos, 1047 end_pos=end_pos, 1048 ) 1049 1050 # raw solution, only contains path elements and not start or end 1051 solution_raw: CoordArray = marked_pos["solution"] 1052 if len(solution_raw.shape) == 2: # noqa: PLR2004 1053 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 1054 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 1055 ) 1056 elif solution_raw.shape == (0,): 1057 # the solution and end should be immediately adjacent 1058 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 1059 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 1060 ) 1061 1062 # order the solution, by creating a list from the start to the end 1063 # add end pos, since we will iterate over all these starting from the start pos 1064 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 1065 tuple(end_pos), 1066 ] 1067 # solution starts with start point 1068 solution: list[CoordTup] = [tuple(start_pos)] 1069 while solution[-1] != tuple(end_pos): 1070 # use `get_coord_neighbors` to find connected neighbors 1071 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 1072 # TODO: make this less ugly 1073 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 1074 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 1075 ) 1076 # neighbors = neighbors[:, [1, 0]] 1077 # filter out neighbors that are not in the raw solution 1078 neighbors_filtered: CoordArray = np.array( 1079 [ 1080 coord 1081 for coord in neighbors 1082 if ( 1083 tuple(coord) in solution_raw_list 1084 and tuple(coord) not in solution 1085 ) 1086 ], 1087 ) 1088 # assert only one element is left, and then add it to the solution 1089 assert neighbors_filtered.shape == ( 1090 1, 1091 2, 1092 ), ( 1093 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 1094 ) 1095 solution.append(tuple(neighbors_filtered[0])) 1096 1097 # assert the solution is complete 1098 assert solution[0] == tuple(start_pos), ( 1099 f"solution {solution} does not start at start_pos {start_pos}" 1100 ) 1101 assert solution[-1] == tuple(end_pos), ( 1102 f"solution {solution} does not end at end_pos {end_pos}" 1103 ) 1104 1105 return cls( 1106 connection_list=np.array(connection_list), 1107 solution=np.array(solution), # type: ignore[call-arg] 1108 )
create a LatticeMaze from a pixel grid. reverse of as_pixels
Raises:
ValueError
: if the pixel grid cannot be cast to aLatticeMaze
-- it's probably aTargetedLatticeMaze
orSolvedMaze
1127 def as_ascii( 1128 self, 1129 show_endpoints: bool = True, 1130 show_solution: bool = True, 1131 ) -> str: 1132 """return an ASCII grid of the maze 1133 1134 useful for debugging in the terminal, or as it's own format 1135 1136 can be reversed with `LatticeMaze.from_ascii()` 1137 """ 1138 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 1139 pixel_grid: PixelGrid = self.as_pixels( 1140 show_endpoints=show_endpoints, 1141 show_solution=show_solution, 1142 ) 1143 1144 chars_replace: tuple = tuple() 1145 if show_endpoints: 1146 chars_replace += (AsciiChars.START, AsciiChars.END) 1147 if show_solution: 1148 chars_replace += (AsciiChars.PATH,) 1149 1150 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1151 if ascii_char in chars_replace: 1152 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 1153 1154 return "\n".join("".join(row) for row in ascii_grid)
return an ASCII grid of the maze
useful for debugging in the terminal, or as it's own format
can be reversed with LatticeMaze.from_ascii()
1156 @classmethod 1157 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 1158 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" 1159 lines: list[str] = ascii_str.strip().split("\n") 1160 lines = [line.strip() for line in lines] 1161 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 1162 [list(line) for line in lines], 1163 dtype=str, 1164 ) 1165 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 1166 1167 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1168 pixel_grid[ascii_grid == ascii_char] = pixel_color 1169 1170 return cls.from_pixels(pixel_grid)
get a LatticeMaze
from an ASCII representation (reverses LaticeMaze.as_ascii
)
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
99@dataclass(frozen=True) 100class AsciiChars: 101 "standard ascii characters for mazes" 102 103 WALL: str = "#" 104 OPEN: str = " " 105 START: str = "S" 106 END: str = "E" 107 PATH: str = "X"
standard ascii characters for mazes
88@dataclass(frozen=True) 89class PixelColors: 90 "standard colors for pixel grids" 91 92 WALL: RGB = (0, 0, 0) 93 OPEN: RGB = (255, 255, 255) 94 START: RGB = (0, 255, 0) 95 END: RGB = (255, 0, 0) 96 PATH: RGB = (0, 0, 255)
standard colors for pixel grids