maze_dataset.tokenization.modular.maze_tokenizer_modular
implements the actual MazeTokenizerModular
class
1"implements the actual `MazeTokenizerModular` class" 2 3import base64 4import warnings 5from functools import cached_property 6from typing import ( 7 Iterable, 8 Literal, 9 Sequence, 10 overload, 11) 12 13from muutils.json_serialize import ( 14 SerializableDataclass, 15 serializable_dataclass, 16 serializable_field, 17) 18from muutils.misc import flatten 19from muutils.misc.sequence import WhenMissing 20 21# from maze_dataset import SolvedMaze 22from maze_dataset.constants import ( 23 VOCAB, 24 VOCAB_LIST, 25 VOCAB_TOKEN_TO_INDEX, 26 Coord, 27 CoordTup, 28) 29from maze_dataset.maze.lattice_maze import LatticeMaze 30from maze_dataset.token_utils import ( 31 TokenizerPendingDeprecationWarning, 32 strings_to_coords, 33) 34from maze_dataset.tokenization.common import TokenError 35from maze_dataset.tokenization.maze_tokenizer_legacy import ( 36 MazeTokenizer, 37 TokenizationMode, 38) 39from maze_dataset.tokenization.modular.element_base import ( 40 _load_tokenizer_element, 41 _TokenizerElement, 42) 43from maze_dataset.tokenization.modular.elements import CoordTokenizers, PromptSequencers 44from maze_dataset.tokenization.modular.fst_load import check_tokenizer_in_fst 45from maze_dataset.tokenization.modular.hashing import ( 46 _hash_tokenizer_name, 47) 48 49 50@serializable_dataclass( 51 frozen=True, 52 kw_only=True, 53 properties_to_serialize=["tokenizer_element_tree_concrete", "name"], 54) 55class MazeTokenizerModular(SerializableDataclass): 56 """Tokenizer for mazes 57 58 # Parameters 59 - `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt. 60 61 # Development 62 - To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`. 63 - Furthermore, the mapping reflected in `from_legacy` must also be maintained. 64 - Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior. 65 """ 66 67 prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field( 68 default=PromptSequencers.AOTP(), 69 loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers), 70 ) 71 72 def hash_int(self) -> int: 73 "return integer hash using blake2b" 74 return _hash_tokenizer_name(self.name) 75 76 def __hash__(self) -> int: 77 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" 78 return self.hash_int() 79 80 def hash_b64(self, n_bytes: int = 8) -> str: 81 """filename-safe base64 encoding of the hash""" 82 # Use modulus to ensure the integer fits within n_bytes * 8 bits 83 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8)) 84 85 encoded = base64.b64encode( 86 hash_mod.to_bytes(n_bytes, byteorder="big"), 87 altchars=b"-_", 88 ).decode() 89 90 # Remove any padding equals signs 91 return encoded.rstrip("=") 92 93 # Information Querying Methods 94 95 @cached_property 96 def tokenizer_elements(self) -> list[_TokenizerElement]: 97 "returns a list of all the elements of this tokenizer" 98 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()] 99 100 def tokenizer_element_tree(self, abstract: bool = False) -> str: 101 """Returns a string representation of the tree of tokenizer elements contained in `self`. 102 103 # Parameters 104 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 105 """ 106 return "\n".join( 107 [ 108 type(self).__name__, 109 self.prompt_sequencer.tokenizer_element_tree( 110 abstract=abstract, 111 depth=1, 112 ), 113 ], 114 ) 115 116 @property 117 def tokenizer_element_tree_concrete(self) -> str: 118 """Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`.""" 119 return self.tokenizer_element_tree() 120 121 def tokenizer_element_dict(self) -> dict: 122 """Nested dictionary of the internal `TokenizerElement`s.""" 123 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()} 124 125 @property 126 def name(self) -> str: 127 """Serializes MazeTokenizer into a key for encoding in zanj""" 128 return "-".join([type(self).__name__, self.prompt_sequencer.name]) # noqa: FLY002 129 130 def summary(self) -> dict[str, str]: 131 """Single-level dictionary of the internal `TokenizerElement`s.""" 132 return { 133 # "prompt_sequencer": self.prompt_sequencer.name, 134 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements}, 135 } 136 137 @staticmethod 138 def _type_check(obj: any) -> None: 139 """Helper method for `has_element`""" 140 if not ( 141 isinstance(obj, _TokenizerElement) 142 or (isinstance(obj, type) and issubclass(obj, _TokenizerElement)) 143 ): 144 err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass." 145 raise TypeError(err_msg) 146 147 def _has_element_singular( 148 self, 149 el: type[_TokenizerElement] | _TokenizerElement, 150 ) -> bool: 151 """Helper method for `has_element`""" 152 self._type_check(el) 153 if isinstance(el, type): 154 return any(isinstance(e, el) for e in self.tokenizer_elements) 155 else: 156 return el in self.tokenizer_elements 157 158 def has_element( 159 self, 160 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement], 161 ) -> bool: 162 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`. 163 164 Querying with a partial subset of `_TokenizerElement` fields is not currently supported. 165 To do such a query, assemble multiple calls to `has_elements`. 166 167 # Parameters 168 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes. 169 If an instance is provided, then comparison is done via instance equality. 170 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted. 171 """ 172 if len(elements) == 1 and isinstance(elements[0], Iterable): 173 elements = elements[0] 174 return all(self._has_element_singular(e) for e in elements) 175 176 def is_valid(self, do_except: bool = False) -> bool: 177 """Returns `True` if `self` is a valid tokenizer. 178 179 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method. 180 """ 181 return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements) 182 183 def is_legacy_equivalent(self) -> bool: 184 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`.""" 185 return any( 186 self == MazeTokenizerModular.from_legacy(tok_mode) 187 for tok_mode in TokenizationMode 188 ) 189 190 def is_tested_tokenizer(self, do_except: bool = False) -> bool: 191 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers. 192 193 uses an fst on the `name` attributes of all the tokenizers 194 195 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested. 196 """ 197 is_valid: bool = self.is_valid(do_except=do_except) 198 in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except) 199 200 if do_except: 201 assert is_valid, "self.is_valid returns False" 202 return True 203 else: 204 return in_tested_fst and is_valid 205 206 def is_AOTP(self) -> bool: 207 "is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path" 208 return self.has_element(PromptSequencers.AOTP) 209 210 def is_UT(self) -> bool: 211 "is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)" 212 return self.has_element(CoordTokenizers.UT) 213 214 # Alternate Constructors 215 # ====================== 216 217 @classmethod 218 def from_legacy( 219 cls, 220 legacy_maze_tokenizer: MazeTokenizer | TokenizationMode, 221 ) -> "MazeTokenizerModular": 222 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance.""" 223 if isinstance(legacy_maze_tokenizer, MazeTokenizer): 224 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode 225 return { 226 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(), 227 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(), 228 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular( 229 prompt_sequencer=PromptSequencers.AOTP( 230 coord_tokenizer=CoordTokenizers.CTT(), 231 ), 232 ), 233 }[legacy_maze_tokenizer] 234 235 # Simple properties 236 # ================= 237 @classmethod 238 def from_tokens( 239 cls, 240 tokens: str | list[str], 241 ) -> "MazeTokenizerModular": 242 """Infers most `MazeTokenizerModular` parameters from a full sequence of tokens.""" 243 raise NotImplementedError( 244 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported", 245 ) 246 247 @property 248 def token_arr(self) -> list[str] | None: 249 """map from index to token""" 250 return VOCAB_LIST 251 252 @property 253 def tokenizer_map(self) -> dict[str, int]: 254 """map from token to index""" 255 return VOCAB_TOKEN_TO_INDEX 256 257 @property 258 def vocab_size(self) -> int: 259 """Number of tokens in the static vocab""" 260 return len(VOCAB_LIST) 261 262 @property 263 def n_tokens(self) -> int: 264 "get the number of tokens in the vocabulary (deprecated)" 265 err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead." 266 raise NameError(err_msg) 267 268 @property 269 def padding_token_index(self) -> int: 270 "get the index of the padding token" 271 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING] 272 273 # conversion functions 274 # ============================================================ 275 276 def to_tokens( 277 self, 278 maze: LatticeMaze, 279 ) -> list[str]: 280 """Converts maze into a list of tokens.""" 281 return self.prompt_sequencer.to_tokens(maze) 282 283 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]: 284 "calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords" 285 return list( 286 flatten( 287 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords], 288 ), 289 ) 290 291 # TODO: unclear why we need to use `noqa: N805` here since its a classmethod 292 # maybe we need to hit every overload with `@classmethod`? 293 @overload 294 def strings_to_coords( 295 cls, # noqa: N805 296 text: str | list[str], 297 when_noncoord: Literal["skip"] = "skip", 298 ) -> list[CoordTup]: ... 299 @overload 300 def strings_to_coords( 301 cls, # noqa: N805 302 text: str | list[str], 303 when_noncoord: Literal["error"] = "error", 304 ) -> list[CoordTup]: ... 305 @overload 306 def strings_to_coords( 307 cls, # noqa: N805 308 text: str | list[str], 309 when_noncoord: Literal["include"] = "include", 310 ) -> list[str | CoordTup]: ... 311 @classmethod 312 def strings_to_coords( 313 cls, 314 text: str | list[str], 315 when_noncoord: WhenMissing = "skip", 316 ) -> list[str | CoordTup]: 317 "wrapper for maze_dataset.token_utils.strings_to_coords" 318 warnings.warn( 319 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.", 320 TokenizerPendingDeprecationWarning, 321 ) 322 return strings_to_coords(text=text, when_noncoord=when_noncoord) 323 324 @staticmethod 325 def encode(text: str | list[str]) -> list[int]: 326 """encode a string or list of strings into a list of tokens""" 327 try: 328 if isinstance(text, str): 329 text = text.split() 330 return [VOCAB_TOKEN_TO_INDEX[token] for token in text] 331 except KeyError as e: 332 err_msg: str = f"Token {e} not found in `VOCAB`." 333 raise TokenError(err_msg) from e 334 335 @staticmethod 336 def decode( 337 token_ids: Sequence[int], 338 joined_tokens: bool = False, 339 ) -> list[str] | str: 340 """decode a list of tokens into a string or list of strings""" 341 try: 342 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids] 343 except IndexError as e: 344 err_msg: str = f"Token index '{e}' not found in `VOCAB`." 345 raise TokenError(err_msg) from e 346 if joined_tokens: 347 return " ".join(output) 348 else: 349 return output
51@serializable_dataclass( 52 frozen=True, 53 kw_only=True, 54 properties_to_serialize=["tokenizer_element_tree_concrete", "name"], 55) 56class MazeTokenizerModular(SerializableDataclass): 57 """Tokenizer for mazes 58 59 # Parameters 60 - `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt. 61 62 # Development 63 - To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`. 64 - Furthermore, the mapping reflected in `from_legacy` must also be maintained. 65 - Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior. 66 """ 67 68 prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field( 69 default=PromptSequencers.AOTP(), 70 loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers), 71 ) 72 73 def hash_int(self) -> int: 74 "return integer hash using blake2b" 75 return _hash_tokenizer_name(self.name) 76 77 def __hash__(self) -> int: 78 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" 79 return self.hash_int() 80 81 def hash_b64(self, n_bytes: int = 8) -> str: 82 """filename-safe base64 encoding of the hash""" 83 # Use modulus to ensure the integer fits within n_bytes * 8 bits 84 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8)) 85 86 encoded = base64.b64encode( 87 hash_mod.to_bytes(n_bytes, byteorder="big"), 88 altchars=b"-_", 89 ).decode() 90 91 # Remove any padding equals signs 92 return encoded.rstrip("=") 93 94 # Information Querying Methods 95 96 @cached_property 97 def tokenizer_elements(self) -> list[_TokenizerElement]: 98 "returns a list of all the elements of this tokenizer" 99 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()] 100 101 def tokenizer_element_tree(self, abstract: bool = False) -> str: 102 """Returns a string representation of the tree of tokenizer elements contained in `self`. 103 104 # Parameters 105 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 106 """ 107 return "\n".join( 108 [ 109 type(self).__name__, 110 self.prompt_sequencer.tokenizer_element_tree( 111 abstract=abstract, 112 depth=1, 113 ), 114 ], 115 ) 116 117 @property 118 def tokenizer_element_tree_concrete(self) -> str: 119 """Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`.""" 120 return self.tokenizer_element_tree() 121 122 def tokenizer_element_dict(self) -> dict: 123 """Nested dictionary of the internal `TokenizerElement`s.""" 124 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()} 125 126 @property 127 def name(self) -> str: 128 """Serializes MazeTokenizer into a key for encoding in zanj""" 129 return "-".join([type(self).__name__, self.prompt_sequencer.name]) # noqa: FLY002 130 131 def summary(self) -> dict[str, str]: 132 """Single-level dictionary of the internal `TokenizerElement`s.""" 133 return { 134 # "prompt_sequencer": self.prompt_sequencer.name, 135 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements}, 136 } 137 138 @staticmethod 139 def _type_check(obj: any) -> None: 140 """Helper method for `has_element`""" 141 if not ( 142 isinstance(obj, _TokenizerElement) 143 or (isinstance(obj, type) and issubclass(obj, _TokenizerElement)) 144 ): 145 err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass." 146 raise TypeError(err_msg) 147 148 def _has_element_singular( 149 self, 150 el: type[_TokenizerElement] | _TokenizerElement, 151 ) -> bool: 152 """Helper method for `has_element`""" 153 self._type_check(el) 154 if isinstance(el, type): 155 return any(isinstance(e, el) for e in self.tokenizer_elements) 156 else: 157 return el in self.tokenizer_elements 158 159 def has_element( 160 self, 161 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement], 162 ) -> bool: 163 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`. 164 165 Querying with a partial subset of `_TokenizerElement` fields is not currently supported. 166 To do such a query, assemble multiple calls to `has_elements`. 167 168 # Parameters 169 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes. 170 If an instance is provided, then comparison is done via instance equality. 171 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted. 172 """ 173 if len(elements) == 1 and isinstance(elements[0], Iterable): 174 elements = elements[0] 175 return all(self._has_element_singular(e) for e in elements) 176 177 def is_valid(self, do_except: bool = False) -> bool: 178 """Returns `True` if `self` is a valid tokenizer. 179 180 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method. 181 """ 182 return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements) 183 184 def is_legacy_equivalent(self) -> bool: 185 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`.""" 186 return any( 187 self == MazeTokenizerModular.from_legacy(tok_mode) 188 for tok_mode in TokenizationMode 189 ) 190 191 def is_tested_tokenizer(self, do_except: bool = False) -> bool: 192 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers. 193 194 uses an fst on the `name` attributes of all the tokenizers 195 196 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested. 197 """ 198 is_valid: bool = self.is_valid(do_except=do_except) 199 in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except) 200 201 if do_except: 202 assert is_valid, "self.is_valid returns False" 203 return True 204 else: 205 return in_tested_fst and is_valid 206 207 def is_AOTP(self) -> bool: 208 "is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path" 209 return self.has_element(PromptSequencers.AOTP) 210 211 def is_UT(self) -> bool: 212 "is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)" 213 return self.has_element(CoordTokenizers.UT) 214 215 # Alternate Constructors 216 # ====================== 217 218 @classmethod 219 def from_legacy( 220 cls, 221 legacy_maze_tokenizer: MazeTokenizer | TokenizationMode, 222 ) -> "MazeTokenizerModular": 223 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance.""" 224 if isinstance(legacy_maze_tokenizer, MazeTokenizer): 225 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode 226 return { 227 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(), 228 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(), 229 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular( 230 prompt_sequencer=PromptSequencers.AOTP( 231 coord_tokenizer=CoordTokenizers.CTT(), 232 ), 233 ), 234 }[legacy_maze_tokenizer] 235 236 # Simple properties 237 # ================= 238 @classmethod 239 def from_tokens( 240 cls, 241 tokens: str | list[str], 242 ) -> "MazeTokenizerModular": 243 """Infers most `MazeTokenizerModular` parameters from a full sequence of tokens.""" 244 raise NotImplementedError( 245 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported", 246 ) 247 248 @property 249 def token_arr(self) -> list[str] | None: 250 """map from index to token""" 251 return VOCAB_LIST 252 253 @property 254 def tokenizer_map(self) -> dict[str, int]: 255 """map from token to index""" 256 return VOCAB_TOKEN_TO_INDEX 257 258 @property 259 def vocab_size(self) -> int: 260 """Number of tokens in the static vocab""" 261 return len(VOCAB_LIST) 262 263 @property 264 def n_tokens(self) -> int: 265 "get the number of tokens in the vocabulary (deprecated)" 266 err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead." 267 raise NameError(err_msg) 268 269 @property 270 def padding_token_index(self) -> int: 271 "get the index of the padding token" 272 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING] 273 274 # conversion functions 275 # ============================================================ 276 277 def to_tokens( 278 self, 279 maze: LatticeMaze, 280 ) -> list[str]: 281 """Converts maze into a list of tokens.""" 282 return self.prompt_sequencer.to_tokens(maze) 283 284 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]: 285 "calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords" 286 return list( 287 flatten( 288 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords], 289 ), 290 ) 291 292 # TODO: unclear why we need to use `noqa: N805` here since its a classmethod 293 # maybe we need to hit every overload with `@classmethod`? 294 @overload 295 def strings_to_coords( 296 cls, # noqa: N805 297 text: str | list[str], 298 when_noncoord: Literal["skip"] = "skip", 299 ) -> list[CoordTup]: ... 300 @overload 301 def strings_to_coords( 302 cls, # noqa: N805 303 text: str | list[str], 304 when_noncoord: Literal["error"] = "error", 305 ) -> list[CoordTup]: ... 306 @overload 307 def strings_to_coords( 308 cls, # noqa: N805 309 text: str | list[str], 310 when_noncoord: Literal["include"] = "include", 311 ) -> list[str | CoordTup]: ... 312 @classmethod 313 def strings_to_coords( 314 cls, 315 text: str | list[str], 316 when_noncoord: WhenMissing = "skip", 317 ) -> list[str | CoordTup]: 318 "wrapper for maze_dataset.token_utils.strings_to_coords" 319 warnings.warn( 320 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.", 321 TokenizerPendingDeprecationWarning, 322 ) 323 return strings_to_coords(text=text, when_noncoord=when_noncoord) 324 325 @staticmethod 326 def encode(text: str | list[str]) -> list[int]: 327 """encode a string or list of strings into a list of tokens""" 328 try: 329 if isinstance(text, str): 330 text = text.split() 331 return [VOCAB_TOKEN_TO_INDEX[token] for token in text] 332 except KeyError as e: 333 err_msg: str = f"Token {e} not found in `VOCAB`." 334 raise TokenError(err_msg) from e 335 336 @staticmethod 337 def decode( 338 token_ids: Sequence[int], 339 joined_tokens: bool = False, 340 ) -> list[str] | str: 341 """decode a list of tokens into a string or list of strings""" 342 try: 343 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids] 344 except IndexError as e: 345 err_msg: str = f"Token index '{e}' not found in `VOCAB`." 346 raise TokenError(err_msg) from e 347 if joined_tokens: 348 return " ".join(output) 349 else: 350 return output
Tokenizer for mazes
Parameters
prompt_sequencer
: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.
Development
- To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy
TokenizationMode.AOTP_UT_Uniform
. - Furthermore, the mapping reflected in
from_legacy
must also be maintained. - Updates to
MazeTokenizerModular
or the_TokenizerElement
hierarchy must maintain that behavior.
73 def hash_int(self) -> int: 74 "return integer hash using blake2b" 75 return _hash_tokenizer_name(self.name)
return integer hash using blake2b
81 def hash_b64(self, n_bytes: int = 8) -> str: 82 """filename-safe base64 encoding of the hash""" 83 # Use modulus to ensure the integer fits within n_bytes * 8 bits 84 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8)) 85 86 encoded = base64.b64encode( 87 hash_mod.to_bytes(n_bytes, byteorder="big"), 88 altchars=b"-_", 89 ).decode() 90 91 # Remove any padding equals signs 92 return encoded.rstrip("=")
filename-safe base64 encoding of the hash
96 @cached_property 97 def tokenizer_elements(self) -> list[_TokenizerElement]: 98 "returns a list of all the elements of this tokenizer" 99 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]
returns a list of all the elements of this tokenizer
101 def tokenizer_element_tree(self, abstract: bool = False) -> str: 102 """Returns a string representation of the tree of tokenizer elements contained in `self`. 103 104 # Parameters 105 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 106 """ 107 return "\n".join( 108 [ 109 type(self).__name__, 110 self.prompt_sequencer.tokenizer_element_tree( 111 abstract=abstract, 112 depth=1, 113 ), 114 ], 115 )
Returns a string representation of the tree of tokenizer elements contained in self
.
Parameters
abstract: bool
: Whether to print the name of the abstract base class or the concrete class for each_TokenizerElement
instance.
117 @property 118 def tokenizer_element_tree_concrete(self) -> str: 119 """Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`.""" 120 return self.tokenizer_element_tree()
Property wrapper for tokenizer_element_tree
so that it can be used in properties_to_serialize
.
122 def tokenizer_element_dict(self) -> dict: 123 """Nested dictionary of the internal `TokenizerElement`s.""" 124 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}
Nested dictionary of the internal TokenizerElement
s.
126 @property 127 def name(self) -> str: 128 """Serializes MazeTokenizer into a key for encoding in zanj""" 129 return "-".join([type(self).__name__, self.prompt_sequencer.name]) # noqa: FLY002
Serializes MazeTokenizer into a key for encoding in zanj
131 def summary(self) -> dict[str, str]: 132 """Single-level dictionary of the internal `TokenizerElement`s.""" 133 return { 134 # "prompt_sequencer": self.prompt_sequencer.name, 135 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements}, 136 }
Single-level dictionary of the internal TokenizerElement
s.
159 def has_element( 160 self, 161 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement], 162 ) -> bool: 163 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`. 164 165 Querying with a partial subset of `_TokenizerElement` fields is not currently supported. 166 To do such a query, assemble multiple calls to `has_elements`. 167 168 # Parameters 169 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes. 170 If an instance is provided, then comparison is done via instance equality. 171 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted. 172 """ 173 if len(elements) == 1 and isinstance(elements[0], Iterable): 174 elements = elements[0] 175 return all(self._has_element_singular(e) for e in elements)
Returns True if the MazeTokenizerModular
instance contains ALL of the items specified in elements
.
Querying with a partial subset of _TokenizerElement
fields is not currently supported.
To do such a query, assemble multiple calls to has_elements
.
Parameters
elements
: Singleton or iterable of_TokenizerElement
instances or classes. If an instance is provided, then comparison is done via instance equality. If a class is provided, then comparison isdone viaisinstance
. I.e., any instance of that class is accepted.
177 def is_valid(self, do_except: bool = False) -> bool: 178 """Returns `True` if `self` is a valid tokenizer. 179 180 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method. 181 """ 182 return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)
Returns True
if self
is a valid tokenizer.
Evaluates the validity of all of self.tokenizer_elements
according to each one's method.
184 def is_legacy_equivalent(self) -> bool: 185 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`.""" 186 return any( 187 self == MazeTokenizerModular.from_legacy(tok_mode) 188 for tok_mode in TokenizationMode 189 )
Returns if self
has identical stringification behavior as any legacy MazeTokenizer
.
191 def is_tested_tokenizer(self, do_except: bool = False) -> bool: 192 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers. 193 194 uses an fst on the `name` attributes of all the tokenizers 195 196 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested. 197 """ 198 is_valid: bool = self.is_valid(do_except=do_except) 199 in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except) 200 201 if do_except: 202 assert is_valid, "self.is_valid returns False" 203 return True 204 else: 205 return in_tested_fst and is_valid
Returns if the tokenizer is returned by all_tokenizers.get_all_tokenizers
, the set of tested and reliable tokenizers.
uses an fst on the name
attributes of all the tokenizers
if do_assert
is True
, raises an AssertionError
if the tokenizer is not tested.
207 def is_AOTP(self) -> bool: 208 "is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path" 209 return self.has_element(PromptSequencers.AOTP)
is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path
211 def is_UT(self) -> bool: 212 "is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)" 213 return self.has_element(CoordTokenizers.UT)
is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)
218 @classmethod 219 def from_legacy( 220 cls, 221 legacy_maze_tokenizer: MazeTokenizer | TokenizationMode, 222 ) -> "MazeTokenizerModular": 223 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance.""" 224 if isinstance(legacy_maze_tokenizer, MazeTokenizer): 225 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode 226 return { 227 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(), 228 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(), 229 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular( 230 prompt_sequencer=PromptSequencers.AOTP( 231 coord_tokenizer=CoordTokenizers.CTT(), 232 ), 233 ), 234 }[legacy_maze_tokenizer]
Maps a legacy MazeTokenizer
or TokenizationMode
to its equivalent MazeTokenizerModular
instance.
238 @classmethod 239 def from_tokens( 240 cls, 241 tokens: str | list[str], 242 ) -> "MazeTokenizerModular": 243 """Infers most `MazeTokenizerModular` parameters from a full sequence of tokens.""" 244 raise NotImplementedError( 245 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported", 246 )
Infers most MazeTokenizerModular
parameters from a full sequence of tokens.
248 @property 249 def token_arr(self) -> list[str] | None: 250 """map from index to token""" 251 return VOCAB_LIST
map from index to token
253 @property 254 def tokenizer_map(self) -> dict[str, int]: 255 """map from token to index""" 256 return VOCAB_TOKEN_TO_INDEX
map from token to index
258 @property 259 def vocab_size(self) -> int: 260 """Number of tokens in the static vocab""" 261 return len(VOCAB_LIST)
Number of tokens in the static vocab
263 @property 264 def n_tokens(self) -> int: 265 "get the number of tokens in the vocabulary (deprecated)" 266 err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead." 267 raise NameError(err_msg)
get the number of tokens in the vocabulary (deprecated)
269 @property 270 def padding_token_index(self) -> int: 271 "get the index of the padding token" 272 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]
get the index of the padding token
277 def to_tokens( 278 self, 279 maze: LatticeMaze, 280 ) -> list[str]: 281 """Converts maze into a list of tokens.""" 282 return self.prompt_sequencer.to_tokens(maze)
Converts maze into a list of tokens.
284 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]: 285 "calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords" 286 return list( 287 flatten( 288 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords], 289 ), 290 )
calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords
312 @classmethod 313 def strings_to_coords( 314 cls, 315 text: str | list[str], 316 when_noncoord: WhenMissing = "skip", 317 ) -> list[str | CoordTup]: 318 "wrapper for maze_dataset.token_utils.strings_to_coords" 319 warnings.warn( 320 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.", 321 TokenizerPendingDeprecationWarning, 322 ) 323 return strings_to_coords(text=text, when_noncoord=when_noncoord)
wrapper for maze_dataset.token_utils.strings_to_coords
325 @staticmethod 326 def encode(text: str | list[str]) -> list[int]: 327 """encode a string or list of strings into a list of tokens""" 328 try: 329 if isinstance(text, str): 330 text = text.split() 331 return [VOCAB_TOKEN_TO_INDEX[token] for token in text] 332 except KeyError as e: 333 err_msg: str = f"Token {e} not found in `VOCAB`." 334 raise TokenError(err_msg) from e
encode a string or list of strings into a list of tokens
336 @staticmethod 337 def decode( 338 token_ids: Sequence[int], 339 joined_tokens: bool = False, 340 ) -> list[str] | str: 341 """decode a list of tokens into a string or list of strings""" 342 try: 343 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids] 344 except IndexError as e: 345 err_msg: str = f"Token index '{e}' not found in `VOCAB`." 346 raise TokenError(err_msg) from e 347 if joined_tokens: 348 return " ".join(output) 349 else: 350 return output
decode a list of tokens into a string or list of strings
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict