Coverage for maze_dataset/maze/lattice_maze.py: 65%

508 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 00:51 -0600

1"""Implements `LatticeMaze`, and the `TargetedLatticeMaze` and `SolvedMaze` subclasses. 

2 

3also includes basic utilities, including converting to/from ascii and pixel representations. 

4""" 

5 

6import typing 

7import warnings 

8from dataclasses import dataclass 

9from itertools import chain 

10 

11import numpy as np 

12from jaxtyping import Bool, Int, Int8, Shaped 

13from muutils.json_serialize.serializable_dataclass import ( 

14 SerializableDataclass, 

15 serializable_dataclass, 

16 serializable_field, 

17) 

18from muutils.misc import isinstance_by_type_name, list_split 

19 

20from maze_dataset.constants import ( 

21 NEIGHBORS_MASK, 

22 SPECIAL_TOKENS, 

23 ConnectionList, 

24 Coord, 

25 CoordArray, 

26 CoordList, 

27 CoordTup, 

28) 

29from maze_dataset.token_utils import ( 

30 TokenizerDeprecationWarning, 

31 connection_list_to_adj_list, 

32 get_adj_list_tokens, 

33 get_origin_tokens, 

34 get_path_tokens, 

35 get_target_tokens, 

36) 

37 

38if typing.TYPE_CHECKING: 

39 from maze_dataset.tokenization import ( 

40 MazeTokenizer, 

41 MazeTokenizerModular, 

42 TokenizationMode, 

43 ) 

44 

45RGB = tuple[int, int, int] 

46"rgb tuple of values 0-255" 

47 

48PixelGrid = Int[np.ndarray, "x y rgb"] 

49"rgb grid of pixels" 

50BinaryPixelGrid = Bool[np.ndarray, "x y"] 

51"boolean grid of pixels" 

52 

53DIM_2: int = 2 

54"2 dimensions" 

55 

56 

57class NoValidEndpointException(Exception): # noqa: N818 

58 """Raised when no valid start or end positions are found in a maze.""" 

59 

60 pass 

61 

62 

63def _fill_edges_with_walls(connection_list: ConnectionList) -> ConnectionList: 

64 """fill the last elements of the connections lists as false for each dim""" 

65 for dim in range(connection_list.shape[0]): 

66 # last row for down 

67 if dim == 0: 

68 connection_list[dim, -1, :] = False 

69 # last column for right 

70 elif dim == 1: 

71 connection_list[dim, :, -1] = False 

72 else: 

73 err_msg: str = f"only 2d lattices supported. got {dim=}" 

74 raise NotImplementedError(err_msg) 

75 return connection_list 

76 

77 

78def color_in_pixel_grid(pixel_grid: PixelGrid, color: RGB) -> bool: 

79 """check if a color is in a pixel grid""" 

80 for row in pixel_grid: 

81 for pixel in row: 

82 if np.all(pixel == color): 

83 return True 

84 return False 

85 

86 

87@dataclass(frozen=True) 

88class PixelColors: 

89 "standard colors for pixel grids" 

90 

91 WALL: RGB = (0, 0, 0) 

92 OPEN: RGB = (255, 255, 255) 

93 START: RGB = (0, 255, 0) 

94 END: RGB = (255, 0, 0) 

95 PATH: RGB = (0, 0, 255) 

96 

97 

98@dataclass(frozen=True) 

99class AsciiChars: 

100 "standard ascii characters for mazes" 

101 

102 WALL: str = "#" 

103 OPEN: str = " " 

104 START: str = "S" 

105 END: str = "E" 

106 PATH: str = "X" 

107 

108 

109ASCII_PIXEL_PAIRINGS: dict[str, RGB] = { 

110 AsciiChars.WALL: PixelColors.WALL, 

111 AsciiChars.OPEN: PixelColors.OPEN, 

112 AsciiChars.START: PixelColors.START, 

113 AsciiChars.END: PixelColors.END, 

114 AsciiChars.PATH: PixelColors.PATH, 

115} 

116"map ascii characters to pixel colors" 

117 

118 

119@serializable_dataclass( 

120 frozen=True, 

121 kw_only=True, 

122 properties_to_serialize=["lattice_dim", "generation_meta"], 

123) 

124class LatticeMaze(SerializableDataclass): 

125 """lattice maze (nodes on a lattice, connections only to neighboring nodes) 

126 

127 Connection List represents which nodes (N) are connected in each direction. 

128 

129 First and second elements represent rightward and downward connections, 

130 respectively. 

131 

132 Example: 

133 Connection list: 

134 [ 

135 [ # down 

136 [F T], 

137 [F F] 

138 ], 

139 [ # right 

140 [T F], 

141 [T F] 

142 ] 

143 ] 

144 

145 Nodes with connections 

146 N T N F 

147 F T 

148 N T N F 

149 F F 

150 

151 Graph: 

152 N - N 

153 | 

154 N - N 

155 

156 Note: the bottom row connections going down, and the 

157 right-hand connections going right, will always be False. 

158 

159 """ 

160 

161 connection_list: ConnectionList 

162 generation_meta: dict | None = serializable_field(default=None, compare=False) 

163 

164 lattice_dim = property(lambda self: self.connection_list.shape[0]) 

165 grid_shape = property(lambda self: self.connection_list.shape[1:]) 

166 n_connections = property(lambda self: self.connection_list.sum()) 

167 

168 @property 

169 def grid_n(self) -> int: 

170 "grid size as int, raises `AssertionError` if not square" 

171 assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" 

172 return self.grid_shape[0] 

173 

174 # ============================================================ 

175 # basic methods 

176 # ============================================================ 

177 

178 def __eq__(self, other: object) -> bool: 

179 "equality check calls super" 

180 return super().__eq__(other) 

181 

182 @staticmethod 

183 def heuristic(a: CoordTup, b: CoordTup) -> float: 

184 """return manhattan distance between two points""" 

185 return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1]) 

186 

187 def __hash__(self) -> int: 

188 """hash the connection list by converting connection list to bytes""" 

189 return hash(self.connection_list.tobytes()) 

190 

191 def nodes_connected(self, a: Coord, b: Coord, /) -> bool: 

192 """returns whether two nodes are connected""" 

193 delta: Coord = b - a 

194 if np.abs(delta).sum() != 1: 

195 # return false if not even adjacent 

196 return False 

197 else: 

198 # test for wall 

199 dim: int = int(np.argmax(np.abs(delta))) 

200 clist_node: Coord = a if (delta.sum() > 0) else b 

201 return self.connection_list[dim, clist_node[0], clist_node[1]] 

202 

203 def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: 

204 """check if a path is valid""" 

205 # check path is not empty 

206 if len(path) == 0: 

207 return empty_is_valid 

208 

209 # check all coords in bounds of maze 

210 if not np.all((path >= 0) & (path < self.grid_shape)): 

211 return False 

212 

213 # check all nodes connected 

214 for i in range(len(path) - 1): 

215 if not self.nodes_connected(path[i], path[i + 1]): 

216 return False 

217 return True 

218 

219 def coord_degrees(self) -> Int8[np.ndarray, "row col"]: 

220 """Returns an array with the connectivity degree of each coord. 

221 

222 I.e., how many neighbors each coord has. 

223 """ 

224 int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( 

225 self.connection_list.astype(np.int8) 

226 ) 

227 degrees: Int8[np.ndarray, "row col"] = np.sum( 

228 int_conn, 

229 axis=0, 

230 ) # Connections to east and south 

231 degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west 

232 degrees[1:, :] += int_conn[0, :-1, :] # Connections to north 

233 return degrees 

234 

235 def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: 

236 """Returns an array of the neighboring, connected coords of `c`.""" 

237 c = np.array(c) # type: ignore[assignment] 

238 neighbors: list[Coord] = [ 

239 neighbor 

240 for neighbor in (c + NEIGHBORS_MASK) 

241 if ( 

242 (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds 

243 and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds 

244 and self.nodes_connected(c, neighbor) # connected 

245 ) 

246 ] 

247 

248 output: CoordArray = np.array(neighbors) 

249 if len(neighbors) > 0: 

250 assert output.shape == ( 

251 len(neighbors), 

252 2, 

253 ), ( 

254 f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" 

255 ) 

256 return output 

257 

258 def gen_connected_component_from(self, c: Coord) -> CoordArray: 

259 """return the connected component from a given coordinate""" 

260 # Stack for DFS 

261 stack: list[Coord] = [c] 

262 

263 # Set to store visited nodes 

264 visited: set[CoordTup] = set() 

265 

266 while stack: 

267 current_node: Coord = stack.pop() 

268 # this is fine since we know current_node is a coord and thus of length 2 

269 visited.add(tuple(current_node)) # type: ignore[arg-type] 

270 

271 # Get the neighbors of the current node 

272 neighbors = self.get_coord_neighbors(current_node) 

273 

274 # Iterate over neighbors 

275 for neighbor in neighbors: 

276 if tuple(neighbor) not in visited: 

277 stack.append(neighbor) 

278 

279 return np.array(list(visited)) 

280 

281 def find_shortest_path( 

282 self, 

283 c_start: CoordTup | Coord, 

284 c_end: CoordTup | Coord, 

285 ) -> CoordArray: 

286 """find the shortest path between two coordinates, using A*""" 

287 c_start = tuple(c_start) # type: ignore[assignment] 

288 c_end = tuple(c_end) # type: ignore[assignment] 

289 

290 g_score: dict[CoordTup, float] = ( 

291 dict() 

292 ) # cost of cheapest path to node from start currently known 

293 f_score: dict[CoordTup, float] = { 

294 c_start: 0.0, 

295 } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) 

296 

297 # init 

298 g_score[c_start] = 0.0 

299 g_score[c_start] = self.heuristic(c_start, c_end) 

300 

301 closed_vtx: set[CoordTup] = set() # nodes already evaluated 

302 # nodes to be evaluated 

303 # we need a set of the tuples, dont place the ints in the set 

304 open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 

305 source: dict[CoordTup, CoordTup] = ( 

306 dict() 

307 ) # node immediately preceding each node in the path (currently known shortest path) 

308 

309 while open_vtx: 

310 # get lowest f_score node 

311 # mypy cant tell that c is of length 2 

312 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] 

313 # f_current: float = f_score[c_current] 

314 

315 # check if goal is reached 

316 if c_end == c_current: 

317 path: list[CoordTup] = [c_current] 

318 p_current: CoordTup = c_current 

319 while p_current in source: 

320 p_current = source[p_current] 

321 path.append(p_current) 

322 # ---------------------------------------------------------------------- 

323 # this is the only return statement 

324 return np.array(path[::-1]) 

325 # ---------------------------------------------------------------------- 

326 

327 # close current node 

328 closed_vtx.add(c_current) 

329 open_vtx.remove(c_current) 

330 

331 # update g_score of neighbors 

332 _np_neighbor: Coord 

333 for _np_neighbor in self.get_coord_neighbors(c_current): 

334 neighbor: CoordTup = tuple(_np_neighbor) 

335 

336 if neighbor in closed_vtx: 

337 # already checked 

338 continue 

339 g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors 

340 

341 if neighbor not in open_vtx: 

342 # found new vtx, so add 

343 open_vtx.add(neighbor) 

344 

345 elif g_temp >= g_score[neighbor]: 

346 # if already knew about this one, but current g_score is worse, skip 

347 continue 

348 

349 # store g_score and source 

350 source[neighbor] = c_current 

351 g_score[neighbor] = g_temp 

352 f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) 

353 

354 raise ValueError( 

355 "A solution could not be found!", 

356 f"{c_start = }, {c_end = }", 

357 self.as_ascii(), 

358 ) 

359 

360 def get_nodes(self) -> CoordArray: 

361 """return a list of all nodes in the maze""" 

362 rows: Int[np.ndarray, "x y"] 

363 cols: Int[np.ndarray, "x y"] 

364 rows, cols = np.meshgrid( 

365 range(self.grid_shape[0]), 

366 range(self.grid_shape[1]), 

367 indexing="ij", 

368 ) 

369 nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T 

370 return nodes 

371 

372 def get_connected_component(self) -> CoordArray: 

373 """get the largest (and assumed only nonsingular) connected component of the maze 

374 

375 TODO: other connected components? 

376 """ 

377 if (self.generation_meta is None) or ( 

378 self.generation_meta.get("fully_connected", False) 

379 ): 

380 # for fully connected case, pick any two positions 

381 return self.get_nodes() 

382 else: 

383 # if metadata provided, use visited cells 

384 visited_cells: set[CoordTup] | None = self.generation_meta.get( 

385 "visited_cells", 

386 None, 

387 ) 

388 if visited_cells is None: 

389 # TODO: dynamically generate visited_cells? 

390 err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" 

391 raise ValueError( 

392 err_msg, 

393 ) 

394 visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) 

395 return visited_cells_np 

396 

397 @typing.overload 

398 def generate_random_path( 

399 self, 

400 allowed_start: CoordList | None = None, 

401 allowed_end: CoordList | None = None, 

402 deadend_start: bool = False, 

403 deadend_end: bool = False, 

404 endpoints_not_equal: bool = False, 

405 except_on_no_valid_endpoint: typing.Literal[True] = True, 

406 ) -> CoordArray: ... 

407 @typing.overload 

408 def generate_random_path( 

409 self, 

410 allowed_start: CoordList | None = None, 

411 allowed_end: CoordList | None = None, 

412 deadend_start: bool = False, 

413 deadend_end: bool = False, 

414 endpoints_not_equal: bool = False, 

415 except_on_no_valid_endpoint: typing.Literal[False] = False, 

416 ) -> typing.Optional[CoordArray]: ... 

417 def generate_random_path( # noqa: C901 

418 self, 

419 allowed_start: CoordList | None = None, 

420 allowed_end: CoordList | None = None, 

421 deadend_start: bool = False, 

422 deadend_end: bool = False, 

423 endpoints_not_equal: bool = False, 

424 except_on_no_valid_endpoint: bool = True, 

425 ) -> typing.Optional[CoordArray]: 

426 """return a path between randomly chosen start and end nodes within the connected component 

427 

428 Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. 

429 

430 # Parameters: 

431 - `allowed_start : CoordList | None` 

432 a list of allowed start positions. If `None`, any position in the connected component is allowed 

433 (defaults to `None`) 

434 - `allowed_end : CoordList | None` 

435 a list of allowed end positions. If `None`, any position in the connected component is allowed 

436 (defaults to `None`) 

437 - `deadend_start : bool` 

438 whether to ***force*** the start position to be a deadend (defaults to `False`) 

439 (defaults to `False`) 

440 - `deadend_end : bool` 

441 whether to ***force*** the end position to be a deadend (defaults to `False`) 

442 (defaults to `False`) 

443 - `endpoints_not_equal : bool` 

444 whether to ensure tha the start and end point are not the same 

445 (defaults to `False`) 

446 - `except_on_no_valid_endpoint : bool` 

447 whether to raise an error if no valid start or end positions are found 

448 if this is `False`, the function might return `None` and this must be handled by the caller 

449 (defaults to `True`) 

450 

451 # Returns: 

452 - `CoordArray` 

453 a path between the selected start and end positions 

454 

455 # Raises: 

456 - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` 

457 """ 

458 # we can't create a "path" in a single-node maze 

459 assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 

460 f"can't create path in single-node maze: {self.as_ascii()}" 

461 ) 

462 

463 # get connected component 

464 connected_component: CoordArray = self.get_connected_component() 

465 

466 # initialize start and end positions 

467 positions: Int[np.int8, "2 2"] 

468 

469 # if no special conditions on start and end positions 

470 if (allowed_start, allowed_end, deadend_start, deadend_end) == ( 

471 None, 

472 None, 

473 False, 

474 False, 

475 ): 

476 try: 

477 positions = connected_component[ # type: ignore[assignment] 

478 np.random.choice( 

479 len(connected_component), 

480 size=2, 

481 replace=False, 

482 ) 

483 ] 

484 except ValueError as e: 

485 if except_on_no_valid_endpoint: 

486 err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" 

487 raise NoValidEndpointException( 

488 err_msg, 

489 ) from e 

490 return None 

491 

492 return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] 

493 

494 # handle special conditions 

495 connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) 

496 # copy connected component set 

497 allowed_start_set: set[CoordTup] = connected_component_set.copy() 

498 allowed_end_set: set[CoordTup] = connected_component_set.copy() 

499 

500 # filter by explicitly allowed start and end positions 

501 # '# type: ignore[assignment]' here because the returned tuple can be of any length 

502 if allowed_start is not None: 

503 allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] 

504 

505 if allowed_end is not None: 

506 allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] 

507 

508 # filter by forcing deadends 

509 if deadend_start: 

510 allowed_start_set = set( 

511 filter( 

512 lambda x: len(self.get_coord_neighbors(x)) == 1, 

513 allowed_start_set, 

514 ), 

515 ) 

516 

517 if deadend_end: 

518 allowed_end_set = set( 

519 filter( 

520 lambda x: len(self.get_coord_neighbors(x)) == 1, 

521 allowed_end_set, 

522 ), 

523 ) 

524 

525 # check we have valid positions 

526 if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: 

527 if except_on_no_valid_endpoint: 

528 err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" 

529 raise NoValidEndpointException( 

530 err_msg, 

531 ) 

532 return None 

533 

534 # randomly select start and end positions 

535 try: 

536 # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok 

537 start_pos: CoordTup = tuple( # type: ignore[assignment] 

538 list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], 

539 ) 

540 if endpoints_not_equal: 

541 # remove start position from end positions 

542 allowed_end_set.discard(start_pos) 

543 end_pos: CoordTup = tuple( # type: ignore[assignment] 

544 list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], 

545 ) 

546 except ValueError as e: 

547 if except_on_no_valid_endpoint: 

548 err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" 

549 raise NoValidEndpointException( 

550 err_msg, 

551 ) from e 

552 return None 

553 

554 return self.find_shortest_path(start_pos, end_pos) 

555 

556 # ============================================================ 

557 # to and from adjacency list 

558 # ============================================================ 

559 def as_adj_list( 

560 self, 

561 shuffle_d0: bool = True, 

562 shuffle_d1: bool = True, 

563 ) -> Int8[np.ndarray, "conn start_end coord"]: 

564 """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" 

565 return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1) 

566 

567 @classmethod 

568 def from_adj_list( 

569 cls, 

570 adj_list: Int8[np.ndarray, "conn start_end coord"], 

571 ) -> "LatticeMaze": 

572 """create a LatticeMaze from a list of connections 

573 

574 > [!NOTE] 

575 > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. 

576 """ 

577 # this is where it would probably break for rectangular mazes 

578 grid_n: int = adj_list.max() + 1 

579 

580 connection_list: ConnectionList = np.zeros( 

581 (2, grid_n, grid_n), 

582 dtype=np.bool_, 

583 ) 

584 

585 for c_start, c_end in adj_list: 

586 # check that exactly 1 coordinate matches 

587 if (c_start == c_end).sum() != 1: 

588 raise ValueError("invalid connection") 

589 

590 # get the direction 

591 d: int = (c_start != c_end).argmax() 

592 

593 x: int 

594 y: int 

595 # pick whichever has the lesser value in the direction `d` 

596 if c_start[d] < c_end[d]: 

597 x, y = c_start 

598 else: 

599 x, y = c_end 

600 

601 connection_list[d, x, y] = True 

602 

603 return LatticeMaze( 

604 connection_list=connection_list, 

605 ) 

606 

607 def as_adj_list_tokens(self) -> list[str | CoordTup]: 

608 """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" 

609 warnings.warn( 

610 "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", 

611 TokenizerDeprecationWarning, 

612 ) 

613 return [ 

614 SPECIAL_TOKENS.ADJLIST_START, 

615 *chain.from_iterable( # type: ignore[list-item] 

616 [ 

617 [ 

618 tuple(c_s), 

619 SPECIAL_TOKENS.CONNECTOR, 

620 tuple(c_e), 

621 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 

622 ] 

623 for c_s, c_e in self.as_adj_list() 

624 ], 

625 ), 

626 SPECIAL_TOKENS.ADJLIST_END, 

627 ] 

628 

629 def _as_adj_list_tokens(self) -> list[str | CoordTup]: 

630 return [ 

631 SPECIAL_TOKENS.ADJLIST_START, 

632 *chain.from_iterable( # type: ignore[list-item] 

633 [ 

634 [ 

635 tuple(c_s), 

636 SPECIAL_TOKENS.CONNECTOR, 

637 tuple(c_e), 

638 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 

639 ] 

640 for c_s, c_e in self.as_adj_list() 

641 ], 

642 ), 

643 SPECIAL_TOKENS.ADJLIST_END, 

644 ] 

645 

646 def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]: 

647 """turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples""" 

648 output: list[CoordTup | str] = self._as_adj_list_tokens() 

649 # if getattr(self, "start_pos", None) is not None: 

650 if isinstance(self, TargetedLatticeMaze): 

651 output += self._get_start_pos_tokens() 

652 if isinstance(self, TargetedLatticeMaze): 

653 output += self._get_end_pos_tokens() 

654 if isinstance(self, SolvedMaze): 

655 output += self._get_solution_tokens() 

656 return output 

657 

658 def _as_tokens( 

659 self, 

660 maze_tokenizer: "MazeTokenizer | TokenizationMode", 

661 ) -> list[str]: 

662 # type ignores here fine since we check the instance 

663 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 

664 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 

665 if ( 

666 isinstance_by_type_name(maze_tokenizer, "MazeTokenizer") 

667 and maze_tokenizer.is_AOTP() # type: ignore[union-attr] 

668 ): 

669 coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP() 

670 coords_processed: list[str] = maze_tokenizer.coords_to_strings( # type: ignore[union-attr] 

671 coords=coords_raw, 

672 when_noncoord="include", 

673 ) 

674 return coords_processed 

675 else: 

676 err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}" 

677 raise NotImplementedError(err_msg) 

678 

679 def as_tokens( 

680 self, 

681 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 

682 ) -> list[str]: 

683 """serialize maze and solution to tokens""" 

684 if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): 

685 return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] 

686 else: 

687 return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type] 

688 

689 @classmethod 

690 def _from_tokens_AOTP( 

691 cls, 

692 tokens: list[str], 

693 maze_tokenizer: "MazeTokenizer | MazeTokenizerModular", 

694 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 

695 """create a LatticeMaze from a list of tokens""" 

696 # figure out what input format 

697 # ======================================== 

698 if tokens[0] == SPECIAL_TOKENS.ADJLIST_START: 

699 adj_list_tokens = get_adj_list_tokens(tokens) 

700 else: 

701 # If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens 

702 adj_list_tokens = tokens 

703 warnings.warn( 

704 "Assuming input is just adjacency list tokens, no special tokens found", 

705 ) 

706 

707 # process edges for adjacency list 

708 # ======================================== 

709 edges: list[list[str]] = list_split( 

710 adj_list_tokens, 

711 SPECIAL_TOKENS.ADJACENCY_ENDLINE, 

712 ) 

713 

714 coordinates: list[tuple[CoordTup, CoordTup]] = list() 

715 for e in edges: 

716 # skip last endline 

717 if len(e) != 0: 

718 # convert to coords, split start and end 

719 e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords( 

720 e, 

721 when_noncoord="include", 

722 ) 

723 # this assertion depends on the tokenizer having exactly one token for the connector 

724 # which is also why we "include" above 

725 # the connector token is discarded below 

726 assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }" # noqa: PLR2004 

727 assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, ( 

728 f"invalid edge: {e = } {e_coords = }" 

729 ) 

730 e_coords_first: CoordTup = e_coords[0] # type: ignore[assignment] 

731 e_coords_last: CoordTup = e_coords[-1] # type: ignore[assignment] 

732 coordinates.append((e_coords_first, e_coords_last)) 

733 

734 assert all(len(c) == DIM_2 for c in coordinates), ( 

735 f"invalid coordinates: {coordinates = }" 

736 ) 

737 adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates) 

738 assert tuple(adj_list.shape) == ( 

739 len(coordinates), 

740 2, 

741 2, 

742 ), f"invalid adj_list: {adj_list.shape = } {coordinates = }" 

743 

744 output_maze: LatticeMaze = cls.from_adj_list(adj_list) 

745 

746 # add start and end positions 

747 # ======================================== 

748 is_targeted: bool = False 

749 if all( 

750 x in tokens 

751 for x in ( 

752 SPECIAL_TOKENS.ORIGIN_START, 

753 SPECIAL_TOKENS.ORIGIN_END, 

754 SPECIAL_TOKENS.TARGET_START, 

755 SPECIAL_TOKENS.TARGET_END, 

756 ) 

757 ): 

758 start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 

759 get_origin_tokens(tokens), 

760 when_noncoord="error", 

761 ) 

762 end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( 

763 get_target_tokens(tokens), 

764 when_noncoord="error", 

765 ) 

766 assert len(start_pos_list) == 1, ( 

767 f"invalid start_pos_list: {start_pos_list = }" 

768 ) 

769 assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }" 

770 

771 start_pos: CoordTup = start_pos_list[0] 

772 end_pos: CoordTup = end_pos_list[0] 

773 

774 output_maze = TargetedLatticeMaze.from_lattice_maze( 

775 lattice_maze=output_maze, 

776 start_pos=start_pos, 

777 end_pos=end_pos, 

778 ) 

779 

780 is_targeted = True 

781 

782 if all( 

783 x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END) 

784 ): 

785 assert is_targeted, "maze must be targeted to have a solution" 

786 solution: list[CoordTup] = maze_tokenizer.strings_to_coords( 

787 get_path_tokens(tokens, trim_end=True), 

788 when_noncoord="error", 

789 ) 

790 output_maze = SolvedMaze.from_targeted_lattice_maze( 

791 # HACK: I think this is fine, but im not sure 

792 targeted_lattice_maze=output_maze, # type: ignore[arg-type] 

793 solution=solution, 

794 ) 

795 

796 return output_maze 

797 

798 # TODO: any way to get return type hinting working for this? 

799 @classmethod 

800 def from_tokens( 

801 cls, 

802 tokens: list[str], 

803 maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", 

804 ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": 

805 """Constructs a maze from a tokenization. 

806 

807 Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. 

808 """ 

809 # HACK: type ignores here fine since we check the instance 

810 if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): 

811 maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] 

812 if ( 

813 isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") 

814 and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] 

815 ): 

816 err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." 

817 raise NotImplementedError( 

818 err_msg, 

819 ) 

820 

821 if isinstance(tokens, str): 

822 tokens = tokens.split() 

823 

824 if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] 

825 return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] 

826 else: 

827 raise NotImplementedError("only AOTP tokenization is supported") 

828 

829 # ============================================================ 

830 # to and from pixels 

831 # ============================================================ 

832 def _as_pixels_bw(self) -> BinaryPixelGrid: 

833 assert self.lattice_dim == DIM_2, "only 2D mazes are supported" 

834 # Create an empty pixel grid with walls 

835 pixel_grid: Int[np.ndarray, "x y"] = np.full( 

836 (self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1), 

837 False, 

838 dtype=np.bool_, 

839 ) 

840 

841 # Set white nodes 

842 pixel_grid[1::2, 1::2] = True 

843 

844 # Set white connections (downward) 

845 for i, row in enumerate(self.connection_list[0]): 

846 for j, connected in enumerate(row): 

847 if connected: 

848 pixel_grid[i * 2 + 2, j * 2 + 1] = True 

849 

850 # Set white connections (rightward) 

851 for i, row in enumerate(self.connection_list[1]): 

852 for j, connected in enumerate(row): 

853 if connected: 

854 pixel_grid[i * 2 + 1, j * 2 + 2] = True 

855 

856 return pixel_grid 

857 

858 def as_pixels( 

859 self, 

860 show_endpoints: bool = True, 

861 show_solution: bool = True, 

862 ) -> PixelGrid: 

863 """convert the maze to a pixel grid 

864 

865 - useful as a simpler way of plotting the maze than the more complex `MazePlot` 

866 - the same underlying representation as `as_ascii` but as an image 

867 - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data 

868 """ 

869 # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` 

870 # but solution, start_pos, end_pos not always defined 

871 # but its fine since we explicitly check the type 

872 if show_solution and not show_endpoints: 

873 raise ValueError("show_solution=True requires show_endpoints=True") 

874 # convert original bool pixel grid to RGB 

875 pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() 

876 pixel_grid: PixelGrid = np.full( 

877 (*pixel_grid_bw.shape, 3), 

878 PixelColors.WALL, 

879 dtype=np.uint8, 

880 ) 

881 pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 

882 

883 if self.__class__ == LatticeMaze: 

884 return pixel_grid 

885 

886 # set endpoints for TargetedLatticeMaze 

887 if self.__class__ == TargetedLatticeMaze: 

888 if show_endpoints: 

889 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 

890 PixelColors.START 

891 ) 

892 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 

893 PixelColors.END 

894 ) 

895 return pixel_grid 

896 

897 # set solution -- we only reach this part if `self.__class__ == SolvedMaze` 

898 if show_solution: 

899 for coord in self.solution: # type: ignore[attr-defined] 

900 pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH 

901 

902 # set pixels between coords 

903 for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] 

904 next_coord = self.solution[index + 1] # type: ignore[attr-defined] 

905 # check they are adjacent using norm 

906 assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( 

907 f"Coords {coord} and {next_coord} are not adjacent" 

908 ) 

909 # set pixel between them 

910 pixel_grid[ 

911 coord[0] * 2 + 1 + next_coord[0] - coord[0], 

912 coord[1] * 2 + 1 + next_coord[1] - coord[1], 

913 ] = PixelColors.PATH 

914 

915 # set endpoints (again, since path would overwrite them) 

916 pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 

917 PixelColors.START 

918 ) 

919 pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] 

920 PixelColors.END 

921 ) 

922 

923 return pixel_grid 

924 

925 @classmethod 

926 def _from_pixel_grid_bw( 

927 cls, 

928 pixel_grid: BinaryPixelGrid, 

929 ) -> tuple[ConnectionList, tuple[int, int]]: 

930 grid_shape: tuple[int, int] = ( 

931 pixel_grid.shape[0] // 2, 

932 pixel_grid.shape[1] // 2, 

933 ) 

934 connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_) 

935 

936 # Extract downward connections 

937 connection_list[0] = pixel_grid[2::2, 1::2] 

938 

939 # Extract rightward connections 

940 connection_list[1] = pixel_grid[1::2, 2::2] 

941 

942 return connection_list, grid_shape 

943 

944 @classmethod 

945 def _from_pixel_grid_with_positions( 

946 cls, 

947 pixel_grid: PixelGrid | BinaryPixelGrid, 

948 marked_positions: dict[str, RGB], 

949 ) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]: 

950 # Convert RGB pixel grid to Bool pixel grid 

951 # error: Incompatible types in assignment (expression has type 

952 # "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]", 

953 # variable has type "ndarray[Any, Any]") [assignment] 

954 pixel_grid_bw: BinaryPixelGrid = ~np.all( # type: ignore[assignment] 

955 pixel_grid == PixelColors.WALL, 

956 axis=-1, 

957 ) 

958 connection_list: ConnectionList 

959 grid_shape: tuple[int, int] 

960 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw) 

961 

962 # Find any marked positions 

963 out_positions: dict[str, CoordArray] = dict() 

964 for key, color in marked_positions.items(): 

965 pos_temp: Int[np.ndarray, "x y"] = np.argwhere( 

966 np.all(pixel_grid == color, axis=-1), 

967 ) 

968 pos_save: list[CoordTup] = list() 

969 for pos in pos_temp: 

970 # if it is a coordinate and not connection (transform position, %2==1) 

971 if pos[0] % 2 == 1 and pos[1] % 2 == 1: 

972 pos_save.append((pos[0] // 2, pos[1] // 2)) 

973 

974 out_positions[key] = np.array(pos_save) 

975 

976 return connection_list, grid_shape, out_positions 

977 

978 @classmethod 

979 def from_pixels( 

980 cls, 

981 pixel_grid: PixelGrid, 

982 ) -> "LatticeMaze": 

983 """create a LatticeMaze from a pixel grid. reverse of `as_pixels` 

984 

985 # Raises: 

986 - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` 

987 """ 

988 connection_list: ConnectionList 

989 grid_shape: tuple[int, int] 

990 

991 # if a binary pixel grid, return regular LatticeMaze 

992 if len(pixel_grid.shape) == 2: # noqa: PLR2004 

993 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) 

994 return LatticeMaze(connection_list=connection_list) 

995 

996 # otherwise, detect and check it's valid 

997 cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) 

998 if cls not in cls_detected.__mro__: 

999 err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" 

1000 raise ValueError( 

1001 err_msg, 

1002 ) 

1003 

1004 ( 

1005 connection_list, 

1006 grid_shape, 

1007 marked_pos, 

1008 ) = cls._from_pixel_grid_with_positions( 

1009 pixel_grid=pixel_grid, 

1010 marked_positions=dict( 

1011 start=PixelColors.START, 

1012 end=PixelColors.END, 

1013 solution=PixelColors.PATH, 

1014 ), 

1015 ) 

1016 # if we wanted a LatticeMaze, return it 

1017 if cls == LatticeMaze: 

1018 return LatticeMaze(connection_list=connection_list) 

1019 

1020 # otherwise, keep going 

1021 temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) 

1022 

1023 # start and end pos 

1024 start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] 

1025 assert start_pos_arr.shape == ( 

1026 1, 

1027 2, 

1028 ), ( 

1029 f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 

1030 ) 

1031 assert end_pos_arr.shape == ( 

1032 1, 

1033 2, 

1034 ), ( 

1035 f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" 

1036 ) 

1037 

1038 start_pos: Coord = start_pos_arr[0] 

1039 end_pos: Coord = end_pos_arr[0] 

1040 

1041 # return a TargetedLatticeMaze if that's what we wanted 

1042 if cls == TargetedLatticeMaze: 

1043 return TargetedLatticeMaze( 

1044 connection_list=connection_list, 

1045 start_pos=start_pos, 

1046 end_pos=end_pos, 

1047 ) 

1048 

1049 # raw solution, only contains path elements and not start or end 

1050 solution_raw: CoordArray = marked_pos["solution"] 

1051 if len(solution_raw.shape) == 2: # noqa: PLR2004 

1052 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 

1053 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" 

1054 ) 

1055 elif solution_raw.shape == (0,): 

1056 # the solution and end should be immediately adjacent 

1057 assert np.sum(np.abs(start_pos - end_pos)) == 1, ( 

1058 f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" 

1059 ) 

1060 

1061 # order the solution, by creating a list from the start to the end 

1062 # add end pos, since we will iterate over all these starting from the start pos 

1063 solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ 

1064 tuple(end_pos), 

1065 ] 

1066 # solution starts with start point 

1067 solution: list[CoordTup] = [tuple(start_pos)] 

1068 while solution[-1] != tuple(end_pos): 

1069 # use `get_coord_neighbors` to find connected neighbors 

1070 neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) 

1071 # TODO: make this less ugly 

1072 assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 

1073 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" 

1074 ) 

1075 # neighbors = neighbors[:, [1, 0]] 

1076 # filter out neighbors that are not in the raw solution 

1077 neighbors_filtered: CoordArray = np.array( 

1078 [ 

1079 coord 

1080 for coord in neighbors 

1081 if ( 

1082 tuple(coord) in solution_raw_list 

1083 and tuple(coord) not in solution 

1084 ) 

1085 ], 

1086 ) 

1087 # assert only one element is left, and then add it to the solution 

1088 assert neighbors_filtered.shape == ( 

1089 1, 

1090 2, 

1091 ), ( 

1092 f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" 

1093 ) 

1094 solution.append(tuple(neighbors_filtered[0])) 

1095 

1096 # assert the solution is complete 

1097 assert solution[0] == tuple(start_pos), ( 

1098 f"solution {solution} does not start at start_pos {start_pos}" 

1099 ) 

1100 assert solution[-1] == tuple(end_pos), ( 

1101 f"solution {solution} does not end at end_pos {end_pos}" 

1102 ) 

1103 

1104 return cls( 

1105 connection_list=np.array(connection_list), 

1106 solution=np.array(solution), # type: ignore[call-arg] 

1107 ) 

1108 

1109 # ============================================================ 

1110 # to and from ASCII 

1111 # ============================================================ 

1112 def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]: 

1113 # Get the pixel grid using to_pixels(). 

1114 pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw() 

1115 

1116 # Replace pixel values with ASCII characters. 

1117 ascii_grid: Shaped[np.ndarray, "x y"] = np.full( 

1118 pixel_grid.shape, 

1119 AsciiChars.WALL, 

1120 dtype=str, 

1121 ) 

1122 ascii_grid[pixel_grid == True] = AsciiChars.OPEN # noqa: E712 

1123 

1124 return ascii_grid 

1125 

1126 def as_ascii( 

1127 self, 

1128 show_endpoints: bool = True, 

1129 show_solution: bool = True, 

1130 ) -> str: 

1131 """return an ASCII grid of the maze 

1132 

1133 useful for debugging in the terminal, or as it's own format 

1134 

1135 can be reversed with `LatticeMaze.from_ascii()` 

1136 """ 

1137 ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() 

1138 pixel_grid: PixelGrid = self.as_pixels( 

1139 show_endpoints=show_endpoints, 

1140 show_solution=show_solution, 

1141 ) 

1142 

1143 chars_replace: tuple = tuple() 

1144 if show_endpoints: 

1145 chars_replace += (AsciiChars.START, AsciiChars.END) 

1146 if show_solution: 

1147 chars_replace += (AsciiChars.PATH,) 

1148 

1149 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 

1150 if ascii_char in chars_replace: 

1151 ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char 

1152 

1153 return "\n".join("".join(row) for row in ascii_grid) 

1154 

1155 @classmethod 

1156 def from_ascii(cls, ascii_str: str) -> "LatticeMaze": 

1157 "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" 

1158 lines: list[str] = ascii_str.strip().split("\n") 

1159 lines = [line.strip() for line in lines] 

1160 ascii_grid: Shaped[np.ndarray, "x y"] = np.array( 

1161 [list(line) for line in lines], 

1162 dtype=str, 

1163 ) 

1164 pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) 

1165 

1166 for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): 

1167 pixel_grid[ascii_grid == ascii_char] = pixel_color 

1168 

1169 return cls.from_pixels(pixel_grid) 

1170 

1171 

1172# type ignore here even though theyre all frozen 

1173# maybe `SerializeableDataclass` itself is not frozen, but thats an ABC 

1174# error: Cannot inherit frozen dataclass from a non-frozen one [misc] 

1175@serializable_dataclass(frozen=True, kw_only=True) 

1176class TargetedLatticeMaze(LatticeMaze): # type: ignore[misc] 

1177 """A LatticeMaze with a start and end position""" 

1178 

1179 # this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos 

1180 # type ignore here because even though its a kw-only dataclass, 

1181 # mypy doesn't like that non-default arguments are after default arguments 

1182 start_pos: Coord = serializable_field( # type: ignore[misc] 

1183 assert_type=False, 

1184 ) 

1185 end_pos: Coord = serializable_field( # type: ignore[misc] 

1186 assert_type=False, 

1187 ) 

1188 

1189 def __post_init__(self) -> None: 

1190 "post init converts start and end pos to numpy arrays, checks they exist and are in bounds" 

1191 # make things numpy arrays (very jank to override frozen dataclass) 

1192 self.__dict__["start_pos"] = np.array(self.start_pos) 

1193 self.__dict__["end_pos"] = np.array(self.end_pos) 

1194 assert self.start_pos is not None 

1195 assert self.end_pos is not None 

1196 # check that start and end are in bounds 

1197 if ( 

1198 self.start_pos[0] >= self.grid_shape[0] 

1199 or self.start_pos[1] >= self.grid_shape[1] 

1200 ): 

1201 err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}" 

1202 raise ValueError( 

1203 err_msg, 

1204 ) 

1205 if ( 

1206 self.end_pos[0] >= self.grid_shape[0] 

1207 or self.end_pos[1] >= self.grid_shape[1] 

1208 ): 

1209 err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }" 

1210 raise ValueError( 

1211 err_msg, 

1212 ) 

1213 

1214 def __eq__(self, other: object) -> bool: 

1215 "check equality, calls parent class equality check" 

1216 return super().__eq__(other) 

1217 

1218 def _get_start_pos_tokens(self) -> list[str | CoordTup]: 

1219 return [ 

1220 SPECIAL_TOKENS.ORIGIN_START, 

1221 tuple(self.start_pos), 

1222 SPECIAL_TOKENS.ORIGIN_END, 

1223 ] 

1224 

1225 def get_start_pos_tokens(self) -> list[str | CoordTup]: 

1226 "(deprecated!) return the start position as a list of tokens" 

1227 warnings.warn( 

1228 "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", 

1229 TokenizerDeprecationWarning, 

1230 ) 

1231 return self._get_start_pos_tokens() 

1232 

1233 def _get_end_pos_tokens(self) -> list[str | CoordTup]: 

1234 return [ 

1235 SPECIAL_TOKENS.TARGET_START, 

1236 tuple(self.end_pos), 

1237 SPECIAL_TOKENS.TARGET_END, 

1238 ] 

1239 

1240 def get_end_pos_tokens(self) -> list[str | CoordTup]: 

1241 "(deprecated!) return the end position as a list of tokens" 

1242 warnings.warn( 

1243 "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", 

1244 TokenizerDeprecationWarning, 

1245 ) 

1246 return self._get_end_pos_tokens() 

1247 

1248 @classmethod 

1249 def from_lattice_maze( 

1250 cls, 

1251 lattice_maze: LatticeMaze, 

1252 start_pos: Coord | CoordTup, 

1253 end_pos: Coord | CoordTup, 

1254 ) -> "TargetedLatticeMaze": 

1255 "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" 

1256 return cls( 

1257 connection_list=lattice_maze.connection_list, 

1258 start_pos=np.array(start_pos), 

1259 end_pos=np.array(end_pos), 

1260 generation_meta=lattice_maze.generation_meta, 

1261 ) 

1262 

1263 

1264@serializable_dataclass(frozen=True, kw_only=True) 

1265class SolvedMaze(TargetedLatticeMaze): # type: ignore[misc] 

1266 """Stores a maze and a solution""" 

1267 

1268 solution: CoordArray = serializable_field( # type: ignore[misc] 

1269 assert_type=False, 

1270 ) 

1271 

1272 def __init__( 

1273 self, 

1274 connection_list: ConnectionList, 

1275 solution: CoordArray, 

1276 generation_meta: dict | None = None, 

1277 start_pos: Coord | None = None, 

1278 end_pos: Coord | None = None, 

1279 allow_invalid: bool = False, 

1280 ) -> None: 

1281 """Create a SolvedMaze from a connection list and a solution 

1282 

1283 > DOCS: better documentation for this init method 

1284 """ 

1285 # figure out the solution 

1286 solution_valid: bool = False 

1287 if solution is not None: 

1288 solution = np.array(solution) 

1289 # note that a path length of 1 here is valid, since the start and end pos could be the same 

1290 if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 

1291 solution_valid = True 

1292 

1293 if not solution_valid and not allow_invalid: 

1294 err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" 

1295 raise ValueError( 

1296 err_msg, 

1297 f"{connection_list = }", 

1298 ) 

1299 

1300 # init the TargetedLatticeMaze 

1301 super().__init__( 

1302 connection_list=connection_list, 

1303 generation_meta=generation_meta, 

1304 # TODO: the argument type is stricter than the expected type but it still fails? 

1305 # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type 

1306 # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] 

1307 start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] 

1308 end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] 

1309 ) 

1310 

1311 self.__dict__["solution"] = solution 

1312 

1313 # adjust the endpoints 

1314 if not allow_invalid: 

1315 if start_pos is not None: 

1316 assert np.array_equal(np.array(start_pos), self.start_pos), ( 

1317 f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" 

1318 ) 

1319 if end_pos is not None: 

1320 assert np.array_equal(np.array(end_pos), self.end_pos), ( 

1321 f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" 

1322 ) 

1323 # TODO: assert the path does not backtrack, walk through walls, etc? 

1324 

1325 def __eq__(self, other: object) -> bool: 

1326 "check equality, calls parent class equality check" 

1327 return super().__eq__(other) 

1328 

1329 def __hash__(self) -> int: 

1330 "hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes" 

1331 return hash((self.connection_list.tobytes(), self.solution.tobytes())) 

1332 

1333 def _get_solution_tokens(self) -> list[str | CoordTup]: 

1334 return [ 

1335 SPECIAL_TOKENS.PATH_START, 

1336 *[tuple(c) for c in self.solution], 

1337 SPECIAL_TOKENS.PATH_END, 

1338 ] 

1339 

1340 def get_solution_tokens(self) -> list[str | CoordTup]: 

1341 "(deprecated!) return the solution as a list of tokens" 

1342 warnings.warn( 

1343 "`LatticeMaze.get_solution_tokens` is deprecated.", 

1344 TokenizerDeprecationWarning, 

1345 ) 

1346 return self._get_solution_tokens() 

1347 

1348 # for backwards compatibility 

1349 @property 

1350 def maze(self) -> LatticeMaze: 

1351 "(deprecated!) return the maze without the solution" 

1352 warnings.warn( 

1353 "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", 

1354 DeprecationWarning, 

1355 ) 

1356 return LatticeMaze(connection_list=self.connection_list) 

1357 

1358 # type ignore here since we're overriding a method with a different signature 

1359 @classmethod 

1360 def from_lattice_maze( # type: ignore[override] 

1361 cls, 

1362 lattice_maze: LatticeMaze, 

1363 solution: list[CoordTup] | CoordArray, 

1364 ) -> "SolvedMaze": 

1365 "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" 

1366 return cls( 

1367 connection_list=lattice_maze.connection_list, 

1368 solution=np.array(solution), 

1369 generation_meta=lattice_maze.generation_meta, 

1370 ) 

1371 

1372 @classmethod 

1373 def from_targeted_lattice_maze( 

1374 cls, 

1375 targeted_lattice_maze: TargetedLatticeMaze, 

1376 solution: list[CoordTup] | CoordArray | None = None, 

1377 ) -> "SolvedMaze": 

1378 """solves the given targeted lattice maze and returns a SolvedMaze""" 

1379 if solution is None: 

1380 solution = targeted_lattice_maze.find_shortest_path( 

1381 targeted_lattice_maze.start_pos, 

1382 targeted_lattice_maze.end_pos, 

1383 ) 

1384 return cls( 

1385 connection_list=targeted_lattice_maze.connection_list, 

1386 solution=np.array(solution), 

1387 generation_meta=targeted_lattice_maze.generation_meta, 

1388 ) 

1389 

1390 def get_solution_forking_points( 

1391 self, 

1392 always_include_endpoints: bool = False, 

1393 ) -> tuple[list[int], CoordArray]: 

1394 """coordinates and their indicies from the solution where a fork is present 

1395 

1396 - if the start point is not a dead end, this counts as a fork 

1397 - if the end point is not a dead end, this counts as a fork 

1398 """ 

1399 output_idxs: list[int] = list() 

1400 output_coords: list[CoordTup] = list() 

1401 

1402 for idx, coord in enumerate(self.solution): 

1403 # more than one choice for first coord, or more than 2 for any other 

1404 # since the previous coord doesn't count as a choice 

1405 is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 

1406 theshold: int = 1 if is_endpoint else 2 

1407 if self.get_coord_neighbors(coord).shape[0] > theshold or ( 

1408 is_endpoint and always_include_endpoints 

1409 ): 

1410 output_idxs.append(idx) 

1411 output_coords.append(coord) 

1412 

1413 return output_idxs, np.array(output_coords) 

1414 

1415 def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: 

1416 """coordinates from the solution where there is only a single (non-backtracking) point to move to 

1417 

1418 returns the complement of `get_solution_forking_points` from the path 

1419 """ 

1420 forks_idxs, _ = self.get_solution_forking_points() 

1421 # HACK: idk why type ignore here 

1422 return ( # type: ignore[return-value] 

1423 np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), 

1424 np.delete(self.solution, forks_idxs, axis=0), 

1425 ) 

1426 

1427 

1428def detect_pixels_type(data: PixelGrid) -> typing.Type[LatticeMaze]: 

1429 """Detects the type of pixels data by checking for the presence of start and end pixels""" 

1430 if color_in_pixel_grid(data, PixelColors.START) or color_in_pixel_grid( 

1431 data, 

1432 PixelColors.END, 

1433 ): 

1434 if color_in_pixel_grid(data, PixelColors.PATH): 

1435 return SolvedMaze 

1436 else: 

1437 return TargetedLatticeMaze 

1438 else: 

1439 return LatticeMaze 

1440 

1441 

1442def _remove_isolated_cells( 

1443 image: Int[np.ndarray, "RGB x y"], 

1444) -> Int[np.ndarray, "RGB x y"]: 

1445 """Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides.""" 

1446 # Create a binary mask where True represents walls 

1447 wall_mask = np.all(image == PixelColors.WALL, axis=-1) 

1448 

1449 # Pad the wall mask to handle edge cases 

1450 padded_wall_mask = np.pad( 

1451 wall_mask, 

1452 ((1, 1), (1, 1)), 

1453 mode="constant", 

1454 constant_values=True, 

1455 ) 

1456 

1457 # Check neighbors in all four directions 

1458 isolated_mask = ( 

1459 padded_wall_mask[1:-1, 2:] # right 

1460 & padded_wall_mask[1:-1, :-2] # left 

1461 & padded_wall_mask[2:, 1:-1] # down 

1462 & padded_wall_mask[:-2, 1:-1] # up 

1463 ) 

1464 

1465 # Combine with non-wall mask to only affect open cells 

1466 non_wall_mask = ~wall_mask 

1467 isolated_mask = isolated_mask & non_wall_mask 

1468 

1469 # Create the output image 

1470 output_image = image.copy() 

1471 output_image[isolated_mask] = PixelColors.WALL 

1472 

1473 return output_image 

1474 

1475 

1476_RIC_PADS: dict = { 

1477 "left": ((1, 0), (0, 0)), 

1478 "right": ((0, 1), (0, 0)), 

1479 "up": ((0, 0), (1, 0)), 

1480 "down": ((0, 0), (0, 1)), 

1481} 

1482 

1483# Define slices for each direction 

1484_RIC_SLICES: dict = { 

1485 "left": (slice(1, None), slice(None, None)), 

1486 "right": (slice(None, -1), slice(None, None)), 

1487 "up": (slice(None, None), slice(1, None)), 

1488 "down": (slice(None, None), slice(None, -1)), 

1489} 

1490 

1491 

1492# TODO: figure out why this function doesnt work, or maybe just get rid of it 

1493# def _remove_isolated_cells_old( 

1494# image: Int[np.ndarray, "RGB x y"], 

1495# ) -> Int[np.ndarray, "RGB x y"]: 

1496# """ 

1497# Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides. 

1498# """ 

1499# warnings.warn("this functin doesn't work and I have no idea why!!!") 

1500# masks: dict[str, np.ndarray] = { 

1501# d: np.all( 

1502# np.pad( 

1503# image[_RIC_SLICES[d][0], _RIC_SLICES[d][1], :] == PixelColors.WALL, 

1504# np.array((*_RIC_PADS[d], (0, 0)), dtype=np.int8), 

1505# mode="constant", 

1506# constant_values=True, 

1507# ), 

1508# axis=2, 

1509# ) 

1510# for d in _RIC_SLICES.keys() 

1511# } 

1512 

1513# # Create a mask for non-wall cells 

1514# mask_non_wall = np.all(image != PixelColors.WALL, axis=2) 

1515 

1516# # print(f"{mask_non_wall.shape = }") 

1517# # print(f"{ {k: masks[k].shape for k in masks.keys()} = }") 

1518 

1519# # print(f"{mask_non_wall = }") 

1520# # print(f"{masks['down'] = }") 

1521 

1522# # Combine the masks 

1523# mask = mask_non_wall & masks["left"] & masks["right"] & masks["up"] & masks["down"] 

1524 

1525# # Apply the mask 

1526# output_image = np.where( 

1527# np.stack([mask] * 3, axis=-1), 

1528# PixelColors.WALL, 

1529# image, 

1530# ) 

1531 

1532# return output_image