docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.tokenization.modular.elements

implements subclasses of _TokenizerElement to be used in MazeTokenizerModular


   1"""implements subclasses of `_TokenizerElement` to be used in `MazeTokenizerModular`"""
   2
   3import abc
   4import random
   5from typing import (
   6	Callable,
   7	Literal,
   8	Sequence,
   9	TypedDict,
  10)
  11
  12import numpy as np
  13from jaxtyping import Bool, Int
  14from muutils.json_serialize import (
  15	serializable_dataclass,
  16	serializable_field,
  17)
  18from muutils.misc import empty_sequence_if_attr_false, flatten
  19
  20# from maze_dataset import SolvedMaze
  21from maze_dataset.constants import (
  22	VOCAB,
  23	ConnectionArray,
  24	ConnectionList,
  25	Coord,
  26	CoordTup,
  27)
  28from maze_dataset.generation import numpy_rng
  29from maze_dataset.maze.lattice_maze import LatticeMaze, SolvedMaze
  30from maze_dataset.token_utils import (
  31	connection_list_to_adj_list,
  32	get_cardinal_direction,
  33	get_relative_direction,
  34	is_connection,
  35	tokens_between,
  36)
  37from maze_dataset.tokenization.modular.element_base import (
  38	__TokenizerElementNamespace,
  39	_load_tokenizer_element,
  40	_TokenizerElement,
  41	_unsupported_is_invalid,
  42	mark_as_unsupported,
  43)
  44from maze_dataset.utils import lattice_connection_array
  45
  46
  47class CoordTokenizers(__TokenizerElementNamespace):
  48	"""Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
  49
  50	key = "coord_tokenizer"
  51
  52	@serializable_dataclass(frozen=True, kw_only=True)
  53	class _CoordTokenizer(_TokenizerElement, abc.ABC):
  54		"""Superclass for classes which tokenize singular coords in a maze."""
  55
  56		@abc.abstractmethod
  57		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:
  58			pass
  59
  60		@classmethod
  61		def attribute_key(cls) -> str:
  62			return CoordTokenizers.key
  63
  64		def is_valid(self, do_except: bool = False) -> bool:
  65			# No invalid instances possible within data member type hint bounds
  66			return True
  67
  68	@serializable_dataclass(frozen=True, kw_only=True)
  69	class UT(_CoordTokenizer):
  70		"""Unique token coordinate tokenizer."""
  71
  72		# inherit docstring
  73		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
  74			return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]
  75
  76	@serializable_dataclass(frozen=True, kw_only=True)
  77	class CTT(_CoordTokenizer):
  78		"""Coordinate tuple tokenizer
  79
  80		# Parameters
  81		- `pre`: Whether all coords include an integral preceding delimiter token
  82		- `intra`: Whether all coords include a delimiter token between coordinates
  83		- `post`: Whether all coords include an integral following delimiter token
  84		"""
  85
  86		pre: bool = serializable_field(default=True)
  87		intra: bool = serializable_field(default=True)
  88		post: bool = serializable_field(default=True)
  89		# Implement methods
  90
  91		# inherit docstring
  92		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
  93			return [
  94				*empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"),
  95				str(coord[0]),
  96				*empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"),
  97				str(coord[1]),
  98				*empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"),
  99			]
 100
 101
 102class EdgeGroupings(__TokenizerElementNamespace):
 103	"""Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`."""
 104
 105	key = "edge_grouping"
 106
 107	class _GroupingTokenParams(TypedDict):
 108		"""A uniform private hyperparameter interface used by `AdjListTokenizer`."""
 109
 110		connection_token_ordinal: Literal[0, 1, 2]
 111		intra: bool
 112		grouped: bool
 113
 114	@serializable_dataclass(frozen=True, kw_only=True)
 115	class _EdgeGrouping(_TokenizerElement, abc.ABC):
 116		"""Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping."""
 117
 118		@classmethod
 119		def attribute_key(cls) -> str:
 120			return EdgeGroupings.key
 121
 122		def is_valid(self, do_except: bool = False) -> bool:
 123			return True
 124
 125		@abc.abstractmethod
 126		def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
 127			"""Divides a ConnectionArray into groups of edges.
 128
 129			Shuffles/sequences within each group if applicable.
 130			"""
 131			pass
 132
 133		@abc.abstractmethod
 134		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
 135			"""Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize.
 136
 137			These hyperparameters are not used by `_EdgeGrouping` internally.
 138			They are located in `_EdgeGrouping` rather than in `AdjListTokenizer`
 139			since the hyperparameter space is a function of the `_EdgeGrouping` subclass.
 140			This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses
 141			into a uniform private interface used by `AdjListTokenizer`.
 142			"""
 143			pass
 144
 145	@serializable_dataclass(frozen=True, kw_only=True)
 146	class Ungrouped(_EdgeGrouping):
 147		"""No grouping occurs, each edge is tokenized individually.
 148
 149		# Parameters
 150		- `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears.
 151		Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
 152		"""
 153
 154		connection_token_ordinal: Literal[0, 1, 2] = serializable_field(
 155			default=1,
 156			assert_type=False,
 157		)
 158
 159		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
 160			return EdgeGroupings._GroupingTokenParams(
 161				connection_token_ordinal=self.connection_token_ordinal,
 162				intra=False,
 163				grouped=False,
 164			)
 165
 166		def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]:
 167			return np.expand_dims(edges, 1)
 168
 169	@serializable_dataclass(frozen=True, kw_only=True)
 170	@mark_as_unsupported(_unsupported_is_invalid)
 171	class ByLeadingCoord(_EdgeGrouping):
 172		"""All edges with the same leading coord are grouped together.
 173
 174		# Parameters
 175		- `intra`: Whether all edge groupings include a delimiter token between individual edge representations.
 176		Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `)
 177		- `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order.
 178		If false, the fixed order is lexicographical by (row, col).
 179		In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.
 180		- `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears.
 181		Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
 182		"""
 183
 184		intra: bool = serializable_field(default=True)
 185		shuffle_group: bool = serializable_field(default=True)
 186		connection_token_ordinal: Literal[0, 1] = serializable_field(
 187			default=0,
 188			assert_type=False,
 189		)
 190
 191		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
 192			return EdgeGroupings._GroupingTokenParams(
 193				connection_token_ordinal=self.connection_token_ordinal,
 194				intra=self.intra,
 195				grouped=True,
 196			)
 197
 198		def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
 199			# Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function
 200			index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort(
 201				(edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]),
 202			)
 203			sorted_edges: ConnectionArray = edges[index_array, ...]
 204			groups: list[ConnectionArray] = np.split(
 205				sorted_edges,
 206				np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:],
 207			)
 208			if self.shuffle_group:
 209				[numpy_rng.shuffle(g, axis=0) for g in groups]
 210			return groups
 211
 212
 213class EdgePermuters(__TokenizerElementNamespace):
 214	"""Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`."""
 215
 216	key = "edge_permuter"
 217
 218	@serializable_dataclass(frozen=True, kw_only=True)
 219	class _EdgePermuter(_TokenizerElement, abc.ABC):
 220		"""Specifies how to sequence the two coords that encode a lattice edge."""
 221
 222		@classmethod
 223		def attribute_key(cls) -> str:
 224			return EdgePermuters.key
 225
 226		def is_valid(self, do_except: bool = False) -> bool:
 227			# No invalid instances possible within data member type hint bounds
 228			return True
 229
 230		@staticmethod
 231		@abc.abstractmethod
 232		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
 233			"""Executes a permutation.
 234
 235			Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation.
 236
 237			# Parameters
 238			- `lattice_edges`: Array of lattice edges.
 239			The two coords in shape[1] must be adjacent in the lattice.
 240
 241			# Returns
 242			- Array of lattice edges with entries along shape[1] systematically permuted.
 243			- shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`.
 244			"""
 245			pass
 246
 247	@serializable_dataclass(frozen=True, kw_only=True)
 248	class SortedCoords(_EdgePermuter):
 249		"""returns a sorted representation. useful for checking consistency"""
 250
 251		@staticmethod
 252		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
 253			return lattice_edges[
 254				np.lexsort(
 255					(
 256						lattice_edges[:, 1, 1],
 257						lattice_edges[:, 1, 0],
 258						lattice_edges[:, 0, 1],
 259						lattice_edges[:, 0, 0],
 260					),
 261				),
 262				...,
 263			]
 264
 265	@serializable_dataclass(frozen=True, kw_only=True)
 266	class RandomCoords(_EdgePermuter):
 267		"""Permutes each edge randomly."""
 268
 269		@staticmethod
 270		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
 271			numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges)
 272			return lattice_edges
 273
 274	@serializable_dataclass(frozen=True, kw_only=True)
 275	class BothCoords(_EdgePermuter):
 276		"""Includes both possible permutations of every edge in the output.
 277
 278		Since input ConnectionList has only 1 instance of each edge,
 279		a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`.
 280		"""
 281
 282		@staticmethod
 283		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
 284			return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)
 285
 286
 287class EdgeSubsets(__TokenizerElementNamespace):
 288	"""Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`."""
 289
 290	key = "edge_subset"
 291
 292	@serializable_dataclass(frozen=True, kw_only=True)
 293	class _EdgeSubset(_TokenizerElement, abc.ABC):
 294		"""Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized."""
 295
 296		@classmethod
 297		def attribute_key(cls) -> str:
 298			return EdgeSubsets.key
 299
 300		def is_valid(self, do_except: bool = False) -> bool:
 301			return True
 302
 303		@abc.abstractmethod
 304		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
 305			"""Returns the set of lattice edges to be tokenized."""
 306			pass
 307
 308	@serializable_dataclass(frozen=True, kw_only=True)
 309	class AllLatticeEdges(_EdgeSubset):
 310		"""All 2n**2-2n edges of the lattice are tokenized.
 311
 312		If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`.
 313		"""
 314
 315		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
 316			return lattice_connection_array(maze.grid_n)
 317
 318	@serializable_dataclass(frozen=True, kw_only=True)
 319	class ConnectionEdges(_EdgeSubset):
 320		"""Only edges which contain a connection are tokenized.
 321
 322		Alternatively, only edges which contain a wall are tokenized.
 323
 324		# Parameters
 325		- `walls`: Whether wall edges or connection edges are tokenized.
 326		If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`.
 327		"""
 328
 329		walls: bool = serializable_field(default=False)
 330
 331		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
 332			conn_list: ConnectionList = maze.connection_list
 333			if self.walls:
 334				conn_list = np.logical_not(conn_list)
 335				conn_list[0, -1, :] = False
 336				conn_list[1, :, -1] = False
 337			return connection_list_to_adj_list(
 338				conn_list,
 339				shuffle_d0=False,
 340				shuffle_d1=False,
 341			)
 342
 343
 344def _adjlist_no_pre_unsupported(self_, do_except: bool = False) -> bool:  # noqa: ANN001
 345	"""Returns False if `pre` is True, True otherwise."""
 346	output: bool = self_.pre is False
 347	if do_except and not output:
 348		raise ValueError(
 349			"AdjListCoord does not support `pre == False`.",
 350		)
 351
 352	return output
 353
 354
 355class AdjListTokenizers(__TokenizerElementNamespace):
 356	"""Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
 357
 358	key = "adj_list_tokenizer"
 359
 360	@serializable_dataclass(frozen=True, kw_only=True)
 361	@mark_as_unsupported(_adjlist_no_pre_unsupported)
 362	class _AdjListTokenizer(_TokenizerElement, abc.ABC):
 363		"""Specifies how the adjacency list is tokenized.
 364
 365		Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations.
 366		See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details.
 367
 368		# Parameters
 369		- `pre`: Whether all edge groupings include a preceding delimiter token
 370		- `post`: Whether all edge groupings include a following delimiter token
 371		- `shuffle_d0`: Specifies how to sequence the edge groupings.
 372			If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group.
 373		- `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping.
 374		- `edge_subset`: Specifies the subset of lattice edges to be tokenized.
 375		- `edge_permuter`: Specifies, in each edge tokenization, which coord either:
 376			1. Appears first in the tokenization, for `AdjListCoord`.
 377			2. Is tokenized directly as a coord, for `AdjListCardinal`.
 378			- `shuffle`: For each edge, the leading coord is selected randomly.
 379			- `all`: Each edge appears twice in the tokenization, appearing with both leading coords.
 380			- `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details.
 381		"""
 382
 383		pre: bool = serializable_field(default=False, assert_type=False)
 384		post: bool = serializable_field(default=True)
 385		shuffle_d0: bool = serializable_field(default=True)
 386		edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field(
 387			default=EdgeGroupings.Ungrouped(),
 388			loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings),
 389		)
 390		edge_subset: EdgeSubsets._EdgeSubset = serializable_field(
 391			default=EdgeSubsets.ConnectionEdges(),
 392			loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets),
 393		)
 394		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
 395			default=EdgePermuters.RandomCoords(),
 396			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
 397		)
 398
 399		@classmethod
 400		def attribute_key(cls) -> str:
 401			return AdjListTokenizers.key
 402
 403		def is_valid(self, do_except: bool = False) -> bool:
 404			# No invalid instances possible within data member type hint bounds
 405			return True
 406
 407		@abc.abstractmethod
 408		def _tokenization_callables(
 409			self,
 410			edges: ConnectionArray,
 411			is_conn: Bool[np.ndarray, " edges"],
 412			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 413			*args,
 414			**kwargs,
 415		) -> list[Callable]:
 416			"""Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization.
 417
 418			# Returns
 419			- `[0]`: leading coord tokens
 420			- `[1]`: connector tokens
 421			- `[2]`: trailing coord tokens
 422			"""
 423			pass
 424
 425		def _tokenize_edge_grouping(
 426			self,
 427			edges: ConnectionArray,
 428			maze: LatticeMaze,
 429			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 430			group_params: EdgeGroupings._GroupingTokenParams,
 431		) -> Sequence[str]:
 432			"""Tokenizes a single edge grouping."""
 433			cxn_ord: int = group_params["connection_token_ordinal"]
 434			is_conn: Bool[np.ndarray, edges] = is_connection(
 435				edges,
 436				maze.connection_list,
 437			)
 438			tokenize_callables = self._tokenization_callables(
 439				edges,
 440				is_conn,
 441				coord_tokenizer,
 442			)
 443
 444			if group_params["grouped"]:
 445				# If grouped
 446				callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1]
 447				repeated_callables = [
 448					tokenize_callables[i] for i in callable_permutation
 449				]
 450				return flatten(
 451					[
 452						tokenize_callables[0](0),
 453						[
 454							[
 455								*[
 456									tok_callable(i)
 457									for tok_callable in repeated_callables
 458								],
 459								*(
 460									(VOCAB.ADJLIST_INTRA,)
 461									if group_params["intra"]
 462									else ()
 463								),
 464							]
 465							for i in range(edges.shape[0])
 466						],
 467					],
 468				)
 469			else:
 470				# If ungrouped
 471				callable_permutation = [0, 2]
 472				callable_permutation.insert(cxn_ord, 1)
 473				tokenize_callables = [
 474					tokenize_callables[i] for i in callable_permutation
 475				]
 476
 477				return flatten(
 478					[
 479						[
 480							[
 481								*[
 482									tok_callable(i)
 483									for tok_callable in tokenize_callables
 484								],
 485								*empty_sequence_if_attr_false(
 486									(VOCAB.ADJLIST_INTRA,),
 487									group_params,
 488									"intra",
 489								),
 490							]
 491							for i in range(edges.shape[0])
 492						],
 493					],
 494				)
 495
 496		def to_tokens(
 497			self,
 498			maze: LatticeMaze,
 499			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 500		) -> list[str]:
 501			# Get the set of edges to be tokenized
 502			edges: ConnectionArray = self.edge_subset._get_edges(maze)
 503			# Systematically permute the leading coord of each edge
 504			edges: ConnectionArray = self.edge_permuter._permute(edges)
 505			group_params: EdgeGroupings._GroupingTokenParams = (
 506				self.edge_grouping._token_params()
 507			)
 508			# then, we need to group the edges
 509			groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges)
 510			# shuffle the groups if specified
 511			if self.shuffle_d0:
 512				if isinstance(groups, np.ndarray):
 513					numpy_rng.shuffle(groups, axis=0)
 514				elif isinstance(groups, list):
 515					random.shuffle(groups)
 516				else:
 517					err_msg: str = f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported."
 518					raise TypeError(err_msg)
 519			# Tokenize each group with optional delimiters
 520			tokens: list[str] = list(
 521				flatten(
 522					[
 523						[
 524							*empty_sequence_if_attr_false(
 525								(VOCAB.ADJLIST_PRE,),
 526								self,
 527								"pre",
 528							),
 529							*self._tokenize_edge_grouping(
 530								group,
 531								maze,
 532								coord_tokenizer,
 533								group_params,
 534							),
 535							*empty_sequence_if_attr_false(
 536								(VOCAB.ADJACENCY_ENDLINE,),
 537								self,
 538								"post",
 539							),
 540						]
 541						for group in groups
 542					],
 543				),
 544			)
 545			return tokens
 546
 547	@serializable_dataclass(frozen=True, kw_only=True)
 548	class AdjListCoord(_AdjListTokenizer):
 549		"""Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members."""
 550
 551		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
 552			default=EdgePermuters.RandomCoords(),
 553			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
 554		)
 555
 556		def _tokenization_callables(
 557			self,
 558			edges: ConnectionArray,
 559			is_conn: Bool[np.ndarray, " edges"],
 560			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 561			*args,
 562			**kwargs,
 563		) -> list[Callable]:
 564			# Map from `is_conn` to the tokens which represent connections and walls
 565			conn_token_map: dict[bool, str] = {
 566				True: VOCAB.CONNECTOR,
 567				False: VOCAB.ADJLIST_WALL,
 568			}
 569			return [
 570				lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
 571				lambda i: conn_token_map[is_conn[i]],
 572				lambda i: coord_tokenizer.to_tokens(edges[i, 1]),
 573			]
 574
 575	@serializable_dataclass(frozen=True, kw_only=True)
 576	class AdjListCardinal(_AdjListTokenizer):
 577		"""Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.
 578
 579		# Parameters
 580		- `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
 581		"""
 582
 583		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
 584			default=EdgePermuters.BothCoords(),
 585			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
 586		)
 587
 588		def _tokenization_callables(
 589			self,
 590			edges: ConnectionArray,
 591			is_conn: Bool[np.ndarray, " edges"],
 592			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 593			*args,
 594			**kwargs,
 595		) -> list[Callable]:
 596			# Map from `is_conn` to the tokens which represent connections and walls
 597			conn_token_map: dict[bool, str] = {
 598				True: VOCAB.CONNECTOR,
 599				False: VOCAB.ADJLIST_WALL,
 600			}
 601			return [
 602				lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
 603				lambda i: conn_token_map[is_conn[i]],
 604				lambda i: get_cardinal_direction(edges[i]),
 605			]
 606
 607
 608class TargetTokenizers(__TokenizerElementNamespace):
 609	"""Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
 610
 611	key = "target_tokenizer"
 612
 613	@serializable_dataclass(frozen=True, kw_only=True)
 614	class _TargetTokenizer(_TokenizerElement, abc.ABC):
 615		"""Superclass of tokenizers for maze targets."""
 616
 617		@abc.abstractmethod
 618		def to_tokens(
 619			self,
 620			targets: Sequence[Coord],
 621			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 622		) -> list[str]:
 623			"""Returns tokens representing the target."""
 624			pass
 625
 626		@classmethod
 627		def attribute_key(cls) -> str:
 628			return TargetTokenizers.key
 629
 630	@serializable_dataclass(frozen=True, kw_only=True)
 631	class Unlabeled(_TargetTokenizer):
 632		"""Targets are simply listed as coord tokens.
 633
 634		- `post`: Whether all coords include an integral following delimiter token
 635		"""
 636
 637		post: bool = serializable_field(default=False)
 638
 639		# inherit docstring
 640		def to_tokens(  # noqa: D102
 641			self,
 642			targets: Sequence[Coord],
 643			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 644		) -> list[str]:
 645			return list(
 646				flatten(
 647					[
 648						[
 649							*coord_tokenizer.to_tokens(target),
 650							*empty_sequence_if_attr_false(
 651								[VOCAB.TARGET_POST],
 652								self,
 653								"post",
 654							),
 655						]
 656						for target in targets
 657					],
 658				),
 659			)
 660
 661		# inherit docstring
 662		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
 663			# No invalid instances possible within data member type hint bounds
 664			return True
 665
 666
 667class StepSizes(__TokenizerElementNamespace):
 668	"""Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`."""
 669
 670	key = "step_size"
 671
 672	@serializable_dataclass(frozen=True, kw_only=True)
 673	class _StepSize(_TokenizerElement, abc.ABC):
 674		"""Specifies which coords in `maze.solution` are used to represent the path."""
 675
 676		@classmethod
 677		def attribute_key(cls) -> str:
 678			return StepSizes.key
 679
 680		@abc.abstractmethod  # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call
 681		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
 682			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
 683			raise NotImplementedError(
 684				"Subclasses must implement `StepSize.step_indices.",
 685			)
 686
 687		def step_start_end_indices(self, maze: SolvedMaze) -> list[tuple[int, int]]:
 688			"""Returns steps as tuples of starting and ending positions for each step."""
 689			indices: list[int] = self._step_single_indices(maze)
 690			# TODO: RUF007 Prefer `itertools.pairwise()` over `zip()` when iterating over successive pairs
 691			return [
 692				(start, end)
 693				for start, end in zip(indices[:-1], indices[1:], strict=False)  # noqa: RUF007
 694			]
 695
 696		def is_valid(self, do_except: bool = False) -> bool:
 697			# No invalid instances possible within data member type hint bounds
 698			return True
 699
 700	@serializable_dataclass(frozen=True, kw_only=True)
 701	class Singles(_StepSize):
 702		"""Every coord in `maze.solution` is represented.
 703
 704		Legacy tokenizers all use this behavior.
 705		"""
 706
 707		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
 708			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
 709			return list(range(maze.solution.shape[0]))
 710
 711	@serializable_dataclass(frozen=True, kw_only=True)
 712	@mark_as_unsupported(_unsupported_is_invalid)
 713	class Straightaways(_StepSize):
 714		"""Only coords where the path turns are represented in the path.
 715
 716		I.e., the path is represented as a sequence of straightaways,
 717		specified by the coords at the turns.
 718		"""
 719
 720		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
 721			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
 722			last_turn_coord: Coord = maze.solution[0, ...]
 723			indices: list[int] = [0]
 724			for i, coord in enumerate(maze.solution):
 725				if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]:
 726					indices.append(i - 1)
 727					last_turn_coord = maze.solution[i - 1, ...]
 728			indices.append(i)
 729			return indices
 730
 731	@serializable_dataclass(frozen=True, kw_only=True)
 732	class Forks(_StepSize):
 733		"""Only coords at forks, where the path has >=2 options for the next step are included.
 734
 735		Excludes the option of backtracking.
 736		The starting and ending coords are always included.
 737		"""
 738
 739		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
 740			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
 741			return maze.get_solution_forking_points(always_include_endpoints=True)[0]
 742
 743	@serializable_dataclass(frozen=True, kw_only=True)
 744	@mark_as_unsupported(_unsupported_is_invalid)
 745	class ForksAndStraightaways(_StepSize):
 746		"""Includes the union of the coords included by `Forks` and `Straightaways`.
 747
 748		See documentation for those classes for details.
 749		"""
 750
 751		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
 752			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
 753			return list(
 754				np.unique(
 755					np.concatenate(
 756						(
 757							StepSizes.Straightaways()._step_single_indices(maze),
 758							StepSizes.Forks()._step_single_indices(maze),
 759						),
 760					),
 761				),
 762			)
 763
 764
 765class StepTokenizers(__TokenizerElementNamespace):
 766	"""Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
 767
 768	key = "step_tokenizers"
 769
 770	@serializable_dataclass(frozen=True, kw_only=True)
 771	class _StepTokenizer(_TokenizerElement, abc.ABC):
 772		"""Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized."""
 773
 774		@classmethod
 775		def attribute_key(cls) -> str:
 776			return StepTokenizers.key
 777
 778		@abc.abstractmethod
 779		def to_tokens(
 780			self,
 781			maze: SolvedMaze,
 782			start_index: int,
 783			end_index: int,
 784			**kwargs,
 785		) -> list[str]:
 786			"""Tokenizes a single step in the solution.
 787
 788			# Parameters
 789			- `maze`: Maze to be tokenized
 790			- `start_index`: The index of the Coord in `maze.solution` at which the current step starts
 791			- `end_index`: The index of the Coord in `maze.solution` at which the current step ends
 792			"""
 793			raise NotImplementedError(
 794				"Subclasses must implement `StepTokenizer.to_tokens.",
 795			)
 796
 797		def is_valid(self, do_except: bool = False) -> bool:
 798			# No invalid instances possible within data member type hint bounds
 799			return True
 800
 801	@serializable_dataclass(frozen=True, kw_only=True)
 802	class Coord(_StepTokenizer):
 803		"""A direct tokenization of the end position coord represents the step."""
 804
 805		# inherit docstring
 806		def to_tokens(  # noqa: D102
 807			self,
 808			maze: SolvedMaze,
 809			start_index: int,
 810			end_index: int,
 811			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 812		) -> list[str]:
 813			return coord_tokenizer.to_tokens(maze.solution[end_index, ...])
 814
 815	@serializable_dataclass(frozen=True, kw_only=True)
 816	class Cardinal(_StepTokenizer):
 817		"""A step is tokenized with a cardinal direction token.
 818
 819		It is the direction of the step from the starting position along the solution.
 820		"""
 821
 822		# inherit docstring
 823		def to_tokens(  # noqa: D102
 824			self,
 825			maze: SolvedMaze,
 826			start_index: int,
 827			end_index: int,
 828			**kwargs,
 829		) -> list[str]:
 830			return [
 831				get_cardinal_direction(maze.solution[start_index : start_index + 2]),
 832			]
 833
 834	@serializable_dataclass(frozen=True, kw_only=True)
 835	class Relative(_StepTokenizer):
 836		"""Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).
 837
 838		To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH.
 839		Similarly to `Cardinal`, the direction is that of the step from the starting position.
 840		"""
 841
 842		# inherit docstring
 843		def to_tokens(  # noqa: D102
 844			self,
 845			maze: SolvedMaze,
 846			start_index: int,
 847			end_index: int,
 848			**kwargs,
 849		) -> list[str]:
 850			if start_index == 0:
 851				start = maze.solution[0]
 852				previous = start + np.array([1, 0])
 853				return [
 854					get_relative_direction(
 855						np.concatenate(
 856							(
 857								np.expand_dims(previous, 0),
 858								maze.solution[start_index : start_index + 2],
 859							),
 860							axis=0,
 861						),
 862					),
 863				]
 864			return [
 865				get_relative_direction(
 866					maze.solution[start_index - 1 : start_index + 2],
 867				),
 868			]
 869
 870	@serializable_dataclass(frozen=True, kw_only=True)
 871	class Distance(_StepTokenizer):
 872		"""A count of the number of individual steps from the starting point to the end point.
 873
 874		Contains no information about directionality, only the distance traveled in the step.
 875		`Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`.
 876		This constraint is enforced in `_PathTokenizer.is_valid`.
 877		"""
 878
 879		# inherit docstring
 880		def to_tokens(  # noqa: D102
 881			self,
 882			maze: SolvedMaze,
 883			start_index: int,
 884			end_index: int,
 885			**kwargs,
 886		) -> list[str]:
 887			d: int = end_index - start_index
 888			return [getattr(VOCAB, f"I_{d:03}")]
 889
 890	"""
 891	`StepTokenizerPermutation`
 892	A sequence of unique `_StepTokenizer`s.
 893	This type exists mostly just for the clarity and convenience of `_PathTokenizer` code.
 894	"""
 895	StepTokenizerPermutation: type = (
 896		tuple[_StepTokenizer]
 897		| tuple[_StepTokenizer, _StepTokenizer]
 898		| tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer]
 899		| tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer]
 900	)
 901
 902
 903class PathTokenizers(__TokenizerElementNamespace):
 904	"""Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
 905
 906	key = "path_tokenizer"
 907
 908	@serializable_dataclass(frozen=True, kw_only=True)
 909	class _PathTokenizer(_TokenizerElement, abc.ABC):
 910		"""Superclass of tokenizers for maze solution paths."""
 911
 912		@abc.abstractmethod
 913		def to_tokens(
 914			self,
 915			maze: SolvedMaze,
 916			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 917		) -> list[str]:
 918			"""Returns tokens representing the solution path."""
 919			pass
 920
 921		@classmethod
 922		def attribute_key(cls) -> str:
 923			return PathTokenizers.key
 924
 925	@serializable_dataclass(frozen=True, kw_only=True)
 926	class StepSequence(_PathTokenizer, abc.ABC):
 927		"""Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path.
 928
 929		Allows for a sequence of leading and trailing tokens which don't fit the step pattern.
 930
 931		# Parameters
 932		- `step_size`: Selects the size of a single step in the sequence
 933		- `step_tokenizers`: Selects the combination and permutation of tokens
 934		- `pre`: Whether all steps include an integral preceding delimiter token
 935		- `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization.
 936		- `post`: Whether all steps include an integral following delimiter token
 937		"""
 938
 939		step_size: StepSizes._StepSize = serializable_field(
 940			default=StepSizes.Singles(),
 941			loading_fn=lambda x: _load_tokenizer_element(x, StepSizes),
 942		)
 943		step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field(
 944			default=(StepTokenizers.Coord(),),
 945			serialization_fn=lambda x: [y.serialize() for y in x],
 946			loading_fn=lambda x: tuple(x[StepTokenizers.key]),
 947		)
 948		pre: bool = serializable_field(default=False)
 949		intra: bool = serializable_field(default=False)
 950		post: bool = serializable_field(default=False)
 951
 952		# inherit docstring
 953		def to_tokens(  # noqa: D102
 954			self,
 955			maze: SolvedMaze,
 956			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 957		) -> list[str]:
 958			return [
 959				*self._leading_tokens(maze, coord_tokenizer),
 960				*flatten(
 961					[
 962						self._single_step_tokens(maze, start, end, coord_tokenizer)
 963						for start, end in self.step_size.step_start_end_indices(maze)
 964					],
 965				),
 966				*self._trailing_tokens(maze, coord_tokenizer),
 967			]
 968
 969		def _single_step_tokens(
 970			self,
 971			maze: SolvedMaze,
 972			i: int,
 973			j: int,
 974			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 975		) -> list[str]:
 976			"""Returns the token sequence representing a single step along the path."""
 977			step_rep_tokens: list[list[str]] = [
 978				step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer)
 979				for step_tokenizer in self.step_tokenizers
 980			]
 981			if self.intra:
 982				step_rep_tokens_and_intra: list[str] = [None] * (
 983					len(step_rep_tokens) * 2
 984				)
 985				step_rep_tokens_and_intra[::2] = step_rep_tokens
 986				step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len(
 987					step_rep_tokens,
 988				)
 989				step_rep_tokens = list(flatten(step_rep_tokens_and_intra))
 990			all_tokens: list[str] = [
 991				*empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
 992				*flatten(step_rep_tokens),
 993				*empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"),
 994			]
 995			return all_tokens
 996
 997		def _leading_tokens(
 998			self,
 999			maze: SolvedMaze,
1000			coord_tokenizer: CoordTokenizers._CoordTokenizer,
1001		) -> list[str]:
1002			"""Returns tokens preceding those from the sequence from `_single_step_tokens`.
1003
1004			Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`.
1005			<PATH_START> should NOT be included.
1006			"""
1007			if StepTokenizers.Coord() in self.step_tokenizers:
1008				return [
1009					*empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
1010					*coord_tokenizer.to_tokens(maze.solution[0, ...]),
1011					*empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"),
1012				]
1013			return []
1014
1015		def _trailing_tokens(
1016			self,
1017			c: Coord,
1018			coord_tokenizer: CoordTokenizers._CoordTokenizer,
1019		) -> list[str]:
1020			"""Returns tokens following those from the sequence from `_single_step_tokens`.
1021
1022			<PATH_END> should NOT be included.
1023			"""
1024			return []
1025
1026		# inherits docstring
1027		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
1028			output: bool
1029
1030			if len(set(self.step_tokenizers)) != len(self.step_tokenizers):
1031				# Uninteresting: repeated elements are not useful
1032				output = False
1033			else:
1034				# we do noqa for the comment if false
1035				if len(self.step_tokenizers) == 1 and isinstance(
1036					self.step_tokenizers[0],
1037					StepTokenizers.Distance,
1038				):
1039					# Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required.
1040					output = False
1041				else:
1042					output = True
1043
1044			if not output and do_except:
1045				raise ValueError(
1046					"PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.",
1047				)
1048
1049			return output
1050
1051
1052class PromptSequencers(__TokenizerElementNamespace):
1053	"""Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`."""
1054
1055	key = "prompt_sequencer"
1056
1057	@serializable_dataclass(frozen=True, kw_only=True)
1058	class _PromptSequencer(_TokenizerElement, abc.ABC):
1059		"""Sequences token regions into a complete maze tokenization.
1060
1061		# Parameters
1062		- `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position.
1063		- `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`.
1064		Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s.
1065		"""
1066
1067		coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field(
1068			default=CoordTokenizers.UT(),
1069			loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers),
1070		)
1071		adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field(
1072			default=AdjListTokenizers.AdjListCoord(),
1073			loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers),
1074		)
1075
1076		@classmethod
1077		def attribute_key(cls) -> str:
1078			return PromptSequencers.key
1079
1080		@staticmethod
1081		def _trim_if_unsolved_maze(
1082			untrimmed: list[str],
1083			is_untargeted: bool = False,
1084			is_unsolved: bool = False,
1085		) -> list[str]:
1086			"""Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze.
1087
1088			# Development
1089			This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP.
1090			It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses.
1091			"""
1092			if is_untargeted:
1093				return tokens_between(
1094					untrimmed,
1095					VOCAB.ADJLIST_START,
1096					VOCAB.ADJLIST_END,
1097					include_start=True,
1098					include_end=True,
1099				)
1100			if is_unsolved:
1101				if VOCAB.TARGET_END in untrimmed:
1102					return tokens_between(
1103						untrimmed,
1104						VOCAB.ADJLIST_START,
1105						VOCAB.TARGET_END,
1106						include_start=True,
1107						include_end=True,
1108					)
1109				else:
1110					return tokens_between(
1111						untrimmed,
1112						VOCAB.ADJLIST_START,
1113						VOCAB.ORIGIN_END,
1114						include_start=True,
1115						include_end=True,
1116					)
1117			return untrimmed
1118
1119		def to_tokens(
1120			self,
1121			maze: LatticeMaze,
1122			*args,
1123			**kwargs,
1124		) -> list[str]:
1125			"""Returns a complete list of tokens for a given set of maze elements."""
1126			untrimmed: list[str] = self._sequence_tokens(
1127				*self._get_prompt_regions(maze),
1128			)
1129			return self._trim_if_unsolved_maze(
1130				untrimmed,
1131				not hasattr(maze, "start_pos"),
1132				not hasattr(maze, "solution"),
1133			)
1134
1135		def _get_prompt_regions(
1136			self,
1137			maze: LatticeMaze,
1138			*args,
1139			**kwargs,
1140		) -> list[list[str]]:
1141			"""Gets the prompt regions of a maze in a fixed sequence.
1142
1143			This method is NOT responsible for including/excluding any prompt regions.
1144			Always return according to the API described under Returns.
1145			This implementation is expected to be suitable for most `PromptSequencer` subclasses.
1146			Subclasses may override this method if needed for special behavior.
1147
1148			# Returns
1149			- [0]: list[str] Adjacency list tokens
1150			- [1]: list[str] Origin tokens
1151			- [2]: list[str] Target tokens
1152			- [3]: list[str] Path tokens
1153
1154			# `None`-valued Args
1155			If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized.
1156			To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables.
1157			"""
1158			origin: Coord | None = getattr(maze, "start_pos", None)
1159			target: list[Coord] | None = [
1160				getattr(maze, "end_pos", None),
1161			]  # TargetTokenizer requires target: Sequence[Coord]
1162
1163			return [
1164				(
1165					self.adj_list_tokenizer.to_tokens(
1166						maze,
1167						coord_tokenizer=self.coord_tokenizer,
1168					)
1169					if hasattr(self, "adj_list_tokenizer")
1170					else []
1171				),
1172				self.coord_tokenizer.to_tokens(origin) if origin is not None else [],
1173				(
1174					self.target_tokenizer.to_tokens(
1175						target,
1176						coord_tokenizer=self.coord_tokenizer,
1177					)
1178					if target[0] is not None and hasattr(self, "target_tokenizer")
1179					else []
1180				),
1181				(
1182					self.path_tokenizer.to_tokens(
1183						maze,
1184						coord_tokenizer=self.coord_tokenizer,
1185					)
1186					if hasattr(maze, "solution") and hasattr(self, "path_tokenizer")
1187					else []
1188				),
1189			]
1190
1191		@abc.abstractmethod
1192		def _sequence_tokens(
1193			self,
1194			adj_list: list[str],
1195			origin: list[str] | None,
1196			target: list[str] | None,
1197			path: list[str] | None,
1198		) -> list[str]:
1199			"""Sequences token regions into a complete prompt.
1200
1201			Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as <ADJLIST_START>, <ORIGIN_END>, etc.
1202
1203			# Parameters
1204			- `adj_list`: Tokens representing the adjacency list
1205			- `origin`: Tokens representing the origin
1206			- `target`: Tokens representing the target
1207			- `path`: Tokens representing the path
1208			"""
1209			pass
1210
1211		def is_valid(self, do_except: bool = False) -> bool:
1212			# No invalid instances possible within data member type hint bounds
1213			return True
1214
1215	@serializable_dataclass(frozen=True, kw_only=True)
1216	class AOTP(_PromptSequencer):
1217		"""Sequences a prompt as [adjacency list, origin, target, path].
1218
1219		# Parameters
1220		- `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`.
1221		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`.
1222		- `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1223		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1224
1225		"""
1226
1227		target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field(
1228			default=TargetTokenizers.Unlabeled(),
1229			loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers),
1230		)
1231		path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1232			default=PathTokenizers.StepSequence(),
1233			loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1234		)
1235
1236		def _sequence_tokens(
1237			self,
1238			adj_list: list[str],
1239			origin: list[str],
1240			target: list[str],
1241			path: list[str],
1242		) -> list[str]:
1243			return [
1244				VOCAB.ADJLIST_START,
1245				*adj_list,
1246				VOCAB.ADJLIST_END,
1247				VOCAB.ORIGIN_START,
1248				*origin,
1249				VOCAB.ORIGIN_END,
1250				VOCAB.TARGET_START,
1251				*target,
1252				VOCAB.TARGET_END,
1253				VOCAB.PATH_START,
1254				*path,
1255				VOCAB.PATH_END,
1256			]
1257
1258	@serializable_dataclass(frozen=True, kw_only=True)
1259	class AOP(_PromptSequencer):
1260		"""Sequences a prompt as [adjacency list, origin, path].
1261
1262		Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself.
1263
1264		# Parameters
1265		- `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1266		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1267		"""
1268
1269		path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1270			default=PathTokenizers.StepSequence(),
1271			loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1272		)
1273
1274		def _sequence_tokens(
1275			self,
1276			adj_list: list[str],
1277			origin: list[str],
1278			# explicitly no target in this tokenizer
1279			target: list[str],
1280			path: list[str],
1281		) -> list[str]:
1282			return [
1283				VOCAB.ADJLIST_START,
1284				*adj_list,
1285				VOCAB.ADJLIST_END,
1286				VOCAB.ORIGIN_START,
1287				*origin,
1288				VOCAB.ORIGIN_END,
1289				VOCAB.TARGET_START,
1290				VOCAB.TARGET_END,
1291				VOCAB.PATH_START,
1292				*path,
1293				VOCAB.PATH_END,
1294			]

 48class CoordTokenizers(__TokenizerElementNamespace):
 49	"""Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
 50
 51	key = "coord_tokenizer"
 52
 53	@serializable_dataclass(frozen=True, kw_only=True)
 54	class _CoordTokenizer(_TokenizerElement, abc.ABC):
 55		"""Superclass for classes which tokenize singular coords in a maze."""
 56
 57		@abc.abstractmethod
 58		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:
 59			pass
 60
 61		@classmethod
 62		def attribute_key(cls) -> str:
 63			return CoordTokenizers.key
 64
 65		def is_valid(self, do_except: bool = False) -> bool:
 66			# No invalid instances possible within data member type hint bounds
 67			return True
 68
 69	@serializable_dataclass(frozen=True, kw_only=True)
 70	class UT(_CoordTokenizer):
 71		"""Unique token coordinate tokenizer."""
 72
 73		# inherit docstring
 74		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
 75			return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]
 76
 77	@serializable_dataclass(frozen=True, kw_only=True)
 78	class CTT(_CoordTokenizer):
 79		"""Coordinate tuple tokenizer
 80
 81		# Parameters
 82		- `pre`: Whether all coords include an integral preceding delimiter token
 83		- `intra`: Whether all coords include a delimiter token between coordinates
 84		- `post`: Whether all coords include an integral following delimiter token
 85		"""
 86
 87		pre: bool = serializable_field(default=True)
 88		intra: bool = serializable_field(default=True)
 89		post: bool = serializable_field(default=True)
 90		# Implement methods
 91
 92		# inherit docstring
 93		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
 94			return [
 95				*empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"),
 96				str(coord[0]),
 97				*empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"),
 98				str(coord[1]),
 99				*empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"),
100			]

Namespace for _CoordTokenizer subclass hierarchy used by MazeTokenizerModular.

key = 'coord_tokenizer'
@serializable_dataclass(frozen=True, kw_only=True)
class CoordTokenizers.UT(CoordTokenizers._CoordTokenizer):
69	@serializable_dataclass(frozen=True, kw_only=True)
70	class UT(_CoordTokenizer):
71		"""Unique token coordinate tokenizer."""
72
73		# inherit docstring
74		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
75			return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]

Unique token coordinate tokenizer.

CoordTokenizers.UT( *, _type_: Literal["<class 'CoordTokenizers.UT'>"] = "<class 'CoordTokenizers.UT'>")
def to_tokens( self, coord: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int]) -> list[str]:
74		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
75			return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]

Converts a maze element into a list of tokens.

Not all _TokenizerElement subclasses produce tokens, so this is not an abstract method. Those subclasses which do produce tokens should override this method.

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
CoordTokenizers._CoordTokenizer
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class CoordTokenizers.CTT(CoordTokenizers._CoordTokenizer):
 77	@serializable_dataclass(frozen=True, kw_only=True)
 78	class CTT(_CoordTokenizer):
 79		"""Coordinate tuple tokenizer
 80
 81		# Parameters
 82		- `pre`: Whether all coords include an integral preceding delimiter token
 83		- `intra`: Whether all coords include a delimiter token between coordinates
 84		- `post`: Whether all coords include an integral following delimiter token
 85		"""
 86
 87		pre: bool = serializable_field(default=True)
 88		intra: bool = serializable_field(default=True)
 89		post: bool = serializable_field(default=True)
 90		# Implement methods
 91
 92		# inherit docstring
 93		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
 94			return [
 95				*empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"),
 96				str(coord[0]),
 97				*empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"),
 98				str(coord[1]),
 99				*empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"),
100			]

Coordinate tuple tokenizer

Parameters

  • pre: Whether all coords include an integral preceding delimiter token
  • intra: Whether all coords include a delimiter token between coordinates
  • post: Whether all coords include an integral following delimiter token
CoordTokenizers.CTT( *, _type_: Literal["<class 'CoordTokenizers.CTT'>"] = "<class 'CoordTokenizers.CTT'>", pre: bool = True, intra: bool = True, post: bool = True)
pre: bool = True
intra: bool = True
post: bool = True
def to_tokens( self, coord: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int]) -> list[str]:
 93		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
 94			return [
 95				*empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"),
 96				str(coord[0]),
 97				*empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"),
 98				str(coord[1]),
 99				*empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"),
100			]

Converts a maze element into a list of tokens.

Not all _TokenizerElement subclasses produce tokens, so this is not an abstract method. Those subclasses which do produce tokens should override this method.

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
CoordTokenizers._CoordTokenizer
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
103class EdgeGroupings(__TokenizerElementNamespace):
104	"""Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`."""
105
106	key = "edge_grouping"
107
108	class _GroupingTokenParams(TypedDict):
109		"""A uniform private hyperparameter interface used by `AdjListTokenizer`."""
110
111		connection_token_ordinal: Literal[0, 1, 2]
112		intra: bool
113		grouped: bool
114
115	@serializable_dataclass(frozen=True, kw_only=True)
116	class _EdgeGrouping(_TokenizerElement, abc.ABC):
117		"""Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping."""
118
119		@classmethod
120		def attribute_key(cls) -> str:
121			return EdgeGroupings.key
122
123		def is_valid(self, do_except: bool = False) -> bool:
124			return True
125
126		@abc.abstractmethod
127		def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
128			"""Divides a ConnectionArray into groups of edges.
129
130			Shuffles/sequences within each group if applicable.
131			"""
132			pass
133
134		@abc.abstractmethod
135		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
136			"""Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize.
137
138			These hyperparameters are not used by `_EdgeGrouping` internally.
139			They are located in `_EdgeGrouping` rather than in `AdjListTokenizer`
140			since the hyperparameter space is a function of the `_EdgeGrouping` subclass.
141			This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses
142			into a uniform private interface used by `AdjListTokenizer`.
143			"""
144			pass
145
146	@serializable_dataclass(frozen=True, kw_only=True)
147	class Ungrouped(_EdgeGrouping):
148		"""No grouping occurs, each edge is tokenized individually.
149
150		# Parameters
151		- `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears.
152		Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
153		"""
154
155		connection_token_ordinal: Literal[0, 1, 2] = serializable_field(
156			default=1,
157			assert_type=False,
158		)
159
160		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
161			return EdgeGroupings._GroupingTokenParams(
162				connection_token_ordinal=self.connection_token_ordinal,
163				intra=False,
164				grouped=False,
165			)
166
167		def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]:
168			return np.expand_dims(edges, 1)
169
170	@serializable_dataclass(frozen=True, kw_only=True)
171	@mark_as_unsupported(_unsupported_is_invalid)
172	class ByLeadingCoord(_EdgeGrouping):
173		"""All edges with the same leading coord are grouped together.
174
175		# Parameters
176		- `intra`: Whether all edge groupings include a delimiter token between individual edge representations.
177		Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `)
178		- `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order.
179		If false, the fixed order is lexicographical by (row, col).
180		In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.
181		- `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears.
182		Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
183		"""
184
185		intra: bool = serializable_field(default=True)
186		shuffle_group: bool = serializable_field(default=True)
187		connection_token_ordinal: Literal[0, 1] = serializable_field(
188			default=0,
189			assert_type=False,
190		)
191
192		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
193			return EdgeGroupings._GroupingTokenParams(
194				connection_token_ordinal=self.connection_token_ordinal,
195				intra=self.intra,
196				grouped=True,
197			)
198
199		def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
200			# Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function
201			index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort(
202				(edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]),
203			)
204			sorted_edges: ConnectionArray = edges[index_array, ...]
205			groups: list[ConnectionArray] = np.split(
206				sorted_edges,
207				np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:],
208			)
209			if self.shuffle_group:
210				[numpy_rng.shuffle(g, axis=0) for g in groups]
211			return groups

Namespace for _EdgeGrouping subclass hierarchy used by _AdjListTokenizer.

key = 'edge_grouping'
@serializable_dataclass(frozen=True, kw_only=True)
class EdgeGroupings.Ungrouped(EdgeGroupings._EdgeGrouping):
146	@serializable_dataclass(frozen=True, kw_only=True)
147	class Ungrouped(_EdgeGrouping):
148		"""No grouping occurs, each edge is tokenized individually.
149
150		# Parameters
151		- `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears.
152		Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
153		"""
154
155		connection_token_ordinal: Literal[0, 1, 2] = serializable_field(
156			default=1,
157			assert_type=False,
158		)
159
160		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
161			return EdgeGroupings._GroupingTokenParams(
162				connection_token_ordinal=self.connection_token_ordinal,
163				intra=False,
164				grouped=False,
165			)
166
167		def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]:
168			return np.expand_dims(edges, 1)

No grouping occurs, each edge is tokenized individually.

Parameters

  • connection_token_ordinal: At which index in the edge tokenization the connector (or wall) token appears. Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
EdgeGroupings.Ungrouped( *, _type_: Literal["<class 'EdgeGroupings.Ungrouped'>"] = "<class 'EdgeGroupings.Ungrouped'>", connection_token_ordinal: Literal[0, 1, 2] = 1)
connection_token_ordinal: Literal[0, 1, 2] = 1
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
EdgeGroupings._EdgeGrouping
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
@mark_as_unsupported(_unsupported_is_invalid)
class EdgeGroupings.ByLeadingCoord(EdgeGroupings._EdgeGrouping):
170	@serializable_dataclass(frozen=True, kw_only=True)
171	@mark_as_unsupported(_unsupported_is_invalid)
172	class ByLeadingCoord(_EdgeGrouping):
173		"""All edges with the same leading coord are grouped together.
174
175		# Parameters
176		- `intra`: Whether all edge groupings include a delimiter token between individual edge representations.
177		Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `)
178		- `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order.
179		If false, the fixed order is lexicographical by (row, col).
180		In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.
181		- `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears.
182		Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
183		"""
184
185		intra: bool = serializable_field(default=True)
186		shuffle_group: bool = serializable_field(default=True)
187		connection_token_ordinal: Literal[0, 1] = serializable_field(
188			default=0,
189			assert_type=False,
190		)
191
192		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
193			return EdgeGroupings._GroupingTokenParams(
194				connection_token_ordinal=self.connection_token_ordinal,
195				intra=self.intra,
196				grouped=True,
197			)
198
199		def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
200			# Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function
201			index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort(
202				(edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]),
203			)
204			sorted_edges: ConnectionArray = edges[index_array, ...]
205			groups: list[ConnectionArray] = np.split(
206				sorted_edges,
207				np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:],
208			)
209			if self.shuffle_group:
210				[numpy_rng.shuffle(g, axis=0) for g in groups]
211			return groups

All edges with the same leading coord are grouped together.

Parameters

  • intra: Whether all edge groupings include a delimiter token between individual edge representations. Note that each edge representation will already always include a connector token (VOCAB.CONNECTOR, or possibly `)
  • shuffle_group: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. If false, the fixed order is lexicographical by (row, col). In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.
  • connection_token_ordinal: At which index in token sequence representing a single edge the connector (or wall) token appears. Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
EdgeGroupings.ByLeadingCoord( *, _type_: Literal["<class 'EdgeGroupings.ByLeadingCoord'>"] = "<class 'EdgeGroupings.ByLeadingCoord'>", intra: bool = True, shuffle_group: bool = True, connection_token_ordinal: Literal[0, 1] = 0)
intra: bool = True
shuffle_group: bool = True
connection_token_ordinal: Literal[0, 1] = 0
def is_valid(self, do_except: bool = False) -> bool:
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool:  # noqa: ANN001
258	"""Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes"""
259	if do_except:
260		err_msg: str = (
261			f"Class `{type(self).__name__ = }, marked as unsupported, is not valid."
262			f"{type(self) = }, {self = }"
263		)
264		raise ValueError(err_msg)
265
266	return False

Default implementation of is_valid for mark_as_unsupported-decorated classes

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
EdgeGroupings._EdgeGrouping
attribute_key
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
214class EdgePermuters(__TokenizerElementNamespace):
215	"""Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`."""
216
217	key = "edge_permuter"
218
219	@serializable_dataclass(frozen=True, kw_only=True)
220	class _EdgePermuter(_TokenizerElement, abc.ABC):
221		"""Specifies how to sequence the two coords that encode a lattice edge."""
222
223		@classmethod
224		def attribute_key(cls) -> str:
225			return EdgePermuters.key
226
227		def is_valid(self, do_except: bool = False) -> bool:
228			# No invalid instances possible within data member type hint bounds
229			return True
230
231		@staticmethod
232		@abc.abstractmethod
233		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
234			"""Executes a permutation.
235
236			Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation.
237
238			# Parameters
239			- `lattice_edges`: Array of lattice edges.
240			The two coords in shape[1] must be adjacent in the lattice.
241
242			# Returns
243			- Array of lattice edges with entries along shape[1] systematically permuted.
244			- shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`.
245			"""
246			pass
247
248	@serializable_dataclass(frozen=True, kw_only=True)
249	class SortedCoords(_EdgePermuter):
250		"""returns a sorted representation. useful for checking consistency"""
251
252		@staticmethod
253		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
254			return lattice_edges[
255				np.lexsort(
256					(
257						lattice_edges[:, 1, 1],
258						lattice_edges[:, 1, 0],
259						lattice_edges[:, 0, 1],
260						lattice_edges[:, 0, 0],
261					),
262				),
263				...,
264			]
265
266	@serializable_dataclass(frozen=True, kw_only=True)
267	class RandomCoords(_EdgePermuter):
268		"""Permutes each edge randomly."""
269
270		@staticmethod
271		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
272			numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges)
273			return lattice_edges
274
275	@serializable_dataclass(frozen=True, kw_only=True)
276	class BothCoords(_EdgePermuter):
277		"""Includes both possible permutations of every edge in the output.
278
279		Since input ConnectionList has only 1 instance of each edge,
280		a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`.
281		"""
282
283		@staticmethod
284		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
285			return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)

Namespace for _EdgePermuter subclass hierarchy used by _AdjListTokenizer.

key = 'edge_permuter'
@serializable_dataclass(frozen=True, kw_only=True)
class EdgePermuters.SortedCoords(EdgePermuters._EdgePermuter):
248	@serializable_dataclass(frozen=True, kw_only=True)
249	class SortedCoords(_EdgePermuter):
250		"""returns a sorted representation. useful for checking consistency"""
251
252		@staticmethod
253		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
254			return lattice_edges[
255				np.lexsort(
256					(
257						lattice_edges[:, 1, 1],
258						lattice_edges[:, 1, 0],
259						lattice_edges[:, 0, 1],
260						lattice_edges[:, 0, 0],
261					),
262				),
263				...,
264			]

returns a sorted representation. useful for checking consistency

EdgePermuters.SortedCoords( *, _type_: Literal["<class 'EdgePermuters.SortedCoords'>"] = "<class 'EdgePermuters.SortedCoords'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
EdgePermuters._EdgePermuter
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class EdgePermuters.RandomCoords(EdgePermuters._EdgePermuter):
266	@serializable_dataclass(frozen=True, kw_only=True)
267	class RandomCoords(_EdgePermuter):
268		"""Permutes each edge randomly."""
269
270		@staticmethod
271		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
272			numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges)
273			return lattice_edges

Permutes each edge randomly.

EdgePermuters.RandomCoords( *, _type_: Literal["<class 'EdgePermuters.RandomCoords'>"] = "<class 'EdgePermuters.RandomCoords'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
EdgePermuters._EdgePermuter
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class EdgePermuters.BothCoords(EdgePermuters._EdgePermuter):
275	@serializable_dataclass(frozen=True, kw_only=True)
276	class BothCoords(_EdgePermuter):
277		"""Includes both possible permutations of every edge in the output.
278
279		Since input ConnectionList has only 1 instance of each edge,
280		a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`.
281		"""
282
283		@staticmethod
284		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
285			return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)

Includes both possible permutations of every edge in the output.

Since input ConnectionList has only 1 instance of each edge, a call to BothCoords._permute will modify lattice_edges in-place, doubling shape[0].

EdgePermuters.BothCoords( *, _type_: Literal["<class 'EdgePermuters.BothCoords'>"] = "<class 'EdgePermuters.BothCoords'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
EdgePermuters._EdgePermuter
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
288class EdgeSubsets(__TokenizerElementNamespace):
289	"""Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`."""
290
291	key = "edge_subset"
292
293	@serializable_dataclass(frozen=True, kw_only=True)
294	class _EdgeSubset(_TokenizerElement, abc.ABC):
295		"""Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized."""
296
297		@classmethod
298		def attribute_key(cls) -> str:
299			return EdgeSubsets.key
300
301		def is_valid(self, do_except: bool = False) -> bool:
302			return True
303
304		@abc.abstractmethod
305		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
306			"""Returns the set of lattice edges to be tokenized."""
307			pass
308
309	@serializable_dataclass(frozen=True, kw_only=True)
310	class AllLatticeEdges(_EdgeSubset):
311		"""All 2n**2-2n edges of the lattice are tokenized.
312
313		If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`.
314		"""
315
316		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
317			return lattice_connection_array(maze.grid_n)
318
319	@serializable_dataclass(frozen=True, kw_only=True)
320	class ConnectionEdges(_EdgeSubset):
321		"""Only edges which contain a connection are tokenized.
322
323		Alternatively, only edges which contain a wall are tokenized.
324
325		# Parameters
326		- `walls`: Whether wall edges or connection edges are tokenized.
327		If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`.
328		"""
329
330		walls: bool = serializable_field(default=False)
331
332		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
333			conn_list: ConnectionList = maze.connection_list
334			if self.walls:
335				conn_list = np.logical_not(conn_list)
336				conn_list[0, -1, :] = False
337				conn_list[1, :, -1] = False
338			return connection_list_to_adj_list(
339				conn_list,
340				shuffle_d0=False,
341				shuffle_d1=False,
342			)

Namespace for _EdgeSubset subclass hierarchy used by _AdjListTokenizer.

key = 'edge_subset'
@serializable_dataclass(frozen=True, kw_only=True)
class EdgeSubsets.AllLatticeEdges(EdgeSubsets._EdgeSubset):
309	@serializable_dataclass(frozen=True, kw_only=True)
310	class AllLatticeEdges(_EdgeSubset):
311		"""All 2n**2-2n edges of the lattice are tokenized.
312
313		If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`.
314		"""
315
316		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
317			return lattice_connection_array(maze.grid_n)

All 2n**2-2n edges of the lattice are tokenized.

If a wall exists on that edge, the edge is tokenized in the same manner, using VOCAB.ADJLIST_WALL in place of VOCAB.CONNECTOR.

EdgeSubsets.AllLatticeEdges( *, _type_: Literal["<class 'EdgeSubsets.AllLatticeEdges'>"] = "<class 'EdgeSubsets.AllLatticeEdges'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
EdgeSubsets._EdgeSubset
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class EdgeSubsets.ConnectionEdges(EdgeSubsets._EdgeSubset):
319	@serializable_dataclass(frozen=True, kw_only=True)
320	class ConnectionEdges(_EdgeSubset):
321		"""Only edges which contain a connection are tokenized.
322
323		Alternatively, only edges which contain a wall are tokenized.
324
325		# Parameters
326		- `walls`: Whether wall edges or connection edges are tokenized.
327		If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`.
328		"""
329
330		walls: bool = serializable_field(default=False)
331
332		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
333			conn_list: ConnectionList = maze.connection_list
334			if self.walls:
335				conn_list = np.logical_not(conn_list)
336				conn_list[0, -1, :] = False
337				conn_list[1, :, -1] = False
338			return connection_list_to_adj_list(
339				conn_list,
340				shuffle_d0=False,
341				shuffle_d1=False,
342			)

Only edges which contain a connection are tokenized.

Alternatively, only edges which contain a wall are tokenized.

Parameters

  • walls: Whether wall edges or connection edges are tokenized. If true, VOCAB.ADJLIST_WALL is used in place of VOCAB.CONNECTOR.
EdgeSubsets.ConnectionEdges( *, _type_: Literal["<class 'EdgeSubsets.ConnectionEdges'>"] = "<class 'EdgeSubsets.ConnectionEdges'>", walls: bool = False)
walls: bool = False
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
EdgeSubsets._EdgeSubset
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
356class AdjListTokenizers(__TokenizerElementNamespace):
357	"""Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
358
359	key = "adj_list_tokenizer"
360
361	@serializable_dataclass(frozen=True, kw_only=True)
362	@mark_as_unsupported(_adjlist_no_pre_unsupported)
363	class _AdjListTokenizer(_TokenizerElement, abc.ABC):
364		"""Specifies how the adjacency list is tokenized.
365
366		Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations.
367		See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details.
368
369		# Parameters
370		- `pre`: Whether all edge groupings include a preceding delimiter token
371		- `post`: Whether all edge groupings include a following delimiter token
372		- `shuffle_d0`: Specifies how to sequence the edge groupings.
373			If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group.
374		- `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping.
375		- `edge_subset`: Specifies the subset of lattice edges to be tokenized.
376		- `edge_permuter`: Specifies, in each edge tokenization, which coord either:
377			1. Appears first in the tokenization, for `AdjListCoord`.
378			2. Is tokenized directly as a coord, for `AdjListCardinal`.
379			- `shuffle`: For each edge, the leading coord is selected randomly.
380			- `all`: Each edge appears twice in the tokenization, appearing with both leading coords.
381			- `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details.
382		"""
383
384		pre: bool = serializable_field(default=False, assert_type=False)
385		post: bool = serializable_field(default=True)
386		shuffle_d0: bool = serializable_field(default=True)
387		edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field(
388			default=EdgeGroupings.Ungrouped(),
389			loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings),
390		)
391		edge_subset: EdgeSubsets._EdgeSubset = serializable_field(
392			default=EdgeSubsets.ConnectionEdges(),
393			loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets),
394		)
395		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
396			default=EdgePermuters.RandomCoords(),
397			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
398		)
399
400		@classmethod
401		def attribute_key(cls) -> str:
402			return AdjListTokenizers.key
403
404		def is_valid(self, do_except: bool = False) -> bool:
405			# No invalid instances possible within data member type hint bounds
406			return True
407
408		@abc.abstractmethod
409		def _tokenization_callables(
410			self,
411			edges: ConnectionArray,
412			is_conn: Bool[np.ndarray, " edges"],
413			coord_tokenizer: CoordTokenizers._CoordTokenizer,
414			*args,
415			**kwargs,
416		) -> list[Callable]:
417			"""Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization.
418
419			# Returns
420			- `[0]`: leading coord tokens
421			- `[1]`: connector tokens
422			- `[2]`: trailing coord tokens
423			"""
424			pass
425
426		def _tokenize_edge_grouping(
427			self,
428			edges: ConnectionArray,
429			maze: LatticeMaze,
430			coord_tokenizer: CoordTokenizers._CoordTokenizer,
431			group_params: EdgeGroupings._GroupingTokenParams,
432		) -> Sequence[str]:
433			"""Tokenizes a single edge grouping."""
434			cxn_ord: int = group_params["connection_token_ordinal"]
435			is_conn: Bool[np.ndarray, edges] = is_connection(
436				edges,
437				maze.connection_list,
438			)
439			tokenize_callables = self._tokenization_callables(
440				edges,
441				is_conn,
442				coord_tokenizer,
443			)
444
445			if group_params["grouped"]:
446				# If grouped
447				callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1]
448				repeated_callables = [
449					tokenize_callables[i] for i in callable_permutation
450				]
451				return flatten(
452					[
453						tokenize_callables[0](0),
454						[
455							[
456								*[
457									tok_callable(i)
458									for tok_callable in repeated_callables
459								],
460								*(
461									(VOCAB.ADJLIST_INTRA,)
462									if group_params["intra"]
463									else ()
464								),
465							]
466							for i in range(edges.shape[0])
467						],
468					],
469				)
470			else:
471				# If ungrouped
472				callable_permutation = [0, 2]
473				callable_permutation.insert(cxn_ord, 1)
474				tokenize_callables = [
475					tokenize_callables[i] for i in callable_permutation
476				]
477
478				return flatten(
479					[
480						[
481							[
482								*[
483									tok_callable(i)
484									for tok_callable in tokenize_callables
485								],
486								*empty_sequence_if_attr_false(
487									(VOCAB.ADJLIST_INTRA,),
488									group_params,
489									"intra",
490								),
491							]
492							for i in range(edges.shape[0])
493						],
494					],
495				)
496
497		def to_tokens(
498			self,
499			maze: LatticeMaze,
500			coord_tokenizer: CoordTokenizers._CoordTokenizer,
501		) -> list[str]:
502			# Get the set of edges to be tokenized
503			edges: ConnectionArray = self.edge_subset._get_edges(maze)
504			# Systematically permute the leading coord of each edge
505			edges: ConnectionArray = self.edge_permuter._permute(edges)
506			group_params: EdgeGroupings._GroupingTokenParams = (
507				self.edge_grouping._token_params()
508			)
509			# then, we need to group the edges
510			groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges)
511			# shuffle the groups if specified
512			if self.shuffle_d0:
513				if isinstance(groups, np.ndarray):
514					numpy_rng.shuffle(groups, axis=0)
515				elif isinstance(groups, list):
516					random.shuffle(groups)
517				else:
518					err_msg: str = f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported."
519					raise TypeError(err_msg)
520			# Tokenize each group with optional delimiters
521			tokens: list[str] = list(
522				flatten(
523					[
524						[
525							*empty_sequence_if_attr_false(
526								(VOCAB.ADJLIST_PRE,),
527								self,
528								"pre",
529							),
530							*self._tokenize_edge_grouping(
531								group,
532								maze,
533								coord_tokenizer,
534								group_params,
535							),
536							*empty_sequence_if_attr_false(
537								(VOCAB.ADJACENCY_ENDLINE,),
538								self,
539								"post",
540							),
541						]
542						for group in groups
543					],
544				),
545			)
546			return tokens
547
548	@serializable_dataclass(frozen=True, kw_only=True)
549	class AdjListCoord(_AdjListTokenizer):
550		"""Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members."""
551
552		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
553			default=EdgePermuters.RandomCoords(),
554			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
555		)
556
557		def _tokenization_callables(
558			self,
559			edges: ConnectionArray,
560			is_conn: Bool[np.ndarray, " edges"],
561			coord_tokenizer: CoordTokenizers._CoordTokenizer,
562			*args,
563			**kwargs,
564		) -> list[Callable]:
565			# Map from `is_conn` to the tokens which represent connections and walls
566			conn_token_map: dict[bool, str] = {
567				True: VOCAB.CONNECTOR,
568				False: VOCAB.ADJLIST_WALL,
569			}
570			return [
571				lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
572				lambda i: conn_token_map[is_conn[i]],
573				lambda i: coord_tokenizer.to_tokens(edges[i, 1]),
574			]
575
576	@serializable_dataclass(frozen=True, kw_only=True)
577	class AdjListCardinal(_AdjListTokenizer):
578		"""Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.
579
580		# Parameters
581		- `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
582		"""
583
584		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
585			default=EdgePermuters.BothCoords(),
586			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
587		)
588
589		def _tokenization_callables(
590			self,
591			edges: ConnectionArray,
592			is_conn: Bool[np.ndarray, " edges"],
593			coord_tokenizer: CoordTokenizers._CoordTokenizer,
594			*args,
595			**kwargs,
596		) -> list[Callable]:
597			# Map from `is_conn` to the tokens which represent connections and walls
598			conn_token_map: dict[bool, str] = {
599				True: VOCAB.CONNECTOR,
600				False: VOCAB.ADJLIST_WALL,
601			}
602			return [
603				lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
604				lambda i: conn_token_map[is_conn[i]],
605				lambda i: get_cardinal_direction(edges[i]),
606			]

Namespace for _AdjListTokenizer subclass hierarchy used by MazeTokenizerModular.

key = 'adj_list_tokenizer'
@serializable_dataclass(frozen=True, kw_only=True)
class AdjListTokenizers.AdjListCoord(AdjListTokenizers._AdjListTokenizer):
548	@serializable_dataclass(frozen=True, kw_only=True)
549	class AdjListCoord(_AdjListTokenizer):
550		"""Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members."""
551
552		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
553			default=EdgePermuters.RandomCoords(),
554			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
555		)
556
557		def _tokenization_callables(
558			self,
559			edges: ConnectionArray,
560			is_conn: Bool[np.ndarray, " edges"],
561			coord_tokenizer: CoordTokenizers._CoordTokenizer,
562			*args,
563			**kwargs,
564		) -> list[Callable]:
565			# Map from `is_conn` to the tokens which represent connections and walls
566			conn_token_map: dict[bool, str] = {
567				True: VOCAB.CONNECTOR,
568				False: VOCAB.ADJLIST_WALL,
569			}
570			return [
571				lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
572				lambda i: conn_token_map[is_conn[i]],
573				lambda i: coord_tokenizer.to_tokens(edges[i, 1]),
574			]

Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.

AdjListTokenizers.AdjListCoord( *, pre: bool = False, post: bool = True, shuffle_d0: bool = True, edge_grouping: maze_dataset.tokenization.modular.elements.EdgeGroupings._EdgeGrouping = EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset: maze_dataset.tokenization.modular.elements.EdgeSubsets._EdgeSubset = EdgeSubsets.ConnectionEdges(walls=False), edge_permuter: maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter = EdgePermuters.RandomCoords(), _type_: Literal["<class 'AdjListTokenizers.AdjListCoord'>"] = "<class 'AdjListTokenizers.AdjListCoord'>")
edge_permuter: maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter = EdgePermuters.RandomCoords()
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@serializable_dataclass(frozen=True, kw_only=True)
class AdjListTokenizers.AdjListCardinal(AdjListTokenizers._AdjListTokenizer):
576	@serializable_dataclass(frozen=True, kw_only=True)
577	class AdjListCardinal(_AdjListTokenizer):
578		"""Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.
579
580		# Parameters
581		- `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
582		"""
583
584		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
585			default=EdgePermuters.BothCoords(),
586			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
587		)
588
589		def _tokenization_callables(
590			self,
591			edges: ConnectionArray,
592			is_conn: Bool[np.ndarray, " edges"],
593			coord_tokenizer: CoordTokenizers._CoordTokenizer,
594			*args,
595			**kwargs,
596		) -> list[Callable]:
597			# Map from `is_conn` to the tokens which represent connections and walls
598			conn_token_map: dict[bool, str] = {
599				True: VOCAB.CONNECTOR,
600				False: VOCAB.ADJLIST_WALL,
601			}
602			return [
603				lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
604				lambda i: conn_token_map[is_conn[i]],
605				lambda i: get_cardinal_direction(edges[i]),
606			]

Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.

Parameters

  • coord_first: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
AdjListTokenizers.AdjListCardinal( *, pre: bool = False, post: bool = True, shuffle_d0: bool = True, edge_grouping: maze_dataset.tokenization.modular.elements.EdgeGroupings._EdgeGrouping = EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset: maze_dataset.tokenization.modular.elements.EdgeSubsets._EdgeSubset = EdgeSubsets.ConnectionEdges(walls=False), edge_permuter: maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter = EdgePermuters.BothCoords(), _type_: Literal["<class 'AdjListTokenizers.AdjListCardinal'>"] = "<class 'AdjListTokenizers.AdjListCardinal'>")
edge_permuter: maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter = EdgePermuters.BothCoords()
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

609class TargetTokenizers(__TokenizerElementNamespace):
610	"""Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
611
612	key = "target_tokenizer"
613
614	@serializable_dataclass(frozen=True, kw_only=True)
615	class _TargetTokenizer(_TokenizerElement, abc.ABC):
616		"""Superclass of tokenizers for maze targets."""
617
618		@abc.abstractmethod
619		def to_tokens(
620			self,
621			targets: Sequence[Coord],
622			coord_tokenizer: CoordTokenizers._CoordTokenizer,
623		) -> list[str]:
624			"""Returns tokens representing the target."""
625			pass
626
627		@classmethod
628		def attribute_key(cls) -> str:
629			return TargetTokenizers.key
630
631	@serializable_dataclass(frozen=True, kw_only=True)
632	class Unlabeled(_TargetTokenizer):
633		"""Targets are simply listed as coord tokens.
634
635		- `post`: Whether all coords include an integral following delimiter token
636		"""
637
638		post: bool = serializable_field(default=False)
639
640		# inherit docstring
641		def to_tokens(  # noqa: D102
642			self,
643			targets: Sequence[Coord],
644			coord_tokenizer: CoordTokenizers._CoordTokenizer,
645		) -> list[str]:
646			return list(
647				flatten(
648					[
649						[
650							*coord_tokenizer.to_tokens(target),
651							*empty_sequence_if_attr_false(
652								[VOCAB.TARGET_POST],
653								self,
654								"post",
655							),
656						]
657						for target in targets
658					],
659				),
660			)
661
662		# inherit docstring
663		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
664			# No invalid instances possible within data member type hint bounds
665			return True

Namespace for _TargetTokenizer subclass hierarchy used by MazeTokenizerModular.

key = 'target_tokenizer'
@serializable_dataclass(frozen=True, kw_only=True)
class TargetTokenizers.Unlabeled(TargetTokenizers._TargetTokenizer):
631	@serializable_dataclass(frozen=True, kw_only=True)
632	class Unlabeled(_TargetTokenizer):
633		"""Targets are simply listed as coord tokens.
634
635		- `post`: Whether all coords include an integral following delimiter token
636		"""
637
638		post: bool = serializable_field(default=False)
639
640		# inherit docstring
641		def to_tokens(  # noqa: D102
642			self,
643			targets: Sequence[Coord],
644			coord_tokenizer: CoordTokenizers._CoordTokenizer,
645		) -> list[str]:
646			return list(
647				flatten(
648					[
649						[
650							*coord_tokenizer.to_tokens(target),
651							*empty_sequence_if_attr_false(
652								[VOCAB.TARGET_POST],
653								self,
654								"post",
655							),
656						]
657						for target in targets
658					],
659				),
660			)
661
662		# inherit docstring
663		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
664			# No invalid instances possible within data member type hint bounds
665			return True

Targets are simply listed as coord tokens.

  • post: Whether all coords include an integral following delimiter token
TargetTokenizers.Unlabeled( *, _type_: Literal["<class 'TargetTokenizers.Unlabeled'>"] = "<class 'TargetTokenizers.Unlabeled'>", post: bool = False)
post: bool = False
def to_tokens( self, targets: Sequence[jaxtyping.Int8[ndarray, 'row_col=2']], coord_tokenizer: maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer) -> list[str]:
641		def to_tokens(  # noqa: D102
642			self,
643			targets: Sequence[Coord],
644			coord_tokenizer: CoordTokenizers._CoordTokenizer,
645		) -> list[str]:
646			return list(
647				flatten(
648					[
649						[
650							*coord_tokenizer.to_tokens(target),
651							*empty_sequence_if_attr_false(
652								[VOCAB.TARGET_POST],
653								self,
654								"post",
655							),
656						]
657						for target in targets
658					],
659				),
660			)

Returns tokens representing the target.

def is_valid(self, do_except: bool = False) -> bool:
663		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
664			# No invalid instances possible within data member type hint bounds
665			return True

Returns if self contains data members capable of producing an overall valid MazeTokenizerModular.

Some _TokenizerElement instances may be created which are not useful despite obeying data member type hints. is_valid allows for more precise detection of invalid _TokenizerElements beyond type hinting alone. If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply return True for that subclass.

Types of Invalidity

In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity:

  • Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. E.g., _TokenizerElements which are strictly worse than some alternative.
  • Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
  • Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
  • Erroneous: These tokenizers might raise exceptions during use.

Development

is_invalid is implemented to always return True in some abstract classes where all currently possible subclass instances are valid. When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.

Nesting

In general, when implementing this method, there is no need to recursively call is_valid on nested _TokenizerElements contained in the class. In other words, failures of is_valid need not bubble up to the top of the nested _TokenizerElement tree. MazeTokenizerModular.is_valid calls is_valid on each of its _TokenizerElements individually, so failure at any level will be detected.

Types of Invalidity

If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. This could be used to create more or less stringent filters on the valid _TokenizerElement instances.

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
TargetTokenizers._TargetTokenizer
attribute_key
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
668class StepSizes(__TokenizerElementNamespace):
669	"""Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`."""
670
671	key = "step_size"
672
673	@serializable_dataclass(frozen=True, kw_only=True)
674	class _StepSize(_TokenizerElement, abc.ABC):
675		"""Specifies which coords in `maze.solution` are used to represent the path."""
676
677		@classmethod
678		def attribute_key(cls) -> str:
679			return StepSizes.key
680
681		@abc.abstractmethod  # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call
682		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
683			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
684			raise NotImplementedError(
685				"Subclasses must implement `StepSize.step_indices.",
686			)
687
688		def step_start_end_indices(self, maze: SolvedMaze) -> list[tuple[int, int]]:
689			"""Returns steps as tuples of starting and ending positions for each step."""
690			indices: list[int] = self._step_single_indices(maze)
691			# TODO: RUF007 Prefer `itertools.pairwise()` over `zip()` when iterating over successive pairs
692			return [
693				(start, end)
694				for start, end in zip(indices[:-1], indices[1:], strict=False)  # noqa: RUF007
695			]
696
697		def is_valid(self, do_except: bool = False) -> bool:
698			# No invalid instances possible within data member type hint bounds
699			return True
700
701	@serializable_dataclass(frozen=True, kw_only=True)
702	class Singles(_StepSize):
703		"""Every coord in `maze.solution` is represented.
704
705		Legacy tokenizers all use this behavior.
706		"""
707
708		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
709			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
710			return list(range(maze.solution.shape[0]))
711
712	@serializable_dataclass(frozen=True, kw_only=True)
713	@mark_as_unsupported(_unsupported_is_invalid)
714	class Straightaways(_StepSize):
715		"""Only coords where the path turns are represented in the path.
716
717		I.e., the path is represented as a sequence of straightaways,
718		specified by the coords at the turns.
719		"""
720
721		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
722			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
723			last_turn_coord: Coord = maze.solution[0, ...]
724			indices: list[int] = [0]
725			for i, coord in enumerate(maze.solution):
726				if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]:
727					indices.append(i - 1)
728					last_turn_coord = maze.solution[i - 1, ...]
729			indices.append(i)
730			return indices
731
732	@serializable_dataclass(frozen=True, kw_only=True)
733	class Forks(_StepSize):
734		"""Only coords at forks, where the path has >=2 options for the next step are included.
735
736		Excludes the option of backtracking.
737		The starting and ending coords are always included.
738		"""
739
740		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
741			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
742			return maze.get_solution_forking_points(always_include_endpoints=True)[0]
743
744	@serializable_dataclass(frozen=True, kw_only=True)
745	@mark_as_unsupported(_unsupported_is_invalid)
746	class ForksAndStraightaways(_StepSize):
747		"""Includes the union of the coords included by `Forks` and `Straightaways`.
748
749		See documentation for those classes for details.
750		"""
751
752		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
753			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
754			return list(
755				np.unique(
756					np.concatenate(
757						(
758							StepSizes.Straightaways()._step_single_indices(maze),
759							StepSizes.Forks()._step_single_indices(maze),
760						),
761					),
762				),
763			)

Namespace for _StepSize subclass hierarchy used by MazeTokenizerModular.

key = 'step_size'
@serializable_dataclass(frozen=True, kw_only=True)
class StepSizes.Singles(StepSizes._StepSize):
701	@serializable_dataclass(frozen=True, kw_only=True)
702	class Singles(_StepSize):
703		"""Every coord in `maze.solution` is represented.
704
705		Legacy tokenizers all use this behavior.
706		"""
707
708		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
709			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
710			return list(range(maze.solution.shape[0]))

Every coord in maze.solution is represented.

Legacy tokenizers all use this behavior.

StepSizes.Singles( *, _type_: Literal["<class 'StepSizes.Singles'>"] = "<class 'StepSizes.Singles'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
StepSizes._StepSize
attribute_key
step_start_end_indices
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
@mark_as_unsupported(_unsupported_is_invalid)
class StepSizes.Straightaways(StepSizes._StepSize):
712	@serializable_dataclass(frozen=True, kw_only=True)
713	@mark_as_unsupported(_unsupported_is_invalid)
714	class Straightaways(_StepSize):
715		"""Only coords where the path turns are represented in the path.
716
717		I.e., the path is represented as a sequence of straightaways,
718		specified by the coords at the turns.
719		"""
720
721		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
722			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
723			last_turn_coord: Coord = maze.solution[0, ...]
724			indices: list[int] = [0]
725			for i, coord in enumerate(maze.solution):
726				if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]:
727					indices.append(i - 1)
728					last_turn_coord = maze.solution[i - 1, ...]
729			indices.append(i)
730			return indices

Only coords where the path turns are represented in the path.

I.e., the path is represented as a sequence of straightaways, specified by the coords at the turns.

StepSizes.Straightaways( *, _type_: Literal["<class 'StepSizes.Straightaways'>"] = "<class 'StepSizes.Straightaways'>")
def is_valid(self, do_except: bool = False) -> bool:
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool:  # noqa: ANN001
258	"""Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes"""
259	if do_except:
260		err_msg: str = (
261			f"Class `{type(self).__name__ = }, marked as unsupported, is not valid."
262			f"{type(self) = }, {self = }"
263		)
264		raise ValueError(err_msg)
265
266	return False

Default implementation of is_valid for mark_as_unsupported-decorated classes

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
StepSizes._StepSize
attribute_key
step_start_end_indices
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class StepSizes.Forks(StepSizes._StepSize):
732	@serializable_dataclass(frozen=True, kw_only=True)
733	class Forks(_StepSize):
734		"""Only coords at forks, where the path has >=2 options for the next step are included.
735
736		Excludes the option of backtracking.
737		The starting and ending coords are always included.
738		"""
739
740		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
741			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
742			return maze.get_solution_forking_points(always_include_endpoints=True)[0]

Only coords at forks, where the path has >=2 options for the next step are included.

Excludes the option of backtracking. The starting and ending coords are always included.

StepSizes.Forks( *, _type_: Literal["<class 'StepSizes.Forks'>"] = "<class 'StepSizes.Forks'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
StepSizes._StepSize
attribute_key
step_start_end_indices
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
@mark_as_unsupported(_unsupported_is_invalid)
class StepSizes.ForksAndStraightaways(StepSizes._StepSize):
744	@serializable_dataclass(frozen=True, kw_only=True)
745	@mark_as_unsupported(_unsupported_is_invalid)
746	class ForksAndStraightaways(_StepSize):
747		"""Includes the union of the coords included by `Forks` and `Straightaways`.
748
749		See documentation for those classes for details.
750		"""
751
752		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
753			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
754			return list(
755				np.unique(
756					np.concatenate(
757						(
758							StepSizes.Straightaways()._step_single_indices(maze),
759							StepSizes.Forks()._step_single_indices(maze),
760						),
761					),
762				),
763			)

Includes the union of the coords included by Forks and Straightaways.

See documentation for those classes for details.

StepSizes.ForksAndStraightaways( *, _type_: Literal["<class 'StepSizes.ForksAndStraightaways'>"] = "<class 'StepSizes.ForksAndStraightaways'>")
def is_valid(self, do_except: bool = False) -> bool:
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool:  # noqa: ANN001
258	"""Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes"""
259	if do_except:
260		err_msg: str = (
261			f"Class `{type(self).__name__ = }, marked as unsupported, is not valid."
262			f"{type(self) = }, {self = }"
263		)
264		raise ValueError(err_msg)
265
266	return False

Default implementation of is_valid for mark_as_unsupported-decorated classes

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
StepSizes._StepSize
attribute_key
step_start_end_indices
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
766class StepTokenizers(__TokenizerElementNamespace):
767	"""Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
768
769	key = "step_tokenizers"
770
771	@serializable_dataclass(frozen=True, kw_only=True)
772	class _StepTokenizer(_TokenizerElement, abc.ABC):
773		"""Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized."""
774
775		@classmethod
776		def attribute_key(cls) -> str:
777			return StepTokenizers.key
778
779		@abc.abstractmethod
780		def to_tokens(
781			self,
782			maze: SolvedMaze,
783			start_index: int,
784			end_index: int,
785			**kwargs,
786		) -> list[str]:
787			"""Tokenizes a single step in the solution.
788
789			# Parameters
790			- `maze`: Maze to be tokenized
791			- `start_index`: The index of the Coord in `maze.solution` at which the current step starts
792			- `end_index`: The index of the Coord in `maze.solution` at which the current step ends
793			"""
794			raise NotImplementedError(
795				"Subclasses must implement `StepTokenizer.to_tokens.",
796			)
797
798		def is_valid(self, do_except: bool = False) -> bool:
799			# No invalid instances possible within data member type hint bounds
800			return True
801
802	@serializable_dataclass(frozen=True, kw_only=True)
803	class Coord(_StepTokenizer):
804		"""A direct tokenization of the end position coord represents the step."""
805
806		# inherit docstring
807		def to_tokens(  # noqa: D102
808			self,
809			maze: SolvedMaze,
810			start_index: int,
811			end_index: int,
812			coord_tokenizer: CoordTokenizers._CoordTokenizer,
813		) -> list[str]:
814			return coord_tokenizer.to_tokens(maze.solution[end_index, ...])
815
816	@serializable_dataclass(frozen=True, kw_only=True)
817	class Cardinal(_StepTokenizer):
818		"""A step is tokenized with a cardinal direction token.
819
820		It is the direction of the step from the starting position along the solution.
821		"""
822
823		# inherit docstring
824		def to_tokens(  # noqa: D102
825			self,
826			maze: SolvedMaze,
827			start_index: int,
828			end_index: int,
829			**kwargs,
830		) -> list[str]:
831			return [
832				get_cardinal_direction(maze.solution[start_index : start_index + 2]),
833			]
834
835	@serializable_dataclass(frozen=True, kw_only=True)
836	class Relative(_StepTokenizer):
837		"""Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).
838
839		To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH.
840		Similarly to `Cardinal`, the direction is that of the step from the starting position.
841		"""
842
843		# inherit docstring
844		def to_tokens(  # noqa: D102
845			self,
846			maze: SolvedMaze,
847			start_index: int,
848			end_index: int,
849			**kwargs,
850		) -> list[str]:
851			if start_index == 0:
852				start = maze.solution[0]
853				previous = start + np.array([1, 0])
854				return [
855					get_relative_direction(
856						np.concatenate(
857							(
858								np.expand_dims(previous, 0),
859								maze.solution[start_index : start_index + 2],
860							),
861							axis=0,
862						),
863					),
864				]
865			return [
866				get_relative_direction(
867					maze.solution[start_index - 1 : start_index + 2],
868				),
869			]
870
871	@serializable_dataclass(frozen=True, kw_only=True)
872	class Distance(_StepTokenizer):
873		"""A count of the number of individual steps from the starting point to the end point.
874
875		Contains no information about directionality, only the distance traveled in the step.
876		`Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`.
877		This constraint is enforced in `_PathTokenizer.is_valid`.
878		"""
879
880		# inherit docstring
881		def to_tokens(  # noqa: D102
882			self,
883			maze: SolvedMaze,
884			start_index: int,
885			end_index: int,
886			**kwargs,
887		) -> list[str]:
888			d: int = end_index - start_index
889			return [getattr(VOCAB, f"I_{d:03}")]
890
891	"""
892	`StepTokenizerPermutation`
893	A sequence of unique `_StepTokenizer`s.
894	This type exists mostly just for the clarity and convenience of `_PathTokenizer` code.
895	"""
896	StepTokenizerPermutation: type = (
897		tuple[_StepTokenizer]
898		| tuple[_StepTokenizer, _StepTokenizer]
899		| tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer]
900		| tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer]
901	)

Namespace for _StepTokenizer subclass hierarchy used by MazeTokenizerModular.

key = 'step_tokenizers'
StepTokenizerPermutation: type = tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer]
@serializable_dataclass(frozen=True, kw_only=True)
class StepTokenizers.Coord(StepTokenizers._StepTokenizer):
802	@serializable_dataclass(frozen=True, kw_only=True)
803	class Coord(_StepTokenizer):
804		"""A direct tokenization of the end position coord represents the step."""
805
806		# inherit docstring
807		def to_tokens(  # noqa: D102
808			self,
809			maze: SolvedMaze,
810			start_index: int,
811			end_index: int,
812			coord_tokenizer: CoordTokenizers._CoordTokenizer,
813		) -> list[str]:
814			return coord_tokenizer.to_tokens(maze.solution[end_index, ...])

A direct tokenization of the end position coord represents the step.

StepTokenizers.Coord( *, _type_: Literal["<class 'StepTokenizers.Coord'>"] = "<class 'StepTokenizers.Coord'>")
def to_tokens( self, maze: maze_dataset.SolvedMaze, start_index: int, end_index: int, coord_tokenizer: maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer) -> list[str]:
807		def to_tokens(  # noqa: D102
808			self,
809			maze: SolvedMaze,
810			start_index: int,
811			end_index: int,
812			coord_tokenizer: CoordTokenizers._CoordTokenizer,
813		) -> list[str]:
814			return coord_tokenizer.to_tokens(maze.solution[end_index, ...])

Tokenizes a single step in the solution.

Parameters

  • maze: Maze to be tokenized
  • start_index: The index of the Coord in maze.solution at which the current step starts
  • end_index: The index of the Coord in maze.solution at which the current step ends
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
StepTokenizers._StepTokenizer
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class StepTokenizers.Cardinal(StepTokenizers._StepTokenizer):
816	@serializable_dataclass(frozen=True, kw_only=True)
817	class Cardinal(_StepTokenizer):
818		"""A step is tokenized with a cardinal direction token.
819
820		It is the direction of the step from the starting position along the solution.
821		"""
822
823		# inherit docstring
824		def to_tokens(  # noqa: D102
825			self,
826			maze: SolvedMaze,
827			start_index: int,
828			end_index: int,
829			**kwargs,
830		) -> list[str]:
831			return [
832				get_cardinal_direction(maze.solution[start_index : start_index + 2]),
833			]

A step is tokenized with a cardinal direction token.

It is the direction of the step from the starting position along the solution.

StepTokenizers.Cardinal( *, _type_: Literal["<class 'StepTokenizers.Cardinal'>"] = "<class 'StepTokenizers.Cardinal'>")
def to_tokens( self, maze: maze_dataset.SolvedMaze, start_index: int, end_index: int, **kwargs) -> list[str]:
824		def to_tokens(  # noqa: D102
825			self,
826			maze: SolvedMaze,
827			start_index: int,
828			end_index: int,
829			**kwargs,
830		) -> list[str]:
831			return [
832				get_cardinal_direction(maze.solution[start_index : start_index + 2]),
833			]

Tokenizes a single step in the solution.

Parameters

  • maze: Maze to be tokenized
  • start_index: The index of the Coord in maze.solution at which the current step starts
  • end_index: The index of the Coord in maze.solution at which the current step ends
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
StepTokenizers._StepTokenizer
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class StepTokenizers.Relative(StepTokenizers._StepTokenizer):
835	@serializable_dataclass(frozen=True, kw_only=True)
836	class Relative(_StepTokenizer):
837		"""Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).
838
839		To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH.
840		Similarly to `Cardinal`, the direction is that of the step from the starting position.
841		"""
842
843		# inherit docstring
844		def to_tokens(  # noqa: D102
845			self,
846			maze: SolvedMaze,
847			start_index: int,
848			end_index: int,
849			**kwargs,
850		) -> list[str]:
851			if start_index == 0:
852				start = maze.solution[0]
853				previous = start + np.array([1, 0])
854				return [
855					get_relative_direction(
856						np.concatenate(
857							(
858								np.expand_dims(previous, 0),
859								maze.solution[start_index : start_index + 2],
860							),
861							axis=0,
862						),
863					),
864				]
865			return [
866				get_relative_direction(
867					maze.solution[start_index - 1 : start_index + 2],
868				),
869			]

Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).

To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH. Similarly to Cardinal, the direction is that of the step from the starting position.

StepTokenizers.Relative( *, _type_: Literal["<class 'StepTokenizers.Relative'>"] = "<class 'StepTokenizers.Relative'>")
def to_tokens( self, maze: maze_dataset.SolvedMaze, start_index: int, end_index: int, **kwargs) -> list[str]:
844		def to_tokens(  # noqa: D102
845			self,
846			maze: SolvedMaze,
847			start_index: int,
848			end_index: int,
849			**kwargs,
850		) -> list[str]:
851			if start_index == 0:
852				start = maze.solution[0]
853				previous = start + np.array([1, 0])
854				return [
855					get_relative_direction(
856						np.concatenate(
857							(
858								np.expand_dims(previous, 0),
859								maze.solution[start_index : start_index + 2],
860							),
861							axis=0,
862						),
863					),
864				]
865			return [
866				get_relative_direction(
867					maze.solution[start_index - 1 : start_index + 2],
868				),
869			]

Tokenizes a single step in the solution.

Parameters

  • maze: Maze to be tokenized
  • start_index: The index of the Coord in maze.solution at which the current step starts
  • end_index: The index of the Coord in maze.solution at which the current step ends
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
StepTokenizers._StepTokenizer
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class StepTokenizers.Distance(StepTokenizers._StepTokenizer):
871	@serializable_dataclass(frozen=True, kw_only=True)
872	class Distance(_StepTokenizer):
873		"""A count of the number of individual steps from the starting point to the end point.
874
875		Contains no information about directionality, only the distance traveled in the step.
876		`Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`.
877		This constraint is enforced in `_PathTokenizer.is_valid`.
878		"""
879
880		# inherit docstring
881		def to_tokens(  # noqa: D102
882			self,
883			maze: SolvedMaze,
884			start_index: int,
885			end_index: int,
886			**kwargs,
887		) -> list[str]:
888			d: int = end_index - start_index
889			return [getattr(VOCAB, f"I_{d:03}")]

A count of the number of individual steps from the starting point to the end point.

Contains no information about directionality, only the distance traveled in the step. Distance must be combined with at least one other _StepTokenizer in a StepTokenizerPermutation. This constraint is enforced in _PathTokenizer.is_valid.

StepTokenizers.Distance( *, _type_: Literal["<class 'StepTokenizers.Distance'>"] = "<class 'StepTokenizers.Distance'>")
def to_tokens( self, maze: maze_dataset.SolvedMaze, start_index: int, end_index: int, **kwargs) -> list[str]:
881		def to_tokens(  # noqa: D102
882			self,
883			maze: SolvedMaze,
884			start_index: int,
885			end_index: int,
886			**kwargs,
887		) -> list[str]:
888			d: int = end_index - start_index
889			return [getattr(VOCAB, f"I_{d:03}")]

Tokenizes a single step in the solution.

Parameters

  • maze: Maze to be tokenized
  • start_index: The index of the Coord in maze.solution at which the current step starts
  • end_index: The index of the Coord in maze.solution at which the current step ends
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
StepTokenizers._StepTokenizer
attribute_key
is_valid
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
 904class PathTokenizers(__TokenizerElementNamespace):
 905	"""Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
 906
 907	key = "path_tokenizer"
 908
 909	@serializable_dataclass(frozen=True, kw_only=True)
 910	class _PathTokenizer(_TokenizerElement, abc.ABC):
 911		"""Superclass of tokenizers for maze solution paths."""
 912
 913		@abc.abstractmethod
 914		def to_tokens(
 915			self,
 916			maze: SolvedMaze,
 917			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 918		) -> list[str]:
 919			"""Returns tokens representing the solution path."""
 920			pass
 921
 922		@classmethod
 923		def attribute_key(cls) -> str:
 924			return PathTokenizers.key
 925
 926	@serializable_dataclass(frozen=True, kw_only=True)
 927	class StepSequence(_PathTokenizer, abc.ABC):
 928		"""Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path.
 929
 930		Allows for a sequence of leading and trailing tokens which don't fit the step pattern.
 931
 932		# Parameters
 933		- `step_size`: Selects the size of a single step in the sequence
 934		- `step_tokenizers`: Selects the combination and permutation of tokens
 935		- `pre`: Whether all steps include an integral preceding delimiter token
 936		- `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization.
 937		- `post`: Whether all steps include an integral following delimiter token
 938		"""
 939
 940		step_size: StepSizes._StepSize = serializable_field(
 941			default=StepSizes.Singles(),
 942			loading_fn=lambda x: _load_tokenizer_element(x, StepSizes),
 943		)
 944		step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field(
 945			default=(StepTokenizers.Coord(),),
 946			serialization_fn=lambda x: [y.serialize() for y in x],
 947			loading_fn=lambda x: tuple(x[StepTokenizers.key]),
 948		)
 949		pre: bool = serializable_field(default=False)
 950		intra: bool = serializable_field(default=False)
 951		post: bool = serializable_field(default=False)
 952
 953		# inherit docstring
 954		def to_tokens(  # noqa: D102
 955			self,
 956			maze: SolvedMaze,
 957			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 958		) -> list[str]:
 959			return [
 960				*self._leading_tokens(maze, coord_tokenizer),
 961				*flatten(
 962					[
 963						self._single_step_tokens(maze, start, end, coord_tokenizer)
 964						for start, end in self.step_size.step_start_end_indices(maze)
 965					],
 966				),
 967				*self._trailing_tokens(maze, coord_tokenizer),
 968			]
 969
 970		def _single_step_tokens(
 971			self,
 972			maze: SolvedMaze,
 973			i: int,
 974			j: int,
 975			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 976		) -> list[str]:
 977			"""Returns the token sequence representing a single step along the path."""
 978			step_rep_tokens: list[list[str]] = [
 979				step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer)
 980				for step_tokenizer in self.step_tokenizers
 981			]
 982			if self.intra:
 983				step_rep_tokens_and_intra: list[str] = [None] * (
 984					len(step_rep_tokens) * 2
 985				)
 986				step_rep_tokens_and_intra[::2] = step_rep_tokens
 987				step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len(
 988					step_rep_tokens,
 989				)
 990				step_rep_tokens = list(flatten(step_rep_tokens_and_intra))
 991			all_tokens: list[str] = [
 992				*empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
 993				*flatten(step_rep_tokens),
 994				*empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"),
 995			]
 996			return all_tokens
 997
 998		def _leading_tokens(
 999			self,
1000			maze: SolvedMaze,
1001			coord_tokenizer: CoordTokenizers._CoordTokenizer,
1002		) -> list[str]:
1003			"""Returns tokens preceding those from the sequence from `_single_step_tokens`.
1004
1005			Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`.
1006			<PATH_START> should NOT be included.
1007			"""
1008			if StepTokenizers.Coord() in self.step_tokenizers:
1009				return [
1010					*empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
1011					*coord_tokenizer.to_tokens(maze.solution[0, ...]),
1012					*empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"),
1013				]
1014			return []
1015
1016		def _trailing_tokens(
1017			self,
1018			c: Coord,
1019			coord_tokenizer: CoordTokenizers._CoordTokenizer,
1020		) -> list[str]:
1021			"""Returns tokens following those from the sequence from `_single_step_tokens`.
1022
1023			<PATH_END> should NOT be included.
1024			"""
1025			return []
1026
1027		# inherits docstring
1028		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
1029			output: bool
1030
1031			if len(set(self.step_tokenizers)) != len(self.step_tokenizers):
1032				# Uninteresting: repeated elements are not useful
1033				output = False
1034			else:
1035				# we do noqa for the comment if false
1036				if len(self.step_tokenizers) == 1 and isinstance(
1037					self.step_tokenizers[0],
1038					StepTokenizers.Distance,
1039				):
1040					# Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required.
1041					output = False
1042				else:
1043					output = True
1044
1045			if not output and do_except:
1046				raise ValueError(
1047					"PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.",
1048				)
1049
1050			return output

Namespace for _PathTokenizer subclass hierarchy used by MazeTokenizerModular.

key = 'path_tokenizer'
@serializable_dataclass(frozen=True, kw_only=True)
class PathTokenizers.StepSequence(PathTokenizers._PathTokenizer, abc.ABC):
 926	@serializable_dataclass(frozen=True, kw_only=True)
 927	class StepSequence(_PathTokenizer, abc.ABC):
 928		"""Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path.
 929
 930		Allows for a sequence of leading and trailing tokens which don't fit the step pattern.
 931
 932		# Parameters
 933		- `step_size`: Selects the size of a single step in the sequence
 934		- `step_tokenizers`: Selects the combination and permutation of tokens
 935		- `pre`: Whether all steps include an integral preceding delimiter token
 936		- `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization.
 937		- `post`: Whether all steps include an integral following delimiter token
 938		"""
 939
 940		step_size: StepSizes._StepSize = serializable_field(
 941			default=StepSizes.Singles(),
 942			loading_fn=lambda x: _load_tokenizer_element(x, StepSizes),
 943		)
 944		step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field(
 945			default=(StepTokenizers.Coord(),),
 946			serialization_fn=lambda x: [y.serialize() for y in x],
 947			loading_fn=lambda x: tuple(x[StepTokenizers.key]),
 948		)
 949		pre: bool = serializable_field(default=False)
 950		intra: bool = serializable_field(default=False)
 951		post: bool = serializable_field(default=False)
 952
 953		# inherit docstring
 954		def to_tokens(  # noqa: D102
 955			self,
 956			maze: SolvedMaze,
 957			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 958		) -> list[str]:
 959			return [
 960				*self._leading_tokens(maze, coord_tokenizer),
 961				*flatten(
 962					[
 963						self._single_step_tokens(maze, start, end, coord_tokenizer)
 964						for start, end in self.step_size.step_start_end_indices(maze)
 965					],
 966				),
 967				*self._trailing_tokens(maze, coord_tokenizer),
 968			]
 969
 970		def _single_step_tokens(
 971			self,
 972			maze: SolvedMaze,
 973			i: int,
 974			j: int,
 975			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 976		) -> list[str]:
 977			"""Returns the token sequence representing a single step along the path."""
 978			step_rep_tokens: list[list[str]] = [
 979				step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer)
 980				for step_tokenizer in self.step_tokenizers
 981			]
 982			if self.intra:
 983				step_rep_tokens_and_intra: list[str] = [None] * (
 984					len(step_rep_tokens) * 2
 985				)
 986				step_rep_tokens_and_intra[::2] = step_rep_tokens
 987				step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len(
 988					step_rep_tokens,
 989				)
 990				step_rep_tokens = list(flatten(step_rep_tokens_and_intra))
 991			all_tokens: list[str] = [
 992				*empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
 993				*flatten(step_rep_tokens),
 994				*empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"),
 995			]
 996			return all_tokens
 997
 998		def _leading_tokens(
 999			self,
1000			maze: SolvedMaze,
1001			coord_tokenizer: CoordTokenizers._CoordTokenizer,
1002		) -> list[str]:
1003			"""Returns tokens preceding those from the sequence from `_single_step_tokens`.
1004
1005			Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`.
1006			<PATH_START> should NOT be included.
1007			"""
1008			if StepTokenizers.Coord() in self.step_tokenizers:
1009				return [
1010					*empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
1011					*coord_tokenizer.to_tokens(maze.solution[0, ...]),
1012					*empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"),
1013				]
1014			return []
1015
1016		def _trailing_tokens(
1017			self,
1018			c: Coord,
1019			coord_tokenizer: CoordTokenizers._CoordTokenizer,
1020		) -> list[str]:
1021			"""Returns tokens following those from the sequence from `_single_step_tokens`.
1022
1023			<PATH_END> should NOT be included.
1024			"""
1025			return []
1026
1027		# inherits docstring
1028		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
1029			output: bool
1030
1031			if len(set(self.step_tokenizers)) != len(self.step_tokenizers):
1032				# Uninteresting: repeated elements are not useful
1033				output = False
1034			else:
1035				# we do noqa for the comment if false
1036				if len(self.step_tokenizers) == 1 and isinstance(
1037					self.step_tokenizers[0],
1038					StepTokenizers.Distance,
1039				):
1040					# Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required.
1041					output = False
1042				else:
1043					output = True
1044
1045			if not output and do_except:
1046				raise ValueError(
1047					"PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.",
1048				)
1049
1050			return output

Any PathTokenizer where the tokenization may be assembled from token subsequences, each of which represents a step along the path.

Allows for a sequence of leading and trailing tokens which don't fit the step pattern.

Parameters

  • step_size: Selects the size of a single step in the sequence
  • step_tokenizers: Selects the combination and permutation of tokens
  • pre: Whether all steps include an integral preceding delimiter token
  • intra: Whether all steps include a delimiter token after each individual _StepTokenizer tokenization.
  • post: Whether all steps include an integral following delimiter token
PathTokenizers.StepSequence( *, _type_: Literal["<class 'PathTokenizers.StepSequence'>"] = "<class 'PathTokenizers.StepSequence'>", step_size: maze_dataset.tokenization.modular.elements.StepSizes._StepSize = StepSizes.Singles(), step_tokenizers: tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] = (StepTokenizers.Coord(),), pre: bool = False, intra: bool = False, post: bool = False)
step_size: maze_dataset.tokenization.modular.elements.StepSizes._StepSize = StepSizes.Singles()
step_tokenizers: tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] = (StepTokenizers.Coord(),)
pre: bool = False
intra: bool = False
post: bool = False
def to_tokens( self, maze: maze_dataset.SolvedMaze, coord_tokenizer: maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer) -> list[str]:
954		def to_tokens(  # noqa: D102
955			self,
956			maze: SolvedMaze,
957			coord_tokenizer: CoordTokenizers._CoordTokenizer,
958		) -> list[str]:
959			return [
960				*self._leading_tokens(maze, coord_tokenizer),
961				*flatten(
962					[
963						self._single_step_tokens(maze, start, end, coord_tokenizer)
964						for start, end in self.step_size.step_start_end_indices(maze)
965					],
966				),
967				*self._trailing_tokens(maze, coord_tokenizer),
968			]

Returns tokens representing the solution path.

def is_valid(self, do_except: bool = False) -> bool:
1028		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
1029			output: bool
1030
1031			if len(set(self.step_tokenizers)) != len(self.step_tokenizers):
1032				# Uninteresting: repeated elements are not useful
1033				output = False
1034			else:
1035				# we do noqa for the comment if false
1036				if len(self.step_tokenizers) == 1 and isinstance(
1037					self.step_tokenizers[0],
1038					StepTokenizers.Distance,
1039				):
1040					# Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required.
1041					output = False
1042				else:
1043					output = True
1044
1045			if not output and do_except:
1046				raise ValueError(
1047					"PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.",
1048				)
1049
1050			return output

Returns if self contains data members capable of producing an overall valid MazeTokenizerModular.

Some _TokenizerElement instances may be created which are not useful despite obeying data member type hints. is_valid allows for more precise detection of invalid _TokenizerElements beyond type hinting alone. If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply return True for that subclass.

Types of Invalidity

In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity:

  • Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. E.g., _TokenizerElements which are strictly worse than some alternative.
  • Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
  • Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
  • Erroneous: These tokenizers might raise exceptions during use.

Development

is_invalid is implemented to always return True in some abstract classes where all currently possible subclass instances are valid. When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.

Nesting

In general, when implementing this method, there is no need to recursively call is_valid on nested _TokenizerElements contained in the class. In other words, failures of is_valid need not bubble up to the top of the nested _TokenizerElement tree. MazeTokenizerModular.is_valid calls is_valid on each of its _TokenizerElements individually, so failure at any level will be detected.

Types of Invalidity

If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. This could be used to create more or less stringent filters on the valid _TokenizerElement instances.

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
PathTokenizers._PathTokenizer
attribute_key
maze_dataset.tokenization.modular.element_base._TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
1053class PromptSequencers(__TokenizerElementNamespace):
1054	"""Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`."""
1055
1056	key = "prompt_sequencer"
1057
1058	@serializable_dataclass(frozen=True, kw_only=True)
1059	class _PromptSequencer(_TokenizerElement, abc.ABC):
1060		"""Sequences token regions into a complete maze tokenization.
1061
1062		# Parameters
1063		- `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position.
1064		- `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`.
1065		Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s.
1066		"""
1067
1068		coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field(
1069			default=CoordTokenizers.UT(),
1070			loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers),
1071		)
1072		adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field(
1073			default=AdjListTokenizers.AdjListCoord(),
1074			loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers),
1075		)
1076
1077		@classmethod
1078		def attribute_key(cls) -> str:
1079			return PromptSequencers.key
1080
1081		@staticmethod
1082		def _trim_if_unsolved_maze(
1083			untrimmed: list[str],
1084			is_untargeted: bool = False,
1085			is_unsolved: bool = False,
1086		) -> list[str]:
1087			"""Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze.
1088
1089			# Development
1090			This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP.
1091			It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses.
1092			"""
1093			if is_untargeted:
1094				return tokens_between(
1095					untrimmed,
1096					VOCAB.ADJLIST_START,
1097					VOCAB.ADJLIST_END,
1098					include_start=True,
1099					include_end=True,
1100				)
1101			if is_unsolved:
1102				if VOCAB.TARGET_END in untrimmed:
1103					return tokens_between(
1104						untrimmed,
1105						VOCAB.ADJLIST_START,
1106						VOCAB.TARGET_END,
1107						include_start=True,
1108						include_end=True,
1109					)
1110				else:
1111					return tokens_between(
1112						untrimmed,
1113						VOCAB.ADJLIST_START,
1114						VOCAB.ORIGIN_END,
1115						include_start=True,
1116						include_end=True,
1117					)
1118			return untrimmed
1119
1120		def to_tokens(
1121			self,
1122			maze: LatticeMaze,
1123			*args,
1124			**kwargs,
1125		) -> list[str]:
1126			"""Returns a complete list of tokens for a given set of maze elements."""
1127			untrimmed: list[str] = self._sequence_tokens(
1128				*self._get_prompt_regions(maze),
1129			)
1130			return self._trim_if_unsolved_maze(
1131				untrimmed,
1132				not hasattr(maze, "start_pos"),
1133				not hasattr(maze, "solution"),
1134			)
1135
1136		def _get_prompt_regions(
1137			self,
1138			maze: LatticeMaze,
1139			*args,
1140			**kwargs,
1141		) -> list[list[str]]:
1142			"""Gets the prompt regions of a maze in a fixed sequence.
1143
1144			This method is NOT responsible for including/excluding any prompt regions.
1145			Always return according to the API described under Returns.
1146			This implementation is expected to be suitable for most `PromptSequencer` subclasses.
1147			Subclasses may override this method if needed for special behavior.
1148
1149			# Returns
1150			- [0]: list[str] Adjacency list tokens
1151			- [1]: list[str] Origin tokens
1152			- [2]: list[str] Target tokens
1153			- [3]: list[str] Path tokens
1154
1155			# `None`-valued Args
1156			If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized.
1157			To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables.
1158			"""
1159			origin: Coord | None = getattr(maze, "start_pos", None)
1160			target: list[Coord] | None = [
1161				getattr(maze, "end_pos", None),
1162			]  # TargetTokenizer requires target: Sequence[Coord]
1163
1164			return [
1165				(
1166					self.adj_list_tokenizer.to_tokens(
1167						maze,
1168						coord_tokenizer=self.coord_tokenizer,
1169					)
1170					if hasattr(self, "adj_list_tokenizer")
1171					else []
1172				),
1173				self.coord_tokenizer.to_tokens(origin) if origin is not None else [],
1174				(
1175					self.target_tokenizer.to_tokens(
1176						target,
1177						coord_tokenizer=self.coord_tokenizer,
1178					)
1179					if target[0] is not None and hasattr(self, "target_tokenizer")
1180					else []
1181				),
1182				(
1183					self.path_tokenizer.to_tokens(
1184						maze,
1185						coord_tokenizer=self.coord_tokenizer,
1186					)
1187					if hasattr(maze, "solution") and hasattr(self, "path_tokenizer")
1188					else []
1189				),
1190			]
1191
1192		@abc.abstractmethod
1193		def _sequence_tokens(
1194			self,
1195			adj_list: list[str],
1196			origin: list[str] | None,
1197			target: list[str] | None,
1198			path: list[str] | None,
1199		) -> list[str]:
1200			"""Sequences token regions into a complete prompt.
1201
1202			Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as <ADJLIST_START>, <ORIGIN_END>, etc.
1203
1204			# Parameters
1205			- `adj_list`: Tokens representing the adjacency list
1206			- `origin`: Tokens representing the origin
1207			- `target`: Tokens representing the target
1208			- `path`: Tokens representing the path
1209			"""
1210			pass
1211
1212		def is_valid(self, do_except: bool = False) -> bool:
1213			# No invalid instances possible within data member type hint bounds
1214			return True
1215
1216	@serializable_dataclass(frozen=True, kw_only=True)
1217	class AOTP(_PromptSequencer):
1218		"""Sequences a prompt as [adjacency list, origin, target, path].
1219
1220		# Parameters
1221		- `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`.
1222		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`.
1223		- `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1224		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1225
1226		"""
1227
1228		target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field(
1229			default=TargetTokenizers.Unlabeled(),
1230			loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers),
1231		)
1232		path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1233			default=PathTokenizers.StepSequence(),
1234			loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1235		)
1236
1237		def _sequence_tokens(
1238			self,
1239			adj_list: list[str],
1240			origin: list[str],
1241			target: list[str],
1242			path: list[str],
1243		) -> list[str]:
1244			return [
1245				VOCAB.ADJLIST_START,
1246				*adj_list,
1247				VOCAB.ADJLIST_END,
1248				VOCAB.ORIGIN_START,
1249				*origin,
1250				VOCAB.ORIGIN_END,
1251				VOCAB.TARGET_START,
1252				*target,
1253				VOCAB.TARGET_END,
1254				VOCAB.PATH_START,
1255				*path,
1256				VOCAB.PATH_END,
1257			]
1258
1259	@serializable_dataclass(frozen=True, kw_only=True)
1260	class AOP(_PromptSequencer):
1261		"""Sequences a prompt as [adjacency list, origin, path].
1262
1263		Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself.
1264
1265		# Parameters
1266		- `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1267		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1268		"""
1269
1270		path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1271			default=PathTokenizers.StepSequence(),
1272			loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1273		)
1274
1275		def _sequence_tokens(
1276			self,
1277			adj_list: list[str],
1278			origin: list[str],
1279			# explicitly no target in this tokenizer
1280			target: list[str],
1281			path: list[str],
1282		) -> list[str]:
1283			return [
1284				VOCAB.ADJLIST_START,
1285				*adj_list,
1286				VOCAB.ADJLIST_END,
1287				VOCAB.ORIGIN_START,
1288				*origin,
1289				VOCAB.ORIGIN_END,
1290				VOCAB.TARGET_START,
1291				VOCAB.TARGET_END,
1292				VOCAB.PATH_START,
1293				*path,
1294				VOCAB.PATH_END,
1295			]

Namespace for _PromptSequencer subclass hierarchy used by MazeTokenizerModular.

key = 'prompt_sequencer'
@serializable_dataclass(frozen=True, kw_only=True)
class PromptSequencers.AOTP(PromptSequencers._PromptSequencer):
1216	@serializable_dataclass(frozen=True, kw_only=True)
1217	class AOTP(_PromptSequencer):
1218		"""Sequences a prompt as [adjacency list, origin, target, path].
1219
1220		# Parameters
1221		- `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`.
1222		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`.
1223		- `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1224		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1225
1226		"""
1227
1228		target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field(
1229			default=TargetTokenizers.Unlabeled(),
1230			loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers),
1231		)
1232		path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1233			default=PathTokenizers.StepSequence(),
1234			loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1235		)
1236
1237		def _sequence_tokens(
1238			self,
1239			adj_list: list[str],
1240			origin: list[str],
1241			target: list[str],
1242			path: list[str],
1243		) -> list[str]:
1244			return [
1245				VOCAB.ADJLIST_START,
1246				*adj_list,
1247				VOCAB.ADJLIST_END,
1248				VOCAB.ORIGIN_START,
1249				*origin,
1250				VOCAB.ORIGIN_END,
1251				VOCAB.TARGET_START,
1252				*target,
1253				VOCAB.TARGET_END,
1254				VOCAB.PATH_START,
1255				*path,
1256				VOCAB.PATH_END,
1257			]

Sequences a prompt as [adjacency list, origin, target, path].

Parameters

  • target_tokenizer: Tokenizer element which tokenizes the target(s) of a TargetedLatticeMaze. Uses coord_tokenizer to tokenize coords if that is part of the design of that TargetTokenizer.
  • path_tokenizer: Tokenizer element which tokenizes the solution path of a SolvedMaze. Uses coord_tokenizer to tokenize coords if that is part of the design of that PathTokenizer.
PromptSequencers.AOTP( *, coord_tokenizer: maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer = CoordTokenizers.UT(), adj_list_tokenizer: maze_dataset.tokenization.modular.elements.AdjListTokenizers._AdjListTokenizer = AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), _type_: Literal["<class 'PromptSequencers.AOTP'>"] = "<class 'PromptSequencers.AOTP'>", target_tokenizer: maze_dataset.tokenization.modular.elements.TargetTokenizers._TargetTokenizer = TargetTokenizers.Unlabeled(post=False), path_tokenizer: maze_dataset.tokenization.modular.elements.PathTokenizers._PathTokenizer = PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False))
target_tokenizer: maze_dataset.tokenization.modular.elements.TargetTokenizers._TargetTokenizer = TargetTokenizers.Unlabeled(post=False)
path_tokenizer: maze_dataset.tokenization.modular.elements.PathTokenizers._PathTokenizer = PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False)
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@serializable_dataclass(frozen=True, kw_only=True)
class PromptSequencers.AOP(PromptSequencers._PromptSequencer):
1259	@serializable_dataclass(frozen=True, kw_only=True)
1260	class AOP(_PromptSequencer):
1261		"""Sequences a prompt as [adjacency list, origin, path].
1262
1263		Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself.
1264
1265		# Parameters
1266		- `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1267		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1268		"""
1269
1270		path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1271			default=PathTokenizers.StepSequence(),
1272			loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1273		)
1274
1275		def _sequence_tokens(
1276			self,
1277			adj_list: list[str],
1278			origin: list[str],
1279			# explicitly no target in this tokenizer
1280			target: list[str],
1281			path: list[str],
1282		) -> list[str]:
1283			return [
1284				VOCAB.ADJLIST_START,
1285				*adj_list,
1286				VOCAB.ADJLIST_END,
1287				VOCAB.ORIGIN_START,
1288				*origin,
1289				VOCAB.ORIGIN_END,
1290				VOCAB.TARGET_START,
1291				VOCAB.TARGET_END,
1292				VOCAB.PATH_START,
1293				*path,
1294				VOCAB.PATH_END,
1295			]

Sequences a prompt as [adjacency list, origin, path].

Still includes "" and "" tokens, but no representation of the target itself.

Parameters

  • path_tokenizer: Tokenizer element which tokenizes the solution path of a SolvedMaze. Uses coord_tokenizer to tokenize coords if that is part of the design of that PathTokenizer.
PromptSequencers.AOP( *, coord_tokenizer: maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer = CoordTokenizers.UT(), adj_list_tokenizer: maze_dataset.tokenization.modular.elements.AdjListTokenizers._AdjListTokenizer = AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), _type_: Literal["<class 'PromptSequencers.AOP'>"] = "<class 'PromptSequencers.AOP'>", path_tokenizer: maze_dataset.tokenization.modular.elements.PathTokenizers._PathTokenizer = PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False))
path_tokenizer: maze_dataset.tokenization.modular.elements.PathTokenizers._PathTokenizer = PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False)
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field