maze_dataset.tokenization.maze_tokenizer_legacy
legacy tokenizer which uses a TokenizationMode
enum and a MazeTokenizer
class
MazeTokenizerModular
is the new standard for tokenization. This class is no longer recommended
for use, but will remain for compatibility with existing code.
1"""legacy tokenizer which uses a `TokenizationMode` enum and a `MazeTokenizer` class 2 3> [!CAUTION] 4> `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended 5> for use, but will remain for compatibility with existing code. 6 7""" 8 9import warnings 10from enum import Enum 11from functools import cached_property 12from typing import ( 13 Callable, 14 Iterable, 15 Literal, 16 Mapping, 17 Sequence, 18 overload, 19) 20 21import numpy as np 22from muutils.json_serialize import ( 23 SerializableDataclass, 24 serializable_dataclass, 25 serializable_field, 26) 27from muutils.kappa import Kappa 28from muutils.misc.sequence import WhenMissing 29 30# from maze_dataset import SolvedMaze 31from maze_dataset.constants import ( 32 SPECIAL_TOKENS, 33 CoordTup, 34) 35from maze_dataset.token_utils import ( 36 TokenizerPendingDeprecationWarning, 37 _coord_to_strings_indexed, 38 _coord_to_strings_UT, 39 coords_to_strings, 40 strings_to_coords, 41) 42from maze_dataset.tokenization.common import TokenError 43from maze_dataset.utils import corner_first_ndindex 44 45 46class TokenizationMode(Enum): 47 """legacy tokenization modes 48 49 > [!CAUTION] 50 > Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use. 51 > Use `MazeTokenizerModular` instead. 52 53 # Abbreviations: 54 - `AOTP`: Ajacency list, Origin, Target, Path 55 - `UT`: Unique Token (for each coordiate) 56 - `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers) 57 58 # Modes: 59 - `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization 60 example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)` 61 - `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible 62 uses `corner_first_ndindex` function to order the tokens 63 - `AOTP_CTT_indexed`: each coordinate is a tuple of integers 64 """ 65 66 AOTP_UT_rasterized = "AOTP_UT_rasterized" 67 AOTP_UT_uniform = "AOTP_UT_uniform" 68 AOTP_CTT_indexed = "AOTP_CTT_indexed" 69 70 def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer": 71 "convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`" 72 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size) 73 74 75_NDINDEX_FUNC_MAP: dict[ 76 TokenizationMode, 77 Callable[[int], Iterable[tuple[int, ...]]], 78] = { 79 TokenizationMode.AOTP_UT_rasterized: lambda n: list(np.ndindex(n, n)), 80 TokenizationMode.AOTP_UT_uniform: lambda n: corner_first_ndindex(n, 2), 81} 82 83 84def is_UT(tokenization_mode: TokenizationMode) -> bool: 85 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 86 return tokenization_mode in ( 87 TokenizationMode.AOTP_UT_rasterized, 88 TokenizationMode.AOTP_UT_uniform, 89 ) 90 91 92def get_tokens_up_to_path_start( 93 tokens: list[str], 94 include_start_coord: bool = True, 95 tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform, 96) -> list[str]: 97 """get tokens up to the path start token 98 99 # Parameters: 100 - `tokens : list[str]` 101 - `include_start_coord : bool` 102 (defaults to `True`) 103 - `tokenization_mode : TokenizationMode` 104 (defaults to `TokenizationMode.AOTP_UT_uniform`) 105 106 # Returns: 107 - `list[str]` subsequence of `tokens` up to the path start token 108 109 # Raises: 110 - `ValueError` : if `tokenization_mode` is invalid 111 """ 112 warnings.warn( 113 "`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.", 114 TokenizerPendingDeprecationWarning, 115 ) 116 path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1 117 if include_start_coord: 118 if is_UT(tokenization_mode): 119 return tokens[: path_start_idx + 1] 120 elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 121 return tokens[: path_start_idx + 5] 122 else: 123 err_msg: str = f"Invalid tokenization mode: {tokenization_mode}" 124 raise ValueError(err_msg) 125 else: 126 return tokens[:path_start_idx] 127 128 129_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE: list[str] = [ 130 "name", 131 "max_grid_size", 132 "token_arr", 133 "tokenizer_map", 134 "vocab_size", 135 "padding_token_index", 136] 137 138 139@serializable_dataclass( 140 properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, 141 kw_only=True, 142) 143class MazeTokenizer(SerializableDataclass): 144 """LEGACY Tokenizer for mazes 145 146 > [!CAUTION] 147 > `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended 148 > for use, but will remain for compatibility with existing code. 149 150 # Parameters: 151 - `tokenization_mode: TokenizationMode` 152 mode of tokenization. required. 153 - `max_grid_size: int | None` 154 maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text 155 156 # Properties 157 - `name: str` 158 auto-generated name of the tokenizer from mode and size 159 160 ## Conditional Properties 161 162 - `node_strings_map: Mapping[CoordTup, str]` 163 map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode 164 165 these all return `None` if `max_grid_size` is `None`. 166 Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None` 167 168 - `token_arr: list[str]` 169 list of tokens, in order of their indices in the vocabulary 170 - `tokenizer_map: Mapping[str, int]` 171 map from token to index 172 - `vocab_size: int` 173 size of the vocabulary 174 - `padding_token_index: int` 175 index of the padding token 176 177 # Methods 178 - `coords_to_strings(coords: list[CoordTup]) -> list[str]` 179 convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates 180 - `strings_to_coords(strings: list[str]) -> list[CoordTup]` 181 convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates 182 183 """ 184 185 # parameters 186 # ============================================================ 187 188 tokenization_mode: TokenizationMode = serializable_field( 189 default=TokenizationMode.AOTP_UT_uniform, 190 serialization_fn=lambda x: x.value, 191 loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]], 192 ) 193 194 max_grid_size: int | None = serializable_field(default=None) 195 196 # properties 197 # ============================================================ 198 199 @property 200 def name(self) -> str: 201 """auto-generated name of the tokenizer from mode and size""" 202 max_grid_size_str: str = ( 203 f"-g{self.max_grid_size}" if self.max_grid_size is not None else "" 204 ) 205 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}" 206 207 @cached_property 208 def _node_strings_map(self) -> Mapping[CoordTup, list[str]]: 209 """map a coordinate to a token""" 210 if self.tokenization_mode in ( 211 TokenizationMode.AOTP_UT_rasterized, 212 TokenizationMode.AOTP_UT_uniform, 213 ): 214 return Kappa(_coord_to_strings_UT) 215 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 216 return Kappa(_coord_to_strings_indexed) 217 else: 218 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 219 raise ValueError(err_msg) 220 221 @cached_property 222 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None: 223 """map a coordinate to a token""" 224 if self.tokenization_mode in ( 225 TokenizationMode.AOTP_UT_rasterized, 226 TokenizationMode.AOTP_UT_uniform, 227 ): 228 return None 229 else: 230 return self._node_strings_map 231 232 # conditional properties (on max_grid_size existing) 233 # ------------------------------------------------------------ 234 235 @cached_property 236 def _token_arr(self) -> list[str]: 237 """map from index to token""" 238 if self.max_grid_size is None: 239 err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }" 240 raise ValueError(err_msg) 241 242 output: list[str] = list(SPECIAL_TOKENS.values()) 243 244 if self.tokenization_mode in ( 245 TokenizationMode.AOTP_UT_rasterized, 246 TokenizationMode.AOTP_UT_uniform, 247 ): 248 output.extend( 249 [ 250 self._node_strings_map[coord][0] 251 for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode]( 252 self.max_grid_size, 253 ) 254 ], 255 ) 256 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 257 # TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models 258 output.extend( 259 [ 260 "(", 261 ",", 262 ")", # new special chars 263 *map(str, range(self.max_grid_size)), # numbers 264 ], 265 ) 266 else: 267 err_msg: str = ( 268 f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}", 269 ) 270 raise ValueError(err_msg) 271 272 return output 273 274 @cached_property 275 def token_arr(self) -> list[str] | None: 276 "get the token array if the max_grid_size is specified" 277 if self.max_grid_size is None: 278 return None 279 return self._token_arr 280 281 @cached_property 282 def _tokenizer_map(self) -> dict[str, int]: 283 """map from token to index""" 284 return {token: i for i, token in enumerate(self._token_arr)} 285 286 @cached_property 287 def tokenizer_map(self) -> dict[str, int] | None: 288 "get the tokenizer map if the max_grid_size is specified" 289 if self.max_grid_size is None: 290 return None 291 return self._tokenizer_map 292 293 @property 294 def _vocab_size(self) -> int: 295 return len(self._token_arr) 296 297 @property 298 def vocab_size(self) -> int | None: 299 "get the size of the vocabulary if the max_grid_size is specified" 300 if self.max_grid_size is None: 301 return None 302 return self._vocab_size 303 304 @property 305 def _n_tokens(self) -> int: 306 # TODO: deprecate 307 return self._vocab_size 308 309 @property 310 def n_tokens(self) -> int | None: 311 "get the number of tokens if the max_grid_size is specified" 312 if self.max_grid_size is None: 313 return None 314 return self._n_tokens 315 316 @cached_property 317 def _padding_token_index(self) -> int: 318 return self.tokenizer_map[SPECIAL_TOKENS.PADDING] 319 320 @cached_property 321 def padding_token_index(self) -> int | None: 322 "get the index of the padding token if it exists" 323 if self.max_grid_size is None: 324 return None 325 return self._padding_token_index 326 327 # conversion functions 328 # ============================================================ 329 330 @overload 331 def coords_to_strings( 332 self, 333 coords: list[str | CoordTup], 334 when_noncoord: Literal["include", "skip"] = "skip", 335 ) -> list[str]: ... 336 @overload 337 def coords_to_strings( 338 self, 339 coords: list[CoordTup], 340 when_noncoord: Literal["error"] = "error", 341 ) -> list[str]: ... 342 def coords_to_strings( 343 self, 344 coords: list[CoordTup], 345 when_noncoord: WhenMissing = "skip", 346 ) -> list[str]: 347 """map a list of coordinate tuples (and maybe other tokens) to strings 348 349 wraps `maze_dataset.token_utils.coords_to_strings` with either 350 `_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode 351 """ 352 if self.tokenization_mode in ( 353 TokenizationMode.AOTP_UT_rasterized, 354 TokenizationMode.AOTP_UT_uniform, 355 ): 356 return coords_to_strings( 357 coords=coords, 358 coord_to_strings_func=_coord_to_strings_UT, 359 when_noncoord=when_noncoord, 360 ) 361 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 362 return coords_to_strings( 363 coords=coords, 364 coord_to_strings_func=_coord_to_strings_indexed, 365 when_noncoord=when_noncoord, 366 ) 367 else: 368 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 369 raise ValueError(err_msg) 370 371 @overload 372 def strings_to_coords( 373 cls, # noqa: N805 374 text: str | list[str], 375 when_noncoord: Literal["skip"] = "skip", 376 ) -> list[CoordTup]: ... 377 @overload 378 def strings_to_coords( 379 cls, # noqa: N805 380 text: str | list[str], 381 when_noncoord: Literal["error"] = "error", 382 ) -> list[CoordTup]: ... 383 @overload 384 def strings_to_coords( 385 cls, # noqa: N805 386 text: str | list[str], 387 when_noncoord: Literal["include"] = "include", 388 ) -> list[str | CoordTup]: ... 389 @classmethod 390 def strings_to_coords( 391 cls, 392 text: str | list[str], 393 when_noncoord: WhenMissing = "skip", 394 ) -> list[str | CoordTup]: 395 "wrapper for `maze_dataset.token_utils.strings_to_coords`" 396 return strings_to_coords(text=text, when_noncoord=when_noncoord) 397 398 def encode(self, text: str | list[str]) -> list[int]: 399 """encode a string or list of strings into a list of tokens""" 400 try: 401 if isinstance(text, str): 402 text = text.split() 403 return [self.tokenizer_map[token] for token in text] 404 except KeyError as e: 405 err_msg: str = ( 406 f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}" 407 ) 408 raise TokenError(err_msg) from e 409 410 def decode( 411 self, 412 tokens: Sequence[int], 413 joined_tokens: bool = False, 414 ) -> list[str] | str: 415 """decode a list of tokens into a string or list of strings""" 416 try: 417 output: list[str] = [self.token_arr[token] for token in tokens] 418 except IndexError as e: 419 err_msg: str = ( 420 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}" 421 ) 422 raise TokenError(err_msg) from e 423 if joined_tokens: 424 return " ".join(output) 425 else: 426 return output 427 428 # UT-only coordinate stuff 429 # ============================================================ 430 431 @cached_property 432 def coordinate_tokens_coords(self) -> dict[CoordTup, int]: 433 "map of coordiante tuples to their token ids, only valid for UT" 434 # print(f"{self.tokenization_mode = }") 435 if not self.is_UT(): 436 err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }" 437 raise ValueError(err_msg) 438 439 if self.max_grid_size is None: 440 err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }" 441 raise ValueError(err_msg) 442 443 raw_converted: list[CoordTup | str] = self.strings_to_coords( 444 self.token_arr, 445 when_noncoord="include", 446 ) 447 448 # filter out non-coordinates 449 return { 450 coord: i 451 for i, coord in enumerate(raw_converted) 452 if not isinstance(coord, str) 453 } 454 455 @cached_property 456 def coordinate_tokens_ids(self) -> dict[str, int]: 457 "map of coordinate tokens to their token ids, only valid for UT" 458 # checks performed in call 459 output: dict[str, int] = dict() 460 461 for coord, index in self.coordinate_tokens_coords.items(): 462 _for_key: list[str] = self.coords_to_strings([coord]) 463 assert len(_for_key) == 1 464 output[_for_key[0]] = index 465 466 return output 467 468 # other 469 # ============================================================ 470 471 def summary(self) -> dict: 472 """returns a summary of the tokenization mode""" 473 return { 474 "tokenization_mode": self.tokenization_mode.value, 475 "max_grid_size": self.max_grid_size, 476 "vocab_size": self.vocab_size, 477 } 478 479 def is_AOTP(self) -> bool: 480 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path""" 481 return self.tokenization_mode in ( 482 TokenizationMode.AOTP_UT_rasterized, 483 TokenizationMode.AOTP_UT_uniform, 484 TokenizationMode.AOTP_CTT_indexed, 485 ) 486 487 def is_UT(self) -> bool: 488 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 489 return is_UT(self.tokenization_mode) 490 491 def clear_cache(self) -> None: 492 """clears all cached properties""" 493 # delete the properties only if they exist 494 for name, prop in self.__class__.__dict__.items(): 495 if isinstance(prop, cached_property): 496 # if the property exists, delete it 497 try: # noqa: SIM105 498 delattr(self, name) 499 except AttributeError: 500 pass
47class TokenizationMode(Enum): 48 """legacy tokenization modes 49 50 > [!CAUTION] 51 > Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use. 52 > Use `MazeTokenizerModular` instead. 53 54 # Abbreviations: 55 - `AOTP`: Ajacency list, Origin, Target, Path 56 - `UT`: Unique Token (for each coordiate) 57 - `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers) 58 59 # Modes: 60 - `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization 61 example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)` 62 - `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible 63 uses `corner_first_ndindex` function to order the tokens 64 - `AOTP_CTT_indexed`: each coordinate is a tuple of integers 65 """ 66 67 AOTP_UT_rasterized = "AOTP_UT_rasterized" 68 AOTP_UT_uniform = "AOTP_UT_uniform" 69 AOTP_CTT_indexed = "AOTP_CTT_indexed" 70 71 def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer": 72 "convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`" 73 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)
legacy tokenization modes
Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use.
Use MazeTokenizerModular
instead.
Abbreviations:
AOTP
: Ajacency list, Origin, Target, PathUT
: Unique Token (for each coordiate)CTT
: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)
Modes:
AOTP_UT_rasterized
: the "classic" mode: assigning tokens to each coordinate is done via rasterization example: for a 3x3 maze, token order is(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)
AOTP_UT_uniform
: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible usescorner_first_ndindex
function to order the tokensAOTP_CTT_indexed
: each coordinate is a tuple of integers
71 def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer": 72 "convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`" 73 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)
convert the mode to a legacy MazeTokenizer
object given a max_grid_size
Inherited Members
- enum.Enum
- name
- value
85def is_UT(tokenization_mode: TokenizationMode) -> bool: 86 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 87 return tokenization_mode in ( 88 TokenizationMode.AOTP_UT_rasterized, 89 TokenizationMode.AOTP_UT_uniform, 90 )
returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)
93def get_tokens_up_to_path_start( 94 tokens: list[str], 95 include_start_coord: bool = True, 96 tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform, 97) -> list[str]: 98 """get tokens up to the path start token 99 100 # Parameters: 101 - `tokens : list[str]` 102 - `include_start_coord : bool` 103 (defaults to `True`) 104 - `tokenization_mode : TokenizationMode` 105 (defaults to `TokenizationMode.AOTP_UT_uniform`) 106 107 # Returns: 108 - `list[str]` subsequence of `tokens` up to the path start token 109 110 # Raises: 111 - `ValueError` : if `tokenization_mode` is invalid 112 """ 113 warnings.warn( 114 "`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.", 115 TokenizerPendingDeprecationWarning, 116 ) 117 path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1 118 if include_start_coord: 119 if is_UT(tokenization_mode): 120 return tokens[: path_start_idx + 1] 121 elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 122 return tokens[: path_start_idx + 5] 123 else: 124 err_msg: str = f"Invalid tokenization mode: {tokenization_mode}" 125 raise ValueError(err_msg) 126 else: 127 return tokens[:path_start_idx]
get tokens up to the path start token
Parameters:
tokens : list[str]
include_start_coord : bool
(defaults toTrue
)tokenization_mode : TokenizationMode
(defaults toTokenizationMode.AOTP_UT_uniform
)
Returns:
list[str]
subsequence oftokens
up to the path start token
Raises:
ValueError
: iftokenization_mode
is invalid
140@serializable_dataclass( 141 properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, 142 kw_only=True, 143) 144class MazeTokenizer(SerializableDataclass): 145 """LEGACY Tokenizer for mazes 146 147 > [!CAUTION] 148 > `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended 149 > for use, but will remain for compatibility with existing code. 150 151 # Parameters: 152 - `tokenization_mode: TokenizationMode` 153 mode of tokenization. required. 154 - `max_grid_size: int | None` 155 maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text 156 157 # Properties 158 - `name: str` 159 auto-generated name of the tokenizer from mode and size 160 161 ## Conditional Properties 162 163 - `node_strings_map: Mapping[CoordTup, str]` 164 map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode 165 166 these all return `None` if `max_grid_size` is `None`. 167 Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None` 168 169 - `token_arr: list[str]` 170 list of tokens, in order of their indices in the vocabulary 171 - `tokenizer_map: Mapping[str, int]` 172 map from token to index 173 - `vocab_size: int` 174 size of the vocabulary 175 - `padding_token_index: int` 176 index of the padding token 177 178 # Methods 179 - `coords_to_strings(coords: list[CoordTup]) -> list[str]` 180 convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates 181 - `strings_to_coords(strings: list[str]) -> list[CoordTup]` 182 convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates 183 184 """ 185 186 # parameters 187 # ============================================================ 188 189 tokenization_mode: TokenizationMode = serializable_field( 190 default=TokenizationMode.AOTP_UT_uniform, 191 serialization_fn=lambda x: x.value, 192 loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]], 193 ) 194 195 max_grid_size: int | None = serializable_field(default=None) 196 197 # properties 198 # ============================================================ 199 200 @property 201 def name(self) -> str: 202 """auto-generated name of the tokenizer from mode and size""" 203 max_grid_size_str: str = ( 204 f"-g{self.max_grid_size}" if self.max_grid_size is not None else "" 205 ) 206 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}" 207 208 @cached_property 209 def _node_strings_map(self) -> Mapping[CoordTup, list[str]]: 210 """map a coordinate to a token""" 211 if self.tokenization_mode in ( 212 TokenizationMode.AOTP_UT_rasterized, 213 TokenizationMode.AOTP_UT_uniform, 214 ): 215 return Kappa(_coord_to_strings_UT) 216 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 217 return Kappa(_coord_to_strings_indexed) 218 else: 219 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 220 raise ValueError(err_msg) 221 222 @cached_property 223 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None: 224 """map a coordinate to a token""" 225 if self.tokenization_mode in ( 226 TokenizationMode.AOTP_UT_rasterized, 227 TokenizationMode.AOTP_UT_uniform, 228 ): 229 return None 230 else: 231 return self._node_strings_map 232 233 # conditional properties (on max_grid_size existing) 234 # ------------------------------------------------------------ 235 236 @cached_property 237 def _token_arr(self) -> list[str]: 238 """map from index to token""" 239 if self.max_grid_size is None: 240 err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }" 241 raise ValueError(err_msg) 242 243 output: list[str] = list(SPECIAL_TOKENS.values()) 244 245 if self.tokenization_mode in ( 246 TokenizationMode.AOTP_UT_rasterized, 247 TokenizationMode.AOTP_UT_uniform, 248 ): 249 output.extend( 250 [ 251 self._node_strings_map[coord][0] 252 for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode]( 253 self.max_grid_size, 254 ) 255 ], 256 ) 257 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 258 # TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models 259 output.extend( 260 [ 261 "(", 262 ",", 263 ")", # new special chars 264 *map(str, range(self.max_grid_size)), # numbers 265 ], 266 ) 267 else: 268 err_msg: str = ( 269 f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}", 270 ) 271 raise ValueError(err_msg) 272 273 return output 274 275 @cached_property 276 def token_arr(self) -> list[str] | None: 277 "get the token array if the max_grid_size is specified" 278 if self.max_grid_size is None: 279 return None 280 return self._token_arr 281 282 @cached_property 283 def _tokenizer_map(self) -> dict[str, int]: 284 """map from token to index""" 285 return {token: i for i, token in enumerate(self._token_arr)} 286 287 @cached_property 288 def tokenizer_map(self) -> dict[str, int] | None: 289 "get the tokenizer map if the max_grid_size is specified" 290 if self.max_grid_size is None: 291 return None 292 return self._tokenizer_map 293 294 @property 295 def _vocab_size(self) -> int: 296 return len(self._token_arr) 297 298 @property 299 def vocab_size(self) -> int | None: 300 "get the size of the vocabulary if the max_grid_size is specified" 301 if self.max_grid_size is None: 302 return None 303 return self._vocab_size 304 305 @property 306 def _n_tokens(self) -> int: 307 # TODO: deprecate 308 return self._vocab_size 309 310 @property 311 def n_tokens(self) -> int | None: 312 "get the number of tokens if the max_grid_size is specified" 313 if self.max_grid_size is None: 314 return None 315 return self._n_tokens 316 317 @cached_property 318 def _padding_token_index(self) -> int: 319 return self.tokenizer_map[SPECIAL_TOKENS.PADDING] 320 321 @cached_property 322 def padding_token_index(self) -> int | None: 323 "get the index of the padding token if it exists" 324 if self.max_grid_size is None: 325 return None 326 return self._padding_token_index 327 328 # conversion functions 329 # ============================================================ 330 331 @overload 332 def coords_to_strings( 333 self, 334 coords: list[str | CoordTup], 335 when_noncoord: Literal["include", "skip"] = "skip", 336 ) -> list[str]: ... 337 @overload 338 def coords_to_strings( 339 self, 340 coords: list[CoordTup], 341 when_noncoord: Literal["error"] = "error", 342 ) -> list[str]: ... 343 def coords_to_strings( 344 self, 345 coords: list[CoordTup], 346 when_noncoord: WhenMissing = "skip", 347 ) -> list[str]: 348 """map a list of coordinate tuples (and maybe other tokens) to strings 349 350 wraps `maze_dataset.token_utils.coords_to_strings` with either 351 `_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode 352 """ 353 if self.tokenization_mode in ( 354 TokenizationMode.AOTP_UT_rasterized, 355 TokenizationMode.AOTP_UT_uniform, 356 ): 357 return coords_to_strings( 358 coords=coords, 359 coord_to_strings_func=_coord_to_strings_UT, 360 when_noncoord=when_noncoord, 361 ) 362 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 363 return coords_to_strings( 364 coords=coords, 365 coord_to_strings_func=_coord_to_strings_indexed, 366 when_noncoord=when_noncoord, 367 ) 368 else: 369 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 370 raise ValueError(err_msg) 371 372 @overload 373 def strings_to_coords( 374 cls, # noqa: N805 375 text: str | list[str], 376 when_noncoord: Literal["skip"] = "skip", 377 ) -> list[CoordTup]: ... 378 @overload 379 def strings_to_coords( 380 cls, # noqa: N805 381 text: str | list[str], 382 when_noncoord: Literal["error"] = "error", 383 ) -> list[CoordTup]: ... 384 @overload 385 def strings_to_coords( 386 cls, # noqa: N805 387 text: str | list[str], 388 when_noncoord: Literal["include"] = "include", 389 ) -> list[str | CoordTup]: ... 390 @classmethod 391 def strings_to_coords( 392 cls, 393 text: str | list[str], 394 when_noncoord: WhenMissing = "skip", 395 ) -> list[str | CoordTup]: 396 "wrapper for `maze_dataset.token_utils.strings_to_coords`" 397 return strings_to_coords(text=text, when_noncoord=when_noncoord) 398 399 def encode(self, text: str | list[str]) -> list[int]: 400 """encode a string or list of strings into a list of tokens""" 401 try: 402 if isinstance(text, str): 403 text = text.split() 404 return [self.tokenizer_map[token] for token in text] 405 except KeyError as e: 406 err_msg: str = ( 407 f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}" 408 ) 409 raise TokenError(err_msg) from e 410 411 def decode( 412 self, 413 tokens: Sequence[int], 414 joined_tokens: bool = False, 415 ) -> list[str] | str: 416 """decode a list of tokens into a string or list of strings""" 417 try: 418 output: list[str] = [self.token_arr[token] for token in tokens] 419 except IndexError as e: 420 err_msg: str = ( 421 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}" 422 ) 423 raise TokenError(err_msg) from e 424 if joined_tokens: 425 return " ".join(output) 426 else: 427 return output 428 429 # UT-only coordinate stuff 430 # ============================================================ 431 432 @cached_property 433 def coordinate_tokens_coords(self) -> dict[CoordTup, int]: 434 "map of coordiante tuples to their token ids, only valid for UT" 435 # print(f"{self.tokenization_mode = }") 436 if not self.is_UT(): 437 err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }" 438 raise ValueError(err_msg) 439 440 if self.max_grid_size is None: 441 err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }" 442 raise ValueError(err_msg) 443 444 raw_converted: list[CoordTup | str] = self.strings_to_coords( 445 self.token_arr, 446 when_noncoord="include", 447 ) 448 449 # filter out non-coordinates 450 return { 451 coord: i 452 for i, coord in enumerate(raw_converted) 453 if not isinstance(coord, str) 454 } 455 456 @cached_property 457 def coordinate_tokens_ids(self) -> dict[str, int]: 458 "map of coordinate tokens to their token ids, only valid for UT" 459 # checks performed in call 460 output: dict[str, int] = dict() 461 462 for coord, index in self.coordinate_tokens_coords.items(): 463 _for_key: list[str] = self.coords_to_strings([coord]) 464 assert len(_for_key) == 1 465 output[_for_key[0]] = index 466 467 return output 468 469 # other 470 # ============================================================ 471 472 def summary(self) -> dict: 473 """returns a summary of the tokenization mode""" 474 return { 475 "tokenization_mode": self.tokenization_mode.value, 476 "max_grid_size": self.max_grid_size, 477 "vocab_size": self.vocab_size, 478 } 479 480 def is_AOTP(self) -> bool: 481 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path""" 482 return self.tokenization_mode in ( 483 TokenizationMode.AOTP_UT_rasterized, 484 TokenizationMode.AOTP_UT_uniform, 485 TokenizationMode.AOTP_CTT_indexed, 486 ) 487 488 def is_UT(self) -> bool: 489 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 490 return is_UT(self.tokenization_mode) 491 492 def clear_cache(self) -> None: 493 """clears all cached properties""" 494 # delete the properties only if they exist 495 for name, prop in self.__class__.__dict__.items(): 496 if isinstance(prop, cached_property): 497 # if the property exists, delete it 498 try: # noqa: SIM105 499 delattr(self, name) 500 except AttributeError: 501 pass
LEGACY Tokenizer for mazes
MazeTokenizerModular
is the new standard for tokenization. This class is no longer recommended
for use, but will remain for compatibility with existing code.
Parameters:
tokenization_mode: TokenizationMode
mode of tokenization. required.max_grid_size: int | None
maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text
Properties
name: str
auto-generated name of the tokenizer from mode and size
Conditional Properties
node_strings_map: Mapping[CoordTup, str]
map from node to string. This returns amuutils.kappa.Kappa
object which you can use like a dictionary. returnsNone
if not aUT
mode
these all return None
if max_grid_size
is None
.
Prepend _
to the name to get a guaranteed type, and cause an exception if max_grid_size
is None
token_arr: list[str]
list of tokens, in order of their indices in the vocabularytokenizer_map: Mapping[str, int]
map from token to indexvocab_size: int
size of the vocabularypadding_token_index: int
index of the padding token
Methods
coords_to_strings(coords: list[CoordTup]) -> list[str]
convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinatesstrings_to_coords(strings: list[str]) -> list[CoordTup]
convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
200 @property 201 def name(self) -> str: 202 """auto-generated name of the tokenizer from mode and size""" 203 max_grid_size_str: str = ( 204 f"-g{self.max_grid_size}" if self.max_grid_size is not None else "" 205 ) 206 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"
auto-generated name of the tokenizer from mode and size
222 @cached_property 223 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None: 224 """map a coordinate to a token""" 225 if self.tokenization_mode in ( 226 TokenizationMode.AOTP_UT_rasterized, 227 TokenizationMode.AOTP_UT_uniform, 228 ): 229 return None 230 else: 231 return self._node_strings_map
map a coordinate to a token
275 @cached_property 276 def token_arr(self) -> list[str] | None: 277 "get the token array if the max_grid_size is specified" 278 if self.max_grid_size is None: 279 return None 280 return self._token_arr
get the token array if the max_grid_size is specified
287 @cached_property 288 def tokenizer_map(self) -> dict[str, int] | None: 289 "get the tokenizer map if the max_grid_size is specified" 290 if self.max_grid_size is None: 291 return None 292 return self._tokenizer_map
get the tokenizer map if the max_grid_size is specified
298 @property 299 def vocab_size(self) -> int | None: 300 "get the size of the vocabulary if the max_grid_size is specified" 301 if self.max_grid_size is None: 302 return None 303 return self._vocab_size
get the size of the vocabulary if the max_grid_size is specified
310 @property 311 def n_tokens(self) -> int | None: 312 "get the number of tokens if the max_grid_size is specified" 313 if self.max_grid_size is None: 314 return None 315 return self._n_tokens
get the number of tokens if the max_grid_size is specified
321 @cached_property 322 def padding_token_index(self) -> int | None: 323 "get the index of the padding token if it exists" 324 if self.max_grid_size is None: 325 return None 326 return self._padding_token_index
get the index of the padding token if it exists
343 def coords_to_strings( 344 self, 345 coords: list[CoordTup], 346 when_noncoord: WhenMissing = "skip", 347 ) -> list[str]: 348 """map a list of coordinate tuples (and maybe other tokens) to strings 349 350 wraps `maze_dataset.token_utils.coords_to_strings` with either 351 `_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode 352 """ 353 if self.tokenization_mode in ( 354 TokenizationMode.AOTP_UT_rasterized, 355 TokenizationMode.AOTP_UT_uniform, 356 ): 357 return coords_to_strings( 358 coords=coords, 359 coord_to_strings_func=_coord_to_strings_UT, 360 when_noncoord=when_noncoord, 361 ) 362 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 363 return coords_to_strings( 364 coords=coords, 365 coord_to_strings_func=_coord_to_strings_indexed, 366 when_noncoord=when_noncoord, 367 ) 368 else: 369 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 370 raise ValueError(err_msg)
map a list of coordinate tuples (and maybe other tokens) to strings
wraps maze_dataset.token_utils.coords_to_strings
with either
_coord_to_strings_UT
or _coord_to_strings_indexed
depending on the tokenization mode
390 @classmethod 391 def strings_to_coords( 392 cls, 393 text: str | list[str], 394 when_noncoord: WhenMissing = "skip", 395 ) -> list[str | CoordTup]: 396 "wrapper for `maze_dataset.token_utils.strings_to_coords`" 397 return strings_to_coords(text=text, when_noncoord=when_noncoord)
wrapper for maze_dataset.token_utils.strings_to_coords
399 def encode(self, text: str | list[str]) -> list[int]: 400 """encode a string or list of strings into a list of tokens""" 401 try: 402 if isinstance(text, str): 403 text = text.split() 404 return [self.tokenizer_map[token] for token in text] 405 except KeyError as e: 406 err_msg: str = ( 407 f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}" 408 ) 409 raise TokenError(err_msg) from e
encode a string or list of strings into a list of tokens
411 def decode( 412 self, 413 tokens: Sequence[int], 414 joined_tokens: bool = False, 415 ) -> list[str] | str: 416 """decode a list of tokens into a string or list of strings""" 417 try: 418 output: list[str] = [self.token_arr[token] for token in tokens] 419 except IndexError as e: 420 err_msg: str = ( 421 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}" 422 ) 423 raise TokenError(err_msg) from e 424 if joined_tokens: 425 return " ".join(output) 426 else: 427 return output
decode a list of tokens into a string or list of strings
432 @cached_property 433 def coordinate_tokens_coords(self) -> dict[CoordTup, int]: 434 "map of coordiante tuples to their token ids, only valid for UT" 435 # print(f"{self.tokenization_mode = }") 436 if not self.is_UT(): 437 err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }" 438 raise ValueError(err_msg) 439 440 if self.max_grid_size is None: 441 err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }" 442 raise ValueError(err_msg) 443 444 raw_converted: list[CoordTup | str] = self.strings_to_coords( 445 self.token_arr, 446 when_noncoord="include", 447 ) 448 449 # filter out non-coordinates 450 return { 451 coord: i 452 for i, coord in enumerate(raw_converted) 453 if not isinstance(coord, str) 454 }
map of coordiante tuples to their token ids, only valid for UT
456 @cached_property 457 def coordinate_tokens_ids(self) -> dict[str, int]: 458 "map of coordinate tokens to their token ids, only valid for UT" 459 # checks performed in call 460 output: dict[str, int] = dict() 461 462 for coord, index in self.coordinate_tokens_coords.items(): 463 _for_key: list[str] = self.coords_to_strings([coord]) 464 assert len(_for_key) == 1 465 output[_for_key[0]] = index 466 467 return output
map of coordinate tokens to their token ids, only valid for UT
472 def summary(self) -> dict: 473 """returns a summary of the tokenization mode""" 474 return { 475 "tokenization_mode": self.tokenization_mode.value, 476 "max_grid_size": self.max_grid_size, 477 "vocab_size": self.vocab_size, 478 }
returns a summary of the tokenization mode
480 def is_AOTP(self) -> bool: 481 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path""" 482 return self.tokenization_mode in ( 483 TokenizationMode.AOTP_UT_rasterized, 484 TokenizationMode.AOTP_UT_uniform, 485 TokenizationMode.AOTP_CTT_indexed, 486 )
returns true if a tokenization mode is Adjacency list, Origin, Target, Path
488 def is_UT(self) -> bool: 489 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 490 return is_UT(self.tokenization_mode)
returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)
492 def clear_cache(self) -> None: 493 """clears all cached properties""" 494 # delete the properties only if they exist 495 for name, prop in self.__class__.__dict__.items(): 496 if isinstance(prop, cached_property): 497 # if the property exists, delete it 498 try: # noqa: SIM105 499 delattr(self, name) 500 except AttributeError: 501 pass
clears all cached properties
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict