docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.tokenization.modular.maze_tokenizer_modular

implements the actual MazeTokenizerModular class


  1"implements the actual `MazeTokenizerModular` class"
  2
  3import base64
  4import warnings
  5from functools import cached_property
  6from typing import (
  7	Iterable,
  8	Literal,
  9	Sequence,
 10	overload,
 11)
 12
 13from muutils.json_serialize import (
 14	SerializableDataclass,
 15	serializable_dataclass,
 16	serializable_field,
 17)
 18from muutils.misc import flatten
 19from muutils.misc.sequence import WhenMissing
 20
 21# from maze_dataset import SolvedMaze
 22from maze_dataset.constants import (
 23	VOCAB,
 24	VOCAB_LIST,
 25	VOCAB_TOKEN_TO_INDEX,
 26	Coord,
 27	CoordTup,
 28)
 29from maze_dataset.maze.lattice_maze import LatticeMaze
 30from maze_dataset.token_utils import (
 31	TokenizerPendingDeprecationWarning,
 32	strings_to_coords,
 33)
 34from maze_dataset.tokenization.common import TokenError
 35from maze_dataset.tokenization.maze_tokenizer_legacy import (
 36	MazeTokenizer,
 37	TokenizationMode,
 38)
 39from maze_dataset.tokenization.modular.element_base import (
 40	_load_tokenizer_element,
 41	_TokenizerElement,
 42)
 43from maze_dataset.tokenization.modular.elements import CoordTokenizers, PromptSequencers
 44from maze_dataset.tokenization.modular.fst_load import check_tokenizer_in_fst
 45from maze_dataset.tokenization.modular.hashing import (
 46	_hash_tokenizer_name,
 47)
 48
 49
 50@serializable_dataclass(
 51	frozen=True,
 52	kw_only=True,
 53	properties_to_serialize=["tokenizer_element_tree_concrete", "name"],
 54)
 55class MazeTokenizerModular(SerializableDataclass):
 56	"""Tokenizer for mazes
 57
 58	# Parameters
 59	- `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.
 60
 61	# Development
 62	- To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`.
 63	- Furthermore, the mapping reflected in `from_legacy` must also be maintained.
 64	- Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior.
 65	"""
 66
 67	prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field(
 68		default=PromptSequencers.AOTP(),
 69		loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers),
 70	)
 71
 72	def hash_int(self) -> int:
 73		"return integer hash using blake2b"
 74		return _hash_tokenizer_name(self.name)
 75
 76	def __hash__(self) -> int:
 77		"Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
 78		return self.hash_int()
 79
 80	def hash_b64(self, n_bytes: int = 8) -> str:
 81		"""filename-safe base64 encoding of the hash"""
 82		# Use modulus to ensure the integer fits within n_bytes * 8 bits
 83		hash_mod: int = self.hash_int() % (1 << (n_bytes * 8))
 84
 85		encoded = base64.b64encode(
 86			hash_mod.to_bytes(n_bytes, byteorder="big"),
 87			altchars=b"-_",
 88		).decode()
 89
 90		# Remove any padding equals signs
 91		return encoded.rstrip("=")
 92
 93	# Information Querying Methods
 94
 95	@cached_property
 96	def tokenizer_elements(self) -> list[_TokenizerElement]:
 97		"returns a list of all the elements of this tokenizer"
 98		return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]
 99
100	def tokenizer_element_tree(self, abstract: bool = False) -> str:
101		"""Returns a string representation of the tree of tokenizer elements contained in `self`.
102
103		# Parameters
104		- `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
105		"""
106		return "\n".join(
107			[
108				type(self).__name__,
109				self.prompt_sequencer.tokenizer_element_tree(
110					abstract=abstract,
111					depth=1,
112				),
113			],
114		)
115
116	@property
117	def tokenizer_element_tree_concrete(self) -> str:
118		"""Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`."""
119		return self.tokenizer_element_tree()
120
121	def tokenizer_element_dict(self) -> dict:
122		"""Nested dictionary of the internal `TokenizerElement`s."""
123		return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}
124
125	@property
126	def name(self) -> str:
127		"""Serializes MazeTokenizer into a key for encoding in zanj"""
128		return "-".join([type(self).__name__, self.prompt_sequencer.name])  # noqa: FLY002
129
130	def summary(self) -> dict[str, str]:
131		"""Single-level dictionary of the internal `TokenizerElement`s."""
132		return {
133			# "prompt_sequencer": self.prompt_sequencer.name,
134			**{elem.attribute_key(): elem.name for elem in self.tokenizer_elements},
135		}
136
137	@staticmethod
138	def _type_check(obj: any) -> None:
139		"""Helper method for `has_element`"""
140		if not (
141			isinstance(obj, _TokenizerElement)
142			or (isinstance(obj, type) and issubclass(obj, _TokenizerElement))
143		):
144			err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass."
145			raise TypeError(err_msg)
146
147	def _has_element_singular(
148		self,
149		el: type[_TokenizerElement] | _TokenizerElement,
150	) -> bool:
151		"""Helper method for `has_element`"""
152		self._type_check(el)
153		if isinstance(el, type):
154			return any(isinstance(e, el) for e in self.tokenizer_elements)
155		else:
156			return el in self.tokenizer_elements
157
158	def has_element(
159		self,
160		*elements: Sequence[type[_TokenizerElement] | _TokenizerElement],
161	) -> bool:
162		"""Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`.
163
164		Querying with a partial subset of `_TokenizerElement` fields is not currently supported.
165		To do such a query, assemble multiple calls to `has_elements`.
166
167		# Parameters
168		- `elements`: Singleton or iterable of `_TokenizerElement` instances or classes.
169		If an instance is provided, then comparison is done via instance equality.
170		If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted.
171		"""
172		if len(elements) == 1 and isinstance(elements[0], Iterable):
173			elements = elements[0]
174		return all(self._has_element_singular(e) for e in elements)
175
176	def is_valid(self, do_except: bool = False) -> bool:
177		"""Returns `True` if `self` is a valid tokenizer.
178
179		Evaluates the validity of all of `self.tokenizer_elements` according to each one's method.
180		"""
181		return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)
182
183	def is_legacy_equivalent(self) -> bool:
184		"""Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`."""
185		return any(
186			self == MazeTokenizerModular.from_legacy(tok_mode)
187			for tok_mode in TokenizationMode
188		)
189
190	def is_tested_tokenizer(self, do_except: bool = False) -> bool:
191		"""Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers.
192
193		uses an fst on the `name` attributes of all the tokenizers
194
195		if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested.
196		"""
197		is_valid: bool = self.is_valid(do_except=do_except)
198		in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except)
199
200		if do_except:
201			assert is_valid, "self.is_valid returns False"
202			return True
203		else:
204			return in_tested_fst and is_valid
205
206	def is_AOTP(self) -> bool:
207		"is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path"
208		return self.has_element(PromptSequencers.AOTP)
209
210	def is_UT(self) -> bool:
211		"is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)"
212		return self.has_element(CoordTokenizers.UT)
213
214	# Alternate Constructors
215	# ======================
216
217	@classmethod
218	def from_legacy(
219		cls,
220		legacy_maze_tokenizer: MazeTokenizer | TokenizationMode,
221	) -> "MazeTokenizerModular":
222		"""Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance."""
223		if isinstance(legacy_maze_tokenizer, MazeTokenizer):
224			legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode
225		return {
226			TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(),
227			TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(),
228			TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular(
229				prompt_sequencer=PromptSequencers.AOTP(
230					coord_tokenizer=CoordTokenizers.CTT(),
231				),
232			),
233		}[legacy_maze_tokenizer]
234
235	# Simple properties
236	# =================
237	@classmethod
238	def from_tokens(
239		cls,
240		tokens: str | list[str],
241	) -> "MazeTokenizerModular":
242		"""Infers most `MazeTokenizerModular` parameters from a full sequence of tokens."""
243		raise NotImplementedError(
244			"Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported",
245		)
246
247	@property
248	def token_arr(self) -> list[str] | None:
249		"""map from index to token"""
250		return VOCAB_LIST
251
252	@property
253	def tokenizer_map(self) -> dict[str, int]:
254		"""map from token to index"""
255		return VOCAB_TOKEN_TO_INDEX
256
257	@property
258	def vocab_size(self) -> int:
259		"""Number of tokens in the static vocab"""
260		return len(VOCAB_LIST)
261
262	@property
263	def n_tokens(self) -> int:
264		"get the number of tokens in the vocabulary (deprecated)"
265		err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead."
266		raise NameError(err_msg)
267
268	@property
269	def padding_token_index(self) -> int:
270		"get the index of the padding token"
271		return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]
272
273	# conversion functions
274	# ============================================================
275
276	def to_tokens(
277		self,
278		maze: LatticeMaze,
279	) -> list[str]:
280		"""Converts maze into a list of tokens."""
281		return self.prompt_sequencer.to_tokens(maze)
282
283	def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]:
284		"calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords"
285		return list(
286			flatten(
287				[self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords],
288			),
289		)
290
291	# TODO: unclear why we need to use `noqa: N805` here since its a classmethod
292	# maybe we need to hit every overload with `@classmethod`?
293	@overload
294	def strings_to_coords(
295		cls,  # noqa: N805
296		text: str | list[str],
297		when_noncoord: Literal["skip"] = "skip",
298	) -> list[CoordTup]: ...
299	@overload
300	def strings_to_coords(
301		cls,  # noqa: N805
302		text: str | list[str],
303		when_noncoord: Literal["error"] = "error",
304	) -> list[CoordTup]: ...
305	@overload
306	def strings_to_coords(
307		cls,  # noqa: N805
308		text: str | list[str],
309		when_noncoord: Literal["include"] = "include",
310	) -> list[str | CoordTup]: ...
311	@classmethod
312	def strings_to_coords(
313		cls,
314		text: str | list[str],
315		when_noncoord: WhenMissing = "skip",
316	) -> list[str | CoordTup]:
317		"wrapper for maze_dataset.token_utils.strings_to_coords"
318		warnings.warn(
319			"`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.",
320			TokenizerPendingDeprecationWarning,
321		)
322		return strings_to_coords(text=text, when_noncoord=when_noncoord)
323
324	@staticmethod
325	def encode(text: str | list[str]) -> list[int]:
326		"""encode a string or list of strings into a list of tokens"""
327		try:
328			if isinstance(text, str):
329				text = text.split()
330			return [VOCAB_TOKEN_TO_INDEX[token] for token in text]
331		except KeyError as e:
332			err_msg: str = f"Token {e} not found in `VOCAB`."
333			raise TokenError(err_msg) from e
334
335	@staticmethod
336	def decode(
337		token_ids: Sequence[int],
338		joined_tokens: bool = False,
339	) -> list[str] | str:
340		"""decode a list of tokens into a string or list of strings"""
341		try:
342			output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids]
343		except IndexError as e:
344			err_msg: str = f"Token index '{e}' not found in `VOCAB`."
345			raise TokenError(err_msg) from e
346		if joined_tokens:
347			return " ".join(output)
348		else:
349			return output

@serializable_dataclass(frozen=True, kw_only=True, properties_to_serialize=['tokenizer_element_tree_concrete', 'name'])
class MazeTokenizerModular(muutils.json_serialize.serializable_dataclass.SerializableDataclass):
 51@serializable_dataclass(
 52	frozen=True,
 53	kw_only=True,
 54	properties_to_serialize=["tokenizer_element_tree_concrete", "name"],
 55)
 56class MazeTokenizerModular(SerializableDataclass):
 57	"""Tokenizer for mazes
 58
 59	# Parameters
 60	- `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.
 61
 62	# Development
 63	- To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`.
 64	- Furthermore, the mapping reflected in `from_legacy` must also be maintained.
 65	- Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior.
 66	"""
 67
 68	prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field(
 69		default=PromptSequencers.AOTP(),
 70		loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers),
 71	)
 72
 73	def hash_int(self) -> int:
 74		"return integer hash using blake2b"
 75		return _hash_tokenizer_name(self.name)
 76
 77	def __hash__(self) -> int:
 78		"Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
 79		return self.hash_int()
 80
 81	def hash_b64(self, n_bytes: int = 8) -> str:
 82		"""filename-safe base64 encoding of the hash"""
 83		# Use modulus to ensure the integer fits within n_bytes * 8 bits
 84		hash_mod: int = self.hash_int() % (1 << (n_bytes * 8))
 85
 86		encoded = base64.b64encode(
 87			hash_mod.to_bytes(n_bytes, byteorder="big"),
 88			altchars=b"-_",
 89		).decode()
 90
 91		# Remove any padding equals signs
 92		return encoded.rstrip("=")
 93
 94	# Information Querying Methods
 95
 96	@cached_property
 97	def tokenizer_elements(self) -> list[_TokenizerElement]:
 98		"returns a list of all the elements of this tokenizer"
 99		return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]
100
101	def tokenizer_element_tree(self, abstract: bool = False) -> str:
102		"""Returns a string representation of the tree of tokenizer elements contained in `self`.
103
104		# Parameters
105		- `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
106		"""
107		return "\n".join(
108			[
109				type(self).__name__,
110				self.prompt_sequencer.tokenizer_element_tree(
111					abstract=abstract,
112					depth=1,
113				),
114			],
115		)
116
117	@property
118	def tokenizer_element_tree_concrete(self) -> str:
119		"""Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`."""
120		return self.tokenizer_element_tree()
121
122	def tokenizer_element_dict(self) -> dict:
123		"""Nested dictionary of the internal `TokenizerElement`s."""
124		return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}
125
126	@property
127	def name(self) -> str:
128		"""Serializes MazeTokenizer into a key for encoding in zanj"""
129		return "-".join([type(self).__name__, self.prompt_sequencer.name])  # noqa: FLY002
130
131	def summary(self) -> dict[str, str]:
132		"""Single-level dictionary of the internal `TokenizerElement`s."""
133		return {
134			# "prompt_sequencer": self.prompt_sequencer.name,
135			**{elem.attribute_key(): elem.name for elem in self.tokenizer_elements},
136		}
137
138	@staticmethod
139	def _type_check(obj: any) -> None:
140		"""Helper method for `has_element`"""
141		if not (
142			isinstance(obj, _TokenizerElement)
143			or (isinstance(obj, type) and issubclass(obj, _TokenizerElement))
144		):
145			err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass."
146			raise TypeError(err_msg)
147
148	def _has_element_singular(
149		self,
150		el: type[_TokenizerElement] | _TokenizerElement,
151	) -> bool:
152		"""Helper method for `has_element`"""
153		self._type_check(el)
154		if isinstance(el, type):
155			return any(isinstance(e, el) for e in self.tokenizer_elements)
156		else:
157			return el in self.tokenizer_elements
158
159	def has_element(
160		self,
161		*elements: Sequence[type[_TokenizerElement] | _TokenizerElement],
162	) -> bool:
163		"""Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`.
164
165		Querying with a partial subset of `_TokenizerElement` fields is not currently supported.
166		To do such a query, assemble multiple calls to `has_elements`.
167
168		# Parameters
169		- `elements`: Singleton or iterable of `_TokenizerElement` instances or classes.
170		If an instance is provided, then comparison is done via instance equality.
171		If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted.
172		"""
173		if len(elements) == 1 and isinstance(elements[0], Iterable):
174			elements = elements[0]
175		return all(self._has_element_singular(e) for e in elements)
176
177	def is_valid(self, do_except: bool = False) -> bool:
178		"""Returns `True` if `self` is a valid tokenizer.
179
180		Evaluates the validity of all of `self.tokenizer_elements` according to each one's method.
181		"""
182		return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)
183
184	def is_legacy_equivalent(self) -> bool:
185		"""Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`."""
186		return any(
187			self == MazeTokenizerModular.from_legacy(tok_mode)
188			for tok_mode in TokenizationMode
189		)
190
191	def is_tested_tokenizer(self, do_except: bool = False) -> bool:
192		"""Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers.
193
194		uses an fst on the `name` attributes of all the tokenizers
195
196		if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested.
197		"""
198		is_valid: bool = self.is_valid(do_except=do_except)
199		in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except)
200
201		if do_except:
202			assert is_valid, "self.is_valid returns False"
203			return True
204		else:
205			return in_tested_fst and is_valid
206
207	def is_AOTP(self) -> bool:
208		"is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path"
209		return self.has_element(PromptSequencers.AOTP)
210
211	def is_UT(self) -> bool:
212		"is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)"
213		return self.has_element(CoordTokenizers.UT)
214
215	# Alternate Constructors
216	# ======================
217
218	@classmethod
219	def from_legacy(
220		cls,
221		legacy_maze_tokenizer: MazeTokenizer | TokenizationMode,
222	) -> "MazeTokenizerModular":
223		"""Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance."""
224		if isinstance(legacy_maze_tokenizer, MazeTokenizer):
225			legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode
226		return {
227			TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(),
228			TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(),
229			TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular(
230				prompt_sequencer=PromptSequencers.AOTP(
231					coord_tokenizer=CoordTokenizers.CTT(),
232				),
233			),
234		}[legacy_maze_tokenizer]
235
236	# Simple properties
237	# =================
238	@classmethod
239	def from_tokens(
240		cls,
241		tokens: str | list[str],
242	) -> "MazeTokenizerModular":
243		"""Infers most `MazeTokenizerModular` parameters from a full sequence of tokens."""
244		raise NotImplementedError(
245			"Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported",
246		)
247
248	@property
249	def token_arr(self) -> list[str] | None:
250		"""map from index to token"""
251		return VOCAB_LIST
252
253	@property
254	def tokenizer_map(self) -> dict[str, int]:
255		"""map from token to index"""
256		return VOCAB_TOKEN_TO_INDEX
257
258	@property
259	def vocab_size(self) -> int:
260		"""Number of tokens in the static vocab"""
261		return len(VOCAB_LIST)
262
263	@property
264	def n_tokens(self) -> int:
265		"get the number of tokens in the vocabulary (deprecated)"
266		err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead."
267		raise NameError(err_msg)
268
269	@property
270	def padding_token_index(self) -> int:
271		"get the index of the padding token"
272		return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]
273
274	# conversion functions
275	# ============================================================
276
277	def to_tokens(
278		self,
279		maze: LatticeMaze,
280	) -> list[str]:
281		"""Converts maze into a list of tokens."""
282		return self.prompt_sequencer.to_tokens(maze)
283
284	def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]:
285		"calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords"
286		return list(
287			flatten(
288				[self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords],
289			),
290		)
291
292	# TODO: unclear why we need to use `noqa: N805` here since its a classmethod
293	# maybe we need to hit every overload with `@classmethod`?
294	@overload
295	def strings_to_coords(
296		cls,  # noqa: N805
297		text: str | list[str],
298		when_noncoord: Literal["skip"] = "skip",
299	) -> list[CoordTup]: ...
300	@overload
301	def strings_to_coords(
302		cls,  # noqa: N805
303		text: str | list[str],
304		when_noncoord: Literal["error"] = "error",
305	) -> list[CoordTup]: ...
306	@overload
307	def strings_to_coords(
308		cls,  # noqa: N805
309		text: str | list[str],
310		when_noncoord: Literal["include"] = "include",
311	) -> list[str | CoordTup]: ...
312	@classmethod
313	def strings_to_coords(
314		cls,
315		text: str | list[str],
316		when_noncoord: WhenMissing = "skip",
317	) -> list[str | CoordTup]:
318		"wrapper for maze_dataset.token_utils.strings_to_coords"
319		warnings.warn(
320			"`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.",
321			TokenizerPendingDeprecationWarning,
322		)
323		return strings_to_coords(text=text, when_noncoord=when_noncoord)
324
325	@staticmethod
326	def encode(text: str | list[str]) -> list[int]:
327		"""encode a string or list of strings into a list of tokens"""
328		try:
329			if isinstance(text, str):
330				text = text.split()
331			return [VOCAB_TOKEN_TO_INDEX[token] for token in text]
332		except KeyError as e:
333			err_msg: str = f"Token {e} not found in `VOCAB`."
334			raise TokenError(err_msg) from e
335
336	@staticmethod
337	def decode(
338		token_ids: Sequence[int],
339		joined_tokens: bool = False,
340	) -> list[str] | str:
341		"""decode a list of tokens into a string or list of strings"""
342		try:
343			output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids]
344		except IndexError as e:
345			err_msg: str = f"Token index '{e}' not found in `VOCAB`."
346			raise TokenError(err_msg) from e
347		if joined_tokens:
348			return " ".join(output)
349		else:
350			return output

Tokenizer for mazes

Parameters

  • prompt_sequencer: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.

Development

  • To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy TokenizationMode.AOTP_UT_Uniform.
  • Furthermore, the mapping reflected in from_legacy must also be maintained.
  • Updates to MazeTokenizerModular or the _TokenizerElement hierarchy must maintain that behavior.
MazeTokenizerModular( *, prompt_sequencer: maze_dataset.tokenization.modular.elements.PromptSequencers._PromptSequencer = PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.UT(), adj_list_tokenizer=AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), target_tokenizer=TargetTokenizers.Unlabeled(post=False), path_tokenizer=PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False)))
prompt_sequencer: maze_dataset.tokenization.modular.elements.PromptSequencers._PromptSequencer = PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.UT(), adj_list_tokenizer=AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), target_tokenizer=TargetTokenizers.Unlabeled(post=False), path_tokenizer=PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False))
def hash_int(self) -> int:
73	def hash_int(self) -> int:
74		"return integer hash using blake2b"
75		return _hash_tokenizer_name(self.name)

return integer hash using blake2b

def hash_b64(self, n_bytes: int = 8) -> str:
81	def hash_b64(self, n_bytes: int = 8) -> str:
82		"""filename-safe base64 encoding of the hash"""
83		# Use modulus to ensure the integer fits within n_bytes * 8 bits
84		hash_mod: int = self.hash_int() % (1 << (n_bytes * 8))
85
86		encoded = base64.b64encode(
87			hash_mod.to_bytes(n_bytes, byteorder="big"),
88			altchars=b"-_",
89		).decode()
90
91		# Remove any padding equals signs
92		return encoded.rstrip("=")

filename-safe base64 encoding of the hash

tokenizer_elements: list[maze_dataset.tokenization.modular.element_base._TokenizerElement]
96	@cached_property
97	def tokenizer_elements(self) -> list[_TokenizerElement]:
98		"returns a list of all the elements of this tokenizer"
99		return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]

returns a list of all the elements of this tokenizer

def tokenizer_element_tree(self, abstract: bool = False) -> str:
101	def tokenizer_element_tree(self, abstract: bool = False) -> str:
102		"""Returns a string representation of the tree of tokenizer elements contained in `self`.
103
104		# Parameters
105		- `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
106		"""
107		return "\n".join(
108			[
109				type(self).__name__,
110				self.prompt_sequencer.tokenizer_element_tree(
111					abstract=abstract,
112					depth=1,
113				),
114			],
115		)

Returns a string representation of the tree of tokenizer elements contained in self.

Parameters

  • abstract: bool: Whether to print the name of the abstract base class or the concrete class for each _TokenizerElement instance.
tokenizer_element_tree_concrete: str
117	@property
118	def tokenizer_element_tree_concrete(self) -> str:
119		"""Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`."""
120		return self.tokenizer_element_tree()

Property wrapper for tokenizer_element_tree so that it can be used in properties_to_serialize.

def tokenizer_element_dict(self) -> dict:
122	def tokenizer_element_dict(self) -> dict:
123		"""Nested dictionary of the internal `TokenizerElement`s."""
124		return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}

Nested dictionary of the internal TokenizerElements.

name: str
126	@property
127	def name(self) -> str:
128		"""Serializes MazeTokenizer into a key for encoding in zanj"""
129		return "-".join([type(self).__name__, self.prompt_sequencer.name])  # noqa: FLY002

Serializes MazeTokenizer into a key for encoding in zanj

def summary(self) -> dict[str, str]:
131	def summary(self) -> dict[str, str]:
132		"""Single-level dictionary of the internal `TokenizerElement`s."""
133		return {
134			# "prompt_sequencer": self.prompt_sequencer.name,
135			**{elem.attribute_key(): elem.name for elem in self.tokenizer_elements},
136		}

Single-level dictionary of the internal TokenizerElements.

def has_element( self, *elements: Sequence[type[maze_dataset.tokenization.modular.element_base._TokenizerElement] | maze_dataset.tokenization.modular.element_base._TokenizerElement]) -> bool:
159	def has_element(
160		self,
161		*elements: Sequence[type[_TokenizerElement] | _TokenizerElement],
162	) -> bool:
163		"""Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`.
164
165		Querying with a partial subset of `_TokenizerElement` fields is not currently supported.
166		To do such a query, assemble multiple calls to `has_elements`.
167
168		# Parameters
169		- `elements`: Singleton or iterable of `_TokenizerElement` instances or classes.
170		If an instance is provided, then comparison is done via instance equality.
171		If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted.
172		"""
173		if len(elements) == 1 and isinstance(elements[0], Iterable):
174			elements = elements[0]
175		return all(self._has_element_singular(e) for e in elements)

Returns True if the MazeTokenizerModular instance contains ALL of the items specified in elements.

Querying with a partial subset of _TokenizerElement fields is not currently supported. To do such a query, assemble multiple calls to has_elements.

Parameters

  • elements: Singleton or iterable of _TokenizerElement instances or classes. If an instance is provided, then comparison is done via instance equality. If a class is provided, then comparison isdone via isinstance. I.e., any instance of that class is accepted.
def is_valid(self, do_except: bool = False) -> bool:
177	def is_valid(self, do_except: bool = False) -> bool:
178		"""Returns `True` if `self` is a valid tokenizer.
179
180		Evaluates the validity of all of `self.tokenizer_elements` according to each one's method.
181		"""
182		return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)

Returns True if self is a valid tokenizer.

Evaluates the validity of all of self.tokenizer_elements according to each one's method.

def is_legacy_equivalent(self) -> bool:
184	def is_legacy_equivalent(self) -> bool:
185		"""Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`."""
186		return any(
187			self == MazeTokenizerModular.from_legacy(tok_mode)
188			for tok_mode in TokenizationMode
189		)

Returns if self has identical stringification behavior as any legacy MazeTokenizer.

def is_tested_tokenizer(self, do_except: bool = False) -> bool:
191	def is_tested_tokenizer(self, do_except: bool = False) -> bool:
192		"""Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers.
193
194		uses an fst on the `name` attributes of all the tokenizers
195
196		if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested.
197		"""
198		is_valid: bool = self.is_valid(do_except=do_except)
199		in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except)
200
201		if do_except:
202			assert is_valid, "self.is_valid returns False"
203			return True
204		else:
205			return in_tested_fst and is_valid

Returns if the tokenizer is returned by all_tokenizers.get_all_tokenizers, the set of tested and reliable tokenizers.

uses an fst on the name attributes of all the tokenizers

if do_assert is True, raises an AssertionError if the tokenizer is not tested.

def is_AOTP(self) -> bool:
207	def is_AOTP(self) -> bool:
208		"is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path"
209		return self.has_element(PromptSequencers.AOTP)

is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path

def is_UT(self) -> bool:
211	def is_UT(self) -> bool:
212		"is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)"
213		return self.has_element(CoordTokenizers.UT)

is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)

@classmethod
def from_legacy( cls, legacy_maze_tokenizer: maze_dataset.tokenization.MazeTokenizer | maze_dataset.tokenization.TokenizationMode) -> MazeTokenizerModular:
218	@classmethod
219	def from_legacy(
220		cls,
221		legacy_maze_tokenizer: MazeTokenizer | TokenizationMode,
222	) -> "MazeTokenizerModular":
223		"""Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance."""
224		if isinstance(legacy_maze_tokenizer, MazeTokenizer):
225			legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode
226		return {
227			TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(),
228			TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(),
229			TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular(
230				prompt_sequencer=PromptSequencers.AOTP(
231					coord_tokenizer=CoordTokenizers.CTT(),
232				),
233			),
234		}[legacy_maze_tokenizer]

Maps a legacy MazeTokenizer or TokenizationMode to its equivalent MazeTokenizerModular instance.

@classmethod
def from_tokens( cls, tokens: str | list[str]) -> MazeTokenizerModular:
238	@classmethod
239	def from_tokens(
240		cls,
241		tokens: str | list[str],
242	) -> "MazeTokenizerModular":
243		"""Infers most `MazeTokenizerModular` parameters from a full sequence of tokens."""
244		raise NotImplementedError(
245			"Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported",
246		)

Infers most MazeTokenizerModular parameters from a full sequence of tokens.

token_arr: list[str] | None
248	@property
249	def token_arr(self) -> list[str] | None:
250		"""map from index to token"""
251		return VOCAB_LIST

map from index to token

tokenizer_map: dict[str, int]
253	@property
254	def tokenizer_map(self) -> dict[str, int]:
255		"""map from token to index"""
256		return VOCAB_TOKEN_TO_INDEX

map from token to index

vocab_size: int
258	@property
259	def vocab_size(self) -> int:
260		"""Number of tokens in the static vocab"""
261		return len(VOCAB_LIST)

Number of tokens in the static vocab

n_tokens: int
263	@property
264	def n_tokens(self) -> int:
265		"get the number of tokens in the vocabulary (deprecated)"
266		err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead."
267		raise NameError(err_msg)

get the number of tokens in the vocabulary (deprecated)

padding_token_index: int
269	@property
270	def padding_token_index(self) -> int:
271		"get the index of the padding token"
272		return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]

get the index of the padding token

def to_tokens(self, maze: maze_dataset.LatticeMaze) -> list[str]:
277	def to_tokens(
278		self,
279		maze: LatticeMaze,
280	) -> list[str]:
281		"""Converts maze into a list of tokens."""
282		return self.prompt_sequencer.to_tokens(maze)

Converts maze into a list of tokens.

def coords_to_strings( self, coords: list[tuple[int, int] | jaxtyping.Int8[ndarray, 'row_col=2']]) -> list[str]:
284	def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]:
285		"calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords"
286		return list(
287			flatten(
288				[self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords],
289			),
290		)

calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords

@classmethod
def strings_to_coords( cls, text: str | list[str], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str | tuple[int, int]]:
312	@classmethod
313	def strings_to_coords(
314		cls,
315		text: str | list[str],
316		when_noncoord: WhenMissing = "skip",
317	) -> list[str | CoordTup]:
318		"wrapper for maze_dataset.token_utils.strings_to_coords"
319		warnings.warn(
320			"`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.",
321			TokenizerPendingDeprecationWarning,
322		)
323		return strings_to_coords(text=text, when_noncoord=when_noncoord)
@staticmethod
def encode(text: str | list[str]) -> list[int]:
325	@staticmethod
326	def encode(text: str | list[str]) -> list[int]:
327		"""encode a string or list of strings into a list of tokens"""
328		try:
329			if isinstance(text, str):
330				text = text.split()
331			return [VOCAB_TOKEN_TO_INDEX[token] for token in text]
332		except KeyError as e:
333			err_msg: str = f"Token {e} not found in `VOCAB`."
334			raise TokenError(err_msg) from e

encode a string or list of strings into a list of tokens

@staticmethod
def decode(token_ids: Sequence[int], joined_tokens: bool = False) -> list[str] | str:
336	@staticmethod
337	def decode(
338		token_ids: Sequence[int],
339		joined_tokens: bool = False,
340	) -> list[str] | str:
341		"""decode a list of tokens into a string or list of strings"""
342		try:
343			output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids]
344		except IndexError as e:
345			err_msg: str = f"Token index '{e}' not found in `VOCAB`."
346			raise TokenError(err_msg) from e
347		if joined_tokens:
348			return " ".join(output)
349		else:
350			return output

decode a list of tokens into a string or list of strings

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict