Coverage for maze_dataset/generation/generators.py: 83%
206 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`"""
3import random
4import warnings
5from typing import Any, Callable
7import numpy as np
8from jaxtyping import Bool
10from maze_dataset.constants import CoordArray, CoordTup
11from maze_dataset.generation.seed import GLOBAL_SEED
12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze
13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls
15numpy_rng = np.random.default_rng(GLOBAL_SEED)
16random.seed(GLOBAL_SEED)
19def _random_start_coord(
20 grid_shape: Coord,
21 start_coord: Coord | CoordTup | None,
22) -> Coord:
23 "picking a random start coord within the bounds of `grid_shape` if none is provided"
24 start_coord_: Coord
25 if start_coord is None:
26 start_coord_ = np.random.randint(
27 0, # lower bound
28 np.maximum(grid_shape - 1, 1), # upper bound (at least 1)
29 size=len(grid_shape), # dimensionality
30 )
31 else:
32 start_coord_ = np.array(start_coord)
34 return start_coord_
37def get_neighbors_in_bounds(
38 coord: Coord,
39 grid_shape: Coord,
40) -> CoordArray:
41 "get all neighbors of a coordinate that are within the bounds of the grid"
42 # get all neighbors
43 neighbors: CoordArray = coord + NEIGHBORS_MASK
45 # filter neighbors by being within grid bounds
46 neighbors_in_bounds: CoordArray = neighbors[
47 (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
48 ]
50 return neighbors_in_bounds
53class LatticeMazeGenerators:
54 """namespace for lattice maze generation algorithms
56 examples of generated mazes can be found here:
57 https://understanding-search.github.io/maze-dataset/examples/maze_examples.html
58 """
60 @staticmethod
61 def gen_dfs(
62 grid_shape: Coord | CoordTup,
63 lattice_dim: int = 2,
64 accessible_cells: float | None = None,
65 max_tree_depth: float | None = None,
66 do_forks: bool = True,
67 randomized_stack: bool = False,
68 start_coord: Coord | None = None,
69 ) -> LatticeMaze:
70 """generate a lattice maze using depth first search, iterative
72 # Arguments
73 - `grid_shape: Coord`: the shape of the grid
74 - `lattice_dim: int`: the dimension of the lattice
75 (default: `2`)
76 - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
77 (default: `None`)
78 - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
79 (default: `None`)
80 - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
81 - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
83 # algorithm
84 1. Choose the initial cell, mark it as visited and push it to the stack
85 2. While the stack is not empty
86 1. Pop a cell from the stack and make it a current cell
87 2. If the current cell has any neighbours which have not been visited
88 1. Push the current cell to the stack
89 2. Choose one of the unvisited neighbours
90 3. Remove the wall between the current cell and the chosen cell
91 4. Mark the chosen cell as visited and push it to the stack
92 """
93 # Default values if no constraints have been passed
94 grid_shape_: Coord = np.array(grid_shape)
95 n_total_cells: int = int(np.prod(grid_shape_))
97 n_accessible_cells: int
98 if accessible_cells is None:
99 n_accessible_cells = n_total_cells
100 elif isinstance(accessible_cells, float):
101 assert accessible_cells <= 1, (
102 f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
103 )
105 n_accessible_cells = int(accessible_cells * n_total_cells)
106 else:
107 assert isinstance(accessible_cells, int)
108 n_accessible_cells = accessible_cells
110 if max_tree_depth is None:
111 max_tree_depth = (
112 2 * n_total_cells
113 ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
114 elif isinstance(max_tree_depth, float):
115 assert max_tree_depth <= 1, (
116 f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
117 )
119 max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
121 # choose a random start coord
122 start_coord = _random_start_coord(grid_shape_, start_coord)
124 # initialize the maze with no connections
125 connection_list: ConnectionList = np.zeros(
126 (lattice_dim, grid_shape_[0], grid_shape_[1]),
127 dtype=np.bool_,
128 )
130 # initialize the stack with the target coord
131 visited_cells: set[tuple[int, int]] = set()
132 visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol
133 stack: list[Coord] = [start_coord]
135 # initialize tree_depth_counter
136 current_tree_depth: int = 1
138 # loop until the stack is empty or n_connected_cells is reached
139 while stack and (len(visited_cells) < n_accessible_cells):
140 # get the current coord from the stack
141 current_coord: Coord
142 if randomized_stack:
143 current_coord = stack.pop(random.randint(0, len(stack) - 1))
144 else:
145 current_coord = stack.pop()
147 # filter neighbors by being within grid bounds and being unvisited
148 unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
149 (neighbor, delta)
150 for neighbor, delta in zip(
151 current_coord + NEIGHBORS_MASK,
152 NEIGHBORS_MASK,
153 strict=False,
154 )
155 if (
156 (tuple(neighbor) not in visited_cells)
157 and (0 <= neighbor[0] < grid_shape_[0])
158 and (0 <= neighbor[1] < grid_shape_[1])
159 )
160 ]
162 # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
163 if unvisited_neighbors_deltas and (
164 current_tree_depth <= max_tree_depth / 2
165 ):
166 # if we want a maze without forks, simply don't add the current coord back to the stack
167 if do_forks and (len(unvisited_neighbors_deltas) > 1):
168 stack.append(current_coord)
170 # choose one of the unvisited neighbors
171 chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
173 # add connection
174 dim: int = int(np.argmax(np.abs(delta)))
175 # if positive, down/right from current coord
176 # if negative, up/left from current coord (down/right from neighbor)
177 clist_node: Coord = (
178 current_coord if (delta.sum() > 0) else chosen_neighbor
179 )
180 connection_list[dim, clist_node[0], clist_node[1]] = True
182 # add to visited cells and stack
183 visited_cells.add(tuple(chosen_neighbor))
184 stack.append(chosen_neighbor)
186 # Update current tree depth
187 current_tree_depth += 1
188 else:
189 current_tree_depth -= 1
191 return LatticeMaze(
192 connection_list=connection_list,
193 generation_meta=dict(
194 func_name="gen_dfs",
195 grid_shape=grid_shape_,
196 start_coord=start_coord,
197 n_accessible_cells=int(n_accessible_cells),
198 max_tree_depth=int(max_tree_depth),
199 # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
200 # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
201 # treated as fully connected even when it is most certainly not, causing solving the maze to break
202 fully_connected=bool(len(visited_cells) == n_total_cells),
203 visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
204 ),
205 )
207 @staticmethod
208 def gen_prim(
209 grid_shape: Coord | CoordTup,
210 lattice_dim: int = 2,
211 accessible_cells: float | None = None,
212 max_tree_depth: float | None = None,
213 do_forks: bool = True,
214 start_coord: Coord | None = None,
215 ) -> LatticeMaze:
216 "(broken!) generate a lattice maze using Prim's algorithm"
217 warnings.warn(
218 "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
219 )
220 return LatticeMazeGenerators.gen_dfs(
221 grid_shape=grid_shape,
222 lattice_dim=lattice_dim,
223 accessible_cells=accessible_cells,
224 max_tree_depth=max_tree_depth,
225 do_forks=do_forks,
226 start_coord=start_coord,
227 randomized_stack=True,
228 )
230 @staticmethod
231 def gen_wilson(
232 grid_shape: Coord | CoordTup,
233 **kwargs,
234 ) -> LatticeMaze:
235 """Generate a lattice maze using Wilson's algorithm.
237 # Algorithm
238 Wilson's algorithm generates an unbiased (random) maze
239 sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
240 acyclic and all cells are part of a unique connected space.
241 https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
242 """
243 assert not kwargs, (
244 f"gen_wilson does not take any additional arguments, got {kwargs = }"
245 )
247 grid_shape_: Coord = np.array(grid_shape)
249 # Initialize grid and visited cells
250 connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
251 visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
253 # Choose a random cell and mark it as visited
254 start_coord: Coord = _random_start_coord(grid_shape_, None)
255 visited[start_coord[0], start_coord[1]] = True
256 del start_coord
258 while not visited.all():
259 # Perform loop-erased random walk from another random cell
261 # Choose walk_start only from unvisited cells
262 unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
263 walk_start: Coord = unvisited_coords[
264 np.random.choice(unvisited_coords.shape[0])
265 ]
267 # Perform the random walk
268 path: list[Coord] = [walk_start]
269 current: Coord = walk_start
271 # exit the loop once the current path hits a visited cell
272 while not visited[current[0], current[1]]:
273 # find a valid neighbor (one always exists on a lattice)
274 neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
275 next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
277 # Check for loop
278 loop_exit: int | None = None
279 for i, p in enumerate(path):
280 if np.array_equal(next_cell, p):
281 loop_exit = i
282 break
284 # erase the loop, or continue the walk
285 if loop_exit is not None:
286 # this removes everything after and including the loop start
287 path = path[: loop_exit + 1]
288 # reset current cell to end of path
289 current = path[-1]
290 else:
291 path.append(next_cell)
292 current = next_cell
294 # Add the path to the maze
295 for i in range(len(path) - 1):
296 c_1: Coord = path[i]
297 c_2: Coord = path[i + 1]
299 # find the dimension of the connection
300 delta: Coord = c_2 - c_1
301 dim: int = int(np.argmax(np.abs(delta)))
303 # if positive, down/right from current coord
304 # if negative, up/left from current coord (down/right from neighbor)
305 clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
306 connection_list[dim, clist_node[0], clist_node[1]] = True
307 visited[c_1[0], c_1[1]] = True
308 # we dont add c_2 because the last c_2 will have already been visited
310 return LatticeMaze(
311 connection_list=connection_list,
312 generation_meta=dict(
313 func_name="gen_wilson",
314 grid_shape=grid_shape_,
315 fully_connected=True,
316 ),
317 )
319 @staticmethod
320 def gen_percolation(
321 grid_shape: Coord | CoordTup,
322 p: float = 0.4,
323 lattice_dim: int = 2,
324 start_coord: Coord | None = None,
325 ) -> LatticeMaze:
326 """generate a lattice maze using simple percolation
328 note that p in the range (0.4, 0.7) gives the most interesting mazes
330 # Arguments
331 - `grid_shape: Coord`: the shape of the grid
332 - `lattice_dim: int`: the dimension of the lattice (default: `2`)
333 - `p: float`: the probability of a cell being accessible (default: `0.5`)
334 - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
335 """
336 assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018
337 grid_shape_: Coord = np.array(grid_shape)
339 start_coord = _random_start_coord(grid_shape_, start_coord)
341 connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
343 connection_list = _fill_edges_with_walls(connection_list)
345 output: LatticeMaze = LatticeMaze(
346 connection_list=connection_list,
347 generation_meta=dict(
348 func_name="gen_percolation",
349 grid_shape=grid_shape_,
350 percolation_p=p,
351 start_coord=start_coord,
352 ),
353 )
355 # generation_meta is sometimes None, but not here since we just made it a dict above
356 output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index]
357 start_coord,
358 )
360 return output
362 @staticmethod
363 def gen_dfs_percolation(
364 grid_shape: Coord | CoordTup,
365 p: float = 0.4,
366 lattice_dim: int = 2,
367 accessible_cells: int | None = None,
368 max_tree_depth: int | None = None,
369 start_coord: Coord | None = None,
370 ) -> LatticeMaze:
371 """dfs and then percolation (adds cycles)"""
372 grid_shape_: Coord = np.array(grid_shape)
373 start_coord = _random_start_coord(grid_shape_, start_coord)
375 # generate initial maze via dfs
376 maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
377 grid_shape=grid_shape_,
378 lattice_dim=lattice_dim,
379 accessible_cells=accessible_cells,
380 max_tree_depth=max_tree_depth,
381 start_coord=start_coord,
382 )
384 # percolate
385 connection_list_perc: np.ndarray = (
386 np.random.rand(*maze.connection_list.shape) < p
387 )
388 connection_list_perc = _fill_edges_with_walls(connection_list_perc)
390 maze.__dict__["connection_list"] = np.logical_or(
391 maze.connection_list,
392 connection_list_perc,
393 )
395 # generation_meta is sometimes None, but not here since we just made it a dict above
396 maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index]
397 maze.generation_meta["percolation_p"] = p # type: ignore[index]
398 maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index]
399 start_coord,
400 )
402 return maze
404 @staticmethod
405 def gen_kruskal(
406 grid_shape: "Coord | CoordTup",
407 lattice_dim: int = 2,
408 start_coord: "Coord | None" = None,
409 ) -> "LatticeMaze":
410 """Generate a maze using Kruskal's algorithm.
412 This function generates a random spanning tree over a grid using Kruskal's algorithm.
413 Each cell is treated as a node, and all valid adjacent edges are listed and processed
414 in random order. An edge is added (i.e. its passage carved) only if it connects two cells
415 that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
416 without cycles.
418 https://en.wikipedia.org/wiki/Kruskal's_algorithm
420 # Parameters:
421 - `grid_shape : Coord | CoordTup`
422 The shape of the maze grid (for example, `(n_rows, n_cols)`).
423 - `lattice_dim : int`
424 The lattice dimension (default is `2`).
425 - `start_coord : Coord | None`
426 Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
427 - `**kwargs`
428 Additional keyword arguments (currently unused).
430 # Returns:
431 - `LatticeMaze`
432 A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
434 # Usage:
435 ```python
436 maze = gen_kruskal((10, 10))
437 ```
438 """
439 assert lattice_dim == 2, ( # noqa: PLR2004
440 "Kruskal's algorithm is only implemented for 2D lattices."
441 )
442 # Convert grid_shape to a tuple of ints
443 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment]
444 n_rows, n_cols = grid_shape_
446 # Initialize union-find data structure.
447 parent: dict[tuple[int, int], tuple[int, int]] = {}
449 def find(cell: tuple[int, int]) -> tuple[int, int]:
450 while parent[cell] != cell:
451 parent[cell] = parent[parent[cell]]
452 cell = parent[cell]
453 return cell
455 def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
456 root1 = find(cell1)
457 root2 = find(cell2)
458 parent[root2] = root1
460 # Initialize each cell as its own set.
461 for i in range(n_rows):
462 for j in range(n_cols):
463 parent[(i, j)] = (i, j)
465 # List all possible edges.
466 # For vertical edges (i.e. connecting a cell to its right neighbor):
467 edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
468 for i in range(n_rows):
469 for j in range(n_cols - 1):
470 edges.append(((i, j), (i, j + 1), 1))
471 # For horizontal edges (i.e. connecting a cell to its bottom neighbor):
472 for i in range(n_rows - 1):
473 for j in range(n_cols):
474 edges.append(((i, j), (i + 1, j), 0))
476 # Shuffle the list of edges.
477 import random
479 random.shuffle(edges)
481 # Initialize connection_list with no connections.
482 # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
483 # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
484 import numpy as np
486 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
488 # Process each edge; if it connects two different trees, union them and carve the passage.
489 for cell1, cell2, direction in edges:
490 if find(cell1) != find(cell2):
491 union(cell1, cell2)
492 if direction == 0:
493 # Horizontal edge: connection is stored in connection_list[0] at cell1.
494 connection_list[0, cell1[0], cell1[1]] = True
495 else:
496 # Vertical edge: connection is stored in connection_list[1] at cell1.
497 connection_list[1, cell1[0], cell1[1]] = True
499 if start_coord is None:
500 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment]
502 generation_meta: dict = dict(
503 func_name="gen_kruskal",
504 grid_shape=grid_shape_,
505 start_coord=start_coord,
506 algorithm="kruskal",
507 fully_connected=True,
508 )
509 return LatticeMaze(
510 connection_list=connection_list, generation_meta=generation_meta
511 )
513 @staticmethod
514 def gen_recursive_division(
515 grid_shape: "Coord | CoordTup",
516 lattice_dim: int = 2,
517 start_coord: "Coord | None" = None,
518 ) -> "LatticeMaze":
519 """Generate a maze using the recursive division algorithm.
521 This function generates a maze by recursively dividing the grid with walls and carving a single
522 passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
523 cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
524 The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
526 # Parameters:
527 - `grid_shape : Coord | CoordTup`
528 The shape of the maze grid (e.g., `(n_rows, n_cols)`).
529 - `lattice_dim : int`
530 The lattice dimension (default is `2`).
531 - `start_coord : Coord | None`
532 Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
533 - `**kwargs`
534 Additional keyword arguments (currently unused).
536 # Returns:
537 - `LatticeMaze`
538 A maze represented by a connection list, generated using recursive division.
540 # Usage:
541 ```python
542 maze = gen_recursive_division((10, 10))
543 ```
544 """
545 assert lattice_dim == 2, ( # noqa: PLR2004
546 "Recursive division algorithm is only implemented for 2D lattices."
547 )
548 # Convert grid_shape to a tuple of ints.
549 grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment]
550 n_rows, n_cols = grid_shape_
552 # Initialize connection_list as a fully connected grid.
553 # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
554 # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
555 connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
556 connection_list[0, : n_rows - 1, :] = True
557 connection_list[1, :, : n_cols - 1] = True
559 def divide(x: int, y: int, width: int, height: int) -> None:
560 """Recursively divide the region starting at (x, y) with the given width and height.
562 Removes connections along the chosen division line except for one randomly chosen gap.
563 """
564 if width < 2 or height < 2: # noqa: PLR2004
565 return
567 if width > height:
568 # Vertical division.
569 wall_col = random.randint(x + 1, x + width - 1)
570 gap_row = random.randint(y, y + height - 1)
571 for row in range(y, y + height):
572 if row == gap_row:
573 continue
574 # Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
575 if wall_col - 1 < n_cols - 1:
576 connection_list[1, row, wall_col - 1] = False
577 # Recurse on the left and right subregions.
578 divide(x, y, wall_col - x, height)
579 divide(wall_col, y, x + width - wall_col, height)
580 else:
581 # Horizontal division.
582 wall_row = random.randint(y + 1, y + height - 1)
583 gap_col = random.randint(x, x + width - 1)
584 for col in range(x, x + width):
585 if col == gap_col:
586 continue
587 # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
588 if wall_row - 1 < n_rows - 1:
589 connection_list[0, wall_row - 1, col] = False
590 # Recurse on the top and bottom subregions.
591 divide(x, y, width, wall_row - y)
592 divide(x, wall_row, width, y + height - wall_row)
594 # Begin the division on the full grid.
595 divide(0, 0, n_cols, n_rows)
597 if start_coord is None:
598 start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment]
600 generation_meta: dict = dict(
601 func_name="gen_recursive_division",
602 grid_shape=grid_shape_,
603 start_coord=start_coord,
604 algorithm="recursive_division",
605 fully_connected=True,
606 )
607 return LatticeMaze(
608 connection_list=connection_list, generation_meta=generation_meta
609 )
612# cant automatically populate this because it messes with pickling :(
613GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = {
614 "gen_dfs": LatticeMazeGenerators.gen_dfs,
615 # TYPING: error: Dict entry 1 has incompatible type
616 # "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]";
617 # expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]" [dict-item]
618 # gen_wilson takes no kwargs and we check that the kwargs are empty
619 # but mypy doesnt like this, `Any` != `KwArg(Any)`
620 "gen_wilson": LatticeMazeGenerators.gen_wilson, # type: ignore[dict-item]
621 "gen_percolation": LatticeMazeGenerators.gen_percolation,
622 "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation,
623 "gen_prim": LatticeMazeGenerators.gen_prim,
624 "gen_kruskal": LatticeMazeGenerators.gen_kruskal,
625 "gen_recursive_division": LatticeMazeGenerators.gen_recursive_division,
626}
627"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`"
629_GENERATORS_PERCOLATED: list[str] = [
630 "gen_percolation",
631 "gen_dfs_percolation",
632]
633"""list of generator names that generate percolated mazes
634we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail
635this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array`
636"""
639def get_maze_with_solution(
640 gen_name: str,
641 grid_shape: Coord | CoordTup,
642 maze_ctor_kwargs: dict | None = None,
643) -> SolvedMaze:
644 "helper function to get a maze already with a solution"
645 if maze_ctor_kwargs is None:
646 maze_ctor_kwargs = dict()
647 # TYPING: error: Too few arguments [call-arg]
648 # not sure why this is happening -- doesnt recognize the kwargs?
649 maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg]
650 solution: CoordArray = np.array(maze.generate_random_path())
651 return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)