docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.tokenization.maze_tokenizer

preserving legacy imports


 1"""preserving legacy imports"""
 2
 3from maze_dataset.tokenization.maze_tokenizer_legacy import (
 4	MazeTokenizer,
 5	TokenizationMode,
 6)
 7from maze_dataset.tokenization.modular.maze_tokenizer_modular import (
 8	MazeTokenizerModular,
 9)
10
11__all__ = [
12	"MazeTokenizer",
13	"TokenizationMode",
14	"MazeTokenizerModular",
15]

@serializable_dataclass(properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, kw_only=True)
class MazeTokenizer(muutils.json_serialize.serializable_dataclass.SerializableDataclass):
140@serializable_dataclass(
141	properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE,
142	kw_only=True,
143)
144class MazeTokenizer(SerializableDataclass):
145	"""LEGACY Tokenizer for mazes
146
147	> [!CAUTION]
148	> `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended
149	> for use, but will remain for compatibility with existing code.
150
151	# Parameters:
152	- `tokenization_mode: TokenizationMode`
153		mode of tokenization. required.
154	- `max_grid_size: int | None`
155		maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text
156
157	# Properties
158	- `name: str`
159		auto-generated name of the tokenizer from mode and size
160
161	## Conditional Properties
162
163	- `node_strings_map: Mapping[CoordTup, str]`
164		map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode
165
166	these all return `None` if `max_grid_size` is `None`.
167	Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None`
168
169	- `token_arr: list[str]`
170		list of tokens, in order of their indices in the vocabulary
171	- `tokenizer_map: Mapping[str, int]`
172		map from token to index
173	- `vocab_size: int`
174		size of the vocabulary
175	- `padding_token_index: int`
176		index of the padding token
177
178	# Methods
179	- `coords_to_strings(coords: list[CoordTup]) -> list[str]`
180		convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates
181	- `strings_to_coords(strings: list[str]) -> list[CoordTup]`
182		convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
183
184	"""
185
186	# parameters
187	# ============================================================
188
189	tokenization_mode: TokenizationMode = serializable_field(
190		default=TokenizationMode.AOTP_UT_uniform,
191		serialization_fn=lambda x: x.value,
192		loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]],
193	)
194
195	max_grid_size: int | None = serializable_field(default=None)
196
197	# properties
198	# ============================================================
199
200	@property
201	def name(self) -> str:
202		"""auto-generated name of the tokenizer from mode and size"""
203		max_grid_size_str: str = (
204			f"-g{self.max_grid_size}" if self.max_grid_size is not None else ""
205		)
206		return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"
207
208	@cached_property
209	def _node_strings_map(self) -> Mapping[CoordTup, list[str]]:
210		"""map a coordinate to a token"""
211		if self.tokenization_mode in (
212			TokenizationMode.AOTP_UT_rasterized,
213			TokenizationMode.AOTP_UT_uniform,
214		):
215			return Kappa(_coord_to_strings_UT)
216		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
217			return Kappa(_coord_to_strings_indexed)
218		else:
219			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
220			raise ValueError(err_msg)
221
222	@cached_property
223	def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None:
224		"""map a coordinate to a token"""
225		if self.tokenization_mode in (
226			TokenizationMode.AOTP_UT_rasterized,
227			TokenizationMode.AOTP_UT_uniform,
228		):
229			return None
230		else:
231			return self._node_strings_map
232
233	# conditional properties (on max_grid_size existing)
234	# ------------------------------------------------------------
235
236	@cached_property
237	def _token_arr(self) -> list[str]:
238		"""map from index to token"""
239		if self.max_grid_size is None:
240			err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }"
241			raise ValueError(err_msg)
242
243		output: list[str] = list(SPECIAL_TOKENS.values())
244
245		if self.tokenization_mode in (
246			TokenizationMode.AOTP_UT_rasterized,
247			TokenizationMode.AOTP_UT_uniform,
248		):
249			output.extend(
250				[
251					self._node_strings_map[coord][0]
252					for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode](
253						self.max_grid_size,
254					)
255				],
256			)
257		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
258			# TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models
259			output.extend(
260				[
261					"(",
262					",",
263					")",  # new special chars
264					*map(str, range(self.max_grid_size)),  # numbers
265				],
266			)
267		else:
268			err_msg: str = (
269				f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}",
270			)
271			raise ValueError(err_msg)
272
273		return output
274
275	@cached_property
276	def token_arr(self) -> list[str] | None:
277		"get the token array if the max_grid_size is specified"
278		if self.max_grid_size is None:
279			return None
280		return self._token_arr
281
282	@cached_property
283	def _tokenizer_map(self) -> dict[str, int]:
284		"""map from token to index"""
285		return {token: i for i, token in enumerate(self._token_arr)}
286
287	@cached_property
288	def tokenizer_map(self) -> dict[str, int] | None:
289		"get the tokenizer map if the max_grid_size is specified"
290		if self.max_grid_size is None:
291			return None
292		return self._tokenizer_map
293
294	@property
295	def _vocab_size(self) -> int:
296		return len(self._token_arr)
297
298	@property
299	def vocab_size(self) -> int | None:
300		"get the size of the vocabulary if the max_grid_size is specified"
301		if self.max_grid_size is None:
302			return None
303		return self._vocab_size
304
305	@property
306	def _n_tokens(self) -> int:
307		# TODO: deprecate
308		return self._vocab_size
309
310	@property
311	def n_tokens(self) -> int | None:
312		"get the number of tokens if the max_grid_size is specified"
313		if self.max_grid_size is None:
314			return None
315		return self._n_tokens
316
317	@cached_property
318	def _padding_token_index(self) -> int:
319		return self.tokenizer_map[SPECIAL_TOKENS.PADDING]
320
321	@cached_property
322	def padding_token_index(self) -> int | None:
323		"get the index of the padding token if it exists"
324		if self.max_grid_size is None:
325			return None
326		return self._padding_token_index
327
328	# conversion functions
329	# ============================================================
330
331	@overload
332	def coords_to_strings(
333		self,
334		coords: list[str | CoordTup],
335		when_noncoord: Literal["include", "skip"] = "skip",
336	) -> list[str]: ...
337	@overload
338	def coords_to_strings(
339		self,
340		coords: list[CoordTup],
341		when_noncoord: Literal["error"] = "error",
342	) -> list[str]: ...
343	def coords_to_strings(
344		self,
345		coords: list[CoordTup],
346		when_noncoord: WhenMissing = "skip",
347	) -> list[str]:
348		"""map a list of coordinate tuples (and maybe other tokens) to strings
349
350		wraps `maze_dataset.token_utils.coords_to_strings` with either
351		`_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode
352		"""
353		if self.tokenization_mode in (
354			TokenizationMode.AOTP_UT_rasterized,
355			TokenizationMode.AOTP_UT_uniform,
356		):
357			return coords_to_strings(
358				coords=coords,
359				coord_to_strings_func=_coord_to_strings_UT,
360				when_noncoord=when_noncoord,
361			)
362		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
363			return coords_to_strings(
364				coords=coords,
365				coord_to_strings_func=_coord_to_strings_indexed,
366				when_noncoord=when_noncoord,
367			)
368		else:
369			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
370			raise ValueError(err_msg)
371
372	@overload
373	def strings_to_coords(
374		cls,  # noqa: N805
375		text: str | list[str],
376		when_noncoord: Literal["skip"] = "skip",
377	) -> list[CoordTup]: ...
378	@overload
379	def strings_to_coords(
380		cls,  # noqa: N805
381		text: str | list[str],
382		when_noncoord: Literal["error"] = "error",
383	) -> list[CoordTup]: ...
384	@overload
385	def strings_to_coords(
386		cls,  # noqa: N805
387		text: str | list[str],
388		when_noncoord: Literal["include"] = "include",
389	) -> list[str | CoordTup]: ...
390	@classmethod
391	def strings_to_coords(
392		cls,
393		text: str | list[str],
394		when_noncoord: WhenMissing = "skip",
395	) -> list[str | CoordTup]:
396		"wrapper for `maze_dataset.token_utils.strings_to_coords`"
397		return strings_to_coords(text=text, when_noncoord=when_noncoord)
398
399	def encode(self, text: str | list[str]) -> list[int]:
400		"""encode a string or list of strings into a list of tokens"""
401		try:
402			if isinstance(text, str):
403				text = text.split()
404			return [self.tokenizer_map[token] for token in text]
405		except KeyError as e:
406			err_msg: str = (
407				f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}"
408			)
409			raise TokenError(err_msg) from e
410
411	def decode(
412		self,
413		tokens: Sequence[int],
414		joined_tokens: bool = False,
415	) -> list[str] | str:
416		"""decode a list of tokens into a string or list of strings"""
417		try:
418			output: list[str] = [self.token_arr[token] for token in tokens]
419		except IndexError as e:
420			err_msg: str = (
421				f"Token index '{e}' not found in vocabulary of length {self.vocab_size}"
422			)
423			raise TokenError(err_msg) from e
424		if joined_tokens:
425			return " ".join(output)
426		else:
427			return output
428
429	# UT-only coordinate stuff
430	# ============================================================
431
432	@cached_property
433	def coordinate_tokens_coords(self) -> dict[CoordTup, int]:
434		"map of coordiante tuples to their token ids, only valid for UT"
435		# print(f"{self.tokenization_mode = }")
436		if not self.is_UT():
437			err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }"
438			raise ValueError(err_msg)
439
440		if self.max_grid_size is None:
441			err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }"
442			raise ValueError(err_msg)
443
444		raw_converted: list[CoordTup | str] = self.strings_to_coords(
445			self.token_arr,
446			when_noncoord="include",
447		)
448
449		# filter out non-coordinates
450		return {
451			coord: i
452			for i, coord in enumerate(raw_converted)
453			if not isinstance(coord, str)
454		}
455
456	@cached_property
457	def coordinate_tokens_ids(self) -> dict[str, int]:
458		"map of coordinate tokens to their token ids, only valid for UT"
459		# checks performed in call
460		output: dict[str, int] = dict()
461
462		for coord, index in self.coordinate_tokens_coords.items():
463			_for_key: list[str] = self.coords_to_strings([coord])
464			assert len(_for_key) == 1
465			output[_for_key[0]] = index
466
467		return output
468
469	# other
470	# ============================================================
471
472	def summary(self) -> dict:
473		"""returns a summary of the tokenization mode"""
474		return {
475			"tokenization_mode": self.tokenization_mode.value,
476			"max_grid_size": self.max_grid_size,
477			"vocab_size": self.vocab_size,
478		}
479
480	def is_AOTP(self) -> bool:
481		"""returns true if a tokenization mode is Adjacency list, Origin, Target, Path"""
482		return self.tokenization_mode in (
483			TokenizationMode.AOTP_UT_rasterized,
484			TokenizationMode.AOTP_UT_uniform,
485			TokenizationMode.AOTP_CTT_indexed,
486		)
487
488	def is_UT(self) -> bool:
489		"returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
490		return is_UT(self.tokenization_mode)
491
492	def clear_cache(self) -> None:
493		"""clears all cached properties"""
494		# delete the properties only if they exist
495		for name, prop in self.__class__.__dict__.items():
496			if isinstance(prop, cached_property):
497				# if the property exists, delete it
498				try:  # noqa: SIM105
499					delattr(self, name)
500				except AttributeError:
501					pass

LEGACY Tokenizer for mazes

Caution

MazeTokenizerModular is the new standard for tokenization. This class is no longer recommended for use, but will remain for compatibility with existing code.

Parameters:

  • tokenization_mode: TokenizationMode mode of tokenization. required.
  • max_grid_size: int | None maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text

Properties

  • name: str auto-generated name of the tokenizer from mode and size

Conditional Properties

  • node_strings_map: Mapping[CoordTup, str] map from node to string. This returns a muutils.kappa.Kappa object which you can use like a dictionary. returns None if not a UT mode

these all return None if max_grid_size is None. Prepend _ to the name to get a guaranteed type, and cause an exception if max_grid_size is None

  • token_arr: list[str] list of tokens, in order of their indices in the vocabulary
  • tokenizer_map: Mapping[str, int] map from token to index
  • vocab_size: int size of the vocabulary
  • padding_token_index: int index of the padding token

Methods

  • coords_to_strings(coords: list[CoordTup]) -> list[str] convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates
  • strings_to_coords(strings: list[str]) -> list[CoordTup] convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
MazeTokenizer( *, tokenization_mode: TokenizationMode = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>, max_grid_size: int | None = None)
tokenization_mode: TokenizationMode = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>
max_grid_size: int | None = None
name: str
200	@property
201	def name(self) -> str:
202		"""auto-generated name of the tokenizer from mode and size"""
203		max_grid_size_str: str = (
204			f"-g{self.max_grid_size}" if self.max_grid_size is not None else ""
205		)
206		return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"

auto-generated name of the tokenizer from mode and size

node_strings_map: Optional[Mapping[tuple[int, int], list[str]]]
222	@cached_property
223	def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None:
224		"""map a coordinate to a token"""
225		if self.tokenization_mode in (
226			TokenizationMode.AOTP_UT_rasterized,
227			TokenizationMode.AOTP_UT_uniform,
228		):
229			return None
230		else:
231			return self._node_strings_map

map a coordinate to a token

token_arr: list[str] | None
275	@cached_property
276	def token_arr(self) -> list[str] | None:
277		"get the token array if the max_grid_size is specified"
278		if self.max_grid_size is None:
279			return None
280		return self._token_arr

get the token array if the max_grid_size is specified

tokenizer_map: dict[str, int] | None
287	@cached_property
288	def tokenizer_map(self) -> dict[str, int] | None:
289		"get the tokenizer map if the max_grid_size is specified"
290		if self.max_grid_size is None:
291			return None
292		return self._tokenizer_map

get the tokenizer map if the max_grid_size is specified

vocab_size: int | None
298	@property
299	def vocab_size(self) -> int | None:
300		"get the size of the vocabulary if the max_grid_size is specified"
301		if self.max_grid_size is None:
302			return None
303		return self._vocab_size

get the size of the vocabulary if the max_grid_size is specified

n_tokens: int | None
310	@property
311	def n_tokens(self) -> int | None:
312		"get the number of tokens if the max_grid_size is specified"
313		if self.max_grid_size is None:
314			return None
315		return self._n_tokens

get the number of tokens if the max_grid_size is specified

padding_token_index: int | None
321	@cached_property
322	def padding_token_index(self) -> int | None:
323		"get the index of the padding token if it exists"
324		if self.max_grid_size is None:
325			return None
326		return self._padding_token_index

get the index of the padding token if it exists

def coords_to_strings( self, coords: list[tuple[int, int]], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str]:
343	def coords_to_strings(
344		self,
345		coords: list[CoordTup],
346		when_noncoord: WhenMissing = "skip",
347	) -> list[str]:
348		"""map a list of coordinate tuples (and maybe other tokens) to strings
349
350		wraps `maze_dataset.token_utils.coords_to_strings` with either
351		`_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode
352		"""
353		if self.tokenization_mode in (
354			TokenizationMode.AOTP_UT_rasterized,
355			TokenizationMode.AOTP_UT_uniform,
356		):
357			return coords_to_strings(
358				coords=coords,
359				coord_to_strings_func=_coord_to_strings_UT,
360				when_noncoord=when_noncoord,
361			)
362		elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
363			return coords_to_strings(
364				coords=coords,
365				coord_to_strings_func=_coord_to_strings_indexed,
366				when_noncoord=when_noncoord,
367			)
368		else:
369			err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
370			raise ValueError(err_msg)

map a list of coordinate tuples (and maybe other tokens) to strings

wraps maze_dataset.token_utils.coords_to_strings with either _coord_to_strings_UT or _coord_to_strings_indexed depending on the tokenization mode

@classmethod
def strings_to_coords( cls, text: str | list[str], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str | tuple[int, int]]:
390	@classmethod
391	def strings_to_coords(
392		cls,
393		text: str | list[str],
394		when_noncoord: WhenMissing = "skip",
395	) -> list[str | CoordTup]:
396		"wrapper for `maze_dataset.token_utils.strings_to_coords`"
397		return strings_to_coords(text=text, when_noncoord=when_noncoord)
def encode(self, text: str | list[str]) -> list[int]:
399	def encode(self, text: str | list[str]) -> list[int]:
400		"""encode a string or list of strings into a list of tokens"""
401		try:
402			if isinstance(text, str):
403				text = text.split()
404			return [self.tokenizer_map[token] for token in text]
405		except KeyError as e:
406			err_msg: str = (
407				f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}"
408			)
409			raise TokenError(err_msg) from e

encode a string or list of strings into a list of tokens

def decode( self, tokens: Sequence[int], joined_tokens: bool = False) -> list[str] | str:
411	def decode(
412		self,
413		tokens: Sequence[int],
414		joined_tokens: bool = False,
415	) -> list[str] | str:
416		"""decode a list of tokens into a string or list of strings"""
417		try:
418			output: list[str] = [self.token_arr[token] for token in tokens]
419		except IndexError as e:
420			err_msg: str = (
421				f"Token index '{e}' not found in vocabulary of length {self.vocab_size}"
422			)
423			raise TokenError(err_msg) from e
424		if joined_tokens:
425			return " ".join(output)
426		else:
427			return output

decode a list of tokens into a string or list of strings

coordinate_tokens_coords: dict[tuple[int, int], int]
432	@cached_property
433	def coordinate_tokens_coords(self) -> dict[CoordTup, int]:
434		"map of coordiante tuples to their token ids, only valid for UT"
435		# print(f"{self.tokenization_mode = }")
436		if not self.is_UT():
437			err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }"
438			raise ValueError(err_msg)
439
440		if self.max_grid_size is None:
441			err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }"
442			raise ValueError(err_msg)
443
444		raw_converted: list[CoordTup | str] = self.strings_to_coords(
445			self.token_arr,
446			when_noncoord="include",
447		)
448
449		# filter out non-coordinates
450		return {
451			coord: i
452			for i, coord in enumerate(raw_converted)
453			if not isinstance(coord, str)
454		}

map of coordiante tuples to their token ids, only valid for UT

coordinate_tokens_ids: dict[str, int]
456	@cached_property
457	def coordinate_tokens_ids(self) -> dict[str, int]:
458		"map of coordinate tokens to their token ids, only valid for UT"
459		# checks performed in call
460		output: dict[str, int] = dict()
461
462		for coord, index in self.coordinate_tokens_coords.items():
463			_for_key: list[str] = self.coords_to_strings([coord])
464			assert len(_for_key) == 1
465			output[_for_key[0]] = index
466
467		return output

map of coordinate tokens to their token ids, only valid for UT

def summary(self) -> dict:
472	def summary(self) -> dict:
473		"""returns a summary of the tokenization mode"""
474		return {
475			"tokenization_mode": self.tokenization_mode.value,
476			"max_grid_size": self.max_grid_size,
477			"vocab_size": self.vocab_size,
478		}

returns a summary of the tokenization mode

def is_AOTP(self) -> bool:
480	def is_AOTP(self) -> bool:
481		"""returns true if a tokenization mode is Adjacency list, Origin, Target, Path"""
482		return self.tokenization_mode in (
483			TokenizationMode.AOTP_UT_rasterized,
484			TokenizationMode.AOTP_UT_uniform,
485			TokenizationMode.AOTP_CTT_indexed,
486		)

returns true if a tokenization mode is Adjacency list, Origin, Target, Path

def is_UT(self) -> bool:
488	def is_UT(self) -> bool:
489		"returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
490		return is_UT(self.tokenization_mode)

returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)

def clear_cache(self) -> None:
492	def clear_cache(self) -> None:
493		"""clears all cached properties"""
494		# delete the properties only if they exist
495		for name, prop in self.__class__.__dict__.items():
496			if isinstance(prop, cached_property):
497				# if the property exists, delete it
498				try:  # noqa: SIM105
499					delattr(self, name)
500				except AttributeError:
501					pass

clears all cached properties

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
class TokenizationMode(enum.Enum):
47class TokenizationMode(Enum):
48	"""legacy tokenization modes
49
50	> [!CAUTION]
51	> Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use.
52	> Use `MazeTokenizerModular` instead.
53
54	# Abbreviations:
55	- `AOTP`: Ajacency list, Origin, Target, Path
56	- `UT`: Unique Token (for each coordiate)
57	- `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)
58
59	# Modes:
60	- `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization
61		example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)`
62	- `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible
63		uses `corner_first_ndindex` function to order the tokens
64	- `AOTP_CTT_indexed`: each coordinate is a tuple of integers
65	"""
66
67	AOTP_UT_rasterized = "AOTP_UT_rasterized"
68	AOTP_UT_uniform = "AOTP_UT_uniform"
69	AOTP_CTT_indexed = "AOTP_CTT_indexed"
70
71	def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer":
72		"convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`"
73		return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)

legacy tokenization modes

Caution

Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use. Use MazeTokenizerModular instead.

Abbreviations:

  • AOTP: Ajacency list, Origin, Target, Path
  • UT: Unique Token (for each coordiate)
  • CTT: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)

Modes:

  • AOTP_UT_rasterized: the "classic" mode: assigning tokens to each coordinate is done via rasterization example: for a 3x3 maze, token order is (0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)
  • AOTP_UT_uniform: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible uses corner_first_ndindex function to order the tokens
  • AOTP_CTT_indexed: each coordinate is a tuple of integers
AOTP_UT_rasterized = <TokenizationMode.AOTP_UT_rasterized: 'AOTP_UT_rasterized'>
AOTP_UT_uniform = <TokenizationMode.AOTP_UT_uniform: 'AOTP_UT_uniform'>
AOTP_CTT_indexed = <TokenizationMode.AOTP_CTT_indexed: 'AOTP_CTT_indexed'>
def to_legacy_tokenizer( self, max_grid_size: int | None = None) -> MazeTokenizer:
71	def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer":
72		"convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`"
73		return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)

convert the mode to a legacy MazeTokenizer object given a max_grid_size

Inherited Members
enum.Enum
name
value
@serializable_dataclass(frozen=True, kw_only=True, properties_to_serialize=['tokenizer_element_tree_concrete', 'name'])
class MazeTokenizerModular(muutils.json_serialize.serializable_dataclass.SerializableDataclass):
 51@serializable_dataclass(
 52	frozen=True,
 53	kw_only=True,
 54	properties_to_serialize=["tokenizer_element_tree_concrete", "name"],
 55)
 56class MazeTokenizerModular(SerializableDataclass):
 57	"""Tokenizer for mazes
 58
 59	# Parameters
 60	- `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.
 61
 62	# Development
 63	- To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`.
 64	- Furthermore, the mapping reflected in `from_legacy` must also be maintained.
 65	- Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior.
 66	"""
 67
 68	prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field(
 69		default=PromptSequencers.AOTP(),
 70		loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers),
 71	)
 72
 73	def hash_int(self) -> int:
 74		"return integer hash using blake2b"
 75		return _hash_tokenizer_name(self.name)
 76
 77	def __hash__(self) -> int:
 78		"Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
 79		return self.hash_int()
 80
 81	def hash_b64(self, n_bytes: int = 8) -> str:
 82		"""filename-safe base64 encoding of the hash"""
 83		# Use modulus to ensure the integer fits within n_bytes * 8 bits
 84		hash_mod: int = self.hash_int() % (1 << (n_bytes * 8))
 85
 86		encoded = base64.b64encode(
 87			hash_mod.to_bytes(n_bytes, byteorder="big"),
 88			altchars=b"-_",
 89		).decode()
 90
 91		# Remove any padding equals signs
 92		return encoded.rstrip("=")
 93
 94	# Information Querying Methods
 95
 96	@cached_property
 97	def tokenizer_elements(self) -> list[_TokenizerElement]:
 98		"returns a list of all the elements of this tokenizer"
 99		return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]
100
101	def tokenizer_element_tree(self, abstract: bool = False) -> str:
102		"""Returns a string representation of the tree of tokenizer elements contained in `self`.
103
104		# Parameters
105		- `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
106		"""
107		return "\n".join(
108			[
109				type(self).__name__,
110				self.prompt_sequencer.tokenizer_element_tree(
111					abstract=abstract,
112					depth=1,
113				),
114			],
115		)
116
117	@property
118	def tokenizer_element_tree_concrete(self) -> str:
119		"""Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`."""
120		return self.tokenizer_element_tree()
121
122	def tokenizer_element_dict(self) -> dict:
123		"""Nested dictionary of the internal `TokenizerElement`s."""
124		return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}
125
126	@property
127	def name(self) -> str:
128		"""Serializes MazeTokenizer into a key for encoding in zanj"""
129		return "-".join([type(self).__name__, self.prompt_sequencer.name])  # noqa: FLY002
130
131	def summary(self) -> dict[str, str]:
132		"""Single-level dictionary of the internal `TokenizerElement`s."""
133		return {
134			# "prompt_sequencer": self.prompt_sequencer.name,
135			**{elem.attribute_key(): elem.name for elem in self.tokenizer_elements},
136		}
137
138	@staticmethod
139	def _type_check(obj: any) -> None:
140		"""Helper method for `has_element`"""
141		if not (
142			isinstance(obj, _TokenizerElement)
143			or (isinstance(obj, type) and issubclass(obj, _TokenizerElement))
144		):
145			err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass."
146			raise TypeError(err_msg)
147
148	def _has_element_singular(
149		self,
150		el: type[_TokenizerElement] | _TokenizerElement,
151	) -> bool:
152		"""Helper method for `has_element`"""
153		self._type_check(el)
154		if isinstance(el, type):
155			return any(isinstance(e, el) for e in self.tokenizer_elements)
156		else:
157			return el in self.tokenizer_elements
158
159	def has_element(
160		self,
161		*elements: Sequence[type[_TokenizerElement] | _TokenizerElement],
162	) -> bool:
163		"""Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`.
164
165		Querying with a partial subset of `_TokenizerElement` fields is not currently supported.
166		To do such a query, assemble multiple calls to `has_elements`.
167
168		# Parameters
169		- `elements`: Singleton or iterable of `_TokenizerElement` instances or classes.
170		If an instance is provided, then comparison is done via instance equality.
171		If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted.
172		"""
173		if len(elements) == 1 and isinstance(elements[0], Iterable):
174			elements = elements[0]
175		return all(self._has_element_singular(e) for e in elements)
176
177	def is_valid(self, do_except: bool = False) -> bool:
178		"""Returns `True` if `self` is a valid tokenizer.
179
180		Evaluates the validity of all of `self.tokenizer_elements` according to each one's method.
181		"""
182		return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)
183
184	def is_legacy_equivalent(self) -> bool:
185		"""Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`."""
186		return any(
187			self == MazeTokenizerModular.from_legacy(tok_mode)
188			for tok_mode in TokenizationMode
189		)
190
191	def is_tested_tokenizer(self, do_except: bool = False) -> bool:
192		"""Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers.
193
194		uses an fst on the `name` attributes of all the tokenizers
195
196		if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested.
197		"""
198		is_valid: bool = self.is_valid(do_except=do_except)
199		in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except)
200
201		if do_except:
202			assert is_valid, "self.is_valid returns False"
203			return True
204		else:
205			return in_tested_fst and is_valid
206
207	def is_AOTP(self) -> bool:
208		"is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path"
209		return self.has_element(PromptSequencers.AOTP)
210
211	def is_UT(self) -> bool:
212		"is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)"
213		return self.has_element(CoordTokenizers.UT)
214
215	# Alternate Constructors
216	# ======================
217
218	@classmethod
219	def from_legacy(
220		cls,
221		legacy_maze_tokenizer: MazeTokenizer | TokenizationMode,
222	) -> "MazeTokenizerModular":
223		"""Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance."""
224		if isinstance(legacy_maze_tokenizer, MazeTokenizer):
225			legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode
226		return {
227			TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(),
228			TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(),
229			TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular(
230				prompt_sequencer=PromptSequencers.AOTP(
231					coord_tokenizer=CoordTokenizers.CTT(),
232				),
233			),
234		}[legacy_maze_tokenizer]
235
236	# Simple properties
237	# =================
238	@classmethod
239	def from_tokens(
240		cls,
241		tokens: str | list[str],
242	) -> "MazeTokenizerModular":
243		"""Infers most `MazeTokenizerModular` parameters from a full sequence of tokens."""
244		raise NotImplementedError(
245			"Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported",
246		)
247
248	@property
249	def token_arr(self) -> list[str] | None:
250		"""map from index to token"""
251		return VOCAB_LIST
252
253	@property
254	def tokenizer_map(self) -> dict[str, int]:
255		"""map from token to index"""
256		return VOCAB_TOKEN_TO_INDEX
257
258	@property
259	def vocab_size(self) -> int:
260		"""Number of tokens in the static vocab"""
261		return len(VOCAB_LIST)
262
263	@property
264	def n_tokens(self) -> int:
265		"get the number of tokens in the vocabulary (deprecated)"
266		err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead."
267		raise NameError(err_msg)
268
269	@property
270	def padding_token_index(self) -> int:
271		"get the index of the padding token"
272		return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]
273
274	# conversion functions
275	# ============================================================
276
277	def to_tokens(
278		self,
279		maze: LatticeMaze,
280	) -> list[str]:
281		"""Converts maze into a list of tokens."""
282		return self.prompt_sequencer.to_tokens(maze)
283
284	def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]:
285		"calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords"
286		return list(
287			flatten(
288				[self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords],
289			),
290		)
291
292	# TODO: unclear why we need to use `noqa: N805` here since its a classmethod
293	# maybe we need to hit every overload with `@classmethod`?
294	@overload
295	def strings_to_coords(
296		cls,  # noqa: N805
297		text: str | list[str],
298		when_noncoord: Literal["skip"] = "skip",
299	) -> list[CoordTup]: ...
300	@overload
301	def strings_to_coords(
302		cls,  # noqa: N805
303		text: str | list[str],
304		when_noncoord: Literal["error"] = "error",
305	) -> list[CoordTup]: ...
306	@overload
307	def strings_to_coords(
308		cls,  # noqa: N805
309		text: str | list[str],
310		when_noncoord: Literal["include"] = "include",
311	) -> list[str | CoordTup]: ...
312	@classmethod
313	def strings_to_coords(
314		cls,
315		text: str | list[str],
316		when_noncoord: WhenMissing = "skip",
317	) -> list[str | CoordTup]:
318		"wrapper for maze_dataset.token_utils.strings_to_coords"
319		warnings.warn(
320			"`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.",
321			TokenizerPendingDeprecationWarning,
322		)
323		return strings_to_coords(text=text, when_noncoord=when_noncoord)
324
325	@staticmethod
326	def encode(text: str | list[str]) -> list[int]:
327		"""encode a string or list of strings into a list of tokens"""
328		try:
329			if isinstance(text, str):
330				text = text.split()
331			return [VOCAB_TOKEN_TO_INDEX[token] for token in text]
332		except KeyError as e:
333			err_msg: str = f"Token {e} not found in `VOCAB`."
334			raise TokenError(err_msg) from e
335
336	@staticmethod
337	def decode(
338		token_ids: Sequence[int],
339		joined_tokens: bool = False,
340	) -> list[str] | str:
341		"""decode a list of tokens into a string or list of strings"""
342		try:
343			output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids]
344		except IndexError as e:
345			err_msg: str = f"Token index '{e}' not found in `VOCAB`."
346			raise TokenError(err_msg) from e
347		if joined_tokens:
348			return " ".join(output)
349		else:
350			return output

Tokenizer for mazes

Parameters

  • prompt_sequencer: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.

Development

  • To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy TokenizationMode.AOTP_UT_Uniform.
  • Furthermore, the mapping reflected in from_legacy must also be maintained.
  • Updates to MazeTokenizerModular or the _TokenizerElement hierarchy must maintain that behavior.
MazeTokenizerModular( *, prompt_sequencer: maze_dataset.tokenization.modular.elements.PromptSequencers._PromptSequencer = PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.UT(), adj_list_tokenizer=AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), target_tokenizer=TargetTokenizers.Unlabeled(post=False), path_tokenizer=PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False)))
prompt_sequencer: maze_dataset.tokenization.modular.elements.PromptSequencers._PromptSequencer = PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.UT(), adj_list_tokenizer=AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), target_tokenizer=TargetTokenizers.Unlabeled(post=False), path_tokenizer=PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False))
def hash_int(self) -> int:
73	def hash_int(self) -> int:
74		"return integer hash using blake2b"
75		return _hash_tokenizer_name(self.name)

return integer hash using blake2b

def hash_b64(self, n_bytes: int = 8) -> str:
81	def hash_b64(self, n_bytes: int = 8) -> str:
82		"""filename-safe base64 encoding of the hash"""
83		# Use modulus to ensure the integer fits within n_bytes * 8 bits
84		hash_mod: int = self.hash_int() % (1 << (n_bytes * 8))
85
86		encoded = base64.b64encode(
87			hash_mod.to_bytes(n_bytes, byteorder="big"),
88			altchars=b"-_",
89		).decode()
90
91		# Remove any padding equals signs
92		return encoded.rstrip("=")

filename-safe base64 encoding of the hash

tokenizer_elements: list[maze_dataset.tokenization.modular.element_base._TokenizerElement]
96	@cached_property
97	def tokenizer_elements(self) -> list[_TokenizerElement]:
98		"returns a list of all the elements of this tokenizer"
99		return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]

returns a list of all the elements of this tokenizer

def tokenizer_element_tree(self, abstract: bool = False) -> str:
101	def tokenizer_element_tree(self, abstract: bool = False) -> str:
102		"""Returns a string representation of the tree of tokenizer elements contained in `self`.
103
104		# Parameters
105		- `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
106		"""
107		return "\n".join(
108			[
109				type(self).__name__,
110				self.prompt_sequencer.tokenizer_element_tree(
111					abstract=abstract,
112					depth=1,
113				),
114			],
115		)

Returns a string representation of the tree of tokenizer elements contained in self.

Parameters

  • abstract: bool: Whether to print the name of the abstract base class or the concrete class for each _TokenizerElement instance.
tokenizer_element_tree_concrete: str
117	@property
118	def tokenizer_element_tree_concrete(self) -> str:
119		"""Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`."""
120		return self.tokenizer_element_tree()

Property wrapper for tokenizer_element_tree so that it can be used in properties_to_serialize.

def tokenizer_element_dict(self) -> dict:
122	def tokenizer_element_dict(self) -> dict:
123		"""Nested dictionary of the internal `TokenizerElement`s."""
124		return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}

Nested dictionary of the internal TokenizerElements.

name: str
126	@property
127	def name(self) -> str:
128		"""Serializes MazeTokenizer into a key for encoding in zanj"""
129		return "-".join([type(self).__name__, self.prompt_sequencer.name])  # noqa: FLY002

Serializes MazeTokenizer into a key for encoding in zanj

def summary(self) -> dict[str, str]:
131	def summary(self) -> dict[str, str]:
132		"""Single-level dictionary of the internal `TokenizerElement`s."""
133		return {
134			# "prompt_sequencer": self.prompt_sequencer.name,
135			**{elem.attribute_key(): elem.name for elem in self.tokenizer_elements},
136		}

Single-level dictionary of the internal TokenizerElements.

def has_element( self, *elements: Sequence[type[maze_dataset.tokenization.modular.element_base._TokenizerElement] | maze_dataset.tokenization.modular.element_base._TokenizerElement]) -> bool:
159	def has_element(
160		self,
161		*elements: Sequence[type[_TokenizerElement] | _TokenizerElement],
162	) -> bool:
163		"""Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`.
164
165		Querying with a partial subset of `_TokenizerElement` fields is not currently supported.
166		To do such a query, assemble multiple calls to `has_elements`.
167
168		# Parameters
169		- `elements`: Singleton or iterable of `_TokenizerElement` instances or classes.
170		If an instance is provided, then comparison is done via instance equality.
171		If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted.
172		"""
173		if len(elements) == 1 and isinstance(elements[0], Iterable):
174			elements = elements[0]
175		return all(self._has_element_singular(e) for e in elements)

Returns True if the MazeTokenizerModular instance contains ALL of the items specified in elements.

Querying with a partial subset of _TokenizerElement fields is not currently supported. To do such a query, assemble multiple calls to has_elements.

Parameters

  • elements: Singleton or iterable of _TokenizerElement instances or classes. If an instance is provided, then comparison is done via instance equality. If a class is provided, then comparison isdone via isinstance. I.e., any instance of that class is accepted.
def is_valid(self, do_except: bool = False) -> bool:
177	def is_valid(self, do_except: bool = False) -> bool:
178		"""Returns `True` if `self` is a valid tokenizer.
179
180		Evaluates the validity of all of `self.tokenizer_elements` according to each one's method.
181		"""
182		return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)

Returns True if self is a valid tokenizer.

Evaluates the validity of all of self.tokenizer_elements according to each one's method.

def is_legacy_equivalent(self) -> bool:
184	def is_legacy_equivalent(self) -> bool:
185		"""Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`."""
186		return any(
187			self == MazeTokenizerModular.from_legacy(tok_mode)
188			for tok_mode in TokenizationMode
189		)

Returns if self has identical stringification behavior as any legacy MazeTokenizer.

def is_tested_tokenizer(self, do_except: bool = False) -> bool:
191	def is_tested_tokenizer(self, do_except: bool = False) -> bool:
192		"""Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers.
193
194		uses an fst on the `name` attributes of all the tokenizers
195
196		if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested.
197		"""
198		is_valid: bool = self.is_valid(do_except=do_except)
199		in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except)
200
201		if do_except:
202			assert is_valid, "self.is_valid returns False"
203			return True
204		else:
205			return in_tested_fst and is_valid

Returns if the tokenizer is returned by all_tokenizers.get_all_tokenizers, the set of tested and reliable tokenizers.

uses an fst on the name attributes of all the tokenizers

if do_assert is True, raises an AssertionError if the tokenizer is not tested.

def is_AOTP(self) -> bool:
207	def is_AOTP(self) -> bool:
208		"is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path"
209		return self.has_element(PromptSequencers.AOTP)

is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path

def is_UT(self) -> bool:
211	def is_UT(self) -> bool:
212		"is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)"
213		return self.has_element(CoordTokenizers.UT)

is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)

@classmethod
def from_legacy( cls, legacy_maze_tokenizer: MazeTokenizer | TokenizationMode) -> MazeTokenizerModular:
218	@classmethod
219	def from_legacy(
220		cls,
221		legacy_maze_tokenizer: MazeTokenizer | TokenizationMode,
222	) -> "MazeTokenizerModular":
223		"""Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance."""
224		if isinstance(legacy_maze_tokenizer, MazeTokenizer):
225			legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode
226		return {
227			TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(),
228			TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(),
229			TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular(
230				prompt_sequencer=PromptSequencers.AOTP(
231					coord_tokenizer=CoordTokenizers.CTT(),
232				),
233			),
234		}[legacy_maze_tokenizer]

Maps a legacy MazeTokenizer or TokenizationMode to its equivalent MazeTokenizerModular instance.

@classmethod
def from_tokens( cls, tokens: str | list[str]) -> MazeTokenizerModular:
238	@classmethod
239	def from_tokens(
240		cls,
241		tokens: str | list[str],
242	) -> "MazeTokenizerModular":
243		"""Infers most `MazeTokenizerModular` parameters from a full sequence of tokens."""
244		raise NotImplementedError(
245			"Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported",
246		)

Infers most MazeTokenizerModular parameters from a full sequence of tokens.

token_arr: list[str] | None
248	@property
249	def token_arr(self) -> list[str] | None:
250		"""map from index to token"""
251		return VOCAB_LIST

map from index to token

tokenizer_map: dict[str, int]
253	@property
254	def tokenizer_map(self) -> dict[str, int]:
255		"""map from token to index"""
256		return VOCAB_TOKEN_TO_INDEX

map from token to index

vocab_size: int
258	@property
259	def vocab_size(self) -> int:
260		"""Number of tokens in the static vocab"""
261		return len(VOCAB_LIST)

Number of tokens in the static vocab

n_tokens: int
263	@property
264	def n_tokens(self) -> int:
265		"get the number of tokens in the vocabulary (deprecated)"
266		err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead."
267		raise NameError(err_msg)

get the number of tokens in the vocabulary (deprecated)

padding_token_index: int
269	@property
270	def padding_token_index(self) -> int:
271		"get the index of the padding token"
272		return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]

get the index of the padding token

def to_tokens(self, maze: maze_dataset.LatticeMaze) -> list[str]:
277	def to_tokens(
278		self,
279		maze: LatticeMaze,
280	) -> list[str]:
281		"""Converts maze into a list of tokens."""
282		return self.prompt_sequencer.to_tokens(maze)

Converts maze into a list of tokens.

def coords_to_strings( self, coords: list[tuple[int, int] | jaxtyping.Int8[ndarray, 'row_col=2']]) -> list[str]:
284	def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]:
285		"calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords"
286		return list(
287			flatten(
288				[self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords],
289			),
290		)

calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords

@classmethod
def strings_to_coords( cls, text: str | list[str], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str | tuple[int, int]]:
312	@classmethod
313	def strings_to_coords(
314		cls,
315		text: str | list[str],
316		when_noncoord: WhenMissing = "skip",
317	) -> list[str | CoordTup]:
318		"wrapper for maze_dataset.token_utils.strings_to_coords"
319		warnings.warn(
320			"`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.",
321			TokenizerPendingDeprecationWarning,
322		)
323		return strings_to_coords(text=text, when_noncoord=when_noncoord)
@staticmethod
def encode(text: str | list[str]) -> list[int]:
325	@staticmethod
326	def encode(text: str | list[str]) -> list[int]:
327		"""encode a string or list of strings into a list of tokens"""
328		try:
329			if isinstance(text, str):
330				text = text.split()
331			return [VOCAB_TOKEN_TO_INDEX[token] for token in text]
332		except KeyError as e:
333			err_msg: str = f"Token {e} not found in `VOCAB`."
334			raise TokenError(err_msg) from e

encode a string or list of strings into a list of tokens

@staticmethod
def decode(token_ids: Sequence[int], joined_tokens: bool = False) -> list[str] | str:
336	@staticmethod
337	def decode(
338		token_ids: Sequence[int],
339		joined_tokens: bool = False,
340	) -> list[str] | str:
341		"""decode a list of tokens into a string or list of strings"""
342		try:
343			output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids]
344		except IndexError as e:
345			err_msg: str = f"Token index '{e}' not found in `VOCAB`."
346			raise TokenError(err_msg) from e
347		if joined_tokens:
348			return " ".join(output)
349		else:
350			return output

decode a list of tokens into a string or list of strings

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict