docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.generation.generators

generation functions have signature (grid_shape: Coord, **kwargs) -> LatticeMaze and are methods in LatticeMazeGenerators


  1"""generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`"""
  2
  3import random
  4import warnings
  5from typing import Any, Callable
  6
  7import numpy as np
  8from jaxtyping import Bool
  9
 10from maze_dataset.constants import CoordArray, CoordTup
 11from maze_dataset.generation.seed import GLOBAL_SEED
 12from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze
 13from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls
 14
 15numpy_rng = np.random.default_rng(GLOBAL_SEED)
 16random.seed(GLOBAL_SEED)
 17
 18
 19def _random_start_coord(
 20	grid_shape: Coord,
 21	start_coord: Coord | CoordTup | None,
 22) -> Coord:
 23	"picking a random start coord within the bounds of `grid_shape` if none is provided"
 24	start_coord_: Coord
 25	if start_coord is None:
 26		start_coord_ = np.random.randint(
 27			0,  # lower bound
 28			np.maximum(grid_shape - 1, 1),  # upper bound (at least 1)
 29			size=len(grid_shape),  # dimensionality
 30		)
 31	else:
 32		start_coord_ = np.array(start_coord)
 33
 34	return start_coord_
 35
 36
 37def get_neighbors_in_bounds(
 38	coord: Coord,
 39	grid_shape: Coord,
 40) -> CoordArray:
 41	"get all neighbors of a coordinate that are within the bounds of the grid"
 42	# get all neighbors
 43	neighbors: CoordArray = coord + NEIGHBORS_MASK
 44
 45	# filter neighbors by being within grid bounds
 46	neighbors_in_bounds: CoordArray = neighbors[
 47		(neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
 48	]
 49
 50	return neighbors_in_bounds
 51
 52
 53class LatticeMazeGenerators:
 54	"""namespace for lattice maze generation algorithms
 55
 56	examples of generated mazes can be found here:
 57	https://understanding-search.github.io/maze-dataset/examples/maze_examples.html
 58	"""
 59
 60	@staticmethod
 61	def gen_dfs(
 62		grid_shape: Coord | CoordTup,
 63		lattice_dim: int = 2,
 64		accessible_cells: float | None = None,
 65		max_tree_depth: float | None = None,
 66		do_forks: bool = True,
 67		randomized_stack: bool = False,
 68		start_coord: Coord | None = None,
 69	) -> LatticeMaze:
 70		"""generate a lattice maze using depth first search, iterative
 71
 72		# Arguments
 73		- `grid_shape: Coord`: the shape of the grid
 74		- `lattice_dim: int`: the dimension of the lattice
 75			(default: `2`)
 76		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 77			(default: `None`)
 78		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 79			(default: `None`)
 80		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 81		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 82
 83		# algorithm
 84		1. Choose the initial cell, mark it as visited and push it to the stack
 85		2. While the stack is not empty
 86			1. Pop a cell from the stack and make it a current cell
 87			2. If the current cell has any neighbours which have not been visited
 88				1. Push the current cell to the stack
 89				2. Choose one of the unvisited neighbours
 90				3. Remove the wall between the current cell and the chosen cell
 91				4. Mark the chosen cell as visited and push it to the stack
 92		"""
 93		# Default values if no constraints have been passed
 94		grid_shape_: Coord = np.array(grid_shape)
 95		n_total_cells: int = int(np.prod(grid_shape_))
 96
 97		n_accessible_cells: int
 98		if accessible_cells is None:
 99			n_accessible_cells = n_total_cells
100		elif isinstance(accessible_cells, float):
101			assert accessible_cells <= 1, (
102				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
103			)
104
105			n_accessible_cells = int(accessible_cells * n_total_cells)
106		else:
107			assert isinstance(accessible_cells, int)
108			n_accessible_cells = accessible_cells
109
110		if max_tree_depth is None:
111			max_tree_depth = (
112				2 * n_total_cells
113			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
114		elif isinstance(max_tree_depth, float):
115			assert max_tree_depth <= 1, (
116				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
117			)
118
119			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
120
121		# choose a random start coord
122		start_coord = _random_start_coord(grid_shape_, start_coord)
123
124		# initialize the maze with no connections
125		connection_list: ConnectionList = np.zeros(
126			(lattice_dim, grid_shape_[0], grid_shape_[1]),
127			dtype=np.bool_,
128		)
129
130		# initialize the stack with the target coord
131		visited_cells: set[tuple[int, int]] = set()
132		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
133		stack: list[Coord] = [start_coord]
134
135		# initialize tree_depth_counter
136		current_tree_depth: int = 1
137
138		# loop until the stack is empty or n_connected_cells is reached
139		while stack and (len(visited_cells) < n_accessible_cells):
140			# get the current coord from the stack
141			current_coord: Coord
142			if randomized_stack:
143				current_coord = stack.pop(random.randint(0, len(stack) - 1))
144			else:
145				current_coord = stack.pop()
146
147			# filter neighbors by being within grid bounds and being unvisited
148			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
149				(neighbor, delta)
150				for neighbor, delta in zip(
151					current_coord + NEIGHBORS_MASK,
152					NEIGHBORS_MASK,
153					strict=False,
154				)
155				if (
156					(tuple(neighbor) not in visited_cells)
157					and (0 <= neighbor[0] < grid_shape_[0])
158					and (0 <= neighbor[1] < grid_shape_[1])
159				)
160			]
161
162			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
163			if unvisited_neighbors_deltas and (
164				current_tree_depth <= max_tree_depth / 2
165			):
166				# if we want a maze without forks, simply don't add the current coord back to the stack
167				if do_forks and (len(unvisited_neighbors_deltas) > 1):
168					stack.append(current_coord)
169
170				# choose one of the unvisited neighbors
171				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
172
173				# add connection
174				dim: int = int(np.argmax(np.abs(delta)))
175				# if positive, down/right from current coord
176				# if negative, up/left from current coord (down/right from neighbor)
177				clist_node: Coord = (
178					current_coord if (delta.sum() > 0) else chosen_neighbor
179				)
180				connection_list[dim, clist_node[0], clist_node[1]] = True
181
182				# add to visited cells and stack
183				visited_cells.add(tuple(chosen_neighbor))
184				stack.append(chosen_neighbor)
185
186				# Update current tree depth
187				current_tree_depth += 1
188			else:
189				current_tree_depth -= 1
190
191		return LatticeMaze(
192			connection_list=connection_list,
193			generation_meta=dict(
194				func_name="gen_dfs",
195				grid_shape=grid_shape_,
196				start_coord=start_coord,
197				n_accessible_cells=int(n_accessible_cells),
198				max_tree_depth=int(max_tree_depth),
199				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
200				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
201				# treated as fully connected even when it is most certainly not, causing solving the maze to break
202				fully_connected=bool(len(visited_cells) == n_total_cells),
203				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
204			),
205		)
206
207	@staticmethod
208	def gen_prim(
209		grid_shape: Coord | CoordTup,
210		lattice_dim: int = 2,
211		accessible_cells: float | None = None,
212		max_tree_depth: float | None = None,
213		do_forks: bool = True,
214		start_coord: Coord | None = None,
215	) -> LatticeMaze:
216		"(broken!) generate a lattice maze using Prim's algorithm"
217		warnings.warn(
218			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
219		)
220		return LatticeMazeGenerators.gen_dfs(
221			grid_shape=grid_shape,
222			lattice_dim=lattice_dim,
223			accessible_cells=accessible_cells,
224			max_tree_depth=max_tree_depth,
225			do_forks=do_forks,
226			start_coord=start_coord,
227			randomized_stack=True,
228		)
229
230	@staticmethod
231	def gen_wilson(
232		grid_shape: Coord | CoordTup,
233		**kwargs,
234	) -> LatticeMaze:
235		"""Generate a lattice maze using Wilson's algorithm.
236
237		# Algorithm
238		Wilson's algorithm generates an unbiased (random) maze
239		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
240		acyclic and all cells are part of a unique connected space.
241		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
242		"""
243		assert not kwargs, (
244			f"gen_wilson does not take any additional arguments, got {kwargs = }"
245		)
246
247		grid_shape_: Coord = np.array(grid_shape)
248
249		# Initialize grid and visited cells
250		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
251		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
252
253		# Choose a random cell and mark it as visited
254		start_coord: Coord = _random_start_coord(grid_shape_, None)
255		visited[start_coord[0], start_coord[1]] = True
256		del start_coord
257
258		while not visited.all():
259			# Perform loop-erased random walk from another random cell
260
261			# Choose walk_start only from unvisited cells
262			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
263			walk_start: Coord = unvisited_coords[
264				np.random.choice(unvisited_coords.shape[0])
265			]
266
267			# Perform the random walk
268			path: list[Coord] = [walk_start]
269			current: Coord = walk_start
270
271			# exit the loop once the current path hits a visited cell
272			while not visited[current[0], current[1]]:
273				# find a valid neighbor (one always exists on a lattice)
274				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
275				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
276
277				# Check for loop
278				loop_exit: int | None = None
279				for i, p in enumerate(path):
280					if np.array_equal(next_cell, p):
281						loop_exit = i
282						break
283
284				# erase the loop, or continue the walk
285				if loop_exit is not None:
286					# this removes everything after and including the loop start
287					path = path[: loop_exit + 1]
288					# reset current cell to end of path
289					current = path[-1]
290				else:
291					path.append(next_cell)
292					current = next_cell
293
294			# Add the path to the maze
295			for i in range(len(path) - 1):
296				c_1: Coord = path[i]
297				c_2: Coord = path[i + 1]
298
299				# find the dimension of the connection
300				delta: Coord = c_2 - c_1
301				dim: int = int(np.argmax(np.abs(delta)))
302
303				# if positive, down/right from current coord
304				# if negative, up/left from current coord (down/right from neighbor)
305				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
306				connection_list[dim, clist_node[0], clist_node[1]] = True
307				visited[c_1[0], c_1[1]] = True
308				# we dont add c_2 because the last c_2 will have already been visited
309
310		return LatticeMaze(
311			connection_list=connection_list,
312			generation_meta=dict(
313				func_name="gen_wilson",
314				grid_shape=grid_shape_,
315				fully_connected=True,
316			),
317		)
318
319	@staticmethod
320	def gen_percolation(
321		grid_shape: Coord | CoordTup,
322		p: float = 0.4,
323		lattice_dim: int = 2,
324		start_coord: Coord | None = None,
325	) -> LatticeMaze:
326		"""generate a lattice maze using simple percolation
327
328		note that p in the range (0.4, 0.7) gives the most interesting mazes
329
330		# Arguments
331		- `grid_shape: Coord`: the shape of the grid
332		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
333		- `p: float`: the probability of a cell being accessible (default: `0.5`)
334		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
335		"""
336		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
337		grid_shape_: Coord = np.array(grid_shape)
338
339		start_coord = _random_start_coord(grid_shape_, start_coord)
340
341		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
342
343		connection_list = _fill_edges_with_walls(connection_list)
344
345		output: LatticeMaze = LatticeMaze(
346			connection_list=connection_list,
347			generation_meta=dict(
348				func_name="gen_percolation",
349				grid_shape=grid_shape_,
350				percolation_p=p,
351				start_coord=start_coord,
352			),
353		)
354
355		# generation_meta is sometimes None, but not here since we just made it a dict above
356		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
357			start_coord,
358		)
359
360		return output
361
362	@staticmethod
363	def gen_dfs_percolation(
364		grid_shape: Coord | CoordTup,
365		p: float = 0.4,
366		lattice_dim: int = 2,
367		accessible_cells: int | None = None,
368		max_tree_depth: int | None = None,
369		start_coord: Coord | None = None,
370	) -> LatticeMaze:
371		"""dfs and then percolation (adds cycles)"""
372		grid_shape_: Coord = np.array(grid_shape)
373		start_coord = _random_start_coord(grid_shape_, start_coord)
374
375		# generate initial maze via dfs
376		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
377			grid_shape=grid_shape_,
378			lattice_dim=lattice_dim,
379			accessible_cells=accessible_cells,
380			max_tree_depth=max_tree_depth,
381			start_coord=start_coord,
382		)
383
384		# percolate
385		connection_list_perc: np.ndarray = (
386			np.random.rand(*maze.connection_list.shape) < p
387		)
388		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
389
390		maze.__dict__["connection_list"] = np.logical_or(
391			maze.connection_list,
392			connection_list_perc,
393		)
394
395		# generation_meta is sometimes None, but not here since we just made it a dict above
396		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
397		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
398		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
399			start_coord,
400		)
401
402		return maze
403
404	@staticmethod
405	def gen_kruskal(
406		grid_shape: "Coord | CoordTup",
407		lattice_dim: int = 2,
408		start_coord: "Coord | None" = None,
409	) -> "LatticeMaze":
410		"""Generate a maze using Kruskal's algorithm.
411
412		This function generates a random spanning tree over a grid using Kruskal's algorithm.
413		Each cell is treated as a node, and all valid adjacent edges are listed and processed
414		in random order. An edge is added (i.e. its passage carved) only if it connects two cells
415		that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
416		without cycles.
417
418		https://en.wikipedia.org/wiki/Kruskal's_algorithm
419
420		# Parameters:
421		- `grid_shape : Coord | CoordTup`
422			The shape of the maze grid (for example, `(n_rows, n_cols)`).
423		- `lattice_dim : int`
424			The lattice dimension (default is `2`).
425		- `start_coord : Coord | None`
426			Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
427		- `**kwargs`
428			Additional keyword arguments (currently unused).
429
430		# Returns:
431		- `LatticeMaze`
432			A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
433
434		# Usage:
435		```python
436		maze = gen_kruskal((10, 10))
437		```
438		"""
439		assert lattice_dim == 2, (  # noqa: PLR2004
440			"Kruskal's algorithm is only implemented for 2D lattices."
441		)
442		# Convert grid_shape to a tuple of ints
443		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
444		n_rows, n_cols = grid_shape_
445
446		# Initialize union-find data structure.
447		parent: dict[tuple[int, int], tuple[int, int]] = {}
448
449		def find(cell: tuple[int, int]) -> tuple[int, int]:
450			while parent[cell] != cell:
451				parent[cell] = parent[parent[cell]]
452				cell = parent[cell]
453			return cell
454
455		def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
456			root1 = find(cell1)
457			root2 = find(cell2)
458			parent[root2] = root1
459
460		# Initialize each cell as its own set.
461		for i in range(n_rows):
462			for j in range(n_cols):
463				parent[(i, j)] = (i, j)
464
465		# List all possible edges.
466		# For vertical edges (i.e. connecting a cell to its right neighbor):
467		edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
468		for i in range(n_rows):
469			for j in range(n_cols - 1):
470				edges.append(((i, j), (i, j + 1), 1))
471		# For horizontal edges (i.e. connecting a cell to its bottom neighbor):
472		for i in range(n_rows - 1):
473			for j in range(n_cols):
474				edges.append(((i, j), (i + 1, j), 0))
475
476		# Shuffle the list of edges.
477		import random
478
479		random.shuffle(edges)
480
481		# Initialize connection_list with no connections.
482		# connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
483		# connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
484		import numpy as np
485
486		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
487
488		# Process each edge; if it connects two different trees, union them and carve the passage.
489		for cell1, cell2, direction in edges:
490			if find(cell1) != find(cell2):
491				union(cell1, cell2)
492				if direction == 0:
493					# Horizontal edge: connection is stored in connection_list[0] at cell1.
494					connection_list[0, cell1[0], cell1[1]] = True
495				else:
496					# Vertical edge: connection is stored in connection_list[1] at cell1.
497					connection_list[1, cell1[0], cell1[1]] = True
498
499		if start_coord is None:
500			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
501
502		generation_meta: dict = dict(
503			func_name="gen_kruskal",
504			grid_shape=grid_shape_,
505			start_coord=start_coord,
506			algorithm="kruskal",
507			fully_connected=True,
508		)
509		return LatticeMaze(
510			connection_list=connection_list, generation_meta=generation_meta
511		)
512
513	@staticmethod
514	def gen_recursive_division(
515		grid_shape: "Coord | CoordTup",
516		lattice_dim: int = 2,
517		start_coord: "Coord | None" = None,
518	) -> "LatticeMaze":
519		"""Generate a maze using the recursive division algorithm.
520
521		This function generates a maze by recursively dividing the grid with walls and carving a single
522		passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
523		cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
524		The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
525
526		# Parameters:
527		- `grid_shape : Coord | CoordTup`
528			The shape of the maze grid (e.g., `(n_rows, n_cols)`).
529		- `lattice_dim : int`
530			The lattice dimension (default is `2`).
531		- `start_coord : Coord | None`
532			Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
533		- `**kwargs`
534			Additional keyword arguments (currently unused).
535
536		# Returns:
537		- `LatticeMaze`
538			A maze represented by a connection list, generated using recursive division.
539
540		# Usage:
541		```python
542		maze = gen_recursive_division((10, 10))
543		```
544		"""
545		assert lattice_dim == 2, (  # noqa: PLR2004
546			"Recursive division algorithm is only implemented for 2D lattices."
547		)
548		# Convert grid_shape to a tuple of ints.
549		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
550		n_rows, n_cols = grid_shape_
551
552		# Initialize connection_list as a fully connected grid.
553		# For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
554		# For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
555		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
556		connection_list[0, : n_rows - 1, :] = True
557		connection_list[1, :, : n_cols - 1] = True
558
559		def divide(x: int, y: int, width: int, height: int) -> None:
560			"""Recursively divide the region starting at (x, y) with the given width and height.
561
562			Removes connections along the chosen division line except for one randomly chosen gap.
563			"""
564			if width < 2 or height < 2:  # noqa: PLR2004
565				return
566
567			if width > height:
568				# Vertical division.
569				wall_col = random.randint(x + 1, x + width - 1)
570				gap_row = random.randint(y, y + height - 1)
571				for row in range(y, y + height):
572					if row == gap_row:
573						continue
574					# Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
575					if wall_col - 1 < n_cols - 1:
576						connection_list[1, row, wall_col - 1] = False
577				# Recurse on the left and right subregions.
578				divide(x, y, wall_col - x, height)
579				divide(wall_col, y, x + width - wall_col, height)
580			else:
581				# Horizontal division.
582				wall_row = random.randint(y + 1, y + height - 1)
583				gap_col = random.randint(x, x + width - 1)
584				for col in range(x, x + width):
585					if col == gap_col:
586						continue
587					# Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
588					if wall_row - 1 < n_rows - 1:
589						connection_list[0, wall_row - 1, col] = False
590				# Recurse on the top and bottom subregions.
591				divide(x, y, width, wall_row - y)
592				divide(x, wall_row, width, y + height - wall_row)
593
594		# Begin the division on the full grid.
595		divide(0, 0, n_cols, n_rows)
596
597		if start_coord is None:
598			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
599
600		generation_meta: dict = dict(
601			func_name="gen_recursive_division",
602			grid_shape=grid_shape_,
603			start_coord=start_coord,
604			algorithm="recursive_division",
605			fully_connected=True,
606		)
607		return LatticeMaze(
608			connection_list=connection_list, generation_meta=generation_meta
609		)
610
611
612# cant automatically populate this because it messes with pickling :(
613GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = {
614	"gen_dfs": LatticeMazeGenerators.gen_dfs,
615	# TYPING: error: Dict entry 1 has incompatible type
616	# "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]";
617	# expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]"  [dict-item]
618	# gen_wilson takes no kwargs and we check that the kwargs are empty
619	# but mypy doesnt like this, `Any` != `KwArg(Any)`
620	"gen_wilson": LatticeMazeGenerators.gen_wilson,  # type: ignore[dict-item]
621	"gen_percolation": LatticeMazeGenerators.gen_percolation,
622	"gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation,
623	"gen_prim": LatticeMazeGenerators.gen_prim,
624	"gen_kruskal": LatticeMazeGenerators.gen_kruskal,
625	"gen_recursive_division": LatticeMazeGenerators.gen_recursive_division,
626}
627"mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`"
628
629_GENERATORS_PERCOLATED: list[str] = [
630	"gen_percolation",
631	"gen_dfs_percolation",
632]
633"""list of generator names that generate percolated mazes
634we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail
635this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array`
636"""
637
638
639def get_maze_with_solution(
640	gen_name: str,
641	grid_shape: Coord | CoordTup,
642	maze_ctor_kwargs: dict | None = None,
643) -> SolvedMaze:
644	"helper function to get a maze already with a solution"
645	if maze_ctor_kwargs is None:
646		maze_ctor_kwargs = dict()
647	# TYPING: error: Too few arguments  [call-arg]
648	# not sure why this is happening -- doesnt recognize the kwargs?
649	maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)  # type: ignore[call-arg]
650	solution: CoordArray = np.array(maze.generate_random_path())
651	return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

numpy_rng = Generator(PCG64) at 0x7A06400C1460
def get_neighbors_in_bounds( coord: jaxtyping.Int8[ndarray, 'row_col=2'], grid_shape: jaxtyping.Int8[ndarray, 'row_col=2']) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
38def get_neighbors_in_bounds(
39	coord: Coord,
40	grid_shape: Coord,
41) -> CoordArray:
42	"get all neighbors of a coordinate that are within the bounds of the grid"
43	# get all neighbors
44	neighbors: CoordArray = coord + NEIGHBORS_MASK
45
46	# filter neighbors by being within grid bounds
47	neighbors_in_bounds: CoordArray = neighbors[
48		(neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1)
49	]
50
51	return neighbors_in_bounds

get all neighbors of a coordinate that are within the bounds of the grid

class LatticeMazeGenerators:
 54class LatticeMazeGenerators:
 55	"""namespace for lattice maze generation algorithms
 56
 57	examples of generated mazes can be found here:
 58	https://understanding-search.github.io/maze-dataset/examples/maze_examples.html
 59	"""
 60
 61	@staticmethod
 62	def gen_dfs(
 63		grid_shape: Coord | CoordTup,
 64		lattice_dim: int = 2,
 65		accessible_cells: float | None = None,
 66		max_tree_depth: float | None = None,
 67		do_forks: bool = True,
 68		randomized_stack: bool = False,
 69		start_coord: Coord | None = None,
 70	) -> LatticeMaze:
 71		"""generate a lattice maze using depth first search, iterative
 72
 73		# Arguments
 74		- `grid_shape: Coord`: the shape of the grid
 75		- `lattice_dim: int`: the dimension of the lattice
 76			(default: `2`)
 77		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 78			(default: `None`)
 79		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 80			(default: `None`)
 81		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 82		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 83
 84		# algorithm
 85		1. Choose the initial cell, mark it as visited and push it to the stack
 86		2. While the stack is not empty
 87			1. Pop a cell from the stack and make it a current cell
 88			2. If the current cell has any neighbours which have not been visited
 89				1. Push the current cell to the stack
 90				2. Choose one of the unvisited neighbours
 91				3. Remove the wall between the current cell and the chosen cell
 92				4. Mark the chosen cell as visited and push it to the stack
 93		"""
 94		# Default values if no constraints have been passed
 95		grid_shape_: Coord = np.array(grid_shape)
 96		n_total_cells: int = int(np.prod(grid_shape_))
 97
 98		n_accessible_cells: int
 99		if accessible_cells is None:
100			n_accessible_cells = n_total_cells
101		elif isinstance(accessible_cells, float):
102			assert accessible_cells <= 1, (
103				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
104			)
105
106			n_accessible_cells = int(accessible_cells * n_total_cells)
107		else:
108			assert isinstance(accessible_cells, int)
109			n_accessible_cells = accessible_cells
110
111		if max_tree_depth is None:
112			max_tree_depth = (
113				2 * n_total_cells
114			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
115		elif isinstance(max_tree_depth, float):
116			assert max_tree_depth <= 1, (
117				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
118			)
119
120			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
121
122		# choose a random start coord
123		start_coord = _random_start_coord(grid_shape_, start_coord)
124
125		# initialize the maze with no connections
126		connection_list: ConnectionList = np.zeros(
127			(lattice_dim, grid_shape_[0], grid_shape_[1]),
128			dtype=np.bool_,
129		)
130
131		# initialize the stack with the target coord
132		visited_cells: set[tuple[int, int]] = set()
133		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
134		stack: list[Coord] = [start_coord]
135
136		# initialize tree_depth_counter
137		current_tree_depth: int = 1
138
139		# loop until the stack is empty or n_connected_cells is reached
140		while stack and (len(visited_cells) < n_accessible_cells):
141			# get the current coord from the stack
142			current_coord: Coord
143			if randomized_stack:
144				current_coord = stack.pop(random.randint(0, len(stack) - 1))
145			else:
146				current_coord = stack.pop()
147
148			# filter neighbors by being within grid bounds and being unvisited
149			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
150				(neighbor, delta)
151				for neighbor, delta in zip(
152					current_coord + NEIGHBORS_MASK,
153					NEIGHBORS_MASK,
154					strict=False,
155				)
156				if (
157					(tuple(neighbor) not in visited_cells)
158					and (0 <= neighbor[0] < grid_shape_[0])
159					and (0 <= neighbor[1] < grid_shape_[1])
160				)
161			]
162
163			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
164			if unvisited_neighbors_deltas and (
165				current_tree_depth <= max_tree_depth / 2
166			):
167				# if we want a maze without forks, simply don't add the current coord back to the stack
168				if do_forks and (len(unvisited_neighbors_deltas) > 1):
169					stack.append(current_coord)
170
171				# choose one of the unvisited neighbors
172				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
173
174				# add connection
175				dim: int = int(np.argmax(np.abs(delta)))
176				# if positive, down/right from current coord
177				# if negative, up/left from current coord (down/right from neighbor)
178				clist_node: Coord = (
179					current_coord if (delta.sum() > 0) else chosen_neighbor
180				)
181				connection_list[dim, clist_node[0], clist_node[1]] = True
182
183				# add to visited cells and stack
184				visited_cells.add(tuple(chosen_neighbor))
185				stack.append(chosen_neighbor)
186
187				# Update current tree depth
188				current_tree_depth += 1
189			else:
190				current_tree_depth -= 1
191
192		return LatticeMaze(
193			connection_list=connection_list,
194			generation_meta=dict(
195				func_name="gen_dfs",
196				grid_shape=grid_shape_,
197				start_coord=start_coord,
198				n_accessible_cells=int(n_accessible_cells),
199				max_tree_depth=int(max_tree_depth),
200				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
201				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
202				# treated as fully connected even when it is most certainly not, causing solving the maze to break
203				fully_connected=bool(len(visited_cells) == n_total_cells),
204				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
205			),
206		)
207
208	@staticmethod
209	def gen_prim(
210		grid_shape: Coord | CoordTup,
211		lattice_dim: int = 2,
212		accessible_cells: float | None = None,
213		max_tree_depth: float | None = None,
214		do_forks: bool = True,
215		start_coord: Coord | None = None,
216	) -> LatticeMaze:
217		"(broken!) generate a lattice maze using Prim's algorithm"
218		warnings.warn(
219			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
220		)
221		return LatticeMazeGenerators.gen_dfs(
222			grid_shape=grid_shape,
223			lattice_dim=lattice_dim,
224			accessible_cells=accessible_cells,
225			max_tree_depth=max_tree_depth,
226			do_forks=do_forks,
227			start_coord=start_coord,
228			randomized_stack=True,
229		)
230
231	@staticmethod
232	def gen_wilson(
233		grid_shape: Coord | CoordTup,
234		**kwargs,
235	) -> LatticeMaze:
236		"""Generate a lattice maze using Wilson's algorithm.
237
238		# Algorithm
239		Wilson's algorithm generates an unbiased (random) maze
240		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
241		acyclic and all cells are part of a unique connected space.
242		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
243		"""
244		assert not kwargs, (
245			f"gen_wilson does not take any additional arguments, got {kwargs = }"
246		)
247
248		grid_shape_: Coord = np.array(grid_shape)
249
250		# Initialize grid and visited cells
251		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
252		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
253
254		# Choose a random cell and mark it as visited
255		start_coord: Coord = _random_start_coord(grid_shape_, None)
256		visited[start_coord[0], start_coord[1]] = True
257		del start_coord
258
259		while not visited.all():
260			# Perform loop-erased random walk from another random cell
261
262			# Choose walk_start only from unvisited cells
263			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
264			walk_start: Coord = unvisited_coords[
265				np.random.choice(unvisited_coords.shape[0])
266			]
267
268			# Perform the random walk
269			path: list[Coord] = [walk_start]
270			current: Coord = walk_start
271
272			# exit the loop once the current path hits a visited cell
273			while not visited[current[0], current[1]]:
274				# find a valid neighbor (one always exists on a lattice)
275				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
276				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
277
278				# Check for loop
279				loop_exit: int | None = None
280				for i, p in enumerate(path):
281					if np.array_equal(next_cell, p):
282						loop_exit = i
283						break
284
285				# erase the loop, or continue the walk
286				if loop_exit is not None:
287					# this removes everything after and including the loop start
288					path = path[: loop_exit + 1]
289					# reset current cell to end of path
290					current = path[-1]
291				else:
292					path.append(next_cell)
293					current = next_cell
294
295			# Add the path to the maze
296			for i in range(len(path) - 1):
297				c_1: Coord = path[i]
298				c_2: Coord = path[i + 1]
299
300				# find the dimension of the connection
301				delta: Coord = c_2 - c_1
302				dim: int = int(np.argmax(np.abs(delta)))
303
304				# if positive, down/right from current coord
305				# if negative, up/left from current coord (down/right from neighbor)
306				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
307				connection_list[dim, clist_node[0], clist_node[1]] = True
308				visited[c_1[0], c_1[1]] = True
309				# we dont add c_2 because the last c_2 will have already been visited
310
311		return LatticeMaze(
312			connection_list=connection_list,
313			generation_meta=dict(
314				func_name="gen_wilson",
315				grid_shape=grid_shape_,
316				fully_connected=True,
317			),
318		)
319
320	@staticmethod
321	def gen_percolation(
322		grid_shape: Coord | CoordTup,
323		p: float = 0.4,
324		lattice_dim: int = 2,
325		start_coord: Coord | None = None,
326	) -> LatticeMaze:
327		"""generate a lattice maze using simple percolation
328
329		note that p in the range (0.4, 0.7) gives the most interesting mazes
330
331		# Arguments
332		- `grid_shape: Coord`: the shape of the grid
333		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
334		- `p: float`: the probability of a cell being accessible (default: `0.5`)
335		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
336		"""
337		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
338		grid_shape_: Coord = np.array(grid_shape)
339
340		start_coord = _random_start_coord(grid_shape_, start_coord)
341
342		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
343
344		connection_list = _fill_edges_with_walls(connection_list)
345
346		output: LatticeMaze = LatticeMaze(
347			connection_list=connection_list,
348			generation_meta=dict(
349				func_name="gen_percolation",
350				grid_shape=grid_shape_,
351				percolation_p=p,
352				start_coord=start_coord,
353			),
354		)
355
356		# generation_meta is sometimes None, but not here since we just made it a dict above
357		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
358			start_coord,
359		)
360
361		return output
362
363	@staticmethod
364	def gen_dfs_percolation(
365		grid_shape: Coord | CoordTup,
366		p: float = 0.4,
367		lattice_dim: int = 2,
368		accessible_cells: int | None = None,
369		max_tree_depth: int | None = None,
370		start_coord: Coord | None = None,
371	) -> LatticeMaze:
372		"""dfs and then percolation (adds cycles)"""
373		grid_shape_: Coord = np.array(grid_shape)
374		start_coord = _random_start_coord(grid_shape_, start_coord)
375
376		# generate initial maze via dfs
377		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
378			grid_shape=grid_shape_,
379			lattice_dim=lattice_dim,
380			accessible_cells=accessible_cells,
381			max_tree_depth=max_tree_depth,
382			start_coord=start_coord,
383		)
384
385		# percolate
386		connection_list_perc: np.ndarray = (
387			np.random.rand(*maze.connection_list.shape) < p
388		)
389		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
390
391		maze.__dict__["connection_list"] = np.logical_or(
392			maze.connection_list,
393			connection_list_perc,
394		)
395
396		# generation_meta is sometimes None, but not here since we just made it a dict above
397		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
398		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
399		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
400			start_coord,
401		)
402
403		return maze
404
405	@staticmethod
406	def gen_kruskal(
407		grid_shape: "Coord | CoordTup",
408		lattice_dim: int = 2,
409		start_coord: "Coord | None" = None,
410	) -> "LatticeMaze":
411		"""Generate a maze using Kruskal's algorithm.
412
413		This function generates a random spanning tree over a grid using Kruskal's algorithm.
414		Each cell is treated as a node, and all valid adjacent edges are listed and processed
415		in random order. An edge is added (i.e. its passage carved) only if it connects two cells
416		that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
417		without cycles.
418
419		https://en.wikipedia.org/wiki/Kruskal's_algorithm
420
421		# Parameters:
422		- `grid_shape : Coord | CoordTup`
423			The shape of the maze grid (for example, `(n_rows, n_cols)`).
424		- `lattice_dim : int`
425			The lattice dimension (default is `2`).
426		- `start_coord : Coord | None`
427			Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
428		- `**kwargs`
429			Additional keyword arguments (currently unused).
430
431		# Returns:
432		- `LatticeMaze`
433			A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
434
435		# Usage:
436		```python
437		maze = gen_kruskal((10, 10))
438		```
439		"""
440		assert lattice_dim == 2, (  # noqa: PLR2004
441			"Kruskal's algorithm is only implemented for 2D lattices."
442		)
443		# Convert grid_shape to a tuple of ints
444		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
445		n_rows, n_cols = grid_shape_
446
447		# Initialize union-find data structure.
448		parent: dict[tuple[int, int], tuple[int, int]] = {}
449
450		def find(cell: tuple[int, int]) -> tuple[int, int]:
451			while parent[cell] != cell:
452				parent[cell] = parent[parent[cell]]
453				cell = parent[cell]
454			return cell
455
456		def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
457			root1 = find(cell1)
458			root2 = find(cell2)
459			parent[root2] = root1
460
461		# Initialize each cell as its own set.
462		for i in range(n_rows):
463			for j in range(n_cols):
464				parent[(i, j)] = (i, j)
465
466		# List all possible edges.
467		# For vertical edges (i.e. connecting a cell to its right neighbor):
468		edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
469		for i in range(n_rows):
470			for j in range(n_cols - 1):
471				edges.append(((i, j), (i, j + 1), 1))
472		# For horizontal edges (i.e. connecting a cell to its bottom neighbor):
473		for i in range(n_rows - 1):
474			for j in range(n_cols):
475				edges.append(((i, j), (i + 1, j), 0))
476
477		# Shuffle the list of edges.
478		import random
479
480		random.shuffle(edges)
481
482		# Initialize connection_list with no connections.
483		# connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
484		# connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
485		import numpy as np
486
487		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
488
489		# Process each edge; if it connects two different trees, union them and carve the passage.
490		for cell1, cell2, direction in edges:
491			if find(cell1) != find(cell2):
492				union(cell1, cell2)
493				if direction == 0:
494					# Horizontal edge: connection is stored in connection_list[0] at cell1.
495					connection_list[0, cell1[0], cell1[1]] = True
496				else:
497					# Vertical edge: connection is stored in connection_list[1] at cell1.
498					connection_list[1, cell1[0], cell1[1]] = True
499
500		if start_coord is None:
501			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
502
503		generation_meta: dict = dict(
504			func_name="gen_kruskal",
505			grid_shape=grid_shape_,
506			start_coord=start_coord,
507			algorithm="kruskal",
508			fully_connected=True,
509		)
510		return LatticeMaze(
511			connection_list=connection_list, generation_meta=generation_meta
512		)
513
514	@staticmethod
515	def gen_recursive_division(
516		grid_shape: "Coord | CoordTup",
517		lattice_dim: int = 2,
518		start_coord: "Coord | None" = None,
519	) -> "LatticeMaze":
520		"""Generate a maze using the recursive division algorithm.
521
522		This function generates a maze by recursively dividing the grid with walls and carving a single
523		passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
524		cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
525		The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
526
527		# Parameters:
528		- `grid_shape : Coord | CoordTup`
529			The shape of the maze grid (e.g., `(n_rows, n_cols)`).
530		- `lattice_dim : int`
531			The lattice dimension (default is `2`).
532		- `start_coord : Coord | None`
533			Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
534		- `**kwargs`
535			Additional keyword arguments (currently unused).
536
537		# Returns:
538		- `LatticeMaze`
539			A maze represented by a connection list, generated using recursive division.
540
541		# Usage:
542		```python
543		maze = gen_recursive_division((10, 10))
544		```
545		"""
546		assert lattice_dim == 2, (  # noqa: PLR2004
547			"Recursive division algorithm is only implemented for 2D lattices."
548		)
549		# Convert grid_shape to a tuple of ints.
550		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
551		n_rows, n_cols = grid_shape_
552
553		# Initialize connection_list as a fully connected grid.
554		# For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
555		# For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
556		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
557		connection_list[0, : n_rows - 1, :] = True
558		connection_list[1, :, : n_cols - 1] = True
559
560		def divide(x: int, y: int, width: int, height: int) -> None:
561			"""Recursively divide the region starting at (x, y) with the given width and height.
562
563			Removes connections along the chosen division line except for one randomly chosen gap.
564			"""
565			if width < 2 or height < 2:  # noqa: PLR2004
566				return
567
568			if width > height:
569				# Vertical division.
570				wall_col = random.randint(x + 1, x + width - 1)
571				gap_row = random.randint(y, y + height - 1)
572				for row in range(y, y + height):
573					if row == gap_row:
574						continue
575					# Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
576					if wall_col - 1 < n_cols - 1:
577						connection_list[1, row, wall_col - 1] = False
578				# Recurse on the left and right subregions.
579				divide(x, y, wall_col - x, height)
580				divide(wall_col, y, x + width - wall_col, height)
581			else:
582				# Horizontal division.
583				wall_row = random.randint(y + 1, y + height - 1)
584				gap_col = random.randint(x, x + width - 1)
585				for col in range(x, x + width):
586					if col == gap_col:
587						continue
588					# Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
589					if wall_row - 1 < n_rows - 1:
590						connection_list[0, wall_row - 1, col] = False
591				# Recurse on the top and bottom subregions.
592				divide(x, y, width, wall_row - y)
593				divide(x, wall_row, width, y + height - wall_row)
594
595		# Begin the division on the full grid.
596		divide(0, 0, n_cols, n_rows)
597
598		if start_coord is None:
599			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
600
601		generation_meta: dict = dict(
602			func_name="gen_recursive_division",
603			grid_shape=grid_shape_,
604			start_coord=start_coord,
605			algorithm="recursive_division",
606			fully_connected=True,
607		)
608		return LatticeMaze(
609			connection_list=connection_list, generation_meta=generation_meta
610		)

namespace for lattice maze generation algorithms

examples of generated mazes can be found here: https://understanding-search.github.io/maze-dataset/examples/maze_examples.html

@staticmethod
def gen_dfs( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, randomized_stack: bool = False, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
 61	@staticmethod
 62	def gen_dfs(
 63		grid_shape: Coord | CoordTup,
 64		lattice_dim: int = 2,
 65		accessible_cells: float | None = None,
 66		max_tree_depth: float | None = None,
 67		do_forks: bool = True,
 68		randomized_stack: bool = False,
 69		start_coord: Coord | None = None,
 70	) -> LatticeMaze:
 71		"""generate a lattice maze using depth first search, iterative
 72
 73		# Arguments
 74		- `grid_shape: Coord`: the shape of the grid
 75		- `lattice_dim: int`: the dimension of the lattice
 76			(default: `2`)
 77		- `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells**
 78			(default: `None`)
 79		- `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape**
 80			(default: `None`)
 81		- `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway.
 82		- `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate.
 83
 84		# algorithm
 85		1. Choose the initial cell, mark it as visited and push it to the stack
 86		2. While the stack is not empty
 87			1. Pop a cell from the stack and make it a current cell
 88			2. If the current cell has any neighbours which have not been visited
 89				1. Push the current cell to the stack
 90				2. Choose one of the unvisited neighbours
 91				3. Remove the wall between the current cell and the chosen cell
 92				4. Mark the chosen cell as visited and push it to the stack
 93		"""
 94		# Default values if no constraints have been passed
 95		grid_shape_: Coord = np.array(grid_shape)
 96		n_total_cells: int = int(np.prod(grid_shape_))
 97
 98		n_accessible_cells: int
 99		if accessible_cells is None:
100			n_accessible_cells = n_total_cells
101		elif isinstance(accessible_cells, float):
102			assert accessible_cells <= 1, (
103				f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}"
104			)
105
106			n_accessible_cells = int(accessible_cells * n_total_cells)
107		else:
108			assert isinstance(accessible_cells, int)
109			n_accessible_cells = accessible_cells
110
111		if max_tree_depth is None:
112			max_tree_depth = (
113				2 * n_total_cells
114			)  # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here.
115		elif isinstance(max_tree_depth, float):
116			assert max_tree_depth <= 1, (
117				f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}"
118			)
119
120			max_tree_depth = int(max_tree_depth * np.sum(grid_shape_))
121
122		# choose a random start coord
123		start_coord = _random_start_coord(grid_shape_, start_coord)
124
125		# initialize the maze with no connections
126		connection_list: ConnectionList = np.zeros(
127			(lattice_dim, grid_shape_[0], grid_shape_[1]),
128			dtype=np.bool_,
129		)
130
131		# initialize the stack with the target coord
132		visited_cells: set[tuple[int, int]] = set()
133		visited_cells.add(tuple(start_coord))  # this wasnt a bug after all lol
134		stack: list[Coord] = [start_coord]
135
136		# initialize tree_depth_counter
137		current_tree_depth: int = 1
138
139		# loop until the stack is empty or n_connected_cells is reached
140		while stack and (len(visited_cells) < n_accessible_cells):
141			# get the current coord from the stack
142			current_coord: Coord
143			if randomized_stack:
144				current_coord = stack.pop(random.randint(0, len(stack) - 1))
145			else:
146				current_coord = stack.pop()
147
148			# filter neighbors by being within grid bounds and being unvisited
149			unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [
150				(neighbor, delta)
151				for neighbor, delta in zip(
152					current_coord + NEIGHBORS_MASK,
153					NEIGHBORS_MASK,
154					strict=False,
155				)
156				if (
157					(tuple(neighbor) not in visited_cells)
158					and (0 <= neighbor[0] < grid_shape_[0])
159					and (0 <= neighbor[1] < grid_shape_[1])
160				)
161			]
162
163			# don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions)
164			if unvisited_neighbors_deltas and (
165				current_tree_depth <= max_tree_depth / 2
166			):
167				# if we want a maze without forks, simply don't add the current coord back to the stack
168				if do_forks and (len(unvisited_neighbors_deltas) > 1):
169					stack.append(current_coord)
170
171				# choose one of the unvisited neighbors
172				chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas)
173
174				# add connection
175				dim: int = int(np.argmax(np.abs(delta)))
176				# if positive, down/right from current coord
177				# if negative, up/left from current coord (down/right from neighbor)
178				clist_node: Coord = (
179					current_coord if (delta.sum() > 0) else chosen_neighbor
180				)
181				connection_list[dim, clist_node[0], clist_node[1]] = True
182
183				# add to visited cells and stack
184				visited_cells.add(tuple(chosen_neighbor))
185				stack.append(chosen_neighbor)
186
187				# Update current tree depth
188				current_tree_depth += 1
189			else:
190				current_tree_depth -= 1
191
192		return LatticeMaze(
193			connection_list=connection_list,
194			generation_meta=dict(
195				func_name="gen_dfs",
196				grid_shape=grid_shape_,
197				start_coord=start_coord,
198				n_accessible_cells=int(n_accessible_cells),
199				max_tree_depth=int(max_tree_depth),
200				# oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug
201				# it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is
202				# treated as fully connected even when it is most certainly not, causing solving the maze to break
203				fully_connected=bool(len(visited_cells) == n_total_cells),
204				visited_cells={tuple(int(x) for x in coord) for coord in visited_cells},
205			),
206		)

generate a lattice maze using depth first search, iterative

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • accessible_cells: int | float |None: the number of accessible cells in the maze. If None, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of total cells (default: None)
  • max_tree_depth: int | float | None: the maximum depth of the tree. If None, defaults to 2 * accessible_cells. if a float, asserts it is <= 1 and treats it as a proportion of the sum of the grid shape (default: None)
  • do_forks: bool: whether to allow forks in the maze. If False, the maze will be have no forks and will be a simple hallway.
  • start_coord: Coord | None: the starting coordinate of the generation algorithm. If None, defaults to a random coordinate.

algorithm

  1. Choose the initial cell, mark it as visited and push it to the stack
  2. While the stack is not empty
    1. Pop a cell from the stack and make it a current cell
    2. If the current cell has any neighbours which have not been visited
      1. Push the current cell to the stack
      2. Choose one of the unvisited neighbours
      3. Remove the wall between the current cell and the chosen cell
      4. Mark the chosen cell as visited and push it to the stack
@staticmethod
def gen_prim( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
208	@staticmethod
209	def gen_prim(
210		grid_shape: Coord | CoordTup,
211		lattice_dim: int = 2,
212		accessible_cells: float | None = None,
213		max_tree_depth: float | None = None,
214		do_forks: bool = True,
215		start_coord: Coord | None = None,
216	) -> LatticeMaze:
217		"(broken!) generate a lattice maze using Prim's algorithm"
218		warnings.warn(
219			"gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12",
220		)
221		return LatticeMazeGenerators.gen_dfs(
222			grid_shape=grid_shape,
223			lattice_dim=lattice_dim,
224			accessible_cells=accessible_cells,
225			max_tree_depth=max_tree_depth,
226			do_forks=do_forks,
227			start_coord=start_coord,
228			randomized_stack=True,
229		)

(broken!) generate a lattice maze using Prim's algorithm

@staticmethod
def gen_wilson( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], **kwargs) -> maze_dataset.LatticeMaze:
231	@staticmethod
232	def gen_wilson(
233		grid_shape: Coord | CoordTup,
234		**kwargs,
235	) -> LatticeMaze:
236		"""Generate a lattice maze using Wilson's algorithm.
237
238		# Algorithm
239		Wilson's algorithm generates an unbiased (random) maze
240		sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is
241		acyclic and all cells are part of a unique connected space.
242		https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm
243		"""
244		assert not kwargs, (
245			f"gen_wilson does not take any additional arguments, got {kwargs = }"
246		)
247
248		grid_shape_: Coord = np.array(grid_shape)
249
250		# Initialize grid and visited cells
251		connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_)
252		visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_)
253
254		# Choose a random cell and mark it as visited
255		start_coord: Coord = _random_start_coord(grid_shape_, None)
256		visited[start_coord[0], start_coord[1]] = True
257		del start_coord
258
259		while not visited.all():
260			# Perform loop-erased random walk from another random cell
261
262			# Choose walk_start only from unvisited cells
263			unvisited_coords: CoordArray = np.column_stack(np.where(~visited))
264			walk_start: Coord = unvisited_coords[
265				np.random.choice(unvisited_coords.shape[0])
266			]
267
268			# Perform the random walk
269			path: list[Coord] = [walk_start]
270			current: Coord = walk_start
271
272			# exit the loop once the current path hits a visited cell
273			while not visited[current[0], current[1]]:
274				# find a valid neighbor (one always exists on a lattice)
275				neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_)
276				next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])]
277
278				# Check for loop
279				loop_exit: int | None = None
280				for i, p in enumerate(path):
281					if np.array_equal(next_cell, p):
282						loop_exit = i
283						break
284
285				# erase the loop, or continue the walk
286				if loop_exit is not None:
287					# this removes everything after and including the loop start
288					path = path[: loop_exit + 1]
289					# reset current cell to end of path
290					current = path[-1]
291				else:
292					path.append(next_cell)
293					current = next_cell
294
295			# Add the path to the maze
296			for i in range(len(path) - 1):
297				c_1: Coord = path[i]
298				c_2: Coord = path[i + 1]
299
300				# find the dimension of the connection
301				delta: Coord = c_2 - c_1
302				dim: int = int(np.argmax(np.abs(delta)))
303
304				# if positive, down/right from current coord
305				# if negative, up/left from current coord (down/right from neighbor)
306				clist_node: Coord = c_1 if (delta.sum() > 0) else c_2
307				connection_list[dim, clist_node[0], clist_node[1]] = True
308				visited[c_1[0], c_1[1]] = True
309				# we dont add c_2 because the last c_2 will have already been visited
310
311		return LatticeMaze(
312			connection_list=connection_list,
313			generation_meta=dict(
314				func_name="gen_wilson",
315				grid_shape=grid_shape_,
316				fully_connected=True,
317			),
318		)

Generate a lattice maze using Wilson's algorithm.

Algorithm

Wilson's algorithm generates an unbiased (random) maze sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is acyclic and all cells are part of a unique connected space. https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm

@staticmethod
def gen_percolation( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], p: float = 0.4, lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
320	@staticmethod
321	def gen_percolation(
322		grid_shape: Coord | CoordTup,
323		p: float = 0.4,
324		lattice_dim: int = 2,
325		start_coord: Coord | None = None,
326	) -> LatticeMaze:
327		"""generate a lattice maze using simple percolation
328
329		note that p in the range (0.4, 0.7) gives the most interesting mazes
330
331		# Arguments
332		- `grid_shape: Coord`: the shape of the grid
333		- `lattice_dim: int`: the dimension of the lattice (default: `2`)
334		- `p: float`: the probability of a cell being accessible (default: `0.5`)
335		- `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start)
336		"""
337		assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}"  # noqa: PT018
338		grid_shape_: Coord = np.array(grid_shape)
339
340		start_coord = _random_start_coord(grid_shape_, start_coord)
341
342		connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p
343
344		connection_list = _fill_edges_with_walls(connection_list)
345
346		output: LatticeMaze = LatticeMaze(
347			connection_list=connection_list,
348			generation_meta=dict(
349				func_name="gen_percolation",
350				grid_shape=grid_shape_,
351				percolation_p=p,
352				start_coord=start_coord,
353			),
354		)
355
356		# generation_meta is sometimes None, but not here since we just made it a dict above
357		output.generation_meta["visited_cells"] = output.gen_connected_component_from(  # type: ignore[index]
358			start_coord,
359		)
360
361		return output

generate a lattice maze using simple percolation

note that p in the range (0.4, 0.7) gives the most interesting mazes

Arguments

  • grid_shape: Coord: the shape of the grid
  • lattice_dim: int: the dimension of the lattice (default: 2)
  • p: float: the probability of a cell being accessible (default: 0.5)
  • start_coord: Coord | None: the starting coordinate for the connected component (default: None will give a random start)
@staticmethod
def gen_dfs_percolation( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], p: float = 0.4, lattice_dim: int = 2, accessible_cells: int | None = None, max_tree_depth: int | None = None, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
363	@staticmethod
364	def gen_dfs_percolation(
365		grid_shape: Coord | CoordTup,
366		p: float = 0.4,
367		lattice_dim: int = 2,
368		accessible_cells: int | None = None,
369		max_tree_depth: int | None = None,
370		start_coord: Coord | None = None,
371	) -> LatticeMaze:
372		"""dfs and then percolation (adds cycles)"""
373		grid_shape_: Coord = np.array(grid_shape)
374		start_coord = _random_start_coord(grid_shape_, start_coord)
375
376		# generate initial maze via dfs
377		maze: LatticeMaze = LatticeMazeGenerators.gen_dfs(
378			grid_shape=grid_shape_,
379			lattice_dim=lattice_dim,
380			accessible_cells=accessible_cells,
381			max_tree_depth=max_tree_depth,
382			start_coord=start_coord,
383		)
384
385		# percolate
386		connection_list_perc: np.ndarray = (
387			np.random.rand(*maze.connection_list.shape) < p
388		)
389		connection_list_perc = _fill_edges_with_walls(connection_list_perc)
390
391		maze.__dict__["connection_list"] = np.logical_or(
392			maze.connection_list,
393			connection_list_perc,
394		)
395
396		# generation_meta is sometimes None, but not here since we just made it a dict above
397		maze.generation_meta["func_name"] = "gen_dfs_percolation"  # type: ignore[index]
398		maze.generation_meta["percolation_p"] = p  # type: ignore[index]
399		maze.generation_meta["visited_cells"] = maze.gen_connected_component_from(  # type: ignore[index]
400			start_coord,
401		)
402
403		return maze

dfs and then percolation (adds cycles)

@staticmethod
def gen_kruskal( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
405	@staticmethod
406	def gen_kruskal(
407		grid_shape: "Coord | CoordTup",
408		lattice_dim: int = 2,
409		start_coord: "Coord | None" = None,
410	) -> "LatticeMaze":
411		"""Generate a maze using Kruskal's algorithm.
412
413		This function generates a random spanning tree over a grid using Kruskal's algorithm.
414		Each cell is treated as a node, and all valid adjacent edges are listed and processed
415		in random order. An edge is added (i.e. its passage carved) only if it connects two cells
416		that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree)
417		without cycles.
418
419		https://en.wikipedia.org/wiki/Kruskal's_algorithm
420
421		# Parameters:
422		- `grid_shape : Coord | CoordTup`
423			The shape of the maze grid (for example, `(n_rows, n_cols)`).
424		- `lattice_dim : int`
425			The lattice dimension (default is `2`).
426		- `start_coord : Coord | None`
427			Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen.
428		- `**kwargs`
429			Additional keyword arguments (currently unused).
430
431		# Returns:
432		- `LatticeMaze`
433			A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.
434
435		# Usage:
436		```python
437		maze = gen_kruskal((10, 10))
438		```
439		"""
440		assert lattice_dim == 2, (  # noqa: PLR2004
441			"Kruskal's algorithm is only implemented for 2D lattices."
442		)
443		# Convert grid_shape to a tuple of ints
444		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
445		n_rows, n_cols = grid_shape_
446
447		# Initialize union-find data structure.
448		parent: dict[tuple[int, int], tuple[int, int]] = {}
449
450		def find(cell: tuple[int, int]) -> tuple[int, int]:
451			while parent[cell] != cell:
452				parent[cell] = parent[parent[cell]]
453				cell = parent[cell]
454			return cell
455
456		def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None:
457			root1 = find(cell1)
458			root2 = find(cell2)
459			parent[root2] = root1
460
461		# Initialize each cell as its own set.
462		for i in range(n_rows):
463			for j in range(n_cols):
464				parent[(i, j)] = (i, j)
465
466		# List all possible edges.
467		# For vertical edges (i.e. connecting a cell to its right neighbor):
468		edges: list[tuple[tuple[int, int], tuple[int, int], int]] = []
469		for i in range(n_rows):
470			for j in range(n_cols - 1):
471				edges.append(((i, j), (i, j + 1), 1))
472		# For horizontal edges (i.e. connecting a cell to its bottom neighbor):
473		for i in range(n_rows - 1):
474			for j in range(n_cols):
475				edges.append(((i, j), (i + 1, j), 0))
476
477		# Shuffle the list of edges.
478		import random
479
480		random.shuffle(edges)
481
482		# Initialize connection_list with no connections.
483		# connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)).
484		# connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)).
485		import numpy as np
486
487		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
488
489		# Process each edge; if it connects two different trees, union them and carve the passage.
490		for cell1, cell2, direction in edges:
491			if find(cell1) != find(cell2):
492				union(cell1, cell2)
493				if direction == 0:
494					# Horizontal edge: connection is stored in connection_list[0] at cell1.
495					connection_list[0, cell1[0], cell1[1]] = True
496				else:
497					# Vertical edge: connection is stored in connection_list[1] at cell1.
498					connection_list[1, cell1[0], cell1[1]] = True
499
500		if start_coord is None:
501			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
502
503		generation_meta: dict = dict(
504			func_name="gen_kruskal",
505			grid_shape=grid_shape_,
506			start_coord=start_coord,
507			algorithm="kruskal",
508			fully_connected=True,
509		)
510		return LatticeMaze(
511			connection_list=connection_list, generation_meta=generation_meta
512		)

Generate a maze using Kruskal's algorithm.

This function generates a random spanning tree over a grid using Kruskal's algorithm. Each cell is treated as a node, and all valid adjacent edges are listed and processed in random order. An edge is added (i.e. its passage carved) only if it connects two cells that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) without cycles.

https://en.wikipedia.org/wiki/Kruskal's_algorithm

Parameters:

  • grid_shape : Coord | CoordTup The shape of the maze grid (for example, (n_rows, n_cols)).
  • lattice_dim : int The lattice dimension (default is 2).
  • start_coord : Coord | None Optionally, specify a starting coordinate. If None, a random coordinate will be chosen.
  • **kwargs Additional keyword arguments (currently unused).

Returns:

  • LatticeMaze A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm.

Usage:

maze = gen_kruskal((10, 10))
@staticmethod
def gen_recursive_division( grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], lattice_dim: int = 2, start_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None) -> maze_dataset.LatticeMaze:
514	@staticmethod
515	def gen_recursive_division(
516		grid_shape: "Coord | CoordTup",
517		lattice_dim: int = 2,
518		start_coord: "Coord | None" = None,
519	) -> "LatticeMaze":
520		"""Generate a maze using the recursive division algorithm.
521
522		This function generates a maze by recursively dividing the grid with walls and carving a single
523		passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent
524		cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage.
525		The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.
526
527		# Parameters:
528		- `grid_shape : Coord | CoordTup`
529			The shape of the maze grid (e.g., `(n_rows, n_cols)`).
530		- `lattice_dim : int`
531			The lattice dimension (default is `2`).
532		- `start_coord : Coord | None`
533			Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen.
534		- `**kwargs`
535			Additional keyword arguments (currently unused).
536
537		# Returns:
538		- `LatticeMaze`
539			A maze represented by a connection list, generated using recursive division.
540
541		# Usage:
542		```python
543		maze = gen_recursive_division((10, 10))
544		```
545		"""
546		assert lattice_dim == 2, (  # noqa: PLR2004
547			"Recursive division algorithm is only implemented for 2D lattices."
548		)
549		# Convert grid_shape to a tuple of ints.
550		grid_shape_: CoordTup = tuple(int(x) for x in grid_shape)  # type: ignore[assignment]
551		n_rows, n_cols = grid_shape_
552
553		# Initialize connection_list as a fully connected grid.
554		# For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True.
555		# For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True.
556		connection_list = np.zeros((2, n_rows, n_cols), dtype=bool)
557		connection_list[0, : n_rows - 1, :] = True
558		connection_list[1, :, : n_cols - 1] = True
559
560		def divide(x: int, y: int, width: int, height: int) -> None:
561			"""Recursively divide the region starting at (x, y) with the given width and height.
562
563			Removes connections along the chosen division line except for one randomly chosen gap.
564			"""
565			if width < 2 or height < 2:  # noqa: PLR2004
566				return
567
568			if width > height:
569				# Vertical division.
570				wall_col = random.randint(x + 1, x + width - 1)
571				gap_row = random.randint(y, y + height - 1)
572				for row in range(y, y + height):
573					if row == gap_row:
574						continue
575					# Remove the vertical connection between (row, wall_col-1) and (row, wall_col).
576					if wall_col - 1 < n_cols - 1:
577						connection_list[1, row, wall_col - 1] = False
578				# Recurse on the left and right subregions.
579				divide(x, y, wall_col - x, height)
580				divide(wall_col, y, x + width - wall_col, height)
581			else:
582				# Horizontal division.
583				wall_row = random.randint(y + 1, y + height - 1)
584				gap_col = random.randint(x, x + width - 1)
585				for col in range(x, x + width):
586					if col == gap_col:
587						continue
588					# Remove the horizontal connection between (wall_row-1, col) and (wall_row, col).
589					if wall_row - 1 < n_rows - 1:
590						connection_list[0, wall_row - 1, col] = False
591				# Recurse on the top and bottom subregions.
592				divide(x, y, width, wall_row - y)
593				divide(x, wall_row, width, y + height - wall_row)
594
595		# Begin the division on the full grid.
596		divide(0, 0, n_cols, n_rows)
597
598		if start_coord is None:
599			start_coord = tuple(np.random.randint(0, n) for n in grid_shape_)  # type: ignore[assignment]
600
601		generation_meta: dict = dict(
602			func_name="gen_recursive_division",
603			grid_shape=grid_shape_,
604			start_coord=start_coord,
605			algorithm="recursive_division",
606			fully_connected=True,
607		)
608		return LatticeMaze(
609			connection_list=connection_list, generation_meta=generation_meta
610		)

Generate a maze using the recursive division algorithm.

This function generates a maze by recursively dividing the grid with walls and carving a single passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. The resulting maze is a perfect maze, meaning there is exactly one path between any two cells.

Parameters:

  • grid_shape : Coord | CoordTup The shape of the maze grid (e.g., (n_rows, n_cols)).
  • lattice_dim : int The lattice dimension (default is 2).
  • start_coord : Coord | None Optionally, specify a starting coordinate. If None, a random coordinate is chosen.
  • **kwargs Additional keyword arguments (currently unused).

Returns:

  • LatticeMaze A maze represented by a connection list, generated using recursive division.

Usage:

maze = gen_recursive_division((10, 10))
GENERATORS_MAP: dict[str, typing.Callable[[jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], typing.Any], maze_dataset.LatticeMaze]] = {'gen_dfs': <function LatticeMazeGenerators.gen_dfs>, 'gen_wilson': <function LatticeMazeGenerators.gen_wilson>, 'gen_percolation': <function LatticeMazeGenerators.gen_percolation>, 'gen_dfs_percolation': <function LatticeMazeGenerators.gen_dfs_percolation>, 'gen_prim': <function LatticeMazeGenerators.gen_prim>, 'gen_kruskal': <function LatticeMazeGenerators.gen_kruskal>, 'gen_recursive_division': <function LatticeMazeGenerators.gen_recursive_division>}

mapping of generator names to generator functions, useful for loading MazeDatasetConfig

def get_maze_with_solution( gen_name: str, grid_shape: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], maze_ctor_kwargs: dict | None = None) -> maze_dataset.SolvedMaze:
640def get_maze_with_solution(
641	gen_name: str,
642	grid_shape: Coord | CoordTup,
643	maze_ctor_kwargs: dict | None = None,
644) -> SolvedMaze:
645	"helper function to get a maze already with a solution"
646	if maze_ctor_kwargs is None:
647		maze_ctor_kwargs = dict()
648	# TYPING: error: Too few arguments  [call-arg]
649	# not sure why this is happening -- doesnt recognize the kwargs?
650	maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs)  # type: ignore[call-arg]
651	solution: CoordArray = np.array(maze.generate_random_path())
652	return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution)

helper function to get a maze already with a solution