docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.maze

LatticeMaze and the classes like SolvedMaze that inherit from it, along with a variety of helper functions"

This package utilizes a simple, efficient representation of mazes. Using an adjacency list to represent mazes would lead to a poor lookup time of whether any given connection exists, whilst using a dense adjacency matrix would waste memory by failing to exploit the structure (e.g., only 4 of the diagonals would be filled in). Instead, we describe mazes with the following simple representation: for a $d$-dimensional lattice with $r$ rows and $c$ columns, we initialize a boolean array $A = {0, 1}^{d \times r \times c}$, which we refer to in the code as a connection_list. The value at $A[0,i,j]$ determines whether a downward connection exists from node $[i,j]$ to $[i+1, j]$. Likewise, the value at $A[1,i,j]$ determines whether a rightwards connection to $[i, j+1]$ exists. Thus, we avoid duplication of data about the existence of connections, at the cost of requiring additional care with indexing when looking for a connection upwards or to the left. Note that this setup allows for a periodic lattice.

To produce solutions to mazes, two points are selected uniformly at random without replacement from the connected component of the maze, and the $A^*$ algorithm is applied to find the shortest path between them.

Parallelization is implemented via the multiprocessing module in the Python standard library, and parallel generation can be controlled via keyword arguments to the MazeDataset.from_config() function.


 1r"""`LatticeMaze` and the classes like `SolvedMaze` that inherit from it, along with a variety of helper functions"
 2
 3This package utilizes a simple, efficient representation of mazes. Using an adjacency list to represent mazes would lead to a poor lookup time of whether any given connection exists, whilst using a dense adjacency matrix would waste memory by failing to exploit the structure (e.g., only 4 of the diagonals would be filled in).
 4Instead, we describe mazes with the following simple representation: for a $d$-dimensional lattice with $r$ rows and $c$ columns, we initialize a boolean array $A = \{0, 1\}^{d \times r \times c}$, which we refer to in the code as a `connection_list`. The value at $A[0,i,j]$ determines whether a downward connection exists from node $[i,j]$ to $[i+1, j]$. Likewise, the value at $A[1,i,j]$ determines whether a rightwards connection to $[i, j+1]$ exists. Thus, we avoid duplication of data about the existence of connections, at the cost of requiring additional care with indexing when looking for a connection upwards or to the left. Note that this setup allows for a periodic lattice.
 5
 6To produce solutions to mazes, two points are selected uniformly at random without replacement from the connected component of the maze, and the $A^*$ algorithm is applied to find the shortest path between them.
 7
 8Parallelization is implemented via the `multiprocessing` module in the Python standard library, and parallel generation can be controlled via keyword arguments to the `MazeDataset.from_config()` function.
 9"""
10
11from maze_dataset.maze.lattice_maze import (
12	AsciiChars,
13	ConnectionList,
14	Coord,
15	CoordArray,
16	LatticeMaze,
17	PixelColors,
18	SolvedMaze,
19	TargetedLatticeMaze,
20)
21
22__all__ = [
23	# submodules
24	"lattice_maze",
25	# imports
26	"SolvedMaze",
27	"TargetedLatticeMaze",
28	"LatticeMaze",
29	"ConnectionList",
30	"AsciiChars",
31	"Coord",
32	"CoordArray",
33	"PixelColors",
34]

@serializable_dataclass(frozen=True, kw_only=True)
class SolvedMaze(maze_dataset.maze.TargetedLatticeMaze):
1265@serializable_dataclass(frozen=True, kw_only=True)
1266class SolvedMaze(TargetedLatticeMaze):  # type: ignore[misc]
1267	"""Stores a maze and a solution"""
1268
1269	solution: CoordArray = serializable_field(  # type: ignore[misc]
1270		assert_type=False,
1271	)
1272
1273	def __init__(
1274		self,
1275		connection_list: ConnectionList,
1276		solution: CoordArray,
1277		generation_meta: dict | None = None,
1278		start_pos: Coord | None = None,
1279		end_pos: Coord | None = None,
1280		allow_invalid: bool = False,
1281	) -> None:
1282		"""Create a SolvedMaze from a connection list and a solution
1283
1284		> DOCS: better documentation for this init method
1285		"""
1286		# figure out the solution
1287		solution_valid: bool = False
1288		if solution is not None:
1289			solution = np.array(solution)
1290			# note that a path length of 1 here is valid, since the start and end pos could be the same
1291			if (solution.shape[0] > 0) and (solution.shape[1] == 2):  # noqa: PLR2004
1292				solution_valid = True
1293
1294		if not solution_valid and not allow_invalid:
1295			err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }"
1296			raise ValueError(
1297				err_msg,
1298				f"{connection_list = }",
1299			)
1300
1301		# init the TargetedLatticeMaze
1302		super().__init__(
1303			connection_list=connection_list,
1304			generation_meta=generation_meta,
1305			# TODO: the argument type is stricter than the expected type but it still fails?
1306			# error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type
1307			# "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]"  [arg-type]
1308			start_pos=np.array(solution[0]) if solution_valid else None,  # type: ignore[arg-type]
1309			end_pos=np.array(solution[-1]) if solution_valid else None,  # type: ignore[arg-type]
1310		)
1311
1312		self.__dict__["solution"] = solution
1313
1314		# adjust the endpoints
1315		if not allow_invalid:
1316			if start_pos is not None:
1317				assert np.array_equal(np.array(start_pos), self.start_pos), (
1318					f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}"
1319				)
1320			if end_pos is not None:
1321				assert np.array_equal(np.array(end_pos), self.end_pos), (
1322					f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}"
1323				)
1324			# TODO: assert the path does not backtrack, walk through walls, etc?
1325
1326	def __eq__(self, other: object) -> bool:
1327		"check equality, calls parent class equality check"
1328		return super().__eq__(other)
1329
1330	def __hash__(self) -> int:
1331		"hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes"
1332		return hash((self.connection_list.tobytes(), self.solution.tobytes()))
1333
1334	def _get_solution_tokens(self) -> list[str | CoordTup]:
1335		return [
1336			SPECIAL_TOKENS.PATH_START,
1337			*[tuple(c) for c in self.solution],
1338			SPECIAL_TOKENS.PATH_END,
1339		]
1340
1341	def get_solution_tokens(self) -> list[str | CoordTup]:
1342		"(deprecated!) return the solution as a list of tokens"
1343		warnings.warn(
1344			"`LatticeMaze.get_solution_tokens` is deprecated.",
1345			TokenizerDeprecationWarning,
1346		)
1347		return self._get_solution_tokens()
1348
1349	# for backwards compatibility
1350	@property
1351	def maze(self) -> LatticeMaze:
1352		"(deprecated!) return the maze without the solution"
1353		warnings.warn(
1354			"`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.",
1355			DeprecationWarning,
1356		)
1357		return LatticeMaze(connection_list=self.connection_list)
1358
1359	# type ignore here since we're overriding a method with a different signature
1360	@classmethod
1361	def from_lattice_maze(  # type: ignore[override]
1362		cls,
1363		lattice_maze: LatticeMaze,
1364		solution: list[CoordTup] | CoordArray,
1365	) -> "SolvedMaze":
1366		"get a `SolvedMaze` from a `LatticeMaze` by specifying a solution"
1367		return cls(
1368			connection_list=lattice_maze.connection_list,
1369			solution=np.array(solution),
1370			generation_meta=lattice_maze.generation_meta,
1371		)
1372
1373	@classmethod
1374	def from_targeted_lattice_maze(
1375		cls,
1376		targeted_lattice_maze: TargetedLatticeMaze,
1377		solution: list[CoordTup] | CoordArray | None = None,
1378	) -> "SolvedMaze":
1379		"""solves the given targeted lattice maze and returns a SolvedMaze"""
1380		if solution is None:
1381			solution = targeted_lattice_maze.find_shortest_path(
1382				targeted_lattice_maze.start_pos,
1383				targeted_lattice_maze.end_pos,
1384			)
1385		return cls(
1386			connection_list=targeted_lattice_maze.connection_list,
1387			solution=np.array(solution),
1388			generation_meta=targeted_lattice_maze.generation_meta,
1389		)
1390
1391	def get_solution_forking_points(
1392		self,
1393		always_include_endpoints: bool = False,
1394	) -> tuple[list[int], CoordArray]:
1395		"""coordinates and their indicies from the solution where a fork is present
1396
1397		- if the start point is not a dead end, this counts as a fork
1398		- if the end point is not a dead end, this counts as a fork
1399		"""
1400		output_idxs: list[int] = list()
1401		output_coords: list[CoordTup] = list()
1402
1403		for idx, coord in enumerate(self.solution):
1404			# more than one choice for first coord, or more than 2 for any other
1405			# since the previous coord doesn't count as a choice
1406			is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1
1407			theshold: int = 1 if is_endpoint else 2
1408			if self.get_coord_neighbors(coord).shape[0] > theshold or (
1409				is_endpoint and always_include_endpoints
1410			):
1411				output_idxs.append(idx)
1412				output_coords.append(coord)
1413
1414		return output_idxs, np.array(output_coords)
1415
1416	def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]:
1417		"""coordinates from the solution where there is only a single (non-backtracking) point to move to
1418
1419		returns the complement of `get_solution_forking_points` from the path
1420		"""
1421		forks_idxs, _ = self.get_solution_forking_points()
1422		# HACK: idk why type ignore here
1423		return (  # type: ignore[return-value]
1424			np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0),
1425			np.delete(self.solution, forks_idxs, axis=0),
1426		)

Stores a maze and a solution

SolvedMaze( connection_list: jaxtyping.Bool[ndarray, 'lattice_dim=2 row col'], solution: jaxtyping.Int8[ndarray, 'coord row_col=2'], generation_meta: dict | None = None, start_pos: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None, end_pos: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None, allow_invalid: bool = False)
1273	def __init__(
1274		self,
1275		connection_list: ConnectionList,
1276		solution: CoordArray,
1277		generation_meta: dict | None = None,
1278		start_pos: Coord | None = None,
1279		end_pos: Coord | None = None,
1280		allow_invalid: bool = False,
1281	) -> None:
1282		"""Create a SolvedMaze from a connection list and a solution
1283
1284		> DOCS: better documentation for this init method
1285		"""
1286		# figure out the solution
1287		solution_valid: bool = False
1288		if solution is not None:
1289			solution = np.array(solution)
1290			# note that a path length of 1 here is valid, since the start and end pos could be the same
1291			if (solution.shape[0] > 0) and (solution.shape[1] == 2):  # noqa: PLR2004
1292				solution_valid = True
1293
1294		if not solution_valid and not allow_invalid:
1295			err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }"
1296			raise ValueError(
1297				err_msg,
1298				f"{connection_list = }",
1299			)
1300
1301		# init the TargetedLatticeMaze
1302		super().__init__(
1303			connection_list=connection_list,
1304			generation_meta=generation_meta,
1305			# TODO: the argument type is stricter than the expected type but it still fails?
1306			# error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type
1307			# "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]"  [arg-type]
1308			start_pos=np.array(solution[0]) if solution_valid else None,  # type: ignore[arg-type]
1309			end_pos=np.array(solution[-1]) if solution_valid else None,  # type: ignore[arg-type]
1310		)
1311
1312		self.__dict__["solution"] = solution
1313
1314		# adjust the endpoints
1315		if not allow_invalid:
1316			if start_pos is not None:
1317				assert np.array_equal(np.array(start_pos), self.start_pos), (
1318					f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}"
1319				)
1320			if end_pos is not None:
1321				assert np.array_equal(np.array(end_pos), self.end_pos), (
1322					f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}"
1323				)
1324			# TODO: assert the path does not backtrack, walk through walls, etc?

Create a SolvedMaze from a connection list and a solution

DOCS: better documentation for this init method

solution: jaxtyping.Int8[ndarray, 'coord row_col=2']
def get_solution_tokens(self) -> list[str | tuple[int, int]]:
1341	def get_solution_tokens(self) -> list[str | CoordTup]:
1342		"(deprecated!) return the solution as a list of tokens"
1343		warnings.warn(
1344			"`LatticeMaze.get_solution_tokens` is deprecated.",
1345			TokenizerDeprecationWarning,
1346		)
1347		return self._get_solution_tokens()

(deprecated!) return the solution as a list of tokens

maze: LatticeMaze
1350	@property
1351	def maze(self) -> LatticeMaze:
1352		"(deprecated!) return the maze without the solution"
1353		warnings.warn(
1354			"`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.",
1355			DeprecationWarning,
1356		)
1357		return LatticeMaze(connection_list=self.connection_list)

(deprecated!) return the maze without the solution

@classmethod
def from_lattice_maze( cls, lattice_maze: LatticeMaze, solution: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2']) -> SolvedMaze:
1360	@classmethod
1361	def from_lattice_maze(  # type: ignore[override]
1362		cls,
1363		lattice_maze: LatticeMaze,
1364		solution: list[CoordTup] | CoordArray,
1365	) -> "SolvedMaze":
1366		"get a `SolvedMaze` from a `LatticeMaze` by specifying a solution"
1367		return cls(
1368			connection_list=lattice_maze.connection_list,
1369			solution=np.array(solution),
1370			generation_meta=lattice_maze.generation_meta,
1371		)

get a SolvedMaze from a LatticeMaze by specifying a solution

@classmethod
def from_targeted_lattice_maze( cls, targeted_lattice_maze: TargetedLatticeMaze, solution: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | None = None) -> SolvedMaze:
1373	@classmethod
1374	def from_targeted_lattice_maze(
1375		cls,
1376		targeted_lattice_maze: TargetedLatticeMaze,
1377		solution: list[CoordTup] | CoordArray | None = None,
1378	) -> "SolvedMaze":
1379		"""solves the given targeted lattice maze and returns a SolvedMaze"""
1380		if solution is None:
1381			solution = targeted_lattice_maze.find_shortest_path(
1382				targeted_lattice_maze.start_pos,
1383				targeted_lattice_maze.end_pos,
1384			)
1385		return cls(
1386			connection_list=targeted_lattice_maze.connection_list,
1387			solution=np.array(solution),
1388			generation_meta=targeted_lattice_maze.generation_meta,
1389		)

solves the given targeted lattice maze and returns a SolvedMaze

def get_solution_forking_points( self, always_include_endpoints: bool = False) -> tuple[list[int], jaxtyping.Int8[ndarray, 'coord row_col=2']]:
1391	def get_solution_forking_points(
1392		self,
1393		always_include_endpoints: bool = False,
1394	) -> tuple[list[int], CoordArray]:
1395		"""coordinates and their indicies from the solution where a fork is present
1396
1397		- if the start point is not a dead end, this counts as a fork
1398		- if the end point is not a dead end, this counts as a fork
1399		"""
1400		output_idxs: list[int] = list()
1401		output_coords: list[CoordTup] = list()
1402
1403		for idx, coord in enumerate(self.solution):
1404			# more than one choice for first coord, or more than 2 for any other
1405			# since the previous coord doesn't count as a choice
1406			is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1
1407			theshold: int = 1 if is_endpoint else 2
1408			if self.get_coord_neighbors(coord).shape[0] > theshold or (
1409				is_endpoint and always_include_endpoints
1410			):
1411				output_idxs.append(idx)
1412				output_coords.append(coord)
1413
1414		return output_idxs, np.array(output_coords)

coordinates and their indicies from the solution where a fork is present

  • if the start point is not a dead end, this counts as a fork
  • if the end point is not a dead end, this counts as a fork
def get_solution_path_following_points(self) -> tuple[list[int], jaxtyping.Int8[ndarray, 'coord row_col=2']]:
1416	def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]:
1417		"""coordinates from the solution where there is only a single (non-backtracking) point to move to
1418
1419		returns the complement of `get_solution_forking_points` from the path
1420		"""
1421		forks_idxs, _ = self.get_solution_forking_points()
1422		# HACK: idk why type ignore here
1423		return (  # type: ignore[return-value]
1424			np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0),
1425			np.delete(self.solution, forks_idxs, axis=0),
1426		)

coordinates from the solution where there is only a single (non-backtracking) point to move to

returns the complement of get_solution_forking_points from the path

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@serializable_dataclass(frozen=True, kw_only=True)
class TargetedLatticeMaze(maze_dataset.maze.LatticeMaze):
1176@serializable_dataclass(frozen=True, kw_only=True)
1177class TargetedLatticeMaze(LatticeMaze):  # type: ignore[misc]
1178	"""A LatticeMaze with a start and end position"""
1179
1180	# this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos
1181	# type ignore here because even though its a kw-only dataclass,
1182	# mypy doesn't like that non-default arguments are after default arguments
1183	start_pos: Coord = serializable_field(  # type: ignore[misc]
1184		assert_type=False,
1185	)
1186	end_pos: Coord = serializable_field(  # type: ignore[misc]
1187		assert_type=False,
1188	)
1189
1190	def __post_init__(self) -> None:
1191		"post init converts start and end pos to numpy arrays, checks they exist and are in bounds"
1192		# make things numpy arrays (very jank to override frozen dataclass)
1193		self.__dict__["start_pos"] = np.array(self.start_pos)
1194		self.__dict__["end_pos"] = np.array(self.end_pos)
1195		assert self.start_pos is not None
1196		assert self.end_pos is not None
1197		# check that start and end are in bounds
1198		if (
1199			self.start_pos[0] >= self.grid_shape[0]
1200			or self.start_pos[1] >= self.grid_shape[1]
1201		):
1202			err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}"
1203			raise ValueError(
1204				err_msg,
1205			)
1206		if (
1207			self.end_pos[0] >= self.grid_shape[0]
1208			or self.end_pos[1] >= self.grid_shape[1]
1209		):
1210			err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }"
1211			raise ValueError(
1212				err_msg,
1213			)
1214
1215	def __eq__(self, other: object) -> bool:
1216		"check equality, calls parent class equality check"
1217		return super().__eq__(other)
1218
1219	def _get_start_pos_tokens(self) -> list[str | CoordTup]:
1220		return [
1221			SPECIAL_TOKENS.ORIGIN_START,
1222			tuple(self.start_pos),
1223			SPECIAL_TOKENS.ORIGIN_END,
1224		]
1225
1226	def get_start_pos_tokens(self) -> list[str | CoordTup]:
1227		"(deprecated!) return the start position as a list of tokens"
1228		warnings.warn(
1229			"`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.",
1230			TokenizerDeprecationWarning,
1231		)
1232		return self._get_start_pos_tokens()
1233
1234	def _get_end_pos_tokens(self) -> list[str | CoordTup]:
1235		return [
1236			SPECIAL_TOKENS.TARGET_START,
1237			tuple(self.end_pos),
1238			SPECIAL_TOKENS.TARGET_END,
1239		]
1240
1241	def get_end_pos_tokens(self) -> list[str | CoordTup]:
1242		"(deprecated!) return the end position as a list of tokens"
1243		warnings.warn(
1244			"`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.",
1245			TokenizerDeprecationWarning,
1246		)
1247		return self._get_end_pos_tokens()
1248
1249	@classmethod
1250	def from_lattice_maze(
1251		cls,
1252		lattice_maze: LatticeMaze,
1253		start_pos: Coord | CoordTup,
1254		end_pos: Coord | CoordTup,
1255	) -> "TargetedLatticeMaze":
1256		"get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions"
1257		return cls(
1258			connection_list=lattice_maze.connection_list,
1259			start_pos=np.array(start_pos),
1260			end_pos=np.array(end_pos),
1261			generation_meta=lattice_maze.generation_meta,
1262		)

A LatticeMaze with a start and end position

TargetedLatticeMaze( *, connection_list: jaxtyping.Bool[ndarray, 'lattice_dim=2 row col'], generation_meta: dict | None = None, start_pos: jaxtyping.Int8[ndarray, 'row_col=2'], end_pos: jaxtyping.Int8[ndarray, 'row_col=2'])
start_pos: jaxtyping.Int8[ndarray, 'row_col=2']
end_pos: jaxtyping.Int8[ndarray, 'row_col=2']
def get_start_pos_tokens(self) -> list[str | tuple[int, int]]:
1226	def get_start_pos_tokens(self) -> list[str | CoordTup]:
1227		"(deprecated!) return the start position as a list of tokens"
1228		warnings.warn(
1229			"`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.",
1230			TokenizerDeprecationWarning,
1231		)
1232		return self._get_start_pos_tokens()

(deprecated!) return the start position as a list of tokens

def get_end_pos_tokens(self) -> list[str | tuple[int, int]]:
1241	def get_end_pos_tokens(self) -> list[str | CoordTup]:
1242		"(deprecated!) return the end position as a list of tokens"
1243		warnings.warn(
1244			"`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.",
1245			TokenizerDeprecationWarning,
1246		)
1247		return self._get_end_pos_tokens()

(deprecated!) return the end position as a list of tokens

@classmethod
def from_lattice_maze( cls, lattice_maze: LatticeMaze, start_pos: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int], end_pos: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int]) -> TargetedLatticeMaze:
1249	@classmethod
1250	def from_lattice_maze(
1251		cls,
1252		lattice_maze: LatticeMaze,
1253		start_pos: Coord | CoordTup,
1254		end_pos: Coord | CoordTup,
1255	) -> "TargetedLatticeMaze":
1256		"get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions"
1257		return cls(
1258			connection_list=lattice_maze.connection_list,
1259			start_pos=np.array(start_pos),
1260			end_pos=np.array(end_pos),
1261			generation_meta=lattice_maze.generation_meta,
1262		)

get a TargetedLatticeMaze from a LatticeMaze by specifying start and end positions

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@serializable_dataclass(frozen=True, kw_only=True, properties_to_serialize=['lattice_dim', 'generation_meta'])
class LatticeMaze(muutils.json_serialize.serializable_dataclass.SerializableDataclass):
 120@serializable_dataclass(
 121	frozen=True,
 122	kw_only=True,
 123	properties_to_serialize=["lattice_dim", "generation_meta"],
 124)
 125class LatticeMaze(SerializableDataclass):
 126	"""lattice maze (nodes on a lattice, connections only to neighboring nodes)
 127
 128	Connection List represents which nodes (N) are connected in each direction.
 129
 130	First and second elements represent rightward and downward connections,
 131	respectively.
 132
 133	Example:
 134		Connection list:
 135			[
 136				[ # down
 137					[F T],
 138					[F F]
 139				],
 140				[ # right
 141					[T F],
 142					[T F]
 143				]
 144			]
 145
 146		Nodes with connections
 147			N T N F
 148			F	T
 149			N T N F
 150			F	F
 151
 152		Graph:
 153			N - N
 154				|
 155			N - N
 156
 157	Note: the bottom row connections going down, and the
 158	right-hand connections going right, will always be False.
 159
 160	"""
 161
 162	connection_list: ConnectionList
 163	generation_meta: dict | None = serializable_field(default=None, compare=False)
 164
 165	lattice_dim = property(lambda self: self.connection_list.shape[0])
 166	grid_shape = property(lambda self: self.connection_list.shape[1:])
 167	n_connections = property(lambda self: self.connection_list.sum())
 168
 169	@property
 170	def grid_n(self) -> int:
 171		"grid size as int, raises `AssertionError` if not square"
 172		assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported"
 173		return self.grid_shape[0]
 174
 175	# ============================================================
 176	# basic methods
 177	# ============================================================
 178
 179	def __eq__(self, other: object) -> bool:
 180		"equality check calls super"
 181		return super().__eq__(other)
 182
 183	@staticmethod
 184	def heuristic(a: CoordTup, b: CoordTup) -> float:
 185		"""return manhattan distance between two points"""
 186		return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])
 187
 188	def __hash__(self) -> int:
 189		"""hash the connection list by converting connection list to bytes"""
 190		return hash(self.connection_list.tobytes())
 191
 192	def nodes_connected(self, a: Coord, b: Coord, /) -> bool:
 193		"""returns whether two nodes are connected"""
 194		delta: Coord = b - a
 195		if np.abs(delta).sum() != 1:
 196			# return false if not even adjacent
 197			return False
 198		else:
 199			# test for wall
 200			dim: int = int(np.argmax(np.abs(delta)))
 201			clist_node: Coord = a if (delta.sum() > 0) else b
 202			return self.connection_list[dim, clist_node[0], clist_node[1]]
 203
 204	def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool:
 205		"""check if a path is valid"""
 206		# check path is not empty
 207		if len(path) == 0:
 208			return empty_is_valid
 209
 210		# check all coords in bounds of maze
 211		if not np.all((path >= 0) & (path < self.grid_shape)):
 212			return False
 213
 214		# check all nodes connected
 215		for i in range(len(path) - 1):
 216			if not self.nodes_connected(path[i], path[i + 1]):
 217				return False
 218		return True
 219
 220	def coord_degrees(self) -> Int8[np.ndarray, "row col"]:
 221		"""Returns an array with the connectivity degree of each coord.
 222
 223		I.e., how many neighbors each coord has.
 224		"""
 225		int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = (
 226			self.connection_list.astype(np.int8)
 227		)
 228		degrees: Int8[np.ndarray, "row col"] = np.sum(
 229			int_conn,
 230			axis=0,
 231		)  # Connections to east and south
 232		degrees[:, 1:] += int_conn[1, :, :-1]  # Connections to west
 233		degrees[1:, :] += int_conn[0, :-1, :]  # Connections to north
 234		return degrees
 235
 236	def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray:
 237		"""Returns an array of the neighboring, connected coords of `c`."""
 238		c = np.array(c)  # type: ignore[assignment]
 239		neighbors: list[Coord] = [
 240			neighbor
 241			for neighbor in (c + NEIGHBORS_MASK)
 242			if (
 243				(0 <= neighbor[0] < self.grid_shape[0])  # in x bounds
 244				and (0 <= neighbor[1] < self.grid_shape[1])  # in y bounds
 245				and self.nodes_connected(c, neighbor)  # connected
 246			)
 247		]
 248
 249		output: CoordArray = np.array(neighbors)
 250		if len(neighbors) > 0:
 251			assert output.shape == (
 252				len(neighbors),
 253				2,
 254			), (
 255				f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}"
 256			)
 257		return output
 258
 259	def gen_connected_component_from(self, c: Coord) -> CoordArray:
 260		"""return the connected component from a given coordinate"""
 261		# Stack for DFS
 262		stack: list[Coord] = [c]
 263
 264		# Set to store visited nodes
 265		visited: set[CoordTup] = set()
 266
 267		while stack:
 268			current_node: Coord = stack.pop()
 269			# this is fine since we know current_node is a coord and thus of length 2
 270			visited.add(tuple(current_node))  # type: ignore[arg-type]
 271
 272			# Get the neighbors of the current node
 273			neighbors = self.get_coord_neighbors(current_node)
 274
 275			# Iterate over neighbors
 276			for neighbor in neighbors:
 277				if tuple(neighbor) not in visited:
 278					stack.append(neighbor)
 279
 280		return np.array(list(visited))
 281
 282	def find_shortest_path(
 283		self,
 284		c_start: CoordTup | Coord,
 285		c_end: CoordTup | Coord,
 286	) -> CoordArray:
 287		"""find the shortest path between two coordinates, using A*"""
 288		c_start = tuple(c_start)  # type: ignore[assignment]
 289		c_end = tuple(c_end)  # type: ignore[assignment]
 290
 291		g_score: dict[CoordTup, float] = (
 292			dict()
 293		)  # cost of cheapest path to node from start currently known
 294		f_score: dict[CoordTup, float] = {
 295			c_start: 0.0,
 296		}  # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end)
 297
 298		# init
 299		g_score[c_start] = 0.0
 300		g_score[c_start] = self.heuristic(c_start, c_end)
 301
 302		closed_vtx: set[CoordTup] = set()  # nodes already evaluated
 303		# nodes to be evaluated
 304		# we need a set of the tuples, dont place the ints in the set
 305		open_vtx: set[CoordTup] = set([c_start])  # noqa: C405
 306		source: dict[CoordTup, CoordTup] = (
 307			dict()
 308		)  # node immediately preceding each node in the path (currently known shortest path)
 309
 310		while open_vtx:
 311			# get lowest f_score node
 312			# mypy cant tell that c is of length 2
 313			c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)])  # type: ignore[index]
 314			# f_current: float = f_score[c_current]
 315
 316			# check if goal is reached
 317			if c_end == c_current:
 318				path: list[CoordTup] = [c_current]
 319				p_current: CoordTup = c_current
 320				while p_current in source:
 321					p_current = source[p_current]
 322					path.append(p_current)
 323				# ----------------------------------------------------------------------
 324				# this is the only return statement
 325				return np.array(path[::-1])
 326				# ----------------------------------------------------------------------
 327
 328			# close current node
 329			closed_vtx.add(c_current)
 330			open_vtx.remove(c_current)
 331
 332			# update g_score of neighbors
 333			_np_neighbor: Coord
 334			for _np_neighbor in self.get_coord_neighbors(c_current):
 335				neighbor: CoordTup = tuple(_np_neighbor)
 336
 337				if neighbor in closed_vtx:
 338					# already checked
 339					continue
 340				g_temp: float = g_score[c_current] + 1  # always 1 for maze neighbors
 341
 342				if neighbor not in open_vtx:
 343					# found new vtx, so add
 344					open_vtx.add(neighbor)
 345
 346				elif g_temp >= g_score[neighbor]:
 347					# if already knew about this one, but current g_score is worse, skip
 348					continue
 349
 350				# store g_score and source
 351				source[neighbor] = c_current
 352				g_score[neighbor] = g_temp
 353				f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end)
 354
 355		raise ValueError(
 356			"A solution could not be found!",
 357			f"{c_start = }, {c_end = }",
 358			self.as_ascii(),
 359		)
 360
 361	def get_nodes(self) -> CoordArray:
 362		"""return a list of all nodes in the maze"""
 363		rows: Int[np.ndarray, "x y"]
 364		cols: Int[np.ndarray, "x y"]
 365		rows, cols = np.meshgrid(
 366			range(self.grid_shape[0]),
 367			range(self.grid_shape[1]),
 368			indexing="ij",
 369		)
 370		nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T
 371		return nodes
 372
 373	def get_connected_component(self) -> CoordArray:
 374		"""get the largest (and assumed only nonsingular) connected component of the maze
 375
 376		TODO: other connected components?
 377		"""
 378		if (self.generation_meta is None) or (
 379			self.generation_meta.get("fully_connected", False)
 380		):
 381			# for fully connected case, pick any two positions
 382			return self.get_nodes()
 383		else:
 384			# if metadata provided, use visited cells
 385			visited_cells: set[CoordTup] | None = self.generation_meta.get(
 386				"visited_cells",
 387				None,
 388			)
 389			if visited_cells is None:
 390				# TODO: dynamically generate visited_cells?
 391				err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}"
 392				raise ValueError(
 393					err_msg,
 394				)
 395			visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells))
 396			return visited_cells_np
 397
 398	@typing.overload
 399	def generate_random_path(
 400		self,
 401		allowed_start: CoordList | None = None,
 402		allowed_end: CoordList | None = None,
 403		deadend_start: bool = False,
 404		deadend_end: bool = False,
 405		endpoints_not_equal: bool = False,
 406		except_on_no_valid_endpoint: typing.Literal[True] = True,
 407	) -> CoordArray: ...
 408	@typing.overload
 409	def generate_random_path(
 410		self,
 411		allowed_start: CoordList | None = None,
 412		allowed_end: CoordList | None = None,
 413		deadend_start: bool = False,
 414		deadend_end: bool = False,
 415		endpoints_not_equal: bool = False,
 416		except_on_no_valid_endpoint: typing.Literal[False] = False,
 417	) -> typing.Optional[CoordArray]: ...
 418	def generate_random_path(  # noqa: C901
 419		self,
 420		allowed_start: CoordList | None = None,
 421		allowed_end: CoordList | None = None,
 422		deadend_start: bool = False,
 423		deadend_end: bool = False,
 424		endpoints_not_equal: bool = False,
 425		except_on_no_valid_endpoint: bool = True,
 426	) -> typing.Optional[CoordArray]:
 427		"""return a path between randomly chosen start and end nodes within the connected component
 428
 429		Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
 430
 431		# Parameters:
 432		- `allowed_start : CoordList | None`
 433			a list of allowed start positions. If `None`, any position in the connected component is allowed
 434			(defaults to `None`)
 435		- `allowed_end : CoordList | None`
 436			a list of allowed end positions. If `None`, any position in the connected component is allowed
 437			(defaults to `None`)
 438		- `deadend_start : bool`
 439			whether to ***force*** the start position to be a deadend (defaults to `False`)
 440			(defaults to `False`)
 441		- `deadend_end : bool`
 442			whether to ***force*** the end position to be a deadend (defaults to `False`)
 443			(defaults to `False`)
 444		- `endpoints_not_equal : bool`
 445			whether to ensure tha the start and end point are not the same
 446			(defaults to `False`)
 447		- `except_on_no_valid_endpoint : bool`
 448			whether to raise an error if no valid start or end positions are found
 449			if this is `False`, the function might return `None` and this must be handled by the caller
 450			(defaults to `True`)
 451
 452		# Returns:
 453		- `CoordArray`
 454			a path between the selected start and end positions
 455
 456		# Raises:
 457		- `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True`
 458		"""
 459		# we can't create a "path" in a single-node maze
 460		assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, (  # noqa: PT018
 461			f"can't create path in single-node maze: {self.as_ascii()}"
 462		)
 463
 464		# get connected component
 465		connected_component: CoordArray = self.get_connected_component()
 466
 467		# initialize start and end positions
 468		positions: Int[np.int8, "2 2"]
 469
 470		# if no special conditions on start and end positions
 471		if (allowed_start, allowed_end, deadend_start, deadend_end) == (
 472			None,
 473			None,
 474			False,
 475			False,
 476		):
 477			try:
 478				positions = connected_component[  # type: ignore[assignment]
 479					np.random.choice(
 480						len(connected_component),
 481						size=2,
 482						replace=False,
 483					)
 484				]
 485			except ValueError as e:
 486				if except_on_no_valid_endpoint:
 487					err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }"
 488					raise NoValidEndpointException(
 489						err_msg,
 490					) from e
 491				return None
 492
 493			return self.find_shortest_path(positions[0], positions[1])  # type: ignore[index]
 494
 495		# handle special conditions
 496		connected_component_set: set[CoordTup] = set(map(tuple, connected_component))
 497		# copy connected component set
 498		allowed_start_set: set[CoordTup] = connected_component_set.copy()
 499		allowed_end_set: set[CoordTup] = connected_component_set.copy()
 500
 501		# filter by explicitly allowed start and end positions
 502		# '# type: ignore[assignment]' here because the returned tuple can be of any length
 503		if allowed_start is not None:
 504			allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set  # type: ignore[assignment]
 505
 506		if allowed_end is not None:
 507			allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set  # type: ignore[assignment]
 508
 509		# filter by forcing deadends
 510		if deadend_start:
 511			allowed_start_set = set(
 512				filter(
 513					lambda x: len(self.get_coord_neighbors(x)) == 1,
 514					allowed_start_set,
 515				),
 516			)
 517
 518		if deadend_end:
 519			allowed_end_set = set(
 520				filter(
 521					lambda x: len(self.get_coord_neighbors(x)) == 1,
 522					allowed_end_set,
 523				),
 524			)
 525
 526		# check we have valid positions
 527		if len(allowed_start_set) == 0 or len(allowed_end_set) == 0:
 528			if except_on_no_valid_endpoint:
 529				err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }"
 530				raise NoValidEndpointException(
 531					err_msg,
 532				)
 533			return None
 534
 535		# randomly select start and end positions
 536		try:
 537			# ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok
 538			start_pos: CoordTup = tuple(  # type: ignore[assignment]
 539				list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))],
 540			)
 541			if endpoints_not_equal:
 542				# remove start position from end positions
 543				allowed_end_set.discard(start_pos)
 544			end_pos: CoordTup = tuple(  # type: ignore[assignment]
 545				list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))],
 546			)
 547		except ValueError as e:
 548			if except_on_no_valid_endpoint:
 549				err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }"
 550				raise NoValidEndpointException(
 551					err_msg,
 552				) from e
 553			return None
 554
 555		return self.find_shortest_path(start_pos, end_pos)
 556
 557	# ============================================================
 558	# to and from adjacency list
 559	# ============================================================
 560	def as_adj_list(
 561		self,
 562		shuffle_d0: bool = True,
 563		shuffle_d1: bool = True,
 564	) -> Int8[np.ndarray, "conn start_end coord"]:
 565		"""return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`"""
 566		return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)
 567
 568	@classmethod
 569	def from_adj_list(
 570		cls,
 571		adj_list: Int8[np.ndarray, "conn start_end coord"],
 572	) -> "LatticeMaze":
 573		"""create a LatticeMaze from a list of connections
 574
 575		> [!NOTE]
 576		> This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
 577		"""
 578		# this is where it would probably break for rectangular mazes
 579		grid_n: int = adj_list.max() + 1
 580
 581		connection_list: ConnectionList = np.zeros(
 582			(2, grid_n, grid_n),
 583			dtype=np.bool_,
 584		)
 585
 586		for c_start, c_end in adj_list:
 587			# check that exactly 1 coordinate matches
 588			if (c_start == c_end).sum() != 1:
 589				raise ValueError("invalid connection")
 590
 591			# get the direction
 592			d: int = (c_start != c_end).argmax()
 593
 594			x: int
 595			y: int
 596			# pick whichever has the lesser value in the direction `d`
 597			if c_start[d] < c_end[d]:
 598				x, y = c_start
 599			else:
 600				x, y = c_end
 601
 602			connection_list[d, x, y] = True
 603
 604		return LatticeMaze(
 605			connection_list=connection_list,
 606		)
 607
 608	def as_adj_list_tokens(self) -> list[str | CoordTup]:
 609		"""(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead"""
 610		warnings.warn(
 611			"`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.",
 612			TokenizerDeprecationWarning,
 613		)
 614		return [
 615			SPECIAL_TOKENS.ADJLIST_START,
 616			*chain.from_iterable(  # type: ignore[list-item]
 617				[
 618					[
 619						tuple(c_s),
 620						SPECIAL_TOKENS.CONNECTOR,
 621						tuple(c_e),
 622						SPECIAL_TOKENS.ADJACENCY_ENDLINE,
 623					]
 624					for c_s, c_e in self.as_adj_list()
 625				],
 626			),
 627			SPECIAL_TOKENS.ADJLIST_END,
 628		]
 629
 630	def _as_adj_list_tokens(self) -> list[str | CoordTup]:
 631		return [
 632			SPECIAL_TOKENS.ADJLIST_START,
 633			*chain.from_iterable(  # type: ignore[list-item]
 634				[
 635					[
 636						tuple(c_s),
 637						SPECIAL_TOKENS.CONNECTOR,
 638						tuple(c_e),
 639						SPECIAL_TOKENS.ADJACENCY_ENDLINE,
 640					]
 641					for c_s, c_e in self.as_adj_list()
 642				],
 643			),
 644			SPECIAL_TOKENS.ADJLIST_END,
 645		]
 646
 647	def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]:
 648		"""turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples"""
 649		output: list[CoordTup | str] = self._as_adj_list_tokens()
 650		# if getattr(self, "start_pos", None) is not None:
 651		if isinstance(self, TargetedLatticeMaze):
 652			output += self._get_start_pos_tokens()
 653		if isinstance(self, TargetedLatticeMaze):
 654			output += self._get_end_pos_tokens()
 655		if isinstance(self, SolvedMaze):
 656			output += self._get_solution_tokens()
 657		return output
 658
 659	def _as_tokens(
 660		self,
 661		maze_tokenizer: "MazeTokenizer | TokenizationMode",
 662	) -> list[str]:
 663		# type ignores here fine since we check the instance
 664		if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
 665			maze_tokenizer = maze_tokenizer.to_legacy_tokenizer()  # type: ignore[union-attr]
 666		if (
 667			isinstance_by_type_name(maze_tokenizer, "MazeTokenizer")
 668			and maze_tokenizer.is_AOTP()  # type: ignore[union-attr]
 669		):
 670			coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP()
 671			coords_processed: list[str] = maze_tokenizer.coords_to_strings(  # type: ignore[union-attr]
 672				coords=coords_raw,
 673				when_noncoord="include",
 674			)
 675			return coords_processed
 676		else:
 677			err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}"
 678			raise NotImplementedError(err_msg)
 679
 680	def as_tokens(
 681		self,
 682		maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
 683	) -> list[str]:
 684		"""serialize maze and solution to tokens"""
 685		if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"):
 686			return maze_tokenizer.to_tokens(self)  # type: ignore[union-attr]
 687		else:
 688			return self._as_tokens(maze_tokenizer)  # type: ignore[union-attr,arg-type]
 689
 690	@classmethod
 691	def _from_tokens_AOTP(
 692		cls,
 693		tokens: list[str],
 694		maze_tokenizer: "MazeTokenizer | MazeTokenizerModular",
 695	) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze":
 696		"""create a LatticeMaze from a list of tokens"""
 697		# figure out what input format
 698		# ========================================
 699		if tokens[0] == SPECIAL_TOKENS.ADJLIST_START:
 700			adj_list_tokens = get_adj_list_tokens(tokens)
 701		else:
 702			# If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens
 703			adj_list_tokens = tokens
 704			warnings.warn(
 705				"Assuming input is just adjacency list tokens, no special tokens found",
 706			)
 707
 708		# process edges for adjacency list
 709		# ========================================
 710		edges: list[list[str]] = list_split(
 711			adj_list_tokens,
 712			SPECIAL_TOKENS.ADJACENCY_ENDLINE,
 713		)
 714
 715		coordinates: list[tuple[CoordTup, CoordTup]] = list()
 716		for e in edges:
 717			# skip last endline
 718			if len(e) != 0:
 719				# convert to coords, split start and end
 720				e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords(
 721					e,
 722					when_noncoord="include",
 723				)
 724				# this assertion depends on the tokenizer having exactly one token for the connector
 725				# which is also why we "include" above
 726				# the connector token is discarded below
 727				assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }"  # noqa: PLR2004
 728				assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, (
 729					f"invalid edge: {e = } {e_coords = }"
 730				)
 731				e_coords_first: CoordTup = e_coords[0]  # type: ignore[assignment]
 732				e_coords_last: CoordTup = e_coords[-1]  # type: ignore[assignment]
 733				coordinates.append((e_coords_first, e_coords_last))
 734
 735		assert all(len(c) == DIM_2 for c in coordinates), (
 736			f"invalid coordinates: {coordinates = }"
 737		)
 738		adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates)
 739		assert tuple(adj_list.shape) == (
 740			len(coordinates),
 741			2,
 742			2,
 743		), f"invalid adj_list: {adj_list.shape = } {coordinates = }"
 744
 745		output_maze: LatticeMaze = cls.from_adj_list(adj_list)
 746
 747		# add start and end positions
 748		# ========================================
 749		is_targeted: bool = False
 750		if all(
 751			x in tokens
 752			for x in (
 753				SPECIAL_TOKENS.ORIGIN_START,
 754				SPECIAL_TOKENS.ORIGIN_END,
 755				SPECIAL_TOKENS.TARGET_START,
 756				SPECIAL_TOKENS.TARGET_END,
 757			)
 758		):
 759			start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords(
 760				get_origin_tokens(tokens),
 761				when_noncoord="error",
 762			)
 763			end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords(
 764				get_target_tokens(tokens),
 765				when_noncoord="error",
 766			)
 767			assert len(start_pos_list) == 1, (
 768				f"invalid start_pos_list: {start_pos_list = }"
 769			)
 770			assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }"
 771
 772			start_pos: CoordTup = start_pos_list[0]
 773			end_pos: CoordTup = end_pos_list[0]
 774
 775			output_maze = TargetedLatticeMaze.from_lattice_maze(
 776				lattice_maze=output_maze,
 777				start_pos=start_pos,
 778				end_pos=end_pos,
 779			)
 780
 781			is_targeted = True
 782
 783		if all(
 784			x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END)
 785		):
 786			assert is_targeted, "maze must be targeted to have a solution"
 787			solution: list[CoordTup] = maze_tokenizer.strings_to_coords(
 788				get_path_tokens(tokens, trim_end=True),
 789				when_noncoord="error",
 790			)
 791			output_maze = SolvedMaze.from_targeted_lattice_maze(
 792				# HACK: I think this is fine, but im not sure
 793				targeted_lattice_maze=output_maze,  # type: ignore[arg-type]
 794				solution=solution,
 795			)
 796
 797		return output_maze
 798
 799	# TODO: any way to get return type hinting working for this?
 800	@classmethod
 801	def from_tokens(
 802		cls,
 803		tokens: list[str],
 804		maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
 805	) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze":
 806		"""Constructs a maze from a tokenization.
 807
 808		Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported.
 809		"""
 810		# HACK: type ignores here fine since we check the instance
 811		if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
 812			maze_tokenizer = maze_tokenizer.to_legacy_tokenizer()  # type: ignore[union-attr]
 813		if (
 814			isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular")
 815			and not maze_tokenizer.is_legacy_equivalent()  # type: ignore[union-attr]
 816		):
 817			err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}."
 818			raise NotImplementedError(
 819				err_msg,
 820			)
 821
 822		if isinstance(tokens, str):
 823			tokens = tokens.split()
 824
 825		if maze_tokenizer.is_AOTP():  # type: ignore[union-attr]
 826			return cls._from_tokens_AOTP(tokens, maze_tokenizer)  # type: ignore[arg-type]
 827		else:
 828			raise NotImplementedError("only AOTP tokenization is supported")
 829
 830	# ============================================================
 831	# to and from pixels
 832	# ============================================================
 833	def _as_pixels_bw(self) -> BinaryPixelGrid:
 834		assert self.lattice_dim == DIM_2, "only 2D mazes are supported"
 835		# Create an empty pixel grid with walls
 836		pixel_grid: Int[np.ndarray, "x y"] = np.full(
 837			(self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1),
 838			False,
 839			dtype=np.bool_,
 840		)
 841
 842		# Set white nodes
 843		pixel_grid[1::2, 1::2] = True
 844
 845		# Set white connections (downward)
 846		for i, row in enumerate(self.connection_list[0]):
 847			for j, connected in enumerate(row):
 848				if connected:
 849					pixel_grid[i * 2 + 2, j * 2 + 1] = True
 850
 851		# Set white connections (rightward)
 852		for i, row in enumerate(self.connection_list[1]):
 853			for j, connected in enumerate(row):
 854				if connected:
 855					pixel_grid[i * 2 + 1, j * 2 + 2] = True
 856
 857		return pixel_grid
 858
 859	def as_pixels(
 860		self,
 861		show_endpoints: bool = True,
 862		show_solution: bool = True,
 863	) -> PixelGrid:
 864		"""convert the maze to a pixel grid
 865
 866		- useful as a simpler way of plotting the maze than the more complex `MazePlot`
 867		- the same underlying representation as `as_ascii` but as an image
 868		- used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data
 869		"""
 870		# HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze`
 871		# but solution, start_pos, end_pos not always defined
 872		# but its fine since we explicitly check the type
 873		if show_solution and not show_endpoints:
 874			raise ValueError("show_solution=True requires show_endpoints=True")
 875		# convert original bool pixel grid to RGB
 876		pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw()
 877		pixel_grid: PixelGrid = np.full(
 878			(*pixel_grid_bw.shape, 3),
 879			PixelColors.WALL,
 880			dtype=np.uint8,
 881		)
 882		pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN  # noqa: E712
 883
 884		if self.__class__ == LatticeMaze:
 885			return pixel_grid
 886
 887		# set endpoints for TargetedLatticeMaze
 888		if self.__class__ == TargetedLatticeMaze:
 889			if show_endpoints:
 890				pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 891					PixelColors.START
 892				)
 893				pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 894					PixelColors.END
 895				)
 896			return pixel_grid
 897
 898		# set solution -- we only reach this part if `self.__class__ == SolvedMaze`
 899		if show_solution:
 900			for coord in self.solution:  # type: ignore[attr-defined]
 901				pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH
 902
 903			# set pixels between coords
 904			for index, coord in enumerate(self.solution[:-1]):  # type: ignore[attr-defined]
 905				next_coord = self.solution[index + 1]  # type: ignore[attr-defined]
 906				# check they are adjacent using norm
 907				assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, (
 908					f"Coords {coord} and {next_coord} are not adjacent"
 909				)
 910				# set pixel between them
 911				pixel_grid[
 912					coord[0] * 2 + 1 + next_coord[0] - coord[0],
 913					coord[1] * 2 + 1 + next_coord[1] - coord[1],
 914				] = PixelColors.PATH
 915
 916			# set endpoints (again, since path would overwrite them)
 917			pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 918				PixelColors.START
 919			)
 920			pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
 921				PixelColors.END
 922			)
 923
 924		return pixel_grid
 925
 926	@classmethod
 927	def _from_pixel_grid_bw(
 928		cls,
 929		pixel_grid: BinaryPixelGrid,
 930	) -> tuple[ConnectionList, tuple[int, int]]:
 931		grid_shape: tuple[int, int] = (
 932			pixel_grid.shape[0] // 2,
 933			pixel_grid.shape[1] // 2,
 934		)
 935		connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_)
 936
 937		# Extract downward connections
 938		connection_list[0] = pixel_grid[2::2, 1::2]
 939
 940		# Extract rightward connections
 941		connection_list[1] = pixel_grid[1::2, 2::2]
 942
 943		return connection_list, grid_shape
 944
 945	@classmethod
 946	def _from_pixel_grid_with_positions(
 947		cls,
 948		pixel_grid: PixelGrid | BinaryPixelGrid,
 949		marked_positions: dict[str, RGB],
 950	) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]:
 951		# Convert RGB pixel grid to Bool pixel grid
 952		# error: Incompatible types in assignment (expression has type
 953		# "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]",
 954		# variable has type "ndarray[Any, Any]")  [assignment]
 955		pixel_grid_bw: BinaryPixelGrid = ~np.all(  # type: ignore[assignment]
 956			pixel_grid == PixelColors.WALL,
 957			axis=-1,
 958		)
 959		connection_list: ConnectionList
 960		grid_shape: tuple[int, int]
 961		connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw)
 962
 963		# Find any marked positions
 964		out_positions: dict[str, CoordArray] = dict()
 965		for key, color in marked_positions.items():
 966			pos_temp: Int[np.ndarray, "x y"] = np.argwhere(
 967				np.all(pixel_grid == color, axis=-1),
 968			)
 969			pos_save: list[CoordTup] = list()
 970			for pos in pos_temp:
 971				# if it is a coordinate and not connection (transform position, %2==1)
 972				if pos[0] % 2 == 1 and pos[1] % 2 == 1:
 973					pos_save.append((pos[0] // 2, pos[1] // 2))
 974
 975			out_positions[key] = np.array(pos_save)
 976
 977		return connection_list, grid_shape, out_positions
 978
 979	@classmethod
 980	def from_pixels(
 981		cls,
 982		pixel_grid: PixelGrid,
 983	) -> "LatticeMaze":
 984		"""create a LatticeMaze from a pixel grid. reverse of `as_pixels`
 985
 986		# Raises:
 987		- `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze`
 988		"""
 989		connection_list: ConnectionList
 990		grid_shape: tuple[int, int]
 991
 992		# if a binary pixel grid, return regular LatticeMaze
 993		if len(pixel_grid.shape) == 2:  # noqa: PLR2004
 994			connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid)
 995			return LatticeMaze(connection_list=connection_list)
 996
 997		# otherwise, detect and check it's valid
 998		cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid)
 999		if cls not in cls_detected.__mro__:
1000			err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }"
1001			raise ValueError(
1002				err_msg,
1003			)
1004
1005		(
1006			connection_list,
1007			grid_shape,
1008			marked_pos,
1009		) = cls._from_pixel_grid_with_positions(
1010			pixel_grid=pixel_grid,
1011			marked_positions=dict(
1012				start=PixelColors.START,
1013				end=PixelColors.END,
1014				solution=PixelColors.PATH,
1015			),
1016		)
1017		# if we wanted a LatticeMaze, return it
1018		if cls == LatticeMaze:
1019			return LatticeMaze(connection_list=connection_list)
1020
1021		# otherwise, keep going
1022		temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list)
1023
1024		# start and end pos
1025		start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"]
1026		assert start_pos_arr.shape == (
1027			1,
1028			2,
1029		), (
1030			f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1031		)
1032		assert end_pos_arr.shape == (
1033			1,
1034			2,
1035		), (
1036			f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1037		)
1038
1039		start_pos: Coord = start_pos_arr[0]
1040		end_pos: Coord = end_pos_arr[0]
1041
1042		# return a TargetedLatticeMaze if that's what we wanted
1043		if cls == TargetedLatticeMaze:
1044			return TargetedLatticeMaze(
1045				connection_list=connection_list,
1046				start_pos=start_pos,
1047				end_pos=end_pos,
1048			)
1049
1050		# raw solution, only contains path elements and not start or end
1051		solution_raw: CoordArray = marked_pos["solution"]
1052		if len(solution_raw.shape) == 2:  # noqa: PLR2004
1053			assert solution_raw.shape[1] == 2, (  # noqa: PLR2004
1054				f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)"
1055			)
1056		elif solution_raw.shape == (0,):
1057			# the solution and end should be immediately adjacent
1058			assert np.sum(np.abs(start_pos - end_pos)) == 1, (
1059				f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given"
1060			)
1061
1062		# order the solution, by creating a list from the start to the end
1063		# add end pos, since we will iterate over all these starting from the start pos
1064		solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [
1065			tuple(end_pos),
1066		]
1067		# solution starts with start point
1068		solution: list[CoordTup] = [tuple(start_pos)]
1069		while solution[-1] != tuple(end_pos):
1070			# use `get_coord_neighbors` to find connected neighbors
1071			neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1])
1072			# TODO: make this less ugly
1073			assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), (  # noqa: PT018, PLR2004
1074				f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}"
1075			)
1076			# neighbors = neighbors[:, [1, 0]]
1077			# filter out neighbors that are not in the raw solution
1078			neighbors_filtered: CoordArray = np.array(
1079				[
1080					coord
1081					for coord in neighbors
1082					if (
1083						tuple(coord) in solution_raw_list
1084						and tuple(coord) not in solution
1085					)
1086				],
1087			)
1088			# assert only one element is left, and then add it to the solution
1089			assert neighbors_filtered.shape == (
1090				1,
1091				2,
1092			), (
1093				f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}"
1094			)
1095			solution.append(tuple(neighbors_filtered[0]))
1096
1097		# assert the solution is complete
1098		assert solution[0] == tuple(start_pos), (
1099			f"solution {solution} does not start at start_pos {start_pos}"
1100		)
1101		assert solution[-1] == tuple(end_pos), (
1102			f"solution {solution} does not end at end_pos {end_pos}"
1103		)
1104
1105		return cls(
1106			connection_list=np.array(connection_list),
1107			solution=np.array(solution),  # type: ignore[call-arg]
1108		)
1109
1110	# ============================================================
1111	# to and from ASCII
1112	# ============================================================
1113	def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]:
1114		# Get the pixel grid using to_pixels().
1115		pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw()
1116
1117		# Replace pixel values with ASCII characters.
1118		ascii_grid: Shaped[np.ndarray, "x y"] = np.full(
1119			pixel_grid.shape,
1120			AsciiChars.WALL,
1121			dtype=str,
1122		)
1123		ascii_grid[pixel_grid == True] = AsciiChars.OPEN  # noqa: E712
1124
1125		return ascii_grid
1126
1127	def as_ascii(
1128		self,
1129		show_endpoints: bool = True,
1130		show_solution: bool = True,
1131	) -> str:
1132		"""return an ASCII grid of the maze
1133
1134		useful for debugging in the terminal, or as it's own format
1135
1136		can be reversed with `LatticeMaze.from_ascii()`
1137		"""
1138		ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid()
1139		pixel_grid: PixelGrid = self.as_pixels(
1140			show_endpoints=show_endpoints,
1141			show_solution=show_solution,
1142		)
1143
1144		chars_replace: tuple = tuple()
1145		if show_endpoints:
1146			chars_replace += (AsciiChars.START, AsciiChars.END)
1147		if show_solution:
1148			chars_replace += (AsciiChars.PATH,)
1149
1150		for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1151			if ascii_char in chars_replace:
1152				ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char
1153
1154		return "\n".join("".join(row) for row in ascii_grid)
1155
1156	@classmethod
1157	def from_ascii(cls, ascii_str: str) -> "LatticeMaze":
1158		"get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)"
1159		lines: list[str] = ascii_str.strip().split("\n")
1160		lines = [line.strip() for line in lines]
1161		ascii_grid: Shaped[np.ndarray, "x y"] = np.array(
1162			[list(line) for line in lines],
1163			dtype=str,
1164		)
1165		pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8)
1166
1167		for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1168			pixel_grid[ascii_grid == ascii_char] = pixel_color
1169
1170		return cls.from_pixels(pixel_grid)

lattice maze (nodes on a lattice, connections only to neighboring nodes)

Connection List represents which nodes (N) are connected in each direction.

First and second elements represent rightward and downward connections, respectively.

Example: Connection list: [ [ # down [F T], [F F] ], [ # right [T F], [T F] ] ]

    Nodes with connections
            N T N F
            F       T
            N T N F
            F       F

    Graph:
            N - N
                    |
            N - N

Note: the bottom row connections going down, and the right-hand connections going right, will always be False.

LatticeMaze( *, connection_list: jaxtyping.Bool[ndarray, 'lattice_dim=2 row col'], generation_meta: dict | None = None)
connection_list: jaxtyping.Bool[ndarray, 'lattice_dim=2 row col']
generation_meta: dict | None = None
lattice_dim
165	lattice_dim = property(lambda self: self.connection_list.shape[0])
grid_shape
166	grid_shape = property(lambda self: self.connection_list.shape[1:])
n_connections
167	n_connections = property(lambda self: self.connection_list.sum())
grid_n: int
169	@property
170	def grid_n(self) -> int:
171		"grid size as int, raises `AssertionError` if not square"
172		assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported"
173		return self.grid_shape[0]

grid size as int, raises AssertionError if not square

@staticmethod
def heuristic(a: tuple[int, int], b: tuple[int, int]) -> float:
183	@staticmethod
184	def heuristic(a: CoordTup, b: CoordTup) -> float:
185		"""return manhattan distance between two points"""
186		return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1])

return manhattan distance between two points

def nodes_connected( self, a: jaxtyping.Int8[ndarray, 'row_col=2'], b: jaxtyping.Int8[ndarray, 'row_col=2'], /) -> bool:
192	def nodes_connected(self, a: Coord, b: Coord, /) -> bool:
193		"""returns whether two nodes are connected"""
194		delta: Coord = b - a
195		if np.abs(delta).sum() != 1:
196			# return false if not even adjacent
197			return False
198		else:
199			# test for wall
200			dim: int = int(np.argmax(np.abs(delta)))
201			clist_node: Coord = a if (delta.sum() > 0) else b
202			return self.connection_list[dim, clist_node[0], clist_node[1]]

returns whether two nodes are connected

def is_valid_path( self, path: jaxtyping.Int8[ndarray, 'coord row_col=2'], empty_is_valid: bool = False) -> bool:
204	def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool:
205		"""check if a path is valid"""
206		# check path is not empty
207		if len(path) == 0:
208			return empty_is_valid
209
210		# check all coords in bounds of maze
211		if not np.all((path >= 0) & (path < self.grid_shape)):
212			return False
213
214		# check all nodes connected
215		for i in range(len(path) - 1):
216			if not self.nodes_connected(path[i], path[i + 1]):
217				return False
218		return True

check if a path is valid

def coord_degrees(self) -> jaxtyping.Int8[ndarray, 'row col']:
220	def coord_degrees(self) -> Int8[np.ndarray, "row col"]:
221		"""Returns an array with the connectivity degree of each coord.
222
223		I.e., how many neighbors each coord has.
224		"""
225		int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = (
226			self.connection_list.astype(np.int8)
227		)
228		degrees: Int8[np.ndarray, "row col"] = np.sum(
229			int_conn,
230			axis=0,
231		)  # Connections to east and south
232		degrees[:, 1:] += int_conn[1, :, :-1]  # Connections to west
233		degrees[1:, :] += int_conn[0, :-1, :]  # Connections to north
234		return degrees

Returns an array with the connectivity degree of each coord.

I.e., how many neighbors each coord has.

def get_coord_neighbors( self, c: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int]) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
236	def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray:
237		"""Returns an array of the neighboring, connected coords of `c`."""
238		c = np.array(c)  # type: ignore[assignment]
239		neighbors: list[Coord] = [
240			neighbor
241			for neighbor in (c + NEIGHBORS_MASK)
242			if (
243				(0 <= neighbor[0] < self.grid_shape[0])  # in x bounds
244				and (0 <= neighbor[1] < self.grid_shape[1])  # in y bounds
245				and self.nodes_connected(c, neighbor)  # connected
246			)
247		]
248
249		output: CoordArray = np.array(neighbors)
250		if len(neighbors) > 0:
251			assert output.shape == (
252				len(neighbors),
253				2,
254			), (
255				f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}"
256			)
257		return output

Returns an array of the neighboring, connected coords of c.

def gen_connected_component_from( self, c: jaxtyping.Int8[ndarray, 'row_col=2']) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
259	def gen_connected_component_from(self, c: Coord) -> CoordArray:
260		"""return the connected component from a given coordinate"""
261		# Stack for DFS
262		stack: list[Coord] = [c]
263
264		# Set to store visited nodes
265		visited: set[CoordTup] = set()
266
267		while stack:
268			current_node: Coord = stack.pop()
269			# this is fine since we know current_node is a coord and thus of length 2
270			visited.add(tuple(current_node))  # type: ignore[arg-type]
271
272			# Get the neighbors of the current node
273			neighbors = self.get_coord_neighbors(current_node)
274
275			# Iterate over neighbors
276			for neighbor in neighbors:
277				if tuple(neighbor) not in visited:
278					stack.append(neighbor)
279
280		return np.array(list(visited))

return the connected component from a given coordinate

def find_shortest_path( self, c_start: tuple[int, int] | jaxtyping.Int8[ndarray, 'row_col=2'], c_end: tuple[int, int] | jaxtyping.Int8[ndarray, 'row_col=2']) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
282	def find_shortest_path(
283		self,
284		c_start: CoordTup | Coord,
285		c_end: CoordTup | Coord,
286	) -> CoordArray:
287		"""find the shortest path between two coordinates, using A*"""
288		c_start = tuple(c_start)  # type: ignore[assignment]
289		c_end = tuple(c_end)  # type: ignore[assignment]
290
291		g_score: dict[CoordTup, float] = (
292			dict()
293		)  # cost of cheapest path to node from start currently known
294		f_score: dict[CoordTup, float] = {
295			c_start: 0.0,
296		}  # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end)
297
298		# init
299		g_score[c_start] = 0.0
300		g_score[c_start] = self.heuristic(c_start, c_end)
301
302		closed_vtx: set[CoordTup] = set()  # nodes already evaluated
303		# nodes to be evaluated
304		# we need a set of the tuples, dont place the ints in the set
305		open_vtx: set[CoordTup] = set([c_start])  # noqa: C405
306		source: dict[CoordTup, CoordTup] = (
307			dict()
308		)  # node immediately preceding each node in the path (currently known shortest path)
309
310		while open_vtx:
311			# get lowest f_score node
312			# mypy cant tell that c is of length 2
313			c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)])  # type: ignore[index]
314			# f_current: float = f_score[c_current]
315
316			# check if goal is reached
317			if c_end == c_current:
318				path: list[CoordTup] = [c_current]
319				p_current: CoordTup = c_current
320				while p_current in source:
321					p_current = source[p_current]
322					path.append(p_current)
323				# ----------------------------------------------------------------------
324				# this is the only return statement
325				return np.array(path[::-1])
326				# ----------------------------------------------------------------------
327
328			# close current node
329			closed_vtx.add(c_current)
330			open_vtx.remove(c_current)
331
332			# update g_score of neighbors
333			_np_neighbor: Coord
334			for _np_neighbor in self.get_coord_neighbors(c_current):
335				neighbor: CoordTup = tuple(_np_neighbor)
336
337				if neighbor in closed_vtx:
338					# already checked
339					continue
340				g_temp: float = g_score[c_current] + 1  # always 1 for maze neighbors
341
342				if neighbor not in open_vtx:
343					# found new vtx, so add
344					open_vtx.add(neighbor)
345
346				elif g_temp >= g_score[neighbor]:
347					# if already knew about this one, but current g_score is worse, skip
348					continue
349
350				# store g_score and source
351				source[neighbor] = c_current
352				g_score[neighbor] = g_temp
353				f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end)
354
355		raise ValueError(
356			"A solution could not be found!",
357			f"{c_start = }, {c_end = }",
358			self.as_ascii(),
359		)

find the shortest path between two coordinates, using A*

def get_nodes(self) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
361	def get_nodes(self) -> CoordArray:
362		"""return a list of all nodes in the maze"""
363		rows: Int[np.ndarray, "x y"]
364		cols: Int[np.ndarray, "x y"]
365		rows, cols = np.meshgrid(
366			range(self.grid_shape[0]),
367			range(self.grid_shape[1]),
368			indexing="ij",
369		)
370		nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T
371		return nodes

return a list of all nodes in the maze

def get_connected_component(self) -> jaxtyping.Int8[ndarray, 'coord row_col=2']:
373	def get_connected_component(self) -> CoordArray:
374		"""get the largest (and assumed only nonsingular) connected component of the maze
375
376		TODO: other connected components?
377		"""
378		if (self.generation_meta is None) or (
379			self.generation_meta.get("fully_connected", False)
380		):
381			# for fully connected case, pick any two positions
382			return self.get_nodes()
383		else:
384			# if metadata provided, use visited cells
385			visited_cells: set[CoordTup] | None = self.generation_meta.get(
386				"visited_cells",
387				None,
388			)
389			if visited_cells is None:
390				# TODO: dynamically generate visited_cells?
391				err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}"
392				raise ValueError(
393					err_msg,
394				)
395			visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells))
396			return visited_cells_np

get the largest (and assumed only nonsingular) connected component of the maze

TODO: other connected components?

def generate_random_path( self, allowed_start: list[tuple[int, int]] | None = None, allowed_end: list[tuple[int, int]] | None = None, deadend_start: bool = False, deadend_end: bool = False, endpoints_not_equal: bool = False, except_on_no_valid_endpoint: bool = True) -> Optional[jaxtyping.Int8[ndarray, 'coord row_col=2']]:
418	def generate_random_path(  # noqa: C901
419		self,
420		allowed_start: CoordList | None = None,
421		allowed_end: CoordList | None = None,
422		deadend_start: bool = False,
423		deadend_end: bool = False,
424		endpoints_not_equal: bool = False,
425		except_on_no_valid_endpoint: bool = True,
426	) -> typing.Optional[CoordArray]:
427		"""return a path between randomly chosen start and end nodes within the connected component
428
429		Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.
430
431		# Parameters:
432		- `allowed_start : CoordList | None`
433			a list of allowed start positions. If `None`, any position in the connected component is allowed
434			(defaults to `None`)
435		- `allowed_end : CoordList | None`
436			a list of allowed end positions. If `None`, any position in the connected component is allowed
437			(defaults to `None`)
438		- `deadend_start : bool`
439			whether to ***force*** the start position to be a deadend (defaults to `False`)
440			(defaults to `False`)
441		- `deadend_end : bool`
442			whether to ***force*** the end position to be a deadend (defaults to `False`)
443			(defaults to `False`)
444		- `endpoints_not_equal : bool`
445			whether to ensure tha the start and end point are not the same
446			(defaults to `False`)
447		- `except_on_no_valid_endpoint : bool`
448			whether to raise an error if no valid start or end positions are found
449			if this is `False`, the function might return `None` and this must be handled by the caller
450			(defaults to `True`)
451
452		# Returns:
453		- `CoordArray`
454			a path between the selected start and end positions
455
456		# Raises:
457		- `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True`
458		"""
459		# we can't create a "path" in a single-node maze
460		assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, (  # noqa: PT018
461			f"can't create path in single-node maze: {self.as_ascii()}"
462		)
463
464		# get connected component
465		connected_component: CoordArray = self.get_connected_component()
466
467		# initialize start and end positions
468		positions: Int[np.int8, "2 2"]
469
470		# if no special conditions on start and end positions
471		if (allowed_start, allowed_end, deadend_start, deadend_end) == (
472			None,
473			None,
474			False,
475			False,
476		):
477			try:
478				positions = connected_component[  # type: ignore[assignment]
479					np.random.choice(
480						len(connected_component),
481						size=2,
482						replace=False,
483					)
484				]
485			except ValueError as e:
486				if except_on_no_valid_endpoint:
487					err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }"
488					raise NoValidEndpointException(
489						err_msg,
490					) from e
491				return None
492
493			return self.find_shortest_path(positions[0], positions[1])  # type: ignore[index]
494
495		# handle special conditions
496		connected_component_set: set[CoordTup] = set(map(tuple, connected_component))
497		# copy connected component set
498		allowed_start_set: set[CoordTup] = connected_component_set.copy()
499		allowed_end_set: set[CoordTup] = connected_component_set.copy()
500
501		# filter by explicitly allowed start and end positions
502		# '# type: ignore[assignment]' here because the returned tuple can be of any length
503		if allowed_start is not None:
504			allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set  # type: ignore[assignment]
505
506		if allowed_end is not None:
507			allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set  # type: ignore[assignment]
508
509		# filter by forcing deadends
510		if deadend_start:
511			allowed_start_set = set(
512				filter(
513					lambda x: len(self.get_coord_neighbors(x)) == 1,
514					allowed_start_set,
515				),
516			)
517
518		if deadend_end:
519			allowed_end_set = set(
520				filter(
521					lambda x: len(self.get_coord_neighbors(x)) == 1,
522					allowed_end_set,
523				),
524			)
525
526		# check we have valid positions
527		if len(allowed_start_set) == 0 or len(allowed_end_set) == 0:
528			if except_on_no_valid_endpoint:
529				err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }"
530				raise NoValidEndpointException(
531					err_msg,
532				)
533			return None
534
535		# randomly select start and end positions
536		try:
537			# ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok
538			start_pos: CoordTup = tuple(  # type: ignore[assignment]
539				list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))],
540			)
541			if endpoints_not_equal:
542				# remove start position from end positions
543				allowed_end_set.discard(start_pos)
544			end_pos: CoordTup = tuple(  # type: ignore[assignment]
545				list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))],
546			)
547		except ValueError as e:
548			if except_on_no_valid_endpoint:
549				err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }"
550				raise NoValidEndpointException(
551					err_msg,
552				) from e
553			return None
554
555		return self.find_shortest_path(start_pos, end_pos)

return a path between randomly chosen start and end nodes within the connected component

Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end.

Parameters:

  • allowed_start : CoordList | None a list of allowed start positions. If None, any position in the connected component is allowed (defaults to None)
  • allowed_end : CoordList | None a list of allowed end positions. If None, any position in the connected component is allowed (defaults to None)
  • deadend_start : bool whether to force the start position to be a deadend (defaults to False) (defaults to False)
  • deadend_end : bool whether to force the end position to be a deadend (defaults to False) (defaults to False)
  • endpoints_not_equal : bool whether to ensure tha the start and end point are not the same (defaults to False)
  • except_on_no_valid_endpoint : bool whether to raise an error if no valid start or end positions are found if this is False, the function might return None and this must be handled by the caller (defaults to True)

Returns:

  • CoordArray a path between the selected start and end positions

Raises:

  • NoValidEndpointException : if no valid start or end positions are found, and except_on_no_valid_endpoint is True
def as_adj_list( self, shuffle_d0: bool = True, shuffle_d1: bool = True) -> jaxtyping.Int8[ndarray, 'conn start_end coord']:
560	def as_adj_list(
561		self,
562		shuffle_d0: bool = True,
563		shuffle_d1: bool = True,
564	) -> Int8[np.ndarray, "conn start_end coord"]:
565		"""return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`"""
566		return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1)

return the maze as an adjacency list, wraps maze_dataset.token_utils.connection_list_to_adj_list

@classmethod
def from_adj_list( cls, adj_list: jaxtyping.Int8[ndarray, 'conn start_end coord']) -> LatticeMaze:
568	@classmethod
569	def from_adj_list(
570		cls,
571		adj_list: Int8[np.ndarray, "conn start_end coord"],
572	) -> "LatticeMaze":
573		"""create a LatticeMaze from a list of connections
574
575		> [!NOTE]
576		> This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.
577		"""
578		# this is where it would probably break for rectangular mazes
579		grid_n: int = adj_list.max() + 1
580
581		connection_list: ConnectionList = np.zeros(
582			(2, grid_n, grid_n),
583			dtype=np.bool_,
584		)
585
586		for c_start, c_end in adj_list:
587			# check that exactly 1 coordinate matches
588			if (c_start == c_end).sum() != 1:
589				raise ValueError("invalid connection")
590
591			# get the direction
592			d: int = (c_start != c_end).argmax()
593
594			x: int
595			y: int
596			# pick whichever has the lesser value in the direction `d`
597			if c_start[d] < c_end[d]:
598				x, y = c_start
599			else:
600				x, y = c_end
601
602			connection_list[d, x, y] = True
603
604		return LatticeMaze(
605			connection_list=connection_list,
606		)

create a LatticeMaze from a list of connections

Note

This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed.

def as_adj_list_tokens(self) -> list[str | tuple[int, int]]:
608	def as_adj_list_tokens(self) -> list[str | CoordTup]:
609		"""(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead"""
610		warnings.warn(
611			"`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.",
612			TokenizerDeprecationWarning,
613		)
614		return [
615			SPECIAL_TOKENS.ADJLIST_START,
616			*chain.from_iterable(  # type: ignore[list-item]
617				[
618					[
619						tuple(c_s),
620						SPECIAL_TOKENS.CONNECTOR,
621						tuple(c_e),
622						SPECIAL_TOKENS.ADJACENCY_ENDLINE,
623					]
624					for c_s, c_e in self.as_adj_list()
625				],
626			),
627			SPECIAL_TOKENS.ADJLIST_END,
628		]

(deprecated!) turn the maze into adjacency list tokens, use MazeTokenizerModular instead

680	def as_tokens(
681		self,
682		maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
683	) -> list[str]:
684		"""serialize maze and solution to tokens"""
685		if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"):
686			return maze_tokenizer.to_tokens(self)  # type: ignore[union-attr]
687		else:
688			return self._as_tokens(maze_tokenizer)  # type: ignore[union-attr,arg-type]

serialize maze and solution to tokens

800	@classmethod
801	def from_tokens(
802		cls,
803		tokens: list[str],
804		maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular",
805	) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze":
806		"""Constructs a maze from a tokenization.
807
808		Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported.
809		"""
810		# HACK: type ignores here fine since we check the instance
811		if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"):
812			maze_tokenizer = maze_tokenizer.to_legacy_tokenizer()  # type: ignore[union-attr]
813		if (
814			isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular")
815			and not maze_tokenizer.is_legacy_equivalent()  # type: ignore[union-attr]
816		):
817			err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}."
818			raise NotImplementedError(
819				err_msg,
820			)
821
822		if isinstance(tokens, str):
823			tokens = tokens.split()
824
825		if maze_tokenizer.is_AOTP():  # type: ignore[union-attr]
826			return cls._from_tokens_AOTP(tokens, maze_tokenizer)  # type: ignore[arg-type]
827		else:
828			raise NotImplementedError("only AOTP tokenization is supported")

Constructs a maze from a tokenization.

Only legacy tokenizers and their MazeTokenizerModular analogs are supported.

def as_pixels( self, show_endpoints: bool = True, show_solution: bool = True) -> jaxtyping.Int[ndarray, 'x y rgb']:
859	def as_pixels(
860		self,
861		show_endpoints: bool = True,
862		show_solution: bool = True,
863	) -> PixelGrid:
864		"""convert the maze to a pixel grid
865
866		- useful as a simpler way of plotting the maze than the more complex `MazePlot`
867		- the same underlying representation as `as_ascii` but as an image
868		- used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data
869		"""
870		# HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze`
871		# but solution, start_pos, end_pos not always defined
872		# but its fine since we explicitly check the type
873		if show_solution and not show_endpoints:
874			raise ValueError("show_solution=True requires show_endpoints=True")
875		# convert original bool pixel grid to RGB
876		pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw()
877		pixel_grid: PixelGrid = np.full(
878			(*pixel_grid_bw.shape, 3),
879			PixelColors.WALL,
880			dtype=np.uint8,
881		)
882		pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN  # noqa: E712
883
884		if self.__class__ == LatticeMaze:
885			return pixel_grid
886
887		# set endpoints for TargetedLatticeMaze
888		if self.__class__ == TargetedLatticeMaze:
889			if show_endpoints:
890				pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
891					PixelColors.START
892				)
893				pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
894					PixelColors.END
895				)
896			return pixel_grid
897
898		# set solution -- we only reach this part if `self.__class__ == SolvedMaze`
899		if show_solution:
900			for coord in self.solution:  # type: ignore[attr-defined]
901				pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH
902
903			# set pixels between coords
904			for index, coord in enumerate(self.solution[:-1]):  # type: ignore[attr-defined]
905				next_coord = self.solution[index + 1]  # type: ignore[attr-defined]
906				# check they are adjacent using norm
907				assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, (
908					f"Coords {coord} and {next_coord} are not adjacent"
909				)
910				# set pixel between them
911				pixel_grid[
912					coord[0] * 2 + 1 + next_coord[0] - coord[0],
913					coord[1] * 2 + 1 + next_coord[1] - coord[1],
914				] = PixelColors.PATH
915
916			# set endpoints (again, since path would overwrite them)
917			pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
918				PixelColors.START
919			)
920			pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = (  # type: ignore[attr-defined]
921				PixelColors.END
922			)
923
924		return pixel_grid

convert the maze to a pixel grid

@classmethod
def from_pixels( cls, pixel_grid: jaxtyping.Int[ndarray, 'x y rgb']) -> LatticeMaze:
 979	@classmethod
 980	def from_pixels(
 981		cls,
 982		pixel_grid: PixelGrid,
 983	) -> "LatticeMaze":
 984		"""create a LatticeMaze from a pixel grid. reverse of `as_pixels`
 985
 986		# Raises:
 987		- `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze`
 988		"""
 989		connection_list: ConnectionList
 990		grid_shape: tuple[int, int]
 991
 992		# if a binary pixel grid, return regular LatticeMaze
 993		if len(pixel_grid.shape) == 2:  # noqa: PLR2004
 994			connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid)
 995			return LatticeMaze(connection_list=connection_list)
 996
 997		# otherwise, detect and check it's valid
 998		cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid)
 999		if cls not in cls_detected.__mro__:
1000			err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }"
1001			raise ValueError(
1002				err_msg,
1003			)
1004
1005		(
1006			connection_list,
1007			grid_shape,
1008			marked_pos,
1009		) = cls._from_pixel_grid_with_positions(
1010			pixel_grid=pixel_grid,
1011			marked_positions=dict(
1012				start=PixelColors.START,
1013				end=PixelColors.END,
1014				solution=PixelColors.PATH,
1015			),
1016		)
1017		# if we wanted a LatticeMaze, return it
1018		if cls == LatticeMaze:
1019			return LatticeMaze(connection_list=connection_list)
1020
1021		# otherwise, keep going
1022		temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list)
1023
1024		# start and end pos
1025		start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"]
1026		assert start_pos_arr.shape == (
1027			1,
1028			2,
1029		), (
1030			f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1031		)
1032		assert end_pos_arr.shape == (
1033			1,
1034			2,
1035		), (
1036			f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate"
1037		)
1038
1039		start_pos: Coord = start_pos_arr[0]
1040		end_pos: Coord = end_pos_arr[0]
1041
1042		# return a TargetedLatticeMaze if that's what we wanted
1043		if cls == TargetedLatticeMaze:
1044			return TargetedLatticeMaze(
1045				connection_list=connection_list,
1046				start_pos=start_pos,
1047				end_pos=end_pos,
1048			)
1049
1050		# raw solution, only contains path elements and not start or end
1051		solution_raw: CoordArray = marked_pos["solution"]
1052		if len(solution_raw.shape) == 2:  # noqa: PLR2004
1053			assert solution_raw.shape[1] == 2, (  # noqa: PLR2004
1054				f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)"
1055			)
1056		elif solution_raw.shape == (0,):
1057			# the solution and end should be immediately adjacent
1058			assert np.sum(np.abs(start_pos - end_pos)) == 1, (
1059				f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given"
1060			)
1061
1062		# order the solution, by creating a list from the start to the end
1063		# add end pos, since we will iterate over all these starting from the start pos
1064		solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [
1065			tuple(end_pos),
1066		]
1067		# solution starts with start point
1068		solution: list[CoordTup] = [tuple(start_pos)]
1069		while solution[-1] != tuple(end_pos):
1070			# use `get_coord_neighbors` to find connected neighbors
1071			neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1])
1072			# TODO: make this less ugly
1073			assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), (  # noqa: PT018, PLR2004
1074				f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}"
1075			)
1076			# neighbors = neighbors[:, [1, 0]]
1077			# filter out neighbors that are not in the raw solution
1078			neighbors_filtered: CoordArray = np.array(
1079				[
1080					coord
1081					for coord in neighbors
1082					if (
1083						tuple(coord) in solution_raw_list
1084						and tuple(coord) not in solution
1085					)
1086				],
1087			)
1088			# assert only one element is left, and then add it to the solution
1089			assert neighbors_filtered.shape == (
1090				1,
1091				2,
1092			), (
1093				f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}"
1094			)
1095			solution.append(tuple(neighbors_filtered[0]))
1096
1097		# assert the solution is complete
1098		assert solution[0] == tuple(start_pos), (
1099			f"solution {solution} does not start at start_pos {start_pos}"
1100		)
1101		assert solution[-1] == tuple(end_pos), (
1102			f"solution {solution} does not end at end_pos {end_pos}"
1103		)
1104
1105		return cls(
1106			connection_list=np.array(connection_list),
1107			solution=np.array(solution),  # type: ignore[call-arg]
1108		)

create a LatticeMaze from a pixel grid. reverse of as_pixels

Raises:

def as_ascii(self, show_endpoints: bool = True, show_solution: bool = True) -> str:
1127	def as_ascii(
1128		self,
1129		show_endpoints: bool = True,
1130		show_solution: bool = True,
1131	) -> str:
1132		"""return an ASCII grid of the maze
1133
1134		useful for debugging in the terminal, or as it's own format
1135
1136		can be reversed with `LatticeMaze.from_ascii()`
1137		"""
1138		ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid()
1139		pixel_grid: PixelGrid = self.as_pixels(
1140			show_endpoints=show_endpoints,
1141			show_solution=show_solution,
1142		)
1143
1144		chars_replace: tuple = tuple()
1145		if show_endpoints:
1146			chars_replace += (AsciiChars.START, AsciiChars.END)
1147		if show_solution:
1148			chars_replace += (AsciiChars.PATH,)
1149
1150		for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1151			if ascii_char in chars_replace:
1152				ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char
1153
1154		return "\n".join("".join(row) for row in ascii_grid)

return an ASCII grid of the maze

useful for debugging in the terminal, or as it's own format

can be reversed with LatticeMaze.from_ascii()

@classmethod
def from_ascii(cls, ascii_str: str) -> LatticeMaze:
1156	@classmethod
1157	def from_ascii(cls, ascii_str: str) -> "LatticeMaze":
1158		"get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)"
1159		lines: list[str] = ascii_str.strip().split("\n")
1160		lines = [line.strip() for line in lines]
1161		ascii_grid: Shaped[np.ndarray, "x y"] = np.array(
1162			[list(line) for line in lines],
1163			dtype=str,
1164		)
1165		pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8)
1166
1167		for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items():
1168			pixel_grid[ascii_grid == ascii_char] = pixel_color
1169
1170		return cls.from_pixels(pixel_grid)

get a LatticeMaze from an ASCII representation (reverses LaticeMaze.as_ascii)

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
ConnectionList = <class 'jaxtyping.Bool[ndarray, 'lattice_dim=2 row col']'>
@dataclass(frozen=True)
class AsciiChars:
 99@dataclass(frozen=True)
100class AsciiChars:
101	"standard ascii characters for mazes"
102
103	WALL: str = "#"
104	OPEN: str = " "
105	START: str = "S"
106	END: str = "E"
107	PATH: str = "X"

standard ascii characters for mazes

AsciiChars( WALL: str = '#', OPEN: str = ' ', START: str = 'S', END: str = 'E', PATH: str = 'X')
WALL: str = '#'
OPEN: str = ' '
START: str = 'S'
END: str = 'E'
PATH: str = 'X'
Coord = <class 'jaxtyping.Int8[ndarray, 'row_col=2']'>
CoordArray = <class 'jaxtyping.Int8[ndarray, 'coord row_col=2']'>
@dataclass(frozen=True)
class PixelColors:
88@dataclass(frozen=True)
89class PixelColors:
90	"standard colors for pixel grids"
91
92	WALL: RGB = (0, 0, 0)
93	OPEN: RGB = (255, 255, 255)
94	START: RGB = (0, 255, 0)
95	END: RGB = (255, 0, 0)
96	PATH: RGB = (0, 0, 255)

standard colors for pixel grids

PixelColors( WALL: tuple[int, int, int] = (0, 0, 0), OPEN: tuple[int, int, int] = (255, 255, 255), START: tuple[int, int, int] = (0, 255, 0), END: tuple[int, int, int] = (255, 0, 0), PATH: tuple[int, int, int] = (0, 0, 255))
WALL: tuple[int, int, int] = (0, 0, 0)
OPEN: tuple[int, int, int] = (255, 255, 255)
START: tuple[int, int, int] = (0, 255, 0)
END: tuple[int, int, int] = (255, 0, 0)
PATH: tuple[int, int, int] = (0, 0, 255)