docs for maze-dataset v1.4.0
View Source on GitHub

maze_dataset.generation.generators

generation functions have signature (grid_shape: Coord, **kwargs) -> LatticeMaze and are methods in LatticeMazeGenerators


  1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`"""
  2
  3import random
  4import warnings
  5from typing import Callable, Concatenate, ParamSpec
  6
  7import numpy as np
  8from jaxtyping import Bool
  9
 10from maze_dataset.constants import CoordArray, CoordTup
 11from maze_dataset.generation.seed import GLOBAL_SEED
 12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze
 13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls
 14
 15_NUMPY_RNG: np.random.Generator = np.random.default_rng(GLOBAL_SEED)
 16random.seed(GLOBAL_SEED)
 17
 18
 19def _random_start_coord(
 20	grid_shape: Coord,
 21	start_coord: Coord | CoordTup | None,
 22) -> Coord:
 23	"picking a random start coord within the bounds of `grid_shape` if none is provided"
 24	start_coord_: Coord
 25	if start_coord is None:
 26		start_coord_ = np.random.randint(
 27			0,  # lower bound
 28			np.maximum(grid_shape - 1, 1),  # upper bound (at least 1)
 29			size=len(grid_shape),  # dimensionality
 30		)
 31	else:
 32		start_coord_ = np.array(start_coord)
 33
 34	return start_coord_
 35
 36
 37def get_neighbors_in_bounds(
 38	coord: Coord,
 39	grid_shape: Coord,
 40) -> CoordArray:
 41	"get all neighbors of a coordinate that are within the bounds of the grid"
 42	# get all neighbors
 43	neighbors: CoordArray = coord + NEIGHBORS_MASK
 44
 45	# filter neighbors by being within grid bounds
 46	neighbors_in_bounds: CoordArray = neighbors[
 47		(neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
 48	]
 49
 50	return neighbors_in_bounds
 51
 52
 53class LatticeMazeGenerators:
 54	"""namespace for lattice maze generation algorithms
 55
 56	examples of generated mazes can be found here:
 57	https://understanding-search.github.io/maze-dataset/examples/maze_examples.html
 58	"""
 59
 60	@staticmethod
 61	def gen_dfs(
 62		grid_shape: Coord | CoordTup,
 63		*,
 64		lattice_dim: int = 2,
 65		accessible_cells: float | None = None,
 66		max_tree_depth: float | None = None,
 67		do_forks: bool = True,
 68		randomized_stack: bool = False,
 69		start_coord: Coord | None = None,
 70	) -> LatticeMaze:
 71		"""generate a lattice maze using depth first search, iterative
 72
 73		# Arguments
 74		- `grid_shape: Coord`: the shape of the grid
 75		- `lattice_dim: int`: the dimension of the lattice
 76			(default: `2`)
 77		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 78			(default: `None`)
 79		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 80			(default: `None`)
 81		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 82		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 83
 84		# algorithm
 85		1. Choose the initial cell, mark it as visited and push it to the stack
 86		2. While the stack is not empty
 87			1. Pop a cell from the stack and make it a current cell
 88			2. If the current cell has any neighbours which have not been visited
 89				1. Push the current cell to the stack
 90				2. Choose one of the unvisited neighbours
 91				3. Remove the wall between the current cell and the chosen cell
 92				4. Mark the chosen cell as visited and push it to the stack
 93		"""
 94		# Default values if no constraints have been passed
 95		grid_shape_: Coord = np.array(grid_shape)
 96		n_total_cells: int = int(np.prod(grid_shape_))
 97
 98		n_accessible_cells: int
 99		if accessible_cells is None:
100			n_accessible_cells = n_total_cells
101		elif isinstance(accessible_cells, float):
102			assert accessible_cells <= 1, (
103				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
104			)
105
106			n_accessible_cells = int(accessible_cells * n_total_cells)
107		else:
108			assert isinstance(accessible_cells, int)
109			n_accessible_cells = accessible_cells
110
111		if max_tree_depth is None:
112			max_tree_depth = (
113				2 * n_total_cells
114			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
115		elif isinstance(max_tree_depth, float):
116			assert max_tree_depth <= 1, (
117				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
118			)
119
120			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
121
122		# choose a random start coord
123		start_coord = _random_start_coord(grid_shape_, start_coord)
124
125		# initialize the maze with no connections
126		connection_list: ConnectionList = np.zeros(
127			(lattice_dim, grid_shape_[0], grid_shape_[1]),
128			dtype=np.bool_,
129		)
130
131		# initialize the stack with the target coord
132		visited_cells: set[tuple[int, int]] = set()
133		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
134		stack: list[Coord] = [start_coord]
135
136		# initialize tree_depth_counter
137		current_tree_depth: int = 1
138
139		# loop until the stack is empty or n_connected_cells is reached
140		while stack and (len(visited_cells) < n_accessible_cells):
141			# get the current coord from the stack
142			current_coord: Coord
143			if randomized_stack:
144				current_coord = stack.pop(random.randint(0, len(stack) - 1))
145			else:
146				current_coord = stack.pop()
147
148			# filter neighbors by being within grid bounds and being unvisited
149			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
150				(neighbor, delta)
151				for neighbor, delta in zip(
152					current_coord + NEIGHBORS_MASK,
153					NEIGHBORS_MASK,
154					strict=False,
155				)
156				if (
157					(tuple(neighbor) not in visited_cells)
158					and (0 <= neighbor[0] < grid_shape_[0])
159					and (0 <= neighbor[1] < grid_shape_[1])
160				)
161			]
162
163			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
164			if unvisited_neighbors_deltas and (
165				current_tree_depth <= max_tree_depth / 2
166			):
167				# if we want a maze without forks, simply don't add the current coord back to the stack
168				if do_forks and (len(unvisited_neighbors_deltas) > 1):
169					stack.append(current_coord)
170
171				# choose one of the unvisited neighbors
172				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
173
174				# add connection
175				dim: int = int(np.argmax(np.abs(delta)))
176				# if positive, down/right from current coord
177				# if negative, up/left from current coord (down/right from neighbor)
178				clist_node: Coord = (
179					current_coord if (delta.sum() > 0) else chosen_neighbor
180				)
181				connection_list[dim, clist_node[0], clist_node[1]] = True
182
183				# add to visited cells and stack
184				visited_cells.add(tuple(chosen_neighbor))
185				stack.append(chosen_neighbor)
186
187				# Update current tree depth
188				current_tree_depth += 1
189			else:
190				current_tree_depth -= 1
191
192		return LatticeMaze(
193			connection_list=connection_list,
194			generation_meta=dict(
195				func_name="gen_dfs",
196				grid_shape=grid_shape_,
197				start_coord=start_coord,
198				n_accessible_cells=int(n_accessible_cells),
199				max_tree_depth=int(max_tree_depth),
200				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
201				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
202				# treated as fully connected even when it is most certainly not, causing solving the maze to break
203				fully_connected=bool(len(visited_cells) == n_total_cells),
204				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
205			),
206		)
207
208	@staticmethod
209	def gen_prim(
210		grid_shape: Coord | CoordTup,
211		lattice_dim: int = 2,
212		accessible_cells: float | None = None,
213		max_tree_depth: float | None = None,
214		do_forks: bool = True,
215		start_coord: Coord | None = None,
216	) -> LatticeMaze:
217		"(broken!) generate a lattice maze using Prim's algorithm"
218		warnings.warn(
219			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
220		)
221		return LatticeMazeGenerators.gen_dfs(
222			grid_shape=grid_shape,
223			lattice_dim=lattice_dim,
224			accessible_cells=accessible_cells,
225			max_tree_depth=max_tree_depth,
226			do_forks=do_forks,
227			start_coord=start_coord,
228			randomized_stack=True,
229		)
230
231	@staticmethod
232	def gen_wilson(
233		grid_shape: Coord | CoordTup,
234		**kwargs,
235	) -> LatticeMaze:
236		"""Generate a lattice maze using Wilson's algorithm.
237
238		# Algorithm
239		Wilson's algorithm generates an unbiased (random) maze
240		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
241		acyclic and all cells are part of a unique connected space.
242		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
243		"""
244		assert not kwargs, (
245			f"gen_wilson does not take any additional arguments, got {kwargs = }"
246		)
247
248		grid_shape_: Coord = np.array(grid_shape)
249
250		# Initialize grid and visited cells
251		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
252		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
253
254		# Choose a random cell and mark it as visited
255		start_coord: Coord = _random_start_coord(grid_shape_, None)
256		visited[start_coord[0], start_coord[1]] = True
257		del start_coord
258
259		while not visited.all():
260			# Perform loop-erased random walk from another random cell
261
262			# Choose walk_start only from unvisited cells
263			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
264			walk_start: Coord = unvisited_coords[
265				np.random.choice(unvisited_coords.shape[0])
266			]
267
268			# Perform the random walk
269			path: list[Coord] = [walk_start]
270			current: Coord = walk_start
271
272			# exit the loop once the current path hits a visited cell
273			while not visited[current[0], current[1]]:
274				# find a valid neighbor (one always exists on a lattice)
275				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
276				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
277
278				# Check for loop
279				loop_exit: int | None = None
280				for i, p in enumerate(path):
281					if np.array_equal(next_cell, p):
282						loop_exit = i
283						break
284
285				# erase the loop, or continue the walk
286				if loop_exit is not None:
287					# this removes everything after and including the loop start
288					path = path[: loop_exit + 1]
289					# reset current cell to end of path
290					current = path[-1]
291				else:
292					path.append(next_cell)
293					current = next_cell
294
295			# Add the path to the maze
296			for i in range(len(path) - 1):
297				c_1: Coord = path[i]
298				c_2: Coord = path[i + 1]
299
300				# find the dimension of the connection
301				delta: Coord = c_2 - c_1
302				dim: int = int(np.argmax(np.abs(delta)))
303
304				# if positive, down/right from current coord
305				# if negative, up/left from current coord (down/right from neighbor)
306				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
307				connection_list[dim, clist_node[0], clist_node[1]] = True
308				visited[c_1[0], c_1[1]] = True
309				# we dont add c_2 because the last c_2 will have already been visited
310
311		return LatticeMaze(
312			connection_list=connection_list,
313			generation_meta=dict(
314				func_name="gen_wilson",
315				grid_shape=grid_shape_,
316				fully_connected=True,
317			),
318		)
319
320	@staticmethod
321	def gen_percolation(
322		grid_shape: Coord | CoordTup,
323		p: float = 0.4,
324		lattice_dim: int = 2,
325		start_coord: Coord | None = None,
326	) -> LatticeMaze:
327		"""generate a lattice maze using simple percolation
328
329		note that p in the range (0.4, 0.7) gives the most interesting mazes
330
331		# Arguments
332		- `grid_shape: Coord`: the shape of the grid
333		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
334		- `p: float`: the probability of a cell being accessible (default: `0.5`)
335		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
336		"""
337		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
338		grid_shape_: Coord = np.array(grid_shape)
339
340		start_coord = _random_start_coord(grid_shape_, start_coord)
341
342		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
343
344		connection_list = _fill_edges_with_walls(connection_list)
345
346		output: LatticeMaze = LatticeMaze(
347			connection_list=connection_list,
348			generation_meta=dict(
349				func_name="gen_percolation",
350				grid_shape=grid_shape_,
351				percolation_p=p,
352				start_coord=start_coord,
353			),
354		)
355
356		# generation_meta is sometimes None, but not here since we just made it a dict above
357		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
358			start_coord,
359		)
360
361		return output
362
363	@staticmethod
364	def gen_dfs_percolation(
365		grid_shape: Coord | CoordTup,
366		p: float = 0.4,
367		lattice_dim: int = 2,
368		accessible_cells: int | None = None,
369		max_tree_depth: int | None = None,
370		start_coord: Coord | None = None,
371	) -> LatticeMaze:
372		"""dfs and then percolation (adds cycles)"""
373		grid_shape_: Coord = np.array(grid_shape)
374		start_coord = _random_start_coord(grid_shape_, start_coord)
375
376		# generate initial maze via dfs
377		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
378			grid_shape=grid_shape_,
379			lattice_dim=lattice_dim,
380			accessible_cells=accessible_cells,
381			max_tree_depth=max_tree_depth,
382			start_coord=start_coord,
383		)
384
385		# percolate
386		connection_list_perc: np.ndarray = (
387			np.random.rand(*maze.connection_list.shape) < p
388		)
389		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
390
391		maze.__dict__["connection_list"] = np.logical_or(
392			maze.connection_list,
393			connection_list_perc,
394		)
395
396		# generation_meta is sometimes None, but not here since we just made it a dict above
397		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
398		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
399		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
400			start_coord,
401		)
402
403		return maze
404
405	@staticmethod
406	def gen_kruskal(
407		grid_shape: "Coord | CoordTup",
408		lattice_dim: int = 2,
409		start_coord: "Coord | None" = None,
410	) -> "LatticeMaze":
411		"""Generate a maze using Kruskal's algorithm.
412
413		This function generates a random spanning tree over a grid using Kruskal's algorithm.
414		Each cell is treated as a node, and all valid adjacent edges are listed and processed
415		in random order. An edge is added (i.e. its passage carved) only if it connects two cells
416		that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
417		without cycles.
418
419		https://en.wikipedia.org/wiki/Kruskal's_algorithm
420
421		# Parameters:
422		- `grid_shape : Coord | CoordTup`
423			The shape of the maze grid (for example, `(n_rows, n_cols)`).
424		- `lattice_dim : int`
425			The lattice dimension (default is `2`).
426		- `start_coord : Coord | None`
427			Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
428		- `**kwargs`
429			Additional keyword arguments (currently unused).
430
431		# Returns:
432		- `LatticeMaze`
433			A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
434
435		# Usage:
436		```python
437		maze = gen_kruskal((10, 10))
438		```
439		"""
440		assert lattice_dim == 2, (  # noqa: PLR2004
441			"Kruskal's algorithm is only implemented for 2D lattices."
442		)
443		# Convert grid_shape to a tuple of ints
444		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
445		n_rows, n_cols = grid_shape_
446
447		# Initialize union-find data structure.
448		parent: dict[tuple[int, int], tuple[int, int]] = {}
449
450		def find(cell: tuple[int, int]) -> tuple[int, int]:
451			while parent[cell] != cell:
452				parent[cell] = parent[parent[cell]]
453				cell = parent[cell]
454			return cell
455
456		def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
457			root1 = find(cell1)
458			root2 = find(cell2)
459			parent[root2] = root1
460
461		# Initialize each cell as its own set.
462		for i in range(n_rows):
463			for j in range(n_cols):
464				parent[(i, j)] = (i, j)
465
466		# List all possible edges.
467		# For vertical edges (i.e. connecting a cell to its right neighbor):
468		edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
469		for i in range(n_rows):
470			for j in range(n_cols - 1):
471				edges.append(((i, j), (i, j + 1), 1))
472		# For horizontal edges (i.e. connecting a cell to its bottom neighbor):
473		for i in range(n_rows - 1):
474			for j in range(n_cols):
475				edges.append(((i, j), (i + 1, j), 0))
476
477		# Shuffle the list of edges.
478		import random
479
480		random.shuffle(edges)
481
482		# Initialize connection_list with no connections.
483		# connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
484		# connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
485		import numpy as np
486
487		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
488
489		# Process each edge; if it connects two different trees, union them and carve the passage.
490		for cell1, cell2, direction in edges:
491			if find(cell1) != find(cell2):
492				union(cell1, cell2)
493				if direction == 0:
494					# Horizontal edge: connection is stored in connection_list[0] at cell1.
495					connection_list[0, cell1[0], cell1[1]] = True
496				else:
497					# Vertical edge: connection is stored in connection_list[1] at cell1.
498					connection_list[1, cell1[0], cell1[1]] = True
499
500		if start_coord is None:
501			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
502
503		generation_meta: dict = dict(
504			func_name="gen_kruskal",
505			grid_shape=grid_shape_,
506			start_coord=start_coord,
507			algorithm="kruskal",
508			fully_connected=True,
509		)
510		return LatticeMaze(
511			connection_list=connection_list, generation_meta=generation_meta
512		)
513
514	@staticmethod
515	def gen_recursive_division(
516		grid_shape: "Coord | CoordTup",
517		lattice_dim: int = 2,
518		start_coord: "Coord | None" = None,
519	) -> "LatticeMaze":
520		"""Generate a maze using the recursive division algorithm.
521
522		This function generates a maze by recursively dividing the grid with walls and carving a single
523		passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
524		cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
525		The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
526
527		# Parameters:
528		- `grid_shape : Coord | CoordTup`
529			The shape of the maze grid (e.g., `(n_rows, n_cols)`).
530		- `lattice_dim : int`
531			The lattice dimension (default is `2`).
532		- `start_coord : Coord | None`
533			Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
534		- `**kwargs`
535			Additional keyword arguments (currently unused).
536
537		# Returns:
538		- `LatticeMaze`
539			A maze represented by a connection list, generated using recursive division.
540
541		# Usage:
542		```python
543		maze = gen_recursive_division((10, 10))
544		```
545		"""
546		assert lattice_dim == 2, (  # noqa: PLR2004
547			"Recursive division algorithm is only implemented for 2D lattices."
548		)
549		# Convert grid_shape to a tuple of ints.
550		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
551		n_rows, n_cols = grid_shape_
552
553		# Initialize connection_list as a fully connected grid.
554		# For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
555		# For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
556		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
557		connection_list[0, : n_rows - 1, :] = True
558		connection_list[1, :, : n_cols - 1] = True
559
560		def divide(x: int, y: int, width: int, height: int) -> None:
561			"""Recursively divide the region starting at (x, y) with the given width and height.
562
563			Removes connections along the chosen division line except for one randomly chosen gap.
564			"""
565			if width < 2 or height < 2:  # noqa: PLR2004
566				return
567
568			if width > height:
569				# Vertical division.
570				wall_col = random.randint(x + 1, x + width - 1)
571				gap_row = random.randint(y, y + height - 1)
572				for row in range(y, y + height):
573					if row == gap_row:
574						continue
575					# Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
576					if wall_col - 1 < n_cols - 1:
577						connection_list[1, row, wall_col - 1] = False
578				# Recurse on the left and right subregions.
579				divide(x, y, wall_col - x, height)
580				divide(wall_col, y, x + width - wall_col, height)
581			else:
582				# Horizontal division.
583				wall_row = random.randint(y + 1, y + height - 1)
584				gap_col = random.randint(x, x + width - 1)
585				for col in range(x, x + width):
586					if col == gap_col:
587						continue
588					# Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
589					if wall_row - 1 < n_rows - 1:
590						connection_list[0, wall_row - 1, col] = False
591				# Recurse on the top and bottom subregions.
592				divide(x, y, width, wall_row - y)
593				divide(x, wall_row, width, y + height - wall_row)
594
595		# Begin the division on the full grid.
596		divide(0, 0, n_cols, n_rows)
597
598		if start_coord is None:
599			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
600
601		generation_meta: dict = dict(
602			func_name="gen_recursive_division",
603			grid_shape=grid_shape_,
604			start_coord=start_coord,
605			algorithm="recursive_division",
606			fully_connected=True,
607		)
608		return LatticeMaze(
609			connection_list=connection_list, generation_meta=generation_meta
610		)
611
612
613P_GeneratorKwargs = ParamSpec("P_GeneratorKwargs")
614MazeGeneratorFunc = Callable[
615	Concatenate[Coord | CoordTup, P_GeneratorKwargs],
616	LatticeMaze,
617]
618
619
620# cant automatically populate this because it messes with pickling :(
621GENERATORS_MAP: dict[str, MazeGeneratorFunc] = {
622	"gen_dfs": LatticeMazeGenerators.gen_dfs,
623	# TYPING: error: Dict entry 1 has incompatible type
624	# "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]";
625	# expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]"  [dict-item]
626	# gen_wilson takes no kwargs and we check that the kwargs are empty
627	# but mypy doesnt like this, `Any` != `KwArg(Any)`
628	"gen_wilson": LatticeMazeGenerators.gen_wilson,  # type: ignore[dict-item]
629	"gen_percolation": LatticeMazeGenerators.gen_percolation,
630	"gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation,
631	"gen_prim": LatticeMazeGenerators.gen_prim,
632	"gen_kruskal": LatticeMazeGenerators.gen_kruskal,
633	"gen_recursive_division": LatticeMazeGenerators.gen_recursive_division,
634}
635"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`"
636
637_GENERATORS_PERCOLATED: list[str] = [
638	"gen_percolation",
639	"gen_dfs_percolation",
640]
641"""list of generator names that generate percolated mazes
642we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail
643this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array`
644"""
645
646
647# TODO: we should deprecate this, always get a dataset when you want a maze with a solution
648def get_maze_with_solution(
649	gen_name: str,
650	grid_shape: Coord | CoordTup,
651	maze_ctor_kwargs: dict | None = None,
652) -> SolvedMaze:
653	"helper function to get a maze already with a solution"
654	if maze_ctor_kwargs is None:
655		maze_ctor_kwargs = dict()
656	# TYPING: error: Too few arguments  [call-arg]
657	# not sure why this is happening -- doesnt recognize the kwargs?
658	maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)  # type: ignore[call-arg]
659	solution: CoordArray = np.array(maze.generate_random_path())
660	return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

def get_neighbors_in_bounds( coord: jaxtyping.Int8[ndarray, 'row_col=2'], grid_shape: jaxtyping.Int8[ndarray, 'row_col=2']) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
38def get_neighbors_in_bounds(
39	coord: Coord,
40	grid_shape: Coord,
41) -> CoordArray:
42	"get all neighbors of a coordinate that are within the bounds of the grid"
43	# get all neighbors
44	neighbors: CoordArray = coord + NEIGHBORS_MASK
45
46	# filter neighbors by being within grid bounds
47	neighbors_in_bounds: CoordArray = neighbors[
48		(neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
49	]
50
51	return neighbors_in_bounds

get all neighbors of a coordinate that are within the bounds of the grid

class LatticeMazeGenerators:
 54class LatticeMazeGenerators:
 55	"""namespace for lattice maze generation algorithms
 56
 57	examples of generated mazes can be found here:
 58	https://understanding-search.github.io/maze-dataset/examples/maze_examples.html
 59	"""
 60
 61	@staticmethod
 62	def gen_dfs(
 63		grid_shape: Coord | CoordTup,
 64		*,
 65		lattice_dim: int = 2,
 66		accessible_cells: float | None = None,
 67		max_tree_depth: float | None = None,
 68		do_forks: bool = True,
 69		randomized_stack: bool = False,
 70		start_coord: Coord | None = None,
 71	) -> LatticeMaze:
 72		"""generate a lattice maze using depth first search, iterative
 73
 74		# Arguments
 75		- `grid_shape: Coord`: the shape of the grid
 76		- `lattice_dim: int`: the dimension of the lattice
 77			(default: `2`)
 78		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 79			(default: `None`)
 80		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 81			(default: `None`)
 82		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 83		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 84
 85		# algorithm
 86		1. Choose the initial cell, mark it as visited and push it to the stack
 87		2. While the stack is not empty
 88			1. Pop a cell from the stack and make it a current cell
 89			2. If the current cell has any neighbours which have not been visited
 90				1. Push the current cell to the stack
 91				2. Choose one of the unvisited neighbours
 92				3. Remove the wall between the current cell and the chosen cell
 93				4. Mark the chosen cell as visited and push it to the stack
 94		"""
 95		# Default values if no constraints have been passed
 96		grid_shape_: Coord = np.array(grid_shape)
 97		n_total_cells: int = int(np.prod(grid_shape_))
 98
 99		n_accessible_cells: int
100		if accessible_cells is None:
101			n_accessible_cells = n_total_cells
102		elif isinstance(accessible_cells, float):
103			assert accessible_cells <= 1, (
104				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
105			)
106
107			n_accessible_cells = int(accessible_cells * n_total_cells)
108		else:
109			assert isinstance(accessible_cells, int)
110			n_accessible_cells = accessible_cells
111
112		if max_tree_depth is None:
113			max_tree_depth = (
114				2 * n_total_cells
115			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
116		elif isinstance(max_tree_depth, float):
117			assert max_tree_depth <= 1, (
118				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
119			)
120
121			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
122
123		# choose a random start coord
124		start_coord = _random_start_coord(grid_shape_, start_coord)
125
126		# initialize the maze with no connections
127		connection_list: ConnectionList = np.zeros(
128			(lattice_dim, grid_shape_[0], grid_shape_[1]),
129			dtype=np.bool_,
130		)
131
132		# initialize the stack with the target coord
133		visited_cells: set[tuple[int, int]] = set()
134		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
135		stack: list[Coord] = [start_coord]
136
137		# initialize tree_depth_counter
138		current_tree_depth: int = 1
139
140		# loop until the stack is empty or n_connected_cells is reached
141		while stack and (len(visited_cells) < n_accessible_cells):
142			# get the current coord from the stack
143			current_coord: Coord
144			if randomized_stack:
145				current_coord = stack.pop(random.randint(0, len(stack) - 1))
146			else:
147				current_coord = stack.pop()
148
149			# filter neighbors by being within grid bounds and being unvisited
150			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
151				(neighbor, delta)
152				for neighbor, delta in zip(
153					current_coord + NEIGHBORS_MASK,
154					NEIGHBORS_MASK,
155					strict=False,
156				)
157				if (
158					(tuple(neighbor) not in visited_cells)
159					and (0 <= neighbor[0] < grid_shape_[0])
160					and (0 <= neighbor[1] < grid_shape_[1])
161				)
162			]
163
164			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
165			if unvisited_neighbors_deltas and (
166				current_tree_depth <= max_tree_depth / 2
167			):
168				# if we want a maze without forks, simply don't add the current coord back to the stack
169				if do_forks and (len(unvisited_neighbors_deltas) > 1):
170					stack.append(current_coord)
171
172				# choose one of the unvisited neighbors
173				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
174
175				# add connection
176				dim: int = int(np.argmax(np.abs(delta)))
177				# if positive, down/right from current coord
178				# if negative, up/left from current coord (down/right from neighbor)
179				clist_node: Coord = (
180					current_coord if (delta.sum() > 0) else chosen_neighbor
181				)
182				connection_list[dim, clist_node[0], clist_node[1]] = True
183
184				# add to visited cells and stack
185				visited_cells.add(tuple(chosen_neighbor))
186				stack.append(chosen_neighbor)
187
188				# Update current tree depth
189				current_tree_depth += 1
190			else:
191				current_tree_depth -= 1
192
193		return LatticeMaze(
194			connection_list=connection_list,
195			generation_meta=dict(
196				func_name="gen_dfs",
197				grid_shape=grid_shape_,
198				start_coord=start_coord,
199				n_accessible_cells=int(n_accessible_cells),
200				max_tree_depth=int(max_tree_depth),
201				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
202				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
203				# treated as fully connected even when it is most certainly not, causing solving the maze to break
204				fully_connected=bool(len(visited_cells) == n_total_cells),
205				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
206			),
207		)
208
209	@staticmethod
210	def gen_prim(
211		grid_shape: Coord | CoordTup,
212		lattice_dim: int = 2,
213		accessible_cells: float | None = None,
214		max_tree_depth: float | None = None,
215		do_forks: bool = True,
216		start_coord: Coord | None = None,
217	) -> LatticeMaze:
218		"(broken!) generate a lattice maze using Prim's algorithm"
219		warnings.warn(
220			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
221		)
222		return LatticeMazeGenerators.gen_dfs(
223			grid_shape=grid_shape,
224			lattice_dim=lattice_dim,
225			accessible_cells=accessible_cells,
226			max_tree_depth=max_tree_depth,
227			do_forks=do_forks,
228			start_coord=start_coord,
229			randomized_stack=True,
230		)
231
232	@staticmethod
233	def gen_wilson(
234		grid_shape: Coord | CoordTup,
235		**kwargs,
236	) -> LatticeMaze:
237		"""Generate a lattice maze using Wilson's algorithm.
238
239		# Algorithm
240		Wilson's algorithm generates an unbiased (random) maze
241		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
242		acyclic and all cells are part of a unique connected space.
243		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
244		"""
245		assert not kwargs, (
246			f"gen_wilson does not take any additional arguments, got {kwargs = }"
247		)
248
249		grid_shape_: Coord = np.array(grid_shape)
250
251		# Initialize grid and visited cells
252		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
253		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
254
255		# Choose a random cell and mark it as visited
256		start_coord: Coord = _random_start_coord(grid_shape_, None)
257		visited[start_coord[0], start_coord[1]] = True
258		del start_coord
259
260		while not visited.all():
261			# Perform loop-erased random walk from another random cell
262
263			# Choose walk_start only from unvisited cells
264			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
265			walk_start: Coord = unvisited_coords[
266				np.random.choice(unvisited_coords.shape[0])
267			]
268
269			# Perform the random walk
270			path: list[Coord] = [walk_start]
271			current: Coord = walk_start
272
273			# exit the loop once the current path hits a visited cell
274			while not visited[current[0], current[1]]:
275				# find a valid neighbor (one always exists on a lattice)
276				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
277				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
278
279				# Check for loop
280				loop_exit: int | None = None
281				for i, p in enumerate(path):
282					if np.array_equal(next_cell, p):
283						loop_exit = i
284						break
285
286				# erase the loop, or continue the walk
287				if loop_exit is not None:
288					# this removes everything after and including the loop start
289					path = path[: loop_exit + 1]
290					# reset current cell to end of path
291					current = path[-1]
292				else:
293					path.append(next_cell)
294					current = next_cell
295
296			# Add the path to the maze
297			for i in range(len(path) - 1):
298				c_1: Coord = path[i]
299				c_2: Coord = path[i + 1]
300
301				# find the dimension of the connection
302				delta: Coord = c_2 - c_1
303				dim: int = int(np.argmax(np.abs(delta)))
304
305				# if positive, down/right from current coord
306				# if negative, up/left from current coord (down/right from neighbor)
307				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
308				connection_list[dim, clist_node[0], clist_node[1]] = True
309				visited[c_1[0], c_1[1]] = True
310				# we dont add c_2 because the last c_2 will have already been visited
311
312		return LatticeMaze(
313			connection_list=connection_list,
314			generation_meta=dict(
315				func_name="gen_wilson",
316				grid_shape=grid_shape_,
317				fully_connected=True,
318			),
319		)
320
321	@staticmethod
322	def gen_percolation(
323		grid_shape: Coord | CoordTup,
324		p: float = 0.4,
325		lattice_dim: int = 2,
326		start_coord: Coord | None = None,
327	) -> LatticeMaze:
328		"""generate a lattice maze using simple percolation
329
330		note that p in the range (0.4, 0.7) gives the most interesting mazes
331
332		# Arguments
333		- `grid_shape: Coord`: the shape of the grid
334		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
335		- `p: float`: the probability of a cell being accessible (default: `0.5`)
336		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
337		"""
338		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
339		grid_shape_: Coord = np.array(grid_shape)
340
341		start_coord = _random_start_coord(grid_shape_, start_coord)
342
343		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
344
345		connection_list = _fill_edges_with_walls(connection_list)
346
347		output: LatticeMaze = LatticeMaze(
348			connection_list=connection_list,
349			generation_meta=dict(
350				func_name="gen_percolation",
351				grid_shape=grid_shape_,
352				percolation_p=p,
353				start_coord=start_coord,
354			),
355		)
356
357		# generation_meta is sometimes None, but not here since we just made it a dict above
358		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
359			start_coord,
360		)
361
362		return output
363
364	@staticmethod
365	def gen_dfs_percolation(
366		grid_shape: Coord | CoordTup,
367		p: float = 0.4,
368		lattice_dim: int = 2,
369		accessible_cells: int | None = None,
370		max_tree_depth: int | None = None,
371		start_coord: Coord | None = None,
372	) -> LatticeMaze:
373		"""dfs and then percolation (adds cycles)"""
374		grid_shape_: Coord = np.array(grid_shape)
375		start_coord = _random_start_coord(grid_shape_, start_coord)
376
377		# generate initial maze via dfs
378		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
379			grid_shape=grid_shape_,
380			lattice_dim=lattice_dim,
381			accessible_cells=accessible_cells,
382			max_tree_depth=max_tree_depth,
383			start_coord=start_coord,
384		)
385
386		# percolate
387		connection_list_perc: np.ndarray = (
388			np.random.rand(*maze.connection_list.shape) < p
389		)
390		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
391
392		maze.__dict__["connection_list"] = np.logical_or(
393			maze.connection_list,
394			connection_list_perc,
395		)
396
397		# generation_meta is sometimes None, but not here since we just made it a dict above
398		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
399		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
400		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
401			start_coord,
402		)
403
404		return maze
405
406	@staticmethod
407	def gen_kruskal(
408		grid_shape: "Coord | CoordTup",
409		lattice_dim: int = 2,
410		start_coord: "Coord | None" = None,
411	) -> "LatticeMaze":
412		"""Generate a maze using Kruskal's algorithm.
413
414		This function generates a random spanning tree over a grid using Kruskal's algorithm.
415		Each cell is treated as a node, and all valid adjacent edges are listed and processed
416		in random order. An edge is added (i.e. its passage carved) only if it connects two cells
417		that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
418		without cycles.
419
420		https://en.wikipedia.org/wiki/Kruskal's_algorithm
421
422		# Parameters:
423		- `grid_shape : Coord | CoordTup`
424			The shape of the maze grid (for example, `(n_rows, n_cols)`).
425		- `lattice_dim : int`
426			The lattice dimension (default is `2`).
427		- `start_coord : Coord | None`
428			Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
429		- `**kwargs`
430			Additional keyword arguments (currently unused).
431
432		# Returns:
433		- `LatticeMaze`
434			A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
435
436		# Usage:
437		```python
438		maze = gen_kruskal((10, 10))
439		```
440		"""
441		assert lattice_dim == 2, (  # noqa: PLR2004
442			"Kruskal's algorithm is only implemented for 2D lattices."
443		)
444		# Convert grid_shape to a tuple of ints
445		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
446		n_rows, n_cols = grid_shape_
447
448		# Initialize union-find data structure.
449		parent: dict[tuple[int, int], tuple[int, int]] = {}
450
451		def find(cell: tuple[int, int]) -> tuple[int, int]:
452			while parent[cell] != cell:
453				parent[cell] = parent[parent[cell]]
454				cell = parent[cell]
455			return cell
456
457		def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
458			root1 = find(cell1)
459			root2 = find(cell2)
460			parent[root2] = root1
461
462		# Initialize each cell as its own set.
463		for i in range(n_rows):
464			for j in range(n_cols):
465				parent[(i, j)] = (i, j)
466
467		# List all possible edges.
468		# For vertical edges (i.e. connecting a cell to its right neighbor):
469		edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
470		for i in range(n_rows):
471			for j in range(n_cols - 1):
472				edges.append(((i, j), (i, j + 1), 1))
473		# For horizontal edges (i.e. connecting a cell to its bottom neighbor):
474		for i in range(n_rows - 1):
475			for j in range(n_cols):
476				edges.append(((i, j), (i + 1, j), 0))
477
478		# Shuffle the list of edges.
479		import random
480
481		random.shuffle(edges)
482
483		# Initialize connection_list with no connections.
484		# connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
485		# connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
486		import numpy as np
487
488		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
489
490		# Process each edge; if it connects two different trees, union them and carve the passage.
491		for cell1, cell2, direction in edges:
492			if find(cell1) != find(cell2):
493				union(cell1, cell2)
494				if direction == 0:
495					# Horizontal edge: connection is stored in connection_list[0] at cell1.
496					connection_list[0, cell1[0], cell1[1]] = True
497				else:
498					# Vertical edge: connection is stored in connection_list[1] at cell1.
499					connection_list[1, cell1[0], cell1[1]] = True
500
501		if start_coord is None:
502			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
503
504		generation_meta: dict = dict(
505			func_name="gen_kruskal",
506			grid_shape=grid_shape_,
507			start_coord=start_coord,
508			algorithm="kruskal",
509			fully_connected=True,
510		)
511		return LatticeMaze(
512			connection_list=connection_list, generation_meta=generation_meta
513		)
514
515	@staticmethod
516	def gen_recursive_division(
517		grid_shape: "Coord | CoordTup",
518		lattice_dim: int = 2,
519		start_coord: "Coord | None" = None,
520	) -> "LatticeMaze":
521		"""Generate a maze using the recursive division algorithm.
522
523		This function generates a maze by recursively dividing the grid with walls and carving a single
524		passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
525		cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
526		The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
527
528		# Parameters:
529		- `grid_shape : Coord | CoordTup`
530			The shape of the maze grid (e.g., `(n_rows, n_cols)`).
531		- `lattice_dim : int`
532			The lattice dimension (default is `2`).
533		- `start_coord : Coord | None`
534			Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
535		- `**kwargs`
536			Additional keyword arguments (currently unused).
537
538		# Returns:
539		- `LatticeMaze`
540			A maze represented by a connection list, generated using recursive division.
541
542		# Usage:
543		```python
544		maze = gen_recursive_division((10, 10))
545		```
546		"""
547		assert lattice_dim == 2, (  # noqa: PLR2004
548			"Recursive division algorithm is only implemented for 2D lattices."
549		)
550		# Convert grid_shape to a tuple of ints.
551		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
552		n_rows, n_cols = grid_shape_
553
554		# Initialize connection_list as a fully connected grid.
555		# For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
556		# For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
557		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
558		connection_list[0, : n_rows - 1, :] = True
559		connection_list[1, :, : n_cols - 1] = True
560
561		def divide(x: int, y: int, width: int, height: int) -> None:
562			"""Recursively divide the region starting at (x, y) with the given width and height.
563
564			Removes connections along the chosen division line except for one randomly chosen gap.
565			"""
566			if width < 2 or height < 2:  # noqa: PLR2004
567				return
568
569			if width > height:
570				# Vertical division.
571				wall_col = random.randint(x + 1, x + width - 1)
572				gap_row = random.randint(y, y + height - 1)
573				for row in range(y, y + height):
574					if row == gap_row:
575						continue
576					# Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
577					if wall_col - 1 < n_cols - 1:
578						connection_list[1, row, wall_col - 1] = False
579				# Recurse on the left and right subregions.
580				divide(x, y, wall_col - x, height)
581				divide(wall_col, y, x + width - wall_col, height)
582			else:
583				# Horizontal division.
584				wall_row = random.randint(y + 1, y + height - 1)
585				gap_col = random.randint(x, x + width - 1)
586				for col in range(x, x + width):
587					if col == gap_col:
588						continue
589					# Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
590					if wall_row - 1 < n_rows - 1:
591						connection_list[0, wall_row - 1, col] = False
592				# Recurse on the top and bottom subregions.
593				divide(x, y, width, wall_row - y)
594				divide(x, wall_row, width, y + height - wall_row)
595
596		# Begin the division on the full grid.
597		divide(0, 0, n_cols, n_rows)
598
599		if start_coord is None:
600			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
601
602		generation_meta: dict = dict(
603			func_name="gen_recursive_division",
604			grid_shape=grid_shape_,
605			start_coord=start_coord,
606			algorithm="recursive_division",
607			fully_connected=True,
608		)
609		return LatticeMaze(
610			connection_list=connection_list, generation_meta=generation_meta
611		)

namespace for lattice maze generation algorithms

examples of generated mazes can be found here: https://understanding-search.github.io/maze-dataset/examples/maze_examples.html

@staticmethod
def gen_dfs( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], *, lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, randomized_stack: bool = False, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
 61	@staticmethod
 62	def gen_dfs(
 63		grid_shape: Coord | CoordTup,
 64		*,
 65		lattice_dim: int = 2,
 66		accessible_cells: float | None = None,
 67		max_tree_depth: float | None = None,
 68		do_forks: bool = True,
 69		randomized_stack: bool = False,
 70		start_coord: Coord | None = None,
 71	) -> LatticeMaze:
 72		"""generate a lattice maze using depth first search, iterative
 73
 74		# Arguments
 75		- `grid_shape: Coord`: the shape of the grid
 76		- `lattice_dim: int`: the dimension of the lattice
 77			(default: `2`)
 78		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 79			(default: `None`)
 80		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 81			(default: `None`)
 82		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 83		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 84
 85		# algorithm
 86		1. Choose the initial cell, mark it as visited and push it to the stack
 87		2. While the stack is not empty
 88			1. Pop a cell from the stack and make it a current cell
 89			2. If the current cell has any neighbours which have not been visited
 90				1. Push the current cell to the stack
 91				2. Choose one of the unvisited neighbours
 92				3. Remove the wall between the current cell and the chosen cell
 93				4. Mark the chosen cell as visited and push it to the stack
 94		"""
 95		# Default values if no constraints have been passed
 96		grid_shape_: Coord = np.array(grid_shape)
 97		n_total_cells: int = int(np.prod(grid_shape_))
 98
 99		n_accessible_cells: int
100		if accessible_cells is None:
101			n_accessible_cells = n_total_cells
102		elif isinstance(accessible_cells, float):
103			assert accessible_cells <= 1, (
104				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
105			)
106
107			n_accessible_cells = int(accessible_cells * n_total_cells)
108		else:
109			assert isinstance(accessible_cells, int)
110			n_accessible_cells = accessible_cells
111
112		if max_tree_depth is None:
113			max_tree_depth = (
114				2 * n_total_cells
115			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
116		elif isinstance(max_tree_depth, float):
117			assert max_tree_depth <= 1, (
118				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
119			)
120
121			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
122
123		# choose a random start coord
124		start_coord = _random_start_coord(grid_shape_, start_coord)
125
126		# initialize the maze with no connections
127		connection_list: ConnectionList = np.zeros(
128			(lattice_dim, grid_shape_[0], grid_shape_[1]),
129			dtype=np.bool_,
130		)
131
132		# initialize the stack with the target coord
133		visited_cells: set[tuple[int, int]] = set()
134		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
135		stack: list[Coord] = [start_coord]
136
137		# initialize tree_depth_counter
138		current_tree_depth: int = 1
139
140		# loop until the stack is empty or n_connected_cells is reached
141		while stack and (len(visited_cells) < n_accessible_cells):
142			# get the current coord from the stack
143			current_coord: Coord
144			if randomized_stack:
145				current_coord = stack.pop(random.randint(0, len(stack) - 1))
146			else:
147				current_coord = stack.pop()
148
149			# filter neighbors by being within grid bounds and being unvisited
150			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
151				(neighbor, delta)
152				for neighbor, delta in zip(
153					current_coord + NEIGHBORS_MASK,
154					NEIGHBORS_MASK,
155					strict=False,
156				)
157				if (
158					(tuple(neighbor) not in visited_cells)
159					and (0 <= neighbor[0] < grid_shape_[0])
160					and (0 <= neighbor[1] < grid_shape_[1])
161				)
162			]
163
164			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
165			if unvisited_neighbors_deltas and (
166				current_tree_depth <= max_tree_depth / 2
167			):
168				# if we want a maze without forks, simply don't add the current coord back to the stack
169				if do_forks and (len(unvisited_neighbors_deltas) > 1):
170					stack.append(current_coord)
171
172				# choose one of the unvisited neighbors
173				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
174
175				# add connection
176				dim: int = int(np.argmax(np.abs(delta)))
177				# if positive, down/right from current coord
178				# if negative, up/left from current coord (down/right from neighbor)
179				clist_node: Coord = (
180					current_coord if (delta.sum() > 0) else chosen_neighbor
181				)
182				connection_list[dim, clist_node[0], clist_node[1]] = True
183
184				# add to visited cells and stack
185				visited_cells.add(tuple(chosen_neighbor))
186				stack.append(chosen_neighbor)
187
188				# Update current tree depth
189				current_tree_depth += 1
190			else:
191				current_tree_depth -= 1
192
193		return LatticeMaze(
194			connection_list=connection_list,
195			generation_meta=dict(
196				func_name="gen_dfs",
197				grid_shape=grid_shape_,
198				start_coord=start_coord,
199				n_accessible_cells=int(n_accessible_cells),
200				max_tree_depth=int(max_tree_depth),
201				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
202				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
203				# treated as fully connected even when it is most certainly not, causing solving the maze to break
204				fully_connected=bool(len(visited_cells) == n_total_cells),
205				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
206			),
207		)

generate a lattice maze using depth first search, iterative

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • accessible_cells: int | float |None: the number of accessible cells in the maze. If None, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default: None)
  • max_tree_depth: int | float | None: the maximum depth of the tree. If None, defaults to 2 * accessible_cells. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default: None)
  • do_forks: bool: whether to allow forks in the maze. If False, the maze will be have no forks and will be a simple hallway.
  • start_coord: Coord | None: the starting coordinate of the generation algorithm. If None, defaults to a random coordinate.

algorithm

  1. Choose the initial cell, mark it as visited and push it to the stack
  2. While the stack is not empty
    1. Pop a cell from the stack and make it a current cell
    2. If the current cell has any neighbours which have not been visited
      1. Push the current cell to the stack
      2. Choose one of the unvisited neighbours
      3. Remove the wall between the current cell and the chosen cell
      4. Mark the chosen cell as visited and push it to the stack
@staticmethod
def gen_prim( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
209	@staticmethod
210	def gen_prim(
211		grid_shape: Coord | CoordTup,
212		lattice_dim: int = 2,
213		accessible_cells: float | None = None,
214		max_tree_depth: float | None = None,
215		do_forks: bool = True,
216		start_coord: Coord | None = None,
217	) -> LatticeMaze:
218		"(broken!) generate a lattice maze using Prim's algorithm"
219		warnings.warn(
220			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
221		)
222		return LatticeMazeGenerators.gen_dfs(
223			grid_shape=grid_shape,
224			lattice_dim=lattice_dim,
225			accessible_cells=accessible_cells,
226			max_tree_depth=max_tree_depth,
227			do_forks=do_forks,
228			start_coord=start_coord,
229			randomized_stack=True,
230		)

(broken!) generate a lattice maze using Prim's algorithm

@staticmethod
def gen_wilson( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], **kwargs) -> maze_dataset.LatticeMaze:
232	@staticmethod
233	def gen_wilson(
234		grid_shape: Coord | CoordTup,
235		**kwargs,
236	) -> LatticeMaze:
237		"""Generate a lattice maze using Wilson's algorithm.
238
239		# Algorithm
240		Wilson's algorithm generates an unbiased (random) maze
241		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
242		acyclic and all cells are part of a unique connected space.
243		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
244		"""
245		assert not kwargs, (
246			f"gen_wilson does not take any additional arguments, got {kwargs = }"
247		)
248
249		grid_shape_: Coord = np.array(grid_shape)
250
251		# Initialize grid and visited cells
252		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
253		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
254
255		# Choose a random cell and mark it as visited
256		start_coord: Coord = _random_start_coord(grid_shape_, None)
257		visited[start_coord[0], start_coord[1]] = True
258		del start_coord
259
260		while not visited.all():
261			# Perform loop-erased random walk from another random cell
262
263			# Choose walk_start only from unvisited cells
264			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
265			walk_start: Coord = unvisited_coords[
266				np.random.choice(unvisited_coords.shape[0])
267			]
268
269			# Perform the random walk
270			path: list[Coord] = [walk_start]
271			current: Coord = walk_start
272
273			# exit the loop once the current path hits a visited cell
274			while not visited[current[0], current[1]]:
275				# find a valid neighbor (one always exists on a lattice)
276				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
277				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
278
279				# Check for loop
280				loop_exit: int | None = None
281				for i, p in enumerate(path):
282					if np.array_equal(next_cell, p):
283						loop_exit = i
284						break
285
286				# erase the loop, or continue the walk
287				if loop_exit is not None:
288					# this removes everything after and including the loop start
289					path = path[: loop_exit + 1]
290					# reset current cell to end of path
291					current = path[-1]
292				else:
293					path.append(next_cell)
294					current = next_cell
295
296			# Add the path to the maze
297			for i in range(len(path) - 1):
298				c_1: Coord = path[i]
299				c_2: Coord = path[i + 1]
300
301				# find the dimension of the connection
302				delta: Coord = c_2 - c_1
303				dim: int = int(np.argmax(np.abs(delta)))
304
305				# if positive, down/right from current coord
306				# if negative, up/left from current coord (down/right from neighbor)
307				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
308				connection_list[dim, clist_node[0], clist_node[1]] = True
309				visited[c_1[0], c_1[1]] = True
310				# we dont add c_2 because the last c_2 will have already been visited
311
312		return LatticeMaze(
313			connection_list=connection_list,
314			generation_meta=dict(
315				func_name="gen_wilson",
316				grid_shape=grid_shape_,
317				fully_connected=True,
318			),
319		)

Generate a lattice maze using Wilson's algorithm.

Algorithm

Wilson's algorithm generates an unbiased (random) maze sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is acyclic and all cells are part of a unique connected space. https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm

@staticmethod
def gen_percolation( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], p: float = 0.4, lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
321	@staticmethod
322	def gen_percolation(
323		grid_shape: Coord | CoordTup,
324		p: float = 0.4,
325		lattice_dim: int = 2,
326		start_coord: Coord | None = None,
327	) -> LatticeMaze:
328		"""generate a lattice maze using simple percolation
329
330		note that p in the range (0.4, 0.7) gives the most interesting mazes
331
332		# Arguments
333		- `grid_shape: Coord`: the shape of the grid
334		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
335		- `p: float`: the probability of a cell being accessible (default: `0.5`)
336		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
337		"""
338		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
339		grid_shape_: Coord = np.array(grid_shape)
340
341		start_coord = _random_start_coord(grid_shape_, start_coord)
342
343		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
344
345		connection_list = _fill_edges_with_walls(connection_list)
346
347		output: LatticeMaze = LatticeMaze(
348			connection_list=connection_list,
349			generation_meta=dict(
350				func_name="gen_percolation",
351				grid_shape=grid_shape_,
352				percolation_p=p,
353				start_coord=start_coord,
354			),
355		)
356
357		# generation_meta is sometimes None, but not here since we just made it a dict above
358		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
359			start_coord,
360		)
361
362		return output

generate a lattice maze using simple percolation

note that p in the range (0.4, 0.7) gives the most interesting mazes

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • p: float: the probability of a cell being accessible (default: 0.5)
  • start_coord: Coord | None: the starting coordinate for the connected component (default: None will give a random start)
@staticmethod
def gen_dfs_percolation( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], p: float = 0.4, lattice_dim: int = 2, accessible_cells: int | None = None, max_tree_depth: int | None = None, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
364	@staticmethod
365	def gen_dfs_percolation(
366		grid_shape: Coord | CoordTup,
367		p: float = 0.4,
368		lattice_dim: int = 2,
369		accessible_cells: int | None = None,
370		max_tree_depth: int | None = None,
371		start_coord: Coord | None = None,
372	) -> LatticeMaze:
373		"""dfs and then percolation (adds cycles)"""
374		grid_shape_: Coord = np.array(grid_shape)
375		start_coord = _random_start_coord(grid_shape_, start_coord)
376
377		# generate initial maze via dfs
378		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
379			grid_shape=grid_shape_,
380			lattice_dim=lattice_dim,
381			accessible_cells=accessible_cells,
382			max_tree_depth=max_tree_depth,
383			start_coord=start_coord,
384		)
385
386		# percolate
387		connection_list_perc: np.ndarray = (
388			np.random.rand(*maze.connection_list.shape) < p
389		)
390		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
391
392		maze.__dict__["connection_list"] = np.logical_or(
393			maze.connection_list,
394			connection_list_perc,
395		)
396
397		# generation_meta is sometimes None, but not here since we just made it a dict above
398		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
399		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
400		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
401			start_coord,
402		)
403
404		return maze

dfs and then percolation (adds cycles)

@staticmethod
def gen_kruskal( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
406	@staticmethod
407	def gen_kruskal(
408		grid_shape: "Coord | CoordTup",
409		lattice_dim: int = 2,
410		start_coord: "Coord | None" = None,
411	) -> "LatticeMaze":
412		"""Generate a maze using Kruskal's algorithm.
413
414		This function generates a random spanning tree over a grid using Kruskal's algorithm.
415		Each cell is treated as a node, and all valid adjacent edges are listed and processed
416		in random order. An edge is added (i.e. its passage carved) only if it connects two cells
417		that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
418		without cycles.
419
420		https://en.wikipedia.org/wiki/Kruskal's_algorithm
421
422		# Parameters:
423		- `grid_shape : Coord | CoordTup`
424			The shape of the maze grid (for example, `(n_rows, n_cols)`).
425		- `lattice_dim : int`
426			The lattice dimension (default is `2`).
427		- `start_coord : Coord | None`
428			Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
429		- `**kwargs`
430			Additional keyword arguments (currently unused).
431
432		# Returns:
433		- `LatticeMaze`
434			A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
435
436		# Usage:
437		```python
438		maze = gen_kruskal((10, 10))
439		```
440		"""
441		assert lattice_dim == 2, (  # noqa: PLR2004
442			"Kruskal's algorithm is only implemented for 2D lattices."
443		)
444		# Convert grid_shape to a tuple of ints
445		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
446		n_rows, n_cols = grid_shape_
447
448		# Initialize union-find data structure.
449		parent: dict[tuple[int, int], tuple[int, int]] = {}
450
451		def find(cell: tuple[int, int]) -> tuple[int, int]:
452			while parent[cell] != cell:
453				parent[cell] = parent[parent[cell]]
454				cell = parent[cell]
455			return cell
456
457		def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
458			root1 = find(cell1)
459			root2 = find(cell2)
460			parent[root2] = root1
461
462		# Initialize each cell as its own set.
463		for i in range(n_rows):
464			for j in range(n_cols):
465				parent[(i, j)] = (i, j)
466
467		# List all possible edges.
468		# For vertical edges (i.e. connecting a cell to its right neighbor):
469		edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
470		for i in range(n_rows):
471			for j in range(n_cols - 1):
472				edges.append(((i, j), (i, j + 1), 1))
473		# For horizontal edges (i.e. connecting a cell to its bottom neighbor):
474		for i in range(n_rows - 1):
475			for j in range(n_cols):
476				edges.append(((i, j), (i + 1, j), 0))
477
478		# Shuffle the list of edges.
479		import random
480
481		random.shuffle(edges)
482
483		# Initialize connection_list with no connections.
484		# connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
485		# connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
486		import numpy as np
487
488		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
489
490		# Process each edge; if it connects two different trees, union them and carve the passage.
491		for cell1, cell2, direction in edges:
492			if find(cell1) != find(cell2):
493				union(cell1, cell2)
494				if direction == 0:
495					# Horizontal edge: connection is stored in connection_list[0] at cell1.
496					connection_list[0, cell1[0], cell1[1]] = True
497				else:
498					# Vertical edge: connection is stored in connection_list[1] at cell1.
499					connection_list[1, cell1[0], cell1[1]] = True
500
501		if start_coord is None:
502			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
503
504		generation_meta: dict = dict(
505			func_name="gen_kruskal",
506			grid_shape=grid_shape_,
507			start_coord=start_coord,
508			algorithm="kruskal",
509			fully_connected=True,
510		)
511		return LatticeMaze(
512			connection_list=connection_list, generation_meta=generation_meta
513		)

Generate a maze using Kruskal's algorithm.

This function generates a random spanning tree over a grid using Kruskal's algorithm. Each cell is treated as a node, and all valid adjacent edges are listed and processed in random order. An edge is added (i.e. its passage carved) only if it connects two cells that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) without cycles.

https://en.wikipedia.org/wiki/Kruskal's_algorithm

Parameters:

  • grid_shape : Coord | CoordTup The shape of the maze grid (for example, (n_rows, n_cols)).
  • lattice_dim : int The lattice dimension (default is 2).
  • start_coord : Coord | None Optionally, specify a starting coordinate. If None, a random coordinate will be chosen.
  • **kwargs Additional keyword arguments (currently unused).

Returns:

  • LatticeMaze A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.

Usage:

maze = gen_kruskal((10, 10))
@staticmethod
def gen_recursive_division( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
515	@staticmethod
516	def gen_recursive_division(
517		grid_shape: "Coord | CoordTup",
518		lattice_dim: int = 2,
519		start_coord: "Coord | None" = None,
520	) -> "LatticeMaze":
521		"""Generate a maze using the recursive division algorithm.
522
523		This function generates a maze by recursively dividing the grid with walls and carving a single
524		passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
525		cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
526		The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
527
528		# Parameters:
529		- `grid_shape : Coord | CoordTup`
530			The shape of the maze grid (e.g., `(n_rows, n_cols)`).
531		- `lattice_dim : int`
532			The lattice dimension (default is `2`).
533		- `start_coord : Coord | None`
534			Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
535		- `**kwargs`
536			Additional keyword arguments (currently unused).
537
538		# Returns:
539		- `LatticeMaze`
540			A maze represented by a connection list, generated using recursive division.
541
542		# Usage:
543		```python
544		maze = gen_recursive_division((10, 10))
545		```
546		"""
547		assert lattice_dim == 2, (  # noqa: PLR2004
548			"Recursive division algorithm is only implemented for 2D lattices."
549		)
550		# Convert grid_shape to a tuple of ints.
551		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
552		n_rows, n_cols = grid_shape_
553
554		# Initialize connection_list as a fully connected grid.
555		# For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
556		# For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
557		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
558		connection_list[0, : n_rows - 1, :] = True
559		connection_list[1, :, : n_cols - 1] = True
560
561		def divide(x: int, y: int, width: int, height: int) -> None:
562			"""Recursively divide the region starting at (x, y) with the given width and height.
563
564			Removes connections along the chosen division line except for one randomly chosen gap.
565			"""
566			if width < 2 or height < 2:  # noqa: PLR2004
567				return
568
569			if width > height:
570				# Vertical division.
571				wall_col = random.randint(x + 1, x + width - 1)
572				gap_row = random.randint(y, y + height - 1)
573				for row in range(y, y + height):
574					if row == gap_row:
575						continue
576					# Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
577					if wall_col - 1 < n_cols - 1:
578						connection_list[1, row, wall_col - 1] = False
579				# Recurse on the left and right subregions.
580				divide(x, y, wall_col - x, height)
581				divide(wall_col, y, x + width - wall_col, height)
582			else:
583				# Horizontal division.
584				wall_row = random.randint(y + 1, y + height - 1)
585				gap_col = random.randint(x, x + width - 1)
586				for col in range(x, x + width):
587					if col == gap_col:
588						continue
589					# Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
590					if wall_row - 1 < n_rows - 1:
591						connection_list[0, wall_row - 1, col] = False
592				# Recurse on the top and bottom subregions.
593				divide(x, y, width, wall_row - y)
594				divide(x, wall_row, width, y + height - wall_row)
595
596		# Begin the division on the full grid.
597		divide(0, 0, n_cols, n_rows)
598
599		if start_coord is None:
600			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
601
602		generation_meta: dict = dict(
603			func_name="gen_recursive_division",
604			grid_shape=grid_shape_,
605			start_coord=start_coord,
606			algorithm="recursive_division",
607			fully_connected=True,
608		)
609		return LatticeMaze(
610			connection_list=connection_list, generation_meta=generation_meta
611		)

Generate a maze using the recursive division algorithm.

This function generates a maze by recursively dividing the grid with walls and carving a single passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.

Parameters:

  • grid_shape : Coord | CoordTup The shape of the maze grid (e.g., (n_rows, n_cols)).
  • lattice_dim : int The lattice dimension (default is 2).
  • start_coord : Coord | None Optionally, specify a starting coordinate. If None, a random coordinate is chosen.
  • **kwargs Additional keyword arguments (currently unused).

Returns:

  • LatticeMaze A maze represented by a connection list, generated using recursive division.

Usage:

maze = gen_recursive_division((10, 10))
P_GeneratorKwargs = ~P_GeneratorKwargs
MazeGeneratorFunc = typing.Callable[typing.Concatenate[jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], ~P_GeneratorKwargs], maze_dataset.LatticeMaze]
GENERATORS_MAP: dict[str, typing.Callable[typing.Concatenate[jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], ~P_GeneratorKwargs], maze_dataset.LatticeMaze]] = {'gen_dfs': <function LatticeMazeGenerators.gen_dfs>, 'gen_wilson': <function LatticeMazeGenerators.gen_wilson>, 'gen_percolation': <function LatticeMazeGenerators.gen_percolation>, 'gen_dfs_percolation': <function LatticeMazeGenerators.gen_dfs_percolation>, 'gen_prim': <function LatticeMazeGenerators.gen_prim>, 'gen_kruskal': <function LatticeMazeGenerators.gen_kruskal>, 'gen_recursive_division': <function LatticeMazeGenerators.gen_recursive_division>}

mapping of generator names to generator functions, useful for loading MazeDatasetConfig

def get_maze_with_solution( gen_name: str, grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], maze_ctor_kwargs: dict | None = None) -> maze_dataset.SolvedMaze:
649def get_maze_with_solution(
650	gen_name: str,
651	grid_shape: Coord | CoordTup,
652	maze_ctor_kwargs: dict | None = None,
653) -> SolvedMaze:
654	"helper function to get a maze already with a solution"
655	if maze_ctor_kwargs is None:
656		maze_ctor_kwargs = dict()
657	# TYPING: error: Too few arguments  [call-arg]
658	# not sure why this is happening -- doesnt recognize the kwargs?
659	maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)  # type: ignore[call-arg]
660	solution: CoordArray = np.array(maze.generate_random_path())
661	return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

helper function to get a maze already with a solution