Coverage for maze_dataset/tokenization/maze_tokenizer_legacy.py: 78%
172 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
1"""legacy tokenizer which uses a `TokenizationMode` enum and a `MazeTokenizer` class
3> [!CAUTION]
4> `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended
5> for use, but will remain for compatibility with existing code.
7"""
9import warnings
10from enum import Enum
11from functools import cached_property
12from typing import (
13 Callable,
14 Iterable,
15 Literal,
16 Mapping,
17 Sequence,
18 overload,
19)
21import numpy as np
22from muutils.json_serialize import (
23 SerializableDataclass,
24 serializable_dataclass,
25 serializable_field,
26)
27from muutils.kappa import Kappa
28from muutils.misc.sequence import WhenMissing
30# from maze_dataset import SolvedMaze
31from maze_dataset.constants import (
32 SPECIAL_TOKENS,
33 CoordTup,
34)
35from maze_dataset.token_utils import (
36 TokenizerPendingDeprecationWarning,
37 _coord_to_strings_indexed,
38 _coord_to_strings_UT,
39 coords_to_strings,
40 strings_to_coords,
41)
42from maze_dataset.tokenization.common import TokenError
43from maze_dataset.utils import corner_first_ndindex
46class TokenizationMode(Enum):
47 """legacy tokenization modes
49 > [!CAUTION]
50 > Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use.
51 > Use `MazeTokenizerModular` instead.
53 # Abbreviations:
54 - `AOTP`: Ajacency list, Origin, Target, Path
55 - `UT`: Unique Token (for each coordiate)
56 - `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)
58 # Modes:
59 - `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization
60 example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)`
61 - `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible
62 uses `corner_first_ndindex` function to order the tokens
63 - `AOTP_CTT_indexed`: each coordinate is a tuple of integers
64 """
66 AOTP_UT_rasterized = "AOTP_UT_rasterized"
67 AOTP_UT_uniform = "AOTP_UT_uniform"
68 AOTP_CTT_indexed = "AOTP_CTT_indexed"
70 def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer":
71 "convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`"
72 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)
75_NDINDEX_FUNC_MAP: dict[
76 TokenizationMode,
77 Callable[[int], Iterable[tuple[int, ...]]],
78] = {
79 TokenizationMode.AOTP_UT_rasterized: lambda n: list(np.ndindex(n, n)),
80 TokenizationMode.AOTP_UT_uniform: lambda n: corner_first_ndindex(n, 2),
81}
84def is_UT(tokenization_mode: TokenizationMode) -> bool:
85 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
86 return tokenization_mode in (
87 TokenizationMode.AOTP_UT_rasterized,
88 TokenizationMode.AOTP_UT_uniform,
89 )
92def get_tokens_up_to_path_start(
93 tokens: list[str],
94 include_start_coord: bool = True,
95 tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform,
96) -> list[str]:
97 """get tokens up to the path start token
99 # Parameters:
100 - `tokens : list[str]`
101 - `include_start_coord : bool`
102 (defaults to `True`)
103 - `tokenization_mode : TokenizationMode`
104 (defaults to `TokenizationMode.AOTP_UT_uniform`)
106 # Returns:
107 - `list[str]` subsequence of `tokens` up to the path start token
109 # Raises:
110 - `ValueError` : if `tokenization_mode` is invalid
111 """
112 warnings.warn(
113 "`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.",
114 TokenizerPendingDeprecationWarning,
115 )
116 path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1
117 if include_start_coord:
118 if is_UT(tokenization_mode):
119 return tokens[: path_start_idx + 1]
120 elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
121 return tokens[: path_start_idx + 5]
122 else:
123 err_msg: str = f"Invalid tokenization mode: {tokenization_mode}"
124 raise ValueError(err_msg)
125 else:
126 return tokens[:path_start_idx]
129_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE: list[str] = [
130 "name",
131 "max_grid_size",
132 "token_arr",
133 "tokenizer_map",
134 "vocab_size",
135 "padding_token_index",
136]
139@serializable_dataclass(
140 properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE,
141 kw_only=True,
142)
143class MazeTokenizer(SerializableDataclass):
144 """LEGACY Tokenizer for mazes
146 > [!CAUTION]
147 > `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended
148 > for use, but will remain for compatibility with existing code.
150 # Parameters:
151 - `tokenization_mode: TokenizationMode`
152 mode of tokenization. required.
153 - `max_grid_size: int | None`
154 maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text
156 # Properties
157 - `name: str`
158 auto-generated name of the tokenizer from mode and size
160 ## Conditional Properties
162 - `node_strings_map: Mapping[CoordTup, str]`
163 map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode
165 these all return `None` if `max_grid_size` is `None`.
166 Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None`
168 - `token_arr: list[str]`
169 list of tokens, in order of their indices in the vocabulary
170 - `tokenizer_map: Mapping[str, int]`
171 map from token to index
172 - `vocab_size: int`
173 size of the vocabulary
174 - `padding_token_index: int`
175 index of the padding token
177 # Methods
178 - `coords_to_strings(coords: list[CoordTup]) -> list[str]`
179 convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates
180 - `strings_to_coords(strings: list[str]) -> list[CoordTup]`
181 convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
183 """
185 # parameters
186 # ============================================================
188 tokenization_mode: TokenizationMode = serializable_field(
189 default=TokenizationMode.AOTP_UT_uniform,
190 serialization_fn=lambda x: x.value,
191 loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]],
192 )
194 max_grid_size: int | None = serializable_field(default=None)
196 # properties
197 # ============================================================
199 @property
200 def name(self) -> str:
201 """auto-generated name of the tokenizer from mode and size"""
202 max_grid_size_str: str = (
203 f"-g{self.max_grid_size}" if self.max_grid_size is not None else ""
204 )
205 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"
207 @cached_property
208 def _node_strings_map(self) -> Mapping[CoordTup, list[str]]:
209 """map a coordinate to a token"""
210 if self.tokenization_mode in (
211 TokenizationMode.AOTP_UT_rasterized,
212 TokenizationMode.AOTP_UT_uniform,
213 ):
214 return Kappa(_coord_to_strings_UT)
215 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
216 return Kappa(_coord_to_strings_indexed)
217 else:
218 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
219 raise ValueError(err_msg)
221 @cached_property
222 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None:
223 """map a coordinate to a token"""
224 if self.tokenization_mode in (
225 TokenizationMode.AOTP_UT_rasterized,
226 TokenizationMode.AOTP_UT_uniform,
227 ):
228 return None
229 else:
230 return self._node_strings_map
232 # conditional properties (on max_grid_size existing)
233 # ------------------------------------------------------------
235 @cached_property
236 def _token_arr(self) -> list[str]:
237 """map from index to token"""
238 if self.max_grid_size is None:
239 err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }"
240 raise ValueError(err_msg)
242 output: list[str] = list(SPECIAL_TOKENS.values())
244 if self.tokenization_mode in (
245 TokenizationMode.AOTP_UT_rasterized,
246 TokenizationMode.AOTP_UT_uniform,
247 ):
248 output.extend(
249 [
250 self._node_strings_map[coord][0]
251 for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode](
252 self.max_grid_size,
253 )
254 ],
255 )
256 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
257 # TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models
258 output.extend(
259 [
260 "(",
261 ",",
262 ")", # new special chars
263 *map(str, range(self.max_grid_size)), # numbers
264 ],
265 )
266 else:
267 err_msg: str = (
268 f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}",
269 )
270 raise ValueError(err_msg)
272 return output
274 @cached_property
275 def token_arr(self) -> list[str] | None:
276 "get the token array if the max_grid_size is specified"
277 if self.max_grid_size is None:
278 return None
279 return self._token_arr
281 @cached_property
282 def _tokenizer_map(self) -> dict[str, int]:
283 """map from token to index"""
284 return {token: i for i, token in enumerate(self._token_arr)}
286 @cached_property
287 def tokenizer_map(self) -> dict[str, int] | None:
288 "get the tokenizer map if the max_grid_size is specified"
289 if self.max_grid_size is None:
290 return None
291 return self._tokenizer_map
293 @property
294 def _vocab_size(self) -> int:
295 return len(self._token_arr)
297 @property
298 def vocab_size(self) -> int | None:
299 "get the size of the vocabulary if the max_grid_size is specified"
300 if self.max_grid_size is None:
301 return None
302 return self._vocab_size
304 @property
305 def _n_tokens(self) -> int:
306 # TODO: deprecate
307 return self._vocab_size
309 @property
310 def n_tokens(self) -> int | None:
311 "get the number of tokens if the max_grid_size is specified"
312 if self.max_grid_size is None:
313 return None
314 return self._n_tokens
316 @cached_property
317 def _padding_token_index(self) -> int:
318 return self.tokenizer_map[SPECIAL_TOKENS.PADDING]
320 @cached_property
321 def padding_token_index(self) -> int | None:
322 "get the index of the padding token if it exists"
323 if self.max_grid_size is None:
324 return None
325 return self._padding_token_index
327 # conversion functions
328 # ============================================================
330 @overload
331 def coords_to_strings(
332 self,
333 coords: list[str | CoordTup],
334 when_noncoord: Literal["include", "skip"] = "skip",
335 ) -> list[str]: ...
336 @overload
337 def coords_to_strings(
338 self,
339 coords: list[CoordTup],
340 when_noncoord: Literal["error"] = "error",
341 ) -> list[str]: ...
342 def coords_to_strings(
343 self,
344 coords: list[CoordTup],
345 when_noncoord: WhenMissing = "skip",
346 ) -> list[str]:
347 """map a list of coordinate tuples (and maybe other tokens) to strings
349 wraps `maze_dataset.token_utils.coords_to_strings` with either
350 `_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode
351 """
352 if self.tokenization_mode in (
353 TokenizationMode.AOTP_UT_rasterized,
354 TokenizationMode.AOTP_UT_uniform,
355 ):
356 return coords_to_strings(
357 coords=coords,
358 coord_to_strings_func=_coord_to_strings_UT,
359 when_noncoord=when_noncoord,
360 )
361 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed:
362 return coords_to_strings(
363 coords=coords,
364 coord_to_strings_func=_coord_to_strings_indexed,
365 when_noncoord=when_noncoord,
366 )
367 else:
368 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}"
369 raise ValueError(err_msg)
371 @overload
372 def strings_to_coords(
373 cls, # noqa: N805
374 text: str | list[str],
375 when_noncoord: Literal["skip"] = "skip",
376 ) -> list[CoordTup]: ...
377 @overload
378 def strings_to_coords(
379 cls, # noqa: N805
380 text: str | list[str],
381 when_noncoord: Literal["error"] = "error",
382 ) -> list[CoordTup]: ...
383 @overload
384 def strings_to_coords(
385 cls, # noqa: N805
386 text: str | list[str],
387 when_noncoord: Literal["include"] = "include",
388 ) -> list[str | CoordTup]: ...
389 @classmethod
390 def strings_to_coords(
391 cls,
392 text: str | list[str],
393 when_noncoord: WhenMissing = "skip",
394 ) -> list[str | CoordTup]:
395 "wrapper for `maze_dataset.token_utils.strings_to_coords`"
396 return strings_to_coords(text=text, when_noncoord=when_noncoord)
398 def encode(self, text: str | list[str]) -> list[int]:
399 """encode a string or list of strings into a list of tokens"""
400 try:
401 if isinstance(text, str):
402 text = text.split()
403 return [self.tokenizer_map[token] for token in text]
404 except KeyError as e:
405 err_msg: str = (
406 f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}"
407 )
408 raise TokenError(err_msg) from e
410 def decode(
411 self,
412 tokens: Sequence[int],
413 joined_tokens: bool = False,
414 ) -> list[str] | str:
415 """decode a list of tokens into a string or list of strings"""
416 try:
417 output: list[str] = [self.token_arr[token] for token in tokens]
418 except IndexError as e:
419 err_msg: str = (
420 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}"
421 )
422 raise TokenError(err_msg) from e
423 if joined_tokens:
424 return " ".join(output)
425 else:
426 return output
428 # UT-only coordinate stuff
429 # ============================================================
431 @cached_property
432 def coordinate_tokens_coords(self) -> dict[CoordTup, int]:
433 "map of coordiante tuples to their token ids, only valid for UT"
434 # print(f"{self.tokenization_mode = }")
435 if not self.is_UT():
436 err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }"
437 raise ValueError(err_msg)
439 if self.max_grid_size is None:
440 err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }"
441 raise ValueError(err_msg)
443 raw_converted: list[CoordTup | str] = self.strings_to_coords(
444 self.token_arr,
445 when_noncoord="include",
446 )
448 # filter out non-coordinates
449 return {
450 coord: i
451 for i, coord in enumerate(raw_converted)
452 if not isinstance(coord, str)
453 }
455 @cached_property
456 def coordinate_tokens_ids(self) -> dict[str, int]:
457 "map of coordinate tokens to their token ids, only valid for UT"
458 # checks performed in call
459 output: dict[str, int] = dict()
461 for coord, index in self.coordinate_tokens_coords.items():
462 _for_key: list[str] = self.coords_to_strings([coord])
463 assert len(_for_key) == 1
464 output[_for_key[0]] = index
466 return output
468 # other
469 # ============================================================
471 def summary(self) -> dict:
472 """returns a summary of the tokenization mode"""
473 return {
474 "tokenization_mode": self.tokenization_mode.value,
475 "max_grid_size": self.max_grid_size,
476 "vocab_size": self.vocab_size,
477 }
479 def is_AOTP(self) -> bool:
480 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path"""
481 return self.tokenization_mode in (
482 TokenizationMode.AOTP_UT_rasterized,
483 TokenizationMode.AOTP_UT_uniform,
484 TokenizationMode.AOTP_CTT_indexed,
485 )
487 def is_UT(self) -> bool:
488 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)"
489 return is_UT(self.tokenization_mode)
491 def clear_cache(self) -> None:
492 """clears all cached properties"""
493 # delete the properties only if they exist
494 for name, prop in self.__class__.__dict__.items():
495 if isinstance(prop, cached_property):
496 # if the property exists, delete it
497 try: # noqa: SIM105
498 delattr(self, name)
499 except AttributeError:
500 pass