maze_dataset.maze.lattice_maze
Implements LatticeMaze
, and the TargetedLatticeMaze
and SolvedMaze
subclasses.
also includes basic utilities, including converting to/from ascii and pixel representations.
1"""Implements `LatticeMaze`, and the `TargetedLatticeMaze` and `SolvedMaze` subclasses. 2 3also includes basic utilities, including converting to/from ascii and pixel representations. 4""" 5 6import typing 7import warnings 8from dataclasses import dataclass 9from itertools import chain 10 11import numpy as np 12from jaxtyping import Bool, Int, Int8, Shaped 13from muutils.json_serialize.serializable_dataclass import ( 14 SerializableDataclass, 15 serializable_dataclass, 16 serializable_field, 17) 18from muutils.misc import isinstance_by_type_name, list_split 19 20from maze_dataset.constants import ( 21 NEIGHBORS_MASK, 22 SPECIAL_TOKENS, 23 ConnectionList, 24 Coord, 25 CoordArray, 26 CoordList, 27 CoordTup, 28) 29from maze_dataset.token_utils import ( 30 TokenizerDeprecationWarning, 31 connection_list_to_adj_list, 32 get_adj_list_tokens, 33 get_origin_tokens, 34 get_path_tokens, 35 get_target_tokens, 36) 37 38if typing.TYPE_CHECKING: 39 from maze_dataset.tokenization import ( 40 MazeTokenizer, 41 MazeTokenizerModular, 42 TokenizationMode, 43 ) 44 45RGB = tuple[int, int, int] 46"rgb tuple of values 0-255" 47 48PixelGrid = Int[np.ndarray, "x y rgb"] 49"rgb grid of pixels" 50BinaryPixelGrid = Bool[np.ndarray, "x y"] 51"boolean grid of pixels" 52 53DIM_2: int = 2 54"2 dimensions" 55 56 57class NoValidEndpointException(Exception): # noqa: N818 58 """Raised when no valid start or end positions are found in a maze.""" 59 60 pass 61 62 63def _fill_edges_with_walls(connection_list: ConnectionList) -> ConnectionList: 64 """fill the last elements of the connections lists as false for each dim""" 65 for dim in range(connection_list.shape[0]): 66 # last row for down 67 if dim == 0: 68 connection_list[dim, -1, :] = False 69 # last column for right 70 elif dim == 1: 71 connection_list[dim, :, -1] = False 72 else: 73 err_msg: str = f"only 2d lattices supported. got {dim=}" 74 raise NotImplementedError(err_msg) 75 return connection_list 76 77 78def color_in_pixel_grid(pixel_grid: PixelGrid, color: RGB) -> bool: 79 """check if a color is in a pixel grid""" 80 for row in pixel_grid: 81 for pixel in row: 82 if np.all(pixel == color): 83 return True 84 return False 85 86 87@dataclass(frozen=True) 88class PixelColors: 89 "standard colors for pixel grids" 90 91 WALL: RGB = (0, 0, 0) 92 OPEN: RGB = (255, 255, 255) 93 START: RGB = (0, 255, 0) 94 END: RGB = (255, 0, 0) 95 PATH: RGB = (0, 0, 255) 96 97 98@dataclass(frozen=True) 99class AsciiChars: 100 "standard ascii characters for mazes" 101 102 WALL: str = "#" 103 OPEN: str = " " 104 START: str = "S" 105 END: str = "E" 106 PATH: str = "X" 107 108 109ASCII_PIXEL_PAIRINGS: dict[str, RGB] = { 110 AsciiChars.WALL: PixelColors.WALL, 111 AsciiChars.OPEN: PixelColors.OPEN, 112 AsciiChars.START: PixelColors.START, 113 AsciiChars.END: PixelColors.END, 114 AsciiChars.PATH: PixelColors.PATH, 115} 116"map ascii characters to pixel colors" 117 118 119@serializable_dataclass( 120 frozen=True, 121 kw_only=True, 122 properties_to_serialize=["lattice_dim", "generation_meta"], 123) 124class LatticeMaze(SerializableDataclass): 125 """lattice maze (nodes on a lattice, connections only to neighboring nodes) 126 127 Connection List represents which nodes (N) are connected in each direction. 128 129 First and second elements represent rightward and downward connections, 130 respectively. 131 132 Example: 133 Connection list: 134 [ 135 [ # down 136 [F T], 137 [F F] 138 ], 139 [ # right 140 [T F], 141 [T F] 142 ] 143 ] 144 145 Nodes with connections 146 N T N F 147 F T 148 N T N F 149 F F 150 151 Graph: 152 N - N 153 | 154 N - N 155 156 Note: the bottom row connections going down, and the 157 right-hand connections going right, will always be False. 158 159 """ 160 161 connection_list: ConnectionList 162 generation_meta: dict | None = serializable_field(default=None, compare=False) 163 164 lattice_dim = property(lambda self: self.connection_list.shape[0]) 165 grid_shape = property(lambda self: self.connection_list.shape[1:]) 166 n_connections = property(lambda self: self.connection_list.sum()) 167 168 @property 169 def grid_n(self) -> int: 170 "grid size as int, raises `AssertionError` if not square" 171 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 172 return self.grid_shape[0] 173 174 # ============================================================ 175 # basic methods 176 # ============================================================ 177 178 def __eq__(self, other: object) -> bool: 179 "equality check calls super" 180 return super().__eq__(other) 181 182 @staticmethod 183 def heuristic(a: CoordTup, b: CoordTup) -> float: 184 """return manhattan distance between two points""" 185 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1]) 186 187 def __hash__(self) -> int: 188 """hash the connection list by converting connection list to bytes""" 189 return hash(self.connection_list.tobytes()) 190 191 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 192 """returns whether two nodes are connected""" 193 delta: Coord = b - a 194 if np.abs(delta).sum() != 1: 195 # return false if not even adjacent 196 return False 197 else: 198 # test for wall 199 dim: int = int(np.argmax(np.abs(delta))) 200 clist_node: Coord = a if (delta.sum() > 0) else b 201 return self.connection_list[dim, clist_node[0], clist_node[1]] 202 203 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 204 """check if a path is valid""" 205 # check path is not empty 206 if len(path) == 0: 207 return empty_is_valid 208 209 # check all coords in bounds of maze 210 if not np.all((path >= 0) & (path < self.grid_shape)): 211 return False 212 213 # check all nodes connected 214 for i in range(len(path) - 1): 215 if not self.nodes_connected(path[i], path[i + 1]): 216 return False 217 return True 218 219 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 220 """Returns an array with the connectivity degree of each coord. 221 222 I.e., how many neighbors each coord has. 223 """ 224 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 225 self.connection_list.astype(np.int8) 226 ) 227 degrees: Int8[np.ndarray, "row col"] = np.sum( 228 int_conn, 229 axis=0, 230 ) # Connections to east and south 231 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 232 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 233 return degrees 234 235 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 236 """Returns an array of the neighboring, connected coords of `c`.""" 237 c = np.array(c) # type: ignore[assignment] 238 neighbors: list[Coord] = [ 239 neighbor 240 for neighbor in (c + NEIGHBORS_MASK) 241 if ( 242 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 243 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 244 and self.nodes_connected(c, neighbor) # connected 245 ) 246 ] 247 248 output: CoordArray = np.array(neighbors) 249 if len(neighbors) > 0: 250 assert output.shape == ( 251 len(neighbors), 252 2, 253 ), ( 254 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 255 ) 256 return output 257 258 def gen_connected_component_from(self, c: Coord) -> CoordArray: 259 """return the connected component from a given coordinate""" 260 # Stack for DFS 261 stack: list[Coord] = [c] 262 263 # Set to store visited nodes 264 visited: set[CoordTup] = set() 265 266 while stack: 267 current_node: Coord = stack.pop() 268 # this is fine since we know current_node is a coord and thus of length 2 269 visited.add(tuple(current_node)) # type: ignore[arg-type] 270 271 # Get the neighbors of the current node 272 neighbors = self.get_coord_neighbors(current_node) 273 274 # Iterate over neighbors 275 for neighbor in neighbors: 276 if tuple(neighbor) not in visited: 277 stack.append(neighbor) 278 279 return np.array(list(visited)) 280 281 def find_shortest_path( 282 self, 283 c_start: CoordTup | Coord, 284 c_end: CoordTup | Coord, 285 ) -> CoordArray: 286 """find the shortest path between two coordinates, using A*""" 287 c_start = tuple(c_start) # type: ignore[assignment] 288 c_end = tuple(c_end) # type: ignore[assignment] 289 290 g_score: dict[CoordTup, float] = ( 291 dict() 292 ) # cost of cheapest path to node from start currently known 293 f_score: dict[CoordTup, float] = { 294 c_start: 0.0, 295 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 296 297 # init 298 g_score[c_start] = 0.0 299 g_score[c_start] = self.heuristic(c_start, c_end) 300 301 closed_vtx: set[CoordTup] = set() # nodes already evaluated 302 # nodes to be evaluated 303 # we need a set of the tuples, dont place the ints in the set 304 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 305 source: dict[CoordTup, CoordTup] = ( 306 dict() 307 ) # node immediately preceding each node in the path (currently known shortest path) 308 309 while open_vtx: 310 # get lowest f_score node 311 # mypy cant tell that c is of length 2 312 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] 313 # f_current: float = f_score[c_current] 314 315 # check if goal is reached 316 if c_end == c_current: 317 path: list[CoordTup] = [c_current] 318 p_current: CoordTup = c_current 319 while p_current in source: 320 p_current = source[p_current] 321 path.append(p_current) 322 # ---------------------------------------------------------------------- 323 # this is the only return statement 324 return np.array(path[::-1]) 325 # ---------------------------------------------------------------------- 326 327 # close current node 328 closed_vtx.add(c_current) 329 open_vtx.remove(c_current) 330 331 # update g_score of neighbors 332 _np_neighbor: Coord 333 for _np_neighbor in self.get_coord_neighbors(c_current): 334 neighbor: CoordTup = tuple(_np_neighbor) 335 336 if neighbor in closed_vtx: 337 # already checked 338 continue 339 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 340 341 if neighbor not in open_vtx: 342 # found new vtx, so add 343 open_vtx.add(neighbor) 344 345 elif g_temp >= g_score[neighbor]: 346 # if already knew about this one, but current g_score is worse, skip 347 continue 348 349 # store g_score and source 350 source[neighbor] = c_current 351 g_score[neighbor] = g_temp 352 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 353 354 raise ValueError( 355 "A solution could not be found!", 356 f"{c_start = }, {c_end = }", 357 self.as_ascii(), 358 ) 359 360 def get_nodes(self) -> CoordArray: 361 """return a list of all nodes in the maze""" 362 rows: Int[np.ndarray, "x y"] 363 cols: Int[np.ndarray, "x y"] 364 rows, cols = np.meshgrid( 365 range(self.grid_shape[0]), 366 range(self.grid_shape[1]), 367 indexing="ij", 368 ) 369 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 370 return nodes 371 372 def get_connected_component(self) -> CoordArray: 373 """get the largest (and assumed only nonsingular) connected component of the maze 374 375 TODO: other connected components? 376 """ 377 if (self.generation_meta is None) or ( 378 self.generation_meta.get("fully_connected", False) 379 ): 380 # for fully connected case, pick any two positions 381 return self.get_nodes() 382 else: 383 # if metadata provided, use visited cells 384 visited_cells: set[CoordTup] | None = self.generation_meta.get( 385 "visited_cells", 386 None, 387 ) 388 if visited_cells is None: 389 # TODO: dynamically generate visited_cells? 390 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 391 raise ValueError( 392 err_msg, 393 ) 394 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 395 return visited_cells_np 396 397 @typing.overload 398 def generate_random_path( 399 self, 400 allowed_start: CoordList | None = None, 401 allowed_end: CoordList | None = None, 402 deadend_start: bool = False, 403 deadend_end: bool = False, 404 endpoints_not_equal: bool = False, 405 except_on_no_valid_endpoint: typing.Literal[True] = True, 406 ) -> CoordArray: ... 407 @typing.overload 408 def generate_random_path( 409 self, 410 allowed_start: CoordList | None = None, 411 allowed_end: CoordList | None = None, 412 deadend_start: bool = False, 413 deadend_end: bool = False, 414 endpoints_not_equal: bool = False, 415 except_on_no_valid_endpoint: typing.Literal[False] = False, 416 ) -> typing.Optional[CoordArray]: ... 417 def generate_random_path( # noqa: C901 418 self, 419 allowed_start: CoordList | None = None, 420 allowed_end: CoordList | None = None, 421 deadend_start: bool = False, 422 deadend_end: bool = False, 423 endpoints_not_equal: bool = False, 424 except_on_no_valid_endpoint: bool = True, 425 ) -> typing.Optional[CoordArray]: 426 """return a path between randomly chosen start and end nodes within the connected component 427 428 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 429 430 # Parameters: 431 - `allowed_start : CoordList | None` 432 a list of allowed start positions. If `None`, any position in the connected component is allowed 433 (defaults to `None`) 434 - `allowed_end : CoordList | None` 435 a list of allowed end positions. If `None`, any position in the connected component is allowed 436 (defaults to `None`) 437 - `deadend_start : bool` 438 whether to ***force*** the start position to be a deadend (defaults to `False`) 439 (defaults to `False`) 440 - `deadend_end : bool` 441 whether to ***force*** the end position to be a deadend (defaults to `False`) 442 (defaults to `False`) 443 - `endpoints_not_equal : bool` 444 whether to ensure tha the start and end point are not the same 445 (defaults to `False`) 446 - `except_on_no_valid_endpoint : bool` 447 whether to raise an error if no valid start or end positions are found 448 if this is `False`, the function might return `None` and this must be handled by the caller 449 (defaults to `True`) 450 451 # Returns: 452 - `CoordArray` 453 a path between the selected start and end positions 454 455 # Raises: 456 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 457 """ 458 # we can't create a "path" in a single-node maze 459 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 460 f"can't create path in single-node maze: {self.as_ascii()}" 461 ) 462 463 # get connected component 464 connected_component: CoordArray = self.get_connected_component() 465 466 # initialize start and end positions 467 positions: Int[np.int8, "2 2"] 468 469 # if no special conditions on start and end positions 470 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 471 None, 472 None, 473 False, 474 False, 475 ): 476 try: 477 positions = connected_component[ # type: ignore[assignment] 478 np.random.choice( 479 len(connected_component), 480 size=2, 481 replace=False, 482 ) 483 ] 484 except ValueError as e: 485 if except_on_no_valid_endpoint: 486 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" 487 raise NoValidEndpointException( 488 err_msg, 489 ) from e 490 return None 491 492 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 493 494 # handle special conditions 495 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 496 # copy connected component set 497 allowed_start_set: set[CoordTup] = connected_component_set.copy() 498 allowed_end_set: set[CoordTup] = connected_component_set.copy() 499 500 # filter by explicitly allowed start and end positions 501 # '# type: ignore[assignment]' here because the returned tuple can be of any length 502 if allowed_start is not None: 503 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 504 505 if allowed_end is not None: 506 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 507 508 # filter by forcing deadends 509 if deadend_start: 510 allowed_start_set = set( 511 filter( 512 lambda x: len(self.get_coord_neighbors(x)) == 1, 513 allowed_start_set, 514 ), 515 ) 516 517 if deadend_end: 518 allowed_end_set = set( 519 filter( 520 lambda x: len(self.get_coord_neighbors(x)) == 1, 521 allowed_end_set, 522 ), 523 ) 524 525 # check we have valid positions 526 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 527 if except_on_no_valid_endpoint: 528 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 529 raise NoValidEndpointException( 530 err_msg, 531 ) 532 return None 533 534 # randomly select start and end positions 535 try: 536 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 537 start_pos: CoordTup = tuple( # type: ignore[assignment] 538 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], 539 ) 540 if endpoints_not_equal: 541 # remove start position from end positions 542 allowed_end_set.discard(start_pos) 543 end_pos: CoordTup = tuple( # type: ignore[assignment] 544 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], 545 ) 546 except ValueError as e: 547 if except_on_no_valid_endpoint: 548 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 549 raise NoValidEndpointException( 550 err_msg, 551 ) from e 552 return None 553 554 return self.find_shortest_path(start_pos, end_pos) 555 556 # ============================================================ 557 # to and from adjacency list 558 # ============================================================ 559 def as_adj_list( 560 self, 561 shuffle_d0: bool = True, 562 shuffle_d1: bool = True, 563 ) -> Int8[np.ndarray, "conn start_end coord"]: 564 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" 565 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1) 566 567 @classmethod 568 def from_adj_list( 569 cls, 570 adj_list: Int8[np.ndarray, "conn start_end coord"], 571 ) -> "LatticeMaze": 572 """create a LatticeMaze from a list of connections 573 574 > [!NOTE] 575 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 576 """ 577 # this is where it would probably break for rectangular mazes 578 grid_n: int = adj_list.max() + 1 579 580 connection_list: ConnectionList = np.zeros( 581 (2, grid_n, grid_n), 582 dtype=np.bool_, 583 ) 584 585 for c_start, c_end in adj_list: 586 # check that exactly 1 coordinate matches 587 if (c_start == c_end).sum() != 1: 588 raise ValueError("invalid connection") 589 590 # get the direction 591 d: int = (c_start != c_end).argmax() 592 593 x: int 594 y: int 595 # pick whichever has the lesser value in the direction `d` 596 if c_start[d] < c_end[d]: 597 x, y = c_start 598 else: 599 x, y = c_end 600 601 connection_list[d, x, y] = True 602 603 return LatticeMaze( 604 connection_list=connection_list, 605 ) 606 607 def as_adj_list_tokens(self) -> list[str | CoordTup]: 608 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" 609 warnings.warn( 610 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 611 TokenizerDeprecationWarning, 612 ) 613 return [ 614 SPECIAL_TOKENS.ADJLIST_START, 615 *chain.from_iterable( # type: ignore[list-item] 616 [ 617 [ 618 tuple(c_s), 619 SPECIAL_TOKENS.CONNECTOR, 620 tuple(c_e), 621 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 622 ] 623 for c_s, c_e in self.as_adj_list() 624 ], 625 ), 626 SPECIAL_TOKENS.ADJLIST_END, 627 ] 628 629 def _as_adj_list_tokens(self) -> list[str | CoordTup]: 630 return [ 631 SPECIAL_TOKENS.ADJLIST_START, 632 *chain.from_iterable( # type: ignore[list-item] 633 [ 634 [ 635 tuple(c_s), 636 SPECIAL_TOKENS.CONNECTOR, 637 tuple(c_e), 638 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 639 ] 640 for c_s, c_e in self.as_adj_list() 641 ], 642 ), 643 SPECIAL_TOKENS.ADJLIST_END, 644 ] 645 646 def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]: 647 """turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples""" 648 output: list[CoordTup | str] = self._as_adj_list_tokens() 649 # if getattr(self, "start_pos", None) is not None: 650 if isinstance(self, TargetedLatticeMaze): 651 output += self._get_start_pos_tokens() 652 if isinstance(self, TargetedLatticeMaze): 653 output += self._get_end_pos_tokens() 654 if isinstance(self, SolvedMaze): 655 output += self._get_solution_tokens() 656 return output 657 658 def _as_tokens( 659 self, 660 maze_tokenizer: "MazeTokenizer | TokenizationMode", 661 ) -> list[str]: 662 # type ignores here fine since we check the instance 663 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 664 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 665 if ( 666 isinstance_by_type_name(maze_tokenizer, "MazeTokenizer") 667 and maze_tokenizer.is_AOTP() # type: ignore[union-attr] 668 ): 669 coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP() 670 coords_processed: list[str] = maze_tokenizer.coords_to_strings( # type: ignore[union-attr] 671 coords=coords_raw, 672 when_noncoord="include", 673 ) 674 return coords_processed 675 else: 676 err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}" 677 raise NotImplementedError(err_msg) 678 679 def as_tokens( 680 self, 681 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 682 ) -> list[str]: 683 """serialize maze and solution to tokens""" 684 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 685 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 686 else: 687 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type] 688 689 @classmethod 690 def _from_tokens_AOTP( 691 cls, 692 tokens: list[str], 693 maze_tokenizer: "MazeTokenizer | MazeTokenizerModular", 694 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 695 """create a LatticeMaze from a list of tokens""" 696 # figure out what input format 697 # ======================================== 698 if tokens[0] == SPECIAL_TOKENS.ADJLIST_START: 699 adj_list_tokens = get_adj_list_tokens(tokens) 700 else: 701 # If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens 702 adj_list_tokens = tokens 703 warnings.warn( 704 "Assuming input is just adjacency list tokens, no special tokens found", 705 ) 706 707 # process edges for adjacency list 708 # ======================================== 709 edges: list[list[str]] = list_split( 710 adj_list_tokens, 711 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 712 ) 713 714 coordinates: list[tuple[CoordTup, CoordTup]] = list() 715 for e in edges: 716 # skip last endline 717 if len(e) != 0: 718 # convert to coords, split start and end 719 e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords( 720 e, 721 when_noncoord="include", 722 ) 723 # this assertion depends on the tokenizer having exactly one token for the connector 724 # which is also why we "include" above 725 # the connector token is discarded below 726 assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }" # noqa: PLR2004 727 assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, ( 728 f"invalid edge: {e = } {e_coords = }" 729 ) 730 e_coords_first: CoordTup = e_coords[0] # type: ignore[assignment] 731 e_coords_last: CoordTup = e_coords[-1] # type: ignore[assignment] 732 coordinates.append((e_coords_first, e_coords_last)) 733 734 assert all(len(c) == DIM_2 for c in coordinates), ( 735 f"invalid coordinates: {coordinates = }" 736 ) 737 adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates) 738 assert tuple(adj_list.shape) == ( 739 len(coordinates), 740 2, 741 2, 742 ), f"invalid adj_list: {adj_list.shape = } {coordinates = }" 743 744 output_maze: LatticeMaze = cls.from_adj_list(adj_list) 745 746 # add start and end positions 747 # ======================================== 748 is_targeted: bool = False 749 if all( 750 x in tokens 751 for x in ( 752 SPECIAL_TOKENS.ORIGIN_START, 753 SPECIAL_TOKENS.ORIGIN_END, 754 SPECIAL_TOKENS.TARGET_START, 755 SPECIAL_TOKENS.TARGET_END, 756 ) 757 ): 758 start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 759 get_origin_tokens(tokens), 760 when_noncoord="error", 761 ) 762 end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 763 get_target_tokens(tokens), 764 when_noncoord="error", 765 ) 766 assert len(start_pos_list) == 1, ( 767 f"invalid start_pos_list: {start_pos_list = }" 768 ) 769 assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }" 770 771 start_pos: CoordTup = start_pos_list[0] 772 end_pos: CoordTup = end_pos_list[0] 773 774 output_maze = TargetedLatticeMaze.from_lattice_maze( 775 lattice_maze=output_maze, 776 start_pos=start_pos, 777 end_pos=end_pos, 778 ) 779 780 is_targeted = True 781 782 if all( 783 x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END) 784 ): 785 assert is_targeted, "maze must be targeted to have a solution" 786 solution: list[CoordTup] = maze_tokenizer.strings_to_coords( 787 get_path_tokens(tokens, trim_end=True), 788 when_noncoord="error", 789 ) 790 output_maze = SolvedMaze.from_targeted_lattice_maze( 791 # HACK: I think this is fine, but im not sure 792 targeted_lattice_maze=output_maze, # type: ignore[arg-type] 793 solution=solution, 794 ) 795 796 return output_maze 797 798 # TODO: any way to get return type hinting working for this? 799 @classmethod 800 def from_tokens( 801 cls, 802 tokens: list[str], 803 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 804 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 805 """Constructs a maze from a tokenization. 806 807 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 808 """ 809 # HACK: type ignores here fine since we check the instance 810 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 811 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 812 if ( 813 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 814 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 815 ): 816 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 817 raise NotImplementedError( 818 err_msg, 819 ) 820 821 if isinstance(tokens, str): 822 tokens = tokens.split() 823 824 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 825 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 826 else: 827 raise NotImplementedError("only AOTP tokenization is supported") 828 829 # ============================================================ 830 # to and from pixels 831 # ============================================================ 832 def _as_pixels_bw(self) -> BinaryPixelGrid: 833 assert self.lattice_dim == DIM_2, "only 2D mazes are supported" 834 # Create an empty pixel grid with walls 835 pixel_grid: Int[np.ndarray, "x y"] = np.full( 836 (self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1), 837 False, 838 dtype=np.bool_, 839 ) 840 841 # Set white nodes 842 pixel_grid[1::2, 1::2] = True 843 844 # Set white connections (downward) 845 for i, row in enumerate(self.connection_list[0]): 846 for j, connected in enumerate(row): 847 if connected: 848 pixel_grid[i * 2 + 2, j * 2 + 1] = True 849 850 # Set white connections (rightward) 851 for i, row in enumerate(self.connection_list[1]): 852 for j, connected in enumerate(row): 853 if connected: 854 pixel_grid[i * 2 + 1, j * 2 + 2] = True 855 856 return pixel_grid 857 858 def as_pixels( 859 self, 860 show_endpoints: bool = True, 861 show_solution: bool = True, 862 ) -> PixelGrid: 863 """convert the maze to a pixel grid 864 865 - useful as a simpler way of plotting the maze than the more complex `MazePlot` 866 - the same underlying representation as `as_ascii` but as an image 867 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data 868 """ 869 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 870 # but solution, start_pos, end_pos not always defined 871 # but its fine since we explicitly check the type 872 if show_solution and not show_endpoints: 873 raise ValueError("show_solution=True requires show_endpoints=True") 874 # convert original bool pixel grid to RGB 875 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 876 pixel_grid: PixelGrid = np.full( 877 (*pixel_grid_bw.shape, 3), 878 PixelColors.WALL, 879 dtype=np.uint8, 880 ) 881 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 882 883 if self.__class__ == LatticeMaze: 884 return pixel_grid 885 886 # set endpoints for TargetedLatticeMaze 887 if self.__class__ == TargetedLatticeMaze: 888 if show_endpoints: 889 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 890 PixelColors.START 891 ) 892 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 893 PixelColors.END 894 ) 895 return pixel_grid 896 897 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 898 if show_solution: 899 for coord in self.solution: # type: ignore[attr-defined] 900 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 901 902 # set pixels between coords 903 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 904 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 905 # check they are adjacent using norm 906 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 907 f"Coords {coord} and {next_coord} are not adjacent" 908 ) 909 # set pixel between them 910 pixel_grid[ 911 coord[0] * 2 + 1 + next_coord[0] - coord[0], 912 coord[1] * 2 + 1 + next_coord[1] - coord[1], 913 ] = PixelColors.PATH 914 915 # set endpoints (again, since path would overwrite them) 916 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 917 PixelColors.START 918 ) 919 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 920 PixelColors.END 921 ) 922 923 return pixel_grid 924 925 @classmethod 926 def _from_pixel_grid_bw( 927 cls, 928 pixel_grid: BinaryPixelGrid, 929 ) -> tuple[ConnectionList, tuple[int, int]]: 930 grid_shape: tuple[int, int] = ( 931 pixel_grid.shape[0] // 2, 932 pixel_grid.shape[1] // 2, 933 ) 934 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_) 935 936 # Extract downward connections 937 connection_list[0] = pixel_grid[2::2, 1::2] 938 939 # Extract rightward connections 940 connection_list[1] = pixel_grid[1::2, 2::2] 941 942 return connection_list, grid_shape 943 944 @classmethod 945 def _from_pixel_grid_with_positions( 946 cls, 947 pixel_grid: PixelGrid | BinaryPixelGrid, 948 marked_positions: dict[str, RGB], 949 ) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]: 950 # Convert RGB pixel grid to Bool pixel grid 951 # error: Incompatible types in assignment (expression has type 952 # "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]", 953 # variable has type "ndarray[Any, Any]") [assignment] 954 pixel_grid_bw: BinaryPixelGrid = ~np.all( # type: ignore[assignment] 955 pixel_grid == PixelColors.WALL, 956 axis=-1, 957 ) 958 connection_list: ConnectionList 959 grid_shape: tuple[int, int] 960 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw) 961 962 # Find any marked positions 963 out_positions: dict[str, CoordArray] = dict() 964 for key, color in marked_positions.items(): 965 pos_temp: Int[np.ndarray, "x y"] = np.argwhere( 966 np.all(pixel_grid == color, axis=-1), 967 ) 968 pos_save: list[CoordTup] = list() 969 for pos in pos_temp: 970 # if it is a coordinate and not connection (transform position, %2==1) 971 if pos[0] % 2 == 1 and pos[1] % 2 == 1: 972 pos_save.append((pos[0] // 2, pos[1] // 2)) 973 974 out_positions[key] = np.array(pos_save) 975 976 return connection_list, grid_shape, out_positions 977 978 @classmethod 979 def from_pixels( 980 cls, 981 pixel_grid: PixelGrid, 982 ) -> "LatticeMaze": 983 """create a LatticeMaze from a pixel grid. reverse of `as_pixels` 984 985 # Raises: 986 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` 987 """ 988 connection_list: ConnectionList 989 grid_shape: tuple[int, int] 990 991 # if a binary pixel grid, return regular LatticeMaze 992 if len(pixel_grid.shape) == 2: # noqa: PLR2004 993 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 994 return LatticeMaze(connection_list=connection_list) 995 996 # otherwise, detect and check it's valid 997 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 998 if cls not in cls_detected.__mro__: 999 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" 1000 raise ValueError( 1001 err_msg, 1002 ) 1003 1004 ( 1005 connection_list, 1006 grid_shape, 1007 marked_pos, 1008 ) = cls._from_pixel_grid_with_positions( 1009 pixel_grid=pixel_grid, 1010 marked_positions=dict( 1011 start=PixelColors.START, 1012 end=PixelColors.END, 1013 solution=PixelColors.PATH, 1014 ), 1015 ) 1016 # if we wanted a LatticeMaze, return it 1017 if cls == LatticeMaze: 1018 return LatticeMaze(connection_list=connection_list) 1019 1020 # otherwise, keep going 1021 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 1022 1023 # start and end pos 1024 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 1025 assert start_pos_arr.shape == ( 1026 1, 1027 2, 1028 ), ( 1029 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1030 ) 1031 assert end_pos_arr.shape == ( 1032 1, 1033 2, 1034 ), ( 1035 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1036 ) 1037 1038 start_pos: Coord = start_pos_arr[0] 1039 end_pos: Coord = end_pos_arr[0] 1040 1041 # return a TargetedLatticeMaze if that's what we wanted 1042 if cls == TargetedLatticeMaze: 1043 return TargetedLatticeMaze( 1044 connection_list=connection_list, 1045 start_pos=start_pos, 1046 end_pos=end_pos, 1047 ) 1048 1049 # raw solution, only contains path elements and not start or end 1050 solution_raw: CoordArray = marked_pos["solution"] 1051 if len(solution_raw.shape) == 2: # noqa: PLR2004 1052 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 1053 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 1054 ) 1055 elif solution_raw.shape == (0,): 1056 # the solution and end should be immediately adjacent 1057 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 1058 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 1059 ) 1060 1061 # order the solution, by creating a list from the start to the end 1062 # add end pos, since we will iterate over all these starting from the start pos 1063 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 1064 tuple(end_pos), 1065 ] 1066 # solution starts with start point 1067 solution: list[CoordTup] = [tuple(start_pos)] 1068 while solution[-1] != tuple(end_pos): 1069 # use `get_coord_neighbors` to find connected neighbors 1070 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 1071 # TODO: make this less ugly 1072 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 1073 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 1074 ) 1075 # neighbors = neighbors[:, [1, 0]] 1076 # filter out neighbors that are not in the raw solution 1077 neighbors_filtered: CoordArray = np.array( 1078 [ 1079 coord 1080 for coord in neighbors 1081 if ( 1082 tuple(coord) in solution_raw_list 1083 and tuple(coord) not in solution 1084 ) 1085 ], 1086 ) 1087 # assert only one element is left, and then add it to the solution 1088 assert neighbors_filtered.shape == ( 1089 1, 1090 2, 1091 ), ( 1092 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 1093 ) 1094 solution.append(tuple(neighbors_filtered[0])) 1095 1096 # assert the solution is complete 1097 assert solution[0] == tuple(start_pos), ( 1098 f"solution {solution} does not start at start_pos {start_pos}" 1099 ) 1100 assert solution[-1] == tuple(end_pos), ( 1101 f"solution {solution} does not end at end_pos {end_pos}" 1102 ) 1103 1104 return cls( 1105 connection_list=np.array(connection_list), 1106 solution=np.array(solution), # type: ignore[call-arg] 1107 ) 1108 1109 # ============================================================ 1110 # to and from ASCII 1111 # ============================================================ 1112 def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]: 1113 # Get the pixel grid using to_pixels(). 1114 pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw() 1115 1116 # Replace pixel values with ASCII characters. 1117 ascii_grid: Shaped[np.ndarray, "x y"] = np.full( 1118 pixel_grid.shape, 1119 AsciiChars.WALL, 1120 dtype=str, 1121 ) 1122 ascii_grid[pixel_grid == True] = AsciiChars.OPEN # noqa: E712 1123 1124 return ascii_grid 1125 1126 def as_ascii( 1127 self, 1128 show_endpoints: bool = True, 1129 show_solution: bool = True, 1130 ) -> str: 1131 """return an ASCII grid of the maze 1132 1133 useful for debugging in the terminal, or as it's own format 1134 1135 can be reversed with `LatticeMaze.from_ascii()` 1136 """ 1137 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 1138 pixel_grid: PixelGrid = self.as_pixels( 1139 show_endpoints=show_endpoints, 1140 show_solution=show_solution, 1141 ) 1142 1143 chars_replace: tuple = tuple() 1144 if show_endpoints: 1145 chars_replace += (AsciiChars.START, AsciiChars.END) 1146 if show_solution: 1147 chars_replace += (AsciiChars.PATH,) 1148 1149 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1150 if ascii_char in chars_replace: 1151 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 1152 1153 return "\n".join("".join(row) for row in ascii_grid) 1154 1155 @classmethod 1156 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 1157 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" 1158 lines: list[str] = ascii_str.strip().split("\n") 1159 lines = [line.strip() for line in lines] 1160 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 1161 [list(line) for line in lines], 1162 dtype=str, 1163 ) 1164 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 1165 1166 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1167 pixel_grid[ascii_grid == ascii_char] = pixel_color 1168 1169 return cls.from_pixels(pixel_grid) 1170 1171 1172# type ignore here even though theyre all frozen 1173# maybe `SerializeableDataclass` itself is not frozen, but thats an ABC 1174# error: Cannot inherit frozen dataclass from a non-frozen one [misc] 1175@serializable_dataclass(frozen=True, kw_only=True) 1176class TargetedLatticeMaze(LatticeMaze): # type: ignore[misc] 1177 """A LatticeMaze with a start and end position""" 1178 1179 # this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos 1180 # type ignore here because even though its a kw-only dataclass, 1181 # mypy doesn't like that non-default arguments are after default arguments 1182 start_pos: Coord = serializable_field( # type: ignore[misc] 1183 assert_type=False, 1184 ) 1185 end_pos: Coord = serializable_field( # type: ignore[misc] 1186 assert_type=False, 1187 ) 1188 1189 def __post_init__(self) -> None: 1190 "post init converts start and end pos to numpy arrays, checks they exist and are in bounds" 1191 # make things numpy arrays (very jank to override frozen dataclass) 1192 self.__dict__["start_pos"] = np.array(self.start_pos) 1193 self.__dict__["end_pos"] = np.array(self.end_pos) 1194 assert self.start_pos is not None 1195 assert self.end_pos is not None 1196 # check that start and end are in bounds 1197 if ( 1198 self.start_pos[0] >= self.grid_shape[0] 1199 or self.start_pos[1] >= self.grid_shape[1] 1200 ): 1201 err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}" 1202 raise ValueError( 1203 err_msg, 1204 ) 1205 if ( 1206 self.end_pos[0] >= self.grid_shape[0] 1207 or self.end_pos[1] >= self.grid_shape[1] 1208 ): 1209 err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }" 1210 raise ValueError( 1211 err_msg, 1212 ) 1213 1214 def __eq__(self, other: object) -> bool: 1215 "check equality, calls parent class equality check" 1216 return super().__eq__(other) 1217 1218 def _get_start_pos_tokens(self) -> list[str | CoordTup]: 1219 return [ 1220 SPECIAL_TOKENS.ORIGIN_START, 1221 tuple(self.start_pos), 1222 SPECIAL_TOKENS.ORIGIN_END, 1223 ] 1224 1225 def get_start_pos_tokens(self) -> list[str | CoordTup]: 1226 "(deprecated!) return the start position as a list of tokens" 1227 warnings.warn( 1228 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 1229 TokenizerDeprecationWarning, 1230 ) 1231 return self._get_start_pos_tokens() 1232 1233 def _get_end_pos_tokens(self) -> list[str | CoordTup]: 1234 return [ 1235 SPECIAL_TOKENS.TARGET_START, 1236 tuple(self.end_pos), 1237 SPECIAL_TOKENS.TARGET_END, 1238 ] 1239 1240 def get_end_pos_tokens(self) -> list[str | CoordTup]: 1241 "(deprecated!) return the end position as a list of tokens" 1242 warnings.warn( 1243 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 1244 TokenizerDeprecationWarning, 1245 ) 1246 return self._get_end_pos_tokens() 1247 1248 @classmethod 1249 def from_lattice_maze( 1250 cls, 1251 lattice_maze: LatticeMaze, 1252 start_pos: Coord | CoordTup, 1253 end_pos: Coord | CoordTup, 1254 ) -> "TargetedLatticeMaze": 1255 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" 1256 return cls( 1257 connection_list=lattice_maze.connection_list, 1258 start_pos=np.array(start_pos), 1259 end_pos=np.array(end_pos), 1260 generation_meta=lattice_maze.generation_meta, 1261 ) 1262 1263 1264@serializable_dataclass(frozen=True, kw_only=True) 1265class SolvedMaze(TargetedLatticeMaze): # type: ignore[misc] 1266 """Stores a maze and a solution""" 1267 1268 solution: CoordArray = serializable_field( # type: ignore[misc] 1269 assert_type=False, 1270 ) 1271 1272 def __init__( 1273 self, 1274 connection_list: ConnectionList, 1275 solution: CoordArray, 1276 generation_meta: dict | None = None, 1277 start_pos: Coord | None = None, 1278 end_pos: Coord | None = None, 1279 allow_invalid: bool = False, 1280 ) -> None: 1281 """Create a SolvedMaze from a connection list and a solution 1282 1283 > DOCS: better documentation for this init method 1284 """ 1285 # figure out the solution 1286 solution_valid: bool = False 1287 if solution is not None: 1288 solution = np.array(solution) 1289 # note that a path length of 1 here is valid, since the start and end pos could be the same 1290 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 1291 solution_valid = True 1292 1293 if not solution_valid and not allow_invalid: 1294 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" 1295 raise ValueError( 1296 err_msg, 1297 f"{connection_list = }", 1298 ) 1299 1300 # init the TargetedLatticeMaze 1301 super().__init__( 1302 connection_list=connection_list, 1303 generation_meta=generation_meta, 1304 # TODO: the argument type is stricter than the expected type but it still fails? 1305 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 1306 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 1307 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 1308 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 1309 ) 1310 1311 self.__dict__["solution"] = solution 1312 1313 # adjust the endpoints 1314 if not allow_invalid: 1315 if start_pos is not None: 1316 assert np.array_equal(np.array(start_pos), self.start_pos), ( 1317 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 1318 ) 1319 if end_pos is not None: 1320 assert np.array_equal(np.array(end_pos), self.end_pos), ( 1321 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 1322 ) 1323 # TODO: assert the path does not backtrack, walk through walls, etc? 1324 1325 def __eq__(self, other: object) -> bool: 1326 "check equality, calls parent class equality check" 1327 return super().__eq__(other) 1328 1329 def __hash__(self) -> int: 1330 "hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes" 1331 return hash((self.connection_list.tobytes(), self.solution.tobytes())) 1332 1333 def _get_solution_tokens(self) -> list[str | CoordTup]: 1334 return [ 1335 SPECIAL_TOKENS.PATH_START, 1336 *[tuple(c) for c in self.solution], 1337 SPECIAL_TOKENS.PATH_END, 1338 ] 1339 1340 def get_solution_tokens(self) -> list[str | CoordTup]: 1341 "(deprecated!) return the solution as a list of tokens" 1342 warnings.warn( 1343 "`LatticeMaze.get_solution_tokens` is deprecated.", 1344 TokenizerDeprecationWarning, 1345 ) 1346 return self._get_solution_tokens() 1347 1348 # for backwards compatibility 1349 @property 1350 def maze(self) -> LatticeMaze: 1351 "(deprecated!) return the maze without the solution" 1352 warnings.warn( 1353 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 1354 DeprecationWarning, 1355 ) 1356 return LatticeMaze(connection_list=self.connection_list) 1357 1358 # type ignore here since we're overriding a method with a different signature 1359 @classmethod 1360 def from_lattice_maze( # type: ignore[override] 1361 cls, 1362 lattice_maze: LatticeMaze, 1363 solution: list[CoordTup] | CoordArray, 1364 ) -> "SolvedMaze": 1365 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" 1366 return cls( 1367 connection_list=lattice_maze.connection_list, 1368 solution=np.array(solution), 1369 generation_meta=lattice_maze.generation_meta, 1370 ) 1371 1372 @classmethod 1373 def from_targeted_lattice_maze( 1374 cls, 1375 targeted_lattice_maze: TargetedLatticeMaze, 1376 solution: list[CoordTup] | CoordArray | None = None, 1377 ) -> "SolvedMaze": 1378 """solves the given targeted lattice maze and returns a SolvedMaze""" 1379 if solution is None: 1380 solution = targeted_lattice_maze.find_shortest_path( 1381 targeted_lattice_maze.start_pos, 1382 targeted_lattice_maze.end_pos, 1383 ) 1384 return cls( 1385 connection_list=targeted_lattice_maze.connection_list, 1386 solution=np.array(solution), 1387 generation_meta=targeted_lattice_maze.generation_meta, 1388 ) 1389 1390 def get_solution_forking_points( 1391 self, 1392 always_include_endpoints: bool = False, 1393 ) -> tuple[list[int], CoordArray]: 1394 """coordinates and their indicies from the solution where a fork is present 1395 1396 - if the start point is not a dead end, this counts as a fork 1397 - if the end point is not a dead end, this counts as a fork 1398 """ 1399 output_idxs: list[int] = list() 1400 output_coords: list[CoordTup] = list() 1401 1402 for idx, coord in enumerate(self.solution): 1403 # more than one choice for first coord, or more than 2 for any other 1404 # since the previous coord doesn't count as a choice 1405 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 1406 theshold: int = 1 if is_endpoint else 2 1407 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 1408 is_endpoint and always_include_endpoints 1409 ): 1410 output_idxs.append(idx) 1411 output_coords.append(coord) 1412 1413 return output_idxs, np.array(output_coords) 1414 1415 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 1416 """coordinates from the solution where there is only a single (non-backtracking) point to move to 1417 1418 returns the complement of `get_solution_forking_points` from the path 1419 """ 1420 forks_idxs, _ = self.get_solution_forking_points() 1421 # HACK: idk why type ignore here 1422 return ( # type: ignore[return-value] 1423 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 1424 np.delete(self.solution, forks_idxs, axis=0), 1425 ) 1426 1427 1428def detect_pixels_type(data: PixelGrid) -> typing.Type[LatticeMaze]: 1429 """Detects the type of pixels data by checking for the presence of start and end pixels""" 1430 if color_in_pixel_grid(data, PixelColors.START) or color_in_pixel_grid( 1431 data, 1432 PixelColors.END, 1433 ): 1434 if color_in_pixel_grid(data, PixelColors.PATH): 1435 return SolvedMaze 1436 else: 1437 return TargetedLatticeMaze 1438 else: 1439 return LatticeMaze 1440 1441 1442def _remove_isolated_cells( 1443 image: Int[np.ndarray, "RGB x y"], 1444) -> Int[np.ndarray, "RGB x y"]: 1445 """Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides.""" 1446 # Create a binary mask where True represents walls 1447 wall_mask = np.all(image == PixelColors.WALL, axis=-1) 1448 1449 # Pad the wall mask to handle edge cases 1450 padded_wall_mask = np.pad( 1451 wall_mask, 1452 ((1, 1), (1, 1)), 1453 mode="constant", 1454 constant_values=True, 1455 ) 1456 1457 # Check neighbors in all four directions 1458 isolated_mask = ( 1459 padded_wall_mask[1:-1, 2:] # right 1460 & padded_wall_mask[1:-1, :-2] # left 1461 & padded_wall_mask[2:, 1:-1] # down 1462 & padded_wall_mask[:-2, 1:-1] # up 1463 ) 1464 1465 # Combine with non-wall mask to only affect open cells 1466 non_wall_mask = ~wall_mask 1467 isolated_mask = isolated_mask & non_wall_mask 1468 1469 # Create the output image 1470 output_image = image.copy() 1471 output_image[isolated_mask] = PixelColors.WALL 1472 1473 return output_image 1474 1475 1476_RIC_PADS: dict = { 1477 "left": ((1, 0), (0, 0)), 1478 "right": ((0, 1), (0, 0)), 1479 "up": ((0, 0), (1, 0)), 1480 "down": ((0, 0), (0, 1)), 1481} 1482 1483# Define slices for each direction 1484_RIC_SLICES: dict = { 1485 "left": (slice(1, None), slice(None, None)), 1486 "right": (slice(None, -1), slice(None, None)), 1487 "up": (slice(None, None), slice(1, None)), 1488 "down": (slice(None, None), slice(None, -1)), 1489} 1490 1491 1492# TODO: figure out why this function doesnt work, or maybe just get rid of it 1493# def _remove_isolated_cells_old( 1494# image: Int[np.ndarray, "RGB x y"], 1495# ) -> Int[np.ndarray, "RGB x y"]: 1496# """ 1497# Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides. 1498# """ 1499# warnings.warn("this functin doesn't work and I have no idea why!!!") 1500# masks: dict[str, np.ndarray] = { 1501# d: np.all( 1502# np.pad( 1503# image[_RIC_SLICES[d][0], _RIC_SLICES[d][1], :] == PixelColors.WALL, 1504# np.array((*_RIC_PADS[d], (0, 0)), dtype=np.int8), 1505# mode="constant", 1506# constant_values=True, 1507# ), 1508# axis=2, 1509# ) 1510# for d in _RIC_SLICES.keys() 1511# } 1512 1513# # Create a mask for non-wall cells 1514# mask_non_wall = np.all(image != PixelColors.WALL, axis=2) 1515 1516# # print(f"{mask_non_wall.shape = }") 1517# # print(f"{ {k: masks[k].shape for k in masks.keys()} = }") 1518 1519# # print(f"{mask_non_wall = }") 1520# # print(f"{masks['down'] = }") 1521 1522# # Combine the masks 1523# mask = mask_non_wall & masks["left"] & masks["right"] & masks["up"] & masks["down"] 1524 1525# # Apply the mask 1526# output_image = np.where( 1527# np.stack([mask] * 3, axis=-1), 1528# PixelColors.WALL, 1529# image, 1530# ) 1531 1532# return output_image
rgb tuple of values 0-255
rgb grid of pixels
boolean grid of pixels
2 dimensions
58class NoValidEndpointException(Exception): # noqa: N818 59 """Raised when no valid start or end positions are found in a maze.""" 60 61 pass
Raised when no valid start or end positions are found in a maze.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- add_note
- args
79def color_in_pixel_grid(pixel_grid: PixelGrid, color: RGB) -> bool: 80 """check if a color is in a pixel grid""" 81 for row in pixel_grid: 82 for pixel in row: 83 if np.all(pixel == color): 84 return True 85 return False
check if a color is in a pixel grid
88@dataclass(frozen=True) 89class PixelColors: 90 "standard colors for pixel grids" 91 92 WALL: RGB = (0, 0, 0) 93 OPEN: RGB = (255, 255, 255) 94 START: RGB = (0, 255, 0) 95 END: RGB = (255, 0, 0) 96 PATH: RGB = (0, 0, 255)
standard colors for pixel grids
99@dataclass(frozen=True) 100class AsciiChars: 101 "standard ascii characters for mazes" 102 103 WALL: str = "#" 104 OPEN: str = " " 105 START: str = "S" 106 END: str = "E" 107 PATH: str = "X"
standard ascii characters for mazes
map ascii characters to pixel colors
120@serializable_dataclass( 121 frozen=True, 122 kw_only=True, 123 properties_to_serialize=["lattice_dim", "generation_meta"], 124) 125class LatticeMaze(SerializableDataclass): 126 """lattice maze (nodes on a lattice, connections only to neighboring nodes) 127 128 Connection List represents which nodes (N) are connected in each direction. 129 130 First and second elements represent rightward and downward connections, 131 respectively. 132 133 Example: 134 Connection list: 135 [ 136 [ # down 137 [F T], 138 [F F] 139 ], 140 [ # right 141 [T F], 142 [T F] 143 ] 144 ] 145 146 Nodes with connections 147 N T N F 148 F T 149 N T N F 150 F F 151 152 Graph: 153 N - N 154 | 155 N - N 156 157 Note: the bottom row connections going down, and the 158 right-hand connections going right, will always be False. 159 160 """ 161 162 connection_list: ConnectionList 163 generation_meta: dict | None = serializable_field(default=None, compare=False) 164 165 lattice_dim = property(lambda self: self.connection_list.shape[0]) 166 grid_shape = property(lambda self: self.connection_list.shape[1:]) 167 n_connections = property(lambda self: self.connection_list.sum()) 168 169 @property 170 def grid_n(self) -> int: 171 "grid size as int, raises `AssertionError` if not square" 172 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 173 return self.grid_shape[0] 174 175 # ============================================================ 176 # basic methods 177 # ============================================================ 178 179 def __eq__(self, other: object) -> bool: 180 "equality check calls super" 181 return super().__eq__(other) 182 183 @staticmethod 184 def heuristic(a: CoordTup, b: CoordTup) -> float: 185 """return manhattan distance between two points""" 186 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1]) 187 188 def __hash__(self) -> int: 189 """hash the connection list by converting connection list to bytes""" 190 return hash(self.connection_list.tobytes()) 191 192 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 193 """returns whether two nodes are connected""" 194 delta: Coord = b - a 195 if np.abs(delta).sum() != 1: 196 # return false if not even adjacent 197 return False 198 else: 199 # test for wall 200 dim: int = int(np.argmax(np.abs(delta))) 201 clist_node: Coord = a if (delta.sum() > 0) else b 202 return self.connection_list[dim, clist_node[0], clist_node[1]] 203 204 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 205 """check if a path is valid""" 206 # check path is not empty 207 if len(path) == 0: 208 return empty_is_valid 209 210 # check all coords in bounds of maze 211 if not np.all((path >= 0) & (path < self.grid_shape)): 212 return False 213 214 # check all nodes connected 215 for i in range(len(path) - 1): 216 if not self.nodes_connected(path[i], path[i + 1]): 217 return False 218 return True 219 220 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 221 """Returns an array with the connectivity degree of each coord. 222 223 I.e., how many neighbors each coord has. 224 """ 225 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 226 self.connection_list.astype(np.int8) 227 ) 228 degrees: Int8[np.ndarray, "row col"] = np.sum( 229 int_conn, 230 axis=0, 231 ) # Connections to east and south 232 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 233 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 234 return degrees 235 236 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 237 """Returns an array of the neighboring, connected coords of `c`.""" 238 c = np.array(c) # type: ignore[assignment] 239 neighbors: list[Coord] = [ 240 neighbor 241 for neighbor in (c + NEIGHBORS_MASK) 242 if ( 243 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 244 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 245 and self.nodes_connected(c, neighbor) # connected 246 ) 247 ] 248 249 output: CoordArray = np.array(neighbors) 250 if len(neighbors) > 0: 251 assert output.shape == ( 252 len(neighbors), 253 2, 254 ), ( 255 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 256 ) 257 return output 258 259 def gen_connected_component_from(self, c: Coord) -> CoordArray: 260 """return the connected component from a given coordinate""" 261 # Stack for DFS 262 stack: list[Coord] = [c] 263 264 # Set to store visited nodes 265 visited: set[CoordTup] = set() 266 267 while stack: 268 current_node: Coord = stack.pop() 269 # this is fine since we know current_node is a coord and thus of length 2 270 visited.add(tuple(current_node)) # type: ignore[arg-type] 271 272 # Get the neighbors of the current node 273 neighbors = self.get_coord_neighbors(current_node) 274 275 # Iterate over neighbors 276 for neighbor in neighbors: 277 if tuple(neighbor) not in visited: 278 stack.append(neighbor) 279 280 return np.array(list(visited)) 281 282 def find_shortest_path( 283 self, 284 c_start: CoordTup | Coord, 285 c_end: CoordTup | Coord, 286 ) -> CoordArray: 287 """find the shortest path between two coordinates, using A*""" 288 c_start = tuple(c_start) # type: ignore[assignment] 289 c_end = tuple(c_end) # type: ignore[assignment] 290 291 g_score: dict[CoordTup, float] = ( 292 dict() 293 ) # cost of cheapest path to node from start currently known 294 f_score: dict[CoordTup, float] = { 295 c_start: 0.0, 296 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 297 298 # init 299 g_score[c_start] = 0.0 300 g_score[c_start] = self.heuristic(c_start, c_end) 301 302 closed_vtx: set[CoordTup] = set() # nodes already evaluated 303 # nodes to be evaluated 304 # we need a set of the tuples, dont place the ints in the set 305 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 306 source: dict[CoordTup, CoordTup] = ( 307 dict() 308 ) # node immediately preceding each node in the path (currently known shortest path) 309 310 while open_vtx: 311 # get lowest f_score node 312 # mypy cant tell that c is of length 2 313 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] 314 # f_current: float = f_score[c_current] 315 316 # check if goal is reached 317 if c_end == c_current: 318 path: list[CoordTup] = [c_current] 319 p_current: CoordTup = c_current 320 while p_current in source: 321 p_current = source[p_current] 322 path.append(p_current) 323 # ---------------------------------------------------------------------- 324 # this is the only return statement 325 return np.array(path[::-1]) 326 # ---------------------------------------------------------------------- 327 328 # close current node 329 closed_vtx.add(c_current) 330 open_vtx.remove(c_current) 331 332 # update g_score of neighbors 333 _np_neighbor: Coord 334 for _np_neighbor in self.get_coord_neighbors(c_current): 335 neighbor: CoordTup = tuple(_np_neighbor) 336 337 if neighbor in closed_vtx: 338 # already checked 339 continue 340 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 341 342 if neighbor not in open_vtx: 343 # found new vtx, so add 344 open_vtx.add(neighbor) 345 346 elif g_temp >= g_score[neighbor]: 347 # if already knew about this one, but current g_score is worse, skip 348 continue 349 350 # store g_score and source 351 source[neighbor] = c_current 352 g_score[neighbor] = g_temp 353 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 354 355 raise ValueError( 356 "A solution could not be found!", 357 f"{c_start = }, {c_end = }", 358 self.as_ascii(), 359 ) 360 361 def get_nodes(self) -> CoordArray: 362 """return a list of all nodes in the maze""" 363 rows: Int[np.ndarray, "x y"] 364 cols: Int[np.ndarray, "x y"] 365 rows, cols = np.meshgrid( 366 range(self.grid_shape[0]), 367 range(self.grid_shape[1]), 368 indexing="ij", 369 ) 370 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 371 return nodes 372 373 def get_connected_component(self) -> CoordArray: 374 """get the largest (and assumed only nonsingular) connected component of the maze 375 376 TODO: other connected components? 377 """ 378 if (self.generation_meta is None) or ( 379 self.generation_meta.get("fully_connected", False) 380 ): 381 # for fully connected case, pick any two positions 382 return self.get_nodes() 383 else: 384 # if metadata provided, use visited cells 385 visited_cells: set[CoordTup] | None = self.generation_meta.get( 386 "visited_cells", 387 None, 388 ) 389 if visited_cells is None: 390 # TODO: dynamically generate visited_cells? 391 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 392 raise ValueError( 393 err_msg, 394 ) 395 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 396 return visited_cells_np 397 398 @typing.overload 399 def generate_random_path( 400 self, 401 allowed_start: CoordList | None = None, 402 allowed_end: CoordList | None = None, 403 deadend_start: bool = False, 404 deadend_end: bool = False, 405 endpoints_not_equal: bool = False, 406 except_on_no_valid_endpoint: typing.Literal[True] = True, 407 ) -> CoordArray: ... 408 @typing.overload 409 def generate_random_path( 410 self, 411 allowed_start: CoordList | None = None, 412 allowed_end: CoordList | None = None, 413 deadend_start: bool = False, 414 deadend_end: bool = False, 415 endpoints_not_equal: bool = False, 416 except_on_no_valid_endpoint: typing.Literal[False] = False, 417 ) -> typing.Optional[CoordArray]: ... 418 def generate_random_path( # noqa: C901 419 self, 420 allowed_start: CoordList | None = None, 421 allowed_end: CoordList | None = None, 422 deadend_start: bool = False, 423 deadend_end: bool = False, 424 endpoints_not_equal: bool = False, 425 except_on_no_valid_endpoint: bool = True, 426 ) -> typing.Optional[CoordArray]: 427 """return a path between randomly chosen start and end nodes within the connected component 428 429 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 430 431 # Parameters: 432 - `allowed_start : CoordList | None` 433 a list of allowed start positions. If `None`, any position in the connected component is allowed 434 (defaults to `None`) 435 - `allowed_end : CoordList | None` 436 a list of allowed end positions. If `None`, any position in the connected component is allowed 437 (defaults to `None`) 438 - `deadend_start : bool` 439 whether to ***force*** the start position to be a deadend (defaults to `False`) 440 (defaults to `False`) 441 - `deadend_end : bool` 442 whether to ***force*** the end position to be a deadend (defaults to `False`) 443 (defaults to `False`) 444 - `endpoints_not_equal : bool` 445 whether to ensure tha the start and end point are not the same 446 (defaults to `False`) 447 - `except_on_no_valid_endpoint : bool` 448 whether to raise an error if no valid start or end positions are found 449 if this is `False`, the function might return `None` and this must be handled by the caller 450 (defaults to `True`) 451 452 # Returns: 453 - `CoordArray` 454 a path between the selected start and end positions 455 456 # Raises: 457 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 458 """ 459 # we can't create a "path" in a single-node maze 460 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 461 f"can't create path in single-node maze: {self.as_ascii()}" 462 ) 463 464 # get connected component 465 connected_component: CoordArray = self.get_connected_component() 466 467 # initialize start and end positions 468 positions: Int[np.int8, "2 2"] 469 470 # if no special conditions on start and end positions 471 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 472 None, 473 None, 474 False, 475 False, 476 ): 477 try: 478 positions = connected_component[ # type: ignore[assignment] 479 np.random.choice( 480 len(connected_component), 481 size=2, 482 replace=False, 483 ) 484 ] 485 except ValueError as e: 486 if except_on_no_valid_endpoint: 487 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" 488 raise NoValidEndpointException( 489 err_msg, 490 ) from e 491 return None 492 493 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 494 495 # handle special conditions 496 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 497 # copy connected component set 498 allowed_start_set: set[CoordTup] = connected_component_set.copy() 499 allowed_end_set: set[CoordTup] = connected_component_set.copy() 500 501 # filter by explicitly allowed start and end positions 502 # '# type: ignore[assignment]' here because the returned tuple can be of any length 503 if allowed_start is not None: 504 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 505 506 if allowed_end is not None: 507 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 508 509 # filter by forcing deadends 510 if deadend_start: 511 allowed_start_set = set( 512 filter( 513 lambda x: len(self.get_coord_neighbors(x)) == 1, 514 allowed_start_set, 515 ), 516 ) 517 518 if deadend_end: 519 allowed_end_set = set( 520 filter( 521 lambda x: len(self.get_coord_neighbors(x)) == 1, 522 allowed_end_set, 523 ), 524 ) 525 526 # check we have valid positions 527 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 528 if except_on_no_valid_endpoint: 529 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 530 raise NoValidEndpointException( 531 err_msg, 532 ) 533 return None 534 535 # randomly select start and end positions 536 try: 537 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 538 start_pos: CoordTup = tuple( # type: ignore[assignment] 539 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], 540 ) 541 if endpoints_not_equal: 542 # remove start position from end positions 543 allowed_end_set.discard(start_pos) 544 end_pos: CoordTup = tuple( # type: ignore[assignment] 545 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], 546 ) 547 except ValueError as e: 548 if except_on_no_valid_endpoint: 549 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 550 raise NoValidEndpointException( 551 err_msg, 552 ) from e 553 return None 554 555 return self.find_shortest_path(start_pos, end_pos) 556 557 # ============================================================ 558 # to and from adjacency list 559 # ============================================================ 560 def as_adj_list( 561 self, 562 shuffle_d0: bool = True, 563 shuffle_d1: bool = True, 564 ) -> Int8[np.ndarray, "conn start_end coord"]: 565 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" 566 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1) 567 568 @classmethod 569 def from_adj_list( 570 cls, 571 adj_list: Int8[np.ndarray, "conn start_end coord"], 572 ) -> "LatticeMaze": 573 """create a LatticeMaze from a list of connections 574 575 > [!NOTE] 576 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 577 """ 578 # this is where it would probably break for rectangular mazes 579 grid_n: int = adj_list.max() + 1 580 581 connection_list: ConnectionList = np.zeros( 582 (2, grid_n, grid_n), 583 dtype=np.bool_, 584 ) 585 586 for c_start, c_end in adj_list: 587 # check that exactly 1 coordinate matches 588 if (c_start == c_end).sum() != 1: 589 raise ValueError("invalid connection") 590 591 # get the direction 592 d: int = (c_start != c_end).argmax() 593 594 x: int 595 y: int 596 # pick whichever has the lesser value in the direction `d` 597 if c_start[d] < c_end[d]: 598 x, y = c_start 599 else: 600 x, y = c_end 601 602 connection_list[d, x, y] = True 603 604 return LatticeMaze( 605 connection_list=connection_list, 606 ) 607 608 def as_adj_list_tokens(self) -> list[str | CoordTup]: 609 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" 610 warnings.warn( 611 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 612 TokenizerDeprecationWarning, 613 ) 614 return [ 615 SPECIAL_TOKENS.ADJLIST_START, 616 *chain.from_iterable( # type: ignore[list-item] 617 [ 618 [ 619 tuple(c_s), 620 SPECIAL_TOKENS.CONNECTOR, 621 tuple(c_e), 622 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 623 ] 624 for c_s, c_e in self.as_adj_list() 625 ], 626 ), 627 SPECIAL_TOKENS.ADJLIST_END, 628 ] 629 630 def _as_adj_list_tokens(self) -> list[str | CoordTup]: 631 return [ 632 SPECIAL_TOKENS.ADJLIST_START, 633 *chain.from_iterable( # type: ignore[list-item] 634 [ 635 [ 636 tuple(c_s), 637 SPECIAL_TOKENS.CONNECTOR, 638 tuple(c_e), 639 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 640 ] 641 for c_s, c_e in self.as_adj_list() 642 ], 643 ), 644 SPECIAL_TOKENS.ADJLIST_END, 645 ] 646 647 def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]: 648 """turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples""" 649 output: list[CoordTup | str] = self._as_adj_list_tokens() 650 # if getattr(self, "start_pos", None) is not None: 651 if isinstance(self, TargetedLatticeMaze): 652 output += self._get_start_pos_tokens() 653 if isinstance(self, TargetedLatticeMaze): 654 output += self._get_end_pos_tokens() 655 if isinstance(self, SolvedMaze): 656 output += self._get_solution_tokens() 657 return output 658 659 def _as_tokens( 660 self, 661 maze_tokenizer: "MazeTokenizer | TokenizationMode", 662 ) -> list[str]: 663 # type ignores here fine since we check the instance 664 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 665 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 666 if ( 667 isinstance_by_type_name(maze_tokenizer, "MazeTokenizer") 668 and maze_tokenizer.is_AOTP() # type: ignore[union-attr] 669 ): 670 coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP() 671 coords_processed: list[str] = maze_tokenizer.coords_to_strings( # type: ignore[union-attr] 672 coords=coords_raw, 673 when_noncoord="include", 674 ) 675 return coords_processed 676 else: 677 err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}" 678 raise NotImplementedError(err_msg) 679 680 def as_tokens( 681 self, 682 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 683 ) -> list[str]: 684 """serialize maze and solution to tokens""" 685 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 686 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 687 else: 688 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type] 689 690 @classmethod 691 def _from_tokens_AOTP( 692 cls, 693 tokens: list[str], 694 maze_tokenizer: "MazeTokenizer | MazeTokenizerModular", 695 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 696 """create a LatticeMaze from a list of tokens""" 697 # figure out what input format 698 # ======================================== 699 if tokens[0] == SPECIAL_TOKENS.ADJLIST_START: 700 adj_list_tokens = get_adj_list_tokens(tokens) 701 else: 702 # If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens 703 adj_list_tokens = tokens 704 warnings.warn( 705 "Assuming input is just adjacency list tokens, no special tokens found", 706 ) 707 708 # process edges for adjacency list 709 # ======================================== 710 edges: list[list[str]] = list_split( 711 adj_list_tokens, 712 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 713 ) 714 715 coordinates: list[tuple[CoordTup, CoordTup]] = list() 716 for e in edges: 717 # skip last endline 718 if len(e) != 0: 719 # convert to coords, split start and end 720 e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords( 721 e, 722 when_noncoord="include", 723 ) 724 # this assertion depends on the tokenizer having exactly one token for the connector 725 # which is also why we "include" above 726 # the connector token is discarded below 727 assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }" # noqa: PLR2004 728 assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, ( 729 f"invalid edge: {e = } {e_coords = }" 730 ) 731 e_coords_first: CoordTup = e_coords[0] # type: ignore[assignment] 732 e_coords_last: CoordTup = e_coords[-1] # type: ignore[assignment] 733 coordinates.append((e_coords_first, e_coords_last)) 734 735 assert all(len(c) == DIM_2 for c in coordinates), ( 736 f"invalid coordinates: {coordinates = }" 737 ) 738 adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates) 739 assert tuple(adj_list.shape) == ( 740 len(coordinates), 741 2, 742 2, 743 ), f"invalid adj_list: {adj_list.shape = } {coordinates = }" 744 745 output_maze: LatticeMaze = cls.from_adj_list(adj_list) 746 747 # add start and end positions 748 # ======================================== 749 is_targeted: bool = False 750 if all( 751 x in tokens 752 for x in ( 753 SPECIAL_TOKENS.ORIGIN_START, 754 SPECIAL_TOKENS.ORIGIN_END, 755 SPECIAL_TOKENS.TARGET_START, 756 SPECIAL_TOKENS.TARGET_END, 757 ) 758 ): 759 start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 760 get_origin_tokens(tokens), 761 when_noncoord="error", 762 ) 763 end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 764 get_target_tokens(tokens), 765 when_noncoord="error", 766 ) 767 assert len(start_pos_list) == 1, ( 768 f"invalid start_pos_list: {start_pos_list = }" 769 ) 770 assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }" 771 772 start_pos: CoordTup = start_pos_list[0] 773 end_pos: CoordTup = end_pos_list[0] 774 775 output_maze = TargetedLatticeMaze.from_lattice_maze( 776 lattice_maze=output_maze, 777 start_pos=start_pos, 778 end_pos=end_pos, 779 ) 780 781 is_targeted = True 782 783 if all( 784 x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END) 785 ): 786 assert is_targeted, "maze must be targeted to have a solution" 787 solution: list[CoordTup] = maze_tokenizer.strings_to_coords( 788 get_path_tokens(tokens, trim_end=True), 789 when_noncoord="error", 790 ) 791 output_maze = SolvedMaze.from_targeted_lattice_maze( 792 # HACK: I think this is fine, but im not sure 793 targeted_lattice_maze=output_maze, # type: ignore[arg-type] 794 solution=solution, 795 ) 796 797 return output_maze 798 799 # TODO: any way to get return type hinting working for this? 800 @classmethod 801 def from_tokens( 802 cls, 803 tokens: list[str], 804 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 805 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 806 """Constructs a maze from a tokenization. 807 808 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 809 """ 810 # HACK: type ignores here fine since we check the instance 811 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 812 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 813 if ( 814 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 815 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 816 ): 817 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 818 raise NotImplementedError( 819 err_msg, 820 ) 821 822 if isinstance(tokens, str): 823 tokens = tokens.split() 824 825 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 826 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 827 else: 828 raise NotImplementedError("only AOTP tokenization is supported") 829 830 # ============================================================ 831 # to and from pixels 832 # ============================================================ 833 def _as_pixels_bw(self) -> BinaryPixelGrid: 834 assert self.lattice_dim == DIM_2, "only 2D mazes are supported" 835 # Create an empty pixel grid with walls 836 pixel_grid: Int[np.ndarray, "x y"] = np.full( 837 (self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1), 838 False, 839 dtype=np.bool_, 840 ) 841 842 # Set white nodes 843 pixel_grid[1::2, 1::2] = True 844 845 # Set white connections (downward) 846 for i, row in enumerate(self.connection_list[0]): 847 for j, connected in enumerate(row): 848 if connected: 849 pixel_grid[i * 2 + 2, j * 2 + 1] = True 850 851 # Set white connections (rightward) 852 for i, row in enumerate(self.connection_list[1]): 853 for j, connected in enumerate(row): 854 if connected: 855 pixel_grid[i * 2 + 1, j * 2 + 2] = True 856 857 return pixel_grid 858 859 def as_pixels( 860 self, 861 show_endpoints: bool = True, 862 show_solution: bool = True, 863 ) -> PixelGrid: 864 """convert the maze to a pixel grid 865 866 - useful as a simpler way of plotting the maze than the more complex `MazePlot` 867 - the same underlying representation as `as_ascii` but as an image 868 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data 869 """ 870 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 871 # but solution, start_pos, end_pos not always defined 872 # but its fine since we explicitly check the type 873 if show_solution and not show_endpoints: 874 raise ValueError("show_solution=True requires show_endpoints=True") 875 # convert original bool pixel grid to RGB 876 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 877 pixel_grid: PixelGrid = np.full( 878 (*pixel_grid_bw.shape, 3), 879 PixelColors.WALL, 880 dtype=np.uint8, 881 ) 882 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 883 884 if self.__class__ == LatticeMaze: 885 return pixel_grid 886 887 # set endpoints for TargetedLatticeMaze 888 if self.__class__ == TargetedLatticeMaze: 889 if show_endpoints: 890 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 891 PixelColors.START 892 ) 893 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 894 PixelColors.END 895 ) 896 return pixel_grid 897 898 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 899 if show_solution: 900 for coord in self.solution: # type: ignore[attr-defined] 901 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 902 903 # set pixels between coords 904 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 905 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 906 # check they are adjacent using norm 907 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 908 f"Coords {coord} and {next_coord} are not adjacent" 909 ) 910 # set pixel between them 911 pixel_grid[ 912 coord[0] * 2 + 1 + next_coord[0] - coord[0], 913 coord[1] * 2 + 1 + next_coord[1] - coord[1], 914 ] = PixelColors.PATH 915 916 # set endpoints (again, since path would overwrite them) 917 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 918 PixelColors.START 919 ) 920 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 921 PixelColors.END 922 ) 923 924 return pixel_grid 925 926 @classmethod 927 def _from_pixel_grid_bw( 928 cls, 929 pixel_grid: BinaryPixelGrid, 930 ) -> tuple[ConnectionList, tuple[int, int]]: 931 grid_shape: tuple[int, int] = ( 932 pixel_grid.shape[0] // 2, 933 pixel_grid.shape[1] // 2, 934 ) 935 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_) 936 937 # Extract downward connections 938 connection_list[0] = pixel_grid[2::2, 1::2] 939 940 # Extract rightward connections 941 connection_list[1] = pixel_grid[1::2, 2::2] 942 943 return connection_list, grid_shape 944 945 @classmethod 946 def _from_pixel_grid_with_positions( 947 cls, 948 pixel_grid: PixelGrid | BinaryPixelGrid, 949 marked_positions: dict[str, RGB], 950 ) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]: 951 # Convert RGB pixel grid to Bool pixel grid 952 # error: Incompatible types in assignment (expression has type 953 # "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]", 954 # variable has type "ndarray[Any, Any]") [assignment] 955 pixel_grid_bw: BinaryPixelGrid = ~np.all( # type: ignore[assignment] 956 pixel_grid == PixelColors.WALL, 957 axis=-1, 958 ) 959 connection_list: ConnectionList 960 grid_shape: tuple[int, int] 961 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw) 962 963 # Find any marked positions 964 out_positions: dict[str, CoordArray] = dict() 965 for key, color in marked_positions.items(): 966 pos_temp: Int[np.ndarray, "x y"] = np.argwhere( 967 np.all(pixel_grid == color, axis=-1), 968 ) 969 pos_save: list[CoordTup] = list() 970 for pos in pos_temp: 971 # if it is a coordinate and not connection (transform position, %2==1) 972 if pos[0] % 2 == 1 and pos[1] % 2 == 1: 973 pos_save.append((pos[0] // 2, pos[1] // 2)) 974 975 out_positions[key] = np.array(pos_save) 976 977 return connection_list, grid_shape, out_positions 978 979 @classmethod 980 def from_pixels( 981 cls, 982 pixel_grid: PixelGrid, 983 ) -> "LatticeMaze": 984 """create a LatticeMaze from a pixel grid. reverse of `as_pixels` 985 986 # Raises: 987 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` 988 """ 989 connection_list: ConnectionList 990 grid_shape: tuple[int, int] 991 992 # if a binary pixel grid, return regular LatticeMaze 993 if len(pixel_grid.shape) == 2: # noqa: PLR2004 994 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 995 return LatticeMaze(connection_list=connection_list) 996 997 # otherwise, detect and check it's valid 998 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 999 if cls not in cls_detected.__mro__: 1000 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" 1001 raise ValueError( 1002 err_msg, 1003 ) 1004 1005 ( 1006 connection_list, 1007 grid_shape, 1008 marked_pos, 1009 ) = cls._from_pixel_grid_with_positions( 1010 pixel_grid=pixel_grid, 1011 marked_positions=dict( 1012 start=PixelColors.START, 1013 end=PixelColors.END, 1014 solution=PixelColors.PATH, 1015 ), 1016 ) 1017 # if we wanted a LatticeMaze, return it 1018 if cls == LatticeMaze: 1019 return LatticeMaze(connection_list=connection_list) 1020 1021 # otherwise, keep going 1022 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 1023 1024 # start and end pos 1025 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 1026 assert start_pos_arr.shape == ( 1027 1, 1028 2, 1029 ), ( 1030 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1031 ) 1032 assert end_pos_arr.shape == ( 1033 1, 1034 2, 1035 ), ( 1036 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1037 ) 1038 1039 start_pos: Coord = start_pos_arr[0] 1040 end_pos: Coord = end_pos_arr[0] 1041 1042 # return a TargetedLatticeMaze if that's what we wanted 1043 if cls == TargetedLatticeMaze: 1044 return TargetedLatticeMaze( 1045 connection_list=connection_list, 1046 start_pos=start_pos, 1047 end_pos=end_pos, 1048 ) 1049 1050 # raw solution, only contains path elements and not start or end 1051 solution_raw: CoordArray = marked_pos["solution"] 1052 if len(solution_raw.shape) == 2: # noqa: PLR2004 1053 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 1054 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 1055 ) 1056 elif solution_raw.shape == (0,): 1057 # the solution and end should be immediately adjacent 1058 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 1059 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 1060 ) 1061 1062 # order the solution, by creating a list from the start to the end 1063 # add end pos, since we will iterate over all these starting from the start pos 1064 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 1065 tuple(end_pos), 1066 ] 1067 # solution starts with start point 1068 solution: list[CoordTup] = [tuple(start_pos)] 1069 while solution[-1] != tuple(end_pos): 1070 # use `get_coord_neighbors` to find connected neighbors 1071 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 1072 # TODO: make this less ugly 1073 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 1074 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 1075 ) 1076 # neighbors = neighbors[:, [1, 0]] 1077 # filter out neighbors that are not in the raw solution 1078 neighbors_filtered: CoordArray = np.array( 1079 [ 1080 coord 1081 for coord in neighbors 1082 if ( 1083 tuple(coord) in solution_raw_list 1084 and tuple(coord) not in solution 1085 ) 1086 ], 1087 ) 1088 # assert only one element is left, and then add it to the solution 1089 assert neighbors_filtered.shape == ( 1090 1, 1091 2, 1092 ), ( 1093 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 1094 ) 1095 solution.append(tuple(neighbors_filtered[0])) 1096 1097 # assert the solution is complete 1098 assert solution[0] == tuple(start_pos), ( 1099 f"solution {solution} does not start at start_pos {start_pos}" 1100 ) 1101 assert solution[-1] == tuple(end_pos), ( 1102 f"solution {solution} does not end at end_pos {end_pos}" 1103 ) 1104 1105 return cls( 1106 connection_list=np.array(connection_list), 1107 solution=np.array(solution), # type: ignore[call-arg] 1108 ) 1109 1110 # ============================================================ 1111 # to and from ASCII 1112 # ============================================================ 1113 def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]: 1114 # Get the pixel grid using to_pixels(). 1115 pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw() 1116 1117 # Replace pixel values with ASCII characters. 1118 ascii_grid: Shaped[np.ndarray, "x y"] = np.full( 1119 pixel_grid.shape, 1120 AsciiChars.WALL, 1121 dtype=str, 1122 ) 1123 ascii_grid[pixel_grid == True] = AsciiChars.OPEN # noqa: E712 1124 1125 return ascii_grid 1126 1127 def as_ascii( 1128 self, 1129 show_endpoints: bool = True, 1130 show_solution: bool = True, 1131 ) -> str: 1132 """return an ASCII grid of the maze 1133 1134 useful for debugging in the terminal, or as it's own format 1135 1136 can be reversed with `LatticeMaze.from_ascii()` 1137 """ 1138 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 1139 pixel_grid: PixelGrid = self.as_pixels( 1140 show_endpoints=show_endpoints, 1141 show_solution=show_solution, 1142 ) 1143 1144 chars_replace: tuple = tuple() 1145 if show_endpoints: 1146 chars_replace += (AsciiChars.START, AsciiChars.END) 1147 if show_solution: 1148 chars_replace += (AsciiChars.PATH,) 1149 1150 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1151 if ascii_char in chars_replace: 1152 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 1153 1154 return "\n".join("".join(row) for row in ascii_grid) 1155 1156 @classmethod 1157 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 1158 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" 1159 lines: list[str] = ascii_str.strip().split("\n") 1160 lines = [line.strip() for line in lines] 1161 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 1162 [list(line) for line in lines], 1163 dtype=str, 1164 ) 1165 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 1166 1167 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1168 pixel_grid[ascii_grid == ascii_char] = pixel_color 1169 1170 return cls.from_pixels(pixel_grid)
lattice maze (nodes on a lattice, connections only to neighboring nodes)
Connection List represents which nodes (N) are connected in each direction.
First and second elements represent rightward and downward connections, respectively.
Example: Connection list: [ [ # down [F T], [F F] ], [ # right [T F], [T F] ] ]
Nodes with connections
N T N F
F T
N T N F
F F
Graph:
N - N
|
N - N
Note: the bottom row connections going down, and the right-hand connections going right, will always be False.
169 @property 170 def grid_n(self) -> int: 171 "grid size as int, raises `AssertionError` if not square" 172 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 173 return self.grid_shape[0]
grid size as int, raises AssertionError
if not square
183 @staticmethod 184 def heuristic(a: CoordTup, b: CoordTup) -> float: 185 """return manhattan distance between two points""" 186 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])
return manhattan distance between two points
192 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 193 """returns whether two nodes are connected""" 194 delta: Coord = b - a 195 if np.abs(delta).sum() != 1: 196 # return false if not even adjacent 197 return False 198 else: 199 # test for wall 200 dim: int = int(np.argmax(np.abs(delta))) 201 clist_node: Coord = a if (delta.sum() > 0) else b 202 return self.connection_list[dim, clist_node[0], clist_node[1]]
returns whether two nodes are connected
204 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 205 """check if a path is valid""" 206 # check path is not empty 207 if len(path) == 0: 208 return empty_is_valid 209 210 # check all coords in bounds of maze 211 if not np.all((path >= 0) & (path < self.grid_shape)): 212 return False 213 214 # check all nodes connected 215 for i in range(len(path) - 1): 216 if not self.nodes_connected(path[i], path[i + 1]): 217 return False 218 return True
check if a path is valid
220 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 221 """Returns an array with the connectivity degree of each coord. 222 223 I.e., how many neighbors each coord has. 224 """ 225 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 226 self.connection_list.astype(np.int8) 227 ) 228 degrees: Int8[np.ndarray, "row col"] = np.sum( 229 int_conn, 230 axis=0, 231 ) # Connections to east and south 232 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 233 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 234 return degrees
Returns an array with the connectivity degree of each coord.
I.e., how many neighbors each coord has.
236 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 237 """Returns an array of the neighboring, connected coords of `c`.""" 238 c = np.array(c) # type: ignore[assignment] 239 neighbors: list[Coord] = [ 240 neighbor 241 for neighbor in (c + NEIGHBORS_MASK) 242 if ( 243 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 244 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 245 and self.nodes_connected(c, neighbor) # connected 246 ) 247 ] 248 249 output: CoordArray = np.array(neighbors) 250 if len(neighbors) > 0: 251 assert output.shape == ( 252 len(neighbors), 253 2, 254 ), ( 255 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 256 ) 257 return output
Returns an array of the neighboring, connected coords of c
.
259 def gen_connected_component_from(self, c: Coord) -> CoordArray: 260 """return the connected component from a given coordinate""" 261 # Stack for DFS 262 stack: list[Coord] = [c] 263 264 # Set to store visited nodes 265 visited: set[CoordTup] = set() 266 267 while stack: 268 current_node: Coord = stack.pop() 269 # this is fine since we know current_node is a coord and thus of length 2 270 visited.add(tuple(current_node)) # type: ignore[arg-type] 271 272 # Get the neighbors of the current node 273 neighbors = self.get_coord_neighbors(current_node) 274 275 # Iterate over neighbors 276 for neighbor in neighbors: 277 if tuple(neighbor) not in visited: 278 stack.append(neighbor) 279 280 return np.array(list(visited))
return the connected component from a given coordinate
282 def find_shortest_path( 283 self, 284 c_start: CoordTup | Coord, 285 c_end: CoordTup | Coord, 286 ) -> CoordArray: 287 """find the shortest path between two coordinates, using A*""" 288 c_start = tuple(c_start) # type: ignore[assignment] 289 c_end = tuple(c_end) # type: ignore[assignment] 290 291 g_score: dict[CoordTup, float] = ( 292 dict() 293 ) # cost of cheapest path to node from start currently known 294 f_score: dict[CoordTup, float] = { 295 c_start: 0.0, 296 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 297 298 # init 299 g_score[c_start] = 0.0 300 g_score[c_start] = self.heuristic(c_start, c_end) 301 302 closed_vtx: set[CoordTup] = set() # nodes already evaluated 303 # nodes to be evaluated 304 # we need a set of the tuples, dont place the ints in the set 305 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 306 source: dict[CoordTup, CoordTup] = ( 307 dict() 308 ) # node immediately preceding each node in the path (currently known shortest path) 309 310 while open_vtx: 311 # get lowest f_score node 312 # mypy cant tell that c is of length 2 313 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] 314 # f_current: float = f_score[c_current] 315 316 # check if goal is reached 317 if c_end == c_current: 318 path: list[CoordTup] = [c_current] 319 p_current: CoordTup = c_current 320 while p_current in source: 321 p_current = source[p_current] 322 path.append(p_current) 323 # ---------------------------------------------------------------------- 324 # this is the only return statement 325 return np.array(path[::-1]) 326 # ---------------------------------------------------------------------- 327 328 # close current node 329 closed_vtx.add(c_current) 330 open_vtx.remove(c_current) 331 332 # update g_score of neighbors 333 _np_neighbor: Coord 334 for _np_neighbor in self.get_coord_neighbors(c_current): 335 neighbor: CoordTup = tuple(_np_neighbor) 336 337 if neighbor in closed_vtx: 338 # already checked 339 continue 340 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 341 342 if neighbor not in open_vtx: 343 # found new vtx, so add 344 open_vtx.add(neighbor) 345 346 elif g_temp >= g_score[neighbor]: 347 # if already knew about this one, but current g_score is worse, skip 348 continue 349 350 # store g_score and source 351 source[neighbor] = c_current 352 g_score[neighbor] = g_temp 353 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 354 355 raise ValueError( 356 "A solution could not be found!", 357 f"{c_start = }, {c_end = }", 358 self.as_ascii(), 359 )
find the shortest path between two coordinates, using A*
361 def get_nodes(self) -> CoordArray: 362 """return a list of all nodes in the maze""" 363 rows: Int[np.ndarray, "x y"] 364 cols: Int[np.ndarray, "x y"] 365 rows, cols = np.meshgrid( 366 range(self.grid_shape[0]), 367 range(self.grid_shape[1]), 368 indexing="ij", 369 ) 370 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 371 return nodes
return a list of all nodes in the maze
373 def get_connected_component(self) -> CoordArray: 374 """get the largest (and assumed only nonsingular) connected component of the maze 375 376 TODO: other connected components? 377 """ 378 if (self.generation_meta is None) or ( 379 self.generation_meta.get("fully_connected", False) 380 ): 381 # for fully connected case, pick any two positions 382 return self.get_nodes() 383 else: 384 # if metadata provided, use visited cells 385 visited_cells: set[CoordTup] | None = self.generation_meta.get( 386 "visited_cells", 387 None, 388 ) 389 if visited_cells is None: 390 # TODO: dynamically generate visited_cells? 391 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 392 raise ValueError( 393 err_msg, 394 ) 395 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 396 return visited_cells_np
get the largest (and assumed only nonsingular) connected component of the maze
TODO: other connected components?
418 def generate_random_path( # noqa: C901 419 self, 420 allowed_start: CoordList | None = None, 421 allowed_end: CoordList | None = None, 422 deadend_start: bool = False, 423 deadend_end: bool = False, 424 endpoints_not_equal: bool = False, 425 except_on_no_valid_endpoint: bool = True, 426 ) -> typing.Optional[CoordArray]: 427 """return a path between randomly chosen start and end nodes within the connected component 428 429 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 430 431 # Parameters: 432 - `allowed_start : CoordList | None` 433 a list of allowed start positions. If `None`, any position in the connected component is allowed 434 (defaults to `None`) 435 - `allowed_end : CoordList | None` 436 a list of allowed end positions. If `None`, any position in the connected component is allowed 437 (defaults to `None`) 438 - `deadend_start : bool` 439 whether to ***force*** the start position to be a deadend (defaults to `False`) 440 (defaults to `False`) 441 - `deadend_end : bool` 442 whether to ***force*** the end position to be a deadend (defaults to `False`) 443 (defaults to `False`) 444 - `endpoints_not_equal : bool` 445 whether to ensure tha the start and end point are not the same 446 (defaults to `False`) 447 - `except_on_no_valid_endpoint : bool` 448 whether to raise an error if no valid start or end positions are found 449 if this is `False`, the function might return `None` and this must be handled by the caller 450 (defaults to `True`) 451 452 # Returns: 453 - `CoordArray` 454 a path between the selected start and end positions 455 456 # Raises: 457 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 458 """ 459 # we can't create a "path" in a single-node maze 460 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 461 f"can't create path in single-node maze: {self.as_ascii()}" 462 ) 463 464 # get connected component 465 connected_component: CoordArray = self.get_connected_component() 466 467 # initialize start and end positions 468 positions: Int[np.int8, "2 2"] 469 470 # if no special conditions on start and end positions 471 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 472 None, 473 None, 474 False, 475 False, 476 ): 477 try: 478 positions = connected_component[ # type: ignore[assignment] 479 np.random.choice( 480 len(connected_component), 481 size=2, 482 replace=False, 483 ) 484 ] 485 except ValueError as e: 486 if except_on_no_valid_endpoint: 487 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" 488 raise NoValidEndpointException( 489 err_msg, 490 ) from e 491 return None 492 493 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 494 495 # handle special conditions 496 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 497 # copy connected component set 498 allowed_start_set: set[CoordTup] = connected_component_set.copy() 499 allowed_end_set: set[CoordTup] = connected_component_set.copy() 500 501 # filter by explicitly allowed start and end positions 502 # '# type: ignore[assignment]' here because the returned tuple can be of any length 503 if allowed_start is not None: 504 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 505 506 if allowed_end is not None: 507 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 508 509 # filter by forcing deadends 510 if deadend_start: 511 allowed_start_set = set( 512 filter( 513 lambda x: len(self.get_coord_neighbors(x)) == 1, 514 allowed_start_set, 515 ), 516 ) 517 518 if deadend_end: 519 allowed_end_set = set( 520 filter( 521 lambda x: len(self.get_coord_neighbors(x)) == 1, 522 allowed_end_set, 523 ), 524 ) 525 526 # check we have valid positions 527 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 528 if except_on_no_valid_endpoint: 529 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 530 raise NoValidEndpointException( 531 err_msg, 532 ) 533 return None 534 535 # randomly select start and end positions 536 try: 537 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 538 start_pos: CoordTup = tuple( # type: ignore[assignment] 539 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], 540 ) 541 if endpoints_not_equal: 542 # remove start position from end positions 543 allowed_end_set.discard(start_pos) 544 end_pos: CoordTup = tuple( # type: ignore[assignment] 545 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], 546 ) 547 except ValueError as e: 548 if except_on_no_valid_endpoint: 549 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 550 raise NoValidEndpointException( 551 err_msg, 552 ) from e 553 return None 554 555 return self.find_shortest_path(start_pos, end_pos)
return a path between randomly chosen start and end nodes within the connected component
Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
Parameters:
allowed_start : CoordList | None
a list of allowed start positions. IfNone
, any position in the connected component is allowed (defaults toNone
)allowed_end : CoordList | None
a list of allowed end positions. IfNone
, any position in the connected component is allowed (defaults toNone
)deadend_start : bool
whether to force the start position to be a deadend (defaults toFalse
) (defaults toFalse
)deadend_end : bool
whether to force the end position to be a deadend (defaults toFalse
) (defaults toFalse
)endpoints_not_equal : bool
whether to ensure tha the start and end point are not the same (defaults toFalse
)except_on_no_valid_endpoint : bool
whether to raise an error if no valid start or end positions are found if this isFalse
, the function might returnNone
and this must be handled by the caller (defaults toTrue
)
Returns:
CoordArray
a path between the selected start and end positions
Raises:
NoValidEndpointException
: if no valid start or end positions are found, andexcept_on_no_valid_endpoint
isTrue
560 def as_adj_list( 561 self, 562 shuffle_d0: bool = True, 563 shuffle_d1: bool = True, 564 ) -> Int8[np.ndarray, "conn start_end coord"]: 565 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" 566 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)
return the maze as an adjacency list, wraps maze_dataset.token_utils.connection_list_to_adj_list
568 @classmethod 569 def from_adj_list( 570 cls, 571 adj_list: Int8[np.ndarray, "conn start_end coord"], 572 ) -> "LatticeMaze": 573 """create a LatticeMaze from a list of connections 574 575 > [!NOTE] 576 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 577 """ 578 # this is where it would probably break for rectangular mazes 579 grid_n: int = adj_list.max() + 1 580 581 connection_list: ConnectionList = np.zeros( 582 (2, grid_n, grid_n), 583 dtype=np.bool_, 584 ) 585 586 for c_start, c_end in adj_list: 587 # check that exactly 1 coordinate matches 588 if (c_start == c_end).sum() != 1: 589 raise ValueError("invalid connection") 590 591 # get the direction 592 d: int = (c_start != c_end).argmax() 593 594 x: int 595 y: int 596 # pick whichever has the lesser value in the direction `d` 597 if c_start[d] < c_end[d]: 598 x, y = c_start 599 else: 600 x, y = c_end 601 602 connection_list[d, x, y] = True 603 604 return LatticeMaze( 605 connection_list=connection_list, 606 )
create a LatticeMaze from a list of connections
This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
608 def as_adj_list_tokens(self) -> list[str | CoordTup]: 609 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" 610 warnings.warn( 611 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 612 TokenizerDeprecationWarning, 613 ) 614 return [ 615 SPECIAL_TOKENS.ADJLIST_START, 616 *chain.from_iterable( # type: ignore[list-item] 617 [ 618 [ 619 tuple(c_s), 620 SPECIAL_TOKENS.CONNECTOR, 621 tuple(c_e), 622 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 623 ] 624 for c_s, c_e in self.as_adj_list() 625 ], 626 ), 627 SPECIAL_TOKENS.ADJLIST_END, 628 ]
(deprecated!) turn the maze into adjacency list tokens, use MazeTokenizerModular
instead
680 def as_tokens( 681 self, 682 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 683 ) -> list[str]: 684 """serialize maze and solution to tokens""" 685 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 686 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 687 else: 688 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type]
serialize maze and solution to tokens
800 @classmethod 801 def from_tokens( 802 cls, 803 tokens: list[str], 804 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 805 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 806 """Constructs a maze from a tokenization. 807 808 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 809 """ 810 # HACK: type ignores here fine since we check the instance 811 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 812 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 813 if ( 814 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 815 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 816 ): 817 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 818 raise NotImplementedError( 819 err_msg, 820 ) 821 822 if isinstance(tokens, str): 823 tokens = tokens.split() 824 825 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 826 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 827 else: 828 raise NotImplementedError("only AOTP tokenization is supported")
Constructs a maze from a tokenization.
Only legacy tokenizers and their MazeTokenizerModular
analogs are supported.
859 def as_pixels( 860 self, 861 show_endpoints: bool = True, 862 show_solution: bool = True, 863 ) -> PixelGrid: 864 """convert the maze to a pixel grid 865 866 - useful as a simpler way of plotting the maze than the more complex `MazePlot` 867 - the same underlying representation as `as_ascii` but as an image 868 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data 869 """ 870 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 871 # but solution, start_pos, end_pos not always defined 872 # but its fine since we explicitly check the type 873 if show_solution and not show_endpoints: 874 raise ValueError("show_solution=True requires show_endpoints=True") 875 # convert original bool pixel grid to RGB 876 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 877 pixel_grid: PixelGrid = np.full( 878 (*pixel_grid_bw.shape, 3), 879 PixelColors.WALL, 880 dtype=np.uint8, 881 ) 882 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 883 884 if self.__class__ == LatticeMaze: 885 return pixel_grid 886 887 # set endpoints for TargetedLatticeMaze 888 if self.__class__ == TargetedLatticeMaze: 889 if show_endpoints: 890 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 891 PixelColors.START 892 ) 893 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 894 PixelColors.END 895 ) 896 return pixel_grid 897 898 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 899 if show_solution: 900 for coord in self.solution: # type: ignore[attr-defined] 901 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 902 903 # set pixels between coords 904 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 905 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 906 # check they are adjacent using norm 907 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 908 f"Coords {coord} and {next_coord} are not adjacent" 909 ) 910 # set pixel between them 911 pixel_grid[ 912 coord[0] * 2 + 1 + next_coord[0] - coord[0], 913 coord[1] * 2 + 1 + next_coord[1] - coord[1], 914 ] = PixelColors.PATH 915 916 # set endpoints (again, since path would overwrite them) 917 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 918 PixelColors.START 919 ) 920 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 921 PixelColors.END 922 ) 923 924 return pixel_grid
convert the maze to a pixel grid
- useful as a simpler way of plotting the maze than the more complex
MazePlot
- the same underlying representation as
as_ascii
but as an image - used in
RasterizedMazeDataset
, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data
979 @classmethod 980 def from_pixels( 981 cls, 982 pixel_grid: PixelGrid, 983 ) -> "LatticeMaze": 984 """create a LatticeMaze from a pixel grid. reverse of `as_pixels` 985 986 # Raises: 987 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` 988 """ 989 connection_list: ConnectionList 990 grid_shape: tuple[int, int] 991 992 # if a binary pixel grid, return regular LatticeMaze 993 if len(pixel_grid.shape) == 2: # noqa: PLR2004 994 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 995 return LatticeMaze(connection_list=connection_list) 996 997 # otherwise, detect and check it's valid 998 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 999 if cls not in cls_detected.__mro__: 1000 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" 1001 raise ValueError( 1002 err_msg, 1003 ) 1004 1005 ( 1006 connection_list, 1007 grid_shape, 1008 marked_pos, 1009 ) = cls._from_pixel_grid_with_positions( 1010 pixel_grid=pixel_grid, 1011 marked_positions=dict( 1012 start=PixelColors.START, 1013 end=PixelColors.END, 1014 solution=PixelColors.PATH, 1015 ), 1016 ) 1017 # if we wanted a LatticeMaze, return it 1018 if cls == LatticeMaze: 1019 return LatticeMaze(connection_list=connection_list) 1020 1021 # otherwise, keep going 1022 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 1023 1024 # start and end pos 1025 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 1026 assert start_pos_arr.shape == ( 1027 1, 1028 2, 1029 ), ( 1030 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1031 ) 1032 assert end_pos_arr.shape == ( 1033 1, 1034 2, 1035 ), ( 1036 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 1037 ) 1038 1039 start_pos: Coord = start_pos_arr[0] 1040 end_pos: Coord = end_pos_arr[0] 1041 1042 # return a TargetedLatticeMaze if that's what we wanted 1043 if cls == TargetedLatticeMaze: 1044 return TargetedLatticeMaze( 1045 connection_list=connection_list, 1046 start_pos=start_pos, 1047 end_pos=end_pos, 1048 ) 1049 1050 # raw solution, only contains path elements and not start or end 1051 solution_raw: CoordArray = marked_pos["solution"] 1052 if len(solution_raw.shape) == 2: # noqa: PLR2004 1053 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 1054 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 1055 ) 1056 elif solution_raw.shape == (0,): 1057 # the solution and end should be immediately adjacent 1058 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 1059 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 1060 ) 1061 1062 # order the solution, by creating a list from the start to the end 1063 # add end pos, since we will iterate over all these starting from the start pos 1064 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 1065 tuple(end_pos), 1066 ] 1067 # solution starts with start point 1068 solution: list[CoordTup] = [tuple(start_pos)] 1069 while solution[-1] != tuple(end_pos): 1070 # use `get_coord_neighbors` to find connected neighbors 1071 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 1072 # TODO: make this less ugly 1073 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 1074 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 1075 ) 1076 # neighbors = neighbors[:, [1, 0]] 1077 # filter out neighbors that are not in the raw solution 1078 neighbors_filtered: CoordArray = np.array( 1079 [ 1080 coord 1081 for coord in neighbors 1082 if ( 1083 tuple(coord) in solution_raw_list 1084 and tuple(coord) not in solution 1085 ) 1086 ], 1087 ) 1088 # assert only one element is left, and then add it to the solution 1089 assert neighbors_filtered.shape == ( 1090 1, 1091 2, 1092 ), ( 1093 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 1094 ) 1095 solution.append(tuple(neighbors_filtered[0])) 1096 1097 # assert the solution is complete 1098 assert solution[0] == tuple(start_pos), ( 1099 f"solution {solution} does not start at start_pos {start_pos}" 1100 ) 1101 assert solution[-1] == tuple(end_pos), ( 1102 f"solution {solution} does not end at end_pos {end_pos}" 1103 ) 1104 1105 return cls( 1106 connection_list=np.array(connection_list), 1107 solution=np.array(solution), # type: ignore[call-arg] 1108 )
create a LatticeMaze from a pixel grid. reverse of as_pixels
Raises:
ValueError
: if the pixel grid cannot be cast to aLatticeMaze
-- it's probably aTargetedLatticeMaze
orSolvedMaze
1127 def as_ascii( 1128 self, 1129 show_endpoints: bool = True, 1130 show_solution: bool = True, 1131 ) -> str: 1132 """return an ASCII grid of the maze 1133 1134 useful for debugging in the terminal, or as it's own format 1135 1136 can be reversed with `LatticeMaze.from_ascii()` 1137 """ 1138 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 1139 pixel_grid: PixelGrid = self.as_pixels( 1140 show_endpoints=show_endpoints, 1141 show_solution=show_solution, 1142 ) 1143 1144 chars_replace: tuple = tuple() 1145 if show_endpoints: 1146 chars_replace += (AsciiChars.START, AsciiChars.END) 1147 if show_solution: 1148 chars_replace += (AsciiChars.PATH,) 1149 1150 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1151 if ascii_char in chars_replace: 1152 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 1153 1154 return "\n".join("".join(row) for row in ascii_grid)
return an ASCII grid of the maze
useful for debugging in the terminal, or as it's own format
can be reversed with LatticeMaze.from_ascii()
1156 @classmethod 1157 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 1158 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" 1159 lines: list[str] = ascii_str.strip().split("\n") 1160 lines = [line.strip() for line in lines] 1161 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 1162 [list(line) for line in lines], 1163 dtype=str, 1164 ) 1165 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 1166 1167 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 1168 pixel_grid[ascii_grid == ascii_char] = pixel_color 1169 1170 return cls.from_pixels(pixel_grid)
get a LatticeMaze
from an ASCII representation (reverses LaticeMaze.as_ascii
)
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
1176@serializable_dataclass(frozen=True, kw_only=True) 1177class TargetedLatticeMaze(LatticeMaze): # type: ignore[misc] 1178 """A LatticeMaze with a start and end position""" 1179 1180 # this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos 1181 # type ignore here because even though its a kw-only dataclass, 1182 # mypy doesn't like that non-default arguments are after default arguments 1183 start_pos: Coord = serializable_field( # type: ignore[misc] 1184 assert_type=False, 1185 ) 1186 end_pos: Coord = serializable_field( # type: ignore[misc] 1187 assert_type=False, 1188 ) 1189 1190 def __post_init__(self) -> None: 1191 "post init converts start and end pos to numpy arrays, checks they exist and are in bounds" 1192 # make things numpy arrays (very jank to override frozen dataclass) 1193 self.__dict__["start_pos"] = np.array(self.start_pos) 1194 self.__dict__["end_pos"] = np.array(self.end_pos) 1195 assert self.start_pos is not None 1196 assert self.end_pos is not None 1197 # check that start and end are in bounds 1198 if ( 1199 self.start_pos[0] >= self.grid_shape[0] 1200 or self.start_pos[1] >= self.grid_shape[1] 1201 ): 1202 err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}" 1203 raise ValueError( 1204 err_msg, 1205 ) 1206 if ( 1207 self.end_pos[0] >= self.grid_shape[0] 1208 or self.end_pos[1] >= self.grid_shape[1] 1209 ): 1210 err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }" 1211 raise ValueError( 1212 err_msg, 1213 ) 1214 1215 def __eq__(self, other: object) -> bool: 1216 "check equality, calls parent class equality check" 1217 return super().__eq__(other) 1218 1219 def _get_start_pos_tokens(self) -> list[str | CoordTup]: 1220 return [ 1221 SPECIAL_TOKENS.ORIGIN_START, 1222 tuple(self.start_pos), 1223 SPECIAL_TOKENS.ORIGIN_END, 1224 ] 1225 1226 def get_start_pos_tokens(self) -> list[str | CoordTup]: 1227 "(deprecated!) return the start position as a list of tokens" 1228 warnings.warn( 1229 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 1230 TokenizerDeprecationWarning, 1231 ) 1232 return self._get_start_pos_tokens() 1233 1234 def _get_end_pos_tokens(self) -> list[str | CoordTup]: 1235 return [ 1236 SPECIAL_TOKENS.TARGET_START, 1237 tuple(self.end_pos), 1238 SPECIAL_TOKENS.TARGET_END, 1239 ] 1240 1241 def get_end_pos_tokens(self) -> list[str | CoordTup]: 1242 "(deprecated!) return the end position as a list of tokens" 1243 warnings.warn( 1244 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 1245 TokenizerDeprecationWarning, 1246 ) 1247 return self._get_end_pos_tokens() 1248 1249 @classmethod 1250 def from_lattice_maze( 1251 cls, 1252 lattice_maze: LatticeMaze, 1253 start_pos: Coord | CoordTup, 1254 end_pos: Coord | CoordTup, 1255 ) -> "TargetedLatticeMaze": 1256 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" 1257 return cls( 1258 connection_list=lattice_maze.connection_list, 1259 start_pos=np.array(start_pos), 1260 end_pos=np.array(end_pos), 1261 generation_meta=lattice_maze.generation_meta, 1262 )
A LatticeMaze with a start and end position
1226 def get_start_pos_tokens(self) -> list[str | CoordTup]: 1227 "(deprecated!) return the start position as a list of tokens" 1228 warnings.warn( 1229 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 1230 TokenizerDeprecationWarning, 1231 ) 1232 return self._get_start_pos_tokens()
(deprecated!) return the start position as a list of tokens
1241 def get_end_pos_tokens(self) -> list[str | CoordTup]: 1242 "(deprecated!) return the end position as a list of tokens" 1243 warnings.warn( 1244 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 1245 TokenizerDeprecationWarning, 1246 ) 1247 return self._get_end_pos_tokens()
(deprecated!) return the end position as a list of tokens
1249 @classmethod 1250 def from_lattice_maze( 1251 cls, 1252 lattice_maze: LatticeMaze, 1253 start_pos: Coord | CoordTup, 1254 end_pos: Coord | CoordTup, 1255 ) -> "TargetedLatticeMaze": 1256 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" 1257 return cls( 1258 connection_list=lattice_maze.connection_list, 1259 start_pos=np.array(start_pos), 1260 end_pos=np.array(end_pos), 1261 generation_meta=lattice_maze.generation_meta, 1262 )
get a TargetedLatticeMaze
from a LatticeMaze
by specifying start and end positions
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- LatticeMaze
- connection_list
- generation_meta
- lattice_dim
- grid_shape
- n_connections
- grid_n
- heuristic
- nodes_connected
- is_valid_path
- coord_degrees
- get_coord_neighbors
- gen_connected_component_from
- find_shortest_path
- get_nodes
- get_connected_component
- generate_random_path
- as_adj_list
- from_adj_list
- as_adj_list_tokens
- as_tokens
- from_tokens
- as_pixels
- from_pixels
- as_ascii
- from_ascii
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
1265@serializable_dataclass(frozen=True, kw_only=True) 1266class SolvedMaze(TargetedLatticeMaze): # type: ignore[misc] 1267 """Stores a maze and a solution""" 1268 1269 solution: CoordArray = serializable_field( # type: ignore[misc] 1270 assert_type=False, 1271 ) 1272 1273 def __init__( 1274 self, 1275 connection_list: ConnectionList, 1276 solution: CoordArray, 1277 generation_meta: dict | None = None, 1278 start_pos: Coord | None = None, 1279 end_pos: Coord | None = None, 1280 allow_invalid: bool = False, 1281 ) -> None: 1282 """Create a SolvedMaze from a connection list and a solution 1283 1284 > DOCS: better documentation for this init method 1285 """ 1286 # figure out the solution 1287 solution_valid: bool = False 1288 if solution is not None: 1289 solution = np.array(solution) 1290 # note that a path length of 1 here is valid, since the start and end pos could be the same 1291 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 1292 solution_valid = True 1293 1294 if not solution_valid and not allow_invalid: 1295 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" 1296 raise ValueError( 1297 err_msg, 1298 f"{connection_list = }", 1299 ) 1300 1301 # init the TargetedLatticeMaze 1302 super().__init__( 1303 connection_list=connection_list, 1304 generation_meta=generation_meta, 1305 # TODO: the argument type is stricter than the expected type but it still fails? 1306 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 1307 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 1308 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 1309 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 1310 ) 1311 1312 self.__dict__["solution"] = solution 1313 1314 # adjust the endpoints 1315 if not allow_invalid: 1316 if start_pos is not None: 1317 assert np.array_equal(np.array(start_pos), self.start_pos), ( 1318 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 1319 ) 1320 if end_pos is not None: 1321 assert np.array_equal(np.array(end_pos), self.end_pos), ( 1322 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 1323 ) 1324 # TODO: assert the path does not backtrack, walk through walls, etc? 1325 1326 def __eq__(self, other: object) -> bool: 1327 "check equality, calls parent class equality check" 1328 return super().__eq__(other) 1329 1330 def __hash__(self) -> int: 1331 "hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes" 1332 return hash((self.connection_list.tobytes(), self.solution.tobytes())) 1333 1334 def _get_solution_tokens(self) -> list[str | CoordTup]: 1335 return [ 1336 SPECIAL_TOKENS.PATH_START, 1337 *[tuple(c) for c in self.solution], 1338 SPECIAL_TOKENS.PATH_END, 1339 ] 1340 1341 def get_solution_tokens(self) -> list[str | CoordTup]: 1342 "(deprecated!) return the solution as a list of tokens" 1343 warnings.warn( 1344 "`LatticeMaze.get_solution_tokens` is deprecated.", 1345 TokenizerDeprecationWarning, 1346 ) 1347 return self._get_solution_tokens() 1348 1349 # for backwards compatibility 1350 @property 1351 def maze(self) -> LatticeMaze: 1352 "(deprecated!) return the maze without the solution" 1353 warnings.warn( 1354 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 1355 DeprecationWarning, 1356 ) 1357 return LatticeMaze(connection_list=self.connection_list) 1358 1359 # type ignore here since we're overriding a method with a different signature 1360 @classmethod 1361 def from_lattice_maze( # type: ignore[override] 1362 cls, 1363 lattice_maze: LatticeMaze, 1364 solution: list[CoordTup] | CoordArray, 1365 ) -> "SolvedMaze": 1366 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" 1367 return cls( 1368 connection_list=lattice_maze.connection_list, 1369 solution=np.array(solution), 1370 generation_meta=lattice_maze.generation_meta, 1371 ) 1372 1373 @classmethod 1374 def from_targeted_lattice_maze( 1375 cls, 1376 targeted_lattice_maze: TargetedLatticeMaze, 1377 solution: list[CoordTup] | CoordArray | None = None, 1378 ) -> "SolvedMaze": 1379 """solves the given targeted lattice maze and returns a SolvedMaze""" 1380 if solution is None: 1381 solution = targeted_lattice_maze.find_shortest_path( 1382 targeted_lattice_maze.start_pos, 1383 targeted_lattice_maze.end_pos, 1384 ) 1385 return cls( 1386 connection_list=targeted_lattice_maze.connection_list, 1387 solution=np.array(solution), 1388 generation_meta=targeted_lattice_maze.generation_meta, 1389 ) 1390 1391 def get_solution_forking_points( 1392 self, 1393 always_include_endpoints: bool = False, 1394 ) -> tuple[list[int], CoordArray]: 1395 """coordinates and their indicies from the solution where a fork is present 1396 1397 - if the start point is not a dead end, this counts as a fork 1398 - if the end point is not a dead end, this counts as a fork 1399 """ 1400 output_idxs: list[int] = list() 1401 output_coords: list[CoordTup] = list() 1402 1403 for idx, coord in enumerate(self.solution): 1404 # more than one choice for first coord, or more than 2 for any other 1405 # since the previous coord doesn't count as a choice 1406 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 1407 theshold: int = 1 if is_endpoint else 2 1408 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 1409 is_endpoint and always_include_endpoints 1410 ): 1411 output_idxs.append(idx) 1412 output_coords.append(coord) 1413 1414 return output_idxs, np.array(output_coords) 1415 1416 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 1417 """coordinates from the solution where there is only a single (non-backtracking) point to move to 1418 1419 returns the complement of `get_solution_forking_points` from the path 1420 """ 1421 forks_idxs, _ = self.get_solution_forking_points() 1422 # HACK: idk why type ignore here 1423 return ( # type: ignore[return-value] 1424 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 1425 np.delete(self.solution, forks_idxs, axis=0), 1426 )
Stores a maze and a solution
1273 def __init__( 1274 self, 1275 connection_list: ConnectionList, 1276 solution: CoordArray, 1277 generation_meta: dict | None = None, 1278 start_pos: Coord | None = None, 1279 end_pos: Coord | None = None, 1280 allow_invalid: bool = False, 1281 ) -> None: 1282 """Create a SolvedMaze from a connection list and a solution 1283 1284 > DOCS: better documentation for this init method 1285 """ 1286 # figure out the solution 1287 solution_valid: bool = False 1288 if solution is not None: 1289 solution = np.array(solution) 1290 # note that a path length of 1 here is valid, since the start and end pos could be the same 1291 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 1292 solution_valid = True 1293 1294 if not solution_valid and not allow_invalid: 1295 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" 1296 raise ValueError( 1297 err_msg, 1298 f"{connection_list = }", 1299 ) 1300 1301 # init the TargetedLatticeMaze 1302 super().__init__( 1303 connection_list=connection_list, 1304 generation_meta=generation_meta, 1305 # TODO: the argument type is stricter than the expected type but it still fails? 1306 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 1307 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 1308 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 1309 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 1310 ) 1311 1312 self.__dict__["solution"] = solution 1313 1314 # adjust the endpoints 1315 if not allow_invalid: 1316 if start_pos is not None: 1317 assert np.array_equal(np.array(start_pos), self.start_pos), ( 1318 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 1319 ) 1320 if end_pos is not None: 1321 assert np.array_equal(np.array(end_pos), self.end_pos), ( 1322 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 1323 ) 1324 # TODO: assert the path does not backtrack, walk through walls, etc?
Create a SolvedMaze from a connection list and a solution
DOCS: better documentation for this init method
1341 def get_solution_tokens(self) -> list[str | CoordTup]: 1342 "(deprecated!) return the solution as a list of tokens" 1343 warnings.warn( 1344 "`LatticeMaze.get_solution_tokens` is deprecated.", 1345 TokenizerDeprecationWarning, 1346 ) 1347 return self._get_solution_tokens()
(deprecated!) return the solution as a list of tokens
1350 @property 1351 def maze(self) -> LatticeMaze: 1352 "(deprecated!) return the maze without the solution" 1353 warnings.warn( 1354 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 1355 DeprecationWarning, 1356 ) 1357 return LatticeMaze(connection_list=self.connection_list)
(deprecated!) return the maze without the solution
1360 @classmethod 1361 def from_lattice_maze( # type: ignore[override] 1362 cls, 1363 lattice_maze: LatticeMaze, 1364 solution: list[CoordTup] | CoordArray, 1365 ) -> "SolvedMaze": 1366 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" 1367 return cls( 1368 connection_list=lattice_maze.connection_list, 1369 solution=np.array(solution), 1370 generation_meta=lattice_maze.generation_meta, 1371 )
get a SolvedMaze
from a LatticeMaze
by specifying a solution
1373 @classmethod 1374 def from_targeted_lattice_maze( 1375 cls, 1376 targeted_lattice_maze: TargetedLatticeMaze, 1377 solution: list[CoordTup] | CoordArray | None = None, 1378 ) -> "SolvedMaze": 1379 """solves the given targeted lattice maze and returns a SolvedMaze""" 1380 if solution is None: 1381 solution = targeted_lattice_maze.find_shortest_path( 1382 targeted_lattice_maze.start_pos, 1383 targeted_lattice_maze.end_pos, 1384 ) 1385 return cls( 1386 connection_list=targeted_lattice_maze.connection_list, 1387 solution=np.array(solution), 1388 generation_meta=targeted_lattice_maze.generation_meta, 1389 )
solves the given targeted lattice maze and returns a SolvedMaze
1391 def get_solution_forking_points( 1392 self, 1393 always_include_endpoints: bool = False, 1394 ) -> tuple[list[int], CoordArray]: 1395 """coordinates and their indicies from the solution where a fork is present 1396 1397 - if the start point is not a dead end, this counts as a fork 1398 - if the end point is not a dead end, this counts as a fork 1399 """ 1400 output_idxs: list[int] = list() 1401 output_coords: list[CoordTup] = list() 1402 1403 for idx, coord in enumerate(self.solution): 1404 # more than one choice for first coord, or more than 2 for any other 1405 # since the previous coord doesn't count as a choice 1406 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 1407 theshold: int = 1 if is_endpoint else 2 1408 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 1409 is_endpoint and always_include_endpoints 1410 ): 1411 output_idxs.append(idx) 1412 output_coords.append(coord) 1413 1414 return output_idxs, np.array(output_coords)
coordinates and their indicies from the solution where a fork is present
- if the start point is not a dead end, this counts as a fork
- if the end point is not a dead end, this counts as a fork
1416 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 1417 """coordinates from the solution where there is only a single (non-backtracking) point to move to 1418 1419 returns the complement of `get_solution_forking_points` from the path 1420 """ 1421 forks_idxs, _ = self.get_solution_forking_points() 1422 # HACK: idk why type ignore here 1423 return ( # type: ignore[return-value] 1424 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 1425 np.delete(self.solution, forks_idxs, axis=0), 1426 )
coordinates from the solution where there is only a single (non-backtracking) point to move to
returns the complement of get_solution_forking_points
from the path
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- LatticeMaze
- connection_list
- generation_meta
- lattice_dim
- grid_shape
- n_connections
- grid_n
- heuristic
- nodes_connected
- is_valid_path
- coord_degrees
- get_coord_neighbors
- gen_connected_component_from
- find_shortest_path
- get_nodes
- get_connected_component
- generate_random_path
- as_adj_list
- from_adj_list
- as_adj_list_tokens
- as_tokens
- from_tokens
- as_pixels
- from_pixels
- as_ascii
- from_ascii
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
1429def detect_pixels_type(data: PixelGrid) -> typing.Type[LatticeMaze]: 1430 """Detects the type of pixels data by checking for the presence of start and end pixels""" 1431 if color_in_pixel_grid(data, PixelColors.START) or color_in_pixel_grid( 1432 data, 1433 PixelColors.END, 1434 ): 1435 if color_in_pixel_grid(data, PixelColors.PATH): 1436 return SolvedMaze 1437 else: 1438 return TargetedLatticeMaze 1439 else: 1440 return LatticeMaze
Detects the type of pixels data by checking for the presence of start and end pixels