docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.tokenization.maze_tokenizer_legacy

legacy tokenizer which uses a TokenizationMode enum and a MazeTokenizer class

Caution

MazeTokenizerModular is the new standard for tokenization. This class is no longer recommended for use, but will remain for compatibility with existing code.


  1"""legacy tokenizer which uses a `TokenizationMode` enum and a `MazeTokenizer` class
  2
  3> [!CAUTION]
  4> `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended
  5> for use, but will remain for compatibility with existing code.
  6
  7"""
  8
  9import warnings
 10from enum import Enum
 11from functools import cached_property
 12from typing import (
 13	Callable,
 14	Iterable,
 15	Literal,
 16	Mapping,
 17	Sequence,
 18	overload,
 19)
 20
 21import numpy as np
 22from muutils.json_serialize import (
 23	SerializableDataclass,
 24	serializable_dataclass,
 25	serializable_field,
 26)
 27from muutils.kappa import Kappa
 28from muutils.misc.sequence import WhenMissing
 29
 30# from maze_dataset import SolvedMaze
 31from maze_dataset.constants import (
 32	SPECIAL_TOKENS,
 33	CoordTup,
 34)
 35from maze_dataset.token_utils import (
 36	TokenizerPendingDeprecationWarning,
 37	_coord_to_strings_indexed,
 38	_coord_to_strings_UT,
 39	coords_to_strings,
 40	strings_to_coords,
 41)
 42from maze_dataset.tokenization.common import TokenError
 43from maze_dataset.utils import corner_first_ndindex
 44
 45
 46class TokenizationMode(Enum):
 47	"""legacy tokenization modes
 48
 49	> [!CAUTION]
 50	> Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use.
 51	> Use `MazeTokenizerModular` instead.
 52
 53	# Abbreviations:
 54	- `AOTP`: Ajacency list, Origin, Target, Path
 55	- `UT`: Unique Token (for each coordiate)
 56	- `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)
 57
 58	# Modes:
 59	- `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization
 60		example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)`
 61	- `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible
 62		uses `corner_first_ndindex` function to order the tokens
 63	- `AOTP_CTT_indexed`: each coordinate is a tuple of integers
 64	"""
 65
 66	AOTP_UT_rasterized = "AOTP_UT_rasterized"
 67	AOTP_UT_uniform = "AOTP_UT_uniform"
 68	AOTP_CTT_indexed = "AOTP_CTT_indexed"
 69
 70	def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer":
 71		"convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`"
 72		return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)
 73
 74
 75_NDINDEX_FUNC_MAP: dict[
 76	TokenizationMode,
 77	Callable[[int], Iterable[tuple[int, ...]]],
 78] = {
 79	TokenizationMode.AOTP_UT_rasterized: lambda n: list(np.ndindex(n, n)),
 80	TokenizationMode.AOTP_UT_uniform: lambda n: corner_first_ndindex(n, 2),
 81}
 82
 83
 84def is_UT(tokenization_mode: TokenizationMode) -> bool:
 85	"returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
 86	return tokenization_mode in (
 87		TokenizationMode.AOTP_UT_rasterized,
 88		TokenizationMode.AOTP_UT_uniform,
 89	)
 90
 91
 92def get_tokens_up_to_path_start(
 93	tokens: list[str],
 94	include_start_coord: bool = True,
 95	tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform,
 96) -> list[str]:
 97	"""get tokens up to the path start token
 98
 99	# Parameters:
100	- `tokens : list[str]`
101	- `include_start_coord : bool`
102		(defaults to `True`)
103	- `tokenization_mode : TokenizationMode`
104		(defaults to `TokenizationMode.AOTP_UT_uniform`)
105
106	# Returns:
107	- `list[str]` subsequence of `tokens` up to the path start token
108
109	# Raises:
110	- `ValueError` : if `tokenization_mode` is invalid
111	"""
112	warnings.warn(
113		"`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.",
114		TokenizerPendingDeprecationWarning,
115	)
116	path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1
117	if include_start_coord:
118		if is_UT(tokenization_mode):
119			return tokens[: path_start_idx + 1]
120		elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
121			return tokens[: path_start_idx + 5]
122		else:
123			err_msg: str = f"Invalid tokenization mode: {tokenization_mode}"
124			raise ValueError(err_msg)
125	else:
126		return tokens[:path_start_idx]
127
128
129_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE: list[str] = [
130	"name",
131	"max_grid_size",
132	"token_arr",
133	"tokenizer_map",
134	"vocab_size",
135	"padding_token_index",
136]
137
138
139@serializable_dataclass(
140	properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE,
141	kw_only=True,
142)
143class MazeTokenizer(SerializableDataclass):
144	"""LEGACY Tokenizer for mazes
145
146	> [!CAUTION]
147	> `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended
148	> for use, but will remain for compatibility with existing code.
149
150	# Parameters:
151	- `tokenization_mode: TokenizationMode`
152		mode of tokenization. required.
153	- `max_grid_size: int | None`
154		maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text
155
156	# Properties
157	- `name: str`
158		auto-generated name of the tokenizer from mode and size
159
160	## Conditional Properties
161
162	- `node_strings_map: Mapping[CoordTup, str]`
163		map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode
164
165	these all return `None` if `max_grid_size` is `None`.
166	Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None`
167
168	- `token_arr: list[str]`
169		list of tokens, in order of their indices in the vocabulary
170	- `tokenizer_map: Mapping[str, int]`
171		map from token to index
172	- `vocab_size: int`
173		size of the vocabulary
174	- `padding_token_index: int`
175		index of the padding token
176
177	# Methods
178	- `coords_to_strings(coords: list[CoordTup]) -> list[str]`
179		convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates
180	- `strings_to_coords(strings: list[str]) -> list[CoordTup]`
181		convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
182
183	"""
184
185	# parameters
186	# ============================================================
187
188	tokenization_mode: TokenizationMode = serializable_field(
189		default=TokenizationMode.AOTP_UT_uniform,
190		serialization_fn=lambda x: x.value,
191		loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]],
192	)
193
194	max_grid_size: int | None = serializable_field(default=None)
195
196	# properties
197	# ============================================================
198
199	@property
200	def name(self) -> str:
201		"""auto-generated name of the tokenizer from mode and size"""
202		max_grid_size_str: str = (
203			f"-g{self.max_grid_size}" if self.max_grid_size is not None else ""
204		)
205		return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"
206
207	@cached_property
208	def _node_strings_map(self) -> Mapping[CoordTup, list[str]]:
209		"""map a coordinate to a token"""
210		if self.tokenization_mode in (
211			TokenizationMode.AOTP_UT_rasterized,
212			TokenizationMode.AOTP_UT_uniform,
213		):
214			return Kappa(_coord_to_strings_UT)
215		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
216			return Kappa(_coord_to_strings_indexed)
217		else:
218			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
219			raise ValueError(err_msg)
220
221	@cached_property
222	def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None:
223		"""map a coordinate to a token"""
224		if self.tokenization_mode in (
225			TokenizationMode.AOTP_UT_rasterized,
226			TokenizationMode.AOTP_UT_uniform,
227		):
228			return None
229		else:
230			return self._node_strings_map
231
232	# conditional properties (on max_grid_size existing)
233	# ------------------------------------------------------------
234
235	@cached_property
236	def _token_arr(self) -> list[str]:
237		"""map from index to token"""
238		if self.max_grid_size is None:
239			err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }"
240			raise ValueError(err_msg)
241
242		output: list[str] = list(SPECIAL_TOKENS.values())
243
244		if self.tokenization_mode in (
245			TokenizationMode.AOTP_UT_rasterized,
246			TokenizationMode.AOTP_UT_uniform,
247		):
248			output.extend(
249				[
250					self._node_strings_map[coord][0]
251					for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode](
252						self.max_grid_size,
253					)
254				],
255			)
256		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
257			# TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models
258			output.extend(
259				[
260					"(",
261					",",
262					")",  # new special chars
263					*map(str, range(self.max_grid_size)),  # numbers
264				],
265			)
266		else:
267			err_msg: str = (
268				f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}",
269			)
270			raise ValueError(err_msg)
271
272		return output
273
274	@cached_property
275	def token_arr(self) -> list[str] | None:
276		"get the token array if the max_grid_size is specified"
277		if self.max_grid_size is None:
278			return None
279		return self._token_arr
280
281	@cached_property
282	def _tokenizer_map(self) -> dict[str, int]:
283		"""map from token to index"""
284		return {token: i for i, token in enumerate(self._token_arr)}
285
286	@cached_property
287	def tokenizer_map(self) -> dict[str, int] | None:
288		"get the tokenizer map if the max_grid_size is specified"
289		if self.max_grid_size is None:
290			return None
291		return self._tokenizer_map
292
293	@property
294	def _vocab_size(self) -> int:
295		return len(self._token_arr)
296
297	@property
298	def vocab_size(self) -> int | None:
299		"get the size of the vocabulary if the max_grid_size is specified"
300		if self.max_grid_size is None:
301			return None
302		return self._vocab_size
303
304	@property
305	def _n_tokens(self) -> int:
306		# TODO: deprecate
307		return self._vocab_size
308
309	@property
310	def n_tokens(self) -> int | None:
311		"get the number of tokens if the max_grid_size is specified"
312		if self.max_grid_size is None:
313			return None
314		return self._n_tokens
315
316	@cached_property
317	def _padding_token_index(self) -> int:
318		return self.tokenizer_map[SPECIAL_TOKENS.PADDING]
319
320	@cached_property
321	def padding_token_index(self) -> int | None:
322		"get the index of the padding token if it exists"
323		if self.max_grid_size is None:
324			return None
325		return self._padding_token_index
326
327	# conversion functions
328	# ============================================================
329
330	@overload
331	def coords_to_strings(
332		self,
333		coords: list[str | CoordTup],
334		when_noncoord: Literal["include", "skip"] = "skip",
335	) -> list[str]: ...
336	@overload
337	def coords_to_strings(
338		self,
339		coords: list[CoordTup],
340		when_noncoord: Literal["error"] = "error",
341	) -> list[str]: ...
342	def coords_to_strings(
343		self,
344		coords: list[CoordTup],
345		when_noncoord: WhenMissing = "skip",
346	) -> list[str]:
347		"""map a list of coordinate tuples (and maybe other tokens) to strings
348
349		wraps `maze_dataset.token_utils.coords_to_strings` with either
350		`_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode
351		"""
352		if self.tokenization_mode in (
353			TokenizationMode.AOTP_UT_rasterized,
354			TokenizationMode.AOTP_UT_uniform,
355		):
356			return coords_to_strings(
357				coords=coords,
358				coord_to_strings_func=_coord_to_strings_UT,
359				when_noncoord=when_noncoord,
360			)
361		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
362			return coords_to_strings(
363				coords=coords,
364				coord_to_strings_func=_coord_to_strings_indexed,
365				when_noncoord=when_noncoord,
366			)
367		else:
368			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
369			raise ValueError(err_msg)
370
371	@overload
372	def strings_to_coords(
373		cls,  # noqa: N805
374		text: str | list[str],
375		when_noncoord: Literal["skip"] = "skip",
376	) -> list[CoordTup]: ...
377	@overload
378	def strings_to_coords(
379		cls,  # noqa: N805
380		text: str | list[str],
381		when_noncoord: Literal["error"] = "error",
382	) -> list[CoordTup]: ...
383	@overload
384	def strings_to_coords(
385		cls,  # noqa: N805
386		text: str | list[str],
387		when_noncoord: Literal["include"] = "include",
388	) -> list[str | CoordTup]: ...
389	@classmethod
390	def strings_to_coords(
391		cls,
392		text: str | list[str],
393		when_noncoord: WhenMissing = "skip",
394	) -> list[str | CoordTup]:
395		"wrapper for `maze_dataset.token_utils.strings_to_coords`"
396		return strings_to_coords(text=text, when_noncoord=when_noncoord)
397
398	def encode(self, text: str | list[str]) -> list[int]:
399		"""encode a string or list of strings into a list of tokens"""
400		try:
401			if isinstance(text, str):
402				text = text.split()
403			return [self.tokenizer_map[token] for token in text]
404		except KeyError as e:
405			err_msg: str = (
406				f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}"
407			)
408			raise TokenError(err_msg) from e
409
410	def decode(
411		self,
412		tokens: Sequence[int],
413		joined_tokens: bool = False,
414	) -> list[str] | str:
415		"""decode a list of tokens into a string or list of strings"""
416		try:
417			output: list[str] = [self.token_arr[token] for token in tokens]
418		except IndexError as e:
419			err_msg: str = (
420				f"Token index '{e}' not found in vocabulary of length {self.vocab_size}"
421			)
422			raise TokenError(err_msg) from e
423		if joined_tokens:
424			return " ".join(output)
425		else:
426			return output
427
428	# UT-only coordinate stuff
429	# ============================================================
430
431	@cached_property
432	def coordinate_tokens_coords(self) -> dict[CoordTup, int]:
433		"map of coordiante tuples to their token ids, only valid for UT"
434		# print(f"{self.tokenization_mode = }")
435		if not self.is_UT():
436			err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }"
437			raise ValueError(err_msg)
438
439		if self.max_grid_size is None:
440			err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }"
441			raise ValueError(err_msg)
442
443		raw_converted: list[CoordTup | str] = self.strings_to_coords(
444			self.token_arr,
445			when_noncoord="include",
446		)
447
448		# filter out non-coordinates
449		return {
450			coord: i
451			for i, coord in enumerate(raw_converted)
452			if not isinstance(coord, str)
453		}
454
455	@cached_property
456	def coordinate_tokens_ids(self) -> dict[str, int]:
457		"map of coordinate tokens to their token ids, only valid for UT"
458		# checks performed in call
459		output: dict[str, int] = dict()
460
461		for coord, index in self.coordinate_tokens_coords.items():
462			_for_key: list[str] = self.coords_to_strings([coord])
463			assert len(_for_key) == 1
464			output[_for_key[0]] = index
465
466		return output
467
468	# other
469	# ============================================================
470
471	def summary(self) -> dict:
472		"""returns a summary of the tokenization mode"""
473		return {
474			"tokenization_mode": self.tokenization_mode.value,
475			"max_grid_size": self.max_grid_size,
476			"vocab_size": self.vocab_size,
477		}
478
479	def is_AOTP(self) -> bool:
480		"""returns true if a tokenization mode is Adjacency list, Origin, Target, Path"""
481		return self.tokenization_mode in (
482			TokenizationMode.AOTP_UT_rasterized,
483			TokenizationMode.AOTP_UT_uniform,
484			TokenizationMode.AOTP_CTT_indexed,
485		)
486
487	def is_UT(self) -> bool:
488		"returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
489		return is_UT(self.tokenization_mode)
490
491	def clear_cache(self) -> None:
492		"""clears all cached properties"""
493		# delete the properties only if they exist
494		for name, prop in self.__class__.__dict__.items():
495			if isinstance(prop, cached_property):
496				# if the property exists, delete it
497				try:  # noqa: SIM105
498					delattr(self, name)
499				except AttributeError:
500					pass

class TokenizationMode(enum.Enum):
47class TokenizationMode(Enum):
48	"""legacy tokenization modes
49
50	> [!CAUTION]
51	> Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use.
52	> Use `MazeTokenizerModular` instead.
53
54	# Abbreviations:
55	- `AOTP`: Ajacency list, Origin, Target, Path
56	- `UT`: Unique Token (for each coordiate)
57	- `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)
58
59	# Modes:
60	- `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization
61		example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)`
62	- `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible
63		uses `corner_first_ndindex` function to order the tokens
64	- `AOTP_CTT_indexed`: each coordinate is a tuple of integers
65	"""
66
67	AOTP_UT_rasterized = "AOTP_UT_rasterized"
68	AOTP_UT_uniform = "AOTP_UT_uniform"
69	AOTP_CTT_indexed = "AOTP_CTT_indexed"
70
71	def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer":
72		"convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`"
73		return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)

legacy tokenization modes

Caution

Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use. Use MazeTokenizerModular instead.

Abbreviations:

  • AOTP: Ajacency list, Origin, Target, Path
  • UT: Unique Token (for each coordiate)
  • CTT: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)

Modes:

  • AOTP_UT_rasterized: the "classic" mode: assigning tokens to each coordinate is done via rasterization example: for a 3x3 maze, token order is (0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)
  • AOTP_UT_uniform: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible uses corner_first_ndindex function to order the tokens
  • AOTP_CTT_indexed: each coordinate is a tuple of integers
AOTP_UT_rasterized = <TokenizationMode.AOTP_UT_rasterized: 'AOTP_UT_rasterized'>
AOTP_UT_uniform = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>
AOTP_CTT_indexed = <TokenizationMode.AOTP_CTT_indexed: 'AOTP_CTT_indexed'>
def to_legacy_tokenizer( self, max_grid_size: int | None = None) -> MazeTokenizer:
71	def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer":
72		"convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`"
73		return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)

convert the mode to a legacy MazeTokenizer object given a max_grid_size

Inherited Members
enum.Enum
name
value
def is_UT( tokenization_mode: TokenizationMode) -> bool:
85def is_UT(tokenization_mode: TokenizationMode) -> bool:
86	"returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
87	return tokenization_mode in (
88		TokenizationMode.AOTP_UT_rasterized,
89		TokenizationMode.AOTP_UT_uniform,
90	)

returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)

def get_tokens_up_to_path_start( tokens: list[str], include_start_coord: bool = True, tokenization_mode: TokenizationMode = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>) -> list[str]:
 93def get_tokens_up_to_path_start(
 94	tokens: list[str],
 95	include_start_coord: bool = True,
 96	tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform,
 97) -> list[str]:
 98	"""get tokens up to the path start token
 99
100	# Parameters:
101	- `tokens : list[str]`
102	- `include_start_coord : bool`
103		(defaults to `True`)
104	- `tokenization_mode : TokenizationMode`
105		(defaults to `TokenizationMode.AOTP_UT_uniform`)
106
107	# Returns:
108	- `list[str]` subsequence of `tokens` up to the path start token
109
110	# Raises:
111	- `ValueError` : if `tokenization_mode` is invalid
112	"""
113	warnings.warn(
114		"`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.",
115		TokenizerPendingDeprecationWarning,
116	)
117	path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1
118	if include_start_coord:
119		if is_UT(tokenization_mode):
120			return tokens[: path_start_idx + 1]
121		elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
122			return tokens[: path_start_idx + 5]
123		else:
124			err_msg: str = f"Invalid tokenization mode: {tokenization_mode}"
125			raise ValueError(err_msg)
126	else:
127		return tokens[:path_start_idx]

get tokens up to the path start token

Parameters:

Returns:

  • list[str] subsequence of tokens up to the path start token

Raises:

  • ValueError : if tokenization_mode is invalid
@serializable_dataclass(properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, kw_only=True)
class MazeTokenizer(muutils.json_serialize.serializable_dataclass.SerializableDataclass):
140@serializable_dataclass(
141	properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE,
142	kw_only=True,
143)
144class MazeTokenizer(SerializableDataclass):
145	"""LEGACY Tokenizer for mazes
146
147	> [!CAUTION]
148	> `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended
149	> for use, but will remain for compatibility with existing code.
150
151	# Parameters:
152	- `tokenization_mode: TokenizationMode`
153		mode of tokenization. required.
154	- `max_grid_size: int | None`
155		maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text
156
157	# Properties
158	- `name: str`
159		auto-generated name of the tokenizer from mode and size
160
161	## Conditional Properties
162
163	- `node_strings_map: Mapping[CoordTup, str]`
164		map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode
165
166	these all return `None` if `max_grid_size` is `None`.
167	Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None`
168
169	- `token_arr: list[str]`
170		list of tokens, in order of their indices in the vocabulary
171	- `tokenizer_map: Mapping[str, int]`
172		map from token to index
173	- `vocab_size: int`
174		size of the vocabulary
175	- `padding_token_index: int`
176		index of the padding token
177
178	# Methods
179	- `coords_to_strings(coords: list[CoordTup]) -> list[str]`
180		convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates
181	- `strings_to_coords(strings: list[str]) -> list[CoordTup]`
182		convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
183
184	"""
185
186	# parameters
187	# ============================================================
188
189	tokenization_mode: TokenizationMode = serializable_field(
190		default=TokenizationMode.AOTP_UT_uniform,
191		serialization_fn=lambda x: x.value,
192		loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]],
193	)
194
195	max_grid_size: int | None = serializable_field(default=None)
196
197	# properties
198	# ============================================================
199
200	@property
201	def name(self) -> str:
202		"""auto-generated name of the tokenizer from mode and size"""
203		max_grid_size_str: str = (
204			f"-g{self.max_grid_size}" if self.max_grid_size is not None else ""
205		)
206		return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"
207
208	@cached_property
209	def _node_strings_map(self) -> Mapping[CoordTup, list[str]]:
210		"""map a coordinate to a token"""
211		if self.tokenization_mode in (
212			TokenizationMode.AOTP_UT_rasterized,
213			TokenizationMode.AOTP_UT_uniform,
214		):
215			return Kappa(_coord_to_strings_UT)
216		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
217			return Kappa(_coord_to_strings_indexed)
218		else:
219			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
220			raise ValueError(err_msg)
221
222	@cached_property
223	def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None:
224		"""map a coordinate to a token"""
225		if self.tokenization_mode in (
226			TokenizationMode.AOTP_UT_rasterized,
227			TokenizationMode.AOTP_UT_uniform,
228		):
229			return None
230		else:
231			return self._node_strings_map
232
233	# conditional properties (on max_grid_size existing)
234	# ------------------------------------------------------------
235
236	@cached_property
237	def _token_arr(self) -> list[str]:
238		"""map from index to token"""
239		if self.max_grid_size is None:
240			err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }"
241			raise ValueError(err_msg)
242
243		output: list[str] = list(SPECIAL_TOKENS.values())
244
245		if self.tokenization_mode in (
246			TokenizationMode.AOTP_UT_rasterized,
247			TokenizationMode.AOTP_UT_uniform,
248		):
249			output.extend(
250				[
251					self._node_strings_map[coord][0]
252					for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode](
253						self.max_grid_size,
254					)
255				],
256			)
257		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
258			# TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models
259			output.extend(
260				[
261					"(",
262					",",
263					")",  # new special chars
264					*map(str, range(self.max_grid_size)),  # numbers
265				],
266			)
267		else:
268			err_msg: str = (
269				f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}",
270			)
271			raise ValueError(err_msg)
272
273		return output
274
275	@cached_property
276	def token_arr(self) -> list[str] | None:
277		"get the token array if the max_grid_size is specified"
278		if self.max_grid_size is None:
279			return None
280		return self._token_arr
281
282	@cached_property
283	def _tokenizer_map(self) -> dict[str, int]:
284		"""map from token to index"""
285		return {token: i for i, token in enumerate(self._token_arr)}
286
287	@cached_property
288	def tokenizer_map(self) -> dict[str, int] | None:
289		"get the tokenizer map if the max_grid_size is specified"
290		if self.max_grid_size is None:
291			return None
292		return self._tokenizer_map
293
294	@property
295	def _vocab_size(self) -> int:
296		return len(self._token_arr)
297
298	@property
299	def vocab_size(self) -> int | None:
300		"get the size of the vocabulary if the max_grid_size is specified"
301		if self.max_grid_size is None:
302			return None
303		return self._vocab_size
304
305	@property
306	def _n_tokens(self) -> int:
307		# TODO: deprecate
308		return self._vocab_size
309
310	@property
311	def n_tokens(self) -> int | None:
312		"get the number of tokens if the max_grid_size is specified"
313		if self.max_grid_size is None:
314			return None
315		return self._n_tokens
316
317	@cached_property
318	def _padding_token_index(self) -> int:
319		return self.tokenizer_map[SPECIAL_TOKENS.PADDING]
320
321	@cached_property
322	def padding_token_index(self) -> int | None:
323		"get the index of the padding token if it exists"
324		if self.max_grid_size is None:
325			return None
326		return self._padding_token_index
327
328	# conversion functions
329	# ============================================================
330
331	@overload
332	def coords_to_strings(
333		self,
334		coords: list[str | CoordTup],
335		when_noncoord: Literal["include", "skip"] = "skip",
336	) -> list[str]: ...
337	@overload
338	def coords_to_strings(
339		self,
340		coords: list[CoordTup],
341		when_noncoord: Literal["error"] = "error",
342	) -> list[str]: ...
343	def coords_to_strings(
344		self,
345		coords: list[CoordTup],
346		when_noncoord: WhenMissing = "skip",
347	) -> list[str]:
348		"""map a list of coordinate tuples (and maybe other tokens) to strings
349
350		wraps `maze_dataset.token_utils.coords_to_strings` with either
351		`_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode
352		"""
353		if self.tokenization_mode in (
354			TokenizationMode.AOTP_UT_rasterized,
355			TokenizationMode.AOTP_UT_uniform,
356		):
357			return coords_to_strings(
358				coords=coords,
359				coord_to_strings_func=_coord_to_strings_UT,
360				when_noncoord=when_noncoord,
361			)
362		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
363			return coords_to_strings(
364				coords=coords,
365				coord_to_strings_func=_coord_to_strings_indexed,
366				when_noncoord=when_noncoord,
367			)
368		else:
369			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
370			raise ValueError(err_msg)
371
372	@overload
373	def strings_to_coords(
374		cls,  # noqa: N805
375		text: str | list[str],
376		when_noncoord: Literal["skip"] = "skip",
377	) -> list[CoordTup]: ...
378	@overload
379	def strings_to_coords(
380		cls,  # noqa: N805
381		text: str | list[str],
382		when_noncoord: Literal["error"] = "error",
383	) -> list[CoordTup]: ...
384	@overload
385	def strings_to_coords(
386		cls,  # noqa: N805
387		text: str | list[str],
388		when_noncoord: Literal["include"] = "include",
389	) -> list[str | CoordTup]: ...
390	@classmethod
391	def strings_to_coords(
392		cls,
393		text: str | list[str],
394		when_noncoord: WhenMissing = "skip",
395	) -> list[str | CoordTup]:
396		"wrapper for `maze_dataset.token_utils.strings_to_coords`"
397		return strings_to_coords(text=text, when_noncoord=when_noncoord)
398
399	def encode(self, text: str | list[str]) -> list[int]:
400		"""encode a string or list of strings into a list of tokens"""
401		try:
402			if isinstance(text, str):
403				text = text.split()
404			return [self.tokenizer_map[token] for token in text]
405		except KeyError as e:
406			err_msg: str = (
407				f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}"
408			)
409			raise TokenError(err_msg) from e
410
411	def decode(
412		self,
413		tokens: Sequence[int],
414		joined_tokens: bool = False,
415	) -> list[str] | str:
416		"""decode a list of tokens into a string or list of strings"""
417		try:
418			output: list[str] = [self.token_arr[token] for token in tokens]
419		except IndexError as e:
420			err_msg: str = (
421				f"Token index '{e}' not found in vocabulary of length {self.vocab_size}"
422			)
423			raise TokenError(err_msg) from e
424		if joined_tokens:
425			return " ".join(output)
426		else:
427			return output
428
429	# UT-only coordinate stuff
430	# ============================================================
431
432	@cached_property
433	def coordinate_tokens_coords(self) -> dict[CoordTup, int]:
434		"map of coordiante tuples to their token ids, only valid for UT"
435		# print(f"{self.tokenization_mode = }")
436		if not self.is_UT():
437			err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }"
438			raise ValueError(err_msg)
439
440		if self.max_grid_size is None:
441			err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }"
442			raise ValueError(err_msg)
443
444		raw_converted: list[CoordTup | str] = self.strings_to_coords(
445			self.token_arr,
446			when_noncoord="include",
447		)
448
449		# filter out non-coordinates
450		return {
451			coord: i
452			for i, coord in enumerate(raw_converted)
453			if not isinstance(coord, str)
454		}
455
456	@cached_property
457	def coordinate_tokens_ids(self) -> dict[str, int]:
458		"map of coordinate tokens to their token ids, only valid for UT"
459		# checks performed in call
460		output: dict[str, int] = dict()
461
462		for coord, index in self.coordinate_tokens_coords.items():
463			_for_key: list[str] = self.coords_to_strings([coord])
464			assert len(_for_key) == 1
465			output[_for_key[0]] = index
466
467		return output
468
469	# other
470	# ============================================================
471
472	def summary(self) -> dict:
473		"""returns a summary of the tokenization mode"""
474		return {
475			"tokenization_mode": self.tokenization_mode.value,
476			"max_grid_size": self.max_grid_size,
477			"vocab_size": self.vocab_size,
478		}
479
480	def is_AOTP(self) -> bool:
481		"""returns true if a tokenization mode is Adjacency list, Origin, Target, Path"""
482		return self.tokenization_mode in (
483			TokenizationMode.AOTP_UT_rasterized,
484			TokenizationMode.AOTP_UT_uniform,
485			TokenizationMode.AOTP_CTT_indexed,
486		)
487
488	def is_UT(self) -> bool:
489		"returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
490		return is_UT(self.tokenization_mode)
491
492	def clear_cache(self) -> None:
493		"""clears all cached properties"""
494		# delete the properties only if they exist
495		for name, prop in self.__class__.__dict__.items():
496			if isinstance(prop, cached_property):
497				# if the property exists, delete it
498				try:  # noqa: SIM105
499					delattr(self, name)
500				except AttributeError:
501					pass

LEGACY Tokenizer for mazes

Caution

MazeTokenizerModular is the new standard for tokenization. This class is no longer recommended for use, but will remain for compatibility with existing code.

Parameters:

  • tokenization_mode: TokenizationMode mode of tokenization. required.
  • max_grid_size: int | None maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text

Properties

  • name: str auto-generated name of the tokenizer from mode and size

Conditional Properties

  • node_strings_map: Mapping[CoordTup, str] map from node to string. This returns a muutils.kappa.Kappa object which you can use like a dictionary. returns None if not a UT mode

these all return None if max_grid_size is None. Prepend _ to the name to get a guaranteed type, and cause an exception if max_grid_size is None

  • token_arr: list[str] list of tokens, in order of their indices in the vocabulary
  • tokenizer_map: Mapping[str, int] map from token to index
  • vocab_size: int size of the vocabulary
  • padding_token_index: int index of the padding token

Methods

  • coords_to_strings(coords: list[CoordTup]) -> list[str] convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates
  • strings_to_coords(strings: list[str]) -> list[CoordTup] convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
MazeTokenizer( *, tokenization_mode: TokenizationMode = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>, max_grid_size: int | None = None)
tokenization_mode: TokenizationMode = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>
max_grid_size: int | None = None
name: str
200	@property
201	def name(self) -> str:
202		"""auto-generated name of the tokenizer from mode and size"""
203		max_grid_size_str: str = (
204			f"-g{self.max_grid_size}" if self.max_grid_size is not None else ""
205		)
206		return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"

auto-generated name of the tokenizer from mode and size

node_strings_map: Optional[Mapping[tuple[int, int], list[str]]]
222	@cached_property
223	def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None:
224		"""map a coordinate to a token"""
225		if self.tokenization_mode in (
226			TokenizationMode.AOTP_UT_rasterized,
227			TokenizationMode.AOTP_UT_uniform,
228		):
229			return None
230		else:
231			return self._node_strings_map

map a coordinate to a token

token_arr: list[str] | None
275	@cached_property
276	def token_arr(self) -> list[str] | None:
277		"get the token array if the max_grid_size is specified"
278		if self.max_grid_size is None:
279			return None
280		return self._token_arr

get the token array if the max_grid_size is specified

tokenizer_map: dict[str, int] | None
287	@cached_property
288	def tokenizer_map(self) -> dict[str, int] | None:
289		"get the tokenizer map if the max_grid_size is specified"
290		if self.max_grid_size is None:
291			return None
292		return self._tokenizer_map

get the tokenizer map if the max_grid_size is specified

vocab_size: int | None
298	@property
299	def vocab_size(self) -> int | None:
300		"get the size of the vocabulary if the max_grid_size is specified"
301		if self.max_grid_size is None:
302			return None
303		return self._vocab_size

get the size of the vocabulary if the max_grid_size is specified

n_tokens: int | None
310	@property
311	def n_tokens(self) -> int | None:
312		"get the number of tokens if the max_grid_size is specified"
313		if self.max_grid_size is None:
314			return None
315		return self._n_tokens

get the number of tokens if the max_grid_size is specified

padding_token_index: int | None
321	@cached_property
322	def padding_token_index(self) -> int | None:
323		"get the index of the padding token if it exists"
324		if self.max_grid_size is None:
325			return None
326		return self._padding_token_index

get the index of the padding token if it exists

def coords_to_strings( self, coords: list[tuple[int, int]], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str]:
343	def coords_to_strings(
344		self,
345		coords: list[CoordTup],
346		when_noncoord: WhenMissing = "skip",
347	) -> list[str]:
348		"""map a list of coordinate tuples (and maybe other tokens) to strings
349
350		wraps `maze_dataset.token_utils.coords_to_strings` with either
351		`_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode
352		"""
353		if self.tokenization_mode in (
354			TokenizationMode.AOTP_UT_rasterized,
355			TokenizationMode.AOTP_UT_uniform,
356		):
357			return coords_to_strings(
358				coords=coords,
359				coord_to_strings_func=_coord_to_strings_UT,
360				when_noncoord=when_noncoord,
361			)
362		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
363			return coords_to_strings(
364				coords=coords,
365				coord_to_strings_func=_coord_to_strings_indexed,
366				when_noncoord=when_noncoord,
367			)
368		else:
369			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
370			raise ValueError(err_msg)

map a list of coordinate tuples (and maybe other tokens) to strings

wraps maze_dataset.token_utils.coords_to_strings with either _coord_to_strings_UT or _coord_to_strings_indexed depending on the tokenization mode

@classmethod
def strings_to_coords( cls, text: str | list[str], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str | tuple[int, int]]:
390	@classmethod
391	def strings_to_coords(
392		cls,
393		text: str | list[str],
394		when_noncoord: WhenMissing = "skip",
395	) -> list[str | CoordTup]:
396		"wrapper for `maze_dataset.token_utils.strings_to_coords`"
397		return strings_to_coords(text=text, when_noncoord=when_noncoord)
def encode(self, text: str | list[str]) -> list[int]:
399	def encode(self, text: str | list[str]) -> list[int]:
400		"""encode a string or list of strings into a list of tokens"""
401		try:
402			if isinstance(text, str):
403				text = text.split()
404			return [self.tokenizer_map[token] for token in text]
405		except KeyError as e:
406			err_msg: str = (
407				f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}"
408			)
409			raise TokenError(err_msg) from e

encode a string or list of strings into a list of tokens

def decode( self, tokens: Sequence[int], joined_tokens: bool = False) -> list[str] | str:
411	def decode(
412		self,
413		tokens: Sequence[int],
414		joined_tokens: bool = False,
415	) -> list[str] | str:
416		"""decode a list of tokens into a string or list of strings"""
417		try:
418			output: list[str] = [self.token_arr[token] for token in tokens]
419		except IndexError as e:
420			err_msg: str = (
421				f"Token index '{e}' not found in vocabulary of length {self.vocab_size}"
422			)
423			raise TokenError(err_msg) from e
424		if joined_tokens:
425			return " ".join(output)
426		else:
427			return output

decode a list of tokens into a string or list of strings

coordinate_tokens_coords: dict[tuple[int, int], int]
432	@cached_property
433	def coordinate_tokens_coords(self) -> dict[CoordTup, int]:
434		"map of coordiante tuples to their token ids, only valid for UT"
435		# print(f"{self.tokenization_mode = }")
436		if not self.is_UT():
437			err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }"
438			raise ValueError(err_msg)
439
440		if self.max_grid_size is None:
441			err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }"
442			raise ValueError(err_msg)
443
444		raw_converted: list[CoordTup | str] = self.strings_to_coords(
445			self.token_arr,
446			when_noncoord="include",
447		)
448
449		# filter out non-coordinates
450		return {
451			coord: i
452			for i, coord in enumerate(raw_converted)
453			if not isinstance(coord, str)
454		}

map of coordiante tuples to their token ids, only valid for UT

coordinate_tokens_ids: dict[str, int]
456	@cached_property
457	def coordinate_tokens_ids(self) -> dict[str, int]:
458		"map of coordinate tokens to their token ids, only valid for UT"
459		# checks performed in call
460		output: dict[str, int] = dict()
461
462		for coord, index in self.coordinate_tokens_coords.items():
463			_for_key: list[str] = self.coords_to_strings([coord])
464			assert len(_for_key) == 1
465			output[_for_key[0]] = index
466
467		return output

map of coordinate tokens to their token ids, only valid for UT

def summary(self) -> dict:
472	def summary(self) -> dict:
473		"""returns a summary of the tokenization mode"""
474		return {
475			"tokenization_mode": self.tokenization_mode.value,
476			"max_grid_size": self.max_grid_size,
477			"vocab_size": self.vocab_size,
478		}

returns a summary of the tokenization mode

def is_AOTP(self) -> bool:
480	def is_AOTP(self) -> bool:
481		"""returns true if a tokenization mode is Adjacency list, Origin, Target, Path"""
482		return self.tokenization_mode in (
483			TokenizationMode.AOTP_UT_rasterized,
484			TokenizationMode.AOTP_UT_uniform,
485			TokenizationMode.AOTP_CTT_indexed,
486		)

returns true if a tokenization mode is Adjacency list, Origin, Target, Path

def is_UT(self) -> bool:
488	def is_UT(self) -> bool:
489		"returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
490		return is_UT(self.tokenization_mode)

returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)

def clear_cache(self) -> None:
492	def clear_cache(self) -> None:
493		"""clears all cached properties"""
494		# delete the properties only if they exist
495		for name, prop in self.__class__.__dict__.items():
496			if isinstance(prop, cached_property):
497				# if the property exists, delete it
498				try:  # noqa: SIM105
499					delattr(self, name)
500				except AttributeError:
501					pass

clears all cached properties

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict