Coverage for maze_dataset/generation/generators.py: 83%

206 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 12:48 -0600

1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`""" 

2 

3import random 

4import warnings 

5from typing import Any, Callable 

6 

7import numpy as np 

8from jaxtyping import Bool 

9 

10from maze_dataset.constants import CoordArray, CoordTup 

11from maze_dataset.generation.seed import GLOBAL_SEED 

12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze 

13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls 

14 

15numpy_rng = np.random.default_rng(GLOBAL_SEED) 

16random.seed(GLOBAL_SEED) 

17 

18 

19def _random_start_coord( 

20 grid_shape: Coord, 

21 start_coord: Coord | CoordTup | None, 

22) -> Coord: 

23 "picking a random start coord within the bounds of `grid_shape` if none is provided" 

24 start_coord_: Coord 

25 if start_coord is None: 

26 start_coord_ = np.random.randint( 

27 0, # lower bound 

28 np.maximum(grid_shape - 1, 1), # upper bound (at least 1) 

29 size=len(grid_shape), # dimensionality 

30 ) 

31 else: 

32 start_coord_ = np.array(start_coord) 

33 

34 return start_coord_ 

35 

36 

37def get_neighbors_in_bounds( 

38 coord: Coord, 

39 grid_shape: Coord, 

40) -> CoordArray: 

41 "get all neighbors of a coordinate that are within the bounds of the grid" 

42 # get all neighbors 

43 neighbors: CoordArray = coord + NEIGHBORS_MASK 

44 

45 # filter neighbors by being within grid bounds 

46 neighbors_in_bounds: CoordArray = neighbors[ 

47 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1) 

48 ] 

49 

50 return neighbors_in_bounds 

51 

52 

53class LatticeMazeGenerators: 

54 """namespace for lattice maze generation algorithms 

55 

56 examples of generated mazes can be found here: 

57 https://understanding-search.github.io/maze-dataset/examples/maze_examples.html 

58 """ 

59 

60 @staticmethod 

61 def gen_dfs( 

62 grid_shape: Coord | CoordTup, 

63 lattice_dim: int = 2, 

64 accessible_cells: float | None = None, 

65 max_tree_depth: float | None = None, 

66 do_forks: bool = True, 

67 randomized_stack: bool = False, 

68 start_coord: Coord | None = None, 

69 ) -> LatticeMaze: 

70 """generate a lattice maze using depth first search, iterative 

71 

72 # Arguments 

73 - `grid_shape: Coord`: the shape of the grid 

74 - `lattice_dim: int`: the dimension of the lattice 

75 (default: `2`) 

76 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** 

77 (default: `None`) 

78 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** 

79 (default: `None`) 

80 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. 

81 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. 

82 

83 # algorithm 

84 1. Choose the initial cell, mark it as visited and push it to the stack 

85 2. While the stack is not empty 

86 1. Pop a cell from the stack and make it a current cell 

87 2. If the current cell has any neighbours which have not been visited 

88 1. Push the current cell to the stack 

89 2. Choose one of the unvisited neighbours 

90 3. Remove the wall between the current cell and the chosen cell 

91 4. Mark the chosen cell as visited and push it to the stack 

92 """ 

93 # Default values if no constraints have been passed 

94 grid_shape_: Coord = np.array(grid_shape) 

95 n_total_cells: int = int(np.prod(grid_shape_)) 

96 

97 n_accessible_cells: int 

98 if accessible_cells is None: 

99 n_accessible_cells = n_total_cells 

100 elif isinstance(accessible_cells, float): 

101 assert accessible_cells <= 1, ( 

102 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" 

103 ) 

104 

105 n_accessible_cells = int(accessible_cells * n_total_cells) 

106 else: 

107 assert isinstance(accessible_cells, int) 

108 n_accessible_cells = accessible_cells 

109 

110 if max_tree_depth is None: 

111 max_tree_depth = ( 

112 2 * n_total_cells 

113 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. 

114 elif isinstance(max_tree_depth, float): 

115 assert max_tree_depth <= 1, ( 

116 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" 

117 ) 

118 

119 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) 

120 

121 # choose a random start coord 

122 start_coord = _random_start_coord(grid_shape_, start_coord) 

123 

124 # initialize the maze with no connections 

125 connection_list: ConnectionList = np.zeros( 

126 (lattice_dim, grid_shape_[0], grid_shape_[1]), 

127 dtype=np.bool_, 

128 ) 

129 

130 # initialize the stack with the target coord 

131 visited_cells: set[tuple[int, int]] = set() 

132 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol 

133 stack: list[Coord] = [start_coord] 

134 

135 # initialize tree_depth_counter 

136 current_tree_depth: int = 1 

137 

138 # loop until the stack is empty or n_connected_cells is reached 

139 while stack and (len(visited_cells) < n_accessible_cells): 

140 # get the current coord from the stack 

141 current_coord: Coord 

142 if randomized_stack: 

143 current_coord = stack.pop(random.randint(0, len(stack) - 1)) 

144 else: 

145 current_coord = stack.pop() 

146 

147 # filter neighbors by being within grid bounds and being unvisited 

148 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ 

149 (neighbor, delta) 

150 for neighbor, delta in zip( 

151 current_coord + NEIGHBORS_MASK, 

152 NEIGHBORS_MASK, 

153 strict=False, 

154 ) 

155 if ( 

156 (tuple(neighbor) not in visited_cells) 

157 and (0 <= neighbor[0] < grid_shape_[0]) 

158 and (0 <= neighbor[1] < grid_shape_[1]) 

159 ) 

160 ] 

161 

162 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) 

163 if unvisited_neighbors_deltas and ( 

164 current_tree_depth <= max_tree_depth / 2 

165 ): 

166 # if we want a maze without forks, simply don't add the current coord back to the stack 

167 if do_forks and (len(unvisited_neighbors_deltas) > 1): 

168 stack.append(current_coord) 

169 

170 # choose one of the unvisited neighbors 

171 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) 

172 

173 # add connection 

174 dim: int = int(np.argmax(np.abs(delta))) 

175 # if positive, down/right from current coord 

176 # if negative, up/left from current coord (down/right from neighbor) 

177 clist_node: Coord = ( 

178 current_coord if (delta.sum() > 0) else chosen_neighbor 

179 ) 

180 connection_list[dim, clist_node[0], clist_node[1]] = True 

181 

182 # add to visited cells and stack 

183 visited_cells.add(tuple(chosen_neighbor)) 

184 stack.append(chosen_neighbor) 

185 

186 # Update current tree depth 

187 current_tree_depth += 1 

188 else: 

189 current_tree_depth -= 1 

190 

191 return LatticeMaze( 

192 connection_list=connection_list, 

193 generation_meta=dict( 

194 func_name="gen_dfs", 

195 grid_shape=grid_shape_, 

196 start_coord=start_coord, 

197 n_accessible_cells=int(n_accessible_cells), 

198 max_tree_depth=int(max_tree_depth), 

199 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug 

200 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is 

201 # treated as fully connected even when it is most certainly not, causing solving the maze to break 

202 fully_connected=bool(len(visited_cells) == n_total_cells), 

203 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, 

204 ), 

205 ) 

206 

207 @staticmethod 

208 def gen_prim( 

209 grid_shape: Coord | CoordTup, 

210 lattice_dim: int = 2, 

211 accessible_cells: float | None = None, 

212 max_tree_depth: float | None = None, 

213 do_forks: bool = True, 

214 start_coord: Coord | None = None, 

215 ) -> LatticeMaze: 

216 "(broken!) generate a lattice maze using Prim's algorithm" 

217 warnings.warn( 

218 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", 

219 ) 

220 return LatticeMazeGenerators.gen_dfs( 

221 grid_shape=grid_shape, 

222 lattice_dim=lattice_dim, 

223 accessible_cells=accessible_cells, 

224 max_tree_depth=max_tree_depth, 

225 do_forks=do_forks, 

226 start_coord=start_coord, 

227 randomized_stack=True, 

228 ) 

229 

230 @staticmethod 

231 def gen_wilson( 

232 grid_shape: Coord | CoordTup, 

233 **kwargs, 

234 ) -> LatticeMaze: 

235 """Generate a lattice maze using Wilson's algorithm. 

236 

237 # Algorithm 

238 Wilson's algorithm generates an unbiased (random) maze 

239 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is 

240 acyclic and all cells are part of a unique connected space. 

241 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm 

242 """ 

243 assert not kwargs, ( 

244 f"gen_wilson does not take any additional arguments, got {kwargs = }" 

245 ) 

246 

247 grid_shape_: Coord = np.array(grid_shape) 

248 

249 # Initialize grid and visited cells 

250 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) 

251 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) 

252 

253 # Choose a random cell and mark it as visited 

254 start_coord: Coord = _random_start_coord(grid_shape_, None) 

255 visited[start_coord[0], start_coord[1]] = True 

256 del start_coord 

257 

258 while not visited.all(): 

259 # Perform loop-erased random walk from another random cell 

260 

261 # Choose walk_start only from unvisited cells 

262 unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) 

263 walk_start: Coord = unvisited_coords[ 

264 np.random.choice(unvisited_coords.shape[0]) 

265 ] 

266 

267 # Perform the random walk 

268 path: list[Coord] = [walk_start] 

269 current: Coord = walk_start 

270 

271 # exit the loop once the current path hits a visited cell 

272 while not visited[current[0], current[1]]: 

273 # find a valid neighbor (one always exists on a lattice) 

274 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) 

275 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] 

276 

277 # Check for loop 

278 loop_exit: int | None = None 

279 for i, p in enumerate(path): 

280 if np.array_equal(next_cell, p): 

281 loop_exit = i 

282 break 

283 

284 # erase the loop, or continue the walk 

285 if loop_exit is not None: 

286 # this removes everything after and including the loop start 

287 path = path[: loop_exit + 1] 

288 # reset current cell to end of path 

289 current = path[-1] 

290 else: 

291 path.append(next_cell) 

292 current = next_cell 

293 

294 # Add the path to the maze 

295 for i in range(len(path) - 1): 

296 c_1: Coord = path[i] 

297 c_2: Coord = path[i + 1] 

298 

299 # find the dimension of the connection 

300 delta: Coord = c_2 - c_1 

301 dim: int = int(np.argmax(np.abs(delta))) 

302 

303 # if positive, down/right from current coord 

304 # if negative, up/left from current coord (down/right from neighbor) 

305 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 

306 connection_list[dim, clist_node[0], clist_node[1]] = True 

307 visited[c_1[0], c_1[1]] = True 

308 # we dont add c_2 because the last c_2 will have already been visited 

309 

310 return LatticeMaze( 

311 connection_list=connection_list, 

312 generation_meta=dict( 

313 func_name="gen_wilson", 

314 grid_shape=grid_shape_, 

315 fully_connected=True, 

316 ), 

317 ) 

318 

319 @staticmethod 

320 def gen_percolation( 

321 grid_shape: Coord | CoordTup, 

322 p: float = 0.4, 

323 lattice_dim: int = 2, 

324 start_coord: Coord | None = None, 

325 ) -> LatticeMaze: 

326 """generate a lattice maze using simple percolation 

327 

328 note that p in the range (0.4, 0.7) gives the most interesting mazes 

329 

330 # Arguments 

331 - `grid_shape: Coord`: the shape of the grid 

332 - `lattice_dim: int`: the dimension of the lattice (default: `2`) 

333 - `p: float`: the probability of a cell being accessible (default: `0.5`) 

334 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) 

335 """ 

336 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 

337 grid_shape_: Coord = np.array(grid_shape) 

338 

339 start_coord = _random_start_coord(grid_shape_, start_coord) 

340 

341 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p 

342 

343 connection_list = _fill_edges_with_walls(connection_list) 

344 

345 output: LatticeMaze = LatticeMaze( 

346 connection_list=connection_list, 

347 generation_meta=dict( 

348 func_name="gen_percolation", 

349 grid_shape=grid_shape_, 

350 percolation_p=p, 

351 start_coord=start_coord, 

352 ), 

353 ) 

354 

355 # generation_meta is sometimes None, but not here since we just made it a dict above 

356 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] 

357 start_coord, 

358 ) 

359 

360 return output 

361 

362 @staticmethod 

363 def gen_dfs_percolation( 

364 grid_shape: Coord | CoordTup, 

365 p: float = 0.4, 

366 lattice_dim: int = 2, 

367 accessible_cells: int | None = None, 

368 max_tree_depth: int | None = None, 

369 start_coord: Coord | None = None, 

370 ) -> LatticeMaze: 

371 """dfs and then percolation (adds cycles)""" 

372 grid_shape_: Coord = np.array(grid_shape) 

373 start_coord = _random_start_coord(grid_shape_, start_coord) 

374 

375 # generate initial maze via dfs 

376 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( 

377 grid_shape=grid_shape_, 

378 lattice_dim=lattice_dim, 

379 accessible_cells=accessible_cells, 

380 max_tree_depth=max_tree_depth, 

381 start_coord=start_coord, 

382 ) 

383 

384 # percolate 

385 connection_list_perc: np.ndarray = ( 

386 np.random.rand(*maze.connection_list.shape) < p 

387 ) 

388 connection_list_perc = _fill_edges_with_walls(connection_list_perc) 

389 

390 maze.__dict__["connection_list"] = np.logical_or( 

391 maze.connection_list, 

392 connection_list_perc, 

393 ) 

394 

395 # generation_meta is sometimes None, but not here since we just made it a dict above 

396 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] 

397 maze.generation_meta["percolation_p"] = p # type: ignore[index] 

398 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] 

399 start_coord, 

400 ) 

401 

402 return maze 

403 

404 @staticmethod 

405 def gen_kruskal( 

406 grid_shape: "Coord | CoordTup", 

407 lattice_dim: int = 2, 

408 start_coord: "Coord | None" = None, 

409 ) -> "LatticeMaze": 

410 """Generate a maze using Kruskal's algorithm. 

411 

412 This function generates a random spanning tree over a grid using Kruskal's algorithm. 

413 Each cell is treated as a node, and all valid adjacent edges are listed and processed 

414 in random order. An edge is added (i.e. its passage carved) only if it connects two cells 

415 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) 

416 without cycles. 

417 

418 https://en.wikipedia.org/wiki/Kruskal's_algorithm 

419 

420 # Parameters: 

421 - `grid_shape : Coord | CoordTup` 

422 The shape of the maze grid (for example, `(n_rows, n_cols)`). 

423 - `lattice_dim : int` 

424 The lattice dimension (default is `2`). 

425 - `start_coord : Coord | None` 

426 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen. 

427 - `**kwargs` 

428 Additional keyword arguments (currently unused). 

429 

430 # Returns: 

431 - `LatticeMaze` 

432 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm. 

433 

434 # Usage: 

435 ```python 

436 maze = gen_kruskal((10, 10)) 

437 ``` 

438 """ 

439 assert lattice_dim == 2, ( # noqa: PLR2004 

440 "Kruskal's algorithm is only implemented for 2D lattices." 

441 ) 

442 # Convert grid_shape to a tuple of ints 

443 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 

444 n_rows, n_cols = grid_shape_ 

445 

446 # Initialize union-find data structure. 

447 parent: dict[tuple[int, int], tuple[int, int]] = {} 

448 

449 def find(cell: tuple[int, int]) -> tuple[int, int]: 

450 while parent[cell] != cell: 

451 parent[cell] = parent[parent[cell]] 

452 cell = parent[cell] 

453 return cell 

454 

455 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None: 

456 root1 = find(cell1) 

457 root2 = find(cell2) 

458 parent[root2] = root1 

459 

460 # Initialize each cell as its own set. 

461 for i in range(n_rows): 

462 for j in range(n_cols): 

463 parent[(i, j)] = (i, j) 

464 

465 # List all possible edges. 

466 # For vertical edges (i.e. connecting a cell to its right neighbor): 

467 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = [] 

468 for i in range(n_rows): 

469 for j in range(n_cols - 1): 

470 edges.append(((i, j), (i, j + 1), 1)) 

471 # For horizontal edges (i.e. connecting a cell to its bottom neighbor): 

472 for i in range(n_rows - 1): 

473 for j in range(n_cols): 

474 edges.append(((i, j), (i + 1, j), 0)) 

475 

476 # Shuffle the list of edges. 

477 import random 

478 

479 random.shuffle(edges) 

480 

481 # Initialize connection_list with no connections. 

482 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)). 

483 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)). 

484 import numpy as np 

485 

486 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 

487 

488 # Process each edge; if it connects two different trees, union them and carve the passage. 

489 for cell1, cell2, direction in edges: 

490 if find(cell1) != find(cell2): 

491 union(cell1, cell2) 

492 if direction == 0: 

493 # Horizontal edge: connection is stored in connection_list[0] at cell1. 

494 connection_list[0, cell1[0], cell1[1]] = True 

495 else: 

496 # Vertical edge: connection is stored in connection_list[1] at cell1. 

497 connection_list[1, cell1[0], cell1[1]] = True 

498 

499 if start_coord is None: 

500 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 

501 

502 generation_meta: dict = dict( 

503 func_name="gen_kruskal", 

504 grid_shape=grid_shape_, 

505 start_coord=start_coord, 

506 algorithm="kruskal", 

507 fully_connected=True, 

508 ) 

509 return LatticeMaze( 

510 connection_list=connection_list, generation_meta=generation_meta 

511 ) 

512 

513 @staticmethod 

514 def gen_recursive_division( 

515 grid_shape: "Coord | CoordTup", 

516 lattice_dim: int = 2, 

517 start_coord: "Coord | None" = None, 

518 ) -> "LatticeMaze": 

519 """Generate a maze using the recursive division algorithm. 

520 

521 This function generates a maze by recursively dividing the grid with walls and carving a single 

522 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent 

523 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. 

524 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells. 

525 

526 # Parameters: 

527 - `grid_shape : Coord | CoordTup` 

528 The shape of the maze grid (e.g., `(n_rows, n_cols)`). 

529 - `lattice_dim : int` 

530 The lattice dimension (default is `2`). 

531 - `start_coord : Coord | None` 

532 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen. 

533 - `**kwargs` 

534 Additional keyword arguments (currently unused). 

535 

536 # Returns: 

537 - `LatticeMaze` 

538 A maze represented by a connection list, generated using recursive division. 

539 

540 # Usage: 

541 ```python 

542 maze = gen_recursive_division((10, 10)) 

543 ``` 

544 """ 

545 assert lattice_dim == 2, ( # noqa: PLR2004 

546 "Recursive division algorithm is only implemented for 2D lattices." 

547 ) 

548 # Convert grid_shape to a tuple of ints. 

549 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] 

550 n_rows, n_cols = grid_shape_ 

551 

552 # Initialize connection_list as a fully connected grid. 

553 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True. 

554 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True. 

555 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) 

556 connection_list[0, : n_rows - 1, :] = True 

557 connection_list[1, :, : n_cols - 1] = True 

558 

559 def divide(x: int, y: int, width: int, height: int) -> None: 

560 """Recursively divide the region starting at (x, y) with the given width and height. 

561 

562 Removes connections along the chosen division line except for one randomly chosen gap. 

563 """ 

564 if width < 2 or height < 2: # noqa: PLR2004 

565 return 

566 

567 if width > height: 

568 # Vertical division. 

569 wall_col = random.randint(x + 1, x + width - 1) 

570 gap_row = random.randint(y, y + height - 1) 

571 for row in range(y, y + height): 

572 if row == gap_row: 

573 continue 

574 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col). 

575 if wall_col - 1 < n_cols - 1: 

576 connection_list[1, row, wall_col - 1] = False 

577 # Recurse on the left and right subregions. 

578 divide(x, y, wall_col - x, height) 

579 divide(wall_col, y, x + width - wall_col, height) 

580 else: 

581 # Horizontal division. 

582 wall_row = random.randint(y + 1, y + height - 1) 

583 gap_col = random.randint(x, x + width - 1) 

584 for col in range(x, x + width): 

585 if col == gap_col: 

586 continue 

587 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col). 

588 if wall_row - 1 < n_rows - 1: 

589 connection_list[0, wall_row - 1, col] = False 

590 # Recurse on the top and bottom subregions. 

591 divide(x, y, width, wall_row - y) 

592 divide(x, wall_row, width, y + height - wall_row) 

593 

594 # Begin the division on the full grid. 

595 divide(0, 0, n_cols, n_rows) 

596 

597 if start_coord is None: 

598 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] 

599 

600 generation_meta: dict = dict( 

601 func_name="gen_recursive_division", 

602 grid_shape=grid_shape_, 

603 start_coord=start_coord, 

604 algorithm="recursive_division", 

605 fully_connected=True, 

606 ) 

607 return LatticeMaze( 

608 connection_list=connection_list, generation_meta=generation_meta 

609 ) 

610 

611 

612# cant automatically populate this because it messes with pickling :( 

613GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = { 

614 "gen_dfs": LatticeMazeGenerators.gen_dfs, 

615 # TYPING: error: Dict entry 1 has incompatible type 

616 # "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]"; 

617 # expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]" [dict-item] 

618 # gen_wilson takes no kwargs and we check that the kwargs are empty 

619 # but mypy doesnt like this, `Any` != `KwArg(Any)` 

620 "gen_wilson": LatticeMazeGenerators.gen_wilson, # type: ignore[dict-item] 

621 "gen_percolation": LatticeMazeGenerators.gen_percolation, 

622 "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation, 

623 "gen_prim": LatticeMazeGenerators.gen_prim, 

624 "gen_kruskal": LatticeMazeGenerators.gen_kruskal, 

625 "gen_recursive_division": LatticeMazeGenerators.gen_recursive_division, 

626} 

627"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`" 

628 

629_GENERATORS_PERCOLATED: list[str] = [ 

630 "gen_percolation", 

631 "gen_dfs_percolation", 

632] 

633"""list of generator names that generate percolated mazes 

634we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail 

635this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array` 

636""" 

637 

638 

639def get_maze_with_solution( 

640 gen_name: str, 

641 grid_shape: Coord | CoordTup, 

642 maze_ctor_kwargs: dict | None = None, 

643) -> SolvedMaze: 

644 "helper function to get a maze already with a solution" 

645 if maze_ctor_kwargs is None: 

646 maze_ctor_kwargs = dict() 

647 # TYPING: error: Too few arguments [call-arg] 

648 # not sure why this is happening -- doesnt recognize the kwargs? 

649 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg] 

650 solution: CoordArray = np.array(maze.generate_random_path()) 

651 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)