maze_dataset.tokenization.maze_tokenizer
preserving legacy imports
1"""preserving legacy imports""" 2 3from maze_dataset.tokenization.maze_tokenizer_legacy import ( 4 MazeTokenizer, 5 TokenizationMode, 6) 7from maze_dataset.tokenization.modular.maze_tokenizer_modular import ( 8 MazeTokenizerModular, 9) 10 11__all__ = [ 12 "MazeTokenizer", 13 "TokenizationMode", 14 "MazeTokenizerModular", 15]
140@serializable_dataclass( 141 properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, 142 kw_only=True, 143) 144class MazeTokenizer(SerializableDataclass): 145 """LEGACY Tokenizer for mazes 146 147 > [!CAUTION] 148 > `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended 149 > for use, but will remain for compatibility with existing code. 150 151 # Parameters: 152 - `tokenization_mode: TokenizationMode` 153 mode of tokenization. required. 154 - `max_grid_size: int | None` 155 maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text 156 157 # Properties 158 - `name: str` 159 auto-generated name of the tokenizer from mode and size 160 161 ## Conditional Properties 162 163 - `node_strings_map: Mapping[CoordTup, str]` 164 map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode 165 166 these all return `None` if `max_grid_size` is `None`. 167 Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None` 168 169 - `token_arr: list[str]` 170 list of tokens, in order of their indices in the vocabulary 171 - `tokenizer_map: Mapping[str, int]` 172 map from token to index 173 - `vocab_size: int` 174 size of the vocabulary 175 - `padding_token_index: int` 176 index of the padding token 177 178 # Methods 179 - `coords_to_strings(coords: list[CoordTup]) -> list[str]` 180 convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates 181 - `strings_to_coords(strings: list[str]) -> list[CoordTup]` 182 convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates 183 184 """ 185 186 # parameters 187 # ============================================================ 188 189 tokenization_mode: TokenizationMode = serializable_field( 190 default=TokenizationMode.AOTP_UT_uniform, 191 serialization_fn=lambda x: x.value, 192 loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]], 193 ) 194 195 max_grid_size: int | None = serializable_field(default=None) 196 197 # properties 198 # ============================================================ 199 200 @property 201 def name(self) -> str: 202 """auto-generated name of the tokenizer from mode and size""" 203 max_grid_size_str: str = ( 204 f"-g{self.max_grid_size}" if self.max_grid_size is not None else "" 205 ) 206 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}" 207 208 @cached_property 209 def _node_strings_map(self) -> Mapping[CoordTup, list[str]]: 210 """map a coordinate to a token""" 211 if self.tokenization_mode in ( 212 TokenizationMode.AOTP_UT_rasterized, 213 TokenizationMode.AOTP_UT_uniform, 214 ): 215 return Kappa(_coord_to_strings_UT) 216 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 217 return Kappa(_coord_to_strings_indexed) 218 else: 219 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 220 raise ValueError(err_msg) 221 222 @cached_property 223 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None: 224 """map a coordinate to a token""" 225 if self.tokenization_mode in ( 226 TokenizationMode.AOTP_UT_rasterized, 227 TokenizationMode.AOTP_UT_uniform, 228 ): 229 return None 230 else: 231 return self._node_strings_map 232 233 # conditional properties (on max_grid_size existing) 234 # ------------------------------------------------------------ 235 236 @cached_property 237 def _token_arr(self) -> list[str]: 238 """map from index to token""" 239 if self.max_grid_size is None: 240 err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }" 241 raise ValueError(err_msg) 242 243 output: list[str] = list(SPECIAL_TOKENS.values()) 244 245 if self.tokenization_mode in ( 246 TokenizationMode.AOTP_UT_rasterized, 247 TokenizationMode.AOTP_UT_uniform, 248 ): 249 output.extend( 250 [ 251 self._node_strings_map[coord][0] 252 for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode]( 253 self.max_grid_size, 254 ) 255 ], 256 ) 257 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 258 # TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models 259 output.extend( 260 [ 261 "(", 262 ",", 263 ")", # new special chars 264 *map(str, range(self.max_grid_size)), # numbers 265 ], 266 ) 267 else: 268 err_msg: str = ( 269 f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}", 270 ) 271 raise ValueError(err_msg) 272 273 return output 274 275 @cached_property 276 def token_arr(self) -> list[str] | None: 277 "get the token array if the max_grid_size is specified" 278 if self.max_grid_size is None: 279 return None 280 return self._token_arr 281 282 @cached_property 283 def _tokenizer_map(self) -> dict[str, int]: 284 """map from token to index""" 285 return {token: i for i, token in enumerate(self._token_arr)} 286 287 @cached_property 288 def tokenizer_map(self) -> dict[str, int] | None: 289 "get the tokenizer map if the max_grid_size is specified" 290 if self.max_grid_size is None: 291 return None 292 return self._tokenizer_map 293 294 @property 295 def _vocab_size(self) -> int: 296 return len(self._token_arr) 297 298 @property 299 def vocab_size(self) -> int | None: 300 "get the size of the vocabulary if the max_grid_size is specified" 301 if self.max_grid_size is None: 302 return None 303 return self._vocab_size 304 305 @property 306 def _n_tokens(self) -> int: 307 # TODO: deprecate 308 return self._vocab_size 309 310 @property 311 def n_tokens(self) -> int | None: 312 "get the number of tokens if the max_grid_size is specified" 313 if self.max_grid_size is None: 314 return None 315 return self._n_tokens 316 317 @cached_property 318 def _padding_token_index(self) -> int: 319 return self.tokenizer_map[SPECIAL_TOKENS.PADDING] 320 321 @cached_property 322 def padding_token_index(self) -> int | None: 323 "get the index of the padding token if it exists" 324 if self.max_grid_size is None: 325 return None 326 return self._padding_token_index 327 328 # conversion functions 329 # ============================================================ 330 331 @overload 332 def coords_to_strings( 333 self, 334 coords: list[str | CoordTup], 335 when_noncoord: Literal["include", "skip"] = "skip", 336 ) -> list[str]: ... 337 @overload 338 def coords_to_strings( 339 self, 340 coords: list[CoordTup], 341 when_noncoord: Literal["error"] = "error", 342 ) -> list[str]: ... 343 def coords_to_strings( 344 self, 345 coords: list[CoordTup], 346 when_noncoord: WhenMissing = "skip", 347 ) -> list[str]: 348 """map a list of coordinate tuples (and maybe other tokens) to strings 349 350 wraps `maze_dataset.token_utils.coords_to_strings` with either 351 `_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode 352 """ 353 if self.tokenization_mode in ( 354 TokenizationMode.AOTP_UT_rasterized, 355 TokenizationMode.AOTP_UT_uniform, 356 ): 357 return coords_to_strings( 358 coords=coords, 359 coord_to_strings_func=_coord_to_strings_UT, 360 when_noncoord=when_noncoord, 361 ) 362 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 363 return coords_to_strings( 364 coords=coords, 365 coord_to_strings_func=_coord_to_strings_indexed, 366 when_noncoord=when_noncoord, 367 ) 368 else: 369 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 370 raise ValueError(err_msg) 371 372 @overload 373 def strings_to_coords( 374 cls, # noqa: N805 375 text: str | list[str], 376 when_noncoord: Literal["skip"] = "skip", 377 ) -> list[CoordTup]: ... 378 @overload 379 def strings_to_coords( 380 cls, # noqa: N805 381 text: str | list[str], 382 when_noncoord: Literal["error"] = "error", 383 ) -> list[CoordTup]: ... 384 @overload 385 def strings_to_coords( 386 cls, # noqa: N805 387 text: str | list[str], 388 when_noncoord: Literal["include"] = "include", 389 ) -> list[str | CoordTup]: ... 390 @classmethod 391 def strings_to_coords( 392 cls, 393 text: str | list[str], 394 when_noncoord: WhenMissing = "skip", 395 ) -> list[str | CoordTup]: 396 "wrapper for `maze_dataset.token_utils.strings_to_coords`" 397 return strings_to_coords(text=text, when_noncoord=when_noncoord) 398 399 def encode(self, text: str | list[str]) -> list[int]: 400 """encode a string or list of strings into a list of tokens""" 401 try: 402 if isinstance(text, str): 403 text = text.split() 404 return [self.tokenizer_map[token] for token in text] 405 except KeyError as e: 406 err_msg: str = ( 407 f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}" 408 ) 409 raise TokenError(err_msg) from e 410 411 def decode( 412 self, 413 tokens: Sequence[int], 414 joined_tokens: bool = False, 415 ) -> list[str] | str: 416 """decode a list of tokens into a string or list of strings""" 417 try: 418 output: list[str] = [self.token_arr[token] for token in tokens] 419 except IndexError as e: 420 err_msg: str = ( 421 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}" 422 ) 423 raise TokenError(err_msg) from e 424 if joined_tokens: 425 return " ".join(output) 426 else: 427 return output 428 429 # UT-only coordinate stuff 430 # ============================================================ 431 432 @cached_property 433 def coordinate_tokens_coords(self) -> dict[CoordTup, int]: 434 "map of coordiante tuples to their token ids, only valid for UT" 435 # print(f"{self.tokenization_mode = }") 436 if not self.is_UT(): 437 err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }" 438 raise ValueError(err_msg) 439 440 if self.max_grid_size is None: 441 err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }" 442 raise ValueError(err_msg) 443 444 raw_converted: list[CoordTup | str] = self.strings_to_coords( 445 self.token_arr, 446 when_noncoord="include", 447 ) 448 449 # filter out non-coordinates 450 return { 451 coord: i 452 for i, coord in enumerate(raw_converted) 453 if not isinstance(coord, str) 454 } 455 456 @cached_property 457 def coordinate_tokens_ids(self) -> dict[str, int]: 458 "map of coordinate tokens to their token ids, only valid for UT" 459 # checks performed in call 460 output: dict[str, int] = dict() 461 462 for coord, index in self.coordinate_tokens_coords.items(): 463 _for_key: list[str] = self.coords_to_strings([coord]) 464 assert len(_for_key) == 1 465 output[_for_key[0]] = index 466 467 return output 468 469 # other 470 # ============================================================ 471 472 def summary(self) -> dict: 473 """returns a summary of the tokenization mode""" 474 return { 475 "tokenization_mode": self.tokenization_mode.value, 476 "max_grid_size": self.max_grid_size, 477 "vocab_size": self.vocab_size, 478 } 479 480 def is_AOTP(self) -> bool: 481 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path""" 482 return self.tokenization_mode in ( 483 TokenizationMode.AOTP_UT_rasterized, 484 TokenizationMode.AOTP_UT_uniform, 485 TokenizationMode.AOTP_CTT_indexed, 486 ) 487 488 def is_UT(self) -> bool: 489 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 490 return is_UT(self.tokenization_mode) 491 492 def clear_cache(self) -> None: 493 """clears all cached properties""" 494 # delete the properties only if they exist 495 for name, prop in self.__class__.__dict__.items(): 496 if isinstance(prop, cached_property): 497 # if the property exists, delete it 498 try: # noqa: SIM105 499 delattr(self, name) 500 except AttributeError: 501 pass
LEGACY Tokenizer for mazes
MazeTokenizerModular
is the new standard for tokenization. This class is no longer recommended
for use, but will remain for compatibility with existing code.
Parameters:
tokenization_mode: TokenizationMode
mode of tokenization. required.max_grid_size: int | None
maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text
Properties
name: str
auto-generated name of the tokenizer from mode and size
Conditional Properties
node_strings_map: Mapping[CoordTup, str]
map from node to string. This returns amuutils.kappa.Kappa
object which you can use like a dictionary. returnsNone
if not aUT
mode
these all return None
if max_grid_size
is None
.
Prepend _
to the name to get a guaranteed type, and cause an exception if max_grid_size
is None
token_arr: list[str]
list of tokens, in order of their indices in the vocabularytokenizer_map: Mapping[str, int]
map from token to indexvocab_size: int
size of the vocabularypadding_token_index: int
index of the padding token
Methods
coords_to_strings(coords: list[CoordTup]) -> list[str]
convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinatesstrings_to_coords(strings: list[str]) -> list[CoordTup]
convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
200 @property 201 def name(self) -> str: 202 """auto-generated name of the tokenizer from mode and size""" 203 max_grid_size_str: str = ( 204 f"-g{self.max_grid_size}" if self.max_grid_size is not None else "" 205 ) 206 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"
auto-generated name of the tokenizer from mode and size
222 @cached_property 223 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None: 224 """map a coordinate to a token""" 225 if self.tokenization_mode in ( 226 TokenizationMode.AOTP_UT_rasterized, 227 TokenizationMode.AOTP_UT_uniform, 228 ): 229 return None 230 else: 231 return self._node_strings_map
map a coordinate to a token
275 @cached_property 276 def token_arr(self) -> list[str] | None: 277 "get the token array if the max_grid_size is specified" 278 if self.max_grid_size is None: 279 return None 280 return self._token_arr
get the token array if the max_grid_size is specified
287 @cached_property 288 def tokenizer_map(self) -> dict[str, int] | None: 289 "get the tokenizer map if the max_grid_size is specified" 290 if self.max_grid_size is None: 291 return None 292 return self._tokenizer_map
get the tokenizer map if the max_grid_size is specified
298 @property 299 def vocab_size(self) -> int | None: 300 "get the size of the vocabulary if the max_grid_size is specified" 301 if self.max_grid_size is None: 302 return None 303 return self._vocab_size
get the size of the vocabulary if the max_grid_size is specified
310 @property 311 def n_tokens(self) -> int | None: 312 "get the number of tokens if the max_grid_size is specified" 313 if self.max_grid_size is None: 314 return None 315 return self._n_tokens
get the number of tokens if the max_grid_size is specified
321 @cached_property 322 def padding_token_index(self) -> int | None: 323 "get the index of the padding token if it exists" 324 if self.max_grid_size is None: 325 return None 326 return self._padding_token_index
get the index of the padding token if it exists
343 def coords_to_strings( 344 self, 345 coords: list[CoordTup], 346 when_noncoord: WhenMissing = "skip", 347 ) -> list[str]: 348 """map a list of coordinate tuples (and maybe other tokens) to strings 349 350 wraps `maze_dataset.token_utils.coords_to_strings` with either 351 `_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode 352 """ 353 if self.tokenization_mode in ( 354 TokenizationMode.AOTP_UT_rasterized, 355 TokenizationMode.AOTP_UT_uniform, 356 ): 357 return coords_to_strings( 358 coords=coords, 359 coord_to_strings_func=_coord_to_strings_UT, 360 when_noncoord=when_noncoord, 361 ) 362 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 363 return coords_to_strings( 364 coords=coords, 365 coord_to_strings_func=_coord_to_strings_indexed, 366 when_noncoord=when_noncoord, 367 ) 368 else: 369 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 370 raise ValueError(err_msg)
map a list of coordinate tuples (and maybe other tokens) to strings
wraps maze_dataset.token_utils.coords_to_strings
with either
_coord_to_strings_UT
or _coord_to_strings_indexed
depending on the tokenization mode
390 @classmethod 391 def strings_to_coords( 392 cls, 393 text: str | list[str], 394 when_noncoord: WhenMissing = "skip", 395 ) -> list[str | CoordTup]: 396 "wrapper for `maze_dataset.token_utils.strings_to_coords`" 397 return strings_to_coords(text=text, when_noncoord=when_noncoord)
wrapper for maze_dataset.token_utils.strings_to_coords
399 def encode(self, text: str | list[str]) -> list[int]: 400 """encode a string or list of strings into a list of tokens""" 401 try: 402 if isinstance(text, str): 403 text = text.split() 404 return [self.tokenizer_map[token] for token in text] 405 except KeyError as e: 406 err_msg: str = ( 407 f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}" 408 ) 409 raise TokenError(err_msg) from e
encode a string or list of strings into a list of tokens
411 def decode( 412 self, 413 tokens: Sequence[int], 414 joined_tokens: bool = False, 415 ) -> list[str] | str: 416 """decode a list of tokens into a string or list of strings""" 417 try: 418 output: list[str] = [self.token_arr[token] for token in tokens] 419 except IndexError as e: 420 err_msg: str = ( 421 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}" 422 ) 423 raise TokenError(err_msg) from e 424 if joined_tokens: 425 return " ".join(output) 426 else: 427 return output
decode a list of tokens into a string or list of strings
432 @cached_property 433 def coordinate_tokens_coords(self) -> dict[CoordTup, int]: 434 "map of coordiante tuples to their token ids, only valid for UT" 435 # print(f"{self.tokenization_mode = }") 436 if not self.is_UT(): 437 err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }" 438 raise ValueError(err_msg) 439 440 if self.max_grid_size is None: 441 err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }" 442 raise ValueError(err_msg) 443 444 raw_converted: list[CoordTup | str] = self.strings_to_coords( 445 self.token_arr, 446 when_noncoord="include", 447 ) 448 449 # filter out non-coordinates 450 return { 451 coord: i 452 for i, coord in enumerate(raw_converted) 453 if not isinstance(coord, str) 454 }
map of coordiante tuples to their token ids, only valid for UT
456 @cached_property 457 def coordinate_tokens_ids(self) -> dict[str, int]: 458 "map of coordinate tokens to their token ids, only valid for UT" 459 # checks performed in call 460 output: dict[str, int] = dict() 461 462 for coord, index in self.coordinate_tokens_coords.items(): 463 _for_key: list[str] = self.coords_to_strings([coord]) 464 assert len(_for_key) == 1 465 output[_for_key[0]] = index 466 467 return output
map of coordinate tokens to their token ids, only valid for UT
472 def summary(self) -> dict: 473 """returns a summary of the tokenization mode""" 474 return { 475 "tokenization_mode": self.tokenization_mode.value, 476 "max_grid_size": self.max_grid_size, 477 "vocab_size": self.vocab_size, 478 }
returns a summary of the tokenization mode
480 def is_AOTP(self) -> bool: 481 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path""" 482 return self.tokenization_mode in ( 483 TokenizationMode.AOTP_UT_rasterized, 484 TokenizationMode.AOTP_UT_uniform, 485 TokenizationMode.AOTP_CTT_indexed, 486 )
returns true if a tokenization mode is Adjacency list, Origin, Target, Path
488 def is_UT(self) -> bool: 489 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 490 return is_UT(self.tokenization_mode)
returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)
492 def clear_cache(self) -> None: 493 """clears all cached properties""" 494 # delete the properties only if they exist 495 for name, prop in self.__class__.__dict__.items(): 496 if isinstance(prop, cached_property): 497 # if the property exists, delete it 498 try: # noqa: SIM105 499 delattr(self, name) 500 except AttributeError: 501 pass
clears all cached properties
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
47class TokenizationMode(Enum): 48 """legacy tokenization modes 49 50 > [!CAUTION] 51 > Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use. 52 > Use `MazeTokenizerModular` instead. 53 54 # Abbreviations: 55 - `AOTP`: Ajacency list, Origin, Target, Path 56 - `UT`: Unique Token (for each coordiate) 57 - `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers) 58 59 # Modes: 60 - `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization 61 example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)` 62 - `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible 63 uses `corner_first_ndindex` function to order the tokens 64 - `AOTP_CTT_indexed`: each coordinate is a tuple of integers 65 """ 66 67 AOTP_UT_rasterized = "AOTP_UT_rasterized" 68 AOTP_UT_uniform = "AOTP_UT_uniform" 69 AOTP_CTT_indexed = "AOTP_CTT_indexed" 70 71 def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer": 72 "convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`" 73 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)
legacy tokenization modes
Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use.
Use MazeTokenizerModular
instead.
Abbreviations:
AOTP
: Ajacency list, Origin, Target, PathUT
: Unique Token (for each coordiate)CTT
: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)
Modes:
AOTP_UT_rasterized
: the "classic" mode: assigning tokens to each coordinate is done via rasterization example: for a 3x3 maze, token order is(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)
AOTP_UT_uniform
: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible usescorner_first_ndindex
function to order the tokensAOTP_CTT_indexed
: each coordinate is a tuple of integers
71 def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer": 72 "convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`" 73 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)
convert the mode to a legacy MazeTokenizer
object given a max_grid_size
Inherited Members
- enum.Enum
- name
- value
51@serializable_dataclass( 52 frozen=True, 53 kw_only=True, 54 properties_to_serialize=["tokenizer_element_tree_concrete", "name"], 55) 56class MazeTokenizerModular(SerializableDataclass): 57 """Tokenizer for mazes 58 59 # Parameters 60 - `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt. 61 62 # Development 63 - To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`. 64 - Furthermore, the mapping reflected in `from_legacy` must also be maintained. 65 - Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior. 66 """ 67 68 prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field( 69 default=PromptSequencers.AOTP(), 70 loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers), 71 ) 72 73 def hash_int(self) -> int: 74 "return integer hash using blake2b" 75 return _hash_tokenizer_name(self.name) 76 77 def __hash__(self) -> int: 78 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" 79 return self.hash_int() 80 81 def hash_b64(self, n_bytes: int = 8) -> str: 82 """filename-safe base64 encoding of the hash""" 83 # Use modulus to ensure the integer fits within n_bytes * 8 bits 84 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8)) 85 86 encoded = base64.b64encode( 87 hash_mod.to_bytes(n_bytes, byteorder="big"), 88 altchars=b"-_", 89 ).decode() 90 91 # Remove any padding equals signs 92 return encoded.rstrip("=") 93 94 # Information Querying Methods 95 96 @cached_property 97 def tokenizer_elements(self) -> list[_TokenizerElement]: 98 "returns a list of all the elements of this tokenizer" 99 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()] 100 101 def tokenizer_element_tree(self, abstract: bool = False) -> str: 102 """Returns a string representation of the tree of tokenizer elements contained in `self`. 103 104 # Parameters 105 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 106 """ 107 return "\n".join( 108 [ 109 type(self).__name__, 110 self.prompt_sequencer.tokenizer_element_tree( 111 abstract=abstract, 112 depth=1, 113 ), 114 ], 115 ) 116 117 @property 118 def tokenizer_element_tree_concrete(self) -> str: 119 """Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`.""" 120 return self.tokenizer_element_tree() 121 122 def tokenizer_element_dict(self) -> dict: 123 """Nested dictionary of the internal `TokenizerElement`s.""" 124 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()} 125 126 @property 127 def name(self) -> str: 128 """Serializes MazeTokenizer into a key for encoding in zanj""" 129 return "-".join([type(self).__name__, self.prompt_sequencer.name]) # noqa: FLY002 130 131 def summary(self) -> dict[str, str]: 132 """Single-level dictionary of the internal `TokenizerElement`s.""" 133 return { 134 # "prompt_sequencer": self.prompt_sequencer.name, 135 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements}, 136 } 137 138 @staticmethod 139 def _type_check(obj: any) -> None: 140 """Helper method for `has_element`""" 141 if not ( 142 isinstance(obj, _TokenizerElement) 143 or (isinstance(obj, type) and issubclass(obj, _TokenizerElement)) 144 ): 145 err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass." 146 raise TypeError(err_msg) 147 148 def _has_element_singular( 149 self, 150 el: type[_TokenizerElement] | _TokenizerElement, 151 ) -> bool: 152 """Helper method for `has_element`""" 153 self._type_check(el) 154 if isinstance(el, type): 155 return any(isinstance(e, el) for e in self.tokenizer_elements) 156 else: 157 return el in self.tokenizer_elements 158 159 def has_element( 160 self, 161 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement], 162 ) -> bool: 163 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`. 164 165 Querying with a partial subset of `_TokenizerElement` fields is not currently supported. 166 To do such a query, assemble multiple calls to `has_elements`. 167 168 # Parameters 169 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes. 170 If an instance is provided, then comparison is done via instance equality. 171 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted. 172 """ 173 if len(elements) == 1 and isinstance(elements[0], Iterable): 174 elements = elements[0] 175 return all(self._has_element_singular(e) for e in elements) 176 177 def is_valid(self, do_except: bool = False) -> bool: 178 """Returns `True` if `self` is a valid tokenizer. 179 180 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method. 181 """ 182 return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements) 183 184 def is_legacy_equivalent(self) -> bool: 185 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`.""" 186 return any( 187 self == MazeTokenizerModular.from_legacy(tok_mode) 188 for tok_mode in TokenizationMode 189 ) 190 191 def is_tested_tokenizer(self, do_except: bool = False) -> bool: 192 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers. 193 194 uses an fst on the `name` attributes of all the tokenizers 195 196 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested. 197 """ 198 is_valid: bool = self.is_valid(do_except=do_except) 199 in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except) 200 201 if do_except: 202 assert is_valid, "self.is_valid returns False" 203 return True 204 else: 205 return in_tested_fst and is_valid 206 207 def is_AOTP(self) -> bool: 208 "is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path" 209 return self.has_element(PromptSequencers.AOTP) 210 211 def is_UT(self) -> bool: 212 "is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)" 213 return self.has_element(CoordTokenizers.UT) 214 215 # Alternate Constructors 216 # ====================== 217 218 @classmethod 219 def from_legacy( 220 cls, 221 legacy_maze_tokenizer: MazeTokenizer | TokenizationMode, 222 ) -> "MazeTokenizerModular": 223 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance.""" 224 if isinstance(legacy_maze_tokenizer, MazeTokenizer): 225 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode 226 return { 227 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(), 228 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(), 229 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular( 230 prompt_sequencer=PromptSequencers.AOTP( 231 coord_tokenizer=CoordTokenizers.CTT(), 232 ), 233 ), 234 }[legacy_maze_tokenizer] 235 236 # Simple properties 237 # ================= 238 @classmethod 239 def from_tokens( 240 cls, 241 tokens: str | list[str], 242 ) -> "MazeTokenizerModular": 243 """Infers most `MazeTokenizerModular` parameters from a full sequence of tokens.""" 244 raise NotImplementedError( 245 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported", 246 ) 247 248 @property 249 def token_arr(self) -> list[str] | None: 250 """map from index to token""" 251 return VOCAB_LIST 252 253 @property 254 def tokenizer_map(self) -> dict[str, int]: 255 """map from token to index""" 256 return VOCAB_TOKEN_TO_INDEX 257 258 @property 259 def vocab_size(self) -> int: 260 """Number of tokens in the static vocab""" 261 return len(VOCAB_LIST) 262 263 @property 264 def n_tokens(self) -> int: 265 "get the number of tokens in the vocabulary (deprecated)" 266 err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead." 267 raise NameError(err_msg) 268 269 @property 270 def padding_token_index(self) -> int: 271 "get the index of the padding token" 272 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING] 273 274 # conversion functions 275 # ============================================================ 276 277 def to_tokens( 278 self, 279 maze: LatticeMaze, 280 ) -> list[str]: 281 """Converts maze into a list of tokens.""" 282 return self.prompt_sequencer.to_tokens(maze) 283 284 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]: 285 "calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords" 286 return list( 287 flatten( 288 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords], 289 ), 290 ) 291 292 # TODO: unclear why we need to use `noqa: N805` here since its a classmethod 293 # maybe we need to hit every overload with `@classmethod`? 294 @overload 295 def strings_to_coords( 296 cls, # noqa: N805 297 text: str | list[str], 298 when_noncoord: Literal["skip"] = "skip", 299 ) -> list[CoordTup]: ... 300 @overload 301 def strings_to_coords( 302 cls, # noqa: N805 303 text: str | list[str], 304 when_noncoord: Literal["error"] = "error", 305 ) -> list[CoordTup]: ... 306 @overload 307 def strings_to_coords( 308 cls, # noqa: N805 309 text: str | list[str], 310 when_noncoord: Literal["include"] = "include", 311 ) -> list[str | CoordTup]: ... 312 @classmethod 313 def strings_to_coords( 314 cls, 315 text: str | list[str], 316 when_noncoord: WhenMissing = "skip", 317 ) -> list[str | CoordTup]: 318 "wrapper for maze_dataset.token_utils.strings_to_coords" 319 warnings.warn( 320 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.", 321 TokenizerPendingDeprecationWarning, 322 ) 323 return strings_to_coords(text=text, when_noncoord=when_noncoord) 324 325 @staticmethod 326 def encode(text: str | list[str]) -> list[int]: 327 """encode a string or list of strings into a list of tokens""" 328 try: 329 if isinstance(text, str): 330 text = text.split() 331 return [VOCAB_TOKEN_TO_INDEX[token] for token in text] 332 except KeyError as e: 333 err_msg: str = f"Token {e} not found in `VOCAB`." 334 raise TokenError(err_msg) from e 335 336 @staticmethod 337 def decode( 338 token_ids: Sequence[int], 339 joined_tokens: bool = False, 340 ) -> list[str] | str: 341 """decode a list of tokens into a string or list of strings""" 342 try: 343 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids] 344 except IndexError as e: 345 err_msg: str = f"Token index '{e}' not found in `VOCAB`." 346 raise TokenError(err_msg) from e 347 if joined_tokens: 348 return " ".join(output) 349 else: 350 return output
Tokenizer for mazes
Parameters
prompt_sequencer
: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.
Development
- To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy
TokenizationMode.AOTP_UT_Uniform
. - Furthermore, the mapping reflected in
from_legacy
must also be maintained. - Updates to
MazeTokenizerModular
or the_TokenizerElement
hierarchy must maintain that behavior.
73 def hash_int(self) -> int: 74 "return integer hash using blake2b" 75 return _hash_tokenizer_name(self.name)
return integer hash using blake2b
81 def hash_b64(self, n_bytes: int = 8) -> str: 82 """filename-safe base64 encoding of the hash""" 83 # Use modulus to ensure the integer fits within n_bytes * 8 bits 84 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8)) 85 86 encoded = base64.b64encode( 87 hash_mod.to_bytes(n_bytes, byteorder="big"), 88 altchars=b"-_", 89 ).decode() 90 91 # Remove any padding equals signs 92 return encoded.rstrip("=")
filename-safe base64 encoding of the hash
96 @cached_property 97 def tokenizer_elements(self) -> list[_TokenizerElement]: 98 "returns a list of all the elements of this tokenizer" 99 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]
returns a list of all the elements of this tokenizer
101 def tokenizer_element_tree(self, abstract: bool = False) -> str: 102 """Returns a string representation of the tree of tokenizer elements contained in `self`. 103 104 # Parameters 105 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 106 """ 107 return "\n".join( 108 [ 109 type(self).__name__, 110 self.prompt_sequencer.tokenizer_element_tree( 111 abstract=abstract, 112 depth=1, 113 ), 114 ], 115 )
Returns a string representation of the tree of tokenizer elements contained in self
.
Parameters
abstract: bool
: Whether to print the name of the abstract base class or the concrete class for each_TokenizerElement
instance.
117 @property 118 def tokenizer_element_tree_concrete(self) -> str: 119 """Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`.""" 120 return self.tokenizer_element_tree()
Property wrapper for tokenizer_element_tree
so that it can be used in properties_to_serialize
.
122 def tokenizer_element_dict(self) -> dict: 123 """Nested dictionary of the internal `TokenizerElement`s.""" 124 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}
Nested dictionary of the internal TokenizerElement
s.
126 @property 127 def name(self) -> str: 128 """Serializes MazeTokenizer into a key for encoding in zanj""" 129 return "-".join([type(self).__name__, self.prompt_sequencer.name]) # noqa: FLY002
Serializes MazeTokenizer into a key for encoding in zanj
131 def summary(self) -> dict[str, str]: 132 """Single-level dictionary of the internal `TokenizerElement`s.""" 133 return { 134 # "prompt_sequencer": self.prompt_sequencer.name, 135 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements}, 136 }
Single-level dictionary of the internal TokenizerElement
s.
159 def has_element( 160 self, 161 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement], 162 ) -> bool: 163 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`. 164 165 Querying with a partial subset of `_TokenizerElement` fields is not currently supported. 166 To do such a query, assemble multiple calls to `has_elements`. 167 168 # Parameters 169 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes. 170 If an instance is provided, then comparison is done via instance equality. 171 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted. 172 """ 173 if len(elements) == 1 and isinstance(elements[0], Iterable): 174 elements = elements[0] 175 return all(self._has_element_singular(e) for e in elements)
Returns True if the MazeTokenizerModular
instance contains ALL of the items specified in elements
.
Querying with a partial subset of _TokenizerElement
fields is not currently supported.
To do such a query, assemble multiple calls to has_elements
.
Parameters
elements
: Singleton or iterable of_TokenizerElement
instances or classes. If an instance is provided, then comparison is done via instance equality. If a class is provided, then comparison isdone viaisinstance
. I.e., any instance of that class is accepted.
177 def is_valid(self, do_except: bool = False) -> bool: 178 """Returns `True` if `self` is a valid tokenizer. 179 180 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method. 181 """ 182 return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)
Returns True
if self
is a valid tokenizer.
Evaluates the validity of all of self.tokenizer_elements
according to each one's method.
184 def is_legacy_equivalent(self) -> bool: 185 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`.""" 186 return any( 187 self == MazeTokenizerModular.from_legacy(tok_mode) 188 for tok_mode in TokenizationMode 189 )
Returns if self
has identical stringification behavior as any legacy MazeTokenizer
.
191 def is_tested_tokenizer(self, do_except: bool = False) -> bool: 192 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers. 193 194 uses an fst on the `name` attributes of all the tokenizers 195 196 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested. 197 """ 198 is_valid: bool = self.is_valid(do_except=do_except) 199 in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except) 200 201 if do_except: 202 assert is_valid, "self.is_valid returns False" 203 return True 204 else: 205 return in_tested_fst and is_valid
Returns if the tokenizer is returned by all_tokenizers.get_all_tokenizers
, the set of tested and reliable tokenizers.
uses an fst on the name
attributes of all the tokenizers
if do_assert
is True
, raises an AssertionError
if the tokenizer is not tested.
207 def is_AOTP(self) -> bool: 208 "is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path" 209 return self.has_element(PromptSequencers.AOTP)
is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path
211 def is_UT(self) -> bool: 212 "is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)" 213 return self.has_element(CoordTokenizers.UT)
is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)
218 @classmethod 219 def from_legacy( 220 cls, 221 legacy_maze_tokenizer: MazeTokenizer | TokenizationMode, 222 ) -> "MazeTokenizerModular": 223 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance.""" 224 if isinstance(legacy_maze_tokenizer, MazeTokenizer): 225 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode 226 return { 227 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(), 228 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(), 229 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular( 230 prompt_sequencer=PromptSequencers.AOTP( 231 coord_tokenizer=CoordTokenizers.CTT(), 232 ), 233 ), 234 }[legacy_maze_tokenizer]
Maps a legacy MazeTokenizer
or TokenizationMode
to its equivalent MazeTokenizerModular
instance.
238 @classmethod 239 def from_tokens( 240 cls, 241 tokens: str | list[str], 242 ) -> "MazeTokenizerModular": 243 """Infers most `MazeTokenizerModular` parameters from a full sequence of tokens.""" 244 raise NotImplementedError( 245 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported", 246 )
Infers most MazeTokenizerModular
parameters from a full sequence of tokens.
248 @property 249 def token_arr(self) -> list[str] | None: 250 """map from index to token""" 251 return VOCAB_LIST
map from index to token
253 @property 254 def tokenizer_map(self) -> dict[str, int]: 255 """map from token to index""" 256 return VOCAB_TOKEN_TO_INDEX
map from token to index
258 @property 259 def vocab_size(self) -> int: 260 """Number of tokens in the static vocab""" 261 return len(VOCAB_LIST)
Number of tokens in the static vocab
263 @property 264 def n_tokens(self) -> int: 265 "get the number of tokens in the vocabulary (deprecated)" 266 err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead." 267 raise NameError(err_msg)
get the number of tokens in the vocabulary (deprecated)
269 @property 270 def padding_token_index(self) -> int: 271 "get the index of the padding token" 272 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]
get the index of the padding token
277 def to_tokens( 278 self, 279 maze: LatticeMaze, 280 ) -> list[str]: 281 """Converts maze into a list of tokens.""" 282 return self.prompt_sequencer.to_tokens(maze)
Converts maze into a list of tokens.
284 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]: 285 "calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords" 286 return list( 287 flatten( 288 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords], 289 ), 290 )
calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords
312 @classmethod 313 def strings_to_coords( 314 cls, 315 text: str | list[str], 316 when_noncoord: WhenMissing = "skip", 317 ) -> list[str | CoordTup]: 318 "wrapper for maze_dataset.token_utils.strings_to_coords" 319 warnings.warn( 320 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.", 321 TokenizerPendingDeprecationWarning, 322 ) 323 return strings_to_coords(text=text, when_noncoord=when_noncoord)
wrapper for maze_dataset.token_utils.strings_to_coords
325 @staticmethod 326 def encode(text: str | list[str]) -> list[int]: 327 """encode a string or list of strings into a list of tokens""" 328 try: 329 if isinstance(text, str): 330 text = text.split() 331 return [VOCAB_TOKEN_TO_INDEX[token] for token in text] 332 except KeyError as e: 333 err_msg: str = f"Token {e} not found in `VOCAB`." 334 raise TokenError(err_msg) from e
encode a string or list of strings into a list of tokens
336 @staticmethod 337 def decode( 338 token_ids: Sequence[int], 339 joined_tokens: bool = False, 340 ) -> list[str] | str: 341 """decode a list of tokens into a string or list of strings""" 342 try: 343 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids] 344 except IndexError as e: 345 err_msg: str = f"Token index '{e}' not found in `VOCAB`." 346 raise TokenError(err_msg) from e 347 if joined_tokens: 348 return " ".join(output) 349 else: 350 return output
decode a list of tokens into a string or list of strings
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict