Coverage for maze_dataset/generation/generators.py: 82%
208 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 22:20 -0700
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 22:20 -0700
1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`"""
3import random
4import warnings
5from typing import Callable, Concatenate, ParamSpec
7import numpy as np
8from jaxtyping import Bool
10from maze_dataset.constants import CoordArray, CoordTup
11from maze_dataset.generation.seed import GLOBAL_SEED
12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze
13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls
15_NUMPY_RNG: np.random.Generator = np.random.default_rng(GLOBAL_SEED)
16random.seed(GLOBAL_SEED)
19def _random_start_coord(
20 grid_shape: Coord,
21 start_coord: Coord | CoordTup | None,
22) -> Coord:
23 "picking a random start coord within the bounds of `grid_shape` if none is provided"
24 start_coord_: Coord
25 if start_coord is None:
26 start_coord_ = np.random.randint(
27 0, # lower bound
28 np.maximum(grid_shape - 1, 1), # upper bound (at least 1)
29 size=len(grid_shape), # dimensionality
30 )
31 else:
32 start_coord_ = np.array(start_coord)
34 return start_coord_
37def get_neighbors_in_bounds(
38 coord: Coord,
39 grid_shape: Coord,
40) -> CoordArray:
41 "get all neighbors of a coordinate that are within the bounds of the grid"
42 # get all neighbors
43 neighbors: CoordArray = coord + NEIGHBORS_MASK
45 # filter neighbors by being within grid bounds
46 neighbors_in_bounds: CoordArray = neighbors[
47 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
48 ]
50 return neighbors_in_bounds
53class LatticeMazeGenerators:
54 """namespace for lattice maze generation algorithms
56 examples of generated mazes can be found here:
57 https://understanding-search.github.io/maze-dataset/examples/maze_examples.html
58 """
60 @staticmethod
61 def gen_dfs(
62 grid_shape: Coord | CoordTup,
63 *,
64 lattice_dim: int = 2,
65 accessible_cells: float | None = None,
66 max_tree_depth: float | None = None,
67 do_forks: bool = True,
68 randomized_stack: bool = False,
69 start_coord: Coord | None = None,
70 ) -> LatticeMaze:
71 """generate a lattice maze using depth first search, iterative
73 # Arguments
74 - `grid_shape: Coord`: the shape of the grid
75 - `lattice_dim: int`: the dimension of the lattice
76 (default: `2`)
77 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
78 (default: `None`)
79 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
80 (default: `None`)
81 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
82 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
84 # algorithm
85 1. Choose the initial cell, mark it as visited and push it to the stack
86 2. While the stack is not empty
87 1. Pop a cell from the stack and make it a current cell
88 2. If the current cell has any neighbours which have not been visited
89 1. Push the current cell to the stack
90 2. Choose one of the unvisited neighbours
91 3. Remove the wall between the current cell and the chosen cell
92 4. Mark the chosen cell as visited and push it to the stack
93 """
94 # Default values if no constraints have been passed
95 grid_shape_: Coord = np.array(grid_shape)
96 n_total_cells: int = int(np.prod(grid_shape_))
98 n_accessible_cells: int
99 if accessible_cells is None:
100 n_accessible_cells = n_total_cells
101 elif isinstance(accessible_cells, float):
102 assert accessible_cells <= 1, (
103 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
104 )
106 n_accessible_cells = int(accessible_cells * n_total_cells)
107 else:
108 assert isinstance(accessible_cells, int)
109 n_accessible_cells = accessible_cells
111 if max_tree_depth is None:
112 max_tree_depth = (
113 2 * n_total_cells
114 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
115 elif isinstance(max_tree_depth, float):
116 assert max_tree_depth <= 1, (
117 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
118 )
120 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
122 # choose a random start coord
123 start_coord = _random_start_coord(grid_shape_, start_coord)
125 # initialize the maze with no connections
126 connection_list: ConnectionList = np.zeros(
127 (lattice_dim, grid_shape_[0], grid_shape_[1]),
128 dtype=np.bool_,
129 )
131 # initialize the stack with the target coord
132 visited_cells: set[tuple[int, int]] = set()
133 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol
134 stack: list[Coord] = [start_coord]
136 # initialize tree_depth_counter
137 current_tree_depth: int = 1
139 # loop until the stack is empty or n_connected_cells is reached
140 while stack and (len(visited_cells) < n_accessible_cells):
141 # get the current coord from the stack
142 current_coord: Coord
143 if randomized_stack:
144 current_coord = stack.pop(random.randint(0, len(stack) - 1))
145 else:
146 current_coord = stack.pop()
148 # filter neighbors by being within grid bounds and being unvisited
149 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
150 (neighbor, delta)
151 for neighbor, delta in zip(
152 current_coord + NEIGHBORS_MASK,
153 NEIGHBORS_MASK,
154 strict=False,
155 )
156 if (
157 (tuple(neighbor) not in visited_cells)
158 and (0 <= neighbor[0] < grid_shape_[0])
159 and (0 <= neighbor[1] < grid_shape_[1])
160 )
161 ]
163 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
164 if unvisited_neighbors_deltas and (
165 current_tree_depth <= max_tree_depth / 2
166 ):
167 # if we want a maze without forks, simply don't add the current coord back to the stack
168 if do_forks and (len(unvisited_neighbors_deltas) > 1):
169 stack.append(current_coord)
171 # choose one of the unvisited neighbors
172 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
174 # add connection
175 dim: int = int(np.argmax(np.abs(delta)))
176 # if positive, down/right from current coord
177 # if negative, up/left from current coord (down/right from neighbor)
178 clist_node: Coord = (
179 current_coord if (delta.sum() > 0) else chosen_neighbor
180 )
181 connection_list[dim, clist_node[0], clist_node[1]] = True
183 # add to visited cells and stack
184 visited_cells.add(tuple(chosen_neighbor))
185 stack.append(chosen_neighbor)
187 # Update current tree depth
188 current_tree_depth += 1
189 else:
190 current_tree_depth -= 1
192 return LatticeMaze(
193 connection_list=connection_list,
194 generation_meta=dict(
195 func_name="gen_dfs",
196 grid_shape=grid_shape_,
197 start_coord=start_coord,
198 n_accessible_cells=int(n_accessible_cells),
199 max_tree_depth=int(max_tree_depth),
200 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
201 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
202 # treated as fully connected even when it is most certainly not, causing solving the maze to break
203 fully_connected=bool(len(visited_cells) == n_total_cells),
204 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
205 ),
206 )
208 @staticmethod
209 def gen_prim(
210 grid_shape: Coord | CoordTup,
211 lattice_dim: int = 2,
212 accessible_cells: float | None = None,
213 max_tree_depth: float | None = None,
214 do_forks: bool = True,
215 start_coord: Coord | None = None,
216 ) -> LatticeMaze:
217 "(broken!) generate a lattice maze using Prim's algorithm"
218 warnings.warn(
219 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
220 )
221 return LatticeMazeGenerators.gen_dfs(
222 grid_shape=grid_shape,
223 lattice_dim=lattice_dim,
224 accessible_cells=accessible_cells,
225 max_tree_depth=max_tree_depth,
226 do_forks=do_forks,
227 start_coord=start_coord,
228 randomized_stack=True,
229 )
231 @staticmethod
232 def gen_wilson(
233 grid_shape: Coord | CoordTup,
234 **kwargs,
235 ) -> LatticeMaze:
236 """Generate a lattice maze using Wilson's algorithm.
238 # Algorithm
239 Wilson's algorithm generates an unbiased (random) maze
240 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
241 acyclic and all cells are part of a unique connected space.
242 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
243 """
244 assert not kwargs, (
245 f"gen_wilson does not take any additional arguments, got {kwargs = }"
246 )
248 grid_shape_: Coord = np.array(grid_shape)
250 # Initialize grid and visited cells
251 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
252 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
254 # Choose a random cell and mark it as visited
255 start_coord: Coord = _random_start_coord(grid_shape_, None)
256 visited[start_coord[0], start_coord[1]] = True
257 del start_coord
259 while not visited.all():
260 # Perform loop-erased random walk from another random cell
262 # Choose walk_start only from unvisited cells
263 unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
264 walk_start: Coord = unvisited_coords[
265 np.random.choice(unvisited_coords.shape[0])
266 ]
268 # Perform the random walk
269 path: list[Coord] = [walk_start]
270 current: Coord = walk_start
272 # exit the loop once the current path hits a visited cell
273 while not visited[current[0], current[1]]:
274 # find a valid neighbor (one always exists on a lattice)
275 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
276 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
278 # Check for loop
279 loop_exit: int | None = None
280 for i, p in enumerate(path):
281 if np.array_equal(next_cell, p):
282 loop_exit = i
283 break
285 # erase the loop, or continue the walk
286 if loop_exit is not None:
287 # this removes everything after and including the loop start
288 path = path[: loop_exit + 1]
289 # reset current cell to end of path
290 current = path[-1]
291 else:
292 path.append(next_cell)
293 current = next_cell
295 # Add the path to the maze
296 for i in range(len(path) - 1):
297 c_1: Coord = path[i]
298 c_2: Coord = path[i + 1]
300 # find the dimension of the connection
301 delta: Coord = c_2 - c_1
302 dim: int = int(np.argmax(np.abs(delta)))
304 # if positive, down/right from current coord
305 # if negative, up/left from current coord (down/right from neighbor)
306 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
307 connection_list[dim, clist_node[0], clist_node[1]] = True
308 visited[c_1[0], c_1[1]] = True
309 # we dont add c_2 because the last c_2 will have already been visited
311 return LatticeMaze(
312 connection_list=connection_list,
313 generation_meta=dict(
314 func_name="gen_wilson",
315 grid_shape=grid_shape_,
316 fully_connected=True,
317 ),
318 )
320 @staticmethod
321 def gen_percolation(
322 grid_shape: Coord | CoordTup,
323 p: float = 0.4,
324 lattice_dim: int = 2,
325 start_coord: Coord | None = None,
326 ) -> LatticeMaze:
327 """generate a lattice maze using simple percolation
329 note that p in the range (0.4, 0.7) gives the most interesting mazes
331 # Arguments
332 - `grid_shape: Coord`: the shape of the grid
333 - `lattice_dim: int`: the dimension of the lattice (default: `2`)
334 - `p: float`: the probability of a cell being accessible (default: `0.5`)
335 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
336 """
337 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018
338 grid_shape_: Coord = np.array(grid_shape)
340 start_coord = _random_start_coord(grid_shape_, start_coord)
342 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
344 connection_list = _fill_edges_with_walls(connection_list)
346 output: LatticeMaze = LatticeMaze(
347 connection_list=connection_list,
348 generation_meta=dict(
349 func_name="gen_percolation",
350 grid_shape=grid_shape_,
351 percolation_p=p,
352 start_coord=start_coord,
353 ),
354 )
356 # generation_meta is sometimes None, but not here since we just made it a dict above
357 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index]
358 start_coord,
359 )
361 return output
363 @staticmethod
364 def gen_dfs_percolation(
365 grid_shape: Coord | CoordTup,
366 p: float = 0.4,
367 lattice_dim: int = 2,
368 accessible_cells: int | None = None,
369 max_tree_depth: int | None = None,
370 start_coord: Coord | None = None,
371 ) -> LatticeMaze:
372 """dfs and then percolation (adds cycles)"""
373 grid_shape_: Coord = np.array(grid_shape)
374 start_coord = _random_start_coord(grid_shape_, start_coord)
376 # generate initial maze via dfs
377 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
378 grid_shape=grid_shape_,
379 lattice_dim=lattice_dim,
380 accessible_cells=accessible_cells,
381 max_tree_depth=max_tree_depth,
382 start_coord=start_coord,
383 )
385 # percolate
386 connection_list_perc: np.ndarray = (
387 np.random.rand(*maze.connection_list.shape) < p
388 )
389 connection_list_perc = _fill_edges_with_walls(connection_list_perc)
391 maze.__dict__["connection_list"] = np.logical_or(
392 maze.connection_list,
393 connection_list_perc,
394 )
396 # generation_meta is sometimes None, but not here since we just made it a dict above
397 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index]
398 maze.generation_meta["percolation_p"] = p # type: ignore[index]
399 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index]
400 start_coord,
401 )
403 return maze
405 @staticmethod
406 def gen_kruskal(
407 grid_shape: "Coord | CoordTup",
408 lattice_dim: int = 2,
409 start_coord: "Coord | None" = None,
410 ) -> "LatticeMaze":
411 """Generate a maze using Kruskal's algorithm.
413 This function generates a random spanning tree over a grid using Kruskal's algorithm.
414 Each cell is treated as a node, and all valid adjacent edges are listed and processed
415 in random order. An edge is added (i.e. its passage carved) only if it connects two cells
416 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
417 without cycles.
419 https://en.wikipedia.org/wiki/Kruskal's_algorithm
421 # Parameters:
422 - `grid_shape : Coord | CoordTup`
423 The shape of the maze grid (for example, `(n_rows, n_cols)`).
424 - `lattice_dim : int`
425 The lattice dimension (default is `2`).
426 - `start_coord : Coord | None`
427 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
428 - `**kwargs`
429 Additional keyword arguments (currently unused).
431 # Returns:
432 - `LatticeMaze`
433 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
435 # Usage:
436 ```python
437 maze = gen_kruskal((10, 10))
438 ```
439 """
440 assert lattice_dim == 2, ( # noqa: PLR2004
441 "Kruskal's algorithm is only implemented for 2D lattices."
442 )
443 # Convert grid_shape to a tuple of ints
444 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment]
445 n_rows, n_cols = grid_shape_
447 # Initialize union-find data structure.
448 parent: dict[tuple[int, int], tuple[int, int]] = {}
450 def find(cell: tuple[int, int]) -> tuple[int, int]:
451 while parent[cell] != cell:
452 parent[cell] = parent[parent[cell]]
453 cell = parent[cell]
454 return cell
456 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
457 root1 = find(cell1)
458 root2 = find(cell2)
459 parent[root2] = root1
461 # Initialize each cell as its own set.
462 for i in range(n_rows):
463 for j in range(n_cols):
464 parent[(i, j)] = (i, j)
466 # List all possible edges.
467 # For vertical edges (i.e. connecting a cell to its right neighbor):
468 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
469 for i in range(n_rows):
470 for j in range(n_cols - 1):
471 edges.append(((i, j), (i, j + 1), 1))
472 # For horizontal edges (i.e. connecting a cell to its bottom neighbor):
473 for i in range(n_rows - 1):
474 for j in range(n_cols):
475 edges.append(((i, j), (i + 1, j), 0))
477 # Shuffle the list of edges.
478 import random
480 random.shuffle(edges)
482 # Initialize connection_list with no connections.
483 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
484 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
485 import numpy as np
487 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
489 # Process each edge; if it connects two different trees, union them and carve the passage.
490 for cell1, cell2, direction in edges:
491 if find(cell1) != find(cell2):
492 union(cell1, cell2)
493 if direction == 0:
494 # Horizontal edge: connection is stored in connection_list[0] at cell1.
495 connection_list[0, cell1[0], cell1[1]] = True
496 else:
497 # Vertical edge: connection is stored in connection_list[1] at cell1.
498 connection_list[1, cell1[0], cell1[1]] = True
500 if start_coord is None:
501 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment]
503 generation_meta: dict = dict(
504 func_name="gen_kruskal",
505 grid_shape=grid_shape_,
506 start_coord=start_coord,
507 algorithm="kruskal",
508 fully_connected=True,
509 )
510 return LatticeMaze(
511 connection_list=connection_list, generation_meta=generation_meta
512 )
514 @staticmethod
515 def gen_recursive_division(
516 grid_shape: "Coord | CoordTup",
517 lattice_dim: int = 2,
518 start_coord: "Coord | None" = None,
519 ) -> "LatticeMaze":
520 """Generate a maze using the recursive division algorithm.
522 This function generates a maze by recursively dividing the grid with walls and carving a single
523 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
524 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
525 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
527 # Parameters:
528 - `grid_shape : Coord | CoordTup`
529 The shape of the maze grid (e.g., `(n_rows, n_cols)`).
530 - `lattice_dim : int`
531 The lattice dimension (default is `2`).
532 - `start_coord : Coord | None`
533 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
534 - `**kwargs`
535 Additional keyword arguments (currently unused).
537 # Returns:
538 - `LatticeMaze`
539 A maze represented by a connection list, generated using recursive division.
541 # Usage:
542 ```python
543 maze = gen_recursive_division((10, 10))
544 ```
545 """
546 assert lattice_dim == 2, ( # noqa: PLR2004
547 "Recursive division algorithm is only implemented for 2D lattices."
548 )
549 # Convert grid_shape to a tuple of ints.
550 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment]
551 n_rows, n_cols = grid_shape_
553 # Initialize connection_list as a fully connected grid.
554 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
555 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
556 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
557 connection_list[0, : n_rows - 1, :] = True
558 connection_list[1, :, : n_cols - 1] = True
560 def divide(x: int, y: int, width: int, height: int) -> None:
561 """Recursively divide the region starting at (x, y) with the given width and height.
563 Removes connections along the chosen division line except for one randomly chosen gap.
564 """
565 if width < 2 or height < 2: # noqa: PLR2004
566 return
568 if width > height:
569 # Vertical division.
570 wall_col = random.randint(x + 1, x + width - 1)
571 gap_row = random.randint(y, y + height - 1)
572 for row in range(y, y + height):
573 if row == gap_row:
574 continue
575 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
576 if wall_col - 1 < n_cols - 1:
577 connection_list[1, row, wall_col - 1] = False
578 # Recurse on the left and right subregions.
579 divide(x, y, wall_col - x, height)
580 divide(wall_col, y, x + width - wall_col, height)
581 else:
582 # Horizontal division.
583 wall_row = random.randint(y + 1, y + height - 1)
584 gap_col = random.randint(x, x + width - 1)
585 for col in range(x, x + width):
586 if col == gap_col:
587 continue
588 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
589 if wall_row - 1 < n_rows - 1:
590 connection_list[0, wall_row - 1, col] = False
591 # Recurse on the top and bottom subregions.
592 divide(x, y, width, wall_row - y)
593 divide(x, wall_row, width, y + height - wall_row)
595 # Begin the division on the full grid.
596 divide(0, 0, n_cols, n_rows)
598 if start_coord is None:
599 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment]
601 generation_meta: dict = dict(
602 func_name="gen_recursive_division",
603 grid_shape=grid_shape_,
604 start_coord=start_coord,
605 algorithm="recursive_division",
606 fully_connected=True,
607 )
608 return LatticeMaze(
609 connection_list=connection_list, generation_meta=generation_meta
610 )
613P_GeneratorKwargs = ParamSpec("P_GeneratorKwargs")
614MazeGeneratorFunc = Callable[
615 Concatenate[Coord | CoordTup, P_GeneratorKwargs],
616 LatticeMaze,
617]
620# cant automatically populate this because it messes with pickling :(
621GENERATORS_MAP: dict[str, MazeGeneratorFunc] = {
622 "gen_dfs": LatticeMazeGenerators.gen_dfs,
623 # TYPING: error: Dict entry 1 has incompatible type
624 # "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]";
625 # expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]" [dict-item]
626 # gen_wilson takes no kwargs and we check that the kwargs are empty
627 # but mypy doesnt like this, `Any` != `KwArg(Any)`
628 "gen_wilson": LatticeMazeGenerators.gen_wilson, # type: ignore[dict-item]
629 "gen_percolation": LatticeMazeGenerators.gen_percolation,
630 "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation,
631 "gen_prim": LatticeMazeGenerators.gen_prim,
632 "gen_kruskal": LatticeMazeGenerators.gen_kruskal,
633 "gen_recursive_division": LatticeMazeGenerators.gen_recursive_division,
634}
635"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`"
637_GENERATORS_PERCOLATED: list[str] = [
638 "gen_percolation",
639 "gen_dfs_percolation",
640]
641"""list of generator names that generate percolated mazes
642we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail
643this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array`
644"""
647# TODO: we should deprecate this, always get a dataset when you want a maze with a solution
648def get_maze_with_solution(
649 gen_name: str,
650 grid_shape: Coord | CoordTup,
651 maze_ctor_kwargs: dict | None = None,
652) -> SolvedMaze:
653 "helper function to get a maze already with a solution"
654 if maze_ctor_kwargs is None:
655 maze_ctor_kwargs = dict()
656 # TYPING: error: Too few arguments [call-arg]
657 # not sure why this is happening -- doesnt recognize the kwargs?
658 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg]
659 solution: CoordArray = np.array(maze.generate_random_path())
660 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)