docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.maze.lattice_maze

Implements LatticeMaze, and the TargetedLatticeMaze and SolvedMaze subclasses.

also includes basic utilities, including converting to/from ascii and pixel representations.


   1"""Implements `LatticeMaze`, and the `TargetedLatticeMaze` and `SolvedMaze` subclasses.
   2
   3also includes basic utilities, including converting to/from ascii and pixel representations.
   4"""
   5
   6import typing
   7import warnings
   8from dataclasses import dataclass
   9from itertools import chain
  10
  11import numpy as np
  12from jaxtyping import Bool, Int, Int8, Shaped
  13from muutils.json_serialize.serializable_dataclass import (
  14	SerializableDataclass,
  15	serializable_dataclass,
  16	serializable_field,
  17)
  18from muutils.misc import isinstance_by_type_name, list_split
  19
  20from maze_dataset.constants import (
  21	NEIGHBORS_MASK,
  22	SPECIAL_TOKENS,
  23	ConnectionList,
  24	Coord,
  25	CoordArray,
  26	CoordList,
  27	CoordTup,
  28)
  29from maze_dataset.token_utils import (
  30	TokenizerDeprecationWarning,
  31	connection_list_to_adj_list,
  32	get_adj_list_tokens,
  33	get_origin_tokens,
  34	get_path_tokens,
  35	get_target_tokens,
  36)
  37
  38if typing.TYPE_CHECKING:
  39	from maze_dataset.tokenization import (
  40		MazeTokenizer,
  41		MazeTokenizerModular,
  42		TokenizationMode,
  43	)
  44
  45RGB = tuple[int, int, int]
  46"rgb tuple of values 0-255"
  47
  48PixelGrid = Int[np.ndarray, "x y rgb"]
  49"rgb grid of pixels"
  50BinaryPixelGrid = Bool[np.ndarray, "x y"]
  51"boolean grid of pixels"
  52
  53DIM_2: int = 2
  54"2 dimensions"
  55
  56
  57class NoValidEndpointException(Exception):  # noqa: N818
  58	"""Raised when no valid start or end positions are found in a maze."""
  59
  60	pass
  61
  62
  63def _fill_edges_with_walls(connection_list: ConnectionList) -> ConnectionList:
  64	"""fill the last elements of the connections lists as false for each dim"""
  65	for dim in range(connection_list.shape[0]):
  66		# last row for down
  67		if dim == 0:
  68			connection_list[dim, -1, :] = False
  69		# last column for right
  70		elif dim == 1:
  71			connection_list[dim, :, -1] = False
  72		else:
  73			err_msg: str = f"only 2d lattices supported. got {dim=}"
  74			raise NotImplementedError(err_msg)
  75	return connection_list
  76
  77
  78def color_in_pixel_grid(pixel_grid: PixelGrid, color: RGB) -> bool:
  79	"""check if a color is in a pixel grid"""
  80	for row in pixel_grid:
  81		for pixel in row:
  82			if np.all(pixel == color):
  83				return True
  84	return False
  85
  86
  87@dataclass(frozen=True)
  88class PixelColors:
  89	"standard colors for pixel grids"
  90
  91	WALL: RGB = (0, 0, 0)
  92	OPEN: RGB = (255, 255, 255)
  93	START: RGB = (0, 255, 0)
  94	END: RGB = (255, 0, 0)
  95	PATH: RGB = (0, 0, 255)
  96
  97
  98@dataclass(frozen=True)
  99class AsciiChars:
 100	"standard ascii characters for mazes"
 101
 102	WALL: str = "#"
 103	OPEN: str = " "
 104	START: str = "S"
 105	END: str = "E"
 106	PATH: str = "X"
 107
 108
 109ASCII_PIXEL_PAIRINGS: dict[str, RGB] = {
 110	AsciiChars.WALL: PixelColors.WALL,
 111	AsciiChars.OPEN: PixelColors.OPEN,
 112	AsciiChars.START: PixelColors.START,
 113	AsciiChars.END: PixelColors.END,
 114	AsciiChars.PATH: PixelColors.PATH,
 115}
 116"map ascii characters to pixel colors"
 117
 118
 119@serializable_dataclass(
 120	frozen=True,
 121	kw_only=True,
 122	properties_to_serialize=["lattice_dim", "generation_meta"],
 123)
 124class LatticeMaze(SerializableDataclass):
 125	"""lattice maze (nodes on a lattice, connections only to neighboring nodes)
 126
 127	Connection List represents which nodes (N) are connected in each direction.
 128
 129	First and second elements represent rightward and downward connections,
 130	respectively.
 131
 132	Example:
 133		Connection list:
 134			[
 135				[ # down
 136					[F T],
 137					[F F]
 138				],
 139				[ # right
 140					[T F],
 141					[T F]
 142				]
 143			]
 144
 145		Nodes with connections
 146			N T N F
 147			F	T
 148			N T N F
 149			F	F
 150
 151		Graph:
 152			N - N
 153				|
 154			N - N
 155
 156	Note: the bottom row connections going down, and the
 157	right-hand connections going right, will always be False.
 158
 159	"""
 160
 161	connection_list: ConnectionList
 162	generation_meta: dict | None = serializable_field(default=None, compare=False)
 163
 164	lattice_dim = property(lambda self: self.connection_list.shape[0])
 165	grid_shape = property(lambda self: self.connection_list.shape[1:])
 166	n_connections = property(lambda self: self.connection_list.sum())
 167
 168	@property
 169	def grid_n(self) -> int:
 170		"grid size as int, raises `AssertionError` if not square"
 171		assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported"
 172		return self.grid_shape[0]
 173
 174	# ============================================================
 175	# basic methods
 176	# ============================================================
 177
 178	def __eq__(self, other: object) -> bool:
 179		"equality check calls super"
 180		return super().__eq__(other)
 181
 182	@staticmethod
 183	def heuristic(a: CoordTup, b: CoordTup) -> float:
 184		"""return manhattan distance between two points"""
 185		return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])
 186
 187	def __hash__(self) -> int:
 188		"""hash the connection list by converting connection list to bytes"""
 189		return hash(self.connection_list.tobytes())
 190
 191	def nodes_connected(self, a: Coord, b: Coord, /) -> bool:
 192		"""returns whether two nodes are connected"""
 193		delta: Coord = b - a
 194		if np.abs(delta).sum() != 1:
 195			# return false if not even adjacent
 196			return False
 197		else:
 198			# test for wall
 199			dim: int = int(np.argmax(np.abs(delta)))
 200			clist_node: Coord = a if (delta.sum() > 0) else b
 201			return self.connection_list[dim, clist_node[0], clist_node[1]]
 202
 203	def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool:
 204		"""check if a path is valid"""
 205		# check path is not empty
 206		if len(path) == 0:
 207			return empty_is_valid
 208
 209		# check all coords in bounds of maze
 210		if not np.all((path >= 0) & (path < self.grid_shape)):
 211			return False
 212
 213		# check all nodes connected
 214		for i in range(len(path) - 1):
 215			if not self.nodes_connected(path[i], path[i + 1]):
 216				return False
 217		return True
 218
 219	def coord_degrees(self) -> Int8[np.ndarray, "row col"]:
 220		"""Returns an array with the connectivity degree of each coord.
 221
 222		I.e., how many neighbors each coord has.
 223		"""
 224		int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = (
 225			self.connection_list.astype(np.int8)
 226		)
 227		degrees: Int8[np.ndarray, "row col"] = np.sum(
 228			int_conn,
 229			axis=0,
 230		)  # Connections to east and south
 231		degrees[:, 1:] += int_conn[1, :, :-1]  # Connections to west
 232		degrees[1:, :] += int_conn[0, :-1, :]  # Connections to north
 233		return degrees
 234
 235	def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray:
 236		"""Returns an array of the neighboring, connected coords of `c`."""
 237		c = np.array(c)  # type: ignore[assignment]
 238		neighbors: list[Coord] = [
 239			neighbor
 240			for neighbor in (c + NEIGHBORS_MASK)
 241			if (
 242				(0 <= neighbor[0] < self.grid_shape[0])  # in x bounds
 243				and (0 <= neighbor[1] < self.grid_shape[1])  # in y bounds
 244				and self.nodes_connected(c, neighbor)  # connected
 245			)
 246		]
 247
 248		output: CoordArray = np.array(neighbors)
 249		if len(neighbors) > 0:
 250			assert output.shape == (
 251				len(neighbors),
 252				2,
 253			), (
 254				f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}"
 255			)
 256		return output
 257
 258	def gen_connected_component_from(self, c: Coord) -> CoordArray:
 259		"""return the connected component from a given coordinate"""
 260		# Stack for DFS
 261		stack: list[Coord] = [c]
 262
 263		# Set to store visited nodes
 264		visited: set[CoordTup] = set()
 265
 266		while stack:
 267			current_node: Coord = stack.pop()
 268			# this is fine since we know current_node is a coord and thus of length 2
 269			visited.add(tuple(current_node))  # type: ignore[arg-type]
 270
 271			# Get the neighbors of the current node
 272			neighbors = self.get_coord_neighbors(current_node)
 273
 274			# Iterate over neighbors
 275			for neighbor in neighbors:
 276				if tuple(neighbor) not in visited:
 277					stack.append(neighbor)
 278
 279		return np.array(list(visited))
 280
 281	def find_shortest_path(
 282		self,
 283		c_start: CoordTup | Coord,
 284		c_end: CoordTup | Coord,
 285	) -> CoordArray:
 286		"""find the shortest path between two coordinates, using A*"""
 287		c_start = tuple(c_start)  # type: ignore[assignment]
 288		c_end = tuple(c_end)  # type: ignore[assignment]
 289
 290		g_score: dict[CoordTup, float] = (
 291			dict()
 292		)  # cost of cheapest path to node from start currently known
 293		f_score: dict[CoordTup, float] = {
 294			c_start: 0.0,
 295		}  # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end)
 296
 297		# init
 298		g_score[c_start] = 0.0
 299		g_score[c_start] = self.heuristic(c_start, c_end)
 300
 301		closed_vtx: set[CoordTup] = set()  # nodes already evaluated
 302		# nodes to be evaluated
 303		# we need a set of the tuples, dont place the ints in the set
 304		open_vtx: set[CoordTup] = set([c_start])  # noqa: C405
 305		source: dict[CoordTup, CoordTup] = (
 306			dict()
 307		)  # node immediately preceding each node in the path (currently known shortest path)
 308
 309		while open_vtx:
 310			# get lowest f_score node
 311			# mypy cant tell that c is of length 2
 312			c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)])  # type: ignore[index]
 313			# f_current: float = f_score[c_current]
 314
 315			# check if goal is reached
 316			if c_end == c_current:
 317				path: list[CoordTup] = [c_current]
 318				p_current: CoordTup = c_current
 319				while p_current in source:
 320					p_current = source[p_current]
 321					path.append(p_current)
 322				# ----------------------------------------------------------------------
 323				# this is the only return statement
 324				return np.array(path[::-1])
 325				# ----------------------------------------------------------------------
 326
 327			# close current node
 328			closed_vtx.add(c_current)
 329			open_vtx.remove(c_current)
 330
 331			# update g_score of neighbors
 332			_np_neighbor: Coord
 333			for _np_neighbor in self.get_coord_neighbors(c_current):
 334				neighbor: CoordTup = tuple(_np_neighbor)
 335
 336				if neighbor in closed_vtx:
 337					# already checked
 338					continue
 339				g_temp: float = g_score[c_current] + 1  # always 1 for maze neighbors
 340
 341				if neighbor not in open_vtx:
 342					# found new vtx, so add
 343					open_vtx.add(neighbor)
 344
 345				elif g_temp >= g_score[neighbor]:
 346					# if already knew about this one, but current g_score is worse, skip
 347					continue
 348
 349				# store g_score and source
 350				source[neighbor] = c_current
 351				g_score[neighbor] = g_temp
 352				f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end)
 353
 354		raise ValueError(
 355			"A solution could not be found!",
 356			f"{c_start = }, {c_end = }",
 357			self.as_ascii(),
 358		)
 359
 360	def get_nodes(self) -> CoordArray:
 361		"""return a list of all nodes in the maze"""
 362		rows: Int[np.ndarray, "x y"]
 363		cols: Int[np.ndarray, "x y"]
 364		rows, cols = np.meshgrid(
 365			range(self.grid_shape[0]),
 366			range(self.grid_shape[1]),
 367			indexing="ij",
 368		)
 369		nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T
 370		return nodes
 371
 372	def get_connected_component(self) -> CoordArray:
 373		"""get the largest (and assumed only nonsingular) connected component of the maze
 374
 375		TODO: other connected components?
 376		"""
 377		if (self.generation_meta is None) or (
 378			self.generation_meta.get("fully_connected", False)
 379		):
 380			# for fully connected case, pick any two positions
 381			return self.get_nodes()
 382		else:
 383			# if metadata provided, use visited cells
 384			visited_cells: set[CoordTup] | None = self.generation_meta.get(
 385				"visited_cells",
 386				None,
 387			)
 388			if visited_cells is None:
 389				# TODO: dynamically generate visited_cells?
 390				err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}"
 391				raise ValueError(
 392					err_msg,
 393				)
 394			visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells))
 395			return visited_cells_np
 396
 397	@typing.overload
 398	def generate_random_path(
 399		self,
 400		allowed_start: CoordList | None = None,
 401		allowed_end: CoordList | None = None,
 402		deadend_start: bool = False,
 403		deadend_end: bool = False,
 404		endpoints_not_equal: bool = False,
 405		except_on_no_valid_endpoint: typing.Literal[True] = True,
 406	) -> CoordArray: ...
 407	@typing.overload
 408	def generate_random_path(
 409		self,
 410		allowed_start: CoordList | None = None,
 411		allowed_end: CoordList | None = None,
 412		deadend_start: bool = False,
 413		deadend_end: bool = False,
 414		endpoints_not_equal: bool = False,
 415		except_on_no_valid_endpoint: typing.Literal[False] = False,
 416	) -> typing.Optional[CoordArray]: ...
 417	def generate_random_path(  # noqa: C901
 418		self,
 419		allowed_start: CoordList | None = None,
 420		allowed_end: CoordList | None = None,
 421		deadend_start: bool = False,
 422		deadend_end: bool = False,
 423		endpoints_not_equal: bool = False,
 424		except_on_no_valid_endpoint: bool = True,
 425	) -> typing.Optional[CoordArray]:
 426		"""return a path between randomly chosen start and end nodes within the connected component
 427
 428		Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
 429
 430		# Parameters:
 431		- `allowed_start : CoordList | None`
 432			a list of allowed start positions. If `None`, any position in the connected component is allowed
 433			(defaults to `None`)
 434		- `allowed_end : CoordList | None`
 435			a list of allowed end positions. If `None`, any position in the connected component is allowed
 436			(defaults to `None`)
 437		- `deadend_start : bool`
 438			whether to ***force*** the start position to be a deadend (defaults to `False`)
 439			(defaults to `False`)
 440		- `deadend_end : bool`
 441			whether to ***force*** the end position to be a deadend (defaults to `False`)
 442			(defaults to `False`)
 443		- `endpoints_not_equal : bool`
 444			whether to ensure tha the start and end point are not the same
 445			(defaults to `False`)
 446		- `except_on_no_valid_endpoint : bool`
 447			whether to raise an error if no valid start or end positions are found
 448			if this is `False`, the function might return `None` and this must be handled by the caller
 449			(defaults to `True`)
 450
 451		# Returns:
 452		- `CoordArray`
 453			a path between the selected start and end positions
 454
 455		# Raises:
 456		- `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True`
 457		"""
 458		# we can't create a "path" in a single-node maze
 459		assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, (  # noqa: PT018
 460			f"can't create path in single-node maze: {self.as_ascii()}"
 461		)
 462
 463		# get connected component
 464		connected_component: CoordArray = self.get_connected_component()
 465
 466		# initialize start and end positions
 467		positions: Int[np.int8, "2 2"]
 468
 469		# if no special conditions on start and end positions
 470		if (allowed_start, allowed_end, deadend_start, deadend_end) == (
 471			None,
 472			None,
 473			False,
 474			False,
 475		):
 476			try:
 477				positions = connected_component[  # type: ignore[assignment]
 478					np.random.choice(
 479						len(connected_component),
 480						size=2,
 481						replace=False,
 482					)
 483				]
 484			except ValueError as e:
 485				if except_on_no_valid_endpoint:
 486					err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }"
 487					raise NoValidEndpointException(
 488						err_msg,
 489					) from e
 490				return None
 491
 492			return self.find_shortest_path(positions[0], positions[1])  # type: ignore[index]
 493
 494		# handle special conditions
 495		connected_component_set: set[CoordTup] = set(map(tuple, connected_component))
 496		# copy connected component set
 497		allowed_start_set: set[CoordTup] = connected_component_set.copy()
 498		allowed_end_set: set[CoordTup] = connected_component_set.copy()
 499
 500		# filter by explicitly allowed start and end positions
 501		# '# type: ignore[assignment]' here because the returned tuple can be of any length
 502		if allowed_start is not None:
 503			allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set  # type: ignore[assignment]
 504
 505		if allowed_end is not None:
 506			allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set  # type: ignore[assignment]
 507
 508		# filter by forcing deadends
 509		if deadend_start:
 510			allowed_start_set = set(
 511				filter(
 512					lambda x: len(self.get_coord_neighbors(x)) == 1,
 513					allowed_start_set,
 514				),
 515			)
 516
 517		if deadend_end:
 518			allowed_end_set = set(
 519				filter(
 520					lambda x: len(self.get_coord_neighbors(x)) == 1,
 521					allowed_end_set,
 522				),
 523			)
 524
 525		# check we have valid positions
 526		if len(allowed_start_set) == 0 or len(allowed_end_set) == 0:
 527			if except_on_no_valid_endpoint:
 528				err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }"
 529				raise NoValidEndpointException(
 530					err_msg,
 531				)
 532			return None
 533
 534		# randomly select start and end positions
 535		try:
 536			# ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok
 537			start_pos: CoordTup = tuple(  # type: ignore[assignment]
 538				list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))],
 539			)
 540			if endpoints_not_equal:
 541				# remove start position from end positions
 542				allowed_end_set.discard(start_pos)
 543			end_pos: CoordTup = tuple(  # type: ignore[assignment]
 544				list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))],
 545			)
 546		except ValueError as e:
 547			if except_on_no_valid_endpoint:
 548				err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }"
 549				raise NoValidEndpointException(
 550					err_msg,
 551				) from e
 552			return None
 553
 554		return self.find_shortest_path(start_pos, end_pos)
 555
 556	# ============================================================
 557	# to and from adjacency list
 558	# ============================================================
 559	def as_adj_list(
 560		self,
 561		shuffle_d0: bool = True,
 562		shuffle_d1: bool = True,
 563	) -> Int8[np.ndarray, "conn start_end coord"]:
 564		"""return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`"""
 565		return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)
 566
 567	@classmethod
 568	def from_adj_list(
 569		cls,
 570		adj_list: Int8[np.ndarray, "conn start_end coord"],
 571	) -> "LatticeMaze":
 572		"""create a LatticeMaze from a list of connections
 573
 574		> [!NOTE]
 575		> This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
 576		"""
 577		# this is where it would probably break for rectangular mazes
 578		grid_n: int = adj_list.max() + 1
 579
 580		connection_list: ConnectionList = np.zeros(
 581			(2, grid_n, grid_n),
 582			dtype=np.bool_,
 583		)
 584
 585		for c_start, c_end in adj_list:
 586			# check that exactly 1 coordinate matches
 587			if (c_start == c_end).sum() != 1:
 588				raise ValueError("invalid connection")
 589
 590			# get the direction
 591			d: int = (c_start != c_end).argmax()
 592
 593			x: int
 594			y: int
 595			# pick whichever has the lesser value in the direction `d`
 596			if c_start[d] < c_end[d]:
 597				x, y = c_start
 598			else:
 599				x, y = c_end
 600
 601			connection_list[d, x, y] = True
 602
 603		return LatticeMaze(
 604			connection_list=connection_list,
 605		)
 606
 607	def as_adj_list_tokens(self) -> list[str | CoordTup]:
 608		"""(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead"""
 609		warnings.warn(
 610			"`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.",
 611			TokenizerDeprecationWarning,
 612		)
 613		return [
 614			SPECIAL_TOKENS.ADJLIST_START,
 615			*chain.from_iterable(  # type: ignore[list-item]
 616				[
 617					[
 618						tuple(c_s),
 619						SPECIAL_TOKENS.CONNECTOR,
 620						tuple(c_e),
 621						SPECIAL_TOKENS.ADJACENCY_ENDLINE,
 622					]
 623					for c_s, c_e in self.as_adj_list()
 624				],
 625			),
 626			SPECIAL_TOKENS.ADJLIST_END,
 627		]
 628
 629	def _as_adj_list_tokens(self) -> list[str | CoordTup]:
 630		return [
 631			SPECIAL_TOKENS.ADJLIST_START,
 632			*chain.from_iterable(  # type: ignore[list-item]
 633				[
 634					[
 635						tuple(c_s),
 636						SPECIAL_TOKENS.CONNECTOR,
 637						tuple(c_e),
 638						SPECIAL_TOKENS.ADJACENCY_ENDLINE,
 639					]
 640					for c_s, c_e in self.as_adj_list()
 641				],
 642			),
 643			SPECIAL_TOKENS.ADJLIST_END,
 644		]
 645
 646	def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]:
 647		"""turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples"""
 648		output: list[CoordTup | str] = self._as_adj_list_tokens()
 649		# if getattr(self, "start_pos", None) is not None:
 650		if isinstance(self, TargetedLatticeMaze):
 651			output += self._get_start_pos_tokens()
 652		if isinstance(self, TargetedLatticeMaze):
 653			output += self._get_end_pos_tokens()
 654		if isinstance(self, SolvedMaze):
 655			output += self._get_solution_tokens()
 656		return output
 657
 658	def _as_tokens(
 659		self,
 660		maze_tokenizer: "MazeTokenizer | TokenizationMode",
 661	) -> list[str]:
 662		# type ignores here fine since we check the instance
 663		if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
 664			maze_tokenizer = maze_tokenizer.to_legacy_tokenizer()  # type: ignore[union-attr]
 665		if (
 666			isinstance_by_type_name(maze_tokenizer, "MazeTokenizer")
 667			and maze_tokenizer.is_AOTP()  # type: ignore[union-attr]
 668		):
 669			coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP()
 670			coords_processed: list[str] = maze_tokenizer.coords_to_strings(  # type: ignore[union-attr]
 671				coords=coords_raw,
 672				when_noncoord="include",
 673			)
 674			return coords_processed
 675		else:
 676			err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}"
 677			raise NotImplementedError(err_msg)
 678
 679	def as_tokens(
 680		self,
 681		maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
 682	) -> list[str]:
 683		"""serialize maze and solution to tokens"""
 684		if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"):
 685			return maze_tokenizer.to_tokens(self)  # type: ignore[union-attr]
 686		else:
 687			return self._as_tokens(maze_tokenizer)  # type: ignore[union-attr,arg-type]
 688
 689	@classmethod
 690	def _from_tokens_AOTP(
 691		cls,
 692		tokens: list[str],
 693		maze_tokenizer: "MazeTokenizer | MazeTokenizerModular",
 694	) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze":
 695		"""create a LatticeMaze from a list of tokens"""
 696		# figure out what input format
 697		# ========================================
 698		if tokens[0] == SPECIAL_TOKENS.ADJLIST_START:
 699			adj_list_tokens = get_adj_list_tokens(tokens)
 700		else:
 701			# If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens
 702			adj_list_tokens = tokens
 703			warnings.warn(
 704				"Assuming input is just adjacency list tokens, no special tokens found",
 705			)
 706
 707		# process edges for adjacency list
 708		# ========================================
 709		edges: list[list[str]] = list_split(
 710			adj_list_tokens,
 711			SPECIAL_TOKENS.ADJACENCY_ENDLINE,
 712		)
 713
 714		coordinates: list[tuple[CoordTup, CoordTup]] = list()
 715		for e in edges:
 716			# skip last endline
 717			if len(e) != 0:
 718				# convert to coords, split start and end
 719				e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords(
 720					e,
 721					when_noncoord="include",
 722				)
 723				# this assertion depends on the tokenizer having exactly one token for the connector
 724				# which is also why we "include" above
 725				# the connector token is discarded below
 726				assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }"  # noqa: PLR2004
 727				assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, (
 728					f"invalid edge: {e = } {e_coords = }"
 729				)
 730				e_coords_first: CoordTup = e_coords[0]  # type: ignore[assignment]
 731				e_coords_last: CoordTup = e_coords[-1]  # type: ignore[assignment]
 732				coordinates.append((e_coords_first, e_coords_last))
 733
 734		assert all(len(c) == DIM_2 for c in coordinates), (
 735			f"invalid coordinates: {coordinates = }"
 736		)
 737		adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates)
 738		assert tuple(adj_list.shape) == (
 739			len(coordinates),
 740			2,
 741			2,
 742		), f"invalid adj_list: {adj_list.shape = } {coordinates = }"
 743
 744		output_maze: LatticeMaze = cls.from_adj_list(adj_list)
 745
 746		# add start and end positions
 747		# ========================================
 748		is_targeted: bool = False
 749		if all(
 750			x in tokens
 751			for x in (
 752				SPECIAL_TOKENS.ORIGIN_START,
 753				SPECIAL_TOKENS.ORIGIN_END,
 754				SPECIAL_TOKENS.TARGET_START,
 755				SPECIAL_TOKENS.TARGET_END,
 756			)
 757		):
 758			start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords(
 759				get_origin_tokens(tokens),
 760				when_noncoord="error",
 761			)
 762			end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords(
 763				get_target_tokens(tokens),
 764				when_noncoord="error",
 765			)
 766			assert len(start_pos_list) == 1, (
 767				f"invalid start_pos_list: {start_pos_list = }"
 768			)
 769			assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }"
 770
 771			start_pos: CoordTup = start_pos_list[0]
 772			end_pos: CoordTup = end_pos_list[0]
 773
 774			output_maze = TargetedLatticeMaze.from_lattice_maze(
 775				lattice_maze=output_maze,
 776				start_pos=start_pos,
 777				end_pos=end_pos,
 778			)
 779
 780			is_targeted = True
 781
 782		if all(
 783			x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END)
 784		):
 785			assert is_targeted, "maze must be targeted to have a solution"
 786			solution: list[CoordTup] = maze_tokenizer.strings_to_coords(
 787				get_path_tokens(tokens, trim_end=True),
 788				when_noncoord="error",
 789			)
 790			output_maze = SolvedMaze.from_targeted_lattice_maze(
 791				# HACK: I think this is fine, but im not sure
 792				targeted_lattice_maze=output_maze,  # type: ignore[arg-type]
 793				solution=solution,
 794			)
 795
 796		return output_maze
 797
 798	# TODO: any way to get return type hinting working for this?
 799	@classmethod
 800	def from_tokens(
 801		cls,
 802		tokens: list[str],
 803		maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
 804	) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze":
 805		"""Constructs a maze from a tokenization.
 806
 807		Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported.
 808		"""
 809		# HACK: type ignores here fine since we check the instance
 810		if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
 811			maze_tokenizer = maze_tokenizer.to_legacy_tokenizer()  # type: ignore[union-attr]
 812		if (
 813			isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular")
 814			and not maze_tokenizer.is_legacy_equivalent()  # type: ignore[union-attr]
 815		):
 816			err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}."
 817			raise NotImplementedError(
 818				err_msg,
 819			)
 820
 821		if isinstance(tokens, str):
 822			tokens = tokens.split()
 823
 824		if maze_tokenizer.is_AOTP():  # type: ignore[union-attr]
 825			return cls._from_tokens_AOTP(tokens, maze_tokenizer)  # type: ignore[arg-type]
 826		else:
 827			raise NotImplementedError("only AOTP tokenization is supported")
 828
 829	# ============================================================
 830	# to and from pixels
 831	# ============================================================
 832	def _as_pixels_bw(self) -> BinaryPixelGrid:
 833		assert self.lattice_dim == DIM_2, "only 2D mazes are supported"
 834		# Create an empty pixel grid with walls
 835		pixel_grid: Int[np.ndarray, "x y"] = np.full(
 836			(self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1),
 837			False,
 838			dtype=np.bool_,
 839		)
 840
 841		# Set white nodes
 842		pixel_grid[1::2, 1::2] = True
 843
 844		# Set white connections (downward)
 845		for i, row in enumerate(self.connection_list[0]):
 846			for j, connected in enumerate(row):
 847				if connected:
 848					pixel_grid[i * 2 + 2, j * 2 + 1] = True
 849
 850		# Set white connections (rightward)
 851		for i, row in enumerate(self.connection_list[1]):
 852			for j, connected in enumerate(row):
 853				if connected:
 854					pixel_grid[i * 2 + 1, j * 2 + 2] = True
 855
 856		return pixel_grid
 857
 858	def as_pixels(
 859		self,
 860		show_endpoints: bool = True,
 861		show_solution: bool = True,
 862	) -> PixelGrid:
 863		"""convert the maze to a pixel grid
 864
 865		- useful as a simpler way of plotting the maze than the more complex `MazePlot`
 866		- the same underlying representation as `as_ascii` but as an image
 867		- used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data
 868		"""
 869		# HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze`
 870		# but solution, start_pos, end_pos not always defined
 871		# but its fine since we explicitly check the type
 872		if show_solution and not show_endpoints:
 873			raise ValueError("show_solution=True requires show_endpoints=True")
 874		# convert original bool pixel grid to RGB
 875		pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw()
 876		pixel_grid: PixelGrid = np.full(
 877			(*pixel_grid_bw.shape, 3),
 878			PixelColors.WALL,
 879			dtype=np.uint8,
 880		)
 881		pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN  # noqa: E712
 882
 883		if self.__class__ == LatticeMaze:
 884			return pixel_grid
 885
 886		# set endpoints for TargetedLatticeMaze
 887		if self.__class__ == TargetedLatticeMaze:
 888			if show_endpoints:
 889				pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 890					PixelColors.START
 891				)
 892				pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 893					PixelColors.END
 894				)
 895			return pixel_grid
 896
 897		# set solution -- we only reach this part if `self.__class__ == SolvedMaze`
 898		if show_solution:
 899			for coord in self.solution:  # type: ignore[attr-defined]
 900				pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH
 901
 902			# set pixels between coords
 903			for index, coord in enumerate(self.solution[:-1]):  # type: ignore[attr-defined]
 904				next_coord = self.solution[index + 1]  # type: ignore[attr-defined]
 905				# check they are adjacent using norm
 906				assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, (
 907					f"Coords {coord} and {next_coord} are not adjacent"
 908				)
 909				# set pixel between them
 910				pixel_grid[
 911					coord[0] * 2 + 1 + next_coord[0] - coord[0],
 912					coord[1] * 2 + 1 + next_coord[1] - coord[1],
 913				] = PixelColors.PATH
 914
 915			# set endpoints (again, since path would overwrite them)
 916			pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 917				PixelColors.START
 918			)
 919			pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 920				PixelColors.END
 921			)
 922
 923		return pixel_grid
 924
 925	@classmethod
 926	def _from_pixel_grid_bw(
 927		cls,
 928		pixel_grid: BinaryPixelGrid,
 929	) -> tuple[ConnectionList, tuple[int, int]]:
 930		grid_shape: tuple[int, int] = (
 931			pixel_grid.shape[0] // 2,
 932			pixel_grid.shape[1] // 2,
 933		)
 934		connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_)
 935
 936		# Extract downward connections
 937		connection_list[0] = pixel_grid[2::2, 1::2]
 938
 939		# Extract rightward connections
 940		connection_list[1] = pixel_grid[1::2, 2::2]
 941
 942		return connection_list, grid_shape
 943
 944	@classmethod
 945	def _from_pixel_grid_with_positions(
 946		cls,
 947		pixel_grid: PixelGrid | BinaryPixelGrid,
 948		marked_positions: dict[str, RGB],
 949	) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]:
 950		# Convert RGB pixel grid to Bool pixel grid
 951		# error: Incompatible types in assignment (expression has type
 952		# "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]",
 953		# variable has type "ndarray[Any, Any]")  [assignment]
 954		pixel_grid_bw: BinaryPixelGrid = ~np.all(  # type: ignore[assignment]
 955			pixel_grid == PixelColors.WALL,
 956			axis=-1,
 957		)
 958		connection_list: ConnectionList
 959		grid_shape: tuple[int, int]
 960		connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw)
 961
 962		# Find any marked positions
 963		out_positions: dict[str, CoordArray] = dict()
 964		for key, color in marked_positions.items():
 965			pos_temp: Int[np.ndarray, "x y"] = np.argwhere(
 966				np.all(pixel_grid == color, axis=-1),
 967			)
 968			pos_save: list[CoordTup] = list()
 969			for pos in pos_temp:
 970				# if it is a coordinate and not connection (transform position, %2==1)
 971				if pos[0] % 2 == 1 and pos[1] % 2 == 1:
 972					pos_save.append((pos[0] // 2, pos[1] // 2))
 973
 974			out_positions[key] = np.array(pos_save)
 975
 976		return connection_list, grid_shape, out_positions
 977
 978	@classmethod
 979	def from_pixels(
 980		cls,
 981		pixel_grid: PixelGrid,
 982	) -> "LatticeMaze":
 983		"""create a LatticeMaze from a pixel grid. reverse of `as_pixels`
 984
 985		# Raises:
 986		- `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze`
 987		"""
 988		connection_list: ConnectionList
 989		grid_shape: tuple[int, int]
 990
 991		# if a binary pixel grid, return regular LatticeMaze
 992		if len(pixel_grid.shape) == 2:  # noqa: PLR2004
 993			connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid)
 994			return LatticeMaze(connection_list=connection_list)
 995
 996		# otherwise, detect and check it's valid
 997		cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid)
 998		if cls not in cls_detected.__mro__:
 999			err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }"
1000			raise ValueError(
1001				err_msg,
1002			)
1003
1004		(
1005			connection_list,
1006			grid_shape,
1007			marked_pos,
1008		) = cls._from_pixel_grid_with_positions(
1009			pixel_grid=pixel_grid,
1010			marked_positions=dict(
1011				start=PixelColors.START,
1012				end=PixelColors.END,
1013				solution=PixelColors.PATH,
1014			),
1015		)
1016		# if we wanted a LatticeMaze, return it
1017		if cls == LatticeMaze:
1018			return LatticeMaze(connection_list=connection_list)
1019
1020		# otherwise, keep going
1021		temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list)
1022
1023		# start and end pos
1024		start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"]
1025		assert start_pos_arr.shape == (
1026			1,
1027			2,
1028		), (
1029			f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1030		)
1031		assert end_pos_arr.shape == (
1032			1,
1033			2,
1034		), (
1035			f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1036		)
1037
1038		start_pos: Coord = start_pos_arr[0]
1039		end_pos: Coord = end_pos_arr[0]
1040
1041		# return a TargetedLatticeMaze if that's what we wanted
1042		if cls == TargetedLatticeMaze:
1043			return TargetedLatticeMaze(
1044				connection_list=connection_list,
1045				start_pos=start_pos,
1046				end_pos=end_pos,
1047			)
1048
1049		# raw solution, only contains path elements and not start or end
1050		solution_raw: CoordArray = marked_pos["solution"]
1051		if len(solution_raw.shape) == 2:  # noqa: PLR2004
1052			assert solution_raw.shape[1] == 2, (  # noqa: PLR2004
1053				f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)"
1054			)
1055		elif solution_raw.shape == (0,):
1056			# the solution and end should be immediately adjacent
1057			assert np.sum(np.abs(start_pos - end_pos)) == 1, (
1058				f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given"
1059			)
1060
1061		# order the solution, by creating a list from the start to the end
1062		# add end pos, since we will iterate over all these starting from the start pos
1063		solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [
1064			tuple(end_pos),
1065		]
1066		# solution starts with start point
1067		solution: list[CoordTup] = [tuple(start_pos)]
1068		while solution[-1] != tuple(end_pos):
1069			# use `get_coord_neighbors` to find connected neighbors
1070			neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1])
1071			# TODO: make this less ugly
1072			assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), (  # noqa: PT018, PLR2004
1073				f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}"
1074			)
1075			# neighbors = neighbors[:, [1, 0]]
1076			# filter out neighbors that are not in the raw solution
1077			neighbors_filtered: CoordArray = np.array(
1078				[
1079					coord
1080					for coord in neighbors
1081					if (
1082						tuple(coord) in solution_raw_list
1083						and tuple(coord) not in solution
1084					)
1085				],
1086			)
1087			# assert only one element is left, and then add it to the solution
1088			assert neighbors_filtered.shape == (
1089				1,
1090				2,
1091			), (
1092				f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}"
1093			)
1094			solution.append(tuple(neighbors_filtered[0]))
1095
1096		# assert the solution is complete
1097		assert solution[0] == tuple(start_pos), (
1098			f"solution {solution} does not start at start_pos {start_pos}"
1099		)
1100		assert solution[-1] == tuple(end_pos), (
1101			f"solution {solution} does not end at end_pos {end_pos}"
1102		)
1103
1104		return cls(
1105			connection_list=np.array(connection_list),
1106			solution=np.array(solution),  # type: ignore[call-arg]
1107		)
1108
1109	# ============================================================
1110	# to and from ASCII
1111	# ============================================================
1112	def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]:
1113		# Get the pixel grid using to_pixels().
1114		pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw()
1115
1116		# Replace pixel values with ASCII characters.
1117		ascii_grid: Shaped[np.ndarray, "x y"] = np.full(
1118			pixel_grid.shape,
1119			AsciiChars.WALL,
1120			dtype=str,
1121		)
1122		ascii_grid[pixel_grid == True] = AsciiChars.OPEN  # noqa: E712
1123
1124		return ascii_grid
1125
1126	def as_ascii(
1127		self,
1128		show_endpoints: bool = True,
1129		show_solution: bool = True,
1130	) -> str:
1131		"""return an ASCII grid of the maze
1132
1133		useful for debugging in the terminal, or as it's own format
1134
1135		can be reversed with `LatticeMaze.from_ascii()`
1136		"""
1137		ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid()
1138		pixel_grid: PixelGrid = self.as_pixels(
1139			show_endpoints=show_endpoints,
1140			show_solution=show_solution,
1141		)
1142
1143		chars_replace: tuple = tuple()
1144		if show_endpoints:
1145			chars_replace += (AsciiChars.START, AsciiChars.END)
1146		if show_solution:
1147			chars_replace += (AsciiChars.PATH,)
1148
1149		for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1150			if ascii_char in chars_replace:
1151				ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char
1152
1153		return "\n".join("".join(row) for row in ascii_grid)
1154
1155	@classmethod
1156	def from_ascii(cls, ascii_str: str) -> "LatticeMaze":
1157		"get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)"
1158		lines: list[str] = ascii_str.strip().split("\n")
1159		lines = [line.strip() for line in lines]
1160		ascii_grid: Shaped[np.ndarray, "x y"] = np.array(
1161			[list(line) for line in lines],
1162			dtype=str,
1163		)
1164		pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8)
1165
1166		for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1167			pixel_grid[ascii_grid == ascii_char] = pixel_color
1168
1169		return cls.from_pixels(pixel_grid)
1170
1171
1172# type ignore here even though theyre all frozen
1173# maybe `SerializeableDataclass` itself is not frozen, but thats an ABC
1174# error: Cannot inherit frozen dataclass from a non-frozen one  [misc]
1175@serializable_dataclass(frozen=True, kw_only=True)
1176class TargetedLatticeMaze(LatticeMaze):  # type: ignore[misc]
1177	"""A LatticeMaze with a start and end position"""
1178
1179	# this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos
1180	# type ignore here because even though its a kw-only dataclass,
1181	# mypy doesn't like that non-default arguments are after default arguments
1182	start_pos: Coord = serializable_field(  # type: ignore[misc]
1183		assert_type=False,
1184	)
1185	end_pos: Coord = serializable_field(  # type: ignore[misc]
1186		assert_type=False,
1187	)
1188
1189	def __post_init__(self) -> None:
1190		"post init converts start and end pos to numpy arrays, checks they exist and are in bounds"
1191		# make things numpy arrays (very jank to override frozen dataclass)
1192		self.__dict__["start_pos"] = np.array(self.start_pos)
1193		self.__dict__["end_pos"] = np.array(self.end_pos)
1194		assert self.start_pos is not None
1195		assert self.end_pos is not None
1196		# check that start and end are in bounds
1197		if (
1198			self.start_pos[0] >= self.grid_shape[0]
1199			or self.start_pos[1] >= self.grid_shape[1]
1200		):
1201			err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}"
1202			raise ValueError(
1203				err_msg,
1204			)
1205		if (
1206			self.end_pos[0] >= self.grid_shape[0]
1207			or self.end_pos[1] >= self.grid_shape[1]
1208		):
1209			err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }"
1210			raise ValueError(
1211				err_msg,
1212			)
1213
1214	def __eq__(self, other: object) -> bool:
1215		"check equality, calls parent class equality check"
1216		return super().__eq__(other)
1217
1218	def _get_start_pos_tokens(self) -> list[str | CoordTup]:
1219		return [
1220			SPECIAL_TOKENS.ORIGIN_START,
1221			tuple(self.start_pos),
1222			SPECIAL_TOKENS.ORIGIN_END,
1223		]
1224
1225	def get_start_pos_tokens(self) -> list[str | CoordTup]:
1226		"(deprecated!) return the start position as a list of tokens"
1227		warnings.warn(
1228			"`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.",
1229			TokenizerDeprecationWarning,
1230		)
1231		return self._get_start_pos_tokens()
1232
1233	def _get_end_pos_tokens(self) -> list[str | CoordTup]:
1234		return [
1235			SPECIAL_TOKENS.TARGET_START,
1236			tuple(self.end_pos),
1237			SPECIAL_TOKENS.TARGET_END,
1238		]
1239
1240	def get_end_pos_tokens(self) -> list[str | CoordTup]:
1241		"(deprecated!) return the end position as a list of tokens"
1242		warnings.warn(
1243			"`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.",
1244			TokenizerDeprecationWarning,
1245		)
1246		return self._get_end_pos_tokens()
1247
1248	@classmethod
1249	def from_lattice_maze(
1250		cls,
1251		lattice_maze: LatticeMaze,
1252		start_pos: Coord | CoordTup,
1253		end_pos: Coord | CoordTup,
1254	) -> "TargetedLatticeMaze":
1255		"get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions"
1256		return cls(
1257			connection_list=lattice_maze.connection_list,
1258			start_pos=np.array(start_pos),
1259			end_pos=np.array(end_pos),
1260			generation_meta=lattice_maze.generation_meta,
1261		)
1262
1263
1264@serializable_dataclass(frozen=True, kw_only=True)
1265class SolvedMaze(TargetedLatticeMaze):  # type: ignore[misc]
1266	"""Stores a maze and a solution"""
1267
1268	solution: CoordArray = serializable_field(  # type: ignore[misc]
1269		assert_type=False,
1270	)
1271
1272	def __init__(
1273		self,
1274		connection_list: ConnectionList,
1275		solution: CoordArray,
1276		generation_meta: dict | None = None,
1277		start_pos: Coord | None = None,
1278		end_pos: Coord | None = None,
1279		allow_invalid: bool = False,
1280	) -> None:
1281		"""Create a SolvedMaze from a connection list and a solution
1282
1283		> DOCS: better documentation for this init method
1284		"""
1285		# figure out the solution
1286		solution_valid: bool = False
1287		if solution is not None:
1288			solution = np.array(solution)
1289			# note that a path length of 1 here is valid, since the start and end pos could be the same
1290			if (solution.shape[0] > 0) and (solution.shape[1] == 2):  # noqa: PLR2004
1291				solution_valid = True
1292
1293		if not solution_valid and not allow_invalid:
1294			err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }"
1295			raise ValueError(
1296				err_msg,
1297				f"{connection_list = }",
1298			)
1299
1300		# init the TargetedLatticeMaze
1301		super().__init__(
1302			connection_list=connection_list,
1303			generation_meta=generation_meta,
1304			# TODO: the argument type is stricter than the expected type but it still fails?
1305			# error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type
1306			# "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]"  [arg-type]
1307			start_pos=np.array(solution[0]) if solution_valid else None,  # type: ignore[arg-type]
1308			end_pos=np.array(solution[-1]) if solution_valid else None,  # type: ignore[arg-type]
1309		)
1310
1311		self.__dict__["solution"] = solution
1312
1313		# adjust the endpoints
1314		if not allow_invalid:
1315			if start_pos is not None:
1316				assert np.array_equal(np.array(start_pos), self.start_pos), (
1317					f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}"
1318				)
1319			if end_pos is not None:
1320				assert np.array_equal(np.array(end_pos), self.end_pos), (
1321					f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}"
1322				)
1323			# TODO: assert the path does not backtrack, walk through walls, etc?
1324
1325	def __eq__(self, other: object) -> bool:
1326		"check equality, calls parent class equality check"
1327		return super().__eq__(other)
1328
1329	def __hash__(self) -> int:
1330		"hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes"
1331		return hash((self.connection_list.tobytes(), self.solution.tobytes()))
1332
1333	def _get_solution_tokens(self) -> list[str | CoordTup]:
1334		return [
1335			SPECIAL_TOKENS.PATH_START,
1336			*[tuple(c) for c in self.solution],
1337			SPECIAL_TOKENS.PATH_END,
1338		]
1339
1340	def get_solution_tokens(self) -> list[str | CoordTup]:
1341		"(deprecated!) return the solution as a list of tokens"
1342		warnings.warn(
1343			"`LatticeMaze.get_solution_tokens` is deprecated.",
1344			TokenizerDeprecationWarning,
1345		)
1346		return self._get_solution_tokens()
1347
1348	# for backwards compatibility
1349	@property
1350	def maze(self) -> LatticeMaze:
1351		"(deprecated!) return the maze without the solution"
1352		warnings.warn(
1353			"`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.",
1354			DeprecationWarning,
1355		)
1356		return LatticeMaze(connection_list=self.connection_list)
1357
1358	# type ignore here since we're overriding a method with a different signature
1359	@classmethod
1360	def from_lattice_maze(  # type: ignore[override]
1361		cls,
1362		lattice_maze: LatticeMaze,
1363		solution: list[CoordTup] | CoordArray,
1364	) -> "SolvedMaze":
1365		"get a `SolvedMaze` from a `LatticeMaze` by specifying a solution"
1366		return cls(
1367			connection_list=lattice_maze.connection_list,
1368			solution=np.array(solution),
1369			generation_meta=lattice_maze.generation_meta,
1370		)
1371
1372	@classmethod
1373	def from_targeted_lattice_maze(
1374		cls,
1375		targeted_lattice_maze: TargetedLatticeMaze,
1376		solution: list[CoordTup] | CoordArray | None = None,
1377	) -> "SolvedMaze":
1378		"""solves the given targeted lattice maze and returns a SolvedMaze"""
1379		if solution is None:
1380			solution = targeted_lattice_maze.find_shortest_path(
1381				targeted_lattice_maze.start_pos,
1382				targeted_lattice_maze.end_pos,
1383			)
1384		return cls(
1385			connection_list=targeted_lattice_maze.connection_list,
1386			solution=np.array(solution),
1387			generation_meta=targeted_lattice_maze.generation_meta,
1388		)
1389
1390	def get_solution_forking_points(
1391		self,
1392		always_include_endpoints: bool = False,
1393	) -> tuple[list[int], CoordArray]:
1394		"""coordinates and their indicies from the solution where a fork is present
1395
1396		- if the start point is not a dead end, this counts as a fork
1397		- if the end point is not a dead end, this counts as a fork
1398		"""
1399		output_idxs: list[int] = list()
1400		output_coords: list[CoordTup] = list()
1401
1402		for idx, coord in enumerate(self.solution):
1403			# more than one choice for first coord, or more than 2 for any other
1404			# since the previous coord doesn't count as a choice
1405			is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1
1406			theshold: int = 1 if is_endpoint else 2
1407			if self.get_coord_neighbors(coord).shape[0] > theshold or (
1408				is_endpoint and always_include_endpoints
1409			):
1410				output_idxs.append(idx)
1411				output_coords.append(coord)
1412
1413		return output_idxs, np.array(output_coords)
1414
1415	def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]:
1416		"""coordinates from the solution where there is only a single (non-backtracking) point to move to
1417
1418		returns the complement of `get_solution_forking_points` from the path
1419		"""
1420		forks_idxs, _ = self.get_solution_forking_points()
1421		# HACK: idk why type ignore here
1422		return (  # type: ignore[return-value]
1423			np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0),
1424			np.delete(self.solution, forks_idxs, axis=0),
1425		)
1426
1427
1428def detect_pixels_type(data: PixelGrid) -> typing.Type[LatticeMaze]:
1429	"""Detects the type of pixels data by checking for the presence of start and end pixels"""
1430	if color_in_pixel_grid(data, PixelColors.START) or color_in_pixel_grid(
1431		data,
1432		PixelColors.END,
1433	):
1434		if color_in_pixel_grid(data, PixelColors.PATH):
1435			return SolvedMaze
1436		else:
1437			return TargetedLatticeMaze
1438	else:
1439		return LatticeMaze
1440
1441
1442def _remove_isolated_cells(
1443	image: Int[np.ndarray, "RGB x y"],
1444) -> Int[np.ndarray, "RGB x y"]:
1445	"""Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides."""
1446	# Create a binary mask where True represents walls
1447	wall_mask = np.all(image == PixelColors.WALL, axis=-1)
1448
1449	# Pad the wall mask to handle edge cases
1450	padded_wall_mask = np.pad(
1451		wall_mask,
1452		((1, 1), (1, 1)),
1453		mode="constant",
1454		constant_values=True,
1455	)
1456
1457	# Check neighbors in all four directions
1458	isolated_mask = (
1459		padded_wall_mask[1:-1, 2:]  # right
1460		& padded_wall_mask[1:-1, :-2]  # left
1461		& padded_wall_mask[2:, 1:-1]  # down
1462		& padded_wall_mask[:-2, 1:-1]  # up
1463	)
1464
1465	# Combine with non-wall mask to only affect open cells
1466	non_wall_mask = ~wall_mask
1467	isolated_mask = isolated_mask & non_wall_mask
1468
1469	# Create the output image
1470	output_image = image.copy()
1471	output_image[isolated_mask] = PixelColors.WALL
1472
1473	return output_image
1474
1475
1476_RIC_PADS: dict = {
1477	"left": ((1, 0), (0, 0)),
1478	"right": ((0, 1), (0, 0)),
1479	"up": ((0, 0), (1, 0)),
1480	"down": ((0, 0), (0, 1)),
1481}
1482
1483# Define slices for each direction
1484_RIC_SLICES: dict = {
1485	"left": (slice(1, None), slice(None, None)),
1486	"right": (slice(None, -1), slice(None, None)),
1487	"up": (slice(None, None), slice(1, None)),
1488	"down": (slice(None, None), slice(None, -1)),
1489}
1490
1491
1492# TODO: figure out why this function doesnt work, or maybe just get rid of it
1493# def _remove_isolated_cells_old(
1494#     image: Int[np.ndarray, "RGB x y"],
1495# ) -> Int[np.ndarray, "RGB x y"]:
1496#     """
1497#     Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides.
1498#     """
1499#     warnings.warn("this functin doesn't work and I have no idea why!!!")
1500#     masks: dict[str, np.ndarray] = {
1501#         d: np.all(
1502#             np.pad(
1503#                 image[_RIC_SLICES[d][0], _RIC_SLICES[d][1], :] == PixelColors.WALL,
1504#                 np.array((*_RIC_PADS[d], (0, 0)), dtype=np.int8),
1505#                 mode="constant",
1506#                 constant_values=True,
1507#             ),
1508#             axis=2,
1509#         )
1510#         for d in _RIC_SLICES.keys()
1511#     }
1512
1513#     # Create a mask for non-wall cells
1514#     mask_non_wall = np.all(image != PixelColors.WALL, axis=2)
1515
1516#     # print(f"{mask_non_wall.shape = }")
1517#     # print(f"{ {k: masks[k].shape for k in masks.keys()} = }")
1518
1519#     # print(f"{mask_non_wall = }")
1520#     # print(f"{masks['down'] = }")
1521
1522#     # Combine the masks
1523#     mask = mask_non_wall & masks["left"] & masks["right"] & masks["up"] & masks["down"]
1524
1525#     # Apply the mask
1526#     output_image = np.where(
1527#         np.stack([mask] * 3, axis=-1),
1528#         PixelColors.WALL,
1529#         image,
1530#     )
1531
1532#     return output_image

RGB = tuple[int, int, int]

rgb tuple of values 0-255

PixelGrid = <class 'jaxtyping.Int[ndarray, 'x y rgb']'>

rgb grid of pixels

BinaryPixelGrid = <class 'jaxtyping.Bool[ndarray, 'x y']'>

boolean grid of pixels

DIM_2: int = 2

2 dimensions

class NoValidEndpointException(builtins.Exception):
58class NoValidEndpointException(Exception):  # noqa: N818
59	"""Raised when no valid start or end positions are found in a maze."""
60
61	pass

Raised when no valid start or end positions are found in a maze.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
add_note
args
def color_in_pixel_grid( pixel_grid: jaxtyping.Int[ndarray, 'x y rgb'], color: tuple[int, int, int]) -> bool:
79def color_in_pixel_grid(pixel_grid: PixelGrid, color: RGB) -> bool:
80	"""check if a color is in a pixel grid"""
81	for row in pixel_grid:
82		for pixel in row:
83			if np.all(pixel == color):
84				return True
85	return False

check if a color is in a pixel grid

@dataclass(frozen=True)
class PixelColors:
88@dataclass(frozen=True)
89class PixelColors:
90	"standard colors for pixel grids"
91
92	WALL: RGB = (0, 0, 0)
93	OPEN: RGB = (255, 255, 255)
94	START: RGB = (0, 255, 0)
95	END: RGB = (255, 0, 0)
96	PATH: RGB = (0, 0, 255)

standard colors for pixel grids

PixelColors( WALL: tuple[int, int, int] = (0, 0, 0), OPEN: tuple[int, int, int] = (255, 255, 255), START: tuple[int, int, int] = (0, 255, 0), END: tuple[int, int, int] = (255, 0, 0), PATH: tuple[int, int, int] = (0, 0, 255))
WALL: tuple[int, int, int] = (0, 0, 0)
OPEN: tuple[int, int, int] = (255, 255, 255)
START: tuple[int, int, int] = (0, 255, 0)
END: tuple[int, int, int] = (255, 0, 0)
PATH: tuple[int, int, int] = (0, 0, 255)
@dataclass(frozen=True)
class AsciiChars:
 99@dataclass(frozen=True)
100class AsciiChars:
101	"standard ascii characters for mazes"
102
103	WALL: str = "#"
104	OPEN: str = " "
105	START: str = "S"
106	END: str = "E"
107	PATH: str = "X"

standard ascii characters for mazes

AsciiChars( WALL: str = '#', OPEN: str = ' ', START: str = 'S', END: str = 'E', PATH: str = 'X')
WALL: str = '#'
OPEN: str = ' '
START: str = 'S'
END: str = 'E'
PATH: str = 'X'
ASCII_PIXEL_PAIRINGS: dict[str, tuple[int, int, int]] = {'#': (0, 0, 0), ' ': (255, 255, 255), 'S': (0, 255, 0), 'E': (255, 0, 0), 'X': (0, 0, 255)}

map ascii characters to pixel colors

@serializable_dataclass(frozen=True, kw_only=True, properties_to_serialize=['lattice_dim', 'generation_meta'])
class LatticeMaze(muutils.json_serialize.serializable_dataclass.SerializableDataclass):
 120@serializable_dataclass(
 121	frozen=True,
 122	kw_only=True,
 123	properties_to_serialize=["lattice_dim", "generation_meta"],
 124)
 125class LatticeMaze(SerializableDataclass):
 126	"""lattice maze (nodes on a lattice, connections only to neighboring nodes)
 127
 128	Connection List represents which nodes (N) are connected in each direction.
 129
 130	First and second elements represent rightward and downward connections,
 131	respectively.
 132
 133	Example:
 134		Connection list:
 135			[
 136				[ # down
 137					[F T],
 138					[F F]
 139				],
 140				[ # right
 141					[T F],
 142					[T F]
 143				]
 144			]
 145
 146		Nodes with connections
 147			N T N F
 148			F	T
 149			N T N F
 150			F	F
 151
 152		Graph:
 153			N - N
 154				|
 155			N - N
 156
 157	Note: the bottom row connections going down, and the
 158	right-hand connections going right, will always be False.
 159
 160	"""
 161
 162	connection_list: ConnectionList
 163	generation_meta: dict | None = serializable_field(default=None, compare=False)
 164
 165	lattice_dim = property(lambda self: self.connection_list.shape[0])
 166	grid_shape = property(lambda self: self.connection_list.shape[1:])
 167	n_connections = property(lambda self: self.connection_list.sum())
 168
 169	@property
 170	def grid_n(self) -> int:
 171		"grid size as int, raises `AssertionError` if not square"
 172		assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported"
 173		return self.grid_shape[0]
 174
 175	# ============================================================
 176	# basic methods
 177	# ============================================================
 178
 179	def __eq__(self, other: object) -> bool:
 180		"equality check calls super"
 181		return super().__eq__(other)
 182
 183	@staticmethod
 184	def heuristic(a: CoordTup, b: CoordTup) -> float:
 185		"""return manhattan distance between two points"""
 186		return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])
 187
 188	def __hash__(self) -> int:
 189		"""hash the connection list by converting connection list to bytes"""
 190		return hash(self.connection_list.tobytes())
 191
 192	def nodes_connected(self, a: Coord, b: Coord, /) -> bool:
 193		"""returns whether two nodes are connected"""
 194		delta: Coord = b - a
 195		if np.abs(delta).sum() != 1:
 196			# return false if not even adjacent
 197			return False
 198		else:
 199			# test for wall
 200			dim: int = int(np.argmax(np.abs(delta)))
 201			clist_node: Coord = a if (delta.sum() > 0) else b
 202			return self.connection_list[dim, clist_node[0], clist_node[1]]
 203
 204	def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool:
 205		"""check if a path is valid"""
 206		# check path is not empty
 207		if len(path) == 0:
 208			return empty_is_valid
 209
 210		# check all coords in bounds of maze
 211		if not np.all((path >= 0) & (path < self.grid_shape)):
 212			return False
 213
 214		# check all nodes connected
 215		for i in range(len(path) - 1):
 216			if not self.nodes_connected(path[i], path[i + 1]):
 217				return False
 218		return True
 219
 220	def coord_degrees(self) -> Int8[np.ndarray, "row col"]:
 221		"""Returns an array with the connectivity degree of each coord.
 222
 223		I.e., how many neighbors each coord has.
 224		"""
 225		int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = (
 226			self.connection_list.astype(np.int8)
 227		)
 228		degrees: Int8[np.ndarray, "row col"] = np.sum(
 229			int_conn,
 230			axis=0,
 231		)  # Connections to east and south
 232		degrees[:, 1:] += int_conn[1, :, :-1]  # Connections to west
 233		degrees[1:, :] += int_conn[0, :-1, :]  # Connections to north
 234		return degrees
 235
 236	def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray:
 237		"""Returns an array of the neighboring, connected coords of `c`."""
 238		c = np.array(c)  # type: ignore[assignment]
 239		neighbors: list[Coord] = [
 240			neighbor
 241			for neighbor in (c + NEIGHBORS_MASK)
 242			if (
 243				(0 <= neighbor[0] < self.grid_shape[0])  # in x bounds
 244				and (0 <= neighbor[1] < self.grid_shape[1])  # in y bounds
 245				and self.nodes_connected(c, neighbor)  # connected
 246			)
 247		]
 248
 249		output: CoordArray = np.array(neighbors)
 250		if len(neighbors) > 0:
 251			assert output.shape == (
 252				len(neighbors),
 253				2,
 254			), (
 255				f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}"
 256			)
 257		return output
 258
 259	def gen_connected_component_from(self, c: Coord) -> CoordArray:
 260		"""return the connected component from a given coordinate"""
 261		# Stack for DFS
 262		stack: list[Coord] = [c]
 263
 264		# Set to store visited nodes
 265		visited: set[CoordTup] = set()
 266
 267		while stack:
 268			current_node: Coord = stack.pop()
 269			# this is fine since we know current_node is a coord and thus of length 2
 270			visited.add(tuple(current_node))  # type: ignore[arg-type]
 271
 272			# Get the neighbors of the current node
 273			neighbors = self.get_coord_neighbors(current_node)
 274
 275			# Iterate over neighbors
 276			for neighbor in neighbors:
 277				if tuple(neighbor) not in visited:
 278					stack.append(neighbor)
 279
 280		return np.array(list(visited))
 281
 282	def find_shortest_path(
 283		self,
 284		c_start: CoordTup | Coord,
 285		c_end: CoordTup | Coord,
 286	) -> CoordArray:
 287		"""find the shortest path between two coordinates, using A*"""
 288		c_start = tuple(c_start)  # type: ignore[assignment]
 289		c_end = tuple(c_end)  # type: ignore[assignment]
 290
 291		g_score: dict[CoordTup, float] = (
 292			dict()
 293		)  # cost of cheapest path to node from start currently known
 294		f_score: dict[CoordTup, float] = {
 295			c_start: 0.0,
 296		}  # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end)
 297
 298		# init
 299		g_score[c_start] = 0.0
 300		g_score[c_start] = self.heuristic(c_start, c_end)
 301
 302		closed_vtx: set[CoordTup] = set()  # nodes already evaluated
 303		# nodes to be evaluated
 304		# we need a set of the tuples, dont place the ints in the set
 305		open_vtx: set[CoordTup] = set([c_start])  # noqa: C405
 306		source: dict[CoordTup, CoordTup] = (
 307			dict()
 308		)  # node immediately preceding each node in the path (currently known shortest path)
 309
 310		while open_vtx:
 311			# get lowest f_score node
 312			# mypy cant tell that c is of length 2
 313			c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)])  # type: ignore[index]
 314			# f_current: float = f_score[c_current]
 315
 316			# check if goal is reached
 317			if c_end == c_current:
 318				path: list[CoordTup] = [c_current]
 319				p_current: CoordTup = c_current
 320				while p_current in source:
 321					p_current = source[p_current]
 322					path.append(p_current)
 323				# ----------------------------------------------------------------------
 324				# this is the only return statement
 325				return np.array(path[::-1])
 326				# ----------------------------------------------------------------------
 327
 328			# close current node
 329			closed_vtx.add(c_current)
 330			open_vtx.remove(c_current)
 331
 332			# update g_score of neighbors
 333			_np_neighbor: Coord
 334			for _np_neighbor in self.get_coord_neighbors(c_current):
 335				neighbor: CoordTup = tuple(_np_neighbor)
 336
 337				if neighbor in closed_vtx:
 338					# already checked
 339					continue
 340				g_temp: float = g_score[c_current] + 1  # always 1 for maze neighbors
 341
 342				if neighbor not in open_vtx:
 343					# found new vtx, so add
 344					open_vtx.add(neighbor)
 345
 346				elif g_temp >= g_score[neighbor]:
 347					# if already knew about this one, but current g_score is worse, skip
 348					continue
 349
 350				# store g_score and source
 351				source[neighbor] = c_current
 352				g_score[neighbor] = g_temp
 353				f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end)
 354
 355		raise ValueError(
 356			"A solution could not be found!",
 357			f"{c_start = }, {c_end = }",
 358			self.as_ascii(),
 359		)
 360
 361	def get_nodes(self) -> CoordArray:
 362		"""return a list of all nodes in the maze"""
 363		rows: Int[np.ndarray, "x y"]
 364		cols: Int[np.ndarray, "x y"]
 365		rows, cols = np.meshgrid(
 366			range(self.grid_shape[0]),
 367			range(self.grid_shape[1]),
 368			indexing="ij",
 369		)
 370		nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T
 371		return nodes
 372
 373	def get_connected_component(self) -> CoordArray:
 374		"""get the largest (and assumed only nonsingular) connected component of the maze
 375
 376		TODO: other connected components?
 377		"""
 378		if (self.generation_meta is None) or (
 379			self.generation_meta.get("fully_connected", False)
 380		):
 381			# for fully connected case, pick any two positions
 382			return self.get_nodes()
 383		else:
 384			# if metadata provided, use visited cells
 385			visited_cells: set[CoordTup] | None = self.generation_meta.get(
 386				"visited_cells",
 387				None,
 388			)
 389			if visited_cells is None:
 390				# TODO: dynamically generate visited_cells?
 391				err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}"
 392				raise ValueError(
 393					err_msg,
 394				)
 395			visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells))
 396			return visited_cells_np
 397
 398	@typing.overload
 399	def generate_random_path(
 400		self,
 401		allowed_start: CoordList | None = None,
 402		allowed_end: CoordList | None = None,
 403		deadend_start: bool = False,
 404		deadend_end: bool = False,
 405		endpoints_not_equal: bool = False,
 406		except_on_no_valid_endpoint: typing.Literal[True] = True,
 407	) -> CoordArray: ...
 408	@typing.overload
 409	def generate_random_path(
 410		self,
 411		allowed_start: CoordList | None = None,
 412		allowed_end: CoordList | None = None,
 413		deadend_start: bool = False,
 414		deadend_end: bool = False,
 415		endpoints_not_equal: bool = False,
 416		except_on_no_valid_endpoint: typing.Literal[False] = False,
 417	) -> typing.Optional[CoordArray]: ...
 418	def generate_random_path(  # noqa: C901
 419		self,
 420		allowed_start: CoordList | None = None,
 421		allowed_end: CoordList | None = None,
 422		deadend_start: bool = False,
 423		deadend_end: bool = False,
 424		endpoints_not_equal: bool = False,
 425		except_on_no_valid_endpoint: bool = True,
 426	) -> typing.Optional[CoordArray]:
 427		"""return a path between randomly chosen start and end nodes within the connected component
 428
 429		Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
 430
 431		# Parameters:
 432		- `allowed_start : CoordList | None`
 433			a list of allowed start positions. If `None`, any position in the connected component is allowed
 434			(defaults to `None`)
 435		- `allowed_end : CoordList | None`
 436			a list of allowed end positions. If `None`, any position in the connected component is allowed
 437			(defaults to `None`)
 438		- `deadend_start : bool`
 439			whether to ***force*** the start position to be a deadend (defaults to `False`)
 440			(defaults to `False`)
 441		- `deadend_end : bool`
 442			whether to ***force*** the end position to be a deadend (defaults to `False`)
 443			(defaults to `False`)
 444		- `endpoints_not_equal : bool`
 445			whether to ensure tha the start and end point are not the same
 446			(defaults to `False`)
 447		- `except_on_no_valid_endpoint : bool`
 448			whether to raise an error if no valid start or end positions are found
 449			if this is `False`, the function might return `None` and this must be handled by the caller
 450			(defaults to `True`)
 451
 452		# Returns:
 453		- `CoordArray`
 454			a path between the selected start and end positions
 455
 456		# Raises:
 457		- `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True`
 458		"""
 459		# we can't create a "path" in a single-node maze
 460		assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, (  # noqa: PT018
 461			f"can't create path in single-node maze: {self.as_ascii()}"
 462		)
 463
 464		# get connected component
 465		connected_component: CoordArray = self.get_connected_component()
 466
 467		# initialize start and end positions
 468		positions: Int[np.int8, "2 2"]
 469
 470		# if no special conditions on start and end positions
 471		if (allowed_start, allowed_end, deadend_start, deadend_end) == (
 472			None,
 473			None,
 474			False,
 475			False,
 476		):
 477			try:
 478				positions = connected_component[  # type: ignore[assignment]
 479					np.random.choice(
 480						len(connected_component),
 481						size=2,
 482						replace=False,
 483					)
 484				]
 485			except ValueError as e:
 486				if except_on_no_valid_endpoint:
 487					err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }"
 488					raise NoValidEndpointException(
 489						err_msg,
 490					) from e
 491				return None
 492
 493			return self.find_shortest_path(positions[0], positions[1])  # type: ignore[index]
 494
 495		# handle special conditions
 496		connected_component_set: set[CoordTup] = set(map(tuple, connected_component))
 497		# copy connected component set
 498		allowed_start_set: set[CoordTup] = connected_component_set.copy()
 499		allowed_end_set: set[CoordTup] = connected_component_set.copy()
 500
 501		# filter by explicitly allowed start and end positions
 502		# '# type: ignore[assignment]' here because the returned tuple can be of any length
 503		if allowed_start is not None:
 504			allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set  # type: ignore[assignment]
 505
 506		if allowed_end is not None:
 507			allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set  # type: ignore[assignment]
 508
 509		# filter by forcing deadends
 510		if deadend_start:
 511			allowed_start_set = set(
 512				filter(
 513					lambda x: len(self.get_coord_neighbors(x)) == 1,
 514					allowed_start_set,
 515				),
 516			)
 517
 518		if deadend_end:
 519			allowed_end_set = set(
 520				filter(
 521					lambda x: len(self.get_coord_neighbors(x)) == 1,
 522					allowed_end_set,
 523				),
 524			)
 525
 526		# check we have valid positions
 527		if len(allowed_start_set) == 0 or len(allowed_end_set) == 0:
 528			if except_on_no_valid_endpoint:
 529				err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }"
 530				raise NoValidEndpointException(
 531					err_msg,
 532				)
 533			return None
 534
 535		# randomly select start and end positions
 536		try:
 537			# ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok
 538			start_pos: CoordTup = tuple(  # type: ignore[assignment]
 539				list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))],
 540			)
 541			if endpoints_not_equal:
 542				# remove start position from end positions
 543				allowed_end_set.discard(start_pos)
 544			end_pos: CoordTup = tuple(  # type: ignore[assignment]
 545				list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))],
 546			)
 547		except ValueError as e:
 548			if except_on_no_valid_endpoint:
 549				err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }"
 550				raise NoValidEndpointException(
 551					err_msg,
 552				) from e
 553			return None
 554
 555		return self.find_shortest_path(start_pos, end_pos)
 556
 557	# ============================================================
 558	# to and from adjacency list
 559	# ============================================================
 560	def as_adj_list(
 561		self,
 562		shuffle_d0: bool = True,
 563		shuffle_d1: bool = True,
 564	) -> Int8[np.ndarray, "conn start_end coord"]:
 565		"""return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`"""
 566		return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)
 567
 568	@classmethod
 569	def from_adj_list(
 570		cls,
 571		adj_list: Int8[np.ndarray, "conn start_end coord"],
 572	) -> "LatticeMaze":
 573		"""create a LatticeMaze from a list of connections
 574
 575		> [!NOTE]
 576		> This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
 577		"""
 578		# this is where it would probably break for rectangular mazes
 579		grid_n: int = adj_list.max() + 1
 580
 581		connection_list: ConnectionList = np.zeros(
 582			(2, grid_n, grid_n),
 583			dtype=np.bool_,
 584		)
 585
 586		for c_start, c_end in adj_list:
 587			# check that exactly 1 coordinate matches
 588			if (c_start == c_end).sum() != 1:
 589				raise ValueError("invalid connection")
 590
 591			# get the direction
 592			d: int = (c_start != c_end).argmax()
 593
 594			x: int
 595			y: int
 596			# pick whichever has the lesser value in the direction `d`
 597			if c_start[d] < c_end[d]:
 598				x, y = c_start
 599			else:
 600				x, y = c_end
 601
 602			connection_list[d, x, y] = True
 603
 604		return LatticeMaze(
 605			connection_list=connection_list,
 606		)
 607
 608	def as_adj_list_tokens(self) -> list[str | CoordTup]:
 609		"""(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead"""
 610		warnings.warn(
 611			"`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.",
 612			TokenizerDeprecationWarning,
 613		)
 614		return [
 615			SPECIAL_TOKENS.ADJLIST_START,
 616			*chain.from_iterable(  # type: ignore[list-item]
 617				[
 618					[
 619						tuple(c_s),
 620						SPECIAL_TOKENS.CONNECTOR,
 621						tuple(c_e),
 622						SPECIAL_TOKENS.ADJACENCY_ENDLINE,
 623					]
 624					for c_s, c_e in self.as_adj_list()
 625				],
 626			),
 627			SPECIAL_TOKENS.ADJLIST_END,
 628		]
 629
 630	def _as_adj_list_tokens(self) -> list[str | CoordTup]:
 631		return [
 632			SPECIAL_TOKENS.ADJLIST_START,
 633			*chain.from_iterable(  # type: ignore[list-item]
 634				[
 635					[
 636						tuple(c_s),
 637						SPECIAL_TOKENS.CONNECTOR,
 638						tuple(c_e),
 639						SPECIAL_TOKENS.ADJACENCY_ENDLINE,
 640					]
 641					for c_s, c_e in self.as_adj_list()
 642				],
 643			),
 644			SPECIAL_TOKENS.ADJLIST_END,
 645		]
 646
 647	def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]:
 648		"""turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples"""
 649		output: list[CoordTup | str] = self._as_adj_list_tokens()
 650		# if getattr(self, "start_pos", None) is not None:
 651		if isinstance(self, TargetedLatticeMaze):
 652			output += self._get_start_pos_tokens()
 653		if isinstance(self, TargetedLatticeMaze):
 654			output += self._get_end_pos_tokens()
 655		if isinstance(self, SolvedMaze):
 656			output += self._get_solution_tokens()
 657		return output
 658
 659	def _as_tokens(
 660		self,
 661		maze_tokenizer: "MazeTokenizer | TokenizationMode",
 662	) -> list[str]:
 663		# type ignores here fine since we check the instance
 664		if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
 665			maze_tokenizer = maze_tokenizer.to_legacy_tokenizer()  # type: ignore[union-attr]
 666		if (
 667			isinstance_by_type_name(maze_tokenizer, "MazeTokenizer")
 668			and maze_tokenizer.is_AOTP()  # type: ignore[union-attr]
 669		):
 670			coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP()
 671			coords_processed: list[str] = maze_tokenizer.coords_to_strings(  # type: ignore[union-attr]
 672				coords=coords_raw,
 673				when_noncoord="include",
 674			)
 675			return coords_processed
 676		else:
 677			err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}"
 678			raise NotImplementedError(err_msg)
 679
 680	def as_tokens(
 681		self,
 682		maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
 683	) -> list[str]:
 684		"""serialize maze and solution to tokens"""
 685		if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"):
 686			return maze_tokenizer.to_tokens(self)  # type: ignore[union-attr]
 687		else:
 688			return self._as_tokens(maze_tokenizer)  # type: ignore[union-attr,arg-type]
 689
 690	@classmethod
 691	def _from_tokens_AOTP(
 692		cls,
 693		tokens: list[str],
 694		maze_tokenizer: "MazeTokenizer | MazeTokenizerModular",
 695	) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze":
 696		"""create a LatticeMaze from a list of tokens"""
 697		# figure out what input format
 698		# ========================================
 699		if tokens[0] == SPECIAL_TOKENS.ADJLIST_START:
 700			adj_list_tokens = get_adj_list_tokens(tokens)
 701		else:
 702			# If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens
 703			adj_list_tokens = tokens
 704			warnings.warn(
 705				"Assuming input is just adjacency list tokens, no special tokens found",
 706			)
 707
 708		# process edges for adjacency list
 709		# ========================================
 710		edges: list[list[str]] = list_split(
 711			adj_list_tokens,
 712			SPECIAL_TOKENS.ADJACENCY_ENDLINE,
 713		)
 714
 715		coordinates: list[tuple[CoordTup, CoordTup]] = list()
 716		for e in edges:
 717			# skip last endline
 718			if len(e) != 0:
 719				# convert to coords, split start and end
 720				e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords(
 721					e,
 722					when_noncoord="include",
 723				)
 724				# this assertion depends on the tokenizer having exactly one token for the connector
 725				# which is also why we "include" above
 726				# the connector token is discarded below
 727				assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }"  # noqa: PLR2004
 728				assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, (
 729					f"invalid edge: {e = } {e_coords = }"
 730				)
 731				e_coords_first: CoordTup = e_coords[0]  # type: ignore[assignment]
 732				e_coords_last: CoordTup = e_coords[-1]  # type: ignore[assignment]
 733				coordinates.append((e_coords_first, e_coords_last))
 734
 735		assert all(len(c) == DIM_2 for c in coordinates), (
 736			f"invalid coordinates: {coordinates = }"
 737		)
 738		adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates)
 739		assert tuple(adj_list.shape) == (
 740			len(coordinates),
 741			2,
 742			2,
 743		), f"invalid adj_list: {adj_list.shape = } {coordinates = }"
 744
 745		output_maze: LatticeMaze = cls.from_adj_list(adj_list)
 746
 747		# add start and end positions
 748		# ========================================
 749		is_targeted: bool = False
 750		if all(
 751			x in tokens
 752			for x in (
 753				SPECIAL_TOKENS.ORIGIN_START,
 754				SPECIAL_TOKENS.ORIGIN_END,
 755				SPECIAL_TOKENS.TARGET_START,
 756				SPECIAL_TOKENS.TARGET_END,
 757			)
 758		):
 759			start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords(
 760				get_origin_tokens(tokens),
 761				when_noncoord="error",
 762			)
 763			end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords(
 764				get_target_tokens(tokens),
 765				when_noncoord="error",
 766			)
 767			assert len(start_pos_list) == 1, (
 768				f"invalid start_pos_list: {start_pos_list = }"
 769			)
 770			assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }"
 771
 772			start_pos: CoordTup = start_pos_list[0]
 773			end_pos: CoordTup = end_pos_list[0]
 774
 775			output_maze = TargetedLatticeMaze.from_lattice_maze(
 776				lattice_maze=output_maze,
 777				start_pos=start_pos,
 778				end_pos=end_pos,
 779			)
 780
 781			is_targeted = True
 782
 783		if all(
 784			x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END)
 785		):
 786			assert is_targeted, "maze must be targeted to have a solution"
 787			solution: list[CoordTup] = maze_tokenizer.strings_to_coords(
 788				get_path_tokens(tokens, trim_end=True),
 789				when_noncoord="error",
 790			)
 791			output_maze = SolvedMaze.from_targeted_lattice_maze(
 792				# HACK: I think this is fine, but im not sure
 793				targeted_lattice_maze=output_maze,  # type: ignore[arg-type]
 794				solution=solution,
 795			)
 796
 797		return output_maze
 798
 799	# TODO: any way to get return type hinting working for this?
 800	@classmethod
 801	def from_tokens(
 802		cls,
 803		tokens: list[str],
 804		maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
 805	) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze":
 806		"""Constructs a maze from a tokenization.
 807
 808		Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported.
 809		"""
 810		# HACK: type ignores here fine since we check the instance
 811		if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
 812			maze_tokenizer = maze_tokenizer.to_legacy_tokenizer()  # type: ignore[union-attr]
 813		if (
 814			isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular")
 815			and not maze_tokenizer.is_legacy_equivalent()  # type: ignore[union-attr]
 816		):
 817			err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}."
 818			raise NotImplementedError(
 819				err_msg,
 820			)
 821
 822		if isinstance(tokens, str):
 823			tokens = tokens.split()
 824
 825		if maze_tokenizer.is_AOTP():  # type: ignore[union-attr]
 826			return cls._from_tokens_AOTP(tokens, maze_tokenizer)  # type: ignore[arg-type]
 827		else:
 828			raise NotImplementedError("only AOTP tokenization is supported")
 829
 830	# ============================================================
 831	# to and from pixels
 832	# ============================================================
 833	def _as_pixels_bw(self) -> BinaryPixelGrid:
 834		assert self.lattice_dim == DIM_2, "only 2D mazes are supported"
 835		# Create an empty pixel grid with walls
 836		pixel_grid: Int[np.ndarray, "x y"] = np.full(
 837			(self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1),
 838			False,
 839			dtype=np.bool_,
 840		)
 841
 842		# Set white nodes
 843		pixel_grid[1::2, 1::2] = True
 844
 845		# Set white connections (downward)
 846		for i, row in enumerate(self.connection_list[0]):
 847			for j, connected in enumerate(row):
 848				if connected:
 849					pixel_grid[i * 2 + 2, j * 2 + 1] = True
 850
 851		# Set white connections (rightward)
 852		for i, row in enumerate(self.connection_list[1]):
 853			for j, connected in enumerate(row):
 854				if connected:
 855					pixel_grid[i * 2 + 1, j * 2 + 2] = True
 856
 857		return pixel_grid
 858
 859	def as_pixels(
 860		self,
 861		show_endpoints: bool = True,
 862		show_solution: bool = True,
 863	) -> PixelGrid:
 864		"""convert the maze to a pixel grid
 865
 866		- useful as a simpler way of plotting the maze than the more complex `MazePlot`
 867		- the same underlying representation as `as_ascii` but as an image
 868		- used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data
 869		"""
 870		# HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze`
 871		# but solution, start_pos, end_pos not always defined
 872		# but its fine since we explicitly check the type
 873		if show_solution and not show_endpoints:
 874			raise ValueError("show_solution=True requires show_endpoints=True")
 875		# convert original bool pixel grid to RGB
 876		pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw()
 877		pixel_grid: PixelGrid = np.full(
 878			(*pixel_grid_bw.shape, 3),
 879			PixelColors.WALL,
 880			dtype=np.uint8,
 881		)
 882		pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN  # noqa: E712
 883
 884		if self.__class__ == LatticeMaze:
 885			return pixel_grid
 886
 887		# set endpoints for TargetedLatticeMaze
 888		if self.__class__ == TargetedLatticeMaze:
 889			if show_endpoints:
 890				pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 891					PixelColors.START
 892				)
 893				pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 894					PixelColors.END
 895				)
 896			return pixel_grid
 897
 898		# set solution -- we only reach this part if `self.__class__ == SolvedMaze`
 899		if show_solution:
 900			for coord in self.solution:  # type: ignore[attr-defined]
 901				pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH
 902
 903			# set pixels between coords
 904			for index, coord in enumerate(self.solution[:-1]):  # type: ignore[attr-defined]
 905				next_coord = self.solution[index + 1]  # type: ignore[attr-defined]
 906				# check they are adjacent using norm
 907				assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, (
 908					f"Coords {coord} and {next_coord} are not adjacent"
 909				)
 910				# set pixel between them
 911				pixel_grid[
 912					coord[0] * 2 + 1 + next_coord[0] - coord[0],
 913					coord[1] * 2 + 1 + next_coord[1] - coord[1],
 914				] = PixelColors.PATH
 915
 916			# set endpoints (again, since path would overwrite them)
 917			pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 918				PixelColors.START
 919			)
 920			pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 921				PixelColors.END
 922			)
 923
 924		return pixel_grid
 925
 926	@classmethod
 927	def _from_pixel_grid_bw(
 928		cls,
 929		pixel_grid: BinaryPixelGrid,
 930	) -> tuple[ConnectionList, tuple[int, int]]:
 931		grid_shape: tuple[int, int] = (
 932			pixel_grid.shape[0] // 2,
 933			pixel_grid.shape[1] // 2,
 934		)
 935		connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_)
 936
 937		# Extract downward connections
 938		connection_list[0] = pixel_grid[2::2, 1::2]
 939
 940		# Extract rightward connections
 941		connection_list[1] = pixel_grid[1::2, 2::2]
 942
 943		return connection_list, grid_shape
 944
 945	@classmethod
 946	def _from_pixel_grid_with_positions(
 947		cls,
 948		pixel_grid: PixelGrid | BinaryPixelGrid,
 949		marked_positions: dict[str, RGB],
 950	) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]:
 951		# Convert RGB pixel grid to Bool pixel grid
 952		# error: Incompatible types in assignment (expression has type
 953		# "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]",
 954		# variable has type "ndarray[Any, Any]")  [assignment]
 955		pixel_grid_bw: BinaryPixelGrid = ~np.all(  # type: ignore[assignment]
 956			pixel_grid == PixelColors.WALL,
 957			axis=-1,
 958		)
 959		connection_list: ConnectionList
 960		grid_shape: tuple[int, int]
 961		connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw)
 962
 963		# Find any marked positions
 964		out_positions: dict[str, CoordArray] = dict()
 965		for key, color in marked_positions.items():
 966			pos_temp: Int[np.ndarray, "x y"] = np.argwhere(
 967				np.all(pixel_grid == color, axis=-1),
 968			)
 969			pos_save: list[CoordTup] = list()
 970			for pos in pos_temp:
 971				# if it is a coordinate and not connection (transform position, %2==1)
 972				if pos[0] % 2 == 1 and pos[1] % 2 == 1:
 973					pos_save.append((pos[0] // 2, pos[1] // 2))
 974
 975			out_positions[key] = np.array(pos_save)
 976
 977		return connection_list, grid_shape, out_positions
 978
 979	@classmethod
 980	def from_pixels(
 981		cls,
 982		pixel_grid: PixelGrid,
 983	) -> "LatticeMaze":
 984		"""create a LatticeMaze from a pixel grid. reverse of `as_pixels`
 985
 986		# Raises:
 987		- `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze`
 988		"""
 989		connection_list: ConnectionList
 990		grid_shape: tuple[int, int]
 991
 992		# if a binary pixel grid, return regular LatticeMaze
 993		if len(pixel_grid.shape) == 2:  # noqa: PLR2004
 994			connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid)
 995			return LatticeMaze(connection_list=connection_list)
 996
 997		# otherwise, detect and check it's valid
 998		cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid)
 999		if cls not in cls_detected.__mro__:
1000			err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }"
1001			raise ValueError(
1002				err_msg,
1003			)
1004
1005		(
1006			connection_list,
1007			grid_shape,
1008			marked_pos,
1009		) = cls._from_pixel_grid_with_positions(
1010			pixel_grid=pixel_grid,
1011			marked_positions=dict(
1012				start=PixelColors.START,
1013				end=PixelColors.END,
1014				solution=PixelColors.PATH,
1015			),
1016		)
1017		# if we wanted a LatticeMaze, return it
1018		if cls == LatticeMaze:
1019			return LatticeMaze(connection_list=connection_list)
1020
1021		# otherwise, keep going
1022		temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list)
1023
1024		# start and end pos
1025		start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"]
1026		assert start_pos_arr.shape == (
1027			1,
1028			2,
1029		), (
1030			f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1031		)
1032		assert end_pos_arr.shape == (
1033			1,
1034			2,
1035		), (
1036			f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1037		)
1038
1039		start_pos: Coord = start_pos_arr[0]
1040		end_pos: Coord = end_pos_arr[0]
1041
1042		# return a TargetedLatticeMaze if that's what we wanted
1043		if cls == TargetedLatticeMaze:
1044			return TargetedLatticeMaze(
1045				connection_list=connection_list,
1046				start_pos=start_pos,
1047				end_pos=end_pos,
1048			)
1049
1050		# raw solution, only contains path elements and not start or end
1051		solution_raw: CoordArray = marked_pos["solution"]
1052		if len(solution_raw.shape) == 2:  # noqa: PLR2004
1053			assert solution_raw.shape[1] == 2, (  # noqa: PLR2004
1054				f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)"
1055			)
1056		elif solution_raw.shape == (0,):
1057			# the solution and end should be immediately adjacent
1058			assert np.sum(np.abs(start_pos - end_pos)) == 1, (
1059				f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given"
1060			)
1061
1062		# order the solution, by creating a list from the start to the end
1063		# add end pos, since we will iterate over all these starting from the start pos
1064		solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [
1065			tuple(end_pos),
1066		]
1067		# solution starts with start point
1068		solution: list[CoordTup] = [tuple(start_pos)]
1069		while solution[-1] != tuple(end_pos):
1070			# use `get_coord_neighbors` to find connected neighbors
1071			neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1])
1072			# TODO: make this less ugly
1073			assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), (  # noqa: PT018, PLR2004
1074				f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}"
1075			)
1076			# neighbors = neighbors[:, [1, 0]]
1077			# filter out neighbors that are not in the raw solution
1078			neighbors_filtered: CoordArray = np.array(
1079				[
1080					coord
1081					for coord in neighbors
1082					if (
1083						tuple(coord) in solution_raw_list
1084						and tuple(coord) not in solution
1085					)
1086				],
1087			)
1088			# assert only one element is left, and then add it to the solution
1089			assert neighbors_filtered.shape == (
1090				1,
1091				2,
1092			), (
1093				f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}"
1094			)
1095			solution.append(tuple(neighbors_filtered[0]))
1096
1097		# assert the solution is complete
1098		assert solution[0] == tuple(start_pos), (
1099			f"solution {solution} does not start at start_pos {start_pos}"
1100		)
1101		assert solution[-1] == tuple(end_pos), (
1102			f"solution {solution} does not end at end_pos {end_pos}"
1103		)
1104
1105		return cls(
1106			connection_list=np.array(connection_list),
1107			solution=np.array(solution),  # type: ignore[call-arg]
1108		)
1109
1110	# ============================================================
1111	# to and from ASCII
1112	# ============================================================
1113	def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]:
1114		# Get the pixel grid using to_pixels().
1115		pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw()
1116
1117		# Replace pixel values with ASCII characters.
1118		ascii_grid: Shaped[np.ndarray, "x y"] = np.full(
1119			pixel_grid.shape,
1120			AsciiChars.WALL,
1121			dtype=str,
1122		)
1123		ascii_grid[pixel_grid == True] = AsciiChars.OPEN  # noqa: E712
1124
1125		return ascii_grid
1126
1127	def as_ascii(
1128		self,
1129		show_endpoints: bool = True,
1130		show_solution: bool = True,
1131	) -> str:
1132		"""return an ASCII grid of the maze
1133
1134		useful for debugging in the terminal, or as it's own format
1135
1136		can be reversed with `LatticeMaze.from_ascii()`
1137		"""
1138		ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid()
1139		pixel_grid: PixelGrid = self.as_pixels(
1140			show_endpoints=show_endpoints,
1141			show_solution=show_solution,
1142		)
1143
1144		chars_replace: tuple = tuple()
1145		if show_endpoints:
1146			chars_replace += (AsciiChars.START, AsciiChars.END)
1147		if show_solution:
1148			chars_replace += (AsciiChars.PATH,)
1149
1150		for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1151			if ascii_char in chars_replace:
1152				ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char
1153
1154		return "\n".join("".join(row) for row in ascii_grid)
1155
1156	@classmethod
1157	def from_ascii(cls, ascii_str: str) -> "LatticeMaze":
1158		"get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)"
1159		lines: list[str] = ascii_str.strip().split("\n")
1160		lines = [line.strip() for line in lines]
1161		ascii_grid: Shaped[np.ndarray, "x y"] = np.array(
1162			[list(line) for line in lines],
1163			dtype=str,
1164		)
1165		pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8)
1166
1167		for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1168			pixel_grid[ascii_grid == ascii_char] = pixel_color
1169
1170		return cls.from_pixels(pixel_grid)

lattice maze (nodes on a lattice, connections only to neighboring nodes)

Connection List represents which nodes (N) are connected in each direction.

First and second elements represent rightward and downward connections, respectively.

Example: Connection list: [ [ # down [F T], [F F] ], [ # right [T F], [T F] ] ]

    Nodes with connections
            N T N F
            F       T
            N T N F
            F       F

    Graph:
            N - N
                    |
            N - N

Note: the bottom row connections going down, and the right-hand connections going right, will always be False.

LatticeMaze( *, connection_list: jaxtyping.Bool[ndarray, 'lattice_dim=2 row col'], generation_meta: dict | None = None)
connection_list: jaxtyping.Bool[ndarray, 'lattice_dim=2 row col']
generation_meta: dict | None = None
lattice_dim
165	lattice_dim = property(lambda self: self.connection_list.shape[0])
grid_shape
166	grid_shape = property(lambda self: self.connection_list.shape[1:])
n_connections
167	n_connections = property(lambda self: self.connection_list.sum())
grid_n: int
169	@property
170	def grid_n(self) -> int:
171		"grid size as int, raises `AssertionError` if not square"
172		assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported"
173		return self.grid_shape[0]

grid size as int, raises AssertionError if not square

@staticmethod
def heuristic(a: tuple[int, int], b: tuple[int, int]) -> float:
183	@staticmethod
184	def heuristic(a: CoordTup, b: CoordTup) -> float:
185		"""return manhattan distance between two points"""
186		return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])

return manhattan distance between two points

def nodes_connected( self, a: jaxtyping.Int8[ndarray, 'row_col=2'], b: jaxtyping.Int8[ndarray, 'row_col=2'], /) -> bool:
192	def nodes_connected(self, a: Coord, b: Coord, /) -> bool:
193		"""returns whether two nodes are connected"""
194		delta: Coord = b - a
195		if np.abs(delta).sum() != 1:
196			# return false if not even adjacent
197			return False
198		else:
199			# test for wall
200			dim: int = int(np.argmax(np.abs(delta)))
201			clist_node: Coord = a if (delta.sum() > 0) else b
202			return self.connection_list[dim, clist_node[0], clist_node[1]]

returns whether two nodes are connected

def is_valid_path( self, path: jaxtyping.Int8[ndarray, 'coord row_col=2'], empty_is_valid: bool = False) -> bool:
204	def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool:
205		"""check if a path is valid"""
206		# check path is not empty
207		if len(path) == 0:
208			return empty_is_valid
209
210		# check all coords in bounds of maze
211		if not np.all((path >= 0) & (path < self.grid_shape)):
212			return False
213
214		# check all nodes connected
215		for i in range(len(path) - 1):
216			if not self.nodes_connected(path[i], path[i + 1]):
217				return False
218		return True

check if a path is valid

def coord_degrees(self) -> jaxtyping.Int8[ndarray, 'row col']:
220	def coord_degrees(self) -> Int8[np.ndarray, "row col"]:
221		"""Returns an array with the connectivity degree of each coord.
222
223		I.e., how many neighbors each coord has.
224		"""
225		int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = (
226			self.connection_list.astype(np.int8)
227		)
228		degrees: Int8[np.ndarray, "row col"] = np.sum(
229			int_conn,
230			axis=0,
231		)  # Connections to east and south
232		degrees[:, 1:] += int_conn[1, :, :-1]  # Connections to west
233		degrees[1:, :] += int_conn[0, :-1, :]  # Connections to north
234		return degrees

Returns an array with the connectivity degree of each coord.

I.e., how many neighbors each coord has.

def get_coord_neighbors( self, c: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int]) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
236	def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray:
237		"""Returns an array of the neighboring, connected coords of `c`."""
238		c = np.array(c)  # type: ignore[assignment]
239		neighbors: list[Coord] = [
240			neighbor
241			for neighbor in (c + NEIGHBORS_MASK)
242			if (
243				(0 <= neighbor[0] < self.grid_shape[0])  # in x bounds
244				and (0 <= neighbor[1] < self.grid_shape[1])  # in y bounds
245				and self.nodes_connected(c, neighbor)  # connected
246			)
247		]
248
249		output: CoordArray = np.array(neighbors)
250		if len(neighbors) > 0:
251			assert output.shape == (
252				len(neighbors),
253				2,
254			), (
255				f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}"
256			)
257		return output

Returns an array of the neighboring, connected coords of c.

def gen_connected_component_from( self, c: jaxtyping.Int8[ndarray, 'row_col=2']) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
259	def gen_connected_component_from(self, c: Coord) -> CoordArray:
260		"""return the connected component from a given coordinate"""
261		# Stack for DFS
262		stack: list[Coord] = [c]
263
264		# Set to store visited nodes
265		visited: set[CoordTup] = set()
266
267		while stack:
268			current_node: Coord = stack.pop()
269			# this is fine since we know current_node is a coord and thus of length 2
270			visited.add(tuple(current_node))  # type: ignore[arg-type]
271
272			# Get the neighbors of the current node
273			neighbors = self.get_coord_neighbors(current_node)
274
275			# Iterate over neighbors
276			for neighbor in neighbors:
277				if tuple(neighbor) not in visited:
278					stack.append(neighbor)
279
280		return np.array(list(visited))

return the connected component from a given coordinate

def find_shortest_path( self, c_start: tuple[int, int] | jaxtyping.Int8[ndarray, 'row_col=2'], c_end: tuple[int, int] | jaxtyping.Int8[ndarray, 'row_col=2']) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
282	def find_shortest_path(
283		self,
284		c_start: CoordTup | Coord,
285		c_end: CoordTup | Coord,
286	) -> CoordArray:
287		"""find the shortest path between two coordinates, using A*"""
288		c_start = tuple(c_start)  # type: ignore[assignment]
289		c_end = tuple(c_end)  # type: ignore[assignment]
290
291		g_score: dict[CoordTup, float] = (
292			dict()
293		)  # cost of cheapest path to node from start currently known
294		f_score: dict[CoordTup, float] = {
295			c_start: 0.0,
296		}  # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end)
297
298		# init
299		g_score[c_start] = 0.0
300		g_score[c_start] = self.heuristic(c_start, c_end)
301
302		closed_vtx: set[CoordTup] = set()  # nodes already evaluated
303		# nodes to be evaluated
304		# we need a set of the tuples, dont place the ints in the set
305		open_vtx: set[CoordTup] = set([c_start])  # noqa: C405
306		source: dict[CoordTup, CoordTup] = (
307			dict()
308		)  # node immediately preceding each node in the path (currently known shortest path)
309
310		while open_vtx:
311			# get lowest f_score node
312			# mypy cant tell that c is of length 2
313			c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)])  # type: ignore[index]
314			# f_current: float = f_score[c_current]
315
316			# check if goal is reached
317			if c_end == c_current:
318				path: list[CoordTup] = [c_current]
319				p_current: CoordTup = c_current
320				while p_current in source:
321					p_current = source[p_current]
322					path.append(p_current)
323				# ----------------------------------------------------------------------
324				# this is the only return statement
325				return np.array(path[::-1])
326				# ----------------------------------------------------------------------
327
328			# close current node
329			closed_vtx.add(c_current)
330			open_vtx.remove(c_current)
331
332			# update g_score of neighbors
333			_np_neighbor: Coord
334			for _np_neighbor in self.get_coord_neighbors(c_current):
335				neighbor: CoordTup = tuple(_np_neighbor)
336
337				if neighbor in closed_vtx:
338					# already checked
339					continue
340				g_temp: float = g_score[c_current] + 1  # always 1 for maze neighbors
341
342				if neighbor not in open_vtx:
343					# found new vtx, so add
344					open_vtx.add(neighbor)
345
346				elif g_temp >= g_score[neighbor]:
347					# if already knew about this one, but current g_score is worse, skip
348					continue
349
350				# store g_score and source
351				source[neighbor] = c_current
352				g_score[neighbor] = g_temp
353				f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end)
354
355		raise ValueError(
356			"A solution could not be found!",
357			f"{c_start = }, {c_end = }",
358			self.as_ascii(),
359		)

find the shortest path between two coordinates, using A*

def get_nodes(self) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
361	def get_nodes(self) -> CoordArray:
362		"""return a list of all nodes in the maze"""
363		rows: Int[np.ndarray, "x y"]
364		cols: Int[np.ndarray, "x y"]
365		rows, cols = np.meshgrid(
366			range(self.grid_shape[0]),
367			range(self.grid_shape[1]),
368			indexing="ij",
369		)
370		nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T
371		return nodes

return a list of all nodes in the maze

def get_connected_component(self) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
373	def get_connected_component(self) -> CoordArray:
374		"""get the largest (and assumed only nonsingular) connected component of the maze
375
376		TODO: other connected components?
377		"""
378		if (self.generation_meta is None) or (
379			self.generation_meta.get("fully_connected", False)
380		):
381			# for fully connected case, pick any two positions
382			return self.get_nodes()
383		else:
384			# if metadata provided, use visited cells
385			visited_cells: set[CoordTup] | None = self.generation_meta.get(
386				"visited_cells",
387				None,
388			)
389			if visited_cells is None:
390				# TODO: dynamically generate visited_cells?
391				err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}"
392				raise ValueError(
393					err_msg,
394				)
395			visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells))
396			return visited_cells_np

get the largest (and assumed only nonsingular) connected component of the maze

TODO: other connected components?

def generate_random_path( self, allowed_start: list[tuple[int, int]] | None = None, allowed_end: list[tuple[int, int]] | None = None, deadend_start: bool = False, deadend_end: bool = False, endpoints_not_equal: bool = False, except_on_no_valid_endpoint: bool = True) -> Optional[jaxtyping.Int8[ndarray, 'coord row_col=2']]:
418	def generate_random_path(  # noqa: C901
419		self,
420		allowed_start: CoordList | None = None,
421		allowed_end: CoordList | None = None,
422		deadend_start: bool = False,
423		deadend_end: bool = False,
424		endpoints_not_equal: bool = False,
425		except_on_no_valid_endpoint: bool = True,
426	) -> typing.Optional[CoordArray]:
427		"""return a path between randomly chosen start and end nodes within the connected component
428
429		Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
430
431		# Parameters:
432		- `allowed_start : CoordList | None`
433			a list of allowed start positions. If `None`, any position in the connected component is allowed
434			(defaults to `None`)
435		- `allowed_end : CoordList | None`
436			a list of allowed end positions. If `None`, any position in the connected component is allowed
437			(defaults to `None`)
438		- `deadend_start : bool`
439			whether to ***force*** the start position to be a deadend (defaults to `False`)
440			(defaults to `False`)
441		- `deadend_end : bool`
442			whether to ***force*** the end position to be a deadend (defaults to `False`)
443			(defaults to `False`)
444		- `endpoints_not_equal : bool`
445			whether to ensure tha the start and end point are not the same
446			(defaults to `False`)
447		- `except_on_no_valid_endpoint : bool`
448			whether to raise an error if no valid start or end positions are found
449			if this is `False`, the function might return `None` and this must be handled by the caller
450			(defaults to `True`)
451
452		# Returns:
453		- `CoordArray`
454			a path between the selected start and end positions
455
456		# Raises:
457		- `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True`
458		"""
459		# we can't create a "path" in a single-node maze
460		assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, (  # noqa: PT018
461			f"can't create path in single-node maze: {self.as_ascii()}"
462		)
463
464		# get connected component
465		connected_component: CoordArray = self.get_connected_component()
466
467		# initialize start and end positions
468		positions: Int[np.int8, "2 2"]
469
470		# if no special conditions on start and end positions
471		if (allowed_start, allowed_end, deadend_start, deadend_end) == (
472			None,
473			None,
474			False,
475			False,
476		):
477			try:
478				positions = connected_component[  # type: ignore[assignment]
479					np.random.choice(
480						len(connected_component),
481						size=2,
482						replace=False,
483					)
484				]
485			except ValueError as e:
486				if except_on_no_valid_endpoint:
487					err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }"
488					raise NoValidEndpointException(
489						err_msg,
490					) from e
491				return None
492
493			return self.find_shortest_path(positions[0], positions[1])  # type: ignore[index]
494
495		# handle special conditions
496		connected_component_set: set[CoordTup] = set(map(tuple, connected_component))
497		# copy connected component set
498		allowed_start_set: set[CoordTup] = connected_component_set.copy()
499		allowed_end_set: set[CoordTup] = connected_component_set.copy()
500
501		# filter by explicitly allowed start and end positions
502		# '# type: ignore[assignment]' here because the returned tuple can be of any length
503		if allowed_start is not None:
504			allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set  # type: ignore[assignment]
505
506		if allowed_end is not None:
507			allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set  # type: ignore[assignment]
508
509		# filter by forcing deadends
510		if deadend_start:
511			allowed_start_set = set(
512				filter(
513					lambda x: len(self.get_coord_neighbors(x)) == 1,
514					allowed_start_set,
515				),
516			)
517
518		if deadend_end:
519			allowed_end_set = set(
520				filter(
521					lambda x: len(self.get_coord_neighbors(x)) == 1,
522					allowed_end_set,
523				),
524			)
525
526		# check we have valid positions
527		if len(allowed_start_set) == 0 or len(allowed_end_set) == 0:
528			if except_on_no_valid_endpoint:
529				err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }"
530				raise NoValidEndpointException(
531					err_msg,
532				)
533			return None
534
535		# randomly select start and end positions
536		try:
537			# ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok
538			start_pos: CoordTup = tuple(  # type: ignore[assignment]
539				list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))],
540			)
541			if endpoints_not_equal:
542				# remove start position from end positions
543				allowed_end_set.discard(start_pos)
544			end_pos: CoordTup = tuple(  # type: ignore[assignment]
545				list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))],
546			)
547		except ValueError as e:
548			if except_on_no_valid_endpoint:
549				err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }"
550				raise NoValidEndpointException(
551					err_msg,
552				) from e
553			return None
554
555		return self.find_shortest_path(start_pos, end_pos)

return a path between randomly chosen start and end nodes within the connected component

Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.

Parameters:

  • allowed_start : CoordList | None a list of allowed start positions. If None, any position in the connected component is allowed (defaults to None)
  • allowed_end : CoordList | None a list of allowed end positions. If None, any position in the connected component is allowed (defaults to None)
  • deadend_start : bool whether to force the start position to be a deadend (defaults to False) (defaults to False)
  • deadend_end : bool whether to force the end position to be a deadend (defaults to False) (defaults to False)
  • endpoints_not_equal : bool whether to ensure tha the start and end point are not the same (defaults to False)
  • except_on_no_valid_endpoint : bool whether to raise an error if no valid start or end positions are found if this is False, the function might return None and this must be handled by the caller (defaults to True)

Returns:

  • CoordArray a path between the selected start and end positions

Raises:

def as_adj_list( self, shuffle_d0: bool = True, shuffle_d1: bool = True) -> jaxtyping.Int8[ndarray, 'conn start_end coord']:
560	def as_adj_list(
561		self,
562		shuffle_d0: bool = True,
563		shuffle_d1: bool = True,
564	) -> Int8[np.ndarray, "conn start_end coord"]:
565		"""return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`"""
566		return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)

return the maze as an adjacency list, wraps maze_dataset.token_utils.connection_list_to_adj_list

@classmethod
def from_adj_list( cls, adj_list: jaxtyping.Int8[ndarray, 'conn start_end coord']) -> LatticeMaze:
568	@classmethod
569	def from_adj_list(
570		cls,
571		adj_list: Int8[np.ndarray, "conn start_end coord"],
572	) -> "LatticeMaze":
573		"""create a LatticeMaze from a list of connections
574
575		> [!NOTE]
576		> This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
577		"""
578		# this is where it would probably break for rectangular mazes
579		grid_n: int = adj_list.max() + 1
580
581		connection_list: ConnectionList = np.zeros(
582			(2, grid_n, grid_n),
583			dtype=np.bool_,
584		)
585
586		for c_start, c_end in adj_list:
587			# check that exactly 1 coordinate matches
588			if (c_start == c_end).sum() != 1:
589				raise ValueError("invalid connection")
590
591			# get the direction
592			d: int = (c_start != c_end).argmax()
593
594			x: int
595			y: int
596			# pick whichever has the lesser value in the direction `d`
597			if c_start[d] < c_end[d]:
598				x, y = c_start
599			else:
600				x, y = c_end
601
602			connection_list[d, x, y] = True
603
604		return LatticeMaze(
605			connection_list=connection_list,
606		)

create a LatticeMaze from a list of connections

Note

This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.

def as_adj_list_tokens(self) -> list[str | tuple[int, int]]:
608	def as_adj_list_tokens(self) -> list[str | CoordTup]:
609		"""(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead"""
610		warnings.warn(
611			"`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.",
612			TokenizerDeprecationWarning,
613		)
614		return [
615			SPECIAL_TOKENS.ADJLIST_START,
616			*chain.from_iterable(  # type: ignore[list-item]
617				[
618					[
619						tuple(c_s),
620						SPECIAL_TOKENS.CONNECTOR,
621						tuple(c_e),
622						SPECIAL_TOKENS.ADJACENCY_ENDLINE,
623					]
624					for c_s, c_e in self.as_adj_list()
625				],
626			),
627			SPECIAL_TOKENS.ADJLIST_END,
628		]

(deprecated!) turn the maze into adjacency list tokens, use MazeTokenizerModular instead

680	def as_tokens(
681		self,
682		maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
683	) -> list[str]:
684		"""serialize maze and solution to tokens"""
685		if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"):
686			return maze_tokenizer.to_tokens(self)  # type: ignore[union-attr]
687		else:
688			return self._as_tokens(maze_tokenizer)  # type: ignore[union-attr,arg-type]

serialize maze and solution to tokens

800	@classmethod
801	def from_tokens(
802		cls,
803		tokens: list[str],
804		maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
805	) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze":
806		"""Constructs a maze from a tokenization.
807
808		Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported.
809		"""
810		# HACK: type ignores here fine since we check the instance
811		if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
812			maze_tokenizer = maze_tokenizer.to_legacy_tokenizer()  # type: ignore[union-attr]
813		if (
814			isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular")
815			and not maze_tokenizer.is_legacy_equivalent()  # type: ignore[union-attr]
816		):
817			err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}."
818			raise NotImplementedError(
819				err_msg,
820			)
821
822		if isinstance(tokens, str):
823			tokens = tokens.split()
824
825		if maze_tokenizer.is_AOTP():  # type: ignore[union-attr]
826			return cls._from_tokens_AOTP(tokens, maze_tokenizer)  # type: ignore[arg-type]
827		else:
828			raise NotImplementedError("only AOTP tokenization is supported")

Constructs a maze from a tokenization.

Only legacy tokenizers and their MazeTokenizerModular analogs are supported.

def as_pixels( self, show_endpoints: bool = True, show_solution: bool = True) -> jaxtyping.Int[ndarray, 'x y rgb']:
859	def as_pixels(
860		self,
861		show_endpoints: bool = True,
862		show_solution: bool = True,
863	) -> PixelGrid:
864		"""convert the maze to a pixel grid
865
866		- useful as a simpler way of plotting the maze than the more complex `MazePlot`
867		- the same underlying representation as `as_ascii` but as an image
868		- used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data
869		"""
870		# HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze`
871		# but solution, start_pos, end_pos not always defined
872		# but its fine since we explicitly check the type
873		if show_solution and not show_endpoints:
874			raise ValueError("show_solution=True requires show_endpoints=True")
875		# convert original bool pixel grid to RGB
876		pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw()
877		pixel_grid: PixelGrid = np.full(
878			(*pixel_grid_bw.shape, 3),
879			PixelColors.WALL,
880			dtype=np.uint8,
881		)
882		pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN  # noqa: E712
883
884		if self.__class__ == LatticeMaze:
885			return pixel_grid
886
887		# set endpoints for TargetedLatticeMaze
888		if self.__class__ == TargetedLatticeMaze:
889			if show_endpoints:
890				pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
891					PixelColors.START
892				)
893				pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
894					PixelColors.END
895				)
896			return pixel_grid
897
898		# set solution -- we only reach this part if `self.__class__ == SolvedMaze`
899		if show_solution:
900			for coord in self.solution:  # type: ignore[attr-defined]
901				pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH
902
903			# set pixels between coords
904			for index, coord in enumerate(self.solution[:-1]):  # type: ignore[attr-defined]
905				next_coord = self.solution[index + 1]  # type: ignore[attr-defined]
906				# check they are adjacent using norm
907				assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, (
908					f"Coords {coord} and {next_coord} are not adjacent"
909				)
910				# set pixel between them
911				pixel_grid[
912					coord[0] * 2 + 1 + next_coord[0] - coord[0],
913					coord[1] * 2 + 1 + next_coord[1] - coord[1],
914				] = PixelColors.PATH
915
916			# set endpoints (again, since path would overwrite them)
917			pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
918				PixelColors.START
919			)
920			pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
921				PixelColors.END
922			)
923
924		return pixel_grid

convert the maze to a pixel grid

@classmethod
def from_pixels( cls, pixel_grid: jaxtyping.Int[ndarray, 'x y rgb']) -> LatticeMaze:
 979	@classmethod
 980	def from_pixels(
 981		cls,
 982		pixel_grid: PixelGrid,
 983	) -> "LatticeMaze":
 984		"""create a LatticeMaze from a pixel grid. reverse of `as_pixels`
 985
 986		# Raises:
 987		- `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze`
 988		"""
 989		connection_list: ConnectionList
 990		grid_shape: tuple[int, int]
 991
 992		# if a binary pixel grid, return regular LatticeMaze
 993		if len(pixel_grid.shape) == 2:  # noqa: PLR2004
 994			connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid)
 995			return LatticeMaze(connection_list=connection_list)
 996
 997		# otherwise, detect and check it's valid
 998		cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid)
 999		if cls not in cls_detected.__mro__:
1000			err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }"
1001			raise ValueError(
1002				err_msg,
1003			)
1004
1005		(
1006			connection_list,
1007			grid_shape,
1008			marked_pos,
1009		) = cls._from_pixel_grid_with_positions(
1010			pixel_grid=pixel_grid,
1011			marked_positions=dict(
1012				start=PixelColors.START,
1013				end=PixelColors.END,
1014				solution=PixelColors.PATH,
1015			),
1016		)
1017		# if we wanted a LatticeMaze, return it
1018		if cls == LatticeMaze:
1019			return LatticeMaze(connection_list=connection_list)
1020
1021		# otherwise, keep going
1022		temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list)
1023
1024		# start and end pos
1025		start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"]
1026		assert start_pos_arr.shape == (
1027			1,
1028			2,
1029		), (
1030			f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1031		)
1032		assert end_pos_arr.shape == (
1033			1,
1034			2,
1035		), (
1036			f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1037		)
1038
1039		start_pos: Coord = start_pos_arr[0]
1040		end_pos: Coord = end_pos_arr[0]
1041
1042		# return a TargetedLatticeMaze if that's what we wanted
1043		if cls == TargetedLatticeMaze:
1044			return TargetedLatticeMaze(
1045				connection_list=connection_list,
1046				start_pos=start_pos,
1047				end_pos=end_pos,
1048			)
1049
1050		# raw solution, only contains path elements and not start or end
1051		solution_raw: CoordArray = marked_pos["solution"]
1052		if len(solution_raw.shape) == 2:  # noqa: PLR2004
1053			assert solution_raw.shape[1] == 2, (  # noqa: PLR2004
1054				f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)"
1055			)
1056		elif solution_raw.shape == (0,):
1057			# the solution and end should be immediately adjacent
1058			assert np.sum(np.abs(start_pos - end_pos)) == 1, (
1059				f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given"
1060			)
1061
1062		# order the solution, by creating a list from the start to the end
1063		# add end pos, since we will iterate over all these starting from the start pos
1064		solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [
1065			tuple(end_pos),
1066		]
1067		# solution starts with start point
1068		solution: list[CoordTup] = [tuple(start_pos)]
1069		while solution[-1] != tuple(end_pos):
1070			# use `get_coord_neighbors` to find connected neighbors
1071			neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1])
1072			# TODO: make this less ugly
1073			assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), (  # noqa: PT018, PLR2004
1074				f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}"
1075			)
1076			# neighbors = neighbors[:, [1, 0]]
1077			# filter out neighbors that are not in the raw solution
1078			neighbors_filtered: CoordArray = np.array(
1079				[
1080					coord
1081					for coord in neighbors
1082					if (
1083						tuple(coord) in solution_raw_list
1084						and tuple(coord) not in solution
1085					)
1086				],
1087			)
1088			# assert only one element is left, and then add it to the solution
1089			assert neighbors_filtered.shape == (
1090				1,
1091				2,
1092			), (
1093				f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}"
1094			)
1095			solution.append(tuple(neighbors_filtered[0]))
1096
1097		# assert the solution is complete
1098		assert solution[0] == tuple(start_pos), (
1099			f"solution {solution} does not start at start_pos {start_pos}"
1100		)
1101		assert solution[-1] == tuple(end_pos), (
1102			f"solution {solution} does not end at end_pos {end_pos}"
1103		)
1104
1105		return cls(
1106			connection_list=np.array(connection_list),
1107			solution=np.array(solution),  # type: ignore[call-arg]
1108		)

create a LatticeMaze from a pixel grid. reverse of as_pixels

Raises:

def as_ascii(self, show_endpoints: bool = True, show_solution: bool = True) -> str:
1127	def as_ascii(
1128		self,
1129		show_endpoints: bool = True,
1130		show_solution: bool = True,
1131	) -> str:
1132		"""return an ASCII grid of the maze
1133
1134		useful for debugging in the terminal, or as it's own format
1135
1136		can be reversed with `LatticeMaze.from_ascii()`
1137		"""
1138		ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid()
1139		pixel_grid: PixelGrid = self.as_pixels(
1140			show_endpoints=show_endpoints,
1141			show_solution=show_solution,
1142		)
1143
1144		chars_replace: tuple = tuple()
1145		if show_endpoints:
1146			chars_replace += (AsciiChars.START, AsciiChars.END)
1147		if show_solution:
1148			chars_replace += (AsciiChars.PATH,)
1149
1150		for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1151			if ascii_char in chars_replace:
1152				ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char
1153
1154		return "\n".join("".join(row) for row in ascii_grid)

return an ASCII grid of the maze

useful for debugging in the terminal, or as it's own format

can be reversed with LatticeMaze.from_ascii()

@classmethod
def from_ascii(cls, ascii_str: str) -> LatticeMaze:
1156	@classmethod
1157	def from_ascii(cls, ascii_str: str) -> "LatticeMaze":
1158		"get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)"
1159		lines: list[str] = ascii_str.strip().split("\n")
1160		lines = [line.strip() for line in lines]
1161		ascii_grid: Shaped[np.ndarray, "x y"] = np.array(
1162			[list(line) for line in lines],
1163			dtype=str,
1164		)
1165		pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8)
1166
1167		for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1168			pixel_grid[ascii_grid == ascii_char] = pixel_color
1169
1170		return cls.from_pixels(pixel_grid)

get a LatticeMaze from an ASCII representation (reverses LaticeMaze.as_ascii)

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class TargetedLatticeMaze(LatticeMaze):
1176@serializable_dataclass(frozen=True, kw_only=True)
1177class TargetedLatticeMaze(LatticeMaze):  # type: ignore[misc]
1178	"""A LatticeMaze with a start and end position"""
1179
1180	# this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos
1181	# type ignore here because even though its a kw-only dataclass,
1182	# mypy doesn't like that non-default arguments are after default arguments
1183	start_pos: Coord = serializable_field(  # type: ignore[misc]
1184		assert_type=False,
1185	)
1186	end_pos: Coord = serializable_field(  # type: ignore[misc]
1187		assert_type=False,
1188	)
1189
1190	def __post_init__(self) -> None:
1191		"post init converts start and end pos to numpy arrays, checks they exist and are in bounds"
1192		# make things numpy arrays (very jank to override frozen dataclass)
1193		self.__dict__["start_pos"] = np.array(self.start_pos)
1194		self.__dict__["end_pos"] = np.array(self.end_pos)
1195		assert self.start_pos is not None
1196		assert self.end_pos is not None
1197		# check that start and end are in bounds
1198		if (
1199			self.start_pos[0] >= self.grid_shape[0]
1200			or self.start_pos[1] >= self.grid_shape[1]
1201		):
1202			err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}"
1203			raise ValueError(
1204				err_msg,
1205			)
1206		if (
1207			self.end_pos[0] >= self.grid_shape[0]
1208			or self.end_pos[1] >= self.grid_shape[1]
1209		):
1210			err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }"
1211			raise ValueError(
1212				err_msg,
1213			)
1214
1215	def __eq__(self, other: object) -> bool:
1216		"check equality, calls parent class equality check"
1217		return super().__eq__(other)
1218
1219	def _get_start_pos_tokens(self) -> list[str | CoordTup]:
1220		return [
1221			SPECIAL_TOKENS.ORIGIN_START,
1222			tuple(self.start_pos),
1223			SPECIAL_TOKENS.ORIGIN_END,
1224		]
1225
1226	def get_start_pos_tokens(self) -> list[str | CoordTup]:
1227		"(deprecated!) return the start position as a list of tokens"
1228		warnings.warn(
1229			"`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.",
1230			TokenizerDeprecationWarning,
1231		)
1232		return self._get_start_pos_tokens()
1233
1234	def _get_end_pos_tokens(self) -> list[str | CoordTup]:
1235		return [
1236			SPECIAL_TOKENS.TARGET_START,
1237			tuple(self.end_pos),
1238			SPECIAL_TOKENS.TARGET_END,
1239		]
1240
1241	def get_end_pos_tokens(self) -> list[str | CoordTup]:
1242		"(deprecated!) return the end position as a list of tokens"
1243		warnings.warn(
1244			"`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.",
1245			TokenizerDeprecationWarning,
1246		)
1247		return self._get_end_pos_tokens()
1248
1249	@classmethod
1250	def from_lattice_maze(
1251		cls,
1252		lattice_maze: LatticeMaze,
1253		start_pos: Coord | CoordTup,
1254		end_pos: Coord | CoordTup,
1255	) -> "TargetedLatticeMaze":
1256		"get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions"
1257		return cls(
1258			connection_list=lattice_maze.connection_list,
1259			start_pos=np.array(start_pos),
1260			end_pos=np.array(end_pos),
1261			generation_meta=lattice_maze.generation_meta,
1262		)

A LatticeMaze with a start and end position

TargetedLatticeMaze( *, connection_list: jaxtyping.Bool[ndarray, 'lattice_dim=2 row col'], generation_meta: dict | None = None, start_pos: jaxtyping.Int8[ndarray, 'row_col=2'], end_pos: jaxtyping.Int8[ndarray, 'row_col=2'])
start_pos: jaxtyping.Int8[ndarray, 'row_col=2']
end_pos: jaxtyping.Int8[ndarray, 'row_col=2']
def get_start_pos_tokens(self) -> list[str | tuple[int, int]]:
1226	def get_start_pos_tokens(self) -> list[str | CoordTup]:
1227		"(deprecated!) return the start position as a list of tokens"
1228		warnings.warn(
1229			"`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.",
1230			TokenizerDeprecationWarning,
1231		)
1232		return self._get_start_pos_tokens()

(deprecated!) return the start position as a list of tokens

def get_end_pos_tokens(self) -> list[str | tuple[int, int]]:
1241	def get_end_pos_tokens(self) -> list[str | CoordTup]:
1242		"(deprecated!) return the end position as a list of tokens"
1243		warnings.warn(
1244			"`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.",
1245			TokenizerDeprecationWarning,
1246		)
1247		return self._get_end_pos_tokens()

(deprecated!) return the end position as a list of tokens

@classmethod
def from_lattice_maze( cls, lattice_maze: LatticeMaze, start_pos: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], end_pos: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int]) -> TargetedLatticeMaze:
1249	@classmethod
1250	def from_lattice_maze(
1251		cls,
1252		lattice_maze: LatticeMaze,
1253		start_pos: Coord | CoordTup,
1254		end_pos: Coord | CoordTup,
1255	) -> "TargetedLatticeMaze":
1256		"get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions"
1257		return cls(
1258			connection_list=lattice_maze.connection_list,
1259			start_pos=np.array(start_pos),
1260			end_pos=np.array(end_pos),
1261			generation_meta=lattice_maze.generation_meta,
1262		)

get a TargetedLatticeMaze from a LatticeMaze by specifying start and end positions

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@serializable_dataclass(frozen=True, kw_only=True)
class SolvedMaze(TargetedLatticeMaze):
1265@serializable_dataclass(frozen=True, kw_only=True)
1266class SolvedMaze(TargetedLatticeMaze):  # type: ignore[misc]
1267	"""Stores a maze and a solution"""
1268
1269	solution: CoordArray = serializable_field(  # type: ignore[misc]
1270		assert_type=False,
1271	)
1272
1273	def __init__(
1274		self,
1275		connection_list: ConnectionList,
1276		solution: CoordArray,
1277		generation_meta: dict | None = None,
1278		start_pos: Coord | None = None,
1279		end_pos: Coord | None = None,
1280		allow_invalid: bool = False,
1281	) -> None:
1282		"""Create a SolvedMaze from a connection list and a solution
1283
1284		> DOCS: better documentation for this init method
1285		"""
1286		# figure out the solution
1287		solution_valid: bool = False
1288		if solution is not None:
1289			solution = np.array(solution)
1290			# note that a path length of 1 here is valid, since the start and end pos could be the same
1291			if (solution.shape[0] > 0) and (solution.shape[1] == 2):  # noqa: PLR2004
1292				solution_valid = True
1293
1294		if not solution_valid and not allow_invalid:
1295			err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }"
1296			raise ValueError(
1297				err_msg,
1298				f"{connection_list = }",
1299			)
1300
1301		# init the TargetedLatticeMaze
1302		super().__init__(
1303			connection_list=connection_list,
1304			generation_meta=generation_meta,
1305			# TODO: the argument type is stricter than the expected type but it still fails?
1306			# error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type
1307			# "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]"  [arg-type]
1308			start_pos=np.array(solution[0]) if solution_valid else None,  # type: ignore[arg-type]
1309			end_pos=np.array(solution[-1]) if solution_valid else None,  # type: ignore[arg-type]
1310		)
1311
1312		self.__dict__["solution"] = solution
1313
1314		# adjust the endpoints
1315		if not allow_invalid:
1316			if start_pos is not None:
1317				assert np.array_equal(np.array(start_pos), self.start_pos), (
1318					f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}"
1319				)
1320			if end_pos is not None:
1321				assert np.array_equal(np.array(end_pos), self.end_pos), (
1322					f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}"
1323				)
1324			# TODO: assert the path does not backtrack, walk through walls, etc?
1325
1326	def __eq__(self, other: object) -> bool:
1327		"check equality, calls parent class equality check"
1328		return super().__eq__(other)
1329
1330	def __hash__(self) -> int:
1331		"hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes"
1332		return hash((self.connection_list.tobytes(), self.solution.tobytes()))
1333
1334	def _get_solution_tokens(self) -> list[str | CoordTup]:
1335		return [
1336			SPECIAL_TOKENS.PATH_START,
1337			*[tuple(c) for c in self.solution],
1338			SPECIAL_TOKENS.PATH_END,
1339		]
1340
1341	def get_solution_tokens(self) -> list[str | CoordTup]:
1342		"(deprecated!) return the solution as a list of tokens"
1343		warnings.warn(
1344			"`LatticeMaze.get_solution_tokens` is deprecated.",
1345			TokenizerDeprecationWarning,
1346		)
1347		return self._get_solution_tokens()
1348
1349	# for backwards compatibility
1350	@property
1351	def maze(self) -> LatticeMaze:
1352		"(deprecated!) return the maze without the solution"
1353		warnings.warn(
1354			"`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.",
1355			DeprecationWarning,
1356		)
1357		return LatticeMaze(connection_list=self.connection_list)
1358
1359	# type ignore here since we're overriding a method with a different signature
1360	@classmethod
1361	def from_lattice_maze(  # type: ignore[override]
1362		cls,
1363		lattice_maze: LatticeMaze,
1364		solution: list[CoordTup] | CoordArray,
1365	) -> "SolvedMaze":
1366		"get a `SolvedMaze` from a `LatticeMaze` by specifying a solution"
1367		return cls(
1368			connection_list=lattice_maze.connection_list,
1369			solution=np.array(solution),
1370			generation_meta=lattice_maze.generation_meta,
1371		)
1372
1373	@classmethod
1374	def from_targeted_lattice_maze(
1375		cls,
1376		targeted_lattice_maze: TargetedLatticeMaze,
1377		solution: list[CoordTup] | CoordArray | None = None,
1378	) -> "SolvedMaze":
1379		"""solves the given targeted lattice maze and returns a SolvedMaze"""
1380		if solution is None:
1381			solution = targeted_lattice_maze.find_shortest_path(
1382				targeted_lattice_maze.start_pos,
1383				targeted_lattice_maze.end_pos,
1384			)
1385		return cls(
1386			connection_list=targeted_lattice_maze.connection_list,
1387			solution=np.array(solution),
1388			generation_meta=targeted_lattice_maze.generation_meta,
1389		)
1390
1391	def get_solution_forking_points(
1392		self,
1393		always_include_endpoints: bool = False,
1394	) -> tuple[list[int], CoordArray]:
1395		"""coordinates and their indicies from the solution where a fork is present
1396
1397		- if the start point is not a dead end, this counts as a fork
1398		- if the end point is not a dead end, this counts as a fork
1399		"""
1400		output_idxs: list[int] = list()
1401		output_coords: list[CoordTup] = list()
1402
1403		for idx, coord in enumerate(self.solution):
1404			# more than one choice for first coord, or more than 2 for any other
1405			# since the previous coord doesn't count as a choice
1406			is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1
1407			theshold: int = 1 if is_endpoint else 2
1408			if self.get_coord_neighbors(coord).shape[0] > theshold or (
1409				is_endpoint and always_include_endpoints
1410			):
1411				output_idxs.append(idx)
1412				output_coords.append(coord)
1413
1414		return output_idxs, np.array(output_coords)
1415
1416	def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]:
1417		"""coordinates from the solution where there is only a single (non-backtracking) point to move to
1418
1419		returns the complement of `get_solution_forking_points` from the path
1420		"""
1421		forks_idxs, _ = self.get_solution_forking_points()
1422		# HACK: idk why type ignore here
1423		return (  # type: ignore[return-value]
1424			np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0),
1425			np.delete(self.solution, forks_idxs, axis=0),
1426		)

Stores a maze and a solution

SolvedMaze( connection_list: jaxtyping.Bool[ndarray, 'lattice_dim=2 row col'], solution: jaxtyping.Int8[ndarray, 'coord row_col=2'], generation_meta: dict | None = None, start_pos: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None, end_pos: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None, allow_invalid: bool = False)
1273	def __init__(
1274		self,
1275		connection_list: ConnectionList,
1276		solution: CoordArray,
1277		generation_meta: dict | None = None,
1278		start_pos: Coord | None = None,
1279		end_pos: Coord | None = None,
1280		allow_invalid: bool = False,
1281	) -> None:
1282		"""Create a SolvedMaze from a connection list and a solution
1283
1284		> DOCS: better documentation for this init method
1285		"""
1286		# figure out the solution
1287		solution_valid: bool = False
1288		if solution is not None:
1289			solution = np.array(solution)
1290			# note that a path length of 1 here is valid, since the start and end pos could be the same
1291			if (solution.shape[0] > 0) and (solution.shape[1] == 2):  # noqa: PLR2004
1292				solution_valid = True
1293
1294		if not solution_valid and not allow_invalid:
1295			err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }"
1296			raise ValueError(
1297				err_msg,
1298				f"{connection_list = }",
1299			)
1300
1301		# init the TargetedLatticeMaze
1302		super().__init__(
1303			connection_list=connection_list,
1304			generation_meta=generation_meta,
1305			# TODO: the argument type is stricter than the expected type but it still fails?
1306			# error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type
1307			# "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]"  [arg-type]
1308			start_pos=np.array(solution[0]) if solution_valid else None,  # type: ignore[arg-type]
1309			end_pos=np.array(solution[-1]) if solution_valid else None,  # type: ignore[arg-type]
1310		)
1311
1312		self.__dict__["solution"] = solution
1313
1314		# adjust the endpoints
1315		if not allow_invalid:
1316			if start_pos is not None:
1317				assert np.array_equal(np.array(start_pos), self.start_pos), (
1318					f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}"
1319				)
1320			if end_pos is not None:
1321				assert np.array_equal(np.array(end_pos), self.end_pos), (
1322					f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}"
1323				)
1324			# TODO: assert the path does not backtrack, walk through walls, etc?

Create a SolvedMaze from a connection list and a solution

DOCS: better documentation for this init method

solution: jaxtyping.Int8[ndarray, 'coord row_col=2']
def get_solution_tokens(self) -> list[str | tuple[int, int]]:
1341	def get_solution_tokens(self) -> list[str | CoordTup]:
1342		"(deprecated!) return the solution as a list of tokens"
1343		warnings.warn(
1344			"`LatticeMaze.get_solution_tokens` is deprecated.",
1345			TokenizerDeprecationWarning,
1346		)
1347		return self._get_solution_tokens()

(deprecated!) return the solution as a list of tokens

maze: LatticeMaze
1350	@property
1351	def maze(self) -> LatticeMaze:
1352		"(deprecated!) return the maze without the solution"
1353		warnings.warn(
1354			"`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.",
1355			DeprecationWarning,
1356		)
1357		return LatticeMaze(connection_list=self.connection_list)

(deprecated!) return the maze without the solution

@classmethod
def from_lattice_maze( cls, lattice_maze: LatticeMaze, solution: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2']) -> SolvedMaze:
1360	@classmethod
1361	def from_lattice_maze(  # type: ignore[override]
1362		cls,
1363		lattice_maze: LatticeMaze,
1364		solution: list[CoordTup] | CoordArray,
1365	) -> "SolvedMaze":
1366		"get a `SolvedMaze` from a `LatticeMaze` by specifying a solution"
1367		return cls(
1368			connection_list=lattice_maze.connection_list,
1369			solution=np.array(solution),
1370			generation_meta=lattice_maze.generation_meta,
1371		)

get a SolvedMaze from a LatticeMaze by specifying a solution

@classmethod
def from_targeted_lattice_maze( cls, targeted_lattice_maze: TargetedLatticeMaze, solution: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | None = None) -> SolvedMaze:
1373	@classmethod
1374	def from_targeted_lattice_maze(
1375		cls,
1376		targeted_lattice_maze: TargetedLatticeMaze,
1377		solution: list[CoordTup] | CoordArray | None = None,
1378	) -> "SolvedMaze":
1379		"""solves the given targeted lattice maze and returns a SolvedMaze"""
1380		if solution is None:
1381			solution = targeted_lattice_maze.find_shortest_path(
1382				targeted_lattice_maze.start_pos,
1383				targeted_lattice_maze.end_pos,
1384			)
1385		return cls(
1386			connection_list=targeted_lattice_maze.connection_list,
1387			solution=np.array(solution),
1388			generation_meta=targeted_lattice_maze.generation_meta,
1389		)

solves the given targeted lattice maze and returns a SolvedMaze

def get_solution_forking_points( self, always_include_endpoints: bool = False) -> tuple[list[int], jaxtyping.Int8[ndarray, 'coord row_col=2']]:
1391	def get_solution_forking_points(
1392		self,
1393		always_include_endpoints: bool = False,
1394	) -> tuple[list[int], CoordArray]:
1395		"""coordinates and their indicies from the solution where a fork is present
1396
1397		- if the start point is not a dead end, this counts as a fork
1398		- if the end point is not a dead end, this counts as a fork
1399		"""
1400		output_idxs: list[int] = list()
1401		output_coords: list[CoordTup] = list()
1402
1403		for idx, coord in enumerate(self.solution):
1404			# more than one choice for first coord, or more than 2 for any other
1405			# since the previous coord doesn't count as a choice
1406			is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1
1407			theshold: int = 1 if is_endpoint else 2
1408			if self.get_coord_neighbors(coord).shape[0] > theshold or (
1409				is_endpoint and always_include_endpoints
1410			):
1411				output_idxs.append(idx)
1412				output_coords.append(coord)
1413
1414		return output_idxs, np.array(output_coords)

coordinates and their indicies from the solution where a fork is present

  • if the start point is not a dead end, this counts as a fork
  • if the end point is not a dead end, this counts as a fork
def get_solution_path_following_points(self) -> tuple[list[int], jaxtyping.Int8[ndarray, 'coord row_col=2']]:
1416	def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]:
1417		"""coordinates from the solution where there is only a single (non-backtracking) point to move to
1418
1419		returns the complement of `get_solution_forking_points` from the path
1420		"""
1421		forks_idxs, _ = self.get_solution_forking_points()
1422		# HACK: idk why type ignore here
1423		return (  # type: ignore[return-value]
1424			np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0),
1425			np.delete(self.solution, forks_idxs, axis=0),
1426		)

coordinates from the solution where there is only a single (non-backtracking) point to move to

returns the complement of get_solution_forking_points from the path

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

def detect_pixels_type( data: jaxtyping.Int[ndarray, 'x y rgb']) -> Type[LatticeMaze]:
1429def detect_pixels_type(data: PixelGrid) -> typing.Type[LatticeMaze]:
1430	"""Detects the type of pixels data by checking for the presence of start and end pixels"""
1431	if color_in_pixel_grid(data, PixelColors.START) or color_in_pixel_grid(
1432		data,
1433		PixelColors.END,
1434	):
1435		if color_in_pixel_grid(data, PixelColors.PATH):
1436			return SolvedMaze
1437		else:
1438			return TargetedLatticeMaze
1439	else:
1440		return LatticeMaze

Detects the type of pixels data by checking for the presence of start and end pixels