Coverage for maze_dataset/maze/lattice_maze.py: 65%
508 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:51 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:51 -0600
1"""Implements `LatticeMaze`, and the `TargetedLatticeMaze` and `SolvedMaze` subclasses.
3also includes basic utilities, including converting to/from ascii and pixel representations.
4"""
6import typing
7import warnings
8from dataclasses import dataclass
9from itertools import chain
11import numpy as np
12from jaxtyping import Bool, Int, Int8, Shaped
13from muutils.json_serialize.serializable_dataclass import (
14 SerializableDataclass,
15 serializable_dataclass,
16 serializable_field,
17)
18from muutils.misc import isinstance_by_type_name, list_split
20from maze_dataset.constants import (
21 NEIGHBORS_MASK,
22 SPECIAL_TOKENS,
23 ConnectionList,
24 Coord,
25 CoordArray,
26 CoordList,
27 CoordTup,
28)
29from maze_dataset.token_utils import (
30 TokenizerDeprecationWarning,
31 connection_list_to_adj_list,
32 get_adj_list_tokens,
33 get_origin_tokens,
34 get_path_tokens,
35 get_target_tokens,
36)
38if typing.TYPE_CHECKING:
39 from maze_dataset.tokenization import (
40 MazeTokenizer,
41 MazeTokenizerModular,
42 TokenizationMode,
43 )
45RGB = tuple[int, int, int]
46"rgb tuple of values 0-255"
48PixelGrid = Int[np.ndarray, "x y rgb"]
49"rgb grid of pixels"
50BinaryPixelGrid = Bool[np.ndarray, "x y"]
51"boolean grid of pixels"
53DIM_2: int = 2
54"2 dimensions"
57class NoValidEndpointException(Exception): # noqa: N818
58 """Raised when no valid start or end positions are found in a maze."""
60 pass
63def _fill_edges_with_walls(connection_list: ConnectionList) -> ConnectionList:
64 """fill the last elements of the connections lists as false for each dim"""
65 for dim in range(connection_list.shape[0]):
66 # last row for down
67 if dim == 0:
68 connection_list[dim, -1, :] = False
69 # last column for right
70 elif dim == 1:
71 connection_list[dim, :, -1] = False
72 else:
73 err_msg: str = f"only 2d lattices supported. got {dim=}"
74 raise NotImplementedError(err_msg)
75 return connection_list
78def color_in_pixel_grid(pixel_grid: PixelGrid, color: RGB) -> bool:
79 """check if a color is in a pixel grid"""
80 for row in pixel_grid:
81 for pixel in row:
82 if np.all(pixel == color):
83 return True
84 return False
87@dataclass(frozen=True)
88class PixelColors:
89 "standard colors for pixel grids"
91 WALL: RGB = (0, 0, 0)
92 OPEN: RGB = (255, 255, 255)
93 START: RGB = (0, 255, 0)
94 END: RGB = (255, 0, 0)
95 PATH: RGB = (0, 0, 255)
98@dataclass(frozen=True)
99class AsciiChars:
100 "standard ascii characters for mazes"
102 WALL: str = "#"
103 OPEN: str = " "
104 START: str = "S"
105 END: str = "E"
106 PATH: str = "X"
109ASCII_PIXEL_PAIRINGS: dict[str, RGB] = {
110 AsciiChars.WALL: PixelColors.WALL,
111 AsciiChars.OPEN: PixelColors.OPEN,
112 AsciiChars.START: PixelColors.START,
113 AsciiChars.END: PixelColors.END,
114 AsciiChars.PATH: PixelColors.PATH,
115}
116"map ascii characters to pixel colors"
119@serializable_dataclass(
120 frozen=True,
121 kw_only=True,
122 properties_to_serialize=["lattice_dim", "generation_meta"],
123)
124class LatticeMaze(SerializableDataclass):
125 """lattice maze (nodes on a lattice, connections only to neighboring nodes)
127 Connection List represents which nodes (N) are connected in each direction.
129 First and second elements represent rightward and downward connections,
130 respectively.
132 Example:
133 Connection list:
134 [
135 [ # down
136 [F T],
137 [F F]
138 ],
139 [ # right
140 [T F],
141 [T F]
142 ]
143 ]
145 Nodes with connections
146 N T N F
147 F T
148 N T N F
149 F F
151 Graph:
152 N - N
153 |
154 N - N
156 Note: the bottom row connections going down, and the
157 right-hand connections going right, will always be False.
159 """
161 connection_list: ConnectionList
162 generation_meta: dict | None = serializable_field(default=None, compare=False)
164 lattice_dim = property(lambda self: self.connection_list.shape[0])
165 grid_shape = property(lambda self: self.connection_list.shape[1:])
166 n_connections = property(lambda self: self.connection_list.sum())
168 @property
169 def grid_n(self) -> int:
170 "grid size as int, raises `AssertionError` if not square"
171 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported"
172 return self.grid_shape[0]
174 # ============================================================
175 # basic methods
176 # ============================================================
178 def __eq__(self, other: object) -> bool:
179 "equality check calls super"
180 return super().__eq__(other)
182 @staticmethod
183 def heuristic(a: CoordTup, b: CoordTup) -> float:
184 """return manhattan distance between two points"""
185 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])
187 def __hash__(self) -> int:
188 """hash the connection list by converting connection list to bytes"""
189 return hash(self.connection_list.tobytes())
191 def nodes_connected(self, a: Coord, b: Coord, /) -> bool:
192 """returns whether two nodes are connected"""
193 delta: Coord = b - a
194 if np.abs(delta).sum() != 1:
195 # return false if not even adjacent
196 return False
197 else:
198 # test for wall
199 dim: int = int(np.argmax(np.abs(delta)))
200 clist_node: Coord = a if (delta.sum() > 0) else b
201 return self.connection_list[dim, clist_node[0], clist_node[1]]
203 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool:
204 """check if a path is valid"""
205 # check path is not empty
206 if len(path) == 0:
207 return empty_is_valid
209 # check all coords in bounds of maze
210 if not np.all((path >= 0) & (path < self.grid_shape)):
211 return False
213 # check all nodes connected
214 for i in range(len(path) - 1):
215 if not self.nodes_connected(path[i], path[i + 1]):
216 return False
217 return True
219 def coord_degrees(self) -> Int8[np.ndarray, "row col"]:
220 """Returns an array with the connectivity degree of each coord.
222 I.e., how many neighbors each coord has.
223 """
224 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = (
225 self.connection_list.astype(np.int8)
226 )
227 degrees: Int8[np.ndarray, "row col"] = np.sum(
228 int_conn,
229 axis=0,
230 ) # Connections to east and south
231 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west
232 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north
233 return degrees
235 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray:
236 """Returns an array of the neighboring, connected coords of `c`."""
237 c = np.array(c) # type: ignore[assignment]
238 neighbors: list[Coord] = [
239 neighbor
240 for neighbor in (c + NEIGHBORS_MASK)
241 if (
242 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds
243 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds
244 and self.nodes_connected(c, neighbor) # connected
245 )
246 ]
248 output: CoordArray = np.array(neighbors)
249 if len(neighbors) > 0:
250 assert output.shape == (
251 len(neighbors),
252 2,
253 ), (
254 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}"
255 )
256 return output
258 def gen_connected_component_from(self, c: Coord) -> CoordArray:
259 """return the connected component from a given coordinate"""
260 # Stack for DFS
261 stack: list[Coord] = [c]
263 # Set to store visited nodes
264 visited: set[CoordTup] = set()
266 while stack:
267 current_node: Coord = stack.pop()
268 # this is fine since we know current_node is a coord and thus of length 2
269 visited.add(tuple(current_node)) # type: ignore[arg-type]
271 # Get the neighbors of the current node
272 neighbors = self.get_coord_neighbors(current_node)
274 # Iterate over neighbors
275 for neighbor in neighbors:
276 if tuple(neighbor) not in visited:
277 stack.append(neighbor)
279 return np.array(list(visited))
281 def find_shortest_path(
282 self,
283 c_start: CoordTup | Coord,
284 c_end: CoordTup | Coord,
285 ) -> CoordArray:
286 """find the shortest path between two coordinates, using A*"""
287 c_start = tuple(c_start) # type: ignore[assignment]
288 c_end = tuple(c_end) # type: ignore[assignment]
290 g_score: dict[CoordTup, float] = (
291 dict()
292 ) # cost of cheapest path to node from start currently known
293 f_score: dict[CoordTup, float] = {
294 c_start: 0.0,
295 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end)
297 # init
298 g_score[c_start] = 0.0
299 g_score[c_start] = self.heuristic(c_start, c_end)
301 closed_vtx: set[CoordTup] = set() # nodes already evaluated
302 # nodes to be evaluated
303 # we need a set of the tuples, dont place the ints in the set
304 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405
305 source: dict[CoordTup, CoordTup] = (
306 dict()
307 ) # node immediately preceding each node in the path (currently known shortest path)
309 while open_vtx:
310 # get lowest f_score node
311 # mypy cant tell that c is of length 2
312 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index]
313 # f_current: float = f_score[c_current]
315 # check if goal is reached
316 if c_end == c_current:
317 path: list[CoordTup] = [c_current]
318 p_current: CoordTup = c_current
319 while p_current in source:
320 p_current = source[p_current]
321 path.append(p_current)
322 # ----------------------------------------------------------------------
323 # this is the only return statement
324 return np.array(path[::-1])
325 # ----------------------------------------------------------------------
327 # close current node
328 closed_vtx.add(c_current)
329 open_vtx.remove(c_current)
331 # update g_score of neighbors
332 _np_neighbor: Coord
333 for _np_neighbor in self.get_coord_neighbors(c_current):
334 neighbor: CoordTup = tuple(_np_neighbor)
336 if neighbor in closed_vtx:
337 # already checked
338 continue
339 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors
341 if neighbor not in open_vtx:
342 # found new vtx, so add
343 open_vtx.add(neighbor)
345 elif g_temp >= g_score[neighbor]:
346 # if already knew about this one, but current g_score is worse, skip
347 continue
349 # store g_score and source
350 source[neighbor] = c_current
351 g_score[neighbor] = g_temp
352 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end)
354 raise ValueError(
355 "A solution could not be found!",
356 f"{c_start = }, {c_end = }",
357 self.as_ascii(),
358 )
360 def get_nodes(self) -> CoordArray:
361 """return a list of all nodes in the maze"""
362 rows: Int[np.ndarray, "x y"]
363 cols: Int[np.ndarray, "x y"]
364 rows, cols = np.meshgrid(
365 range(self.grid_shape[0]),
366 range(self.grid_shape[1]),
367 indexing="ij",
368 )
369 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T
370 return nodes
372 def get_connected_component(self) -> CoordArray:
373 """get the largest (and assumed only nonsingular) connected component of the maze
375 TODO: other connected components?
376 """
377 if (self.generation_meta is None) or (
378 self.generation_meta.get("fully_connected", False)
379 ):
380 # for fully connected case, pick any two positions
381 return self.get_nodes()
382 else:
383 # if metadata provided, use visited cells
384 visited_cells: set[CoordTup] | None = self.generation_meta.get(
385 "visited_cells",
386 None,
387 )
388 if visited_cells is None:
389 # TODO: dynamically generate visited_cells?
390 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}"
391 raise ValueError(
392 err_msg,
393 )
394 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells))
395 return visited_cells_np
397 @typing.overload
398 def generate_random_path(
399 self,
400 allowed_start: CoordList | None = None,
401 allowed_end: CoordList | None = None,
402 deadend_start: bool = False,
403 deadend_end: bool = False,
404 endpoints_not_equal: bool = False,
405 except_on_no_valid_endpoint: typing.Literal[True] = True,
406 ) -> CoordArray: ...
407 @typing.overload
408 def generate_random_path(
409 self,
410 allowed_start: CoordList | None = None,
411 allowed_end: CoordList | None = None,
412 deadend_start: bool = False,
413 deadend_end: bool = False,
414 endpoints_not_equal: bool = False,
415 except_on_no_valid_endpoint: typing.Literal[False] = False,
416 ) -> typing.Optional[CoordArray]: ...
417 def generate_random_path( # noqa: C901
418 self,
419 allowed_start: CoordList | None = None,
420 allowed_end: CoordList | None = None,
421 deadend_start: bool = False,
422 deadend_end: bool = False,
423 endpoints_not_equal: bool = False,
424 except_on_no_valid_endpoint: bool = True,
425 ) -> typing.Optional[CoordArray]:
426 """return a path between randomly chosen start and end nodes within the connected component
428 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
430 # Parameters:
431 - `allowed_start : CoordList | None`
432 a list of allowed start positions. If `None`, any position in the connected component is allowed
433 (defaults to `None`)
434 - `allowed_end : CoordList | None`
435 a list of allowed end positions. If `None`, any position in the connected component is allowed
436 (defaults to `None`)
437 - `deadend_start : bool`
438 whether to ***force*** the start position to be a deadend (defaults to `False`)
439 (defaults to `False`)
440 - `deadend_end : bool`
441 whether to ***force*** the end position to be a deadend (defaults to `False`)
442 (defaults to `False`)
443 - `endpoints_not_equal : bool`
444 whether to ensure tha the start and end point are not the same
445 (defaults to `False`)
446 - `except_on_no_valid_endpoint : bool`
447 whether to raise an error if no valid start or end positions are found
448 if this is `False`, the function might return `None` and this must be handled by the caller
449 (defaults to `True`)
451 # Returns:
452 - `CoordArray`
453 a path between the selected start and end positions
455 # Raises:
456 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True`
457 """
458 # we can't create a "path" in a single-node maze
459 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018
460 f"can't create path in single-node maze: {self.as_ascii()}"
461 )
463 # get connected component
464 connected_component: CoordArray = self.get_connected_component()
466 # initialize start and end positions
467 positions: Int[np.int8, "2 2"]
469 # if no special conditions on start and end positions
470 if (allowed_start, allowed_end, deadend_start, deadend_end) == (
471 None,
472 None,
473 False,
474 False,
475 ):
476 try:
477 positions = connected_component[ # type: ignore[assignment]
478 np.random.choice(
479 len(connected_component),
480 size=2,
481 replace=False,
482 )
483 ]
484 except ValueError as e:
485 if except_on_no_valid_endpoint:
486 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }"
487 raise NoValidEndpointException(
488 err_msg,
489 ) from e
490 return None
492 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index]
494 # handle special conditions
495 connected_component_set: set[CoordTup] = set(map(tuple, connected_component))
496 # copy connected component set
497 allowed_start_set: set[CoordTup] = connected_component_set.copy()
498 allowed_end_set: set[CoordTup] = connected_component_set.copy()
500 # filter by explicitly allowed start and end positions
501 # '# type: ignore[assignment]' here because the returned tuple can be of any length
502 if allowed_start is not None:
503 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment]
505 if allowed_end is not None:
506 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment]
508 # filter by forcing deadends
509 if deadend_start:
510 allowed_start_set = set(
511 filter(
512 lambda x: len(self.get_coord_neighbors(x)) == 1,
513 allowed_start_set,
514 ),
515 )
517 if deadend_end:
518 allowed_end_set = set(
519 filter(
520 lambda x: len(self.get_coord_neighbors(x)) == 1,
521 allowed_end_set,
522 ),
523 )
525 # check we have valid positions
526 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0:
527 if except_on_no_valid_endpoint:
528 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }"
529 raise NoValidEndpointException(
530 err_msg,
531 )
532 return None
534 # randomly select start and end positions
535 try:
536 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok
537 start_pos: CoordTup = tuple( # type: ignore[assignment]
538 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))],
539 )
540 if endpoints_not_equal:
541 # remove start position from end positions
542 allowed_end_set.discard(start_pos)
543 end_pos: CoordTup = tuple( # type: ignore[assignment]
544 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))],
545 )
546 except ValueError as e:
547 if except_on_no_valid_endpoint:
548 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }"
549 raise NoValidEndpointException(
550 err_msg,
551 ) from e
552 return None
554 return self.find_shortest_path(start_pos, end_pos)
556 # ============================================================
557 # to and from adjacency list
558 # ============================================================
559 def as_adj_list(
560 self,
561 shuffle_d0: bool = True,
562 shuffle_d1: bool = True,
563 ) -> Int8[np.ndarray, "conn start_end coord"]:
564 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`"""
565 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)
567 @classmethod
568 def from_adj_list(
569 cls,
570 adj_list: Int8[np.ndarray, "conn start_end coord"],
571 ) -> "LatticeMaze":
572 """create a LatticeMaze from a list of connections
574 > [!NOTE]
575 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
576 """
577 # this is where it would probably break for rectangular mazes
578 grid_n: int = adj_list.max() + 1
580 connection_list: ConnectionList = np.zeros(
581 (2, grid_n, grid_n),
582 dtype=np.bool_,
583 )
585 for c_start, c_end in adj_list:
586 # check that exactly 1 coordinate matches
587 if (c_start == c_end).sum() != 1:
588 raise ValueError("invalid connection")
590 # get the direction
591 d: int = (c_start != c_end).argmax()
593 x: int
594 y: int
595 # pick whichever has the lesser value in the direction `d`
596 if c_start[d] < c_end[d]:
597 x, y = c_start
598 else:
599 x, y = c_end
601 connection_list[d, x, y] = True
603 return LatticeMaze(
604 connection_list=connection_list,
605 )
607 def as_adj_list_tokens(self) -> list[str | CoordTup]:
608 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead"""
609 warnings.warn(
610 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.",
611 TokenizerDeprecationWarning,
612 )
613 return [
614 SPECIAL_TOKENS.ADJLIST_START,
615 *chain.from_iterable( # type: ignore[list-item]
616 [
617 [
618 tuple(c_s),
619 SPECIAL_TOKENS.CONNECTOR,
620 tuple(c_e),
621 SPECIAL_TOKENS.ADJACENCY_ENDLINE,
622 ]
623 for c_s, c_e in self.as_adj_list()
624 ],
625 ),
626 SPECIAL_TOKENS.ADJLIST_END,
627 ]
629 def _as_adj_list_tokens(self) -> list[str | CoordTup]:
630 return [
631 SPECIAL_TOKENS.ADJLIST_START,
632 *chain.from_iterable( # type: ignore[list-item]
633 [
634 [
635 tuple(c_s),
636 SPECIAL_TOKENS.CONNECTOR,
637 tuple(c_e),
638 SPECIAL_TOKENS.ADJACENCY_ENDLINE,
639 ]
640 for c_s, c_e in self.as_adj_list()
641 ],
642 ),
643 SPECIAL_TOKENS.ADJLIST_END,
644 ]
646 def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]:
647 """turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples"""
648 output: list[CoordTup | str] = self._as_adj_list_tokens()
649 # if getattr(self, "start_pos", None) is not None:
650 if isinstance(self, TargetedLatticeMaze):
651 output += self._get_start_pos_tokens()
652 if isinstance(self, TargetedLatticeMaze):
653 output += self._get_end_pos_tokens()
654 if isinstance(self, SolvedMaze):
655 output += self._get_solution_tokens()
656 return output
658 def _as_tokens(
659 self,
660 maze_tokenizer: "MazeTokenizer | TokenizationMode",
661 ) -> list[str]:
662 # type ignores here fine since we check the instance
663 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
664 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr]
665 if (
666 isinstance_by_type_name(maze_tokenizer, "MazeTokenizer")
667 and maze_tokenizer.is_AOTP() # type: ignore[union-attr]
668 ):
669 coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP()
670 coords_processed: list[str] = maze_tokenizer.coords_to_strings( # type: ignore[union-attr]
671 coords=coords_raw,
672 when_noncoord="include",
673 )
674 return coords_processed
675 else:
676 err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}"
677 raise NotImplementedError(err_msg)
679 def as_tokens(
680 self,
681 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
682 ) -> list[str]:
683 """serialize maze and solution to tokens"""
684 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"):
685 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr]
686 else:
687 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type]
689 @classmethod
690 def _from_tokens_AOTP(
691 cls,
692 tokens: list[str],
693 maze_tokenizer: "MazeTokenizer | MazeTokenizerModular",
694 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze":
695 """create a LatticeMaze from a list of tokens"""
696 # figure out what input format
697 # ========================================
698 if tokens[0] == SPECIAL_TOKENS.ADJLIST_START:
699 adj_list_tokens = get_adj_list_tokens(tokens)
700 else:
701 # If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens
702 adj_list_tokens = tokens
703 warnings.warn(
704 "Assuming input is just adjacency list tokens, no special tokens found",
705 )
707 # process edges for adjacency list
708 # ========================================
709 edges: list[list[str]] = list_split(
710 adj_list_tokens,
711 SPECIAL_TOKENS.ADJACENCY_ENDLINE,
712 )
714 coordinates: list[tuple[CoordTup, CoordTup]] = list()
715 for e in edges:
716 # skip last endline
717 if len(e) != 0:
718 # convert to coords, split start and end
719 e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords(
720 e,
721 when_noncoord="include",
722 )
723 # this assertion depends on the tokenizer having exactly one token for the connector
724 # which is also why we "include" above
725 # the connector token is discarded below
726 assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }" # noqa: PLR2004
727 assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, (
728 f"invalid edge: {e = } {e_coords = }"
729 )
730 e_coords_first: CoordTup = e_coords[0] # type: ignore[assignment]
731 e_coords_last: CoordTup = e_coords[-1] # type: ignore[assignment]
732 coordinates.append((e_coords_first, e_coords_last))
734 assert all(len(c) == DIM_2 for c in coordinates), (
735 f"invalid coordinates: {coordinates = }"
736 )
737 adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates)
738 assert tuple(adj_list.shape) == (
739 len(coordinates),
740 2,
741 2,
742 ), f"invalid adj_list: {adj_list.shape = } {coordinates = }"
744 output_maze: LatticeMaze = cls.from_adj_list(adj_list)
746 # add start and end positions
747 # ========================================
748 is_targeted: bool = False
749 if all(
750 x in tokens
751 for x in (
752 SPECIAL_TOKENS.ORIGIN_START,
753 SPECIAL_TOKENS.ORIGIN_END,
754 SPECIAL_TOKENS.TARGET_START,
755 SPECIAL_TOKENS.TARGET_END,
756 )
757 ):
758 start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords(
759 get_origin_tokens(tokens),
760 when_noncoord="error",
761 )
762 end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords(
763 get_target_tokens(tokens),
764 when_noncoord="error",
765 )
766 assert len(start_pos_list) == 1, (
767 f"invalid start_pos_list: {start_pos_list = }"
768 )
769 assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }"
771 start_pos: CoordTup = start_pos_list[0]
772 end_pos: CoordTup = end_pos_list[0]
774 output_maze = TargetedLatticeMaze.from_lattice_maze(
775 lattice_maze=output_maze,
776 start_pos=start_pos,
777 end_pos=end_pos,
778 )
780 is_targeted = True
782 if all(
783 x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END)
784 ):
785 assert is_targeted, "maze must be targeted to have a solution"
786 solution: list[CoordTup] = maze_tokenizer.strings_to_coords(
787 get_path_tokens(tokens, trim_end=True),
788 when_noncoord="error",
789 )
790 output_maze = SolvedMaze.from_targeted_lattice_maze(
791 # HACK: I think this is fine, but im not sure
792 targeted_lattice_maze=output_maze, # type: ignore[arg-type]
793 solution=solution,
794 )
796 return output_maze
798 # TODO: any way to get return type hinting working for this?
799 @classmethod
800 def from_tokens(
801 cls,
802 tokens: list[str],
803 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
804 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze":
805 """Constructs a maze from a tokenization.
807 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported.
808 """
809 # HACK: type ignores here fine since we check the instance
810 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
811 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr]
812 if (
813 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular")
814 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr]
815 ):
816 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}."
817 raise NotImplementedError(
818 err_msg,
819 )
821 if isinstance(tokens, str):
822 tokens = tokens.split()
824 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr]
825 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type]
826 else:
827 raise NotImplementedError("only AOTP tokenization is supported")
829 # ============================================================
830 # to and from pixels
831 # ============================================================
832 def _as_pixels_bw(self) -> BinaryPixelGrid:
833 assert self.lattice_dim == DIM_2, "only 2D mazes are supported"
834 # Create an empty pixel grid with walls
835 pixel_grid: Int[np.ndarray, "x y"] = np.full(
836 (self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1),
837 False,
838 dtype=np.bool_,
839 )
841 # Set white nodes
842 pixel_grid[1::2, 1::2] = True
844 # Set white connections (downward)
845 for i, row in enumerate(self.connection_list[0]):
846 for j, connected in enumerate(row):
847 if connected:
848 pixel_grid[i * 2 + 2, j * 2 + 1] = True
850 # Set white connections (rightward)
851 for i, row in enumerate(self.connection_list[1]):
852 for j, connected in enumerate(row):
853 if connected:
854 pixel_grid[i * 2 + 1, j * 2 + 2] = True
856 return pixel_grid
858 def as_pixels(
859 self,
860 show_endpoints: bool = True,
861 show_solution: bool = True,
862 ) -> PixelGrid:
863 """convert the maze to a pixel grid
865 - useful as a simpler way of plotting the maze than the more complex `MazePlot`
866 - the same underlying representation as `as_ascii` but as an image
867 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data
868 """
869 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze`
870 # but solution, start_pos, end_pos not always defined
871 # but its fine since we explicitly check the type
872 if show_solution and not show_endpoints:
873 raise ValueError("show_solution=True requires show_endpoints=True")
874 # convert original bool pixel grid to RGB
875 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw()
876 pixel_grid: PixelGrid = np.full(
877 (*pixel_grid_bw.shape, 3),
878 PixelColors.WALL,
879 dtype=np.uint8,
880 )
881 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712
883 if self.__class__ == LatticeMaze:
884 return pixel_grid
886 # set endpoints for TargetedLatticeMaze
887 if self.__class__ == TargetedLatticeMaze:
888 if show_endpoints:
889 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined]
890 PixelColors.START
891 )
892 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined]
893 PixelColors.END
894 )
895 return pixel_grid
897 # set solution -- we only reach this part if `self.__class__ == SolvedMaze`
898 if show_solution:
899 for coord in self.solution: # type: ignore[attr-defined]
900 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH
902 # set pixels between coords
903 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined]
904 next_coord = self.solution[index + 1] # type: ignore[attr-defined]
905 # check they are adjacent using norm
906 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, (
907 f"Coords {coord} and {next_coord} are not adjacent"
908 )
909 # set pixel between them
910 pixel_grid[
911 coord[0] * 2 + 1 + next_coord[0] - coord[0],
912 coord[1] * 2 + 1 + next_coord[1] - coord[1],
913 ] = PixelColors.PATH
915 # set endpoints (again, since path would overwrite them)
916 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined]
917 PixelColors.START
918 )
919 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined]
920 PixelColors.END
921 )
923 return pixel_grid
925 @classmethod
926 def _from_pixel_grid_bw(
927 cls,
928 pixel_grid: BinaryPixelGrid,
929 ) -> tuple[ConnectionList, tuple[int, int]]:
930 grid_shape: tuple[int, int] = (
931 pixel_grid.shape[0] // 2,
932 pixel_grid.shape[1] // 2,
933 )
934 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_)
936 # Extract downward connections
937 connection_list[0] = pixel_grid[2::2, 1::2]
939 # Extract rightward connections
940 connection_list[1] = pixel_grid[1::2, 2::2]
942 return connection_list, grid_shape
944 @classmethod
945 def _from_pixel_grid_with_positions(
946 cls,
947 pixel_grid: PixelGrid | BinaryPixelGrid,
948 marked_positions: dict[str, RGB],
949 ) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]:
950 # Convert RGB pixel grid to Bool pixel grid
951 # error: Incompatible types in assignment (expression has type
952 # "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]",
953 # variable has type "ndarray[Any, Any]") [assignment]
954 pixel_grid_bw: BinaryPixelGrid = ~np.all( # type: ignore[assignment]
955 pixel_grid == PixelColors.WALL,
956 axis=-1,
957 )
958 connection_list: ConnectionList
959 grid_shape: tuple[int, int]
960 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw)
962 # Find any marked positions
963 out_positions: dict[str, CoordArray] = dict()
964 for key, color in marked_positions.items():
965 pos_temp: Int[np.ndarray, "x y"] = np.argwhere(
966 np.all(pixel_grid == color, axis=-1),
967 )
968 pos_save: list[CoordTup] = list()
969 for pos in pos_temp:
970 # if it is a coordinate and not connection (transform position, %2==1)
971 if pos[0] % 2 == 1 and pos[1] % 2 == 1:
972 pos_save.append((pos[0] // 2, pos[1] // 2))
974 out_positions[key] = np.array(pos_save)
976 return connection_list, grid_shape, out_positions
978 @classmethod
979 def from_pixels(
980 cls,
981 pixel_grid: PixelGrid,
982 ) -> "LatticeMaze":
983 """create a LatticeMaze from a pixel grid. reverse of `as_pixels`
985 # Raises:
986 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze`
987 """
988 connection_list: ConnectionList
989 grid_shape: tuple[int, int]
991 # if a binary pixel grid, return regular LatticeMaze
992 if len(pixel_grid.shape) == 2: # noqa: PLR2004
993 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid)
994 return LatticeMaze(connection_list=connection_list)
996 # otherwise, detect and check it's valid
997 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid)
998 if cls not in cls_detected.__mro__:
999 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }"
1000 raise ValueError(
1001 err_msg,
1002 )
1004 (
1005 connection_list,
1006 grid_shape,
1007 marked_pos,
1008 ) = cls._from_pixel_grid_with_positions(
1009 pixel_grid=pixel_grid,
1010 marked_positions=dict(
1011 start=PixelColors.START,
1012 end=PixelColors.END,
1013 solution=PixelColors.PATH,
1014 ),
1015 )
1016 # if we wanted a LatticeMaze, return it
1017 if cls == LatticeMaze:
1018 return LatticeMaze(connection_list=connection_list)
1020 # otherwise, keep going
1021 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list)
1023 # start and end pos
1024 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"]
1025 assert start_pos_arr.shape == (
1026 1,
1027 2,
1028 ), (
1029 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1030 )
1031 assert end_pos_arr.shape == (
1032 1,
1033 2,
1034 ), (
1035 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1036 )
1038 start_pos: Coord = start_pos_arr[0]
1039 end_pos: Coord = end_pos_arr[0]
1041 # return a TargetedLatticeMaze if that's what we wanted
1042 if cls == TargetedLatticeMaze:
1043 return TargetedLatticeMaze(
1044 connection_list=connection_list,
1045 start_pos=start_pos,
1046 end_pos=end_pos,
1047 )
1049 # raw solution, only contains path elements and not start or end
1050 solution_raw: CoordArray = marked_pos["solution"]
1051 if len(solution_raw.shape) == 2: # noqa: PLR2004
1052 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004
1053 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)"
1054 )
1055 elif solution_raw.shape == (0,):
1056 # the solution and end should be immediately adjacent
1057 assert np.sum(np.abs(start_pos - end_pos)) == 1, (
1058 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given"
1059 )
1061 # order the solution, by creating a list from the start to the end
1062 # add end pos, since we will iterate over all these starting from the start pos
1063 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [
1064 tuple(end_pos),
1065 ]
1066 # solution starts with start point
1067 solution: list[CoordTup] = [tuple(start_pos)]
1068 while solution[-1] != tuple(end_pos):
1069 # use `get_coord_neighbors` to find connected neighbors
1070 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1])
1071 # TODO: make this less ugly
1072 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004
1073 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}"
1074 )
1075 # neighbors = neighbors[:, [1, 0]]
1076 # filter out neighbors that are not in the raw solution
1077 neighbors_filtered: CoordArray = np.array(
1078 [
1079 coord
1080 for coord in neighbors
1081 if (
1082 tuple(coord) in solution_raw_list
1083 and tuple(coord) not in solution
1084 )
1085 ],
1086 )
1087 # assert only one element is left, and then add it to the solution
1088 assert neighbors_filtered.shape == (
1089 1,
1090 2,
1091 ), (
1092 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}"
1093 )
1094 solution.append(tuple(neighbors_filtered[0]))
1096 # assert the solution is complete
1097 assert solution[0] == tuple(start_pos), (
1098 f"solution {solution} does not start at start_pos {start_pos}"
1099 )
1100 assert solution[-1] == tuple(end_pos), (
1101 f"solution {solution} does not end at end_pos {end_pos}"
1102 )
1104 return cls(
1105 connection_list=np.array(connection_list),
1106 solution=np.array(solution), # type: ignore[call-arg]
1107 )
1109 # ============================================================
1110 # to and from ASCII
1111 # ============================================================
1112 def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]:
1113 # Get the pixel grid using to_pixels().
1114 pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw()
1116 # Replace pixel values with ASCII characters.
1117 ascii_grid: Shaped[np.ndarray, "x y"] = np.full(
1118 pixel_grid.shape,
1119 AsciiChars.WALL,
1120 dtype=str,
1121 )
1122 ascii_grid[pixel_grid == True] = AsciiChars.OPEN # noqa: E712
1124 return ascii_grid
1126 def as_ascii(
1127 self,
1128 show_endpoints: bool = True,
1129 show_solution: bool = True,
1130 ) -> str:
1131 """return an ASCII grid of the maze
1133 useful for debugging in the terminal, or as it's own format
1135 can be reversed with `LatticeMaze.from_ascii()`
1136 """
1137 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid()
1138 pixel_grid: PixelGrid = self.as_pixels(
1139 show_endpoints=show_endpoints,
1140 show_solution=show_solution,
1141 )
1143 chars_replace: tuple = tuple()
1144 if show_endpoints:
1145 chars_replace += (AsciiChars.START, AsciiChars.END)
1146 if show_solution:
1147 chars_replace += (AsciiChars.PATH,)
1149 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1150 if ascii_char in chars_replace:
1151 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char
1153 return "\n".join("".join(row) for row in ascii_grid)
1155 @classmethod
1156 def from_ascii(cls, ascii_str: str) -> "LatticeMaze":
1157 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)"
1158 lines: list[str] = ascii_str.strip().split("\n")
1159 lines = [line.strip() for line in lines]
1160 ascii_grid: Shaped[np.ndarray, "x y"] = np.array(
1161 [list(line) for line in lines],
1162 dtype=str,
1163 )
1164 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8)
1166 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1167 pixel_grid[ascii_grid == ascii_char] = pixel_color
1169 return cls.from_pixels(pixel_grid)
1172# type ignore here even though theyre all frozen
1173# maybe `SerializeableDataclass` itself is not frozen, but thats an ABC
1174# error: Cannot inherit frozen dataclass from a non-frozen one [misc]
1175@serializable_dataclass(frozen=True, kw_only=True)
1176class TargetedLatticeMaze(LatticeMaze): # type: ignore[misc]
1177 """A LatticeMaze with a start and end position"""
1179 # this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos
1180 # type ignore here because even though its a kw-only dataclass,
1181 # mypy doesn't like that non-default arguments are after default arguments
1182 start_pos: Coord = serializable_field( # type: ignore[misc]
1183 assert_type=False,
1184 )
1185 end_pos: Coord = serializable_field( # type: ignore[misc]
1186 assert_type=False,
1187 )
1189 def __post_init__(self) -> None:
1190 "post init converts start and end pos to numpy arrays, checks they exist and are in bounds"
1191 # make things numpy arrays (very jank to override frozen dataclass)
1192 self.__dict__["start_pos"] = np.array(self.start_pos)
1193 self.__dict__["end_pos"] = np.array(self.end_pos)
1194 assert self.start_pos is not None
1195 assert self.end_pos is not None
1196 # check that start and end are in bounds
1197 if (
1198 self.start_pos[0] >= self.grid_shape[0]
1199 or self.start_pos[1] >= self.grid_shape[1]
1200 ):
1201 err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}"
1202 raise ValueError(
1203 err_msg,
1204 )
1205 if (
1206 self.end_pos[0] >= self.grid_shape[0]
1207 or self.end_pos[1] >= self.grid_shape[1]
1208 ):
1209 err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }"
1210 raise ValueError(
1211 err_msg,
1212 )
1214 def __eq__(self, other: object) -> bool:
1215 "check equality, calls parent class equality check"
1216 return super().__eq__(other)
1218 def _get_start_pos_tokens(self) -> list[str | CoordTup]:
1219 return [
1220 SPECIAL_TOKENS.ORIGIN_START,
1221 tuple(self.start_pos),
1222 SPECIAL_TOKENS.ORIGIN_END,
1223 ]
1225 def get_start_pos_tokens(self) -> list[str | CoordTup]:
1226 "(deprecated!) return the start position as a list of tokens"
1227 warnings.warn(
1228 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.",
1229 TokenizerDeprecationWarning,
1230 )
1231 return self._get_start_pos_tokens()
1233 def _get_end_pos_tokens(self) -> list[str | CoordTup]:
1234 return [
1235 SPECIAL_TOKENS.TARGET_START,
1236 tuple(self.end_pos),
1237 SPECIAL_TOKENS.TARGET_END,
1238 ]
1240 def get_end_pos_tokens(self) -> list[str | CoordTup]:
1241 "(deprecated!) return the end position as a list of tokens"
1242 warnings.warn(
1243 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.",
1244 TokenizerDeprecationWarning,
1245 )
1246 return self._get_end_pos_tokens()
1248 @classmethod
1249 def from_lattice_maze(
1250 cls,
1251 lattice_maze: LatticeMaze,
1252 start_pos: Coord | CoordTup,
1253 end_pos: Coord | CoordTup,
1254 ) -> "TargetedLatticeMaze":
1255 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions"
1256 return cls(
1257 connection_list=lattice_maze.connection_list,
1258 start_pos=np.array(start_pos),
1259 end_pos=np.array(end_pos),
1260 generation_meta=lattice_maze.generation_meta,
1261 )
1264@serializable_dataclass(frozen=True, kw_only=True)
1265class SolvedMaze(TargetedLatticeMaze): # type: ignore[misc]
1266 """Stores a maze and a solution"""
1268 solution: CoordArray = serializable_field( # type: ignore[misc]
1269 assert_type=False,
1270 )
1272 def __init__(
1273 self,
1274 connection_list: ConnectionList,
1275 solution: CoordArray,
1276 generation_meta: dict | None = None,
1277 start_pos: Coord | None = None,
1278 end_pos: Coord | None = None,
1279 allow_invalid: bool = False,
1280 ) -> None:
1281 """Create a SolvedMaze from a connection list and a solution
1283 > DOCS: better documentation for this init method
1284 """
1285 # figure out the solution
1286 solution_valid: bool = False
1287 if solution is not None:
1288 solution = np.array(solution)
1289 # note that a path length of 1 here is valid, since the start and end pos could be the same
1290 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004
1291 solution_valid = True
1293 if not solution_valid and not allow_invalid:
1294 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }"
1295 raise ValueError(
1296 err_msg,
1297 f"{connection_list = }",
1298 )
1300 # init the TargetedLatticeMaze
1301 super().__init__(
1302 connection_list=connection_list,
1303 generation_meta=generation_meta,
1304 # TODO: the argument type is stricter than the expected type but it still fails?
1305 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type
1306 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type]
1307 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type]
1308 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type]
1309 )
1311 self.__dict__["solution"] = solution
1313 # adjust the endpoints
1314 if not allow_invalid:
1315 if start_pos is not None:
1316 assert np.array_equal(np.array(start_pos), self.start_pos), (
1317 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}"
1318 )
1319 if end_pos is not None:
1320 assert np.array_equal(np.array(end_pos), self.end_pos), (
1321 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}"
1322 )
1323 # TODO: assert the path does not backtrack, walk through walls, etc?
1325 def __eq__(self, other: object) -> bool:
1326 "check equality, calls parent class equality check"
1327 return super().__eq__(other)
1329 def __hash__(self) -> int:
1330 "hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes"
1331 return hash((self.connection_list.tobytes(), self.solution.tobytes()))
1333 def _get_solution_tokens(self) -> list[str | CoordTup]:
1334 return [
1335 SPECIAL_TOKENS.PATH_START,
1336 *[tuple(c) for c in self.solution],
1337 SPECIAL_TOKENS.PATH_END,
1338 ]
1340 def get_solution_tokens(self) -> list[str | CoordTup]:
1341 "(deprecated!) return the solution as a list of tokens"
1342 warnings.warn(
1343 "`LatticeMaze.get_solution_tokens` is deprecated.",
1344 TokenizerDeprecationWarning,
1345 )
1346 return self._get_solution_tokens()
1348 # for backwards compatibility
1349 @property
1350 def maze(self) -> LatticeMaze:
1351 "(deprecated!) return the maze without the solution"
1352 warnings.warn(
1353 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.",
1354 DeprecationWarning,
1355 )
1356 return LatticeMaze(connection_list=self.connection_list)
1358 # type ignore here since we're overriding a method with a different signature
1359 @classmethod
1360 def from_lattice_maze( # type: ignore[override]
1361 cls,
1362 lattice_maze: LatticeMaze,
1363 solution: list[CoordTup] | CoordArray,
1364 ) -> "SolvedMaze":
1365 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution"
1366 return cls(
1367 connection_list=lattice_maze.connection_list,
1368 solution=np.array(solution),
1369 generation_meta=lattice_maze.generation_meta,
1370 )
1372 @classmethod
1373 def from_targeted_lattice_maze(
1374 cls,
1375 targeted_lattice_maze: TargetedLatticeMaze,
1376 solution: list[CoordTup] | CoordArray | None = None,
1377 ) -> "SolvedMaze":
1378 """solves the given targeted lattice maze and returns a SolvedMaze"""
1379 if solution is None:
1380 solution = targeted_lattice_maze.find_shortest_path(
1381 targeted_lattice_maze.start_pos,
1382 targeted_lattice_maze.end_pos,
1383 )
1384 return cls(
1385 connection_list=targeted_lattice_maze.connection_list,
1386 solution=np.array(solution),
1387 generation_meta=targeted_lattice_maze.generation_meta,
1388 )
1390 def get_solution_forking_points(
1391 self,
1392 always_include_endpoints: bool = False,
1393 ) -> tuple[list[int], CoordArray]:
1394 """coordinates and their indicies from the solution where a fork is present
1396 - if the start point is not a dead end, this counts as a fork
1397 - if the end point is not a dead end, this counts as a fork
1398 """
1399 output_idxs: list[int] = list()
1400 output_coords: list[CoordTup] = list()
1402 for idx, coord in enumerate(self.solution):
1403 # more than one choice for first coord, or more than 2 for any other
1404 # since the previous coord doesn't count as a choice
1405 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1
1406 theshold: int = 1 if is_endpoint else 2
1407 if self.get_coord_neighbors(coord).shape[0] > theshold or (
1408 is_endpoint and always_include_endpoints
1409 ):
1410 output_idxs.append(idx)
1411 output_coords.append(coord)
1413 return output_idxs, np.array(output_coords)
1415 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]:
1416 """coordinates from the solution where there is only a single (non-backtracking) point to move to
1418 returns the complement of `get_solution_forking_points` from the path
1419 """
1420 forks_idxs, _ = self.get_solution_forking_points()
1421 # HACK: idk why type ignore here
1422 return ( # type: ignore[return-value]
1423 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0),
1424 np.delete(self.solution, forks_idxs, axis=0),
1425 )
1428def detect_pixels_type(data: PixelGrid) -> typing.Type[LatticeMaze]:
1429 """Detects the type of pixels data by checking for the presence of start and end pixels"""
1430 if color_in_pixel_grid(data, PixelColors.START) or color_in_pixel_grid(
1431 data,
1432 PixelColors.END,
1433 ):
1434 if color_in_pixel_grid(data, PixelColors.PATH):
1435 return SolvedMaze
1436 else:
1437 return TargetedLatticeMaze
1438 else:
1439 return LatticeMaze
1442def _remove_isolated_cells(
1443 image: Int[np.ndarray, "RGB x y"],
1444) -> Int[np.ndarray, "RGB x y"]:
1445 """Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides."""
1446 # Create a binary mask where True represents walls
1447 wall_mask = np.all(image == PixelColors.WALL, axis=-1)
1449 # Pad the wall mask to handle edge cases
1450 padded_wall_mask = np.pad(
1451 wall_mask,
1452 ((1, 1), (1, 1)),
1453 mode="constant",
1454 constant_values=True,
1455 )
1457 # Check neighbors in all four directions
1458 isolated_mask = (
1459 padded_wall_mask[1:-1, 2:] # right
1460 & padded_wall_mask[1:-1, :-2] # left
1461 & padded_wall_mask[2:, 1:-1] # down
1462 & padded_wall_mask[:-2, 1:-1] # up
1463 )
1465 # Combine with non-wall mask to only affect open cells
1466 non_wall_mask = ~wall_mask
1467 isolated_mask = isolated_mask & non_wall_mask
1469 # Create the output image
1470 output_image = image.copy()
1471 output_image[isolated_mask] = PixelColors.WALL
1473 return output_image
1476_RIC_PADS: dict = {
1477 "left": ((1, 0), (0, 0)),
1478 "right": ((0, 1), (0, 0)),
1479 "up": ((0, 0), (1, 0)),
1480 "down": ((0, 0), (0, 1)),
1481}
1483# Define slices for each direction
1484_RIC_SLICES: dict = {
1485 "left": (slice(1, None), slice(None, None)),
1486 "right": (slice(None, -1), slice(None, None)),
1487 "up": (slice(None, None), slice(1, None)),
1488 "down": (slice(None, None), slice(None, -1)),
1489}
1492# TODO: figure out why this function doesnt work, or maybe just get rid of it
1493# def _remove_isolated_cells_old(
1494# image: Int[np.ndarray, "RGB x y"],
1495# ) -> Int[np.ndarray, "RGB x y"]:
1496# """
1497# Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides.
1498# """
1499# warnings.warn("this functin doesn't work and I have no idea why!!!")
1500# masks: dict[str, np.ndarray] = {
1501# d: np.all(
1502# np.pad(
1503# image[_RIC_SLICES[d][0], _RIC_SLICES[d][1], :] == PixelColors.WALL,
1504# np.array((*_RIC_PADS[d], (0, 0)), dtype=np.int8),
1505# mode="constant",
1506# constant_values=True,
1507# ),
1508# axis=2,
1509# )
1510# for d in _RIC_SLICES.keys()
1511# }
1513# # Create a mask for non-wall cells
1514# mask_non_wall = np.all(image != PixelColors.WALL, axis=2)
1516# # print(f"{mask_non_wall.shape = }")
1517# # print(f"{ {k: masks[k].shape for k in masks.keys()} = }")
1519# # print(f"{mask_non_wall = }")
1520# # print(f"{masks['down'] = }")
1522# # Combine the masks
1523# mask = mask_non_wall & masks["left"] & masks["right"] & masks["up"] & masks["down"]
1525# # Apply the mask
1526# output_image = np.where(
1527# np.stack([mask] * 3, axis=-1),
1528# PixelColors.WALL,
1529# image,
1530# )
1532# return output_image