docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.tokenization

turning a maze into text

There are many algorithms by which one might tokenize a 2D maze into a 1D format usable by autoregressive text models. Training multiple models on the encodings output from each of these algorithms may produce very different internal representations, learned solution algorithms, and levels of performance. To explore how different maze tokenization algorithms affect these models, the MazeTokenizerModular class contains a rich set of options to customize how mazes are stringified. This class contains 19 discrete parameters, resulting in 5.9 million unique tokenizers. But wait, there's more! There are 6 additional parameters available in the library which are untested but further expand the the number of tokenizers by a factor of $44/3$ to 86 million.

All output sequences consist of four token regions representing different features of the maze. These regions are distinguished by color in Figure below.

  • Adjacency list: A text representation of the lattice graph
  • Origin: Starting coordinate
  • Target: Ending coordinate
  • Path: Maze solution sequence from the start to the end

Example text output format with token regions highlighted.

Each MazeTokenizerModular is constructed from a set of several _TokenizerElement objects, each of which specifies how different token regions or other elements of the stringification are produced.

Nested internal structure of <code>_TokenizerElement</code> objects inside a typical <code>MazeTokenizerModular</code> object.

Optional delimiter tokens may be added in many places in the output. Delimiter options are all configured using the parameters named pre, intra, and post in various _TokenizerElement classes. Each option controls a unique delimiter token. Here we describe each _TokenizerElement and the behaviors they support. We also discuss some of the model behaviors and properties that may be investigated using these options.

Coordinates

The _CoordTokenizer object controls how coordinates in the lattice are represented in across all token regions. Options include:

  • Unique tokens: Each coordinate is represented as a single unique token "(i,j)"
  • Coordinate tuple tokens: Each coordinate is represented as a sequence of 2 tokens, respectively encoding the row and column positions: ["i", ",", "j"]

Adjacency List

The _AdjListTokenizer object controls this token region. All tokenizations represent the maze connectivity as a sequence of connections or walls between pairs of adjacent coordinates in the lattice.

  • _EdgeSubset: Specifies the subset of lattice edges to be tokenized
    • All edges: Every edge in the lattice
    • Connections: Only edges which contain a connection
    • Walls: Only edges which contain a wall
  • _EdgePermuter: Specifies how to sequence the two coordinates in each lattice edge
    • Random
    • Sorted: The smaller coordinate always comes first
    • Both permutations: Each edge is represented twice, once with each permutation. This option attempts to represent connections in a more directionally symmetric manner. Including only one permutation of each edge may affect models' internal representations of edges, treating a path traversing the edge differently depending on if the coordinate sequence in the path matches the sequence in the adjacency list.
  • shuffle_d0: Whether to shuffle the edges randomly or sort them in the output by their first coordinate
  • connection_token_ordinal: Location in the sequence of the token representing whether the edge is a connection or a wall

Path

The _PathTokenizer object controls this token region. Paths are all represented as a sequence of steps moving from the start to the end position.

  • _StepSize: Specifies the size of each step
    • Singles: Every coordinate traversed between start and end is directly represented
    • Forks: Only coordinates at forking points in the maze are represented. The paths between forking points are implicit. Using this option might train models more directly to represent forking points differently from coordinates where the maze connectivity implies an obvious next step in the path.
  • _StepTokenizer: Specifies how an individual step is represented
    • Coordinate: The coordinates of each step are directly tokenized using a _CoordTokenizer
    • Cardinal direction: A single token corresponding to the cardinal direction taken at the starting position of that step. E.g., NORTH, SOUTH. If using a _StepSize other than Singles, this direction may not correspond to the final direction traveled to arrive at the end position of the step.
    • Relative direction: A single token corresponding to the first-person perspective relative direction taken at the starting position of that step. E.g., RIGHT, LEFT.
    • Distance: A single token corresponding to the number of coordinate positions traversed in that step. E.g., using a _StepSize of Singles, the Distance token would be the same for each step, corresponding to a distance of 1 coordinate. This option is only of interest in combination with a _StepSize other than Singles.

A _PathTokenizer contains a sequence of one or more unique _StepTokenizer objects. Different step representations may be mixed and permuted, allowing for investigation of model representations of multiple aspects of a maze solution at once.

Tokenized Outputs for Training and Evaluation {#token-training}

During deployment we provide only the prompt up to the <PATH_START> token.

Examples of usage of this dataset to train autoregressive transformers can be found in our maze-transformer library [@maze-transformer-github]. Other tokenization and vocabulary schemes are also included, such as representing each coordinate as a pair of $i,j$ index tokens.

Extensibility

The tokenizer architecture is purposefully designed such that adding and testing a wide variety of new tokenization algorithms is fast and minimizes disturbances to functioning code. This is enabled by the modular architecture and the automatic inclusion of any new tokenizers in integration tests. To create a new tokenizer, developers forking the library may simply create their own _TokenizerElement subclass and implement the abstract methods. If the behavior change is sufficiently small, simply adding a parameter to an existing _TokenizerElement subclass and updating its implementation will suffice. For small additions, simply adding new cases to existing unit tests will suffice.

The breadth of tokenizers is also easily scaled in the opposite direction. Due to the exponential scaling of parameter combinations, adding a small number of new features can significantly slow certain procedures which rely on constructing all possible tokenizers, such as integration tests. If any existing subclass contains features which aren't needed, a developer tool decorator is provided which can be applied to the unneeded _TokenizerElement subclasses to prune those features and compact the available space of tokenizers.


  1"""turning a maze into text
  2
  3- `MazeTokenizerModular` is the new recommended way to do this as of 1.0.0
  4- legacy `TokenizationMode` enum and `MazeTokenizer` class for supporting existing code
  5- a variety of helper classes and functions
  6
  7There are many algorithms by which one might tokenize a 2D maze into a 1D format usable by autoregressive text models. Training multiple models on the encodings output from each of these algorithms may produce very different internal representations, learned solution algorithms, and levels of performance. To explore how different maze tokenization algorithms affect these models, the `MazeTokenizerModular` class contains a rich set of options to customize how mazes are stringified. This class contains 19 discrete parameters, resulting in 5.9 million unique tokenizers. But wait, there's more! There are 6 additional parameters available in the library which are untested but further expand the the number of tokenizers by a factor of $44/3$ to 86 million.
  8
  9All output sequences consist of four token regions representing different features of the maze. These regions are distinguished by color in Figure below.
 10
 11- <span style="background-color:rgb(217,210,233)">Adjacency list</span>: A text representation of the lattice graph
 12- <span style="background-color:rgb(217,234,211)">Origin</span>: Starting coordinate
 13- <span style="background-color:rgb(234,209,220)">Target</span>: Ending coordinate
 14- <span style="background-color:rgb(207,226,243)">Path</span>: Maze solution sequence from the start to the end
 15
 16![Example text output format with token regions highlighted.](figures/outputs-tokens-colored.tex)
 17
 18Each `MazeTokenizerModular` is constructed from a set of several `_TokenizerElement` objects, each of which specifies how different token regions or other elements of the stringification are produced.
 19
 20![Nested internal structure of `_TokenizerElement` objects inside a typical `MazeTokenizerModular` object.](figures/TokenizerElement_structure.pdf)
 21
 22Optional delimiter tokens may be added in many places in the output. Delimiter options are all configured using the parameters named `pre`, `intra`, and `post` in various `_TokenizerElement` classes. Each option controls a unique delimiter token.
 23Here we describe each `_TokenizerElement` and the behaviors they support. We also discuss some of the model behaviors and properties that may be investigated using these options.
 24
 25### Coordinates
 26
 27The `_CoordTokenizer` object controls how coordinates in the lattice are represented in across all token regions. Options include:
 28
 29- **Unique tokens**: Each coordinate is represented as a single unique token `"(i,j)"`
 30- **Coordinate tuple tokens**: Each coordinate is represented as a sequence of 2 tokens, respectively encoding the row and column positions: `["i", ",", "j"]`
 31
 32### Adjacency List
 33
 34The `_AdjListTokenizer` object controls this token region. All tokenizations represent the maze connectivity as a sequence of connections or walls between pairs of adjacent coordinates in the lattice.
 35
 36- `_EdgeSubset`: Specifies the subset of lattice edges to be tokenized
 37  - **All edges**: Every edge in the lattice
 38  - **Connections**: Only edges which contain a connection
 39  - **Walls**: Only edges which contain a wall
 40- `_EdgePermuter`: Specifies how to sequence the two coordinates in each lattice edge
 41  - **Random**
 42  - **Sorted**: The smaller coordinate always comes first
 43  - **Both permutations**: Each edge is represented twice, once with each permutation. This option attempts to represent connections in a more directionally symmetric manner. Including only one permutation of each edge may affect models' internal representations of edges, treating a path traversing the edge differently depending on if the coordinate sequence in the path matches the sequence in the adjacency list.
 44- `shuffle_d0`: Whether to shuffle the edges randomly or sort them in the output by their first coordinate
 45- `connection_token_ordinal`: Location in the sequence of the token representing whether the edge is a connection or a wall
 46
 47### Path
 48
 49The `_PathTokenizer` object controls this token region. Paths are all represented as a sequence of steps moving from the start to the end position.
 50
 51- `_StepSize`: Specifies the size of each step
 52  - **Singles**: Every coordinate traversed between start and end is directly represented
 53  - **Forks**: Only coordinates at forking points in the maze are represented. The paths between forking points are implicit. Using this option might train models more directly to represent forking points differently from coordinates where the maze connectivity implies an obvious next step in the path.
 54- `_StepTokenizer`: Specifies how an individual step is represented
 55  - **Coordinate**: The coordinates of each step are directly tokenized using a `_CoordTokenizer`
 56  - **Cardinal direction**: A single token corresponding to the cardinal direction taken at the starting position of that step. E.g., `NORTH`, `SOUTH`. If using a `_StepSize` other than **Singles**, this direction may not correspond to the final direction traveled to arrive at the end position of the step.
 57  - **Relative direction**: A single token corresponding to the first-person perspective relative direction taken at the starting position of that step. E.g., `RIGHT`, `LEFT`.
 58  - **Distance**: A single token corresponding to the number of coordinate positions traversed in that step. E.g., using a `_StepSize` of **Singles**, the **Distance** token would be the same for each step, corresponding to a distance of 1 coordinate. This option is only of interest in combination with a `_StepSize` other than **Singles**.
 59
 60A `_PathTokenizer` contains a sequence of one or more unique `_StepTokenizer` objects. Different step representations may be mixed and permuted, allowing for investigation of model representations of multiple aspects of a maze solution at once.
 61
 62## Tokenized Outputs for Training and Evaluation {#token-training}
 63
 64During deployment we provide only the prompt up to the `<PATH_START>` token.
 65
 66Examples of usage of this dataset to train autoregressive transformers can be found in our `maze-transformer` library [@maze-transformer-github]. Other tokenization and vocabulary schemes are also included, such as representing each coordinate as a pair of $i,j$ index tokens.
 67
 68## Extensibility
 69
 70The tokenizer architecture is purposefully designed such that adding and testing a wide variety of new tokenization algorithms is fast and minimizes disturbances to functioning code. This is enabled by the modular architecture and the automatic inclusion of any new tokenizers in integration tests. To create a new tokenizer, developers forking the library may simply create their own `_TokenizerElement` subclass and implement the abstract methods. If the behavior change is sufficiently small, simply adding a parameter to an existing `_TokenizerElement` subclass and updating its implementation will suffice. For small additions, simply adding new cases to existing unit tests will suffice.
 71
 72The breadth of tokenizers is also easily scaled in the opposite direction. Due to the exponential scaling of parameter combinations, adding a small number of new features can significantly slow certain procedures which rely on constructing all possible tokenizers, such as integration tests. If any existing subclass contains features which aren't needed, a developer tool decorator is provided which can be applied to the unneeded `_TokenizerElement` subclasses to prune those features and compact the available space of tokenizers.
 73
 74"""
 75
 76from maze_dataset.tokenization.maze_tokenizer_legacy import (
 77	MazeTokenizer,
 78	TokenizationMode,
 79	get_tokens_up_to_path_start,
 80)
 81from maze_dataset.tokenization.modular.element_base import _TokenizerElement
 82from maze_dataset.tokenization.modular.elements import (
 83	AdjListTokenizers,
 84	CoordTokenizers,
 85	EdgeGroupings,
 86	EdgePermuters,
 87	EdgeSubsets,
 88	PathTokenizers,
 89	PromptSequencers,
 90	StepSizes,
 91	StepTokenizers,
 92	TargetTokenizers,
 93)
 94from maze_dataset.tokenization.modular.maze_tokenizer_modular import (
 95	MazeTokenizerModular,
 96)
 97
 98# we don't sort alphabetically on purpose, we sort by the type
 99__all__ = [
100	# submodules
101	"modular",
102	"common",
103	"maze_tokenizer_legacy",
104	"maze_tokenizer",
105	# legacy tokenizer
106	"MazeTokenizer",
107	"TokenizationMode",
108	# MMT
109	"MazeTokenizerModular",
110	# element base
111	"_TokenizerElement",
112	# elements
113	"PromptSequencers",
114	"CoordTokenizers",
115	"AdjListTokenizers",
116	"EdgeGroupings",
117	"EdgePermuters",
118	"EdgeSubsets",
119	"TargetTokenizers",
120	"StepSizes",
121	"StepTokenizers",
122	"PathTokenizers",
123	# helpers
124	"get_tokens_up_to_path_start",
125]

@serializable_dataclass(properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, kw_only=True)
class MazeTokenizer(muutils.json_serialize.serializable_dataclass.SerializableDataclass):
140@serializable_dataclass(
141	properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE,
142	kw_only=True,
143)
144class MazeTokenizer(SerializableDataclass):
145	"""LEGACY Tokenizer for mazes
146
147	> [!CAUTION]
148	> `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended
149	> for use, but will remain for compatibility with existing code.
150
151	# Parameters:
152	- `tokenization_mode: TokenizationMode`
153		mode of tokenization. required.
154	- `max_grid_size: int | None`
155		maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text
156
157	# Properties
158	- `name: str`
159		auto-generated name of the tokenizer from mode and size
160
161	## Conditional Properties
162
163	- `node_strings_map: Mapping[CoordTup, str]`
164		map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode
165
166	these all return `None` if `max_grid_size` is `None`.
167	Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None`
168
169	- `token_arr: list[str]`
170		list of tokens, in order of their indices in the vocabulary
171	- `tokenizer_map: Mapping[str, int]`
172		map from token to index
173	- `vocab_size: int`
174		size of the vocabulary
175	- `padding_token_index: int`
176		index of the padding token
177
178	# Methods
179	- `coords_to_strings(coords: list[CoordTup]) -> list[str]`
180		convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates
181	- `strings_to_coords(strings: list[str]) -> list[CoordTup]`
182		convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
183
184	"""
185
186	# parameters
187	# ============================================================
188
189	tokenization_mode: TokenizationMode = serializable_field(
190		default=TokenizationMode.AOTP_UT_uniform,
191		serialization_fn=lambda x: x.value,
192		loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]],
193	)
194
195	max_grid_size: int | None = serializable_field(default=None)
196
197	# properties
198	# ============================================================
199
200	@property
201	def name(self) -> str:
202		"""auto-generated name of the tokenizer from mode and size"""
203		max_grid_size_str: str = (
204			f"-g{self.max_grid_size}" if self.max_grid_size is not None else ""
205		)
206		return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"
207
208	@cached_property
209	def _node_strings_map(self) -> Mapping[CoordTup, list[str]]:
210		"""map a coordinate to a token"""
211		if self.tokenization_mode in (
212			TokenizationMode.AOTP_UT_rasterized,
213			TokenizationMode.AOTP_UT_uniform,
214		):
215			return Kappa(_coord_to_strings_UT)
216		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
217			return Kappa(_coord_to_strings_indexed)
218		else:
219			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
220			raise ValueError(err_msg)
221
222	@cached_property
223	def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None:
224		"""map a coordinate to a token"""
225		if self.tokenization_mode in (
226			TokenizationMode.AOTP_UT_rasterized,
227			TokenizationMode.AOTP_UT_uniform,
228		):
229			return None
230		else:
231			return self._node_strings_map
232
233	# conditional properties (on max_grid_size existing)
234	# ------------------------------------------------------------
235
236	@cached_property
237	def _token_arr(self) -> list[str]:
238		"""map from index to token"""
239		if self.max_grid_size is None:
240			err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }"
241			raise ValueError(err_msg)
242
243		output: list[str] = list(SPECIAL_TOKENS.values())
244
245		if self.tokenization_mode in (
246			TokenizationMode.AOTP_UT_rasterized,
247			TokenizationMode.AOTP_UT_uniform,
248		):
249			output.extend(
250				[
251					self._node_strings_map[coord][0]
252					for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode](
253						self.max_grid_size,
254					)
255				],
256			)
257		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
258			# TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models
259			output.extend(
260				[
261					"(",
262					",",
263					")",  # new special chars
264					*map(str, range(self.max_grid_size)),  # numbers
265				],
266			)
267		else:
268			err_msg: str = (
269				f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}",
270			)
271			raise ValueError(err_msg)
272
273		return output
274
275	@cached_property
276	def token_arr(self) -> list[str] | None:
277		"get the token array if the max_grid_size is specified"
278		if self.max_grid_size is None:
279			return None
280		return self._token_arr
281
282	@cached_property
283	def _tokenizer_map(self) -> dict[str, int]:
284		"""map from token to index"""
285		return {token: i for i, token in enumerate(self._token_arr)}
286
287	@cached_property
288	def tokenizer_map(self) -> dict[str, int] | None:
289		"get the tokenizer map if the max_grid_size is specified"
290		if self.max_grid_size is None:
291			return None
292		return self._tokenizer_map
293
294	@property
295	def _vocab_size(self) -> int:
296		return len(self._token_arr)
297
298	@property
299	def vocab_size(self) -> int | None:
300		"get the size of the vocabulary if the max_grid_size is specified"
301		if self.max_grid_size is None:
302			return None
303		return self._vocab_size
304
305	@property
306	def _n_tokens(self) -> int:
307		# TODO: deprecate
308		return self._vocab_size
309
310	@property
311	def n_tokens(self) -> int | None:
312		"get the number of tokens if the max_grid_size is specified"
313		if self.max_grid_size is None:
314			return None
315		return self._n_tokens
316
317	@cached_property
318	def _padding_token_index(self) -> int:
319		return self.tokenizer_map[SPECIAL_TOKENS.PADDING]
320
321	@cached_property
322	def padding_token_index(self) -> int | None:
323		"get the index of the padding token if it exists"
324		if self.max_grid_size is None:
325			return None
326		return self._padding_token_index
327
328	# conversion functions
329	# ============================================================
330
331	@overload
332	def coords_to_strings(
333		self,
334		coords: list[str | CoordTup],
335		when_noncoord: Literal["include", "skip"] = "skip",
336	) -> list[str]: ...
337	@overload
338	def coords_to_strings(
339		self,
340		coords: list[CoordTup],
341		when_noncoord: Literal["error"] = "error",
342	) -> list[str]: ...
343	def coords_to_strings(
344		self,
345		coords: list[CoordTup],
346		when_noncoord: WhenMissing = "skip",
347	) -> list[str]:
348		"""map a list of coordinate tuples (and maybe other tokens) to strings
349
350		wraps `maze_dataset.token_utils.coords_to_strings` with either
351		`_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode
352		"""
353		if self.tokenization_mode in (
354			TokenizationMode.AOTP_UT_rasterized,
355			TokenizationMode.AOTP_UT_uniform,
356		):
357			return coords_to_strings(
358				coords=coords,
359				coord_to_strings_func=_coord_to_strings_UT,
360				when_noncoord=when_noncoord,
361			)
362		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
363			return coords_to_strings(
364				coords=coords,
365				coord_to_strings_func=_coord_to_strings_indexed,
366				when_noncoord=when_noncoord,
367			)
368		else:
369			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
370			raise ValueError(err_msg)
371
372	@overload
373	def strings_to_coords(
374		cls,  # noqa: N805
375		text: str | list[str],
376		when_noncoord: Literal["skip"] = "skip",
377	) -> list[CoordTup]: ...
378	@overload
379	def strings_to_coords(
380		cls,  # noqa: N805
381		text: str | list[str],
382		when_noncoord: Literal["error"] = "error",
383	) -> list[CoordTup]: ...
384	@overload
385	def strings_to_coords(
386		cls,  # noqa: N805
387		text: str | list[str],
388		when_noncoord: Literal["include"] = "include",
389	) -> list[str | CoordTup]: ...
390	@classmethod
391	def strings_to_coords(
392		cls,
393		text: str | list[str],
394		when_noncoord: WhenMissing = "skip",
395	) -> list[str | CoordTup]:
396		"wrapper for `maze_dataset.token_utils.strings_to_coords`"
397		return strings_to_coords(text=text, when_noncoord=when_noncoord)
398
399	def encode(self, text: str | list[str]) -> list[int]:
400		"""encode a string or list of strings into a list of tokens"""
401		try:
402			if isinstance(text, str):
403				text = text.split()
404			return [self.tokenizer_map[token] for token in text]
405		except KeyError as e:
406			err_msg: str = (
407				f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}"
408			)
409			raise TokenError(err_msg) from e
410
411	def decode(
412		self,
413		tokens: Sequence[int],
414		joined_tokens: bool = False,
415	) -> list[str] | str:
416		"""decode a list of tokens into a string or list of strings"""
417		try:
418			output: list[str] = [self.token_arr[token] for token in tokens]
419		except IndexError as e:
420			err_msg: str = (
421				f"Token index '{e}' not found in vocabulary of length {self.vocab_size}"
422			)
423			raise TokenError(err_msg) from e
424		if joined_tokens:
425			return " ".join(output)
426		else:
427			return output
428
429	# UT-only coordinate stuff
430	# ============================================================
431
432	@cached_property
433	def coordinate_tokens_coords(self) -> dict[CoordTup, int]:
434		"map of coordiante tuples to their token ids, only valid for UT"
435		# print(f"{self.tokenization_mode = }")
436		if not self.is_UT():
437			err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }"
438			raise ValueError(err_msg)
439
440		if self.max_grid_size is None:
441			err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }"
442			raise ValueError(err_msg)
443
444		raw_converted: list[CoordTup | str] = self.strings_to_coords(
445			self.token_arr,
446			when_noncoord="include",
447		)
448
449		# filter out non-coordinates
450		return {
451			coord: i
452			for i, coord in enumerate(raw_converted)
453			if not isinstance(coord, str)
454		}
455
456	@cached_property
457	def coordinate_tokens_ids(self) -> dict[str, int]:
458		"map of coordinate tokens to their token ids, only valid for UT"
459		# checks performed in call
460		output: dict[str, int] = dict()
461
462		for coord, index in self.coordinate_tokens_coords.items():
463			_for_key: list[str] = self.coords_to_strings([coord])
464			assert len(_for_key) == 1
465			output[_for_key[0]] = index
466
467		return output
468
469	# other
470	# ============================================================
471
472	def summary(self) -> dict:
473		"""returns a summary of the tokenization mode"""
474		return {
475			"tokenization_mode": self.tokenization_mode.value,
476			"max_grid_size": self.max_grid_size,
477			"vocab_size": self.vocab_size,
478		}
479
480	def is_AOTP(self) -> bool:
481		"""returns true if a tokenization mode is Adjacency list, Origin, Target, Path"""
482		return self.tokenization_mode in (
483			TokenizationMode.AOTP_UT_rasterized,
484			TokenizationMode.AOTP_UT_uniform,
485			TokenizationMode.AOTP_CTT_indexed,
486		)
487
488	def is_UT(self) -> bool:
489		"returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
490		return is_UT(self.tokenization_mode)
491
492	def clear_cache(self) -> None:
493		"""clears all cached properties"""
494		# delete the properties only if they exist
495		for name, prop in self.__class__.__dict__.items():
496			if isinstance(prop, cached_property):
497				# if the property exists, delete it
498				try:  # noqa: SIM105
499					delattr(self, name)
500				except AttributeError:
501					pass

LEGACY Tokenizer for mazes

Caution

MazeTokenizerModular is the new standard for tokenization. This class is no longer recommended for use, but will remain for compatibility with existing code.

Parameters:

  • tokenization_mode: TokenizationMode mode of tokenization. required.
  • max_grid_size: int | None maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text

Properties

  • name: str auto-generated name of the tokenizer from mode and size

Conditional Properties

  • node_strings_map: Mapping[CoordTup, str] map from node to string. This returns a muutils.kappa.Kappa object which you can use like a dictionary. returns None if not a UT mode

these all return None if max_grid_size is None. Prepend _ to the name to get a guaranteed type, and cause an exception if max_grid_size is None

  • token_arr: list[str] list of tokens, in order of their indices in the vocabulary
  • tokenizer_map: Mapping[str, int] map from token to index
  • vocab_size: int size of the vocabulary
  • padding_token_index: int index of the padding token

Methods

  • coords_to_strings(coords: list[CoordTup]) -> list[str] convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates
  • strings_to_coords(strings: list[str]) -> list[CoordTup] convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
MazeTokenizer( *, tokenization_mode: TokenizationMode = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>, max_grid_size: int | None = None)
tokenization_mode: TokenizationMode = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>
max_grid_size: int | None = None
name: str
200	@property
201	def name(self) -> str:
202		"""auto-generated name of the tokenizer from mode and size"""
203		max_grid_size_str: str = (
204			f"-g{self.max_grid_size}" if self.max_grid_size is not None else ""
205		)
206		return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"

auto-generated name of the tokenizer from mode and size

node_strings_map: Optional[Mapping[tuple[int, int], list[str]]]
222	@cached_property
223	def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None:
224		"""map a coordinate to a token"""
225		if self.tokenization_mode in (
226			TokenizationMode.AOTP_UT_rasterized,
227			TokenizationMode.AOTP_UT_uniform,
228		):
229			return None
230		else:
231			return self._node_strings_map

map a coordinate to a token

token_arr: list[str] | None
275	@cached_property
276	def token_arr(self) -> list[str] | None:
277		"get the token array if the max_grid_size is specified"
278		if self.max_grid_size is None:
279			return None
280		return self._token_arr

get the token array if the max_grid_size is specified

tokenizer_map: dict[str, int] | None
287	@cached_property
288	def tokenizer_map(self) -> dict[str, int] | None:
289		"get the tokenizer map if the max_grid_size is specified"
290		if self.max_grid_size is None:
291			return None
292		return self._tokenizer_map

get the tokenizer map if the max_grid_size is specified

vocab_size: int | None
298	@property
299	def vocab_size(self) -> int | None:
300		"get the size of the vocabulary if the max_grid_size is specified"
301		if self.max_grid_size is None:
302			return None
303		return self._vocab_size

get the size of the vocabulary if the max_grid_size is specified

n_tokens: int | None
310	@property
311	def n_tokens(self) -> int | None:
312		"get the number of tokens if the max_grid_size is specified"
313		if self.max_grid_size is None:
314			return None
315		return self._n_tokens

get the number of tokens if the max_grid_size is specified

padding_token_index: int | None
321	@cached_property
322	def padding_token_index(self) -> int | None:
323		"get the index of the padding token if it exists"
324		if self.max_grid_size is None:
325			return None
326		return self._padding_token_index

get the index of the padding token if it exists

def coords_to_strings( self, coords: list[tuple[int, int]], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str]:
343	def coords_to_strings(
344		self,
345		coords: list[CoordTup],
346		when_noncoord: WhenMissing = "skip",
347	) -> list[str]:
348		"""map a list of coordinate tuples (and maybe other tokens) to strings
349
350		wraps `maze_dataset.token_utils.coords_to_strings` with either
351		`_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode
352		"""
353		if self.tokenization_mode in (
354			TokenizationMode.AOTP_UT_rasterized,
355			TokenizationMode.AOTP_UT_uniform,
356		):
357			return coords_to_strings(
358				coords=coords,
359				coord_to_strings_func=_coord_to_strings_UT,
360				when_noncoord=when_noncoord,
361			)
362		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
363			return coords_to_strings(
364				coords=coords,
365				coord_to_strings_func=_coord_to_strings_indexed,
366				when_noncoord=when_noncoord,
367			)
368		else:
369			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
370			raise ValueError(err_msg)

map a list of coordinate tuples (and maybe other tokens) to strings

wraps maze_dataset.token_utils.coords_to_strings with either _coord_to_strings_UT or _coord_to_strings_indexed depending on the tokenization mode

@classmethod
def strings_to_coords( cls, text: str | list[str], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str | tuple[int, int]]:
390	@classmethod
391	def strings_to_coords(
392		cls,
393		text: str | list[str],
394		when_noncoord: WhenMissing = "skip",
395	) -> list[str | CoordTup]:
396		"wrapper for `maze_dataset.token_utils.strings_to_coords`"
397		return strings_to_coords(text=text, when_noncoord=when_noncoord)
def encode(self, text: str | list[str]) -> list[int]:
399	def encode(self, text: str | list[str]) -> list[int]:
400		"""encode a string or list of strings into a list of tokens"""
401		try:
402			if isinstance(text, str):
403				text = text.split()
404			return [self.tokenizer_map[token] for token in text]
405		except KeyError as e:
406			err_msg: str = (
407				f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}"
408			)
409			raise TokenError(err_msg) from e

encode a string or list of strings into a list of tokens

def decode( self, tokens: Sequence[int], joined_tokens: bool = False) -> list[str] | str:
411	def decode(
412		self,
413		tokens: Sequence[int],
414		joined_tokens: bool = False,
415	) -> list[str] | str:
416		"""decode a list of tokens into a string or list of strings"""
417		try:
418			output: list[str] = [self.token_arr[token] for token in tokens]
419		except IndexError as e:
420			err_msg: str = (
421				f"Token index '{e}' not found in vocabulary of length {self.vocab_size}"
422			)
423			raise TokenError(err_msg) from e
424		if joined_tokens:
425			return " ".join(output)
426		else:
427			return output

decode a list of tokens into a string or list of strings

coordinate_tokens_coords: dict[tuple[int, int], int]
432	@cached_property
433	def coordinate_tokens_coords(self) -> dict[CoordTup, int]:
434		"map of coordiante tuples to their token ids, only valid for UT"
435		# print(f"{self.tokenization_mode = }")
436		if not self.is_UT():
437			err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }"
438			raise ValueError(err_msg)
439
440		if self.max_grid_size is None:
441			err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }"
442			raise ValueError(err_msg)
443
444		raw_converted: list[CoordTup | str] = self.strings_to_coords(
445			self.token_arr,
446			when_noncoord="include",
447		)
448
449		# filter out non-coordinates
450		return {
451			coord: i
452			for i, coord in enumerate(raw_converted)
453			if not isinstance(coord, str)
454		}

map of coordiante tuples to their token ids, only valid for UT

coordinate_tokens_ids: dict[str, int]
456	@cached_property
457	def coordinate_tokens_ids(self) -> dict[str, int]:
458		"map of coordinate tokens to their token ids, only valid for UT"
459		# checks performed in call
460		output: dict[str, int] = dict()
461
462		for coord, index in self.coordinate_tokens_coords.items():
463			_for_key: list[str] = self.coords_to_strings([coord])
464			assert len(_for_key) == 1
465			output[_for_key[0]] = index
466
467		return output

map of coordinate tokens to their token ids, only valid for UT

def summary(self) -> dict:
472	def summary(self) -> dict:
473		"""returns a summary of the tokenization mode"""
474		return {
475			"tokenization_mode": self.tokenization_mode.value,
476			"max_grid_size": self.max_grid_size,
477			"vocab_size": self.vocab_size,
478		}

returns a summary of the tokenization mode

def is_AOTP(self) -> bool:
480	def is_AOTP(self) -> bool:
481		"""returns true if a tokenization mode is Adjacency list, Origin, Target, Path"""
482		return self.tokenization_mode in (
483			TokenizationMode.AOTP_UT_rasterized,
484			TokenizationMode.AOTP_UT_uniform,
485			TokenizationMode.AOTP_CTT_indexed,
486		)

returns true if a tokenization mode is Adjacency list, Origin, Target, Path

def is_UT(self) -> bool:
488	def is_UT(self) -> bool:
489		"returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
490		return is_UT(self.tokenization_mode)

returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)

def clear_cache(self) -> None:
492	def clear_cache(self) -> None:
493		"""clears all cached properties"""
494		# delete the properties only if they exist
495		for name, prop in self.__class__.__dict__.items():
496			if isinstance(prop, cached_property):
497				# if the property exists, delete it
498				try:  # noqa: SIM105
499					delattr(self, name)
500				except AttributeError:
501					pass

clears all cached properties

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
class TokenizationMode(enum.Enum):
47class TokenizationMode(Enum):
48	"""legacy tokenization modes
49
50	> [!CAUTION]
51	> Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use.
52	> Use `MazeTokenizerModular` instead.
53
54	# Abbreviations:
55	- `AOTP`: Ajacency list, Origin, Target, Path
56	- `UT`: Unique Token (for each coordiate)
57	- `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)
58
59	# Modes:
60	- `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization
61		example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)`
62	- `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible
63		uses `corner_first_ndindex` function to order the tokens
64	- `AOTP_CTT_indexed`: each coordinate is a tuple of integers
65	"""
66
67	AOTP_UT_rasterized = "AOTP_UT_rasterized"
68	AOTP_UT_uniform = "AOTP_UT_uniform"
69	AOTP_CTT_indexed = "AOTP_CTT_indexed"
70
71	def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer":
72		"convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`"
73		return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)

legacy tokenization modes

Caution

Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use. Use MazeTokenizerModular instead.

Abbreviations:

  • AOTP: Ajacency list, Origin, Target, Path
  • UT: Unique Token (for each coordiate)
  • CTT: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)

Modes:

  • AOTP_UT_rasterized: the "classic" mode: assigning tokens to each coordinate is done via rasterization example: for a 3x3 maze, token order is (0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)
  • AOTP_UT_uniform: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible uses corner_first_ndindex function to order the tokens
  • AOTP_CTT_indexed: each coordinate is a tuple of integers
AOTP_UT_rasterized = <TokenizationMode.AOTP_UT_rasterized: 'AOTP_UT_rasterized'>
AOTP_UT_uniform = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>
AOTP_CTT_indexed = <TokenizationMode.AOTP_CTT_indexed: 'AOTP_CTT_indexed'>
def to_legacy_tokenizer( self, max_grid_size: int | None = None) -> MazeTokenizer:
71	def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer":
72		"convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`"
73		return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)

convert the mode to a legacy MazeTokenizer object given a max_grid_size

Inherited Members
enum.Enum
name
value
@serializable_dataclass(frozen=True, kw_only=True, properties_to_serialize=['tokenizer_element_tree_concrete', 'name'])
class MazeTokenizerModular(muutils.json_serialize.serializable_dataclass.SerializableDataclass):
 51@serializable_dataclass(
 52	frozen=True,
 53	kw_only=True,
 54	properties_to_serialize=["tokenizer_element_tree_concrete", "name"],
 55)
 56class MazeTokenizerModular(SerializableDataclass):
 57	"""Tokenizer for mazes
 58
 59	# Parameters
 60	- `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.
 61
 62	# Development
 63	- To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`.
 64	- Furthermore, the mapping reflected in `from_legacy` must also be maintained.
 65	- Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior.
 66	"""
 67
 68	prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field(
 69		default=PromptSequencers.AOTP(),
 70		loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers),
 71	)
 72
 73	def hash_int(self) -> int:
 74		"return integer hash using blake2b"
 75		return _hash_tokenizer_name(self.name)
 76
 77	def __hash__(self) -> int:
 78		"Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
 79		return self.hash_int()
 80
 81	def hash_b64(self, n_bytes: int = 8) -> str:
 82		"""filename-safe base64 encoding of the hash"""
 83		# Use modulus to ensure the integer fits within n_bytes * 8 bits
 84		hash_mod: int = self.hash_int() % (1 << (n_bytes * 8))
 85
 86		encoded = base64.b64encode(
 87			hash_mod.to_bytes(n_bytes, byteorder="big"),
 88			altchars=b"-_",
 89		).decode()
 90
 91		# Remove any padding equals signs
 92		return encoded.rstrip("=")
 93
 94	# Information Querying Methods
 95
 96	@cached_property
 97	def tokenizer_elements(self) -> list[_TokenizerElement]:
 98		"returns a list of all the elements of this tokenizer"
 99		return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]
100
101	def tokenizer_element_tree(self, abstract: bool = False) -> str:
102		"""Returns a string representation of the tree of tokenizer elements contained in `self`.
103
104		# Parameters
105		- `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
106		"""
107		return "\n".join(
108			[
109				type(self).__name__,
110				self.prompt_sequencer.tokenizer_element_tree(
111					abstract=abstract,
112					depth=1,
113				),
114			],
115		)
116
117	@property
118	def tokenizer_element_tree_concrete(self) -> str:
119		"""Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`."""
120		return self.tokenizer_element_tree()
121
122	def tokenizer_element_dict(self) -> dict:
123		"""Nested dictionary of the internal `TokenizerElement`s."""
124		return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}
125
126	@property
127	def name(self) -> str:
128		"""Serializes MazeTokenizer into a key for encoding in zanj"""
129		return "-".join([type(self).__name__, self.prompt_sequencer.name])  # noqa: FLY002
130
131	def summary(self) -> dict[str, str]:
132		"""Single-level dictionary of the internal `TokenizerElement`s."""
133		return {
134			# "prompt_sequencer": self.prompt_sequencer.name,
135			**{elem.attribute_key(): elem.name for elem in self.tokenizer_elements},
136		}
137
138	@staticmethod
139	def _type_check(obj: any) -> None:
140		"""Helper method for `has_element`"""
141		if not (
142			isinstance(obj, _TokenizerElement)
143			or (isinstance(obj, type) and issubclass(obj, _TokenizerElement))
144		):
145			err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass."
146			raise TypeError(err_msg)
147
148	def _has_element_singular(
149		self,
150		el: type[_TokenizerElement] | _TokenizerElement,
151	) -> bool:
152		"""Helper method for `has_element`"""
153		self._type_check(el)
154		if isinstance(el, type):
155			return any(isinstance(e, el) for e in self.tokenizer_elements)
156		else:
157			return el in self.tokenizer_elements
158
159	def has_element(
160		self,
161		*elements: Sequence[type[_TokenizerElement] | _TokenizerElement],
162	) -> bool:
163		"""Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`.
164
165		Querying with a partial subset of `_TokenizerElement` fields is not currently supported.
166		To do such a query, assemble multiple calls to `has_elements`.
167
168		# Parameters
169		- `elements`: Singleton or iterable of `_TokenizerElement` instances or classes.
170		If an instance is provided, then comparison is done via instance equality.
171		If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted.
172		"""
173		if len(elements) == 1 and isinstance(elements[0], Iterable):
174			elements = elements[0]
175		return all(self._has_element_singular(e) for e in elements)
176
177	def is_valid(self, do_except: bool = False) -> bool:
178		"""Returns `True` if `self` is a valid tokenizer.
179
180		Evaluates the validity of all of `self.tokenizer_elements` according to each one's method.
181		"""
182		return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)
183
184	def is_legacy_equivalent(self) -> bool:
185		"""Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`."""
186		return any(
187			self == MazeTokenizerModular.from_legacy(tok_mode)
188			for tok_mode in TokenizationMode
189		)
190
191	def is_tested_tokenizer(self, do_except: bool = False) -> bool:
192		"""Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers.
193
194		uses an fst on the `name` attributes of all the tokenizers
195
196		if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested.
197		"""
198		is_valid: bool = self.is_valid(do_except=do_except)
199		in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except)
200
201		if do_except:
202			assert is_valid, "self.is_valid returns False"
203			return True
204		else:
205			return in_tested_fst and is_valid
206
207	def is_AOTP(self) -> bool:
208		"is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path"
209		return self.has_element(PromptSequencers.AOTP)
210
211	def is_UT(self) -> bool:
212		"is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)"
213		return self.has_element(CoordTokenizers.UT)
214
215	# Alternate Constructors
216	# ======================
217
218	@classmethod
219	def from_legacy(
220		cls,
221		legacy_maze_tokenizer: MazeTokenizer | TokenizationMode,
222	) -> "MazeTokenizerModular":
223		"""Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance."""
224		if isinstance(legacy_maze_tokenizer, MazeTokenizer):
225			legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode
226		return {
227			TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(),
228			TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(),
229			TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular(
230				prompt_sequencer=PromptSequencers.AOTP(
231					coord_tokenizer=CoordTokenizers.CTT(),
232				),
233			),
234		}[legacy_maze_tokenizer]
235
236	# Simple properties
237	# =================
238	@classmethod
239	def from_tokens(
240		cls,
241		tokens: str | list[str],
242	) -> "MazeTokenizerModular":
243		"""Infers most `MazeTokenizerModular` parameters from a full sequence of tokens."""
244		raise NotImplementedError(
245			"Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported",
246		)
247
248	@property
249	def token_arr(self) -> list[str] | None:
250		"""map from index to token"""
251		return VOCAB_LIST
252
253	@property
254	def tokenizer_map(self) -> dict[str, int]:
255		"""map from token to index"""
256		return VOCAB_TOKEN_TO_INDEX
257
258	@property
259	def vocab_size(self) -> int:
260		"""Number of tokens in the static vocab"""
261		return len(VOCAB_LIST)
262
263	@property
264	def n_tokens(self) -> int:
265		"get the number of tokens in the vocabulary (deprecated)"
266		err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead."
267		raise NameError(err_msg)
268
269	@property
270	def padding_token_index(self) -> int:
271		"get the index of the padding token"
272		return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]
273
274	# conversion functions
275	# ============================================================
276
277	def to_tokens(
278		self,
279		maze: LatticeMaze,
280	) -> list[str]:
281		"""Converts maze into a list of tokens."""
282		return self.prompt_sequencer.to_tokens(maze)
283
284	def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]:
285		"calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords"
286		return list(
287			flatten(
288				[self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords],
289			),
290		)
291
292	# TODO: unclear why we need to use `noqa: N805` here since its a classmethod
293	# maybe we need to hit every overload with `@classmethod`?
294	@overload
295	def strings_to_coords(
296		cls,  # noqa: N805
297		text: str | list[str],
298		when_noncoord: Literal["skip"] = "skip",
299	) -> list[CoordTup]: ...
300	@overload
301	def strings_to_coords(
302		cls,  # noqa: N805
303		text: str | list[str],
304		when_noncoord: Literal["error"] = "error",
305	) -> list[CoordTup]: ...
306	@overload
307	def strings_to_coords(
308		cls,  # noqa: N805
309		text: str | list[str],
310		when_noncoord: Literal["include"] = "include",
311	) -> list[str | CoordTup]: ...
312	@classmethod
313	def strings_to_coords(
314		cls,
315		text: str | list[str],
316		when_noncoord: WhenMissing = "skip",
317	) -> list[str | CoordTup]:
318		"wrapper for maze_dataset.token_utils.strings_to_coords"
319		warnings.warn(
320			"`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.",
321			TokenizerPendingDeprecationWarning,
322		)
323		return strings_to_coords(text=text, when_noncoord=when_noncoord)
324
325	@staticmethod
326	def encode(text: str | list[str]) -> list[int]:
327		"""encode a string or list of strings into a list of tokens"""
328		try:
329			if isinstance(text, str):
330				text = text.split()
331			return [VOCAB_TOKEN_TO_INDEX[token] for token in text]
332		except KeyError as e:
333			err_msg: str = f"Token {e} not found in `VOCAB`."
334			raise TokenError(err_msg) from e
335
336	@staticmethod
337	def decode(
338		token_ids: Sequence[int],
339		joined_tokens: bool = False,
340	) -> list[str] | str:
341		"""decode a list of tokens into a string or list of strings"""
342		try:
343			output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids]
344		except IndexError as e:
345			err_msg: str = f"Token index '{e}' not found in `VOCAB`."
346			raise TokenError(err_msg) from e
347		if joined_tokens:
348			return " ".join(output)
349		else:
350			return output

Tokenizer for mazes

Parameters

  • prompt_sequencer: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.

Development

  • To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy TokenizationMode.AOTP_UT_Uniform.
  • Furthermore, the mapping reflected in from_legacy must also be maintained.
  • Updates to MazeTokenizerModular or the _TokenizerElement hierarchy must maintain that behavior.
MazeTokenizerModular( *, prompt_sequencer: maze_dataset.tokenization.modular.elements.PromptSequencers._PromptSequencer = PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.UT(), adj_list_tokenizer=AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), target_tokenizer=TargetTokenizers.Unlabeled(post=False), path_tokenizer=PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False)))
prompt_sequencer: maze_dataset.tokenization.modular.elements.PromptSequencers._PromptSequencer = PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.UT(), adj_list_tokenizer=AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), target_tokenizer=TargetTokenizers.Unlabeled(post=False), path_tokenizer=PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False))
def hash_int(self) -> int:
73	def hash_int(self) -> int:
74		"return integer hash using blake2b"
75		return _hash_tokenizer_name(self.name)

return integer hash using blake2b

def hash_b64(self, n_bytes: int = 8) -> str:
81	def hash_b64(self, n_bytes: int = 8) -> str:
82		"""filename-safe base64 encoding of the hash"""
83		# Use modulus to ensure the integer fits within n_bytes * 8 bits
84		hash_mod: int = self.hash_int() % (1 << (n_bytes * 8))
85
86		encoded = base64.b64encode(
87			hash_mod.to_bytes(n_bytes, byteorder="big"),
88			altchars=b"-_",
89		).decode()
90
91		# Remove any padding equals signs
92		return encoded.rstrip("=")

filename-safe base64 encoding of the hash

tokenizer_elements: list[_TokenizerElement]
96	@cached_property
97	def tokenizer_elements(self) -> list[_TokenizerElement]:
98		"returns a list of all the elements of this tokenizer"
99		return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]

returns a list of all the elements of this tokenizer

def tokenizer_element_tree(self, abstract: bool = False) -> str:
101	def tokenizer_element_tree(self, abstract: bool = False) -> str:
102		"""Returns a string representation of the tree of tokenizer elements contained in `self`.
103
104		# Parameters
105		- `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
106		"""
107		return "\n".join(
108			[
109				type(self).__name__,
110				self.prompt_sequencer.tokenizer_element_tree(
111					abstract=abstract,
112					depth=1,
113				),
114			],
115		)

Returns a string representation of the tree of tokenizer elements contained in self.

Parameters

  • abstract: bool: Whether to print the name of the abstract base class or the concrete class for each _TokenizerElement instance.
tokenizer_element_tree_concrete: str
117	@property
118	def tokenizer_element_tree_concrete(self) -> str:
119		"""Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`."""
120		return self.tokenizer_element_tree()

Property wrapper for tokenizer_element_tree so that it can be used in properties_to_serialize.

def tokenizer_element_dict(self) -> dict:
122	def tokenizer_element_dict(self) -> dict:
123		"""Nested dictionary of the internal `TokenizerElement`s."""
124		return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}

Nested dictionary of the internal TokenizerElements.

name: str
126	@property
127	def name(self) -> str:
128		"""Serializes MazeTokenizer into a key for encoding in zanj"""
129		return "-".join([type(self).__name__, self.prompt_sequencer.name])  # noqa: FLY002

Serializes MazeTokenizer into a key for encoding in zanj

def summary(self) -> dict[str, str]:
131	def summary(self) -> dict[str, str]:
132		"""Single-level dictionary of the internal `TokenizerElement`s."""
133		return {
134			# "prompt_sequencer": self.prompt_sequencer.name,
135			**{elem.attribute_key(): elem.name for elem in self.tokenizer_elements},
136		}

Single-level dictionary of the internal TokenizerElements.

def has_element( self, *elements: Sequence[type[_TokenizerElement] | _TokenizerElement]) -> bool:
159	def has_element(
160		self,
161		*elements: Sequence[type[_TokenizerElement] | _TokenizerElement],
162	) -> bool:
163		"""Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`.
164
165		Querying with a partial subset of `_TokenizerElement` fields is not currently supported.
166		To do such a query, assemble multiple calls to `has_elements`.
167
168		# Parameters
169		- `elements`: Singleton or iterable of `_TokenizerElement` instances or classes.
170		If an instance is provided, then comparison is done via instance equality.
171		If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted.
172		"""
173		if len(elements) == 1 and isinstance(elements[0], Iterable):
174			elements = elements[0]
175		return all(self._has_element_singular(e) for e in elements)

Returns True if the MazeTokenizerModular instance contains ALL of the items specified in elements.

Querying with a partial subset of _TokenizerElement fields is not currently supported. To do such a query, assemble multiple calls to has_elements.

Parameters

  • elements: Singleton or iterable of _TokenizerElement instances or classes. If an instance is provided, then comparison is done via instance equality. If a class is provided, then comparison isdone via isinstance. I.e., any instance of that class is accepted.
def is_valid(self, do_except: bool = False) -> bool:
177	def is_valid(self, do_except: bool = False) -> bool:
178		"""Returns `True` if `self` is a valid tokenizer.
179
180		Evaluates the validity of all of `self.tokenizer_elements` according to each one's method.
181		"""
182		return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)

Returns True if self is a valid tokenizer.

Evaluates the validity of all of self.tokenizer_elements according to each one's method.

def is_legacy_equivalent(self) -> bool:
184	def is_legacy_equivalent(self) -> bool:
185		"""Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`."""
186		return any(
187			self == MazeTokenizerModular.from_legacy(tok_mode)
188			for tok_mode in TokenizationMode
189		)

Returns if self has identical stringification behavior as any legacy MazeTokenizer.

def is_tested_tokenizer(self, do_except: bool = False) -> bool:
191	def is_tested_tokenizer(self, do_except: bool = False) -> bool:
192		"""Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers.
193
194		uses an fst on the `name` attributes of all the tokenizers
195
196		if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested.
197		"""
198		is_valid: bool = self.is_valid(do_except=do_except)
199		in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except)
200
201		if do_except:
202			assert is_valid, "self.is_valid returns False"
203			return True
204		else:
205			return in_tested_fst and is_valid

Returns if the tokenizer is returned by all_tokenizers.get_all_tokenizers, the set of tested and reliable tokenizers.

uses an fst on the name attributes of all the tokenizers

if do_assert is True, raises an AssertionError if the tokenizer is not tested.

def is_AOTP(self) -> bool:
207	def is_AOTP(self) -> bool:
208		"is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path"
209		return self.has_element(PromptSequencers.AOTP)

is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path

def is_UT(self) -> bool:
211	def is_UT(self) -> bool:
212		"is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)"
213		return self.has_element(CoordTokenizers.UT)

is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)

@classmethod
def from_legacy( cls, legacy_maze_tokenizer: MazeTokenizer | TokenizationMode) -> MazeTokenizerModular:
218	@classmethod
219	def from_legacy(
220		cls,
221		legacy_maze_tokenizer: MazeTokenizer | TokenizationMode,
222	) -> "MazeTokenizerModular":
223		"""Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance."""
224		if isinstance(legacy_maze_tokenizer, MazeTokenizer):
225			legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode
226		return {
227			TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(),
228			TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(),
229			TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular(
230				prompt_sequencer=PromptSequencers.AOTP(
231					coord_tokenizer=CoordTokenizers.CTT(),
232				),
233			),
234		}[legacy_maze_tokenizer]

Maps a legacy MazeTokenizer or TokenizationMode to its equivalent MazeTokenizerModular instance.

@classmethod
def from_tokens( cls, tokens: str | list[str]) -> MazeTokenizerModular:
238	@classmethod
239	def from_tokens(
240		cls,
241		tokens: str | list[str],
242	) -> "MazeTokenizerModular":
243		"""Infers most `MazeTokenizerModular` parameters from a full sequence of tokens."""
244		raise NotImplementedError(
245			"Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported",
246		)

Infers most MazeTokenizerModular parameters from a full sequence of tokens.

token_arr: list[str] | None
248	@property
249	def token_arr(self) -> list[str] | None:
250		"""map from index to token"""
251		return VOCAB_LIST

map from index to token

tokenizer_map: dict[str, int]
253	@property
254	def tokenizer_map(self) -> dict[str, int]:
255		"""map from token to index"""
256		return VOCAB_TOKEN_TO_INDEX

map from token to index

vocab_size: int
258	@property
259	def vocab_size(self) -> int:
260		"""Number of tokens in the static vocab"""
261		return len(VOCAB_LIST)

Number of tokens in the static vocab

n_tokens: int
263	@property
264	def n_tokens(self) -> int:
265		"get the number of tokens in the vocabulary (deprecated)"
266		err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead."
267		raise NameError(err_msg)

get the number of tokens in the vocabulary (deprecated)

padding_token_index: int
269	@property
270	def padding_token_index(self) -> int:
271		"get the index of the padding token"
272		return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]

get the index of the padding token

def to_tokens(self, maze: maze_dataset.LatticeMaze) -> list[str]:
277	def to_tokens(
278		self,
279		maze: LatticeMaze,
280	) -> list[str]:
281		"""Converts maze into a list of tokens."""
282		return self.prompt_sequencer.to_tokens(maze)

Converts maze into a list of tokens.

def coords_to_strings( self, coords: list[tuple[int, int] | jaxtyping.Int8[ndarray, 'row_col=2']]) -> list[str]:
284	def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]:
285		"calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords"
286		return list(
287			flatten(
288				[self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords],
289			),
290		)

calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords

@classmethod
def strings_to_coords( cls, text: str | list[str], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str | tuple[int, int]]:
312	@classmethod
313	def strings_to_coords(
314		cls,
315		text: str | list[str],
316		when_noncoord: WhenMissing = "skip",
317	) -> list[str | CoordTup]:
318		"wrapper for maze_dataset.token_utils.strings_to_coords"
319		warnings.warn(
320			"`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.",
321			TokenizerPendingDeprecationWarning,
322		)
323		return strings_to_coords(text=text, when_noncoord=when_noncoord)
@staticmethod
def encode(text: str | list[str]) -> list[int]:
325	@staticmethod
326	def encode(text: str | list[str]) -> list[int]:
327		"""encode a string or list of strings into a list of tokens"""
328		try:
329			if isinstance(text, str):
330				text = text.split()
331			return [VOCAB_TOKEN_TO_INDEX[token] for token in text]
332		except KeyError as e:
333			err_msg: str = f"Token {e} not found in `VOCAB`."
334			raise TokenError(err_msg) from e

encode a string or list of strings into a list of tokens

@staticmethod
def decode(token_ids: Sequence[int], joined_tokens: bool = False) -> list[str] | str:
336	@staticmethod
337	def decode(
338		token_ids: Sequence[int],
339		joined_tokens: bool = False,
340	) -> list[str] | str:
341		"""decode a list of tokens into a string or list of strings"""
342		try:
343			output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids]
344		except IndexError as e:
345			err_msg: str = f"Token index '{e}' not found in `VOCAB`."
346			raise TokenError(err_msg) from e
347		if joined_tokens:
348			return " ".join(output)
349		else:
350			return output

decode a list of tokens into a string or list of strings

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class _TokenizerElement(muutils.json_serialize.serializable_dataclass.SerializableDataclass, abc.ABC):
 29@serializable_dataclass(frozen=True, kw_only=True)
 30class _TokenizerElement(SerializableDataclass, abc.ABC):
 31	"""Superclass for tokenizer elements.
 32
 33	Subclasses contain modular functionality for maze tokenization.
 34
 35	# Development
 36	> [!TIP]
 37	> Due to the functionality of `get_all_tokenizers()`, `_TokenizerElement` subclasses
 38	> may only contain fields of type `utils.FiniteValued`.
 39	> Implementing a subclass with an `int` or `float`-typed field, for example, is not supported.
 40	> In the event that adding such fields is deemed necessary, `get_all_tokenizers()` must be updated.
 41
 42	"""
 43
 44	# TYPING: type hint `v` more specifically
 45	@staticmethod
 46	def _stringify(k: str, v: Any) -> str:  # noqa: ANN401
 47		if isinstance(v, bool):
 48			return f"{k}={str(v)[0]}"
 49		if isinstance(v, _TokenizerElement):
 50			return v.name
 51		if isinstance(v, tuple):
 52			return f"{k}={''.join(['(', *[str(x) + ', ' for x in v], ')'])}"
 53		else:
 54			return f"{k}={v}"
 55
 56	@property
 57	def name(self) -> str:
 58		members_str: str = ", ".join(
 59			[self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"],
 60		)
 61		output: str = f"{type(self).__name__}({members_str})"
 62		if "." in output and output.index("(") > output.index("."):
 63			return "".join(output.split(".")[1:])
 64		else:
 65			return output
 66
 67	def __str__(self) -> str:
 68		return self.name
 69
 70	# TYPING: type hints for `__init_subclass__`?
 71	def __init_subclass__(cls, **kwargs):  # noqa: ANN204
 72		"""Hack: dataclass hashes don't include the class itself in the hash function inputs.
 73
 74		This causes dataclasses with identical fields but different types to hash identically.
 75		This hack circumvents this by adding a slightly hidden field to every subclass with a value of `repr(cls)`.
 76		To maintain compatibility with `all_instances`, the static type of the new field can only have 1 possible value.
 77		So we type it as a singleton `Literal` type.
 78		muutils 0.6.1 doesn't support `Literal` type validation, so `assert_type=False`.
 79		Ignore Pylance complaining about the arg to `Literal` being an expression.
 80		"""
 81		super().__init_subclass__(**kwargs)
 82		# we are adding a new attr here intentionally
 83		cls._type_ = serializable_field(  # type: ignore[attr-defined]
 84			init=True,
 85			repr=False,
 86			default=repr(cls),
 87			assert_type=False,
 88		)
 89		cls.__annotations__["_type_"] = Literal[repr(cls)]
 90
 91	def __hash__(self) -> int:
 92		"Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
 93		return _hash_tokenizer_name(self.name)
 94
 95	@classmethod
 96	def _level_one_subclass(cls) -> type["_TokenizerElement"]:
 97		"""Returns the immediate subclass of `_TokenizerElement` of which `cls` is an instance."""
 98		return (
 99			set(cls.__mro__).intersection(set(_TokenizerElement.__subclasses__())).pop()
100		)
101
102	def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]:
103		"""Returns a list of all `_TokenizerElement` instances contained in the subtree.
104
105		Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or
106		which sit inside a `tuple` without further nesting.
107
108		# Parameters
109		- `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer.
110		"""
111		if not any(type(el) == tuple for el in self.__dict__.values()):  # noqa: E721
112			return list(
113				flatten(
114					[
115						[el, *el.tokenizer_elements()]
116						for el in self.__dict__.values()
117						if isinstance(el, _TokenizerElement)
118					],
119				)
120				if deep
121				else filter(
122					lambda x: isinstance(x, _TokenizerElement),
123					self.__dict__.values(),
124				),
125			)
126		else:
127			non_tuple_elems: list[_TokenizerElement] = list(
128				flatten(
129					[
130						[el, *el.tokenizer_elements()]
131						for el in self.__dict__.values()
132						if isinstance(el, _TokenizerElement)
133					]
134					if deep
135					else filter(
136						lambda x: isinstance(x, _TokenizerElement),
137						self.__dict__.values(),
138					),
139				),
140			)
141			tuple_elems: list[_TokenizerElement] = list(
142				flatten(
143					[
144						(
145							[
146								[tup_el, *tup_el.tokenizer_elements()]
147								for tup_el in el
148								if isinstance(tup_el, _TokenizerElement)
149							]
150							if deep
151							else filter(lambda x: isinstance(x, _TokenizerElement), el)
152						)
153						for el in self.__dict__.values()
154						if isinstance(el, tuple)
155					],
156				),
157			)
158			non_tuple_elems.extend(tuple_elems)
159			return non_tuple_elems
160
161	def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str:
162		"""Returns a string representation of the tree of tokenizer elements contained in `self`.
163
164		# Parameters
165		- `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify.
166		- `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
167		"""
168		name: str = "\t" * depth + (
169			type(self).__name__
170			if not abstract
171			else type(self)._level_one_subclass().__name__
172		)
173		return (
174			name
175			+ "\n"
176			+ "".join(
177				el.tokenizer_element_tree(depth + 1, abstract)
178				for el in self.tokenizer_elements(deep=False)
179			)
180		)
181
182	def tokenizer_element_dict(self) -> dict:
183		"""Returns a dictionary representation of the tree of tokenizer elements contained in `self`."""
184		return {
185			type(self).__name__: {
186				key: (
187					val.tokenizer_element_dict()
188					if isinstance(val, _TokenizerElement)
189					else (
190						val
191						if not isinstance(val, tuple)
192						else [
193							(
194								el.tokenizer_element_dict()
195								if isinstance(el, _TokenizerElement)
196								else el
197							)
198							for el in val
199						]
200					)
201				)
202				for key, val in self.__dict__.items()
203				if key != "_type_"
204			},
205		}
206
207	@classmethod
208	@abc.abstractmethod
209	def attribute_key(cls) -> str:
210		"""Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`."""
211		raise NotImplementedError
212
213	def to_tokens(self, *args, **kwargs) -> list[str]:
214		"""Converts a maze element into a list of tokens.
215
216		Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method.
217		Those subclasses which do produce tokens should override this method.
218		"""
219		raise NotImplementedError
220
221	@abc.abstractmethod
222	def is_valid(self, do_except: bool = False) -> bool:
223		"""Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`.
224
225		Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints.
226		`is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone.
227		If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass.
228
229		# Types of Invalidity
230		In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below.
231		Invalidity types, in ascending order of invalidity:
232		- Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study.
233		E.g., `_TokenizerElement`s which are strictly worse than some alternative.
234		- Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
235		- Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
236		- Erroneous: These tokenizers might raise exceptions during use.
237
238		# Development
239		`is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid.
240		When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.
241
242		## Nesting
243		In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class.
244		In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree.
245		`MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected.
246
247		## Types of Invalidity
248		If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments.
249		This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances.
250		"""
251		raise NotImplementedError

Superclass for tokenizer elements.

Subclasses contain modular functionality for maze tokenization.

Development

Tip

Due to the functionality of get_all_tokenizers(), _TokenizerElement subclasses may only contain fields of type utils.FiniteValued. Implementing a subclass with an int or float-typed field, for example, is not supported. In the event that adding such fields is deemed necessary, get_all_tokenizers() must be updated.

name: str
56	@property
57	def name(self) -> str:
58		members_str: str = ", ".join(
59			[self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"],
60		)
61		output: str = f"{type(self).__name__}({members_str})"
62		if "." in output and output.index("(") > output.index("."):
63			return "".join(output.split(".")[1:])
64		else:
65			return output
def tokenizer_elements( self, deep: bool = True) -> list[_TokenizerElement]:
102	def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]:
103		"""Returns a list of all `_TokenizerElement` instances contained in the subtree.
104
105		Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or
106		which sit inside a `tuple` without further nesting.
107
108		# Parameters
109		- `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer.
110		"""
111		if not any(type(el) == tuple for el in self.__dict__.values()):  # noqa: E721
112			return list(
113				flatten(
114					[
115						[el, *el.tokenizer_elements()]
116						for el in self.__dict__.values()
117						if isinstance(el, _TokenizerElement)
118					],
119				)
120				if deep
121				else filter(
122					lambda x: isinstance(x, _TokenizerElement),
123					self.__dict__.values(),
124				),
125			)
126		else:
127			non_tuple_elems: list[_TokenizerElement] = list(
128				flatten(
129					[
130						[el, *el.tokenizer_elements()]
131						for el in self.__dict__.values()
132						if isinstance(el, _TokenizerElement)
133					]
134					if deep
135					else filter(
136						lambda x: isinstance(x, _TokenizerElement),
137						self.__dict__.values(),
138					),
139				),
140			)
141			tuple_elems: list[_TokenizerElement] = list(
142				flatten(
143					[
144						(
145							[
146								[tup_el, *tup_el.tokenizer_elements()]
147								for tup_el in el
148								if isinstance(tup_el, _TokenizerElement)
149							]
150							if deep
151							else filter(lambda x: isinstance(x, _TokenizerElement), el)
152						)
153						for el in self.__dict__.values()
154						if isinstance(el, tuple)
155					],
156				),
157			)
158			non_tuple_elems.extend(tuple_elems)
159			return non_tuple_elems

Returns a list of all _TokenizerElement instances contained in the subtree.

Currently only detects _TokenizerElement instances which are either direct attributes of another instance or which sit inside a tuple without further nesting.

Parameters

  • deep: bool: Whether to return elements nested arbitrarily deeply or just a single layer.
def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str:
161	def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str:
162		"""Returns a string representation of the tree of tokenizer elements contained in `self`.
163
164		# Parameters
165		- `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify.
166		- `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
167		"""
168		name: str = "\t" * depth + (
169			type(self).__name__
170			if not abstract
171			else type(self)._level_one_subclass().__name__
172		)
173		return (
174			name
175			+ "\n"
176			+ "".join(
177				el.tokenizer_element_tree(depth + 1, abstract)
178				for el in self.tokenizer_elements(deep=False)
179			)
180		)

Returns a string representation of the tree of tokenizer elements contained in self.

Parameters

  • depth: int: Current depth in the tree. Used internally for recursion, no need to specify.
  • abstract: bool: Whether to print the name of the abstract base class or the concrete class for each _TokenizerElement instance.
def tokenizer_element_dict(self) -> dict:
182	def tokenizer_element_dict(self) -> dict:
183		"""Returns a dictionary representation of the tree of tokenizer elements contained in `self`."""
184		return {
185			type(self).__name__: {
186				key: (
187					val.tokenizer_element_dict()
188					if isinstance(val, _TokenizerElement)
189					else (
190						val
191						if not isinstance(val, tuple)
192						else [
193							(
194								el.tokenizer_element_dict()
195								if isinstance(el, _TokenizerElement)
196								else el
197							)
198							for el in val
199						]
200					)
201				)
202				for key, val in self.__dict__.items()
203				if key != "_type_"
204			},
205		}

Returns a dictionary representation of the tree of tokenizer elements contained in self.

@classmethod
@abc.abstractmethod
def attribute_key(cls) -> str:
207	@classmethod
208	@abc.abstractmethod
209	def attribute_key(cls) -> str:
210		"""Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`."""
211		raise NotImplementedError

Returns the binding used in MazeTokenizerModular for that type of _TokenizerElement.

def to_tokens(self, *args, **kwargs) -> list[str]:
213	def to_tokens(self, *args, **kwargs) -> list[str]:
214		"""Converts a maze element into a list of tokens.
215
216		Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method.
217		Those subclasses which do produce tokens should override this method.
218		"""
219		raise NotImplementedError

Converts a maze element into a list of tokens.

Not all _TokenizerElement subclasses produce tokens, so this is not an abstract method. Those subclasses which do produce tokens should override this method.

@abc.abstractmethod
def is_valid(self, do_except: bool = False) -> bool:
221	@abc.abstractmethod
222	def is_valid(self, do_except: bool = False) -> bool:
223		"""Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`.
224
225		Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints.
226		`is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone.
227		If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass.
228
229		# Types of Invalidity
230		In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below.
231		Invalidity types, in ascending order of invalidity:
232		- Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study.
233		E.g., `_TokenizerElement`s which are strictly worse than some alternative.
234		- Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
235		- Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
236		- Erroneous: These tokenizers might raise exceptions during use.
237
238		# Development
239		`is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid.
240		When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.
241
242		## Nesting
243		In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class.
244		In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree.
245		`MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected.
246
247		## Types of Invalidity
248		If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments.
249		This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances.
250		"""
251		raise NotImplementedError

Returns if self contains data members capable of producing an overall valid MazeTokenizerModular.

Some _TokenizerElement instances may be created which are not useful despite obeying data member type hints. is_valid allows for more precise detection of invalid _TokenizerElements beyond type hinting alone. If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply return True for that subclass.

Types of Invalidity

In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity:

  • Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. E.g., _TokenizerElements which are strictly worse than some alternative.
  • Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
  • Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
  • Erroneous: These tokenizers might raise exceptions during use.

Development

is_invalid is implemented to always return True in some abstract classes where all currently possible subclass instances are valid. When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.

Nesting

In general, when implementing this method, there is no need to recursively call is_valid on nested _TokenizerElements contained in the class. In other words, failures of is_valid need not bubble up to the top of the nested _TokenizerElement tree. MazeTokenizerModular.is_valid calls is_valid on each of its _TokenizerElements individually, so failure at any level will be detected.

Types of Invalidity

If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. This could be used to create more or less stringent filters on the valid _TokenizerElement instances.

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
1053class PromptSequencers(__TokenizerElementNamespace):
1054	"""Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`."""
1055
1056	key = "prompt_sequencer"
1057
1058	@serializable_dataclass(frozen=True, kw_only=True)
1059	class _PromptSequencer(_TokenizerElement, abc.ABC):
1060		"""Sequences token regions into a complete maze tokenization.
1061
1062		# Parameters
1063		- `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position.
1064		- `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`.
1065		Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s.
1066		"""
1067
1068		coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field(
1069			default=CoordTokenizers.UT(),
1070			loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers),
1071		)
1072		adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field(
1073			default=AdjListTokenizers.AdjListCoord(),
1074			loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers),
1075		)
1076
1077		@classmethod
1078		def attribute_key(cls) -> str:
1079			return PromptSequencers.key
1080
1081		@staticmethod
1082		def _trim_if_unsolved_maze(
1083			untrimmed: list[str],
1084			is_untargeted: bool = False,
1085			is_unsolved: bool = False,
1086		) -> list[str]:
1087			"""Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze.
1088
1089			# Development
1090			This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP.
1091			It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses.
1092			"""
1093			if is_untargeted:
1094				return tokens_between(
1095					untrimmed,
1096					VOCAB.ADJLIST_START,
1097					VOCAB.ADJLIST_END,
1098					include_start=True,
1099					include_end=True,
1100				)
1101			if is_unsolved:
1102				if VOCAB.TARGET_END in untrimmed:
1103					return tokens_between(
1104						untrimmed,
1105						VOCAB.ADJLIST_START,
1106						VOCAB.TARGET_END,
1107						include_start=True,
1108						include_end=True,
1109					)
1110				else:
1111					return tokens_between(
1112						untrimmed,
1113						VOCAB.ADJLIST_START,
1114						VOCAB.ORIGIN_END,
1115						include_start=True,
1116						include_end=True,
1117					)
1118			return untrimmed
1119
1120		def to_tokens(
1121			self,
1122			maze: LatticeMaze,
1123			*args,
1124			**kwargs,
1125		) -> list[str]:
1126			"""Returns a complete list of tokens for a given set of maze elements."""
1127			untrimmed: list[str] = self._sequence_tokens(
1128				*self._get_prompt_regions(maze),
1129			)
1130			return self._trim_if_unsolved_maze(
1131				untrimmed,
1132				not hasattr(maze, "start_pos"),
1133				not hasattr(maze, "solution"),
1134			)
1135
1136		def _get_prompt_regions(
1137			self,
1138			maze: LatticeMaze,
1139			*args,
1140			**kwargs,
1141		) -> list[list[str]]:
1142			"""Gets the prompt regions of a maze in a fixed sequence.
1143
1144			This method is NOT responsible for including/excluding any prompt regions.
1145			Always return according to the API described under Returns.
1146			This implementation is expected to be suitable for most `PromptSequencer` subclasses.
1147			Subclasses may override this method if needed for special behavior.
1148
1149			# Returns
1150			- [0]: list[str] Adjacency list tokens
1151			- [1]: list[str] Origin tokens
1152			- [2]: list[str] Target tokens
1153			- [3]: list[str] Path tokens
1154
1155			# `None`-valued Args
1156			If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized.
1157			To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables.
1158			"""
1159			origin: Coord | None = getattr(maze, "start_pos", None)
1160			target: list[Coord] | None = [
1161				getattr(maze, "end_pos", None),
1162			]  # TargetTokenizer requires target: Sequence[Coord]
1163
1164			return [
1165				(
1166					self.adj_list_tokenizer.to_tokens(
1167						maze,
1168						coord_tokenizer=self.coord_tokenizer,
1169					)
1170					if hasattr(self, "adj_list_tokenizer")
1171					else []
1172				),
1173				self.coord_tokenizer.to_tokens(origin) if origin is not None else [],
1174				(
1175					self.target_tokenizer.to_tokens(
1176						target,
1177						coord_tokenizer=self.coord_tokenizer,
1178					)
1179					if target[0] is not None and hasattr(self, "target_tokenizer")
1180					else []
1181				),
1182				(
1183					self.path_tokenizer.to_tokens(
1184						maze,
1185						coord_tokenizer=self.coord_tokenizer,
1186					)
1187					if hasattr(maze, "solution") and hasattr(self, "path_tokenizer")
1188					else []
1189				),
1190			]
1191
1192		@abc.abstractmethod
1193		def _sequence_tokens(
1194			self,
1195			adj_list: list[str],
1196			origin: list[str] | None,
1197			target: list[str] | None,
1198			path: list[str] | None,
1199		) -> list[str]:
1200			"""Sequences token regions into a complete prompt.
1201
1202			Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as <ADJLIST_START>, <ORIGIN_END>, etc.
1203
1204			# Parameters
1205			- `adj_list`: Tokens representing the adjacency list
1206			- `origin`: Tokens representing the origin
1207			- `target`: Tokens representing the target
1208			- `path`: Tokens representing the path
1209			"""
1210			pass
1211
1212		def is_valid(self, do_except: bool = False) -> bool:
1213			# No invalid instances possible within data member type hint bounds
1214			return True
1215
1216	@serializable_dataclass(frozen=True, kw_only=True)
1217	class AOTP(_PromptSequencer):
1218		"""Sequences a prompt as [adjacency list, origin, target, path].
1219
1220		# Parameters
1221		- `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`.
1222		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`.
1223		- `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1224		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1225
1226		"""
1227
1228		target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field(
1229			default=TargetTokenizers.Unlabeled(),
1230			loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers),
1231		)
1232		path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1233			default=PathTokenizers.StepSequence(),
1234			loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1235		)
1236
1237		def _sequence_tokens(
1238			self,
1239			adj_list: list[str],
1240			origin: list[str],
1241			target: list[str],
1242			path: list[str],
1243		) -> list[str]:
1244			return [
1245				VOCAB.ADJLIST_START,
1246				*adj_list,
1247				VOCAB.ADJLIST_END,
1248				VOCAB.ORIGIN_START,
1249				*origin,
1250				VOCAB.ORIGIN_END,
1251				VOCAB.TARGET_START,
1252				*target,
1253				VOCAB.TARGET_END,
1254				VOCAB.PATH_START,
1255				*path,
1256				VOCAB.PATH_END,
1257			]
1258
1259	@serializable_dataclass(frozen=True, kw_only=True)
1260	class AOP(_PromptSequencer):
1261		"""Sequences a prompt as [adjacency list, origin, path].
1262
1263		Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself.
1264
1265		# Parameters
1266		- `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1267		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1268		"""
1269
1270		path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1271			default=PathTokenizers.StepSequence(),
1272			loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1273		)
1274
1275		def _sequence_tokens(
1276			self,
1277			adj_list: list[str],
1278			origin: list[str],
1279			# explicitly no target in this tokenizer
1280			target: list[str],
1281			path: list[str],
1282		) -> list[str]:
1283			return [
1284				VOCAB.ADJLIST_START,
1285				*adj_list,
1286				VOCAB.ADJLIST_END,
1287				VOCAB.ORIGIN_START,
1288				*origin,
1289				VOCAB.ORIGIN_END,
1290				VOCAB.TARGET_START,
1291				VOCAB.TARGET_END,
1292				VOCAB.PATH_START,
1293				*path,
1294				VOCAB.PATH_END,
1295			]

Namespace for _PromptSequencer subclass hierarchy used by MazeTokenizerModular.

key = 'prompt_sequencer'
@serializable_dataclass(frozen=True, kw_only=True)
class PromptSequencers.AOTP(maze_dataset.tokenization.modular.elements.PromptSequencers._PromptSequencer):
1216	@serializable_dataclass(frozen=True, kw_only=True)
1217	class AOTP(_PromptSequencer):
1218		"""Sequences a prompt as [adjacency list, origin, target, path].
1219
1220		# Parameters
1221		- `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`.
1222		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`.
1223		- `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1224		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1225
1226		"""
1227
1228		target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field(
1229			default=TargetTokenizers.Unlabeled(),
1230			loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers),
1231		)
1232		path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1233			default=PathTokenizers.StepSequence(),
1234			loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1235		)
1236
1237		def _sequence_tokens(
1238			self,
1239			adj_list: list[str],
1240			origin: list[str],
1241			target: list[str],
1242			path: list[str],
1243		) -> list[str]:
1244			return [
1245				VOCAB.ADJLIST_START,
1246				*adj_list,
1247				VOCAB.ADJLIST_END,
1248				VOCAB.ORIGIN_START,
1249				*origin,
1250				VOCAB.ORIGIN_END,
1251				VOCAB.TARGET_START,
1252				*target,
1253				VOCAB.TARGET_END,
1254				VOCAB.PATH_START,
1255				*path,
1256				VOCAB.PATH_END,
1257			]

Sequences a prompt as [adjacency list, origin, target, path].

Parameters

  • target_tokenizer: Tokenizer element which tokenizes the target(s) of a TargetedLatticeMaze. Uses coord_tokenizer to tokenize coords if that is part of the design of that TargetTokenizer.
  • path_tokenizer: Tokenizer element which tokenizes the solution path of a SolvedMaze. Uses coord_tokenizer to tokenize coords if that is part of the design of that PathTokenizer.
PromptSequencers.AOTP( *, coord_tokenizer: maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer = CoordTokenizers.UT(), adj_list_tokenizer: maze_dataset.tokenization.modular.elements.AdjListTokenizers._AdjListTokenizer = AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), _type_: Literal["<class 'PromptSequencers.AOTP'>"] = "<class 'PromptSequencers.AOTP'>", target_tokenizer: maze_dataset.tokenization.modular.elements.TargetTokenizers._TargetTokenizer = TargetTokenizers.Unlabeled(post=False), path_tokenizer: maze_dataset.tokenization.modular.elements.PathTokenizers._PathTokenizer = PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False))
target_tokenizer: maze_dataset.tokenization.modular.elements.TargetTokenizers._TargetTokenizer = TargetTokenizers.Unlabeled(post=False)
path_tokenizer: maze_dataset.tokenization.modular.elements.PathTokenizers._PathTokenizer = PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False)
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@serializable_dataclass(frozen=True, kw_only=True)
class PromptSequencers.AOP(maze_dataset.tokenization.modular.elements.PromptSequencers._PromptSequencer):
1259	@serializable_dataclass(frozen=True, kw_only=True)
1260	class AOP(_PromptSequencer):
1261		"""Sequences a prompt as [adjacency list, origin, path].
1262
1263		Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself.
1264
1265		# Parameters
1266		- `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1267		Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1268		"""
1269
1270		path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1271			default=PathTokenizers.StepSequence(),
1272			loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1273		)
1274
1275		def _sequence_tokens(
1276			self,
1277			adj_list: list[str],
1278			origin: list[str],
1279			# explicitly no target in this tokenizer
1280			target: list[str],
1281			path: list[str],
1282		) -> list[str]:
1283			return [
1284				VOCAB.ADJLIST_START,
1285				*adj_list,
1286				VOCAB.ADJLIST_END,
1287				VOCAB.ORIGIN_START,
1288				*origin,
1289				VOCAB.ORIGIN_END,
1290				VOCAB.TARGET_START,
1291				VOCAB.TARGET_END,
1292				VOCAB.PATH_START,
1293				*path,
1294				VOCAB.PATH_END,
1295			]

Sequences a prompt as [adjacency list, origin, path].

Still includes "" and "" tokens, but no representation of the target itself.

Parameters

  • path_tokenizer: Tokenizer element which tokenizes the solution path of a SolvedMaze. Uses coord_tokenizer to tokenize coords if that is part of the design of that PathTokenizer.
PromptSequencers.AOP( *, coord_tokenizer: maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer = CoordTokenizers.UT(), adj_list_tokenizer: maze_dataset.tokenization.modular.elements.AdjListTokenizers._AdjListTokenizer = AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), _type_: Literal["<class 'PromptSequencers.AOP'>"] = "<class 'PromptSequencers.AOP'>", path_tokenizer: maze_dataset.tokenization.modular.elements.PathTokenizers._PathTokenizer = PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False))
path_tokenizer: maze_dataset.tokenization.modular.elements.PathTokenizers._PathTokenizer = PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False)
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

 48class CoordTokenizers(__TokenizerElementNamespace):
 49	"""Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
 50
 51	key = "coord_tokenizer"
 52
 53	@serializable_dataclass(frozen=True, kw_only=True)
 54	class _CoordTokenizer(_TokenizerElement, abc.ABC):
 55		"""Superclass for classes which tokenize singular coords in a maze."""
 56
 57		@abc.abstractmethod
 58		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:
 59			pass
 60
 61		@classmethod
 62		def attribute_key(cls) -> str:
 63			return CoordTokenizers.key
 64
 65		def is_valid(self, do_except: bool = False) -> bool:
 66			# No invalid instances possible within data member type hint bounds
 67			return True
 68
 69	@serializable_dataclass(frozen=True, kw_only=True)
 70	class UT(_CoordTokenizer):
 71		"""Unique token coordinate tokenizer."""
 72
 73		# inherit docstring
 74		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
 75			return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]
 76
 77	@serializable_dataclass(frozen=True, kw_only=True)
 78	class CTT(_CoordTokenizer):
 79		"""Coordinate tuple tokenizer
 80
 81		# Parameters
 82		- `pre`: Whether all coords include an integral preceding delimiter token
 83		- `intra`: Whether all coords include a delimiter token between coordinates
 84		- `post`: Whether all coords include an integral following delimiter token
 85		"""
 86
 87		pre: bool = serializable_field(default=True)
 88		intra: bool = serializable_field(default=True)
 89		post: bool = serializable_field(default=True)
 90		# Implement methods
 91
 92		# inherit docstring
 93		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
 94			return [
 95				*empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"),
 96				str(coord[0]),
 97				*empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"),
 98				str(coord[1]),
 99				*empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"),
100			]

Namespace for _CoordTokenizer subclass hierarchy used by MazeTokenizerModular.

key = 'coord_tokenizer'
@serializable_dataclass(frozen=True, kw_only=True)
class CoordTokenizers.UT(maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer):
69	@serializable_dataclass(frozen=True, kw_only=True)
70	class UT(_CoordTokenizer):
71		"""Unique token coordinate tokenizer."""
72
73		# inherit docstring
74		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
75			return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]

Unique token coordinate tokenizer.

CoordTokenizers.UT( *, _type_: Literal["<class 'CoordTokenizers.UT'>"] = "<class 'CoordTokenizers.UT'>")
def to_tokens( self, coord: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int]) -> list[str]:
74		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
75			return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]

Converts a maze element into a list of tokens.

Not all _TokenizerElement subclasses produce tokens, so this is not an abstract method. Those subclasses which do produce tokens should override this method.

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class CoordTokenizers.CTT(maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer):
 77	@serializable_dataclass(frozen=True, kw_only=True)
 78	class CTT(_CoordTokenizer):
 79		"""Coordinate tuple tokenizer
 80
 81		# Parameters
 82		- `pre`: Whether all coords include an integral preceding delimiter token
 83		- `intra`: Whether all coords include a delimiter token between coordinates
 84		- `post`: Whether all coords include an integral following delimiter token
 85		"""
 86
 87		pre: bool = serializable_field(default=True)
 88		intra: bool = serializable_field(default=True)
 89		post: bool = serializable_field(default=True)
 90		# Implement methods
 91
 92		# inherit docstring
 93		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
 94			return [
 95				*empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"),
 96				str(coord[0]),
 97				*empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"),
 98				str(coord[1]),
 99				*empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"),
100			]

Coordinate tuple tokenizer

Parameters

  • pre: Whether all coords include an integral preceding delimiter token
  • intra: Whether all coords include a delimiter token between coordinates
  • post: Whether all coords include an integral following delimiter token
CoordTokenizers.CTT( *, _type_: Literal["<class 'CoordTokenizers.CTT'>"] = "<class 'CoordTokenizers.CTT'>", pre: bool = True, intra: bool = True, post: bool = True)
pre: bool = True
intra: bool = True
post: bool = True
def to_tokens( self, coord: jaxtyping.Int8[ndarray, 'row_col=2'] | tuple[int, int]) -> list[str]:
 93		def to_tokens(self, coord: Coord | CoordTup) -> list[str]:  # noqa: D102
 94			return [
 95				*empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"),
 96				str(coord[0]),
 97				*empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"),
 98				str(coord[1]),
 99				*empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"),
100			]

Converts a maze element into a list of tokens.

Not all _TokenizerElement subclasses produce tokens, so this is not an abstract method. Those subclasses which do produce tokens should override this method.

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
356class AdjListTokenizers(__TokenizerElementNamespace):
357	"""Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
358
359	key = "adj_list_tokenizer"
360
361	@serializable_dataclass(frozen=True, kw_only=True)
362	@mark_as_unsupported(_adjlist_no_pre_unsupported)
363	class _AdjListTokenizer(_TokenizerElement, abc.ABC):
364		"""Specifies how the adjacency list is tokenized.
365
366		Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations.
367		See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details.
368
369		# Parameters
370		- `pre`: Whether all edge groupings include a preceding delimiter token
371		- `post`: Whether all edge groupings include a following delimiter token
372		- `shuffle_d0`: Specifies how to sequence the edge groupings.
373			If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group.
374		- `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping.
375		- `edge_subset`: Specifies the subset of lattice edges to be tokenized.
376		- `edge_permuter`: Specifies, in each edge tokenization, which coord either:
377			1. Appears first in the tokenization, for `AdjListCoord`.
378			2. Is tokenized directly as a coord, for `AdjListCardinal`.
379			- `shuffle`: For each edge, the leading coord is selected randomly.
380			- `all`: Each edge appears twice in the tokenization, appearing with both leading coords.
381			- `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details.
382		"""
383
384		pre: bool = serializable_field(default=False, assert_type=False)
385		post: bool = serializable_field(default=True)
386		shuffle_d0: bool = serializable_field(default=True)
387		edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field(
388			default=EdgeGroupings.Ungrouped(),
389			loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings),
390		)
391		edge_subset: EdgeSubsets._EdgeSubset = serializable_field(
392			default=EdgeSubsets.ConnectionEdges(),
393			loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets),
394		)
395		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
396			default=EdgePermuters.RandomCoords(),
397			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
398		)
399
400		@classmethod
401		def attribute_key(cls) -> str:
402			return AdjListTokenizers.key
403
404		def is_valid(self, do_except: bool = False) -> bool:
405			# No invalid instances possible within data member type hint bounds
406			return True
407
408		@abc.abstractmethod
409		def _tokenization_callables(
410			self,
411			edges: ConnectionArray,
412			is_conn: Bool[np.ndarray, " edges"],
413			coord_tokenizer: CoordTokenizers._CoordTokenizer,
414			*args,
415			**kwargs,
416		) -> list[Callable]:
417			"""Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization.
418
419			# Returns
420			- `[0]`: leading coord tokens
421			- `[1]`: connector tokens
422			- `[2]`: trailing coord tokens
423			"""
424			pass
425
426		def _tokenize_edge_grouping(
427			self,
428			edges: ConnectionArray,
429			maze: LatticeMaze,
430			coord_tokenizer: CoordTokenizers._CoordTokenizer,
431			group_params: EdgeGroupings._GroupingTokenParams,
432		) -> Sequence[str]:
433			"""Tokenizes a single edge grouping."""
434			cxn_ord: int = group_params["connection_token_ordinal"]
435			is_conn: Bool[np.ndarray, edges] = is_connection(
436				edges,
437				maze.connection_list,
438			)
439			tokenize_callables = self._tokenization_callables(
440				edges,
441				is_conn,
442				coord_tokenizer,
443			)
444
445			if group_params["grouped"]:
446				# If grouped
447				callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1]
448				repeated_callables = [
449					tokenize_callables[i] for i in callable_permutation
450				]
451				return flatten(
452					[
453						tokenize_callables[0](0),
454						[
455							[
456								*[
457									tok_callable(i)
458									for tok_callable in repeated_callables
459								],
460								*(
461									(VOCAB.ADJLIST_INTRA,)
462									if group_params["intra"]
463									else ()
464								),
465							]
466							for i in range(edges.shape[0])
467						],
468					],
469				)
470			else:
471				# If ungrouped
472				callable_permutation = [0, 2]
473				callable_permutation.insert(cxn_ord, 1)
474				tokenize_callables = [
475					tokenize_callables[i] for i in callable_permutation
476				]
477
478				return flatten(
479					[
480						[
481							[
482								*[
483									tok_callable(i)
484									for tok_callable in tokenize_callables
485								],
486								*empty_sequence_if_attr_false(
487									(VOCAB.ADJLIST_INTRA,),
488									group_params,
489									"intra",
490								),
491							]
492							for i in range(edges.shape[0])
493						],
494					],
495				)
496
497		def to_tokens(
498			self,
499			maze: LatticeMaze,
500			coord_tokenizer: CoordTokenizers._CoordTokenizer,
501		) -> list[str]:
502			# Get the set of edges to be tokenized
503			edges: ConnectionArray = self.edge_subset._get_edges(maze)
504			# Systematically permute the leading coord of each edge
505			edges: ConnectionArray = self.edge_permuter._permute(edges)
506			group_params: EdgeGroupings._GroupingTokenParams = (
507				self.edge_grouping._token_params()
508			)
509			# then, we need to group the edges
510			groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges)
511			# shuffle the groups if specified
512			if self.shuffle_d0:
513				if isinstance(groups, np.ndarray):
514					numpy_rng.shuffle(groups, axis=0)
515				elif isinstance(groups, list):
516					random.shuffle(groups)
517				else:
518					err_msg: str = f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported."
519					raise TypeError(err_msg)
520			# Tokenize each group with optional delimiters
521			tokens: list[str] = list(
522				flatten(
523					[
524						[
525							*empty_sequence_if_attr_false(
526								(VOCAB.ADJLIST_PRE,),
527								self,
528								"pre",
529							),
530							*self._tokenize_edge_grouping(
531								group,
532								maze,
533								coord_tokenizer,
534								group_params,
535							),
536							*empty_sequence_if_attr_false(
537								(VOCAB.ADJACENCY_ENDLINE,),
538								self,
539								"post",
540							),
541						]
542						for group in groups
543					],
544				),
545			)
546			return tokens
547
548	@serializable_dataclass(frozen=True, kw_only=True)
549	class AdjListCoord(_AdjListTokenizer):
550		"""Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members."""
551
552		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
553			default=EdgePermuters.RandomCoords(),
554			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
555		)
556
557		def _tokenization_callables(
558			self,
559			edges: ConnectionArray,
560			is_conn: Bool[np.ndarray, " edges"],
561			coord_tokenizer: CoordTokenizers._CoordTokenizer,
562			*args,
563			**kwargs,
564		) -> list[Callable]:
565			# Map from `is_conn` to the tokens which represent connections and walls
566			conn_token_map: dict[bool, str] = {
567				True: VOCAB.CONNECTOR,
568				False: VOCAB.ADJLIST_WALL,
569			}
570			return [
571				lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
572				lambda i: conn_token_map[is_conn[i]],
573				lambda i: coord_tokenizer.to_tokens(edges[i, 1]),
574			]
575
576	@serializable_dataclass(frozen=True, kw_only=True)
577	class AdjListCardinal(_AdjListTokenizer):
578		"""Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.
579
580		# Parameters
581		- `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
582		"""
583
584		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
585			default=EdgePermuters.BothCoords(),
586			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
587		)
588
589		def _tokenization_callables(
590			self,
591			edges: ConnectionArray,
592			is_conn: Bool[np.ndarray, " edges"],
593			coord_tokenizer: CoordTokenizers._CoordTokenizer,
594			*args,
595			**kwargs,
596		) -> list[Callable]:
597			# Map from `is_conn` to the tokens which represent connections and walls
598			conn_token_map: dict[bool, str] = {
599				True: VOCAB.CONNECTOR,
600				False: VOCAB.ADJLIST_WALL,
601			}
602			return [
603				lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
604				lambda i: conn_token_map[is_conn[i]],
605				lambda i: get_cardinal_direction(edges[i]),
606			]

Namespace for _AdjListTokenizer subclass hierarchy used by MazeTokenizerModular.

key = 'adj_list_tokenizer'
@serializable_dataclass(frozen=True, kw_only=True)
class AdjListTokenizers.AdjListCoord(maze_dataset.tokenization.modular.elements.AdjListTokenizers._AdjListTokenizer):
548	@serializable_dataclass(frozen=True, kw_only=True)
549	class AdjListCoord(_AdjListTokenizer):
550		"""Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members."""
551
552		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
553			default=EdgePermuters.RandomCoords(),
554			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
555		)
556
557		def _tokenization_callables(
558			self,
559			edges: ConnectionArray,
560			is_conn: Bool[np.ndarray, " edges"],
561			coord_tokenizer: CoordTokenizers._CoordTokenizer,
562			*args,
563			**kwargs,
564		) -> list[Callable]:
565			# Map from `is_conn` to the tokens which represent connections and walls
566			conn_token_map: dict[bool, str] = {
567				True: VOCAB.CONNECTOR,
568				False: VOCAB.ADJLIST_WALL,
569			}
570			return [
571				lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
572				lambda i: conn_token_map[is_conn[i]],
573				lambda i: coord_tokenizer.to_tokens(edges[i, 1]),
574			]

Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.

AdjListTokenizers.AdjListCoord( *, pre: bool = False, post: bool = True, shuffle_d0: bool = True, edge_grouping: maze_dataset.tokenization.modular.elements.EdgeGroupings._EdgeGrouping = EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset: maze_dataset.tokenization.modular.elements.EdgeSubsets._EdgeSubset = EdgeSubsets.ConnectionEdges(walls=False), edge_permuter: maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter = EdgePermuters.RandomCoords(), _type_: Literal["<class 'AdjListTokenizers.AdjListCoord'>"] = "<class 'AdjListTokenizers.AdjListCoord'>")
edge_permuter: maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter = EdgePermuters.RandomCoords()
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

@serializable_dataclass(frozen=True, kw_only=True)
class AdjListTokenizers.AdjListCardinal(maze_dataset.tokenization.modular.elements.AdjListTokenizers._AdjListTokenizer):
576	@serializable_dataclass(frozen=True, kw_only=True)
577	class AdjListCardinal(_AdjListTokenizer):
578		"""Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.
579
580		# Parameters
581		- `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
582		"""
583
584		edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
585			default=EdgePermuters.BothCoords(),
586			loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
587		)
588
589		def _tokenization_callables(
590			self,
591			edges: ConnectionArray,
592			is_conn: Bool[np.ndarray, " edges"],
593			coord_tokenizer: CoordTokenizers._CoordTokenizer,
594			*args,
595			**kwargs,
596		) -> list[Callable]:
597			# Map from `is_conn` to the tokens which represent connections and walls
598			conn_token_map: dict[bool, str] = {
599				True: VOCAB.CONNECTOR,
600				False: VOCAB.ADJLIST_WALL,
601			}
602			return [
603				lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
604				lambda i: conn_token_map[is_conn[i]],
605				lambda i: get_cardinal_direction(edges[i]),
606			]

Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.

Parameters

  • coord_first: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
AdjListTokenizers.AdjListCardinal( *, pre: bool = False, post: bool = True, shuffle_d0: bool = True, edge_grouping: maze_dataset.tokenization.modular.elements.EdgeGroupings._EdgeGrouping = EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset: maze_dataset.tokenization.modular.elements.EdgeSubsets._EdgeSubset = EdgeSubsets.ConnectionEdges(walls=False), edge_permuter: maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter = EdgePermuters.BothCoords(), _type_: Literal["<class 'AdjListTokenizers.AdjListCardinal'>"] = "<class 'AdjListTokenizers.AdjListCardinal'>")
edge_permuter: maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter = EdgePermuters.BothCoords()
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

103class EdgeGroupings(__TokenizerElementNamespace):
104	"""Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`."""
105
106	key = "edge_grouping"
107
108	class _GroupingTokenParams(TypedDict):
109		"""A uniform private hyperparameter interface used by `AdjListTokenizer`."""
110
111		connection_token_ordinal: Literal[0, 1, 2]
112		intra: bool
113		grouped: bool
114
115	@serializable_dataclass(frozen=True, kw_only=True)
116	class _EdgeGrouping(_TokenizerElement, abc.ABC):
117		"""Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping."""
118
119		@classmethod
120		def attribute_key(cls) -> str:
121			return EdgeGroupings.key
122
123		def is_valid(self, do_except: bool = False) -> bool:
124			return True
125
126		@abc.abstractmethod
127		def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
128			"""Divides a ConnectionArray into groups of edges.
129
130			Shuffles/sequences within each group if applicable.
131			"""
132			pass
133
134		@abc.abstractmethod
135		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
136			"""Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize.
137
138			These hyperparameters are not used by `_EdgeGrouping` internally.
139			They are located in `_EdgeGrouping` rather than in `AdjListTokenizer`
140			since the hyperparameter space is a function of the `_EdgeGrouping` subclass.
141			This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses
142			into a uniform private interface used by `AdjListTokenizer`.
143			"""
144			pass
145
146	@serializable_dataclass(frozen=True, kw_only=True)
147	class Ungrouped(_EdgeGrouping):
148		"""No grouping occurs, each edge is tokenized individually.
149
150		# Parameters
151		- `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears.
152		Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
153		"""
154
155		connection_token_ordinal: Literal[0, 1, 2] = serializable_field(
156			default=1,
157			assert_type=False,
158		)
159
160		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
161			return EdgeGroupings._GroupingTokenParams(
162				connection_token_ordinal=self.connection_token_ordinal,
163				intra=False,
164				grouped=False,
165			)
166
167		def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]:
168			return np.expand_dims(edges, 1)
169
170	@serializable_dataclass(frozen=True, kw_only=True)
171	@mark_as_unsupported(_unsupported_is_invalid)
172	class ByLeadingCoord(_EdgeGrouping):
173		"""All edges with the same leading coord are grouped together.
174
175		# Parameters
176		- `intra`: Whether all edge groupings include a delimiter token between individual edge representations.
177		Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `)
178		- `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order.
179		If false, the fixed order is lexicographical by (row, col).
180		In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.
181		- `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears.
182		Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
183		"""
184
185		intra: bool = serializable_field(default=True)
186		shuffle_group: bool = serializable_field(default=True)
187		connection_token_ordinal: Literal[0, 1] = serializable_field(
188			default=0,
189			assert_type=False,
190		)
191
192		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
193			return EdgeGroupings._GroupingTokenParams(
194				connection_token_ordinal=self.connection_token_ordinal,
195				intra=self.intra,
196				grouped=True,
197			)
198
199		def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
200			# Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function
201			index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort(
202				(edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]),
203			)
204			sorted_edges: ConnectionArray = edges[index_array, ...]
205			groups: list[ConnectionArray] = np.split(
206				sorted_edges,
207				np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:],
208			)
209			if self.shuffle_group:
210				[numpy_rng.shuffle(g, axis=0) for g in groups]
211			return groups

Namespace for _EdgeGrouping subclass hierarchy used by _AdjListTokenizer.

key = 'edge_grouping'
@serializable_dataclass(frozen=True, kw_only=True)
class EdgeGroupings.Ungrouped(maze_dataset.tokenization.modular.elements.EdgeGroupings._EdgeGrouping):
146	@serializable_dataclass(frozen=True, kw_only=True)
147	class Ungrouped(_EdgeGrouping):
148		"""No grouping occurs, each edge is tokenized individually.
149
150		# Parameters
151		- `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears.
152		Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
153		"""
154
155		connection_token_ordinal: Literal[0, 1, 2] = serializable_field(
156			default=1,
157			assert_type=False,
158		)
159
160		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
161			return EdgeGroupings._GroupingTokenParams(
162				connection_token_ordinal=self.connection_token_ordinal,
163				intra=False,
164				grouped=False,
165			)
166
167		def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]:
168			return np.expand_dims(edges, 1)

No grouping occurs, each edge is tokenized individually.

Parameters

  • connection_token_ordinal: At which index in the edge tokenization the connector (or wall) token appears. Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
EdgeGroupings.Ungrouped( *, _type_: Literal["<class 'EdgeGroupings.Ungrouped'>"] = "<class 'EdgeGroupings.Ungrouped'>", connection_token_ordinal: Literal[0, 1, 2] = 1)
connection_token_ordinal: Literal[0, 1, 2] = 1
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.EdgeGroupings._EdgeGrouping
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
@mark_as_unsupported(_unsupported_is_invalid)
class EdgeGroupings.ByLeadingCoord(maze_dataset.tokenization.modular.elements.EdgeGroupings._EdgeGrouping):
170	@serializable_dataclass(frozen=True, kw_only=True)
171	@mark_as_unsupported(_unsupported_is_invalid)
172	class ByLeadingCoord(_EdgeGrouping):
173		"""All edges with the same leading coord are grouped together.
174
175		# Parameters
176		- `intra`: Whether all edge groupings include a delimiter token between individual edge representations.
177		Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `)
178		- `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order.
179		If false, the fixed order is lexicographical by (row, col).
180		In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.
181		- `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears.
182		Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
183		"""
184
185		intra: bool = serializable_field(default=True)
186		shuffle_group: bool = serializable_field(default=True)
187		connection_token_ordinal: Literal[0, 1] = serializable_field(
188			default=0,
189			assert_type=False,
190		)
191
192		def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
193			return EdgeGroupings._GroupingTokenParams(
194				connection_token_ordinal=self.connection_token_ordinal,
195				intra=self.intra,
196				grouped=True,
197			)
198
199		def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
200			# Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function
201			index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort(
202				(edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]),
203			)
204			sorted_edges: ConnectionArray = edges[index_array, ...]
205			groups: list[ConnectionArray] = np.split(
206				sorted_edges,
207				np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:],
208			)
209			if self.shuffle_group:
210				[numpy_rng.shuffle(g, axis=0) for g in groups]
211			return groups

All edges with the same leading coord are grouped together.

Parameters

  • intra: Whether all edge groupings include a delimiter token between individual edge representations. Note that each edge representation will already always include a connector token (VOCAB.CONNECTOR, or possibly `)
  • shuffle_group: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. If false, the fixed order is lexicographical by (row, col). In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.
  • connection_token_ordinal: At which index in token sequence representing a single edge the connector (or wall) token appears. Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
EdgeGroupings.ByLeadingCoord( *, _type_: Literal["<class 'EdgeGroupings.ByLeadingCoord'>"] = "<class 'EdgeGroupings.ByLeadingCoord'>", intra: bool = True, shuffle_group: bool = True, connection_token_ordinal: Literal[0, 1] = 0)
intra: bool = True
shuffle_group: bool = True
connection_token_ordinal: Literal[0, 1] = 0
def is_valid(self, do_except: bool = False) -> bool:
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool:  # noqa: ANN001
258	"""Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes"""
259	if do_except:
260		err_msg: str = (
261			f"Class `{type(self).__name__ = }, marked as unsupported, is not valid."
262			f"{type(self) = }, {self = }"
263		)
264		raise ValueError(err_msg)
265
266	return False

Default implementation of is_valid for mark_as_unsupported-decorated classes

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.EdgeGroupings._EdgeGrouping
attribute_key
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
214class EdgePermuters(__TokenizerElementNamespace):
215	"""Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`."""
216
217	key = "edge_permuter"
218
219	@serializable_dataclass(frozen=True, kw_only=True)
220	class _EdgePermuter(_TokenizerElement, abc.ABC):
221		"""Specifies how to sequence the two coords that encode a lattice edge."""
222
223		@classmethod
224		def attribute_key(cls) -> str:
225			return EdgePermuters.key
226
227		def is_valid(self, do_except: bool = False) -> bool:
228			# No invalid instances possible within data member type hint bounds
229			return True
230
231		@staticmethod
232		@abc.abstractmethod
233		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
234			"""Executes a permutation.
235
236			Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation.
237
238			# Parameters
239			- `lattice_edges`: Array of lattice edges.
240			The two coords in shape[1] must be adjacent in the lattice.
241
242			# Returns
243			- Array of lattice edges with entries along shape[1] systematically permuted.
244			- shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`.
245			"""
246			pass
247
248	@serializable_dataclass(frozen=True, kw_only=True)
249	class SortedCoords(_EdgePermuter):
250		"""returns a sorted representation. useful for checking consistency"""
251
252		@staticmethod
253		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
254			return lattice_edges[
255				np.lexsort(
256					(
257						lattice_edges[:, 1, 1],
258						lattice_edges[:, 1, 0],
259						lattice_edges[:, 0, 1],
260						lattice_edges[:, 0, 0],
261					),
262				),
263				...,
264			]
265
266	@serializable_dataclass(frozen=True, kw_only=True)
267	class RandomCoords(_EdgePermuter):
268		"""Permutes each edge randomly."""
269
270		@staticmethod
271		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
272			numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges)
273			return lattice_edges
274
275	@serializable_dataclass(frozen=True, kw_only=True)
276	class BothCoords(_EdgePermuter):
277		"""Includes both possible permutations of every edge in the output.
278
279		Since input ConnectionList has only 1 instance of each edge,
280		a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`.
281		"""
282
283		@staticmethod
284		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
285			return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)

Namespace for _EdgePermuter subclass hierarchy used by _AdjListTokenizer.

key = 'edge_permuter'
@serializable_dataclass(frozen=True, kw_only=True)
class EdgePermuters.SortedCoords(maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter):
248	@serializable_dataclass(frozen=True, kw_only=True)
249	class SortedCoords(_EdgePermuter):
250		"""returns a sorted representation. useful for checking consistency"""
251
252		@staticmethod
253		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
254			return lattice_edges[
255				np.lexsort(
256					(
257						lattice_edges[:, 1, 1],
258						lattice_edges[:, 1, 0],
259						lattice_edges[:, 0, 1],
260						lattice_edges[:, 0, 0],
261					),
262				),
263				...,
264			]

returns a sorted representation. useful for checking consistency

EdgePermuters.SortedCoords( *, _type_: Literal["<class 'EdgePermuters.SortedCoords'>"] = "<class 'EdgePermuters.SortedCoords'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class EdgePermuters.RandomCoords(maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter):
266	@serializable_dataclass(frozen=True, kw_only=True)
267	class RandomCoords(_EdgePermuter):
268		"""Permutes each edge randomly."""
269
270		@staticmethod
271		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
272			numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges)
273			return lattice_edges

Permutes each edge randomly.

EdgePermuters.RandomCoords( *, _type_: Literal["<class 'EdgePermuters.RandomCoords'>"] = "<class 'EdgePermuters.RandomCoords'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class EdgePermuters.BothCoords(maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter):
275	@serializable_dataclass(frozen=True, kw_only=True)
276	class BothCoords(_EdgePermuter):
277		"""Includes both possible permutations of every edge in the output.
278
279		Since input ConnectionList has only 1 instance of each edge,
280		a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`.
281		"""
282
283		@staticmethod
284		def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
285			return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)

Includes both possible permutations of every edge in the output.

Since input ConnectionList has only 1 instance of each edge, a call to BothCoords._permute will modify lattice_edges in-place, doubling shape[0].

EdgePermuters.BothCoords( *, _type_: Literal["<class 'EdgePermuters.BothCoords'>"] = "<class 'EdgePermuters.BothCoords'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.EdgePermuters._EdgePermuter
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
288class EdgeSubsets(__TokenizerElementNamespace):
289	"""Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`."""
290
291	key = "edge_subset"
292
293	@serializable_dataclass(frozen=True, kw_only=True)
294	class _EdgeSubset(_TokenizerElement, abc.ABC):
295		"""Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized."""
296
297		@classmethod
298		def attribute_key(cls) -> str:
299			return EdgeSubsets.key
300
301		def is_valid(self, do_except: bool = False) -> bool:
302			return True
303
304		@abc.abstractmethod
305		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
306			"""Returns the set of lattice edges to be tokenized."""
307			pass
308
309	@serializable_dataclass(frozen=True, kw_only=True)
310	class AllLatticeEdges(_EdgeSubset):
311		"""All 2n**2-2n edges of the lattice are tokenized.
312
313		If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`.
314		"""
315
316		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
317			return lattice_connection_array(maze.grid_n)
318
319	@serializable_dataclass(frozen=True, kw_only=True)
320	class ConnectionEdges(_EdgeSubset):
321		"""Only edges which contain a connection are tokenized.
322
323		Alternatively, only edges which contain a wall are tokenized.
324
325		# Parameters
326		- `walls`: Whether wall edges or connection edges are tokenized.
327		If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`.
328		"""
329
330		walls: bool = serializable_field(default=False)
331
332		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
333			conn_list: ConnectionList = maze.connection_list
334			if self.walls:
335				conn_list = np.logical_not(conn_list)
336				conn_list[0, -1, :] = False
337				conn_list[1, :, -1] = False
338			return connection_list_to_adj_list(
339				conn_list,
340				shuffle_d0=False,
341				shuffle_d1=False,
342			)

Namespace for _EdgeSubset subclass hierarchy used by _AdjListTokenizer.

key = 'edge_subset'
@serializable_dataclass(frozen=True, kw_only=True)
class EdgeSubsets.AllLatticeEdges(maze_dataset.tokenization.modular.elements.EdgeSubsets._EdgeSubset):
309	@serializable_dataclass(frozen=True, kw_only=True)
310	class AllLatticeEdges(_EdgeSubset):
311		"""All 2n**2-2n edges of the lattice are tokenized.
312
313		If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`.
314		"""
315
316		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
317			return lattice_connection_array(maze.grid_n)

All 2n**2-2n edges of the lattice are tokenized.

If a wall exists on that edge, the edge is tokenized in the same manner, using VOCAB.ADJLIST_WALL in place of VOCAB.CONNECTOR.

EdgeSubsets.AllLatticeEdges( *, _type_: Literal["<class 'EdgeSubsets.AllLatticeEdges'>"] = "<class 'EdgeSubsets.AllLatticeEdges'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.EdgeSubsets._EdgeSubset
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class EdgeSubsets.ConnectionEdges(maze_dataset.tokenization.modular.elements.EdgeSubsets._EdgeSubset):
319	@serializable_dataclass(frozen=True, kw_only=True)
320	class ConnectionEdges(_EdgeSubset):
321		"""Only edges which contain a connection are tokenized.
322
323		Alternatively, only edges which contain a wall are tokenized.
324
325		# Parameters
326		- `walls`: Whether wall edges or connection edges are tokenized.
327		If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`.
328		"""
329
330		walls: bool = serializable_field(default=False)
331
332		def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
333			conn_list: ConnectionList = maze.connection_list
334			if self.walls:
335				conn_list = np.logical_not(conn_list)
336				conn_list[0, -1, :] = False
337				conn_list[1, :, -1] = False
338			return connection_list_to_adj_list(
339				conn_list,
340				shuffle_d0=False,
341				shuffle_d1=False,
342			)

Only edges which contain a connection are tokenized.

Alternatively, only edges which contain a wall are tokenized.

Parameters

  • walls: Whether wall edges or connection edges are tokenized. If true, VOCAB.ADJLIST_WALL is used in place of VOCAB.CONNECTOR.
EdgeSubsets.ConnectionEdges( *, _type_: Literal["<class 'EdgeSubsets.ConnectionEdges'>"] = "<class 'EdgeSubsets.ConnectionEdges'>", walls: bool = False)
walls: bool = False
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.EdgeSubsets._EdgeSubset
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
609class TargetTokenizers(__TokenizerElementNamespace):
610	"""Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
611
612	key = "target_tokenizer"
613
614	@serializable_dataclass(frozen=True, kw_only=True)
615	class _TargetTokenizer(_TokenizerElement, abc.ABC):
616		"""Superclass of tokenizers for maze targets."""
617
618		@abc.abstractmethod
619		def to_tokens(
620			self,
621			targets: Sequence[Coord],
622			coord_tokenizer: CoordTokenizers._CoordTokenizer,
623		) -> list[str]:
624			"""Returns tokens representing the target."""
625			pass
626
627		@classmethod
628		def attribute_key(cls) -> str:
629			return TargetTokenizers.key
630
631	@serializable_dataclass(frozen=True, kw_only=True)
632	class Unlabeled(_TargetTokenizer):
633		"""Targets are simply listed as coord tokens.
634
635		- `post`: Whether all coords include an integral following delimiter token
636		"""
637
638		post: bool = serializable_field(default=False)
639
640		# inherit docstring
641		def to_tokens(  # noqa: D102
642			self,
643			targets: Sequence[Coord],
644			coord_tokenizer: CoordTokenizers._CoordTokenizer,
645		) -> list[str]:
646			return list(
647				flatten(
648					[
649						[
650							*coord_tokenizer.to_tokens(target),
651							*empty_sequence_if_attr_false(
652								[VOCAB.TARGET_POST],
653								self,
654								"post",
655							),
656						]
657						for target in targets
658					],
659				),
660			)
661
662		# inherit docstring
663		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
664			# No invalid instances possible within data member type hint bounds
665			return True

Namespace for _TargetTokenizer subclass hierarchy used by MazeTokenizerModular.

key = 'target_tokenizer'
@serializable_dataclass(frozen=True, kw_only=True)
class TargetTokenizers.Unlabeled(maze_dataset.tokenization.modular.elements.TargetTokenizers._TargetTokenizer):
631	@serializable_dataclass(frozen=True, kw_only=True)
632	class Unlabeled(_TargetTokenizer):
633		"""Targets are simply listed as coord tokens.
634
635		- `post`: Whether all coords include an integral following delimiter token
636		"""
637
638		post: bool = serializable_field(default=False)
639
640		# inherit docstring
641		def to_tokens(  # noqa: D102
642			self,
643			targets: Sequence[Coord],
644			coord_tokenizer: CoordTokenizers._CoordTokenizer,
645		) -> list[str]:
646			return list(
647				flatten(
648					[
649						[
650							*coord_tokenizer.to_tokens(target),
651							*empty_sequence_if_attr_false(
652								[VOCAB.TARGET_POST],
653								self,
654								"post",
655							),
656						]
657						for target in targets
658					],
659				),
660			)
661
662		# inherit docstring
663		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
664			# No invalid instances possible within data member type hint bounds
665			return True

Targets are simply listed as coord tokens.

  • post: Whether all coords include an integral following delimiter token
TargetTokenizers.Unlabeled( *, _type_: Literal["<class 'TargetTokenizers.Unlabeled'>"] = "<class 'TargetTokenizers.Unlabeled'>", post: bool = False)
post: bool = False
def to_tokens( self, targets: Sequence[jaxtyping.Int8[ndarray, 'row_col=2']], coord_tokenizer: maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer) -> list[str]:
641		def to_tokens(  # noqa: D102
642			self,
643			targets: Sequence[Coord],
644			coord_tokenizer: CoordTokenizers._CoordTokenizer,
645		) -> list[str]:
646			return list(
647				flatten(
648					[
649						[
650							*coord_tokenizer.to_tokens(target),
651							*empty_sequence_if_attr_false(
652								[VOCAB.TARGET_POST],
653								self,
654								"post",
655							),
656						]
657						for target in targets
658					],
659				),
660			)

Returns tokens representing the target.

def is_valid(self, do_except: bool = False) -> bool:
663		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
664			# No invalid instances possible within data member type hint bounds
665			return True

Returns if self contains data members capable of producing an overall valid MazeTokenizerModular.

Some _TokenizerElement instances may be created which are not useful despite obeying data member type hints. is_valid allows for more precise detection of invalid _TokenizerElements beyond type hinting alone. If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply return True for that subclass.

Types of Invalidity

In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity:

  • Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. E.g., _TokenizerElements which are strictly worse than some alternative.
  • Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
  • Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
  • Erroneous: These tokenizers might raise exceptions during use.

Development

is_invalid is implemented to always return True in some abstract classes where all currently possible subclass instances are valid. When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.

Nesting

In general, when implementing this method, there is no need to recursively call is_valid on nested _TokenizerElements contained in the class. In other words, failures of is_valid need not bubble up to the top of the nested _TokenizerElement tree. MazeTokenizerModular.is_valid calls is_valid on each of its _TokenizerElements individually, so failure at any level will be detected.

Types of Invalidity

If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. This could be used to create more or less stringent filters on the valid _TokenizerElement instances.

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.TargetTokenizers._TargetTokenizer
attribute_key
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
668class StepSizes(__TokenizerElementNamespace):
669	"""Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`."""
670
671	key = "step_size"
672
673	@serializable_dataclass(frozen=True, kw_only=True)
674	class _StepSize(_TokenizerElement, abc.ABC):
675		"""Specifies which coords in `maze.solution` are used to represent the path."""
676
677		@classmethod
678		def attribute_key(cls) -> str:
679			return StepSizes.key
680
681		@abc.abstractmethod  # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call
682		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
683			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
684			raise NotImplementedError(
685				"Subclasses must implement `StepSize.step_indices.",
686			)
687
688		def step_start_end_indices(self, maze: SolvedMaze) -> list[tuple[int, int]]:
689			"""Returns steps as tuples of starting and ending positions for each step."""
690			indices: list[int] = self._step_single_indices(maze)
691			# TODO: RUF007 Prefer `itertools.pairwise()` over `zip()` when iterating over successive pairs
692			return [
693				(start, end)
694				for start, end in zip(indices[:-1], indices[1:], strict=False)  # noqa: RUF007
695			]
696
697		def is_valid(self, do_except: bool = False) -> bool:
698			# No invalid instances possible within data member type hint bounds
699			return True
700
701	@serializable_dataclass(frozen=True, kw_only=True)
702	class Singles(_StepSize):
703		"""Every coord in `maze.solution` is represented.
704
705		Legacy tokenizers all use this behavior.
706		"""
707
708		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
709			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
710			return list(range(maze.solution.shape[0]))
711
712	@serializable_dataclass(frozen=True, kw_only=True)
713	@mark_as_unsupported(_unsupported_is_invalid)
714	class Straightaways(_StepSize):
715		"""Only coords where the path turns are represented in the path.
716
717		I.e., the path is represented as a sequence of straightaways,
718		specified by the coords at the turns.
719		"""
720
721		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
722			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
723			last_turn_coord: Coord = maze.solution[0, ...]
724			indices: list[int] = [0]
725			for i, coord in enumerate(maze.solution):
726				if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]:
727					indices.append(i - 1)
728					last_turn_coord = maze.solution[i - 1, ...]
729			indices.append(i)
730			return indices
731
732	@serializable_dataclass(frozen=True, kw_only=True)
733	class Forks(_StepSize):
734		"""Only coords at forks, where the path has >=2 options for the next step are included.
735
736		Excludes the option of backtracking.
737		The starting and ending coords are always included.
738		"""
739
740		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
741			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
742			return maze.get_solution_forking_points(always_include_endpoints=True)[0]
743
744	@serializable_dataclass(frozen=True, kw_only=True)
745	@mark_as_unsupported(_unsupported_is_invalid)
746	class ForksAndStraightaways(_StepSize):
747		"""Includes the union of the coords included by `Forks` and `Straightaways`.
748
749		See documentation for those classes for details.
750		"""
751
752		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
753			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
754			return list(
755				np.unique(
756					np.concatenate(
757						(
758							StepSizes.Straightaways()._step_single_indices(maze),
759							StepSizes.Forks()._step_single_indices(maze),
760						),
761					),
762				),
763			)

Namespace for _StepSize subclass hierarchy used by MazeTokenizerModular.

key = 'step_size'
@serializable_dataclass(frozen=True, kw_only=True)
class StepSizes.Singles(maze_dataset.tokenization.modular.elements.StepSizes._StepSize):
701	@serializable_dataclass(frozen=True, kw_only=True)
702	class Singles(_StepSize):
703		"""Every coord in `maze.solution` is represented.
704
705		Legacy tokenizers all use this behavior.
706		"""
707
708		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
709			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
710			return list(range(maze.solution.shape[0]))

Every coord in maze.solution is represented.

Legacy tokenizers all use this behavior.

StepSizes.Singles( *, _type_: Literal["<class 'StepSizes.Singles'>"] = "<class 'StepSizes.Singles'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.StepSizes._StepSize
attribute_key
step_start_end_indices
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
@mark_as_unsupported(_unsupported_is_invalid)
class StepSizes.Straightaways(maze_dataset.tokenization.modular.elements.StepSizes._StepSize):
712	@serializable_dataclass(frozen=True, kw_only=True)
713	@mark_as_unsupported(_unsupported_is_invalid)
714	class Straightaways(_StepSize):
715		"""Only coords where the path turns are represented in the path.
716
717		I.e., the path is represented as a sequence of straightaways,
718		specified by the coords at the turns.
719		"""
720
721		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
722			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
723			last_turn_coord: Coord = maze.solution[0, ...]
724			indices: list[int] = [0]
725			for i, coord in enumerate(maze.solution):
726				if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]:
727					indices.append(i - 1)
728					last_turn_coord = maze.solution[i - 1, ...]
729			indices.append(i)
730			return indices

Only coords where the path turns are represented in the path.

I.e., the path is represented as a sequence of straightaways, specified by the coords at the turns.

StepSizes.Straightaways( *, _type_: Literal["<class 'StepSizes.Straightaways'>"] = "<class 'StepSizes.Straightaways'>")
def is_valid(self, do_except: bool = False) -> bool:
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool:  # noqa: ANN001
258	"""Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes"""
259	if do_except:
260		err_msg: str = (
261			f"Class `{type(self).__name__ = }, marked as unsupported, is not valid."
262			f"{type(self) = }, {self = }"
263		)
264		raise ValueError(err_msg)
265
266	return False

Default implementation of is_valid for mark_as_unsupported-decorated classes

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.StepSizes._StepSize
attribute_key
step_start_end_indices
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class StepSizes.Forks(maze_dataset.tokenization.modular.elements.StepSizes._StepSize):
732	@serializable_dataclass(frozen=True, kw_only=True)
733	class Forks(_StepSize):
734		"""Only coords at forks, where the path has >=2 options for the next step are included.
735
736		Excludes the option of backtracking.
737		The starting and ending coords are always included.
738		"""
739
740		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
741			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
742			return maze.get_solution_forking_points(always_include_endpoints=True)[0]

Only coords at forks, where the path has >=2 options for the next step are included.

Excludes the option of backtracking. The starting and ending coords are always included.

StepSizes.Forks( *, _type_: Literal["<class 'StepSizes.Forks'>"] = "<class 'StepSizes.Forks'>")
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.StepSizes._StepSize
attribute_key
step_start_end_indices
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
@mark_as_unsupported(_unsupported_is_invalid)
class StepSizes.ForksAndStraightaways(maze_dataset.tokenization.modular.elements.StepSizes._StepSize):
744	@serializable_dataclass(frozen=True, kw_only=True)
745	@mark_as_unsupported(_unsupported_is_invalid)
746	class ForksAndStraightaways(_StepSize):
747		"""Includes the union of the coords included by `Forks` and `Straightaways`.
748
749		See documentation for those classes for details.
750		"""
751
752		def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
753			"""Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
754			return list(
755				np.unique(
756					np.concatenate(
757						(
758							StepSizes.Straightaways()._step_single_indices(maze),
759							StepSizes.Forks()._step_single_indices(maze),
760						),
761					),
762				),
763			)

Includes the union of the coords included by Forks and Straightaways.

See documentation for those classes for details.

StepSizes.ForksAndStraightaways( *, _type_: Literal["<class 'StepSizes.ForksAndStraightaways'>"] = "<class 'StepSizes.ForksAndStraightaways'>")
def is_valid(self, do_except: bool = False) -> bool:
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool:  # noqa: ANN001
258	"""Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes"""
259	if do_except:
260		err_msg: str = (
261			f"Class `{type(self).__name__ = }, marked as unsupported, is not valid."
262			f"{type(self) = }, {self = }"
263		)
264		raise ValueError(err_msg)
265
266	return False

Default implementation of is_valid for mark_as_unsupported-decorated classes

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.StepSizes._StepSize
attribute_key
step_start_end_indices
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
to_tokens
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
766class StepTokenizers(__TokenizerElementNamespace):
767	"""Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
768
769	key = "step_tokenizers"
770
771	@serializable_dataclass(frozen=True, kw_only=True)
772	class _StepTokenizer(_TokenizerElement, abc.ABC):
773		"""Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized."""
774
775		@classmethod
776		def attribute_key(cls) -> str:
777			return StepTokenizers.key
778
779		@abc.abstractmethod
780		def to_tokens(
781			self,
782			maze: SolvedMaze,
783			start_index: int,
784			end_index: int,
785			**kwargs,
786		) -> list[str]:
787			"""Tokenizes a single step in the solution.
788
789			# Parameters
790			- `maze`: Maze to be tokenized
791			- `start_index`: The index of the Coord in `maze.solution` at which the current step starts
792			- `end_index`: The index of the Coord in `maze.solution` at which the current step ends
793			"""
794			raise NotImplementedError(
795				"Subclasses must implement `StepTokenizer.to_tokens.",
796			)
797
798		def is_valid(self, do_except: bool = False) -> bool:
799			# No invalid instances possible within data member type hint bounds
800			return True
801
802	@serializable_dataclass(frozen=True, kw_only=True)
803	class Coord(_StepTokenizer):
804		"""A direct tokenization of the end position coord represents the step."""
805
806		# inherit docstring
807		def to_tokens(  # noqa: D102
808			self,
809			maze: SolvedMaze,
810			start_index: int,
811			end_index: int,
812			coord_tokenizer: CoordTokenizers._CoordTokenizer,
813		) -> list[str]:
814			return coord_tokenizer.to_tokens(maze.solution[end_index, ...])
815
816	@serializable_dataclass(frozen=True, kw_only=True)
817	class Cardinal(_StepTokenizer):
818		"""A step is tokenized with a cardinal direction token.
819
820		It is the direction of the step from the starting position along the solution.
821		"""
822
823		# inherit docstring
824		def to_tokens(  # noqa: D102
825			self,
826			maze: SolvedMaze,
827			start_index: int,
828			end_index: int,
829			**kwargs,
830		) -> list[str]:
831			return [
832				get_cardinal_direction(maze.solution[start_index : start_index + 2]),
833			]
834
835	@serializable_dataclass(frozen=True, kw_only=True)
836	class Relative(_StepTokenizer):
837		"""Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).
838
839		To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH.
840		Similarly to `Cardinal`, the direction is that of the step from the starting position.
841		"""
842
843		# inherit docstring
844		def to_tokens(  # noqa: D102
845			self,
846			maze: SolvedMaze,
847			start_index: int,
848			end_index: int,
849			**kwargs,
850		) -> list[str]:
851			if start_index == 0:
852				start = maze.solution[0]
853				previous = start + np.array([1, 0])
854				return [
855					get_relative_direction(
856						np.concatenate(
857							(
858								np.expand_dims(previous, 0),
859								maze.solution[start_index : start_index + 2],
860							),
861							axis=0,
862						),
863					),
864				]
865			return [
866				get_relative_direction(
867					maze.solution[start_index - 1 : start_index + 2],
868				),
869			]
870
871	@serializable_dataclass(frozen=True, kw_only=True)
872	class Distance(_StepTokenizer):
873		"""A count of the number of individual steps from the starting point to the end point.
874
875		Contains no information about directionality, only the distance traveled in the step.
876		`Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`.
877		This constraint is enforced in `_PathTokenizer.is_valid`.
878		"""
879
880		# inherit docstring
881		def to_tokens(  # noqa: D102
882			self,
883			maze: SolvedMaze,
884			start_index: int,
885			end_index: int,
886			**kwargs,
887		) -> list[str]:
888			d: int = end_index - start_index
889			return [getattr(VOCAB, f"I_{d:03}")]
890
891	"""
892	`StepTokenizerPermutation`
893	A sequence of unique `_StepTokenizer`s.
894	This type exists mostly just for the clarity and convenience of `_PathTokenizer` code.
895	"""
896	StepTokenizerPermutation: type = (
897		tuple[_StepTokenizer]
898		| tuple[_StepTokenizer, _StepTokenizer]
899		| tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer]
900		| tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer]
901	)

Namespace for _StepTokenizer subclass hierarchy used by MazeTokenizerModular.

key = 'step_tokenizers'
StepTokenizerPermutation: type = tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer]
@serializable_dataclass(frozen=True, kw_only=True)
class StepTokenizers.Coord(maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer):
802	@serializable_dataclass(frozen=True, kw_only=True)
803	class Coord(_StepTokenizer):
804		"""A direct tokenization of the end position coord represents the step."""
805
806		# inherit docstring
807		def to_tokens(  # noqa: D102
808			self,
809			maze: SolvedMaze,
810			start_index: int,
811			end_index: int,
812			coord_tokenizer: CoordTokenizers._CoordTokenizer,
813		) -> list[str]:
814			return coord_tokenizer.to_tokens(maze.solution[end_index, ...])

A direct tokenization of the end position coord represents the step.

StepTokenizers.Coord( *, _type_: Literal["<class 'StepTokenizers.Coord'>"] = "<class 'StepTokenizers.Coord'>")
def to_tokens( self, maze: maze_dataset.SolvedMaze, start_index: int, end_index: int, coord_tokenizer: maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer) -> list[str]:
807		def to_tokens(  # noqa: D102
808			self,
809			maze: SolvedMaze,
810			start_index: int,
811			end_index: int,
812			coord_tokenizer: CoordTokenizers._CoordTokenizer,
813		) -> list[str]:
814			return coord_tokenizer.to_tokens(maze.solution[end_index, ...])

Tokenizes a single step in the solution.

Parameters

  • maze: Maze to be tokenized
  • start_index: The index of the Coord in maze.solution at which the current step starts
  • end_index: The index of the Coord in maze.solution at which the current step ends
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class StepTokenizers.Cardinal(maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer):
816	@serializable_dataclass(frozen=True, kw_only=True)
817	class Cardinal(_StepTokenizer):
818		"""A step is tokenized with a cardinal direction token.
819
820		It is the direction of the step from the starting position along the solution.
821		"""
822
823		# inherit docstring
824		def to_tokens(  # noqa: D102
825			self,
826			maze: SolvedMaze,
827			start_index: int,
828			end_index: int,
829			**kwargs,
830		) -> list[str]:
831			return [
832				get_cardinal_direction(maze.solution[start_index : start_index + 2]),
833			]

A step is tokenized with a cardinal direction token.

It is the direction of the step from the starting position along the solution.

StepTokenizers.Cardinal( *, _type_: Literal["<class 'StepTokenizers.Cardinal'>"] = "<class 'StepTokenizers.Cardinal'>")
def to_tokens( self, maze: maze_dataset.SolvedMaze, start_index: int, end_index: int, **kwargs) -> list[str]:
824		def to_tokens(  # noqa: D102
825			self,
826			maze: SolvedMaze,
827			start_index: int,
828			end_index: int,
829			**kwargs,
830		) -> list[str]:
831			return [
832				get_cardinal_direction(maze.solution[start_index : start_index + 2]),
833			]

Tokenizes a single step in the solution.

Parameters

  • maze: Maze to be tokenized
  • start_index: The index of the Coord in maze.solution at which the current step starts
  • end_index: The index of the Coord in maze.solution at which the current step ends
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class StepTokenizers.Relative(maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer):
835	@serializable_dataclass(frozen=True, kw_only=True)
836	class Relative(_StepTokenizer):
837		"""Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).
838
839		To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH.
840		Similarly to `Cardinal`, the direction is that of the step from the starting position.
841		"""
842
843		# inherit docstring
844		def to_tokens(  # noqa: D102
845			self,
846			maze: SolvedMaze,
847			start_index: int,
848			end_index: int,
849			**kwargs,
850		) -> list[str]:
851			if start_index == 0:
852				start = maze.solution[0]
853				previous = start + np.array([1, 0])
854				return [
855					get_relative_direction(
856						np.concatenate(
857							(
858								np.expand_dims(previous, 0),
859								maze.solution[start_index : start_index + 2],
860							),
861							axis=0,
862						),
863					),
864				]
865			return [
866				get_relative_direction(
867					maze.solution[start_index - 1 : start_index + 2],
868				),
869			]

Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).

To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH. Similarly to Cardinal, the direction is that of the step from the starting position.

StepTokenizers.Relative( *, _type_: Literal["<class 'StepTokenizers.Relative'>"] = "<class 'StepTokenizers.Relative'>")
def to_tokens( self, maze: maze_dataset.SolvedMaze, start_index: int, end_index: int, **kwargs) -> list[str]:
844		def to_tokens(  # noqa: D102
845			self,
846			maze: SolvedMaze,
847			start_index: int,
848			end_index: int,
849			**kwargs,
850		) -> list[str]:
851			if start_index == 0:
852				start = maze.solution[0]
853				previous = start + np.array([1, 0])
854				return [
855					get_relative_direction(
856						np.concatenate(
857							(
858								np.expand_dims(previous, 0),
859								maze.solution[start_index : start_index + 2],
860							),
861							axis=0,
862						),
863					),
864				]
865			return [
866				get_relative_direction(
867					maze.solution[start_index - 1 : start_index + 2],
868				),
869			]

Tokenizes a single step in the solution.

Parameters

  • maze: Maze to be tokenized
  • start_index: The index of the Coord in maze.solution at which the current step starts
  • end_index: The index of the Coord in maze.solution at which the current step ends
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
@serializable_dataclass(frozen=True, kw_only=True)
class StepTokenizers.Distance(maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer):
871	@serializable_dataclass(frozen=True, kw_only=True)
872	class Distance(_StepTokenizer):
873		"""A count of the number of individual steps from the starting point to the end point.
874
875		Contains no information about directionality, only the distance traveled in the step.
876		`Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`.
877		This constraint is enforced in `_PathTokenizer.is_valid`.
878		"""
879
880		# inherit docstring
881		def to_tokens(  # noqa: D102
882			self,
883			maze: SolvedMaze,
884			start_index: int,
885			end_index: int,
886			**kwargs,
887		) -> list[str]:
888			d: int = end_index - start_index
889			return [getattr(VOCAB, f"I_{d:03}")]

A count of the number of individual steps from the starting point to the end point.

Contains no information about directionality, only the distance traveled in the step. Distance must be combined with at least one other _StepTokenizer in a StepTokenizerPermutation. This constraint is enforced in _PathTokenizer.is_valid.

StepTokenizers.Distance( *, _type_: Literal["<class 'StepTokenizers.Distance'>"] = "<class 'StepTokenizers.Distance'>")
def to_tokens( self, maze: maze_dataset.SolvedMaze, start_index: int, end_index: int, **kwargs) -> list[str]:
881		def to_tokens(  # noqa: D102
882			self,
883			maze: SolvedMaze,
884			start_index: int,
885			end_index: int,
886			**kwargs,
887		) -> list[str]:
888			d: int = end_index - start_index
889			return [getattr(VOCAB, f"I_{d:03}")]

Tokenizes a single step in the solution.

Parameters

  • maze: Maze to be tokenized
  • start_index: The index of the Coord in maze.solution at which the current step starts
  • end_index: The index of the Coord in maze.solution at which the current step ends
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer
attribute_key
is_valid
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
 904class PathTokenizers(__TokenizerElementNamespace):
 905	"""Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
 906
 907	key = "path_tokenizer"
 908
 909	@serializable_dataclass(frozen=True, kw_only=True)
 910	class _PathTokenizer(_TokenizerElement, abc.ABC):
 911		"""Superclass of tokenizers for maze solution paths."""
 912
 913		@abc.abstractmethod
 914		def to_tokens(
 915			self,
 916			maze: SolvedMaze,
 917			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 918		) -> list[str]:
 919			"""Returns tokens representing the solution path."""
 920			pass
 921
 922		@classmethod
 923		def attribute_key(cls) -> str:
 924			return PathTokenizers.key
 925
 926	@serializable_dataclass(frozen=True, kw_only=True)
 927	class StepSequence(_PathTokenizer, abc.ABC):
 928		"""Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path.
 929
 930		Allows for a sequence of leading and trailing tokens which don't fit the step pattern.
 931
 932		# Parameters
 933		- `step_size`: Selects the size of a single step in the sequence
 934		- `step_tokenizers`: Selects the combination and permutation of tokens
 935		- `pre`: Whether all steps include an integral preceding delimiter token
 936		- `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization.
 937		- `post`: Whether all steps include an integral following delimiter token
 938		"""
 939
 940		step_size: StepSizes._StepSize = serializable_field(
 941			default=StepSizes.Singles(),
 942			loading_fn=lambda x: _load_tokenizer_element(x, StepSizes),
 943		)
 944		step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field(
 945			default=(StepTokenizers.Coord(),),
 946			serialization_fn=lambda x: [y.serialize() for y in x],
 947			loading_fn=lambda x: tuple(x[StepTokenizers.key]),
 948		)
 949		pre: bool = serializable_field(default=False)
 950		intra: bool = serializable_field(default=False)
 951		post: bool = serializable_field(default=False)
 952
 953		# inherit docstring
 954		def to_tokens(  # noqa: D102
 955			self,
 956			maze: SolvedMaze,
 957			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 958		) -> list[str]:
 959			return [
 960				*self._leading_tokens(maze, coord_tokenizer),
 961				*flatten(
 962					[
 963						self._single_step_tokens(maze, start, end, coord_tokenizer)
 964						for start, end in self.step_size.step_start_end_indices(maze)
 965					],
 966				),
 967				*self._trailing_tokens(maze, coord_tokenizer),
 968			]
 969
 970		def _single_step_tokens(
 971			self,
 972			maze: SolvedMaze,
 973			i: int,
 974			j: int,
 975			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 976		) -> list[str]:
 977			"""Returns the token sequence representing a single step along the path."""
 978			step_rep_tokens: list[list[str]] = [
 979				step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer)
 980				for step_tokenizer in self.step_tokenizers
 981			]
 982			if self.intra:
 983				step_rep_tokens_and_intra: list[str] = [None] * (
 984					len(step_rep_tokens) * 2
 985				)
 986				step_rep_tokens_and_intra[::2] = step_rep_tokens
 987				step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len(
 988					step_rep_tokens,
 989				)
 990				step_rep_tokens = list(flatten(step_rep_tokens_and_intra))
 991			all_tokens: list[str] = [
 992				*empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
 993				*flatten(step_rep_tokens),
 994				*empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"),
 995			]
 996			return all_tokens
 997
 998		def _leading_tokens(
 999			self,
1000			maze: SolvedMaze,
1001			coord_tokenizer: CoordTokenizers._CoordTokenizer,
1002		) -> list[str]:
1003			"""Returns tokens preceding those from the sequence from `_single_step_tokens`.
1004
1005			Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`.
1006			<PATH_START> should NOT be included.
1007			"""
1008			if StepTokenizers.Coord() in self.step_tokenizers:
1009				return [
1010					*empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
1011					*coord_tokenizer.to_tokens(maze.solution[0, ...]),
1012					*empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"),
1013				]
1014			return []
1015
1016		def _trailing_tokens(
1017			self,
1018			c: Coord,
1019			coord_tokenizer: CoordTokenizers._CoordTokenizer,
1020		) -> list[str]:
1021			"""Returns tokens following those from the sequence from `_single_step_tokens`.
1022
1023			<PATH_END> should NOT be included.
1024			"""
1025			return []
1026
1027		# inherits docstring
1028		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
1029			output: bool
1030
1031			if len(set(self.step_tokenizers)) != len(self.step_tokenizers):
1032				# Uninteresting: repeated elements are not useful
1033				output = False
1034			else:
1035				# we do noqa for the comment if false
1036				if len(self.step_tokenizers) == 1 and isinstance(
1037					self.step_tokenizers[0],
1038					StepTokenizers.Distance,
1039				):
1040					# Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required.
1041					output = False
1042				else:
1043					output = True
1044
1045			if not output and do_except:
1046				raise ValueError(
1047					"PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.",
1048				)
1049
1050			return output

Namespace for _PathTokenizer subclass hierarchy used by MazeTokenizerModular.

key = 'path_tokenizer'
@serializable_dataclass(frozen=True, kw_only=True)
class PathTokenizers.StepSequence(maze_dataset.tokenization.modular.elements.PathTokenizers._PathTokenizer, abc.ABC):
 926	@serializable_dataclass(frozen=True, kw_only=True)
 927	class StepSequence(_PathTokenizer, abc.ABC):
 928		"""Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path.
 929
 930		Allows for a sequence of leading and trailing tokens which don't fit the step pattern.
 931
 932		# Parameters
 933		- `step_size`: Selects the size of a single step in the sequence
 934		- `step_tokenizers`: Selects the combination and permutation of tokens
 935		- `pre`: Whether all steps include an integral preceding delimiter token
 936		- `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization.
 937		- `post`: Whether all steps include an integral following delimiter token
 938		"""
 939
 940		step_size: StepSizes._StepSize = serializable_field(
 941			default=StepSizes.Singles(),
 942			loading_fn=lambda x: _load_tokenizer_element(x, StepSizes),
 943		)
 944		step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field(
 945			default=(StepTokenizers.Coord(),),
 946			serialization_fn=lambda x: [y.serialize() for y in x],
 947			loading_fn=lambda x: tuple(x[StepTokenizers.key]),
 948		)
 949		pre: bool = serializable_field(default=False)
 950		intra: bool = serializable_field(default=False)
 951		post: bool = serializable_field(default=False)
 952
 953		# inherit docstring
 954		def to_tokens(  # noqa: D102
 955			self,
 956			maze: SolvedMaze,
 957			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 958		) -> list[str]:
 959			return [
 960				*self._leading_tokens(maze, coord_tokenizer),
 961				*flatten(
 962					[
 963						self._single_step_tokens(maze, start, end, coord_tokenizer)
 964						for start, end in self.step_size.step_start_end_indices(maze)
 965					],
 966				),
 967				*self._trailing_tokens(maze, coord_tokenizer),
 968			]
 969
 970		def _single_step_tokens(
 971			self,
 972			maze: SolvedMaze,
 973			i: int,
 974			j: int,
 975			coord_tokenizer: CoordTokenizers._CoordTokenizer,
 976		) -> list[str]:
 977			"""Returns the token sequence representing a single step along the path."""
 978			step_rep_tokens: list[list[str]] = [
 979				step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer)
 980				for step_tokenizer in self.step_tokenizers
 981			]
 982			if self.intra:
 983				step_rep_tokens_and_intra: list[str] = [None] * (
 984					len(step_rep_tokens) * 2
 985				)
 986				step_rep_tokens_and_intra[::2] = step_rep_tokens
 987				step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len(
 988					step_rep_tokens,
 989				)
 990				step_rep_tokens = list(flatten(step_rep_tokens_and_intra))
 991			all_tokens: list[str] = [
 992				*empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
 993				*flatten(step_rep_tokens),
 994				*empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"),
 995			]
 996			return all_tokens
 997
 998		def _leading_tokens(
 999			self,
1000			maze: SolvedMaze,
1001			coord_tokenizer: CoordTokenizers._CoordTokenizer,
1002		) -> list[str]:
1003			"""Returns tokens preceding those from the sequence from `_single_step_tokens`.
1004
1005			Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`.
1006			<PATH_START> should NOT be included.
1007			"""
1008			if StepTokenizers.Coord() in self.step_tokenizers:
1009				return [
1010					*empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
1011					*coord_tokenizer.to_tokens(maze.solution[0, ...]),
1012					*empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"),
1013				]
1014			return []
1015
1016		def _trailing_tokens(
1017			self,
1018			c: Coord,
1019			coord_tokenizer: CoordTokenizers._CoordTokenizer,
1020		) -> list[str]:
1021			"""Returns tokens following those from the sequence from `_single_step_tokens`.
1022
1023			<PATH_END> should NOT be included.
1024			"""
1025			return []
1026
1027		# inherits docstring
1028		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
1029			output: bool
1030
1031			if len(set(self.step_tokenizers)) != len(self.step_tokenizers):
1032				# Uninteresting: repeated elements are not useful
1033				output = False
1034			else:
1035				# we do noqa for the comment if false
1036				if len(self.step_tokenizers) == 1 and isinstance(
1037					self.step_tokenizers[0],
1038					StepTokenizers.Distance,
1039				):
1040					# Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required.
1041					output = False
1042				else:
1043					output = True
1044
1045			if not output and do_except:
1046				raise ValueError(
1047					"PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.",
1048				)
1049
1050			return output

Any PathTokenizer where the tokenization may be assembled from token subsequences, each of which represents a step along the path.

Allows for a sequence of leading and trailing tokens which don't fit the step pattern.

Parameters

  • step_size: Selects the size of a single step in the sequence
  • step_tokenizers: Selects the combination and permutation of tokens
  • pre: Whether all steps include an integral preceding delimiter token
  • intra: Whether all steps include a delimiter token after each individual _StepTokenizer tokenization.
  • post: Whether all steps include an integral following delimiter token
PathTokenizers.StepSequence( *, _type_: Literal["<class 'PathTokenizers.StepSequence'>"] = "<class 'PathTokenizers.StepSequence'>", step_size: maze_dataset.tokenization.modular.elements.StepSizes._StepSize = StepSizes.Singles(), step_tokenizers: tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] = (StepTokenizers.Coord(),), pre: bool = False, intra: bool = False, post: bool = False)
step_size: maze_dataset.tokenization.modular.elements.StepSizes._StepSize = StepSizes.Singles()
step_tokenizers: tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] = (StepTokenizers.Coord(),)
pre: bool = False
intra: bool = False
post: bool = False
def to_tokens( self, maze: maze_dataset.SolvedMaze, coord_tokenizer: maze_dataset.tokenization.modular.elements.CoordTokenizers._CoordTokenizer) -> list[str]:
954		def to_tokens(  # noqa: D102
955			self,
956			maze: SolvedMaze,
957			coord_tokenizer: CoordTokenizers._CoordTokenizer,
958		) -> list[str]:
959			return [
960				*self._leading_tokens(maze, coord_tokenizer),
961				*flatten(
962					[
963						self._single_step_tokens(maze, start, end, coord_tokenizer)
964						for start, end in self.step_size.step_start_end_indices(maze)
965					],
966				),
967				*self._trailing_tokens(maze, coord_tokenizer),
968			]

Returns tokens representing the solution path.

def is_valid(self, do_except: bool = False) -> bool:
1028		def is_valid(self, do_except: bool = False) -> bool:  # noqa: D102
1029			output: bool
1030
1031			if len(set(self.step_tokenizers)) != len(self.step_tokenizers):
1032				# Uninteresting: repeated elements are not useful
1033				output = False
1034			else:
1035				# we do noqa for the comment if false
1036				if len(self.step_tokenizers) == 1 and isinstance(
1037					self.step_tokenizers[0],
1038					StepTokenizers.Distance,
1039				):
1040					# Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required.
1041					output = False
1042				else:
1043					output = True
1044
1045			if not output and do_except:
1046				raise ValueError(
1047					"PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.",
1048				)
1049
1050			return output

Returns if self contains data members capable of producing an overall valid MazeTokenizerModular.

Some _TokenizerElement instances may be created which are not useful despite obeying data member type hints. is_valid allows for more precise detection of invalid _TokenizerElements beyond type hinting alone. If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply return True for that subclass.

Types of Invalidity

In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity:

  • Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. E.g., _TokenizerElements which are strictly worse than some alternative.
  • Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
  • Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
  • Erroneous: These tokenizers might raise exceptions during use.

Development

is_invalid is implemented to always return True in some abstract classes where all currently possible subclass instances are valid. When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.

Nesting

In general, when implementing this method, there is no need to recursively call is_valid on nested _TokenizerElements contained in the class. In other words, failures of is_valid need not bubble up to the top of the nested _TokenizerElement tree. MazeTokenizerModular.is_valid calls is_valid on each of its _TokenizerElements individually, so failure at any level will be detected.

Types of Invalidity

If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. This could be used to create more or less stringent filters on the valid _TokenizerElement instances.

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.tokenization.modular.elements.PathTokenizers._PathTokenizer
attribute_key
_TokenizerElement
name
tokenizer_elements
tokenizer_element_tree
tokenizer_element_dict
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
def get_tokens_up_to_path_start( tokens: list[str], include_start_coord: bool = True, tokenization_mode: TokenizationMode = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>) -> list[str]:
 93def get_tokens_up_to_path_start(
 94	tokens: list[str],
 95	include_start_coord: bool = True,
 96	tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform,
 97) -> list[str]:
 98	"""get tokens up to the path start token
 99
100	# Parameters:
101	- `tokens : list[str]`
102	- `include_start_coord : bool`
103		(defaults to `True`)
104	- `tokenization_mode : TokenizationMode`
105		(defaults to `TokenizationMode.AOTP_UT_uniform`)
106
107	# Returns:
108	- `list[str]` subsequence of `tokens` up to the path start token
109
110	# Raises:
111	- `ValueError` : if `tokenization_mode` is invalid
112	"""
113	warnings.warn(
114		"`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.",
115		TokenizerPendingDeprecationWarning,
116	)
117	path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1
118	if include_start_coord:
119		if is_UT(tokenization_mode):
120			return tokens[: path_start_idx + 1]
121		elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
122			return tokens[: path_start_idx + 5]
123		else:
124			err_msg: str = f"Invalid tokenization mode: {tokenization_mode}"
125			raise ValueError(err_msg)
126	else:
127		return tokens[:path_start_idx]

get tokens up to the path start token

Parameters:

Returns:

  • list[str] subsequence of tokens up to the path start token

Raises:

  • ValueError : if tokenization_mode is invalid