Coverage for maze_dataset/tokenization/modular/maze_tokenizer_modular.py: 82%
128 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
1"implements the actual `MazeTokenizerModular` class"
3import base64
4import warnings
5from functools import cached_property
6from typing import (
7 Iterable,
8 Literal,
9 Sequence,
10 overload,
11)
13from muutils.json_serialize import (
14 SerializableDataclass,
15 serializable_dataclass,
16 serializable_field,
17)
18from muutils.misc import flatten
19from muutils.misc.sequence import WhenMissing
21# from maze_dataset import SolvedMaze
22from maze_dataset.constants import (
23 VOCAB,
24 VOCAB_LIST,
25 VOCAB_TOKEN_TO_INDEX,
26 Coord,
27 CoordTup,
28)
29from maze_dataset.maze.lattice_maze import LatticeMaze
30from maze_dataset.token_utils import (
31 TokenizerPendingDeprecationWarning,
32 strings_to_coords,
33)
34from maze_dataset.tokenization.common import TokenError
35from maze_dataset.tokenization.maze_tokenizer_legacy import (
36 MazeTokenizer,
37 TokenizationMode,
38)
39from maze_dataset.tokenization.modular.element_base import (
40 _load_tokenizer_element,
41 _TokenizerElement,
42)
43from maze_dataset.tokenization.modular.elements import CoordTokenizers, PromptSequencers
44from maze_dataset.tokenization.modular.fst_load import check_tokenizer_in_fst
45from maze_dataset.tokenization.modular.hashing import (
46 _hash_tokenizer_name,
47)
50@serializable_dataclass(
51 frozen=True,
52 kw_only=True,
53 properties_to_serialize=["tokenizer_element_tree_concrete", "name"],
54)
55class MazeTokenizerModular(SerializableDataclass):
56 """Tokenizer for mazes
58 # Parameters
59 - `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.
61 # Development
62 - To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`.
63 - Furthermore, the mapping reflected in `from_legacy` must also be maintained.
64 - Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior.
65 """
67 prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field(
68 default=PromptSequencers.AOTP(),
69 loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers),
70 )
72 def hash_int(self) -> int:
73 "return integer hash using blake2b"
74 return _hash_tokenizer_name(self.name)
76 def __hash__(self) -> int:
77 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
78 return self.hash_int()
80 def hash_b64(self, n_bytes: int = 8) -> str:
81 """filename-safe base64 encoding of the hash"""
82 # Use modulus to ensure the integer fits within n_bytes * 8 bits
83 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8))
85 encoded = base64.b64encode(
86 hash_mod.to_bytes(n_bytes, byteorder="big"),
87 altchars=b"-_",
88 ).decode()
90 # Remove any padding equals signs
91 return encoded.rstrip("=")
93 # Information Querying Methods
95 @cached_property
96 def tokenizer_elements(self) -> list[_TokenizerElement]:
97 "returns a list of all the elements of this tokenizer"
98 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]
100 def tokenizer_element_tree(self, abstract: bool = False) -> str:
101 """Returns a string representation of the tree of tokenizer elements contained in `self`.
103 # Parameters
104 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
105 """
106 return "\n".join(
107 [
108 type(self).__name__,
109 self.prompt_sequencer.tokenizer_element_tree(
110 abstract=abstract,
111 depth=1,
112 ),
113 ],
114 )
116 @property
117 def tokenizer_element_tree_concrete(self) -> str:
118 """Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`."""
119 return self.tokenizer_element_tree()
121 def tokenizer_element_dict(self) -> dict:
122 """Nested dictionary of the internal `TokenizerElement`s."""
123 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}
125 @property
126 def name(self) -> str:
127 """Serializes MazeTokenizer into a key for encoding in zanj"""
128 return "-".join([type(self).__name__, self.prompt_sequencer.name]) # noqa: FLY002
130 def summary(self) -> dict[str, str]:
131 """Single-level dictionary of the internal `TokenizerElement`s."""
132 return {
133 # "prompt_sequencer": self.prompt_sequencer.name,
134 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements},
135 }
137 @staticmethod
138 def _type_check(obj: any) -> None:
139 """Helper method for `has_element`"""
140 if not (
141 isinstance(obj, _TokenizerElement)
142 or (isinstance(obj, type) and issubclass(obj, _TokenizerElement))
143 ):
144 err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass."
145 raise TypeError(err_msg)
147 def _has_element_singular(
148 self,
149 el: type[_TokenizerElement] | _TokenizerElement,
150 ) -> bool:
151 """Helper method for `has_element`"""
152 self._type_check(el)
153 if isinstance(el, type):
154 return any(isinstance(e, el) for e in self.tokenizer_elements)
155 else:
156 return el in self.tokenizer_elements
158 def has_element(
159 self,
160 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement],
161 ) -> bool:
162 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`.
164 Querying with a partial subset of `_TokenizerElement` fields is not currently supported.
165 To do such a query, assemble multiple calls to `has_elements`.
167 # Parameters
168 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes.
169 If an instance is provided, then comparison is done via instance equality.
170 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted.
171 """
172 if len(elements) == 1 and isinstance(elements[0], Iterable):
173 elements = elements[0]
174 return all(self._has_element_singular(e) for e in elements)
176 def is_valid(self, do_except: bool = False) -> bool:
177 """Returns `True` if `self` is a valid tokenizer.
179 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method.
180 """
181 return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)
183 def is_legacy_equivalent(self) -> bool:
184 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`."""
185 return any(
186 self == MazeTokenizerModular.from_legacy(tok_mode)
187 for tok_mode in TokenizationMode
188 )
190 def is_tested_tokenizer(self, do_except: bool = False) -> bool:
191 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers.
193 uses an fst on the `name` attributes of all the tokenizers
195 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested.
196 """
197 is_valid: bool = self.is_valid(do_except=do_except)
198 in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except)
200 if do_except:
201 assert is_valid, "self.is_valid returns False"
202 return True
203 else:
204 return in_tested_fst and is_valid
206 def is_AOTP(self) -> bool:
207 "is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path"
208 return self.has_element(PromptSequencers.AOTP)
210 def is_UT(self) -> bool:
211 "is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)"
212 return self.has_element(CoordTokenizers.UT)
214 # Alternate Constructors
215 # ======================
217 @classmethod
218 def from_legacy(
219 cls,
220 legacy_maze_tokenizer: MazeTokenizer | TokenizationMode,
221 ) -> "MazeTokenizerModular":
222 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance."""
223 if isinstance(legacy_maze_tokenizer, MazeTokenizer):
224 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode
225 return {
226 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(),
227 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(),
228 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular(
229 prompt_sequencer=PromptSequencers.AOTP(
230 coord_tokenizer=CoordTokenizers.CTT(),
231 ),
232 ),
233 }[legacy_maze_tokenizer]
235 # Simple properties
236 # =================
237 @classmethod
238 def from_tokens(
239 cls,
240 tokens: str | list[str],
241 ) -> "MazeTokenizerModular":
242 """Infers most `MazeTokenizerModular` parameters from a full sequence of tokens."""
243 raise NotImplementedError(
244 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported",
245 )
247 @property
248 def token_arr(self) -> list[str] | None:
249 """map from index to token"""
250 return VOCAB_LIST
252 @property
253 def tokenizer_map(self) -> dict[str, int]:
254 """map from token to index"""
255 return VOCAB_TOKEN_TO_INDEX
257 @property
258 def vocab_size(self) -> int:
259 """Number of tokens in the static vocab"""
260 return len(VOCAB_LIST)
262 @property
263 def n_tokens(self) -> int:
264 "get the number of tokens in the vocabulary (deprecated)"
265 err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead."
266 raise NameError(err_msg)
268 @property
269 def padding_token_index(self) -> int:
270 "get the index of the padding token"
271 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]
273 # conversion functions
274 # ============================================================
276 def to_tokens(
277 self,
278 maze: LatticeMaze,
279 ) -> list[str]:
280 """Converts maze into a list of tokens."""
281 return self.prompt_sequencer.to_tokens(maze)
283 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]:
284 "calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords"
285 return list(
286 flatten(
287 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords],
288 ),
289 )
291 # TODO: unclear why we need to use `noqa: N805` here since its a classmethod
292 # maybe we need to hit every overload with `@classmethod`?
293 @overload
294 def strings_to_coords(
295 cls, # noqa: N805
296 text: str | list[str],
297 when_noncoord: Literal["skip"] = "skip",
298 ) -> list[CoordTup]: ...
299 @overload
300 def strings_to_coords(
301 cls, # noqa: N805
302 text: str | list[str],
303 when_noncoord: Literal["error"] = "error",
304 ) -> list[CoordTup]: ...
305 @overload
306 def strings_to_coords(
307 cls, # noqa: N805
308 text: str | list[str],
309 when_noncoord: Literal["include"] = "include",
310 ) -> list[str | CoordTup]: ...
311 @classmethod
312 def strings_to_coords(
313 cls,
314 text: str | list[str],
315 when_noncoord: WhenMissing = "skip",
316 ) -> list[str | CoordTup]:
317 "wrapper for maze_dataset.token_utils.strings_to_coords"
318 warnings.warn(
319 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.",
320 TokenizerPendingDeprecationWarning,
321 )
322 return strings_to_coords(text=text, when_noncoord=when_noncoord)
324 @staticmethod
325 def encode(text: str | list[str]) -> list[int]:
326 """encode a string or list of strings into a list of tokens"""
327 try:
328 if isinstance(text, str):
329 text = text.split()
330 return [VOCAB_TOKEN_TO_INDEX[token] for token in text]
331 except KeyError as e:
332 err_msg: str = f"Token {e} not found in `VOCAB`."
333 raise TokenError(err_msg) from e
335 @staticmethod
336 def decode(
337 token_ids: Sequence[int],
338 joined_tokens: bool = False,
339 ) -> list[str] | str:
340 """decode a list of tokens into a string or list of strings"""
341 try:
342 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids]
343 except IndexError as e:
344 err_msg: str = f"Token index '{e}' not found in `VOCAB`."
345 raise TokenError(err_msg) from e
346 if joined_tokens:
347 return " ".join(output)
348 else:
349 return output