maze_dataset.tokenization
turning a maze into text
MazeTokenizerModular
is the new recommended way to do this as of 1.0.0- legacy
TokenizationMode
enum andMazeTokenizer
class for supporting existing code - a variety of helper classes and functions
There are many algorithms by which one might tokenize a 2D maze into a 1D format usable by autoregressive text models. Training multiple models on the encodings output from each of these algorithms may produce very different internal representations, learned solution algorithms, and levels of performance. To explore how different maze tokenization algorithms affect these models, the MazeTokenizerModular
class contains a rich set of options to customize how mazes are stringified. This class contains 19 discrete parameters, resulting in 5.9 million unique tokenizers. But wait, there's more! There are 6 additional parameters available in the library which are untested but further expand the the number of tokenizers by a factor of $44/3$ to 86 million.
All output sequences consist of four token regions representing different features of the maze. These regions are distinguished by color in Figure below.
- Adjacency list: A text representation of the lattice graph
- Origin: Starting coordinate
- Target: Ending coordinate
- Path: Maze solution sequence from the start to the end
Each MazeTokenizerModular
is constructed from a set of several _TokenizerElement
objects, each of which specifies how different token regions or other elements of the stringification are produced.
Optional delimiter tokens may be added in many places in the output. Delimiter options are all configured using the parameters named pre
, intra
, and post
in various _TokenizerElement
classes. Each option controls a unique delimiter token.
Here we describe each _TokenizerElement
and the behaviors they support. We also discuss some of the model behaviors and properties that may be investigated using these options.
Coordinates
The _CoordTokenizer
object controls how coordinates in the lattice are represented in across all token regions. Options include:
- Unique tokens: Each coordinate is represented as a single unique token
"(i,j)"
- Coordinate tuple tokens: Each coordinate is represented as a sequence of 2 tokens, respectively encoding the row and column positions:
["i", ",", "j"]
Adjacency List
The _AdjListTokenizer
object controls this token region. All tokenizations represent the maze connectivity as a sequence of connections or walls between pairs of adjacent coordinates in the lattice.
_EdgeSubset
: Specifies the subset of lattice edges to be tokenized- All edges: Every edge in the lattice
- Connections: Only edges which contain a connection
- Walls: Only edges which contain a wall
_EdgePermuter
: Specifies how to sequence the two coordinates in each lattice edge- Random
- Sorted: The smaller coordinate always comes first
- Both permutations: Each edge is represented twice, once with each permutation. This option attempts to represent connections in a more directionally symmetric manner. Including only one permutation of each edge may affect models' internal representations of edges, treating a path traversing the edge differently depending on if the coordinate sequence in the path matches the sequence in the adjacency list.
shuffle_d0
: Whether to shuffle the edges randomly or sort them in the output by their first coordinateconnection_token_ordinal
: Location in the sequence of the token representing whether the edge is a connection or a wall
Path
The _PathTokenizer
object controls this token region. Paths are all represented as a sequence of steps moving from the start to the end position.
_StepSize
: Specifies the size of each step- Singles: Every coordinate traversed between start and end is directly represented
- Forks: Only coordinates at forking points in the maze are represented. The paths between forking points are implicit. Using this option might train models more directly to represent forking points differently from coordinates where the maze connectivity implies an obvious next step in the path.
_StepTokenizer
: Specifies how an individual step is represented- Coordinate: The coordinates of each step are directly tokenized using a
_CoordTokenizer
- Cardinal direction: A single token corresponding to the cardinal direction taken at the starting position of that step. E.g.,
NORTH
,SOUTH
. If using a_StepSize
other than Singles, this direction may not correspond to the final direction traveled to arrive at the end position of the step. - Relative direction: A single token corresponding to the first-person perspective relative direction taken at the starting position of that step. E.g.,
RIGHT
,LEFT
. - Distance: A single token corresponding to the number of coordinate positions traversed in that step. E.g., using a
_StepSize
of Singles, the Distance token would be the same for each step, corresponding to a distance of 1 coordinate. This option is only of interest in combination with a_StepSize
other than Singles.
- Coordinate: The coordinates of each step are directly tokenized using a
A _PathTokenizer
contains a sequence of one or more unique _StepTokenizer
objects. Different step representations may be mixed and permuted, allowing for investigation of model representations of multiple aspects of a maze solution at once.
Tokenized Outputs for Training and Evaluation {#token-training}
During deployment we provide only the prompt up to the <PATH_START>
token.
Examples of usage of this dataset to train autoregressive transformers can be found in our maze-transformer
library [@maze-transformer-github]. Other tokenization and vocabulary schemes are also included, such as representing each coordinate as a pair of $i,j$ index tokens.
Extensibility
The tokenizer architecture is purposefully designed such that adding and testing a wide variety of new tokenization algorithms is fast and minimizes disturbances to functioning code. This is enabled by the modular architecture and the automatic inclusion of any new tokenizers in integration tests. To create a new tokenizer, developers forking the library may simply create their own _TokenizerElement
subclass and implement the abstract methods. If the behavior change is sufficiently small, simply adding a parameter to an existing _TokenizerElement
subclass and updating its implementation will suffice. For small additions, simply adding new cases to existing unit tests will suffice.
The breadth of tokenizers is also easily scaled in the opposite direction. Due to the exponential scaling of parameter combinations, adding a small number of new features can significantly slow certain procedures which rely on constructing all possible tokenizers, such as integration tests. If any existing subclass contains features which aren't needed, a developer tool decorator is provided which can be applied to the unneeded _TokenizerElement
subclasses to prune those features and compact the available space of tokenizers.
1"""turning a maze into text 2 3- `MazeTokenizerModular` is the new recommended way to do this as of 1.0.0 4- legacy `TokenizationMode` enum and `MazeTokenizer` class for supporting existing code 5- a variety of helper classes and functions 6 7There are many algorithms by which one might tokenize a 2D maze into a 1D format usable by autoregressive text models. Training multiple models on the encodings output from each of these algorithms may produce very different internal representations, learned solution algorithms, and levels of performance. To explore how different maze tokenization algorithms affect these models, the `MazeTokenizerModular` class contains a rich set of options to customize how mazes are stringified. This class contains 19 discrete parameters, resulting in 5.9 million unique tokenizers. But wait, there's more! There are 6 additional parameters available in the library which are untested but further expand the the number of tokenizers by a factor of $44/3$ to 86 million. 8 9All output sequences consist of four token regions representing different features of the maze. These regions are distinguished by color in Figure below. 10 11- <span style="background-color:rgb(217,210,233)">Adjacency list</span>: A text representation of the lattice graph 12- <span style="background-color:rgb(217,234,211)">Origin</span>: Starting coordinate 13- <span style="background-color:rgb(234,209,220)">Target</span>: Ending coordinate 14- <span style="background-color:rgb(207,226,243)">Path</span>: Maze solution sequence from the start to the end 15 16 17 18Each `MazeTokenizerModular` is constructed from a set of several `_TokenizerElement` objects, each of which specifies how different token regions or other elements of the stringification are produced. 19 20 21 22Optional delimiter tokens may be added in many places in the output. Delimiter options are all configured using the parameters named `pre`, `intra`, and `post` in various `_TokenizerElement` classes. Each option controls a unique delimiter token. 23Here we describe each `_TokenizerElement` and the behaviors they support. We also discuss some of the model behaviors and properties that may be investigated using these options. 24 25### Coordinates 26 27The `_CoordTokenizer` object controls how coordinates in the lattice are represented in across all token regions. Options include: 28 29- **Unique tokens**: Each coordinate is represented as a single unique token `"(i,j)"` 30- **Coordinate tuple tokens**: Each coordinate is represented as a sequence of 2 tokens, respectively encoding the row and column positions: `["i", ",", "j"]` 31 32### Adjacency List 33 34The `_AdjListTokenizer` object controls this token region. All tokenizations represent the maze connectivity as a sequence of connections or walls between pairs of adjacent coordinates in the lattice. 35 36- `_EdgeSubset`: Specifies the subset of lattice edges to be tokenized 37 - **All edges**: Every edge in the lattice 38 - **Connections**: Only edges which contain a connection 39 - **Walls**: Only edges which contain a wall 40- `_EdgePermuter`: Specifies how to sequence the two coordinates in each lattice edge 41 - **Random** 42 - **Sorted**: The smaller coordinate always comes first 43 - **Both permutations**: Each edge is represented twice, once with each permutation. This option attempts to represent connections in a more directionally symmetric manner. Including only one permutation of each edge may affect models' internal representations of edges, treating a path traversing the edge differently depending on if the coordinate sequence in the path matches the sequence in the adjacency list. 44- `shuffle_d0`: Whether to shuffle the edges randomly or sort them in the output by their first coordinate 45- `connection_token_ordinal`: Location in the sequence of the token representing whether the edge is a connection or a wall 46 47### Path 48 49The `_PathTokenizer` object controls this token region. Paths are all represented as a sequence of steps moving from the start to the end position. 50 51- `_StepSize`: Specifies the size of each step 52 - **Singles**: Every coordinate traversed between start and end is directly represented 53 - **Forks**: Only coordinates at forking points in the maze are represented. The paths between forking points are implicit. Using this option might train models more directly to represent forking points differently from coordinates where the maze connectivity implies an obvious next step in the path. 54- `_StepTokenizer`: Specifies how an individual step is represented 55 - **Coordinate**: The coordinates of each step are directly tokenized using a `_CoordTokenizer` 56 - **Cardinal direction**: A single token corresponding to the cardinal direction taken at the starting position of that step. E.g., `NORTH`, `SOUTH`. If using a `_StepSize` other than **Singles**, this direction may not correspond to the final direction traveled to arrive at the end position of the step. 57 - **Relative direction**: A single token corresponding to the first-person perspective relative direction taken at the starting position of that step. E.g., `RIGHT`, `LEFT`. 58 - **Distance**: A single token corresponding to the number of coordinate positions traversed in that step. E.g., using a `_StepSize` of **Singles**, the **Distance** token would be the same for each step, corresponding to a distance of 1 coordinate. This option is only of interest in combination with a `_StepSize` other than **Singles**. 59 60A `_PathTokenizer` contains a sequence of one or more unique `_StepTokenizer` objects. Different step representations may be mixed and permuted, allowing for investigation of model representations of multiple aspects of a maze solution at once. 61 62## Tokenized Outputs for Training and Evaluation {#token-training} 63 64During deployment we provide only the prompt up to the `<PATH_START>` token. 65 66Examples of usage of this dataset to train autoregressive transformers can be found in our `maze-transformer` library [@maze-transformer-github]. Other tokenization and vocabulary schemes are also included, such as representing each coordinate as a pair of $i,j$ index tokens. 67 68## Extensibility 69 70The tokenizer architecture is purposefully designed such that adding and testing a wide variety of new tokenization algorithms is fast and minimizes disturbances to functioning code. This is enabled by the modular architecture and the automatic inclusion of any new tokenizers in integration tests. To create a new tokenizer, developers forking the library may simply create their own `_TokenizerElement` subclass and implement the abstract methods. If the behavior change is sufficiently small, simply adding a parameter to an existing `_TokenizerElement` subclass and updating its implementation will suffice. For small additions, simply adding new cases to existing unit tests will suffice. 71 72The breadth of tokenizers is also easily scaled in the opposite direction. Due to the exponential scaling of parameter combinations, adding a small number of new features can significantly slow certain procedures which rely on constructing all possible tokenizers, such as integration tests. If any existing subclass contains features which aren't needed, a developer tool decorator is provided which can be applied to the unneeded `_TokenizerElement` subclasses to prune those features and compact the available space of tokenizers. 73 74""" 75 76from maze_dataset.tokenization.maze_tokenizer_legacy import ( 77 MazeTokenizer, 78 TokenizationMode, 79 get_tokens_up_to_path_start, 80) 81from maze_dataset.tokenization.modular.element_base import _TokenizerElement 82from maze_dataset.tokenization.modular.elements import ( 83 AdjListTokenizers, 84 CoordTokenizers, 85 EdgeGroupings, 86 EdgePermuters, 87 EdgeSubsets, 88 PathTokenizers, 89 PromptSequencers, 90 StepSizes, 91 StepTokenizers, 92 TargetTokenizers, 93) 94from maze_dataset.tokenization.modular.maze_tokenizer_modular import ( 95 MazeTokenizerModular, 96) 97 98# we don't sort alphabetically on purpose, we sort by the type 99__all__ = [ 100 # submodules 101 "modular", 102 "common", 103 "maze_tokenizer_legacy", 104 "maze_tokenizer", 105 # legacy tokenizer 106 "MazeTokenizer", 107 "TokenizationMode", 108 # MMT 109 "MazeTokenizerModular", 110 # element base 111 "_TokenizerElement", 112 # elements 113 "PromptSequencers", 114 "CoordTokenizers", 115 "AdjListTokenizers", 116 "EdgeGroupings", 117 "EdgePermuters", 118 "EdgeSubsets", 119 "TargetTokenizers", 120 "StepSizes", 121 "StepTokenizers", 122 "PathTokenizers", 123 # helpers 124 "get_tokens_up_to_path_start", 125]
140@serializable_dataclass( 141 properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, 142 kw_only=True, 143) 144class MazeTokenizer(SerializableDataclass): 145 """LEGACY Tokenizer for mazes 146 147 > [!CAUTION] 148 > `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended 149 > for use, but will remain for compatibility with existing code. 150 151 # Parameters: 152 - `tokenization_mode: TokenizationMode` 153 mode of tokenization. required. 154 - `max_grid_size: int | None` 155 maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text 156 157 # Properties 158 - `name: str` 159 auto-generated name of the tokenizer from mode and size 160 161 ## Conditional Properties 162 163 - `node_strings_map: Mapping[CoordTup, str]` 164 map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode 165 166 these all return `None` if `max_grid_size` is `None`. 167 Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None` 168 169 - `token_arr: list[str]` 170 list of tokens, in order of their indices in the vocabulary 171 - `tokenizer_map: Mapping[str, int]` 172 map from token to index 173 - `vocab_size: int` 174 size of the vocabulary 175 - `padding_token_index: int` 176 index of the padding token 177 178 # Methods 179 - `coords_to_strings(coords: list[CoordTup]) -> list[str]` 180 convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates 181 - `strings_to_coords(strings: list[str]) -> list[CoordTup]` 182 convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates 183 184 """ 185 186 # parameters 187 # ============================================================ 188 189 tokenization_mode: TokenizationMode = serializable_field( 190 default=TokenizationMode.AOTP_UT_uniform, 191 serialization_fn=lambda x: x.value, 192 loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]], 193 ) 194 195 max_grid_size: int | None = serializable_field(default=None) 196 197 # properties 198 # ============================================================ 199 200 @property 201 def name(self) -> str: 202 """auto-generated name of the tokenizer from mode and size""" 203 max_grid_size_str: str = ( 204 f"-g{self.max_grid_size}" if self.max_grid_size is not None else "" 205 ) 206 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}" 207 208 @cached_property 209 def _node_strings_map(self) -> Mapping[CoordTup, list[str]]: 210 """map a coordinate to a token""" 211 if self.tokenization_mode in ( 212 TokenizationMode.AOTP_UT_rasterized, 213 TokenizationMode.AOTP_UT_uniform, 214 ): 215 return Kappa(_coord_to_strings_UT) 216 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 217 return Kappa(_coord_to_strings_indexed) 218 else: 219 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 220 raise ValueError(err_msg) 221 222 @cached_property 223 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None: 224 """map a coordinate to a token""" 225 if self.tokenization_mode in ( 226 TokenizationMode.AOTP_UT_rasterized, 227 TokenizationMode.AOTP_UT_uniform, 228 ): 229 return None 230 else: 231 return self._node_strings_map 232 233 # conditional properties (on max_grid_size existing) 234 # ------------------------------------------------------------ 235 236 @cached_property 237 def _token_arr(self) -> list[str]: 238 """map from index to token""" 239 if self.max_grid_size is None: 240 err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }" 241 raise ValueError(err_msg) 242 243 output: list[str] = list(SPECIAL_TOKENS.values()) 244 245 if self.tokenization_mode in ( 246 TokenizationMode.AOTP_UT_rasterized, 247 TokenizationMode.AOTP_UT_uniform, 248 ): 249 output.extend( 250 [ 251 self._node_strings_map[coord][0] 252 for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode]( 253 self.max_grid_size, 254 ) 255 ], 256 ) 257 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 258 # TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models 259 output.extend( 260 [ 261 "(", 262 ",", 263 ")", # new special chars 264 *map(str, range(self.max_grid_size)), # numbers 265 ], 266 ) 267 else: 268 err_msg: str = ( 269 f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}", 270 ) 271 raise ValueError(err_msg) 272 273 return output 274 275 @cached_property 276 def token_arr(self) -> list[str] | None: 277 "get the token array if the max_grid_size is specified" 278 if self.max_grid_size is None: 279 return None 280 return self._token_arr 281 282 @cached_property 283 def _tokenizer_map(self) -> dict[str, int]: 284 """map from token to index""" 285 return {token: i for i, token in enumerate(self._token_arr)} 286 287 @cached_property 288 def tokenizer_map(self) -> dict[str, int] | None: 289 "get the tokenizer map if the max_grid_size is specified" 290 if self.max_grid_size is None: 291 return None 292 return self._tokenizer_map 293 294 @property 295 def _vocab_size(self) -> int: 296 return len(self._token_arr) 297 298 @property 299 def vocab_size(self) -> int | None: 300 "get the size of the vocabulary if the max_grid_size is specified" 301 if self.max_grid_size is None: 302 return None 303 return self._vocab_size 304 305 @property 306 def _n_tokens(self) -> int: 307 # TODO: deprecate 308 return self._vocab_size 309 310 @property 311 def n_tokens(self) -> int | None: 312 "get the number of tokens if the max_grid_size is specified" 313 if self.max_grid_size is None: 314 return None 315 return self._n_tokens 316 317 @cached_property 318 def _padding_token_index(self) -> int: 319 return self.tokenizer_map[SPECIAL_TOKENS.PADDING] 320 321 @cached_property 322 def padding_token_index(self) -> int | None: 323 "get the index of the padding token if it exists" 324 if self.max_grid_size is None: 325 return None 326 return self._padding_token_index 327 328 # conversion functions 329 # ============================================================ 330 331 @overload 332 def coords_to_strings( 333 self, 334 coords: list[str | CoordTup], 335 when_noncoord: Literal["include", "skip"] = "skip", 336 ) -> list[str]: ... 337 @overload 338 def coords_to_strings( 339 self, 340 coords: list[CoordTup], 341 when_noncoord: Literal["error"] = "error", 342 ) -> list[str]: ... 343 def coords_to_strings( 344 self, 345 coords: list[CoordTup], 346 when_noncoord: WhenMissing = "skip", 347 ) -> list[str]: 348 """map a list of coordinate tuples (and maybe other tokens) to strings 349 350 wraps `maze_dataset.token_utils.coords_to_strings` with either 351 `_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode 352 """ 353 if self.tokenization_mode in ( 354 TokenizationMode.AOTP_UT_rasterized, 355 TokenizationMode.AOTP_UT_uniform, 356 ): 357 return coords_to_strings( 358 coords=coords, 359 coord_to_strings_func=_coord_to_strings_UT, 360 when_noncoord=when_noncoord, 361 ) 362 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 363 return coords_to_strings( 364 coords=coords, 365 coord_to_strings_func=_coord_to_strings_indexed, 366 when_noncoord=when_noncoord, 367 ) 368 else: 369 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 370 raise ValueError(err_msg) 371 372 @overload 373 def strings_to_coords( 374 cls, # noqa: N805 375 text: str | list[str], 376 when_noncoord: Literal["skip"] = "skip", 377 ) -> list[CoordTup]: ... 378 @overload 379 def strings_to_coords( 380 cls, # noqa: N805 381 text: str | list[str], 382 when_noncoord: Literal["error"] = "error", 383 ) -> list[CoordTup]: ... 384 @overload 385 def strings_to_coords( 386 cls, # noqa: N805 387 text: str | list[str], 388 when_noncoord: Literal["include"] = "include", 389 ) -> list[str | CoordTup]: ... 390 @classmethod 391 def strings_to_coords( 392 cls, 393 text: str | list[str], 394 when_noncoord: WhenMissing = "skip", 395 ) -> list[str | CoordTup]: 396 "wrapper for `maze_dataset.token_utils.strings_to_coords`" 397 return strings_to_coords(text=text, when_noncoord=when_noncoord) 398 399 def encode(self, text: str | list[str]) -> list[int]: 400 """encode a string or list of strings into a list of tokens""" 401 try: 402 if isinstance(text, str): 403 text = text.split() 404 return [self.tokenizer_map[token] for token in text] 405 except KeyError as e: 406 err_msg: str = ( 407 f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}" 408 ) 409 raise TokenError(err_msg) from e 410 411 def decode( 412 self, 413 tokens: Sequence[int], 414 joined_tokens: bool = False, 415 ) -> list[str] | str: 416 """decode a list of tokens into a string or list of strings""" 417 try: 418 output: list[str] = [self.token_arr[token] for token in tokens] 419 except IndexError as e: 420 err_msg: str = ( 421 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}" 422 ) 423 raise TokenError(err_msg) from e 424 if joined_tokens: 425 return " ".join(output) 426 else: 427 return output 428 429 # UT-only coordinate stuff 430 # ============================================================ 431 432 @cached_property 433 def coordinate_tokens_coords(self) -> dict[CoordTup, int]: 434 "map of coordiante tuples to their token ids, only valid for UT" 435 # print(f"{self.tokenization_mode = }") 436 if not self.is_UT(): 437 err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }" 438 raise ValueError(err_msg) 439 440 if self.max_grid_size is None: 441 err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }" 442 raise ValueError(err_msg) 443 444 raw_converted: list[CoordTup | str] = self.strings_to_coords( 445 self.token_arr, 446 when_noncoord="include", 447 ) 448 449 # filter out non-coordinates 450 return { 451 coord: i 452 for i, coord in enumerate(raw_converted) 453 if not isinstance(coord, str) 454 } 455 456 @cached_property 457 def coordinate_tokens_ids(self) -> dict[str, int]: 458 "map of coordinate tokens to their token ids, only valid for UT" 459 # checks performed in call 460 output: dict[str, int] = dict() 461 462 for coord, index in self.coordinate_tokens_coords.items(): 463 _for_key: list[str] = self.coords_to_strings([coord]) 464 assert len(_for_key) == 1 465 output[_for_key[0]] = index 466 467 return output 468 469 # other 470 # ============================================================ 471 472 def summary(self) -> dict: 473 """returns a summary of the tokenization mode""" 474 return { 475 "tokenization_mode": self.tokenization_mode.value, 476 "max_grid_size": self.max_grid_size, 477 "vocab_size": self.vocab_size, 478 } 479 480 def is_AOTP(self) -> bool: 481 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path""" 482 return self.tokenization_mode in ( 483 TokenizationMode.AOTP_UT_rasterized, 484 TokenizationMode.AOTP_UT_uniform, 485 TokenizationMode.AOTP_CTT_indexed, 486 ) 487 488 def is_UT(self) -> bool: 489 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 490 return is_UT(self.tokenization_mode) 491 492 def clear_cache(self) -> None: 493 """clears all cached properties""" 494 # delete the properties only if they exist 495 for name, prop in self.__class__.__dict__.items(): 496 if isinstance(prop, cached_property): 497 # if the property exists, delete it 498 try: # noqa: SIM105 499 delattr(self, name) 500 except AttributeError: 501 pass
LEGACY Tokenizer for mazes
MazeTokenizerModular
is the new standard for tokenization. This class is no longer recommended
for use, but will remain for compatibility with existing code.
Parameters:
tokenization_mode: TokenizationMode
mode of tokenization. required.max_grid_size: int | None
maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text
Properties
name: str
auto-generated name of the tokenizer from mode and size
Conditional Properties
node_strings_map: Mapping[CoordTup, str]
map from node to string. This returns amuutils.kappa.Kappa
object which you can use like a dictionary. returnsNone
if not aUT
mode
these all return None
if max_grid_size
is None
.
Prepend _
to the name to get a guaranteed type, and cause an exception if max_grid_size
is None
token_arr: list[str]
list of tokens, in order of their indices in the vocabularytokenizer_map: Mapping[str, int]
map from token to indexvocab_size: int
size of the vocabularypadding_token_index: int
index of the padding token
Methods
coords_to_strings(coords: list[CoordTup]) -> list[str]
convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinatesstrings_to_coords(strings: list[str]) -> list[CoordTup]
convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates
200 @property 201 def name(self) -> str: 202 """auto-generated name of the tokenizer from mode and size""" 203 max_grid_size_str: str = ( 204 f"-g{self.max_grid_size}" if self.max_grid_size is not None else "" 205 ) 206 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}"
auto-generated name of the tokenizer from mode and size
222 @cached_property 223 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None: 224 """map a coordinate to a token""" 225 if self.tokenization_mode in ( 226 TokenizationMode.AOTP_UT_rasterized, 227 TokenizationMode.AOTP_UT_uniform, 228 ): 229 return None 230 else: 231 return self._node_strings_map
map a coordinate to a token
275 @cached_property 276 def token_arr(self) -> list[str] | None: 277 "get the token array if the max_grid_size is specified" 278 if self.max_grid_size is None: 279 return None 280 return self._token_arr
get the token array if the max_grid_size is specified
287 @cached_property 288 def tokenizer_map(self) -> dict[str, int] | None: 289 "get the tokenizer map if the max_grid_size is specified" 290 if self.max_grid_size is None: 291 return None 292 return self._tokenizer_map
get the tokenizer map if the max_grid_size is specified
298 @property 299 def vocab_size(self) -> int | None: 300 "get the size of the vocabulary if the max_grid_size is specified" 301 if self.max_grid_size is None: 302 return None 303 return self._vocab_size
get the size of the vocabulary if the max_grid_size is specified
310 @property 311 def n_tokens(self) -> int | None: 312 "get the number of tokens if the max_grid_size is specified" 313 if self.max_grid_size is None: 314 return None 315 return self._n_tokens
get the number of tokens if the max_grid_size is specified
321 @cached_property 322 def padding_token_index(self) -> int | None: 323 "get the index of the padding token if it exists" 324 if self.max_grid_size is None: 325 return None 326 return self._padding_token_index
get the index of the padding token if it exists
343 def coords_to_strings( 344 self, 345 coords: list[CoordTup], 346 when_noncoord: WhenMissing = "skip", 347 ) -> list[str]: 348 """map a list of coordinate tuples (and maybe other tokens) to strings 349 350 wraps `maze_dataset.token_utils.coords_to_strings` with either 351 `_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode 352 """ 353 if self.tokenization_mode in ( 354 TokenizationMode.AOTP_UT_rasterized, 355 TokenizationMode.AOTP_UT_uniform, 356 ): 357 return coords_to_strings( 358 coords=coords, 359 coord_to_strings_func=_coord_to_strings_UT, 360 when_noncoord=when_noncoord, 361 ) 362 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 363 return coords_to_strings( 364 coords=coords, 365 coord_to_strings_func=_coord_to_strings_indexed, 366 when_noncoord=when_noncoord, 367 ) 368 else: 369 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 370 raise ValueError(err_msg)
map a list of coordinate tuples (and maybe other tokens) to strings
wraps maze_dataset.token_utils.coords_to_strings
with either
_coord_to_strings_UT
or _coord_to_strings_indexed
depending on the tokenization mode
390 @classmethod 391 def strings_to_coords( 392 cls, 393 text: str | list[str], 394 when_noncoord: WhenMissing = "skip", 395 ) -> list[str | CoordTup]: 396 "wrapper for `maze_dataset.token_utils.strings_to_coords`" 397 return strings_to_coords(text=text, when_noncoord=when_noncoord)
wrapper for maze_dataset.token_utils.strings_to_coords
399 def encode(self, text: str | list[str]) -> list[int]: 400 """encode a string or list of strings into a list of tokens""" 401 try: 402 if isinstance(text, str): 403 text = text.split() 404 return [self.tokenizer_map[token] for token in text] 405 except KeyError as e: 406 err_msg: str = ( 407 f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}" 408 ) 409 raise TokenError(err_msg) from e
encode a string or list of strings into a list of tokens
411 def decode( 412 self, 413 tokens: Sequence[int], 414 joined_tokens: bool = False, 415 ) -> list[str] | str: 416 """decode a list of tokens into a string or list of strings""" 417 try: 418 output: list[str] = [self.token_arr[token] for token in tokens] 419 except IndexError as e: 420 err_msg: str = ( 421 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}" 422 ) 423 raise TokenError(err_msg) from e 424 if joined_tokens: 425 return " ".join(output) 426 else: 427 return output
decode a list of tokens into a string or list of strings
432 @cached_property 433 def coordinate_tokens_coords(self) -> dict[CoordTup, int]: 434 "map of coordiante tuples to their token ids, only valid for UT" 435 # print(f"{self.tokenization_mode = }") 436 if not self.is_UT(): 437 err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }" 438 raise ValueError(err_msg) 439 440 if self.max_grid_size is None: 441 err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }" 442 raise ValueError(err_msg) 443 444 raw_converted: list[CoordTup | str] = self.strings_to_coords( 445 self.token_arr, 446 when_noncoord="include", 447 ) 448 449 # filter out non-coordinates 450 return { 451 coord: i 452 for i, coord in enumerate(raw_converted) 453 if not isinstance(coord, str) 454 }
map of coordiante tuples to their token ids, only valid for UT
456 @cached_property 457 def coordinate_tokens_ids(self) -> dict[str, int]: 458 "map of coordinate tokens to their token ids, only valid for UT" 459 # checks performed in call 460 output: dict[str, int] = dict() 461 462 for coord, index in self.coordinate_tokens_coords.items(): 463 _for_key: list[str] = self.coords_to_strings([coord]) 464 assert len(_for_key) == 1 465 output[_for_key[0]] = index 466 467 return output
map of coordinate tokens to their token ids, only valid for UT
472 def summary(self) -> dict: 473 """returns a summary of the tokenization mode""" 474 return { 475 "tokenization_mode": self.tokenization_mode.value, 476 "max_grid_size": self.max_grid_size, 477 "vocab_size": self.vocab_size, 478 }
returns a summary of the tokenization mode
480 def is_AOTP(self) -> bool: 481 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path""" 482 return self.tokenization_mode in ( 483 TokenizationMode.AOTP_UT_rasterized, 484 TokenizationMode.AOTP_UT_uniform, 485 TokenizationMode.AOTP_CTT_indexed, 486 )
returns true if a tokenization mode is Adjacency list, Origin, Target, Path
488 def is_UT(self) -> bool: 489 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 490 return is_UT(self.tokenization_mode)
returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)
492 def clear_cache(self) -> None: 493 """clears all cached properties""" 494 # delete the properties only if they exist 495 for name, prop in self.__class__.__dict__.items(): 496 if isinstance(prop, cached_property): 497 # if the property exists, delete it 498 try: # noqa: SIM105 499 delattr(self, name) 500 except AttributeError: 501 pass
clears all cached properties
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
47class TokenizationMode(Enum): 48 """legacy tokenization modes 49 50 > [!CAUTION] 51 > Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use. 52 > Use `MazeTokenizerModular` instead. 53 54 # Abbreviations: 55 - `AOTP`: Ajacency list, Origin, Target, Path 56 - `UT`: Unique Token (for each coordiate) 57 - `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers) 58 59 # Modes: 60 - `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization 61 example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)` 62 - `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible 63 uses `corner_first_ndindex` function to order the tokens 64 - `AOTP_CTT_indexed`: each coordinate is a tuple of integers 65 """ 66 67 AOTP_UT_rasterized = "AOTP_UT_rasterized" 68 AOTP_UT_uniform = "AOTP_UT_uniform" 69 AOTP_CTT_indexed = "AOTP_CTT_indexed" 70 71 def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer": 72 "convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`" 73 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)
legacy tokenization modes
Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use.
Use MazeTokenizerModular
instead.
Abbreviations:
AOTP
: Ajacency list, Origin, Target, PathUT
: Unique Token (for each coordiate)CTT
: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers)
Modes:
AOTP_UT_rasterized
: the "classic" mode: assigning tokens to each coordinate is done via rasterization example: for a 3x3 maze, token order is(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)
AOTP_UT_uniform
: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible usescorner_first_ndindex
function to order the tokensAOTP_CTT_indexed
: each coordinate is a tuple of integers
71 def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer": 72 "convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`" 73 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size)
convert the mode to a legacy MazeTokenizer
object given a max_grid_size
Inherited Members
- enum.Enum
- name
- value
51@serializable_dataclass( 52 frozen=True, 53 kw_only=True, 54 properties_to_serialize=["tokenizer_element_tree_concrete", "name"], 55) 56class MazeTokenizerModular(SerializableDataclass): 57 """Tokenizer for mazes 58 59 # Parameters 60 - `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt. 61 62 # Development 63 - To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`. 64 - Furthermore, the mapping reflected in `from_legacy` must also be maintained. 65 - Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior. 66 """ 67 68 prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field( 69 default=PromptSequencers.AOTP(), 70 loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers), 71 ) 72 73 def hash_int(self) -> int: 74 "return integer hash using blake2b" 75 return _hash_tokenizer_name(self.name) 76 77 def __hash__(self) -> int: 78 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" 79 return self.hash_int() 80 81 def hash_b64(self, n_bytes: int = 8) -> str: 82 """filename-safe base64 encoding of the hash""" 83 # Use modulus to ensure the integer fits within n_bytes * 8 bits 84 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8)) 85 86 encoded = base64.b64encode( 87 hash_mod.to_bytes(n_bytes, byteorder="big"), 88 altchars=b"-_", 89 ).decode() 90 91 # Remove any padding equals signs 92 return encoded.rstrip("=") 93 94 # Information Querying Methods 95 96 @cached_property 97 def tokenizer_elements(self) -> list[_TokenizerElement]: 98 "returns a list of all the elements of this tokenizer" 99 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()] 100 101 def tokenizer_element_tree(self, abstract: bool = False) -> str: 102 """Returns a string representation of the tree of tokenizer elements contained in `self`. 103 104 # Parameters 105 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 106 """ 107 return "\n".join( 108 [ 109 type(self).__name__, 110 self.prompt_sequencer.tokenizer_element_tree( 111 abstract=abstract, 112 depth=1, 113 ), 114 ], 115 ) 116 117 @property 118 def tokenizer_element_tree_concrete(self) -> str: 119 """Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`.""" 120 return self.tokenizer_element_tree() 121 122 def tokenizer_element_dict(self) -> dict: 123 """Nested dictionary of the internal `TokenizerElement`s.""" 124 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()} 125 126 @property 127 def name(self) -> str: 128 """Serializes MazeTokenizer into a key for encoding in zanj""" 129 return "-".join([type(self).__name__, self.prompt_sequencer.name]) # noqa: FLY002 130 131 def summary(self) -> dict[str, str]: 132 """Single-level dictionary of the internal `TokenizerElement`s.""" 133 return { 134 # "prompt_sequencer": self.prompt_sequencer.name, 135 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements}, 136 } 137 138 @staticmethod 139 def _type_check(obj: any) -> None: 140 """Helper method for `has_element`""" 141 if not ( 142 isinstance(obj, _TokenizerElement) 143 or (isinstance(obj, type) and issubclass(obj, _TokenizerElement)) 144 ): 145 err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass." 146 raise TypeError(err_msg) 147 148 def _has_element_singular( 149 self, 150 el: type[_TokenizerElement] | _TokenizerElement, 151 ) -> bool: 152 """Helper method for `has_element`""" 153 self._type_check(el) 154 if isinstance(el, type): 155 return any(isinstance(e, el) for e in self.tokenizer_elements) 156 else: 157 return el in self.tokenizer_elements 158 159 def has_element( 160 self, 161 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement], 162 ) -> bool: 163 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`. 164 165 Querying with a partial subset of `_TokenizerElement` fields is not currently supported. 166 To do such a query, assemble multiple calls to `has_elements`. 167 168 # Parameters 169 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes. 170 If an instance is provided, then comparison is done via instance equality. 171 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted. 172 """ 173 if len(elements) == 1 and isinstance(elements[0], Iterable): 174 elements = elements[0] 175 return all(self._has_element_singular(e) for e in elements) 176 177 def is_valid(self, do_except: bool = False) -> bool: 178 """Returns `True` if `self` is a valid tokenizer. 179 180 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method. 181 """ 182 return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements) 183 184 def is_legacy_equivalent(self) -> bool: 185 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`.""" 186 return any( 187 self == MazeTokenizerModular.from_legacy(tok_mode) 188 for tok_mode in TokenizationMode 189 ) 190 191 def is_tested_tokenizer(self, do_except: bool = False) -> bool: 192 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers. 193 194 uses an fst on the `name` attributes of all the tokenizers 195 196 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested. 197 """ 198 is_valid: bool = self.is_valid(do_except=do_except) 199 in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except) 200 201 if do_except: 202 assert is_valid, "self.is_valid returns False" 203 return True 204 else: 205 return in_tested_fst and is_valid 206 207 def is_AOTP(self) -> bool: 208 "is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path" 209 return self.has_element(PromptSequencers.AOTP) 210 211 def is_UT(self) -> bool: 212 "is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)" 213 return self.has_element(CoordTokenizers.UT) 214 215 # Alternate Constructors 216 # ====================== 217 218 @classmethod 219 def from_legacy( 220 cls, 221 legacy_maze_tokenizer: MazeTokenizer | TokenizationMode, 222 ) -> "MazeTokenizerModular": 223 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance.""" 224 if isinstance(legacy_maze_tokenizer, MazeTokenizer): 225 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode 226 return { 227 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(), 228 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(), 229 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular( 230 prompt_sequencer=PromptSequencers.AOTP( 231 coord_tokenizer=CoordTokenizers.CTT(), 232 ), 233 ), 234 }[legacy_maze_tokenizer] 235 236 # Simple properties 237 # ================= 238 @classmethod 239 def from_tokens( 240 cls, 241 tokens: str | list[str], 242 ) -> "MazeTokenizerModular": 243 """Infers most `MazeTokenizerModular` parameters from a full sequence of tokens.""" 244 raise NotImplementedError( 245 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported", 246 ) 247 248 @property 249 def token_arr(self) -> list[str] | None: 250 """map from index to token""" 251 return VOCAB_LIST 252 253 @property 254 def tokenizer_map(self) -> dict[str, int]: 255 """map from token to index""" 256 return VOCAB_TOKEN_TO_INDEX 257 258 @property 259 def vocab_size(self) -> int: 260 """Number of tokens in the static vocab""" 261 return len(VOCAB_LIST) 262 263 @property 264 def n_tokens(self) -> int: 265 "get the number of tokens in the vocabulary (deprecated)" 266 err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead." 267 raise NameError(err_msg) 268 269 @property 270 def padding_token_index(self) -> int: 271 "get the index of the padding token" 272 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING] 273 274 # conversion functions 275 # ============================================================ 276 277 def to_tokens( 278 self, 279 maze: LatticeMaze, 280 ) -> list[str]: 281 """Converts maze into a list of tokens.""" 282 return self.prompt_sequencer.to_tokens(maze) 283 284 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]: 285 "calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords" 286 return list( 287 flatten( 288 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords], 289 ), 290 ) 291 292 # TODO: unclear why we need to use `noqa: N805` here since its a classmethod 293 # maybe we need to hit every overload with `@classmethod`? 294 @overload 295 def strings_to_coords( 296 cls, # noqa: N805 297 text: str | list[str], 298 when_noncoord: Literal["skip"] = "skip", 299 ) -> list[CoordTup]: ... 300 @overload 301 def strings_to_coords( 302 cls, # noqa: N805 303 text: str | list[str], 304 when_noncoord: Literal["error"] = "error", 305 ) -> list[CoordTup]: ... 306 @overload 307 def strings_to_coords( 308 cls, # noqa: N805 309 text: str | list[str], 310 when_noncoord: Literal["include"] = "include", 311 ) -> list[str | CoordTup]: ... 312 @classmethod 313 def strings_to_coords( 314 cls, 315 text: str | list[str], 316 when_noncoord: WhenMissing = "skip", 317 ) -> list[str | CoordTup]: 318 "wrapper for maze_dataset.token_utils.strings_to_coords" 319 warnings.warn( 320 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.", 321 TokenizerPendingDeprecationWarning, 322 ) 323 return strings_to_coords(text=text, when_noncoord=when_noncoord) 324 325 @staticmethod 326 def encode(text: str | list[str]) -> list[int]: 327 """encode a string or list of strings into a list of tokens""" 328 try: 329 if isinstance(text, str): 330 text = text.split() 331 return [VOCAB_TOKEN_TO_INDEX[token] for token in text] 332 except KeyError as e: 333 err_msg: str = f"Token {e} not found in `VOCAB`." 334 raise TokenError(err_msg) from e 335 336 @staticmethod 337 def decode( 338 token_ids: Sequence[int], 339 joined_tokens: bool = False, 340 ) -> list[str] | str: 341 """decode a list of tokens into a string or list of strings""" 342 try: 343 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids] 344 except IndexError as e: 345 err_msg: str = f"Token index '{e}' not found in `VOCAB`." 346 raise TokenError(err_msg) from e 347 if joined_tokens: 348 return " ".join(output) 349 else: 350 return output
Tokenizer for mazes
Parameters
prompt_sequencer
: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt.
Development
- To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy
TokenizationMode.AOTP_UT_Uniform
. - Furthermore, the mapping reflected in
from_legacy
must also be maintained. - Updates to
MazeTokenizerModular
or the_TokenizerElement
hierarchy must maintain that behavior.
73 def hash_int(self) -> int: 74 "return integer hash using blake2b" 75 return _hash_tokenizer_name(self.name)
return integer hash using blake2b
81 def hash_b64(self, n_bytes: int = 8) -> str: 82 """filename-safe base64 encoding of the hash""" 83 # Use modulus to ensure the integer fits within n_bytes * 8 bits 84 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8)) 85 86 encoded = base64.b64encode( 87 hash_mod.to_bytes(n_bytes, byteorder="big"), 88 altchars=b"-_", 89 ).decode() 90 91 # Remove any padding equals signs 92 return encoded.rstrip("=")
filename-safe base64 encoding of the hash
96 @cached_property 97 def tokenizer_elements(self) -> list[_TokenizerElement]: 98 "returns a list of all the elements of this tokenizer" 99 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()]
returns a list of all the elements of this tokenizer
101 def tokenizer_element_tree(self, abstract: bool = False) -> str: 102 """Returns a string representation of the tree of tokenizer elements contained in `self`. 103 104 # Parameters 105 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 106 """ 107 return "\n".join( 108 [ 109 type(self).__name__, 110 self.prompt_sequencer.tokenizer_element_tree( 111 abstract=abstract, 112 depth=1, 113 ), 114 ], 115 )
Returns a string representation of the tree of tokenizer elements contained in self
.
Parameters
abstract: bool
: Whether to print the name of the abstract base class or the concrete class for each_TokenizerElement
instance.
117 @property 118 def tokenizer_element_tree_concrete(self) -> str: 119 """Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`.""" 120 return self.tokenizer_element_tree()
Property wrapper for tokenizer_element_tree
so that it can be used in properties_to_serialize
.
122 def tokenizer_element_dict(self) -> dict: 123 """Nested dictionary of the internal `TokenizerElement`s.""" 124 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()}
Nested dictionary of the internal TokenizerElement
s.
126 @property 127 def name(self) -> str: 128 """Serializes MazeTokenizer into a key for encoding in zanj""" 129 return "-".join([type(self).__name__, self.prompt_sequencer.name]) # noqa: FLY002
Serializes MazeTokenizer into a key for encoding in zanj
131 def summary(self) -> dict[str, str]: 132 """Single-level dictionary of the internal `TokenizerElement`s.""" 133 return { 134 # "prompt_sequencer": self.prompt_sequencer.name, 135 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements}, 136 }
Single-level dictionary of the internal TokenizerElement
s.
159 def has_element( 160 self, 161 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement], 162 ) -> bool: 163 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`. 164 165 Querying with a partial subset of `_TokenizerElement` fields is not currently supported. 166 To do such a query, assemble multiple calls to `has_elements`. 167 168 # Parameters 169 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes. 170 If an instance is provided, then comparison is done via instance equality. 171 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted. 172 """ 173 if len(elements) == 1 and isinstance(elements[0], Iterable): 174 elements = elements[0] 175 return all(self._has_element_singular(e) for e in elements)
Returns True if the MazeTokenizerModular
instance contains ALL of the items specified in elements
.
Querying with a partial subset of _TokenizerElement
fields is not currently supported.
To do such a query, assemble multiple calls to has_elements
.
Parameters
elements
: Singleton or iterable of_TokenizerElement
instances or classes. If an instance is provided, then comparison is done via instance equality. If a class is provided, then comparison isdone viaisinstance
. I.e., any instance of that class is accepted.
177 def is_valid(self, do_except: bool = False) -> bool: 178 """Returns `True` if `self` is a valid tokenizer. 179 180 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method. 181 """ 182 return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements)
Returns True
if self
is a valid tokenizer.
Evaluates the validity of all of self.tokenizer_elements
according to each one's method.
184 def is_legacy_equivalent(self) -> bool: 185 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`.""" 186 return any( 187 self == MazeTokenizerModular.from_legacy(tok_mode) 188 for tok_mode in TokenizationMode 189 )
Returns if self
has identical stringification behavior as any legacy MazeTokenizer
.
191 def is_tested_tokenizer(self, do_except: bool = False) -> bool: 192 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers. 193 194 uses an fst on the `name` attributes of all the tokenizers 195 196 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested. 197 """ 198 is_valid: bool = self.is_valid(do_except=do_except) 199 in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except) 200 201 if do_except: 202 assert is_valid, "self.is_valid returns False" 203 return True 204 else: 205 return in_tested_fst and is_valid
Returns if the tokenizer is returned by all_tokenizers.get_all_tokenizers
, the set of tested and reliable tokenizers.
uses an fst on the name
attributes of all the tokenizers
if do_assert
is True
, raises an AssertionError
if the tokenizer is not tested.
207 def is_AOTP(self) -> bool: 208 "is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path" 209 return self.has_element(PromptSequencers.AOTP)
is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path
211 def is_UT(self) -> bool: 212 "is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)" 213 return self.has_element(CoordTokenizers.UT)
is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)
218 @classmethod 219 def from_legacy( 220 cls, 221 legacy_maze_tokenizer: MazeTokenizer | TokenizationMode, 222 ) -> "MazeTokenizerModular": 223 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance.""" 224 if isinstance(legacy_maze_tokenizer, MazeTokenizer): 225 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode 226 return { 227 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(), 228 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(), 229 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular( 230 prompt_sequencer=PromptSequencers.AOTP( 231 coord_tokenizer=CoordTokenizers.CTT(), 232 ), 233 ), 234 }[legacy_maze_tokenizer]
Maps a legacy MazeTokenizer
or TokenizationMode
to its equivalent MazeTokenizerModular
instance.
238 @classmethod 239 def from_tokens( 240 cls, 241 tokens: str | list[str], 242 ) -> "MazeTokenizerModular": 243 """Infers most `MazeTokenizerModular` parameters from a full sequence of tokens.""" 244 raise NotImplementedError( 245 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported", 246 )
Infers most MazeTokenizerModular
parameters from a full sequence of tokens.
248 @property 249 def token_arr(self) -> list[str] | None: 250 """map from index to token""" 251 return VOCAB_LIST
map from index to token
253 @property 254 def tokenizer_map(self) -> dict[str, int]: 255 """map from token to index""" 256 return VOCAB_TOKEN_TO_INDEX
map from token to index
258 @property 259 def vocab_size(self) -> int: 260 """Number of tokens in the static vocab""" 261 return len(VOCAB_LIST)
Number of tokens in the static vocab
263 @property 264 def n_tokens(self) -> int: 265 "get the number of tokens in the vocabulary (deprecated)" 266 err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead." 267 raise NameError(err_msg)
get the number of tokens in the vocabulary (deprecated)
269 @property 270 def padding_token_index(self) -> int: 271 "get the index of the padding token" 272 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING]
get the index of the padding token
277 def to_tokens( 278 self, 279 maze: LatticeMaze, 280 ) -> list[str]: 281 """Converts maze into a list of tokens.""" 282 return self.prompt_sequencer.to_tokens(maze)
Converts maze into a list of tokens.
284 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]: 285 "calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords" 286 return list( 287 flatten( 288 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords], 289 ), 290 )
calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords
312 @classmethod 313 def strings_to_coords( 314 cls, 315 text: str | list[str], 316 when_noncoord: WhenMissing = "skip", 317 ) -> list[str | CoordTup]: 318 "wrapper for maze_dataset.token_utils.strings_to_coords" 319 warnings.warn( 320 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.", 321 TokenizerPendingDeprecationWarning, 322 ) 323 return strings_to_coords(text=text, when_noncoord=when_noncoord)
wrapper for maze_dataset.token_utils.strings_to_coords
325 @staticmethod 326 def encode(text: str | list[str]) -> list[int]: 327 """encode a string or list of strings into a list of tokens""" 328 try: 329 if isinstance(text, str): 330 text = text.split() 331 return [VOCAB_TOKEN_TO_INDEX[token] for token in text] 332 except KeyError as e: 333 err_msg: str = f"Token {e} not found in `VOCAB`." 334 raise TokenError(err_msg) from e
encode a string or list of strings into a list of tokens
336 @staticmethod 337 def decode( 338 token_ids: Sequence[int], 339 joined_tokens: bool = False, 340 ) -> list[str] | str: 341 """decode a list of tokens into a string or list of strings""" 342 try: 343 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids] 344 except IndexError as e: 345 err_msg: str = f"Token index '{e}' not found in `VOCAB`." 346 raise TokenError(err_msg) from e 347 if joined_tokens: 348 return " ".join(output) 349 else: 350 return output
decode a list of tokens into a string or list of strings
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
29@serializable_dataclass(frozen=True, kw_only=True) 30class _TokenizerElement(SerializableDataclass, abc.ABC): 31 """Superclass for tokenizer elements. 32 33 Subclasses contain modular functionality for maze tokenization. 34 35 # Development 36 > [!TIP] 37 > Due to the functionality of `get_all_tokenizers()`, `_TokenizerElement` subclasses 38 > may only contain fields of type `utils.FiniteValued`. 39 > Implementing a subclass with an `int` or `float`-typed field, for example, is not supported. 40 > In the event that adding such fields is deemed necessary, `get_all_tokenizers()` must be updated. 41 42 """ 43 44 # TYPING: type hint `v` more specifically 45 @staticmethod 46 def _stringify(k: str, v: Any) -> str: # noqa: ANN401 47 if isinstance(v, bool): 48 return f"{k}={str(v)[0]}" 49 if isinstance(v, _TokenizerElement): 50 return v.name 51 if isinstance(v, tuple): 52 return f"{k}={''.join(['(', *[str(x) + ', ' for x in v], ')'])}" 53 else: 54 return f"{k}={v}" 55 56 @property 57 def name(self) -> str: 58 members_str: str = ", ".join( 59 [self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"], 60 ) 61 output: str = f"{type(self).__name__}({members_str})" 62 if "." in output and output.index("(") > output.index("."): 63 return "".join(output.split(".")[1:]) 64 else: 65 return output 66 67 def __str__(self) -> str: 68 return self.name 69 70 # TYPING: type hints for `__init_subclass__`? 71 def __init_subclass__(cls, **kwargs): # noqa: ANN204 72 """Hack: dataclass hashes don't include the class itself in the hash function inputs. 73 74 This causes dataclasses with identical fields but different types to hash identically. 75 This hack circumvents this by adding a slightly hidden field to every subclass with a value of `repr(cls)`. 76 To maintain compatibility with `all_instances`, the static type of the new field can only have 1 possible value. 77 So we type it as a singleton `Literal` type. 78 muutils 0.6.1 doesn't support `Literal` type validation, so `assert_type=False`. 79 Ignore Pylance complaining about the arg to `Literal` being an expression. 80 """ 81 super().__init_subclass__(**kwargs) 82 # we are adding a new attr here intentionally 83 cls._type_ = serializable_field( # type: ignore[attr-defined] 84 init=True, 85 repr=False, 86 default=repr(cls), 87 assert_type=False, 88 ) 89 cls.__annotations__["_type_"] = Literal[repr(cls)] 90 91 def __hash__(self) -> int: 92 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" 93 return _hash_tokenizer_name(self.name) 94 95 @classmethod 96 def _level_one_subclass(cls) -> type["_TokenizerElement"]: 97 """Returns the immediate subclass of `_TokenizerElement` of which `cls` is an instance.""" 98 return ( 99 set(cls.__mro__).intersection(set(_TokenizerElement.__subclasses__())).pop() 100 ) 101 102 def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]: 103 """Returns a list of all `_TokenizerElement` instances contained in the subtree. 104 105 Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or 106 which sit inside a `tuple` without further nesting. 107 108 # Parameters 109 - `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer. 110 """ 111 if not any(type(el) == tuple for el in self.__dict__.values()): # noqa: E721 112 return list( 113 flatten( 114 [ 115 [el, *el.tokenizer_elements()] 116 for el in self.__dict__.values() 117 if isinstance(el, _TokenizerElement) 118 ], 119 ) 120 if deep 121 else filter( 122 lambda x: isinstance(x, _TokenizerElement), 123 self.__dict__.values(), 124 ), 125 ) 126 else: 127 non_tuple_elems: list[_TokenizerElement] = list( 128 flatten( 129 [ 130 [el, *el.tokenizer_elements()] 131 for el in self.__dict__.values() 132 if isinstance(el, _TokenizerElement) 133 ] 134 if deep 135 else filter( 136 lambda x: isinstance(x, _TokenizerElement), 137 self.__dict__.values(), 138 ), 139 ), 140 ) 141 tuple_elems: list[_TokenizerElement] = list( 142 flatten( 143 [ 144 ( 145 [ 146 [tup_el, *tup_el.tokenizer_elements()] 147 for tup_el in el 148 if isinstance(tup_el, _TokenizerElement) 149 ] 150 if deep 151 else filter(lambda x: isinstance(x, _TokenizerElement), el) 152 ) 153 for el in self.__dict__.values() 154 if isinstance(el, tuple) 155 ], 156 ), 157 ) 158 non_tuple_elems.extend(tuple_elems) 159 return non_tuple_elems 160 161 def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str: 162 """Returns a string representation of the tree of tokenizer elements contained in `self`. 163 164 # Parameters 165 - `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify. 166 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 167 """ 168 name: str = "\t" * depth + ( 169 type(self).__name__ 170 if not abstract 171 else type(self)._level_one_subclass().__name__ 172 ) 173 return ( 174 name 175 + "\n" 176 + "".join( 177 el.tokenizer_element_tree(depth + 1, abstract) 178 for el in self.tokenizer_elements(deep=False) 179 ) 180 ) 181 182 def tokenizer_element_dict(self) -> dict: 183 """Returns a dictionary representation of the tree of tokenizer elements contained in `self`.""" 184 return { 185 type(self).__name__: { 186 key: ( 187 val.tokenizer_element_dict() 188 if isinstance(val, _TokenizerElement) 189 else ( 190 val 191 if not isinstance(val, tuple) 192 else [ 193 ( 194 el.tokenizer_element_dict() 195 if isinstance(el, _TokenizerElement) 196 else el 197 ) 198 for el in val 199 ] 200 ) 201 ) 202 for key, val in self.__dict__.items() 203 if key != "_type_" 204 }, 205 } 206 207 @classmethod 208 @abc.abstractmethod 209 def attribute_key(cls) -> str: 210 """Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`.""" 211 raise NotImplementedError 212 213 def to_tokens(self, *args, **kwargs) -> list[str]: 214 """Converts a maze element into a list of tokens. 215 216 Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method. 217 Those subclasses which do produce tokens should override this method. 218 """ 219 raise NotImplementedError 220 221 @abc.abstractmethod 222 def is_valid(self, do_except: bool = False) -> bool: 223 """Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`. 224 225 Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints. 226 `is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone. 227 If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass. 228 229 # Types of Invalidity 230 In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. 231 Invalidity types, in ascending order of invalidity: 232 - Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. 233 E.g., `_TokenizerElement`s which are strictly worse than some alternative. 234 - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers. 235 - Untrainable: Training functional models using these tokenizers would be (nearly) impossible. 236 - Erroneous: These tokenizers might raise exceptions during use. 237 238 # Development 239 `is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid. 240 When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary. 241 242 ## Nesting 243 In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class. 244 In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree. 245 `MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected. 246 247 ## Types of Invalidity 248 If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. 249 This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances. 250 """ 251 raise NotImplementedError
Superclass for tokenizer elements.
Subclasses contain modular functionality for maze tokenization.
Development
Due to the functionality of get_all_tokenizers()
, _TokenizerElement
subclasses
may only contain fields of type utils.FiniteValued
.
Implementing a subclass with an int
or float
-typed field, for example, is not supported.
In the event that adding such fields is deemed necessary, get_all_tokenizers()
must be updated.
56 @property 57 def name(self) -> str: 58 members_str: str = ", ".join( 59 [self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"], 60 ) 61 output: str = f"{type(self).__name__}({members_str})" 62 if "." in output and output.index("(") > output.index("."): 63 return "".join(output.split(".")[1:]) 64 else: 65 return output
102 def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]: 103 """Returns a list of all `_TokenizerElement` instances contained in the subtree. 104 105 Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or 106 which sit inside a `tuple` without further nesting. 107 108 # Parameters 109 - `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer. 110 """ 111 if not any(type(el) == tuple for el in self.__dict__.values()): # noqa: E721 112 return list( 113 flatten( 114 [ 115 [el, *el.tokenizer_elements()] 116 for el in self.__dict__.values() 117 if isinstance(el, _TokenizerElement) 118 ], 119 ) 120 if deep 121 else filter( 122 lambda x: isinstance(x, _TokenizerElement), 123 self.__dict__.values(), 124 ), 125 ) 126 else: 127 non_tuple_elems: list[_TokenizerElement] = list( 128 flatten( 129 [ 130 [el, *el.tokenizer_elements()] 131 for el in self.__dict__.values() 132 if isinstance(el, _TokenizerElement) 133 ] 134 if deep 135 else filter( 136 lambda x: isinstance(x, _TokenizerElement), 137 self.__dict__.values(), 138 ), 139 ), 140 ) 141 tuple_elems: list[_TokenizerElement] = list( 142 flatten( 143 [ 144 ( 145 [ 146 [tup_el, *tup_el.tokenizer_elements()] 147 for tup_el in el 148 if isinstance(tup_el, _TokenizerElement) 149 ] 150 if deep 151 else filter(lambda x: isinstance(x, _TokenizerElement), el) 152 ) 153 for el in self.__dict__.values() 154 if isinstance(el, tuple) 155 ], 156 ), 157 ) 158 non_tuple_elems.extend(tuple_elems) 159 return non_tuple_elems
Returns a list of all _TokenizerElement
instances contained in the subtree.
Currently only detects _TokenizerElement
instances which are either direct attributes of another instance or
which sit inside a tuple
without further nesting.
Parameters
deep: bool
: Whether to return elements nested arbitrarily deeply or just a single layer.
161 def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str: 162 """Returns a string representation of the tree of tokenizer elements contained in `self`. 163 164 # Parameters 165 - `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify. 166 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 167 """ 168 name: str = "\t" * depth + ( 169 type(self).__name__ 170 if not abstract 171 else type(self)._level_one_subclass().__name__ 172 ) 173 return ( 174 name 175 + "\n" 176 + "".join( 177 el.tokenizer_element_tree(depth + 1, abstract) 178 for el in self.tokenizer_elements(deep=False) 179 ) 180 )
Returns a string representation of the tree of tokenizer elements contained in self
.
Parameters
depth: int
: Current depth in the tree. Used internally for recursion, no need to specify.abstract: bool
: Whether to print the name of the abstract base class or the concrete class for each_TokenizerElement
instance.
182 def tokenizer_element_dict(self) -> dict: 183 """Returns a dictionary representation of the tree of tokenizer elements contained in `self`.""" 184 return { 185 type(self).__name__: { 186 key: ( 187 val.tokenizer_element_dict() 188 if isinstance(val, _TokenizerElement) 189 else ( 190 val 191 if not isinstance(val, tuple) 192 else [ 193 ( 194 el.tokenizer_element_dict() 195 if isinstance(el, _TokenizerElement) 196 else el 197 ) 198 for el in val 199 ] 200 ) 201 ) 202 for key, val in self.__dict__.items() 203 if key != "_type_" 204 }, 205 }
Returns a dictionary representation of the tree of tokenizer elements contained in self
.
207 @classmethod 208 @abc.abstractmethod 209 def attribute_key(cls) -> str: 210 """Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`.""" 211 raise NotImplementedError
Returns the binding used in MazeTokenizerModular
for that type of _TokenizerElement
.
213 def to_tokens(self, *args, **kwargs) -> list[str]: 214 """Converts a maze element into a list of tokens. 215 216 Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method. 217 Those subclasses which do produce tokens should override this method. 218 """ 219 raise NotImplementedError
Converts a maze element into a list of tokens.
Not all _TokenizerElement
subclasses produce tokens, so this is not an abstract method.
Those subclasses which do produce tokens should override this method.
221 @abc.abstractmethod 222 def is_valid(self, do_except: bool = False) -> bool: 223 """Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`. 224 225 Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints. 226 `is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone. 227 If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass. 228 229 # Types of Invalidity 230 In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. 231 Invalidity types, in ascending order of invalidity: 232 - Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. 233 E.g., `_TokenizerElement`s which are strictly worse than some alternative. 234 - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers. 235 - Untrainable: Training functional models using these tokenizers would be (nearly) impossible. 236 - Erroneous: These tokenizers might raise exceptions during use. 237 238 # Development 239 `is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid. 240 When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary. 241 242 ## Nesting 243 In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class. 244 In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree. 245 `MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected. 246 247 ## Types of Invalidity 248 If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. 249 This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances. 250 """ 251 raise NotImplementedError
Returns if self
contains data members capable of producing an overall valid MazeTokenizerModular
.
Some _TokenizerElement
instances may be created which are not useful despite obeying data member type hints.
is_valid
allows for more precise detection of invalid _TokenizerElement
s beyond type hinting alone.
If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply return True
for that subclass.
Types of Invalidity
In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity:
- Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study.
E.g.,
_TokenizerElement
s which are strictly worse than some alternative. - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
- Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
- Erroneous: These tokenizers might raise exceptions during use.
Development
is_invalid
is implemented to always return True
in some abstract classes where all currently possible subclass instances are valid.
When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.
Nesting
In general, when implementing this method, there is no need to recursively call is_valid
on nested _TokenizerElement
s contained in the class.
In other words, failures of is_valid
need not bubble up to the top of the nested _TokenizerElement
tree.
MazeTokenizerModular.is_valid
calls is_valid
on each of its _TokenizerElement
s individually, so failure at any level will be detected.
Types of Invalidity
If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments.
This could be used to create more or less stringent filters on the valid _TokenizerElement
instances.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
1053class PromptSequencers(__TokenizerElementNamespace): 1054 """Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`.""" 1055 1056 key = "prompt_sequencer" 1057 1058 @serializable_dataclass(frozen=True, kw_only=True) 1059 class _PromptSequencer(_TokenizerElement, abc.ABC): 1060 """Sequences token regions into a complete maze tokenization. 1061 1062 # Parameters 1063 - `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position. 1064 - `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`. 1065 Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s. 1066 """ 1067 1068 coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field( 1069 default=CoordTokenizers.UT(), 1070 loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers), 1071 ) 1072 adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field( 1073 default=AdjListTokenizers.AdjListCoord(), 1074 loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers), 1075 ) 1076 1077 @classmethod 1078 def attribute_key(cls) -> str: 1079 return PromptSequencers.key 1080 1081 @staticmethod 1082 def _trim_if_unsolved_maze( 1083 untrimmed: list[str], 1084 is_untargeted: bool = False, 1085 is_unsolved: bool = False, 1086 ) -> list[str]: 1087 """Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze. 1088 1089 # Development 1090 This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP. 1091 It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses. 1092 """ 1093 if is_untargeted: 1094 return tokens_between( 1095 untrimmed, 1096 VOCAB.ADJLIST_START, 1097 VOCAB.ADJLIST_END, 1098 include_start=True, 1099 include_end=True, 1100 ) 1101 if is_unsolved: 1102 if VOCAB.TARGET_END in untrimmed: 1103 return tokens_between( 1104 untrimmed, 1105 VOCAB.ADJLIST_START, 1106 VOCAB.TARGET_END, 1107 include_start=True, 1108 include_end=True, 1109 ) 1110 else: 1111 return tokens_between( 1112 untrimmed, 1113 VOCAB.ADJLIST_START, 1114 VOCAB.ORIGIN_END, 1115 include_start=True, 1116 include_end=True, 1117 ) 1118 return untrimmed 1119 1120 def to_tokens( 1121 self, 1122 maze: LatticeMaze, 1123 *args, 1124 **kwargs, 1125 ) -> list[str]: 1126 """Returns a complete list of tokens for a given set of maze elements.""" 1127 untrimmed: list[str] = self._sequence_tokens( 1128 *self._get_prompt_regions(maze), 1129 ) 1130 return self._trim_if_unsolved_maze( 1131 untrimmed, 1132 not hasattr(maze, "start_pos"), 1133 not hasattr(maze, "solution"), 1134 ) 1135 1136 def _get_prompt_regions( 1137 self, 1138 maze: LatticeMaze, 1139 *args, 1140 **kwargs, 1141 ) -> list[list[str]]: 1142 """Gets the prompt regions of a maze in a fixed sequence. 1143 1144 This method is NOT responsible for including/excluding any prompt regions. 1145 Always return according to the API described under Returns. 1146 This implementation is expected to be suitable for most `PromptSequencer` subclasses. 1147 Subclasses may override this method if needed for special behavior. 1148 1149 # Returns 1150 - [0]: list[str] Adjacency list tokens 1151 - [1]: list[str] Origin tokens 1152 - [2]: list[str] Target tokens 1153 - [3]: list[str] Path tokens 1154 1155 # `None`-valued Args 1156 If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized. 1157 To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables. 1158 """ 1159 origin: Coord | None = getattr(maze, "start_pos", None) 1160 target: list[Coord] | None = [ 1161 getattr(maze, "end_pos", None), 1162 ] # TargetTokenizer requires target: Sequence[Coord] 1163 1164 return [ 1165 ( 1166 self.adj_list_tokenizer.to_tokens( 1167 maze, 1168 coord_tokenizer=self.coord_tokenizer, 1169 ) 1170 if hasattr(self, "adj_list_tokenizer") 1171 else [] 1172 ), 1173 self.coord_tokenizer.to_tokens(origin) if origin is not None else [], 1174 ( 1175 self.target_tokenizer.to_tokens( 1176 target, 1177 coord_tokenizer=self.coord_tokenizer, 1178 ) 1179 if target[0] is not None and hasattr(self, "target_tokenizer") 1180 else [] 1181 ), 1182 ( 1183 self.path_tokenizer.to_tokens( 1184 maze, 1185 coord_tokenizer=self.coord_tokenizer, 1186 ) 1187 if hasattr(maze, "solution") and hasattr(self, "path_tokenizer") 1188 else [] 1189 ), 1190 ] 1191 1192 @abc.abstractmethod 1193 def _sequence_tokens( 1194 self, 1195 adj_list: list[str], 1196 origin: list[str] | None, 1197 target: list[str] | None, 1198 path: list[str] | None, 1199 ) -> list[str]: 1200 """Sequences token regions into a complete prompt. 1201 1202 Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as <ADJLIST_START>, <ORIGIN_END>, etc. 1203 1204 # Parameters 1205 - `adj_list`: Tokens representing the adjacency list 1206 - `origin`: Tokens representing the origin 1207 - `target`: Tokens representing the target 1208 - `path`: Tokens representing the path 1209 """ 1210 pass 1211 1212 def is_valid(self, do_except: bool = False) -> bool: 1213 # No invalid instances possible within data member type hint bounds 1214 return True 1215 1216 @serializable_dataclass(frozen=True, kw_only=True) 1217 class AOTP(_PromptSequencer): 1218 """Sequences a prompt as [adjacency list, origin, target, path]. 1219 1220 # Parameters 1221 - `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`. 1222 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`. 1223 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 1224 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 1225 1226 """ 1227 1228 target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field( 1229 default=TargetTokenizers.Unlabeled(), 1230 loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers), 1231 ) 1232 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 1233 default=PathTokenizers.StepSequence(), 1234 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 1235 ) 1236 1237 def _sequence_tokens( 1238 self, 1239 adj_list: list[str], 1240 origin: list[str], 1241 target: list[str], 1242 path: list[str], 1243 ) -> list[str]: 1244 return [ 1245 VOCAB.ADJLIST_START, 1246 *adj_list, 1247 VOCAB.ADJLIST_END, 1248 VOCAB.ORIGIN_START, 1249 *origin, 1250 VOCAB.ORIGIN_END, 1251 VOCAB.TARGET_START, 1252 *target, 1253 VOCAB.TARGET_END, 1254 VOCAB.PATH_START, 1255 *path, 1256 VOCAB.PATH_END, 1257 ] 1258 1259 @serializable_dataclass(frozen=True, kw_only=True) 1260 class AOP(_PromptSequencer): 1261 """Sequences a prompt as [adjacency list, origin, path]. 1262 1263 Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself. 1264 1265 # Parameters 1266 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 1267 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 1268 """ 1269 1270 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 1271 default=PathTokenizers.StepSequence(), 1272 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 1273 ) 1274 1275 def _sequence_tokens( 1276 self, 1277 adj_list: list[str], 1278 origin: list[str], 1279 # explicitly no target in this tokenizer 1280 target: list[str], 1281 path: list[str], 1282 ) -> list[str]: 1283 return [ 1284 VOCAB.ADJLIST_START, 1285 *adj_list, 1286 VOCAB.ADJLIST_END, 1287 VOCAB.ORIGIN_START, 1288 *origin, 1289 VOCAB.ORIGIN_END, 1290 VOCAB.TARGET_START, 1291 VOCAB.TARGET_END, 1292 VOCAB.PATH_START, 1293 *path, 1294 VOCAB.PATH_END, 1295 ]
Namespace for _PromptSequencer
subclass hierarchy used by MazeTokenizerModular
.
1216 @serializable_dataclass(frozen=True, kw_only=True) 1217 class AOTP(_PromptSequencer): 1218 """Sequences a prompt as [adjacency list, origin, target, path]. 1219 1220 # Parameters 1221 - `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`. 1222 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`. 1223 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 1224 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 1225 1226 """ 1227 1228 target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field( 1229 default=TargetTokenizers.Unlabeled(), 1230 loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers), 1231 ) 1232 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 1233 default=PathTokenizers.StepSequence(), 1234 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 1235 ) 1236 1237 def _sequence_tokens( 1238 self, 1239 adj_list: list[str], 1240 origin: list[str], 1241 target: list[str], 1242 path: list[str], 1243 ) -> list[str]: 1244 return [ 1245 VOCAB.ADJLIST_START, 1246 *adj_list, 1247 VOCAB.ADJLIST_END, 1248 VOCAB.ORIGIN_START, 1249 *origin, 1250 VOCAB.ORIGIN_END, 1251 VOCAB.TARGET_START, 1252 *target, 1253 VOCAB.TARGET_END, 1254 VOCAB.PATH_START, 1255 *path, 1256 VOCAB.PATH_END, 1257 ]
Sequences a prompt as [adjacency list, origin, target, path].
Parameters
target_tokenizer
: Tokenizer element which tokenizes the target(s) of aTargetedLatticeMaze
. Usescoord_tokenizer
to tokenize coords if that is part of the design of thatTargetTokenizer
.path_tokenizer
: Tokenizer element which tokenizes the solution path of aSolvedMaze
. Usescoord_tokenizer
to tokenize coords if that is part of the design of thatPathTokenizer
.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.elements.PromptSequencers._PromptSequencer
- coord_tokenizer
- adj_list_tokenizer
- attribute_key
- to_tokens
- is_valid
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
1259 @serializable_dataclass(frozen=True, kw_only=True) 1260 class AOP(_PromptSequencer): 1261 """Sequences a prompt as [adjacency list, origin, path]. 1262 1263 Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself. 1264 1265 # Parameters 1266 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 1267 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 1268 """ 1269 1270 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 1271 default=PathTokenizers.StepSequence(), 1272 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 1273 ) 1274 1275 def _sequence_tokens( 1276 self, 1277 adj_list: list[str], 1278 origin: list[str], 1279 # explicitly no target in this tokenizer 1280 target: list[str], 1281 path: list[str], 1282 ) -> list[str]: 1283 return [ 1284 VOCAB.ADJLIST_START, 1285 *adj_list, 1286 VOCAB.ADJLIST_END, 1287 VOCAB.ORIGIN_START, 1288 *origin, 1289 VOCAB.ORIGIN_END, 1290 VOCAB.TARGET_START, 1291 VOCAB.TARGET_END, 1292 VOCAB.PATH_START, 1293 *path, 1294 VOCAB.PATH_END, 1295 ]
Sequences a prompt as [adjacency list, origin, path].
Still includes "
Parameters
path_tokenizer
: Tokenizer element which tokenizes the solution path of aSolvedMaze
. Usescoord_tokenizer
to tokenize coords if that is part of the design of thatPathTokenizer
.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.elements.PromptSequencers._PromptSequencer
- coord_tokenizer
- adj_list_tokenizer
- attribute_key
- to_tokens
- is_valid
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
48class CoordTokenizers(__TokenizerElementNamespace): 49 """Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 50 51 key = "coord_tokenizer" 52 53 @serializable_dataclass(frozen=True, kw_only=True) 54 class _CoordTokenizer(_TokenizerElement, abc.ABC): 55 """Superclass for classes which tokenize singular coords in a maze.""" 56 57 @abc.abstractmethod 58 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: 59 pass 60 61 @classmethod 62 def attribute_key(cls) -> str: 63 return CoordTokenizers.key 64 65 def is_valid(self, do_except: bool = False) -> bool: 66 # No invalid instances possible within data member type hint bounds 67 return True 68 69 @serializable_dataclass(frozen=True, kw_only=True) 70 class UT(_CoordTokenizer): 71 """Unique token coordinate tokenizer.""" 72 73 # inherit docstring 74 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 75 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])] 76 77 @serializable_dataclass(frozen=True, kw_only=True) 78 class CTT(_CoordTokenizer): 79 """Coordinate tuple tokenizer 80 81 # Parameters 82 - `pre`: Whether all coords include an integral preceding delimiter token 83 - `intra`: Whether all coords include a delimiter token between coordinates 84 - `post`: Whether all coords include an integral following delimiter token 85 """ 86 87 pre: bool = serializable_field(default=True) 88 intra: bool = serializable_field(default=True) 89 post: bool = serializable_field(default=True) 90 # Implement methods 91 92 # inherit docstring 93 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 94 return [ 95 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"), 96 str(coord[0]), 97 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"), 98 str(coord[1]), 99 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"), 100 ]
Namespace for _CoordTokenizer
subclass hierarchy used by MazeTokenizerModular
.
69 @serializable_dataclass(frozen=True, kw_only=True) 70 class UT(_CoordTokenizer): 71 """Unique token coordinate tokenizer.""" 72 73 # inherit docstring 74 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 75 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]
Unique token coordinate tokenizer.
74 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 75 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]
Converts a maze element into a list of tokens.
Not all _TokenizerElement
subclasses produce tokens, so this is not an abstract method.
Those subclasses which do produce tokens should override this method.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
77 @serializable_dataclass(frozen=True, kw_only=True) 78 class CTT(_CoordTokenizer): 79 """Coordinate tuple tokenizer 80 81 # Parameters 82 - `pre`: Whether all coords include an integral preceding delimiter token 83 - `intra`: Whether all coords include a delimiter token between coordinates 84 - `post`: Whether all coords include an integral following delimiter token 85 """ 86 87 pre: bool = serializable_field(default=True) 88 intra: bool = serializable_field(default=True) 89 post: bool = serializable_field(default=True) 90 # Implement methods 91 92 # inherit docstring 93 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 94 return [ 95 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"), 96 str(coord[0]), 97 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"), 98 str(coord[1]), 99 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"), 100 ]
Coordinate tuple tokenizer
Parameters
93 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 94 return [ 95 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"), 96 str(coord[0]), 97 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"), 98 str(coord[1]), 99 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"), 100 ]
Converts a maze element into a list of tokens.
Not all _TokenizerElement
subclasses produce tokens, so this is not an abstract method.
Those subclasses which do produce tokens should override this method.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
356class AdjListTokenizers(__TokenizerElementNamespace): 357 """Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 358 359 key = "adj_list_tokenizer" 360 361 @serializable_dataclass(frozen=True, kw_only=True) 362 @mark_as_unsupported(_adjlist_no_pre_unsupported) 363 class _AdjListTokenizer(_TokenizerElement, abc.ABC): 364 """Specifies how the adjacency list is tokenized. 365 366 Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations. 367 See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details. 368 369 # Parameters 370 - `pre`: Whether all edge groupings include a preceding delimiter token 371 - `post`: Whether all edge groupings include a following delimiter token 372 - `shuffle_d0`: Specifies how to sequence the edge groupings. 373 If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group. 374 - `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping. 375 - `edge_subset`: Specifies the subset of lattice edges to be tokenized. 376 - `edge_permuter`: Specifies, in each edge tokenization, which coord either: 377 1. Appears first in the tokenization, for `AdjListCoord`. 378 2. Is tokenized directly as a coord, for `AdjListCardinal`. 379 - `shuffle`: For each edge, the leading coord is selected randomly. 380 - `all`: Each edge appears twice in the tokenization, appearing with both leading coords. 381 - `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details. 382 """ 383 384 pre: bool = serializable_field(default=False, assert_type=False) 385 post: bool = serializable_field(default=True) 386 shuffle_d0: bool = serializable_field(default=True) 387 edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field( 388 default=EdgeGroupings.Ungrouped(), 389 loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings), 390 ) 391 edge_subset: EdgeSubsets._EdgeSubset = serializable_field( 392 default=EdgeSubsets.ConnectionEdges(), 393 loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets), 394 ) 395 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 396 default=EdgePermuters.RandomCoords(), 397 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 398 ) 399 400 @classmethod 401 def attribute_key(cls) -> str: 402 return AdjListTokenizers.key 403 404 def is_valid(self, do_except: bool = False) -> bool: 405 # No invalid instances possible within data member type hint bounds 406 return True 407 408 @abc.abstractmethod 409 def _tokenization_callables( 410 self, 411 edges: ConnectionArray, 412 is_conn: Bool[np.ndarray, " edges"], 413 coord_tokenizer: CoordTokenizers._CoordTokenizer, 414 *args, 415 **kwargs, 416 ) -> list[Callable]: 417 """Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization. 418 419 # Returns 420 - `[0]`: leading coord tokens 421 - `[1]`: connector tokens 422 - `[2]`: trailing coord tokens 423 """ 424 pass 425 426 def _tokenize_edge_grouping( 427 self, 428 edges: ConnectionArray, 429 maze: LatticeMaze, 430 coord_tokenizer: CoordTokenizers._CoordTokenizer, 431 group_params: EdgeGroupings._GroupingTokenParams, 432 ) -> Sequence[str]: 433 """Tokenizes a single edge grouping.""" 434 cxn_ord: int = group_params["connection_token_ordinal"] 435 is_conn: Bool[np.ndarray, edges] = is_connection( 436 edges, 437 maze.connection_list, 438 ) 439 tokenize_callables = self._tokenization_callables( 440 edges, 441 is_conn, 442 coord_tokenizer, 443 ) 444 445 if group_params["grouped"]: 446 # If grouped 447 callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1] 448 repeated_callables = [ 449 tokenize_callables[i] for i in callable_permutation 450 ] 451 return flatten( 452 [ 453 tokenize_callables[0](0), 454 [ 455 [ 456 *[ 457 tok_callable(i) 458 for tok_callable in repeated_callables 459 ], 460 *( 461 (VOCAB.ADJLIST_INTRA,) 462 if group_params["intra"] 463 else () 464 ), 465 ] 466 for i in range(edges.shape[0]) 467 ], 468 ], 469 ) 470 else: 471 # If ungrouped 472 callable_permutation = [0, 2] 473 callable_permutation.insert(cxn_ord, 1) 474 tokenize_callables = [ 475 tokenize_callables[i] for i in callable_permutation 476 ] 477 478 return flatten( 479 [ 480 [ 481 [ 482 *[ 483 tok_callable(i) 484 for tok_callable in tokenize_callables 485 ], 486 *empty_sequence_if_attr_false( 487 (VOCAB.ADJLIST_INTRA,), 488 group_params, 489 "intra", 490 ), 491 ] 492 for i in range(edges.shape[0]) 493 ], 494 ], 495 ) 496 497 def to_tokens( 498 self, 499 maze: LatticeMaze, 500 coord_tokenizer: CoordTokenizers._CoordTokenizer, 501 ) -> list[str]: 502 # Get the set of edges to be tokenized 503 edges: ConnectionArray = self.edge_subset._get_edges(maze) 504 # Systematically permute the leading coord of each edge 505 edges: ConnectionArray = self.edge_permuter._permute(edges) 506 group_params: EdgeGroupings._GroupingTokenParams = ( 507 self.edge_grouping._token_params() 508 ) 509 # then, we need to group the edges 510 groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges) 511 # shuffle the groups if specified 512 if self.shuffle_d0: 513 if isinstance(groups, np.ndarray): 514 numpy_rng.shuffle(groups, axis=0) 515 elif isinstance(groups, list): 516 random.shuffle(groups) 517 else: 518 err_msg: str = f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported." 519 raise TypeError(err_msg) 520 # Tokenize each group with optional delimiters 521 tokens: list[str] = list( 522 flatten( 523 [ 524 [ 525 *empty_sequence_if_attr_false( 526 (VOCAB.ADJLIST_PRE,), 527 self, 528 "pre", 529 ), 530 *self._tokenize_edge_grouping( 531 group, 532 maze, 533 coord_tokenizer, 534 group_params, 535 ), 536 *empty_sequence_if_attr_false( 537 (VOCAB.ADJACENCY_ENDLINE,), 538 self, 539 "post", 540 ), 541 ] 542 for group in groups 543 ], 544 ), 545 ) 546 return tokens 547 548 @serializable_dataclass(frozen=True, kw_only=True) 549 class AdjListCoord(_AdjListTokenizer): 550 """Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.""" 551 552 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 553 default=EdgePermuters.RandomCoords(), 554 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 555 ) 556 557 def _tokenization_callables( 558 self, 559 edges: ConnectionArray, 560 is_conn: Bool[np.ndarray, " edges"], 561 coord_tokenizer: CoordTokenizers._CoordTokenizer, 562 *args, 563 **kwargs, 564 ) -> list[Callable]: 565 # Map from `is_conn` to the tokens which represent connections and walls 566 conn_token_map: dict[bool, str] = { 567 True: VOCAB.CONNECTOR, 568 False: VOCAB.ADJLIST_WALL, 569 } 570 return [ 571 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 572 lambda i: conn_token_map[is_conn[i]], 573 lambda i: coord_tokenizer.to_tokens(edges[i, 1]), 574 ] 575 576 @serializable_dataclass(frozen=True, kw_only=True) 577 class AdjListCardinal(_AdjListTokenizer): 578 """Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members. 579 580 # Parameters 581 - `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens. 582 """ 583 584 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 585 default=EdgePermuters.BothCoords(), 586 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 587 ) 588 589 def _tokenization_callables( 590 self, 591 edges: ConnectionArray, 592 is_conn: Bool[np.ndarray, " edges"], 593 coord_tokenizer: CoordTokenizers._CoordTokenizer, 594 *args, 595 **kwargs, 596 ) -> list[Callable]: 597 # Map from `is_conn` to the tokens which represent connections and walls 598 conn_token_map: dict[bool, str] = { 599 True: VOCAB.CONNECTOR, 600 False: VOCAB.ADJLIST_WALL, 601 } 602 return [ 603 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 604 lambda i: conn_token_map[is_conn[i]], 605 lambda i: get_cardinal_direction(edges[i]), 606 ]
Namespace for _AdjListTokenizer
subclass hierarchy used by MazeTokenizerModular
.
548 @serializable_dataclass(frozen=True, kw_only=True) 549 class AdjListCoord(_AdjListTokenizer): 550 """Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.""" 551 552 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 553 default=EdgePermuters.RandomCoords(), 554 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 555 ) 556 557 def _tokenization_callables( 558 self, 559 edges: ConnectionArray, 560 is_conn: Bool[np.ndarray, " edges"], 561 coord_tokenizer: CoordTokenizers._CoordTokenizer, 562 *args, 563 **kwargs, 564 ) -> list[Callable]: 565 # Map from `is_conn` to the tokens which represent connections and walls 566 conn_token_map: dict[bool, str] = { 567 True: VOCAB.CONNECTOR, 568 False: VOCAB.ADJLIST_WALL, 569 } 570 return [ 571 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 572 lambda i: conn_token_map[is_conn[i]], 573 lambda i: coord_tokenizer.to_tokens(edges[i, 1]), 574 ]
Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.elements.AdjListTokenizers._AdjListTokenizer
- pre
- post
- shuffle_d0
- edge_grouping
- edge_subset
- attribute_key
- is_valid
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
576 @serializable_dataclass(frozen=True, kw_only=True) 577 class AdjListCardinal(_AdjListTokenizer): 578 """Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members. 579 580 # Parameters 581 - `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens. 582 """ 583 584 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 585 default=EdgePermuters.BothCoords(), 586 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 587 ) 588 589 def _tokenization_callables( 590 self, 591 edges: ConnectionArray, 592 is_conn: Bool[np.ndarray, " edges"], 593 coord_tokenizer: CoordTokenizers._CoordTokenizer, 594 *args, 595 **kwargs, 596 ) -> list[Callable]: 597 # Map from `is_conn` to the tokens which represent connections and walls 598 conn_token_map: dict[bool, str] = { 599 True: VOCAB.CONNECTOR, 600 False: VOCAB.ADJLIST_WALL, 601 } 602 return [ 603 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 604 lambda i: conn_token_map[is_conn[i]], 605 lambda i: get_cardinal_direction(edges[i]), 606 ]
Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.
Parameters
coord_first
: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.elements.AdjListTokenizers._AdjListTokenizer
- pre
- post
- shuffle_d0
- edge_grouping
- edge_subset
- attribute_key
- is_valid
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
103class EdgeGroupings(__TokenizerElementNamespace): 104 """Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`.""" 105 106 key = "edge_grouping" 107 108 class _GroupingTokenParams(TypedDict): 109 """A uniform private hyperparameter interface used by `AdjListTokenizer`.""" 110 111 connection_token_ordinal: Literal[0, 1, 2] 112 intra: bool 113 grouped: bool 114 115 @serializable_dataclass(frozen=True, kw_only=True) 116 class _EdgeGrouping(_TokenizerElement, abc.ABC): 117 """Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping.""" 118 119 @classmethod 120 def attribute_key(cls) -> str: 121 return EdgeGroupings.key 122 123 def is_valid(self, do_except: bool = False) -> bool: 124 return True 125 126 @abc.abstractmethod 127 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 128 """Divides a ConnectionArray into groups of edges. 129 130 Shuffles/sequences within each group if applicable. 131 """ 132 pass 133 134 @abc.abstractmethod 135 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 136 """Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize. 137 138 These hyperparameters are not used by `_EdgeGrouping` internally. 139 They are located in `_EdgeGrouping` rather than in `AdjListTokenizer` 140 since the hyperparameter space is a function of the `_EdgeGrouping` subclass. 141 This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses 142 into a uniform private interface used by `AdjListTokenizer`. 143 """ 144 pass 145 146 @serializable_dataclass(frozen=True, kw_only=True) 147 class Ungrouped(_EdgeGrouping): 148 """No grouping occurs, each edge is tokenized individually. 149 150 # Parameters 151 - `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears. 152 Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization. 153 """ 154 155 connection_token_ordinal: Literal[0, 1, 2] = serializable_field( 156 default=1, 157 assert_type=False, 158 ) 159 160 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 161 return EdgeGroupings._GroupingTokenParams( 162 connection_token_ordinal=self.connection_token_ordinal, 163 intra=False, 164 grouped=False, 165 ) 166 167 def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]: 168 return np.expand_dims(edges, 1) 169 170 @serializable_dataclass(frozen=True, kw_only=True) 171 @mark_as_unsupported(_unsupported_is_invalid) 172 class ByLeadingCoord(_EdgeGrouping): 173 """All edges with the same leading coord are grouped together. 174 175 # Parameters 176 - `intra`: Whether all edge groupings include a delimiter token between individual edge representations. 177 Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `) 178 - `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. 179 If false, the fixed order is lexicographical by (row, col). 180 In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord. 181 - `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears. 182 Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization. 183 """ 184 185 intra: bool = serializable_field(default=True) 186 shuffle_group: bool = serializable_field(default=True) 187 connection_token_ordinal: Literal[0, 1] = serializable_field( 188 default=0, 189 assert_type=False, 190 ) 191 192 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 193 return EdgeGroupings._GroupingTokenParams( 194 connection_token_ordinal=self.connection_token_ordinal, 195 intra=self.intra, 196 grouped=True, 197 ) 198 199 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 200 # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function 201 index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort( 202 (edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]), 203 ) 204 sorted_edges: ConnectionArray = edges[index_array, ...] 205 groups: list[ConnectionArray] = np.split( 206 sorted_edges, 207 np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:], 208 ) 209 if self.shuffle_group: 210 [numpy_rng.shuffle(g, axis=0) for g in groups] 211 return groups
Namespace for _EdgeGrouping
subclass hierarchy used by _AdjListTokenizer
.
146 @serializable_dataclass(frozen=True, kw_only=True) 147 class Ungrouped(_EdgeGrouping): 148 """No grouping occurs, each edge is tokenized individually. 149 150 # Parameters 151 - `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears. 152 Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization. 153 """ 154 155 connection_token_ordinal: Literal[0, 1, 2] = serializable_field( 156 default=1, 157 assert_type=False, 158 ) 159 160 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 161 return EdgeGroupings._GroupingTokenParams( 162 connection_token_ordinal=self.connection_token_ordinal, 163 intra=False, 164 grouped=False, 165 ) 166 167 def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]: 168 return np.expand_dims(edges, 1)
No grouping occurs, each edge is tokenized individually.
Parameters
connection_token_ordinal
: At which index in the edge tokenization the connector (or wall) token appears. Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
170 @serializable_dataclass(frozen=True, kw_only=True) 171 @mark_as_unsupported(_unsupported_is_invalid) 172 class ByLeadingCoord(_EdgeGrouping): 173 """All edges with the same leading coord are grouped together. 174 175 # Parameters 176 - `intra`: Whether all edge groupings include a delimiter token between individual edge representations. 177 Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `) 178 - `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. 179 If false, the fixed order is lexicographical by (row, col). 180 In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord. 181 - `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears. 182 Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization. 183 """ 184 185 intra: bool = serializable_field(default=True) 186 shuffle_group: bool = serializable_field(default=True) 187 connection_token_ordinal: Literal[0, 1] = serializable_field( 188 default=0, 189 assert_type=False, 190 ) 191 192 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 193 return EdgeGroupings._GroupingTokenParams( 194 connection_token_ordinal=self.connection_token_ordinal, 195 intra=self.intra, 196 grouped=True, 197 ) 198 199 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 200 # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function 201 index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort( 202 (edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]), 203 ) 204 sorted_edges: ConnectionArray = edges[index_array, ...] 205 groups: list[ConnectionArray] = np.split( 206 sorted_edges, 207 np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:], 208 ) 209 if self.shuffle_group: 210 [numpy_rng.shuffle(g, axis=0) for g in groups] 211 return groups
All edges with the same leading coord are grouped together.
Parameters
intra
: Whether all edge groupings include a delimiter token between individual edge representations. Note that each edge representation will already always include a connector token (VOCAB.CONNECTOR
, or possibly `)shuffle_group
: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. If false, the fixed order is lexicographical by (row, col). In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.connection_token_ordinal
: At which index in token sequence representing a single edge the connector (or wall) token appears. Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool: # noqa: ANN001 258 """Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes""" 259 if do_except: 260 err_msg: str = ( 261 f"Class `{type(self).__name__ = }, marked as unsupported, is not valid." 262 f"{type(self) = }, {self = }" 263 ) 264 raise ValueError(err_msg) 265 266 return False
Default implementation of is_valid
for mark_as_unsupported
-decorated classes
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
214class EdgePermuters(__TokenizerElementNamespace): 215 """Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`.""" 216 217 key = "edge_permuter" 218 219 @serializable_dataclass(frozen=True, kw_only=True) 220 class _EdgePermuter(_TokenizerElement, abc.ABC): 221 """Specifies how to sequence the two coords that encode a lattice edge.""" 222 223 @classmethod 224 def attribute_key(cls) -> str: 225 return EdgePermuters.key 226 227 def is_valid(self, do_except: bool = False) -> bool: 228 # No invalid instances possible within data member type hint bounds 229 return True 230 231 @staticmethod 232 @abc.abstractmethod 233 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 234 """Executes a permutation. 235 236 Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation. 237 238 # Parameters 239 - `lattice_edges`: Array of lattice edges. 240 The two coords in shape[1] must be adjacent in the lattice. 241 242 # Returns 243 - Array of lattice edges with entries along shape[1] systematically permuted. 244 - shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`. 245 """ 246 pass 247 248 @serializable_dataclass(frozen=True, kw_only=True) 249 class SortedCoords(_EdgePermuter): 250 """returns a sorted representation. useful for checking consistency""" 251 252 @staticmethod 253 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 254 return lattice_edges[ 255 np.lexsort( 256 ( 257 lattice_edges[:, 1, 1], 258 lattice_edges[:, 1, 0], 259 lattice_edges[:, 0, 1], 260 lattice_edges[:, 0, 0], 261 ), 262 ), 263 ..., 264 ] 265 266 @serializable_dataclass(frozen=True, kw_only=True) 267 class RandomCoords(_EdgePermuter): 268 """Permutes each edge randomly.""" 269 270 @staticmethod 271 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 272 numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges) 273 return lattice_edges 274 275 @serializable_dataclass(frozen=True, kw_only=True) 276 class BothCoords(_EdgePermuter): 277 """Includes both possible permutations of every edge in the output. 278 279 Since input ConnectionList has only 1 instance of each edge, 280 a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`. 281 """ 282 283 @staticmethod 284 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 285 return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)
Namespace for _EdgePermuter
subclass hierarchy used by _AdjListTokenizer
.
248 @serializable_dataclass(frozen=True, kw_only=True) 249 class SortedCoords(_EdgePermuter): 250 """returns a sorted representation. useful for checking consistency""" 251 252 @staticmethod 253 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 254 return lattice_edges[ 255 np.lexsort( 256 ( 257 lattice_edges[:, 1, 1], 258 lattice_edges[:, 1, 0], 259 lattice_edges[:, 0, 1], 260 lattice_edges[:, 0, 0], 261 ), 262 ), 263 ..., 264 ]
returns a sorted representation. useful for checking consistency
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
266 @serializable_dataclass(frozen=True, kw_only=True) 267 class RandomCoords(_EdgePermuter): 268 """Permutes each edge randomly.""" 269 270 @staticmethod 271 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 272 numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges) 273 return lattice_edges
Permutes each edge randomly.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
275 @serializable_dataclass(frozen=True, kw_only=True) 276 class BothCoords(_EdgePermuter): 277 """Includes both possible permutations of every edge in the output. 278 279 Since input ConnectionList has only 1 instance of each edge, 280 a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`. 281 """ 282 283 @staticmethod 284 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 285 return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)
Includes both possible permutations of every edge in the output.
Since input ConnectionList has only 1 instance of each edge,
a call to BothCoords._permute
will modify lattice_edges
in-place, doubling shape[0]
.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
288class EdgeSubsets(__TokenizerElementNamespace): 289 """Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`.""" 290 291 key = "edge_subset" 292 293 @serializable_dataclass(frozen=True, kw_only=True) 294 class _EdgeSubset(_TokenizerElement, abc.ABC): 295 """Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized.""" 296 297 @classmethod 298 def attribute_key(cls) -> str: 299 return EdgeSubsets.key 300 301 def is_valid(self, do_except: bool = False) -> bool: 302 return True 303 304 @abc.abstractmethod 305 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 306 """Returns the set of lattice edges to be tokenized.""" 307 pass 308 309 @serializable_dataclass(frozen=True, kw_only=True) 310 class AllLatticeEdges(_EdgeSubset): 311 """All 2n**2-2n edges of the lattice are tokenized. 312 313 If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`. 314 """ 315 316 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 317 return lattice_connection_array(maze.grid_n) 318 319 @serializable_dataclass(frozen=True, kw_only=True) 320 class ConnectionEdges(_EdgeSubset): 321 """Only edges which contain a connection are tokenized. 322 323 Alternatively, only edges which contain a wall are tokenized. 324 325 # Parameters 326 - `walls`: Whether wall edges or connection edges are tokenized. 327 If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`. 328 """ 329 330 walls: bool = serializable_field(default=False) 331 332 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 333 conn_list: ConnectionList = maze.connection_list 334 if self.walls: 335 conn_list = np.logical_not(conn_list) 336 conn_list[0, -1, :] = False 337 conn_list[1, :, -1] = False 338 return connection_list_to_adj_list( 339 conn_list, 340 shuffle_d0=False, 341 shuffle_d1=False, 342 )
Namespace for _EdgeSubset
subclass hierarchy used by _AdjListTokenizer
.
309 @serializable_dataclass(frozen=True, kw_only=True) 310 class AllLatticeEdges(_EdgeSubset): 311 """All 2n**2-2n edges of the lattice are tokenized. 312 313 If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`. 314 """ 315 316 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 317 return lattice_connection_array(maze.grid_n)
All 2n**2-2n edges of the lattice are tokenized.
If a wall exists on that edge, the edge is tokenized in the same manner, using VOCAB.ADJLIST_WALL
in place of VOCAB.CONNECTOR
.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
319 @serializable_dataclass(frozen=True, kw_only=True) 320 class ConnectionEdges(_EdgeSubset): 321 """Only edges which contain a connection are tokenized. 322 323 Alternatively, only edges which contain a wall are tokenized. 324 325 # Parameters 326 - `walls`: Whether wall edges or connection edges are tokenized. 327 If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`. 328 """ 329 330 walls: bool = serializable_field(default=False) 331 332 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 333 conn_list: ConnectionList = maze.connection_list 334 if self.walls: 335 conn_list = np.logical_not(conn_list) 336 conn_list[0, -1, :] = False 337 conn_list[1, :, -1] = False 338 return connection_list_to_adj_list( 339 conn_list, 340 shuffle_d0=False, 341 shuffle_d1=False, 342 )
Only edges which contain a connection are tokenized.
Alternatively, only edges which contain a wall are tokenized.
Parameters
walls
: Whether wall edges or connection edges are tokenized. If true,VOCAB.ADJLIST_WALL
is used in place ofVOCAB.CONNECTOR
.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
609class TargetTokenizers(__TokenizerElementNamespace): 610 """Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 611 612 key = "target_tokenizer" 613 614 @serializable_dataclass(frozen=True, kw_only=True) 615 class _TargetTokenizer(_TokenizerElement, abc.ABC): 616 """Superclass of tokenizers for maze targets.""" 617 618 @abc.abstractmethod 619 def to_tokens( 620 self, 621 targets: Sequence[Coord], 622 coord_tokenizer: CoordTokenizers._CoordTokenizer, 623 ) -> list[str]: 624 """Returns tokens representing the target.""" 625 pass 626 627 @classmethod 628 def attribute_key(cls) -> str: 629 return TargetTokenizers.key 630 631 @serializable_dataclass(frozen=True, kw_only=True) 632 class Unlabeled(_TargetTokenizer): 633 """Targets are simply listed as coord tokens. 634 635 - `post`: Whether all coords include an integral following delimiter token 636 """ 637 638 post: bool = serializable_field(default=False) 639 640 # inherit docstring 641 def to_tokens( # noqa: D102 642 self, 643 targets: Sequence[Coord], 644 coord_tokenizer: CoordTokenizers._CoordTokenizer, 645 ) -> list[str]: 646 return list( 647 flatten( 648 [ 649 [ 650 *coord_tokenizer.to_tokens(target), 651 *empty_sequence_if_attr_false( 652 [VOCAB.TARGET_POST], 653 self, 654 "post", 655 ), 656 ] 657 for target in targets 658 ], 659 ), 660 ) 661 662 # inherit docstring 663 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 664 # No invalid instances possible within data member type hint bounds 665 return True
Namespace for _TargetTokenizer
subclass hierarchy used by MazeTokenizerModular
.
631 @serializable_dataclass(frozen=True, kw_only=True) 632 class Unlabeled(_TargetTokenizer): 633 """Targets are simply listed as coord tokens. 634 635 - `post`: Whether all coords include an integral following delimiter token 636 """ 637 638 post: bool = serializable_field(default=False) 639 640 # inherit docstring 641 def to_tokens( # noqa: D102 642 self, 643 targets: Sequence[Coord], 644 coord_tokenizer: CoordTokenizers._CoordTokenizer, 645 ) -> list[str]: 646 return list( 647 flatten( 648 [ 649 [ 650 *coord_tokenizer.to_tokens(target), 651 *empty_sequence_if_attr_false( 652 [VOCAB.TARGET_POST], 653 self, 654 "post", 655 ), 656 ] 657 for target in targets 658 ], 659 ), 660 ) 661 662 # inherit docstring 663 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 664 # No invalid instances possible within data member type hint bounds 665 return True
Targets are simply listed as coord tokens.
post
: Whether all coords include an integral following delimiter token
641 def to_tokens( # noqa: D102 642 self, 643 targets: Sequence[Coord], 644 coord_tokenizer: CoordTokenizers._CoordTokenizer, 645 ) -> list[str]: 646 return list( 647 flatten( 648 [ 649 [ 650 *coord_tokenizer.to_tokens(target), 651 *empty_sequence_if_attr_false( 652 [VOCAB.TARGET_POST], 653 self, 654 "post", 655 ), 656 ] 657 for target in targets 658 ], 659 ), 660 )
Returns tokens representing the target.
663 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 664 # No invalid instances possible within data member type hint bounds 665 return True
Returns if self
contains data members capable of producing an overall valid MazeTokenizerModular
.
Some _TokenizerElement
instances may be created which are not useful despite obeying data member type hints.
is_valid
allows for more precise detection of invalid _TokenizerElement
s beyond type hinting alone.
If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply return True
for that subclass.
Types of Invalidity
In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity:
- Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study.
E.g.,
_TokenizerElement
s which are strictly worse than some alternative. - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
- Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
- Erroneous: These tokenizers might raise exceptions during use.
Development
is_invalid
is implemented to always return True
in some abstract classes where all currently possible subclass instances are valid.
When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.
Nesting
In general, when implementing this method, there is no need to recursively call is_valid
on nested _TokenizerElement
s contained in the class.
In other words, failures of is_valid
need not bubble up to the top of the nested _TokenizerElement
tree.
MazeTokenizerModular.is_valid
calls is_valid
on each of its _TokenizerElement
s individually, so failure at any level will be detected.
Types of Invalidity
If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments.
This could be used to create more or less stringent filters on the valid _TokenizerElement
instances.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
668class StepSizes(__TokenizerElementNamespace): 669 """Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`.""" 670 671 key = "step_size" 672 673 @serializable_dataclass(frozen=True, kw_only=True) 674 class _StepSize(_TokenizerElement, abc.ABC): 675 """Specifies which coords in `maze.solution` are used to represent the path.""" 676 677 @classmethod 678 def attribute_key(cls) -> str: 679 return StepSizes.key 680 681 @abc.abstractmethod # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call 682 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 683 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 684 raise NotImplementedError( 685 "Subclasses must implement `StepSize.step_indices.", 686 ) 687 688 def step_start_end_indices(self, maze: SolvedMaze) -> list[tuple[int, int]]: 689 """Returns steps as tuples of starting and ending positions for each step.""" 690 indices: list[int] = self._step_single_indices(maze) 691 # TODO: RUF007 Prefer `itertools.pairwise()` over `zip()` when iterating over successive pairs 692 return [ 693 (start, end) 694 for start, end in zip(indices[:-1], indices[1:], strict=False) # noqa: RUF007 695 ] 696 697 def is_valid(self, do_except: bool = False) -> bool: 698 # No invalid instances possible within data member type hint bounds 699 return True 700 701 @serializable_dataclass(frozen=True, kw_only=True) 702 class Singles(_StepSize): 703 """Every coord in `maze.solution` is represented. 704 705 Legacy tokenizers all use this behavior. 706 """ 707 708 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 709 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 710 return list(range(maze.solution.shape[0])) 711 712 @serializable_dataclass(frozen=True, kw_only=True) 713 @mark_as_unsupported(_unsupported_is_invalid) 714 class Straightaways(_StepSize): 715 """Only coords where the path turns are represented in the path. 716 717 I.e., the path is represented as a sequence of straightaways, 718 specified by the coords at the turns. 719 """ 720 721 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 722 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 723 last_turn_coord: Coord = maze.solution[0, ...] 724 indices: list[int] = [0] 725 for i, coord in enumerate(maze.solution): 726 if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]: 727 indices.append(i - 1) 728 last_turn_coord = maze.solution[i - 1, ...] 729 indices.append(i) 730 return indices 731 732 @serializable_dataclass(frozen=True, kw_only=True) 733 class Forks(_StepSize): 734 """Only coords at forks, where the path has >=2 options for the next step are included. 735 736 Excludes the option of backtracking. 737 The starting and ending coords are always included. 738 """ 739 740 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 741 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 742 return maze.get_solution_forking_points(always_include_endpoints=True)[0] 743 744 @serializable_dataclass(frozen=True, kw_only=True) 745 @mark_as_unsupported(_unsupported_is_invalid) 746 class ForksAndStraightaways(_StepSize): 747 """Includes the union of the coords included by `Forks` and `Straightaways`. 748 749 See documentation for those classes for details. 750 """ 751 752 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 753 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 754 return list( 755 np.unique( 756 np.concatenate( 757 ( 758 StepSizes.Straightaways()._step_single_indices(maze), 759 StepSizes.Forks()._step_single_indices(maze), 760 ), 761 ), 762 ), 763 )
Namespace for _StepSize
subclass hierarchy used by MazeTokenizerModular
.
701 @serializable_dataclass(frozen=True, kw_only=True) 702 class Singles(_StepSize): 703 """Every coord in `maze.solution` is represented. 704 705 Legacy tokenizers all use this behavior. 706 """ 707 708 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 709 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 710 return list(range(maze.solution.shape[0]))
Every coord in maze.solution
is represented.
Legacy tokenizers all use this behavior.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.elements.StepSizes._StepSize
- attribute_key
- step_start_end_indices
- is_valid
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
712 @serializable_dataclass(frozen=True, kw_only=True) 713 @mark_as_unsupported(_unsupported_is_invalid) 714 class Straightaways(_StepSize): 715 """Only coords where the path turns are represented in the path. 716 717 I.e., the path is represented as a sequence of straightaways, 718 specified by the coords at the turns. 719 """ 720 721 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 722 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 723 last_turn_coord: Coord = maze.solution[0, ...] 724 indices: list[int] = [0] 725 for i, coord in enumerate(maze.solution): 726 if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]: 727 indices.append(i - 1) 728 last_turn_coord = maze.solution[i - 1, ...] 729 indices.append(i) 730 return indices
Only coords where the path turns are represented in the path.
I.e., the path is represented as a sequence of straightaways, specified by the coords at the turns.
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool: # noqa: ANN001 258 """Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes""" 259 if do_except: 260 err_msg: str = ( 261 f"Class `{type(self).__name__ = }, marked as unsupported, is not valid." 262 f"{type(self) = }, {self = }" 263 ) 264 raise ValueError(err_msg) 265 266 return False
Default implementation of is_valid
for mark_as_unsupported
-decorated classes
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
732 @serializable_dataclass(frozen=True, kw_only=True) 733 class Forks(_StepSize): 734 """Only coords at forks, where the path has >=2 options for the next step are included. 735 736 Excludes the option of backtracking. 737 The starting and ending coords are always included. 738 """ 739 740 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 741 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 742 return maze.get_solution_forking_points(always_include_endpoints=True)[0]
Only coords at forks, where the path has >=2 options for the next step are included.
Excludes the option of backtracking. The starting and ending coords are always included.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.elements.StepSizes._StepSize
- attribute_key
- step_start_end_indices
- is_valid
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
744 @serializable_dataclass(frozen=True, kw_only=True) 745 @mark_as_unsupported(_unsupported_is_invalid) 746 class ForksAndStraightaways(_StepSize): 747 """Includes the union of the coords included by `Forks` and `Straightaways`. 748 749 See documentation for those classes for details. 750 """ 751 752 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 753 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 754 return list( 755 np.unique( 756 np.concatenate( 757 ( 758 StepSizes.Straightaways()._step_single_indices(maze), 759 StepSizes.Forks()._step_single_indices(maze), 760 ), 761 ), 762 ), 763 )
Includes the union of the coords included by Forks
and Straightaways
.
See documentation for those classes for details.
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool: # noqa: ANN001 258 """Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes""" 259 if do_except: 260 err_msg: str = ( 261 f"Class `{type(self).__name__ = }, marked as unsupported, is not valid." 262 f"{type(self) = }, {self = }" 263 ) 264 raise ValueError(err_msg) 265 266 return False
Default implementation of is_valid
for mark_as_unsupported
-decorated classes
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
766class StepTokenizers(__TokenizerElementNamespace): 767 """Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 768 769 key = "step_tokenizers" 770 771 @serializable_dataclass(frozen=True, kw_only=True) 772 class _StepTokenizer(_TokenizerElement, abc.ABC): 773 """Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized.""" 774 775 @classmethod 776 def attribute_key(cls) -> str: 777 return StepTokenizers.key 778 779 @abc.abstractmethod 780 def to_tokens( 781 self, 782 maze: SolvedMaze, 783 start_index: int, 784 end_index: int, 785 **kwargs, 786 ) -> list[str]: 787 """Tokenizes a single step in the solution. 788 789 # Parameters 790 - `maze`: Maze to be tokenized 791 - `start_index`: The index of the Coord in `maze.solution` at which the current step starts 792 - `end_index`: The index of the Coord in `maze.solution` at which the current step ends 793 """ 794 raise NotImplementedError( 795 "Subclasses must implement `StepTokenizer.to_tokens.", 796 ) 797 798 def is_valid(self, do_except: bool = False) -> bool: 799 # No invalid instances possible within data member type hint bounds 800 return True 801 802 @serializable_dataclass(frozen=True, kw_only=True) 803 class Coord(_StepTokenizer): 804 """A direct tokenization of the end position coord represents the step.""" 805 806 # inherit docstring 807 def to_tokens( # noqa: D102 808 self, 809 maze: SolvedMaze, 810 start_index: int, 811 end_index: int, 812 coord_tokenizer: CoordTokenizers._CoordTokenizer, 813 ) -> list[str]: 814 return coord_tokenizer.to_tokens(maze.solution[end_index, ...]) 815 816 @serializable_dataclass(frozen=True, kw_only=True) 817 class Cardinal(_StepTokenizer): 818 """A step is tokenized with a cardinal direction token. 819 820 It is the direction of the step from the starting position along the solution. 821 """ 822 823 # inherit docstring 824 def to_tokens( # noqa: D102 825 self, 826 maze: SolvedMaze, 827 start_index: int, 828 end_index: int, 829 **kwargs, 830 ) -> list[str]: 831 return [ 832 get_cardinal_direction(maze.solution[start_index : start_index + 2]), 833 ] 834 835 @serializable_dataclass(frozen=True, kw_only=True) 836 class Relative(_StepTokenizer): 837 """Tokenizes a solution step using relative first-person directions (right, left, forward, etc.). 838 839 To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH. 840 Similarly to `Cardinal`, the direction is that of the step from the starting position. 841 """ 842 843 # inherit docstring 844 def to_tokens( # noqa: D102 845 self, 846 maze: SolvedMaze, 847 start_index: int, 848 end_index: int, 849 **kwargs, 850 ) -> list[str]: 851 if start_index == 0: 852 start = maze.solution[0] 853 previous = start + np.array([1, 0]) 854 return [ 855 get_relative_direction( 856 np.concatenate( 857 ( 858 np.expand_dims(previous, 0), 859 maze.solution[start_index : start_index + 2], 860 ), 861 axis=0, 862 ), 863 ), 864 ] 865 return [ 866 get_relative_direction( 867 maze.solution[start_index - 1 : start_index + 2], 868 ), 869 ] 870 871 @serializable_dataclass(frozen=True, kw_only=True) 872 class Distance(_StepTokenizer): 873 """A count of the number of individual steps from the starting point to the end point. 874 875 Contains no information about directionality, only the distance traveled in the step. 876 `Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`. 877 This constraint is enforced in `_PathTokenizer.is_valid`. 878 """ 879 880 # inherit docstring 881 def to_tokens( # noqa: D102 882 self, 883 maze: SolvedMaze, 884 start_index: int, 885 end_index: int, 886 **kwargs, 887 ) -> list[str]: 888 d: int = end_index - start_index 889 return [getattr(VOCAB, f"I_{d:03}")] 890 891 """ 892 `StepTokenizerPermutation` 893 A sequence of unique `_StepTokenizer`s. 894 This type exists mostly just for the clarity and convenience of `_PathTokenizer` code. 895 """ 896 StepTokenizerPermutation: type = ( 897 tuple[_StepTokenizer] 898 | tuple[_StepTokenizer, _StepTokenizer] 899 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer] 900 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer] 901 )
Namespace for _StepTokenizer
subclass hierarchy used by MazeTokenizerModular
.
802 @serializable_dataclass(frozen=True, kw_only=True) 803 class Coord(_StepTokenizer): 804 """A direct tokenization of the end position coord represents the step.""" 805 806 # inherit docstring 807 def to_tokens( # noqa: D102 808 self, 809 maze: SolvedMaze, 810 start_index: int, 811 end_index: int, 812 coord_tokenizer: CoordTokenizers._CoordTokenizer, 813 ) -> list[str]: 814 return coord_tokenizer.to_tokens(maze.solution[end_index, ...])
A direct tokenization of the end position coord represents the step.
807 def to_tokens( # noqa: D102 808 self, 809 maze: SolvedMaze, 810 start_index: int, 811 end_index: int, 812 coord_tokenizer: CoordTokenizers._CoordTokenizer, 813 ) -> list[str]: 814 return coord_tokenizer.to_tokens(maze.solution[end_index, ...])
Tokenizes a single step in the solution.
Parameters
maze
: Maze to be tokenizedstart_index
: The index of the Coord inmaze.solution
at which the current step startsend_index
: The index of the Coord inmaze.solution
at which the current step ends
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
816 @serializable_dataclass(frozen=True, kw_only=True) 817 class Cardinal(_StepTokenizer): 818 """A step is tokenized with a cardinal direction token. 819 820 It is the direction of the step from the starting position along the solution. 821 """ 822 823 # inherit docstring 824 def to_tokens( # noqa: D102 825 self, 826 maze: SolvedMaze, 827 start_index: int, 828 end_index: int, 829 **kwargs, 830 ) -> list[str]: 831 return [ 832 get_cardinal_direction(maze.solution[start_index : start_index + 2]), 833 ]
A step is tokenized with a cardinal direction token.
It is the direction of the step from the starting position along the solution.
824 def to_tokens( # noqa: D102 825 self, 826 maze: SolvedMaze, 827 start_index: int, 828 end_index: int, 829 **kwargs, 830 ) -> list[str]: 831 return [ 832 get_cardinal_direction(maze.solution[start_index : start_index + 2]), 833 ]
Tokenizes a single step in the solution.
Parameters
maze
: Maze to be tokenizedstart_index
: The index of the Coord inmaze.solution
at which the current step startsend_index
: The index of the Coord inmaze.solution
at which the current step ends
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
835 @serializable_dataclass(frozen=True, kw_only=True) 836 class Relative(_StepTokenizer): 837 """Tokenizes a solution step using relative first-person directions (right, left, forward, etc.). 838 839 To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH. 840 Similarly to `Cardinal`, the direction is that of the step from the starting position. 841 """ 842 843 # inherit docstring 844 def to_tokens( # noqa: D102 845 self, 846 maze: SolvedMaze, 847 start_index: int, 848 end_index: int, 849 **kwargs, 850 ) -> list[str]: 851 if start_index == 0: 852 start = maze.solution[0] 853 previous = start + np.array([1, 0]) 854 return [ 855 get_relative_direction( 856 np.concatenate( 857 ( 858 np.expand_dims(previous, 0), 859 maze.solution[start_index : start_index + 2], 860 ), 861 axis=0, 862 ), 863 ), 864 ] 865 return [ 866 get_relative_direction( 867 maze.solution[start_index - 1 : start_index + 2], 868 ), 869 ]
Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).
To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH.
Similarly to Cardinal
, the direction is that of the step from the starting position.
844 def to_tokens( # noqa: D102 845 self, 846 maze: SolvedMaze, 847 start_index: int, 848 end_index: int, 849 **kwargs, 850 ) -> list[str]: 851 if start_index == 0: 852 start = maze.solution[0] 853 previous = start + np.array([1, 0]) 854 return [ 855 get_relative_direction( 856 np.concatenate( 857 ( 858 np.expand_dims(previous, 0), 859 maze.solution[start_index : start_index + 2], 860 ), 861 axis=0, 862 ), 863 ), 864 ] 865 return [ 866 get_relative_direction( 867 maze.solution[start_index - 1 : start_index + 2], 868 ), 869 ]
Tokenizes a single step in the solution.
Parameters
maze
: Maze to be tokenizedstart_index
: The index of the Coord inmaze.solution
at which the current step startsend_index
: The index of the Coord inmaze.solution
at which the current step ends
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
871 @serializable_dataclass(frozen=True, kw_only=True) 872 class Distance(_StepTokenizer): 873 """A count of the number of individual steps from the starting point to the end point. 874 875 Contains no information about directionality, only the distance traveled in the step. 876 `Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`. 877 This constraint is enforced in `_PathTokenizer.is_valid`. 878 """ 879 880 # inherit docstring 881 def to_tokens( # noqa: D102 882 self, 883 maze: SolvedMaze, 884 start_index: int, 885 end_index: int, 886 **kwargs, 887 ) -> list[str]: 888 d: int = end_index - start_index 889 return [getattr(VOCAB, f"I_{d:03}")]
A count of the number of individual steps from the starting point to the end point.
Contains no information about directionality, only the distance traveled in the step.
Distance
must be combined with at least one other _StepTokenizer
in a StepTokenizerPermutation
.
This constraint is enforced in _PathTokenizer.is_valid
.
881 def to_tokens( # noqa: D102 882 self, 883 maze: SolvedMaze, 884 start_index: int, 885 end_index: int, 886 **kwargs, 887 ) -> list[str]: 888 d: int = end_index - start_index 889 return [getattr(VOCAB, f"I_{d:03}")]
Tokenizes a single step in the solution.
Parameters
maze
: Maze to be tokenizedstart_index
: The index of the Coord inmaze.solution
at which the current step startsend_index
: The index of the Coord inmaze.solution
at which the current step ends
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
904class PathTokenizers(__TokenizerElementNamespace): 905 """Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 906 907 key = "path_tokenizer" 908 909 @serializable_dataclass(frozen=True, kw_only=True) 910 class _PathTokenizer(_TokenizerElement, abc.ABC): 911 """Superclass of tokenizers for maze solution paths.""" 912 913 @abc.abstractmethod 914 def to_tokens( 915 self, 916 maze: SolvedMaze, 917 coord_tokenizer: CoordTokenizers._CoordTokenizer, 918 ) -> list[str]: 919 """Returns tokens representing the solution path.""" 920 pass 921 922 @classmethod 923 def attribute_key(cls) -> str: 924 return PathTokenizers.key 925 926 @serializable_dataclass(frozen=True, kw_only=True) 927 class StepSequence(_PathTokenizer, abc.ABC): 928 """Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path. 929 930 Allows for a sequence of leading and trailing tokens which don't fit the step pattern. 931 932 # Parameters 933 - `step_size`: Selects the size of a single step in the sequence 934 - `step_tokenizers`: Selects the combination and permutation of tokens 935 - `pre`: Whether all steps include an integral preceding delimiter token 936 - `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization. 937 - `post`: Whether all steps include an integral following delimiter token 938 """ 939 940 step_size: StepSizes._StepSize = serializable_field( 941 default=StepSizes.Singles(), 942 loading_fn=lambda x: _load_tokenizer_element(x, StepSizes), 943 ) 944 step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field( 945 default=(StepTokenizers.Coord(),), 946 serialization_fn=lambda x: [y.serialize() for y in x], 947 loading_fn=lambda x: tuple(x[StepTokenizers.key]), 948 ) 949 pre: bool = serializable_field(default=False) 950 intra: bool = serializable_field(default=False) 951 post: bool = serializable_field(default=False) 952 953 # inherit docstring 954 def to_tokens( # noqa: D102 955 self, 956 maze: SolvedMaze, 957 coord_tokenizer: CoordTokenizers._CoordTokenizer, 958 ) -> list[str]: 959 return [ 960 *self._leading_tokens(maze, coord_tokenizer), 961 *flatten( 962 [ 963 self._single_step_tokens(maze, start, end, coord_tokenizer) 964 for start, end in self.step_size.step_start_end_indices(maze) 965 ], 966 ), 967 *self._trailing_tokens(maze, coord_tokenizer), 968 ] 969 970 def _single_step_tokens( 971 self, 972 maze: SolvedMaze, 973 i: int, 974 j: int, 975 coord_tokenizer: CoordTokenizers._CoordTokenizer, 976 ) -> list[str]: 977 """Returns the token sequence representing a single step along the path.""" 978 step_rep_tokens: list[list[str]] = [ 979 step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer) 980 for step_tokenizer in self.step_tokenizers 981 ] 982 if self.intra: 983 step_rep_tokens_and_intra: list[str] = [None] * ( 984 len(step_rep_tokens) * 2 985 ) 986 step_rep_tokens_and_intra[::2] = step_rep_tokens 987 step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len( 988 step_rep_tokens, 989 ) 990 step_rep_tokens = list(flatten(step_rep_tokens_and_intra)) 991 all_tokens: list[str] = [ 992 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 993 *flatten(step_rep_tokens), 994 *empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"), 995 ] 996 return all_tokens 997 998 def _leading_tokens( 999 self, 1000 maze: SolvedMaze, 1001 coord_tokenizer: CoordTokenizers._CoordTokenizer, 1002 ) -> list[str]: 1003 """Returns tokens preceding those from the sequence from `_single_step_tokens`. 1004 1005 Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`. 1006 <PATH_START> should NOT be included. 1007 """ 1008 if StepTokenizers.Coord() in self.step_tokenizers: 1009 return [ 1010 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 1011 *coord_tokenizer.to_tokens(maze.solution[0, ...]), 1012 *empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"), 1013 ] 1014 return [] 1015 1016 def _trailing_tokens( 1017 self, 1018 c: Coord, 1019 coord_tokenizer: CoordTokenizers._CoordTokenizer, 1020 ) -> list[str]: 1021 """Returns tokens following those from the sequence from `_single_step_tokens`. 1022 1023 <PATH_END> should NOT be included. 1024 """ 1025 return [] 1026 1027 # inherits docstring 1028 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 1029 output: bool 1030 1031 if len(set(self.step_tokenizers)) != len(self.step_tokenizers): 1032 # Uninteresting: repeated elements are not useful 1033 output = False 1034 else: 1035 # we do noqa for the comment if false 1036 if len(self.step_tokenizers) == 1 and isinstance( 1037 self.step_tokenizers[0], 1038 StepTokenizers.Distance, 1039 ): 1040 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required. 1041 output = False 1042 else: 1043 output = True 1044 1045 if not output and do_except: 1046 raise ValueError( 1047 "PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.", 1048 ) 1049 1050 return output
Namespace for _PathTokenizer
subclass hierarchy used by MazeTokenizerModular
.
926 @serializable_dataclass(frozen=True, kw_only=True) 927 class StepSequence(_PathTokenizer, abc.ABC): 928 """Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path. 929 930 Allows for a sequence of leading and trailing tokens which don't fit the step pattern. 931 932 # Parameters 933 - `step_size`: Selects the size of a single step in the sequence 934 - `step_tokenizers`: Selects the combination and permutation of tokens 935 - `pre`: Whether all steps include an integral preceding delimiter token 936 - `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization. 937 - `post`: Whether all steps include an integral following delimiter token 938 """ 939 940 step_size: StepSizes._StepSize = serializable_field( 941 default=StepSizes.Singles(), 942 loading_fn=lambda x: _load_tokenizer_element(x, StepSizes), 943 ) 944 step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field( 945 default=(StepTokenizers.Coord(),), 946 serialization_fn=lambda x: [y.serialize() for y in x], 947 loading_fn=lambda x: tuple(x[StepTokenizers.key]), 948 ) 949 pre: bool = serializable_field(default=False) 950 intra: bool = serializable_field(default=False) 951 post: bool = serializable_field(default=False) 952 953 # inherit docstring 954 def to_tokens( # noqa: D102 955 self, 956 maze: SolvedMaze, 957 coord_tokenizer: CoordTokenizers._CoordTokenizer, 958 ) -> list[str]: 959 return [ 960 *self._leading_tokens(maze, coord_tokenizer), 961 *flatten( 962 [ 963 self._single_step_tokens(maze, start, end, coord_tokenizer) 964 for start, end in self.step_size.step_start_end_indices(maze) 965 ], 966 ), 967 *self._trailing_tokens(maze, coord_tokenizer), 968 ] 969 970 def _single_step_tokens( 971 self, 972 maze: SolvedMaze, 973 i: int, 974 j: int, 975 coord_tokenizer: CoordTokenizers._CoordTokenizer, 976 ) -> list[str]: 977 """Returns the token sequence representing a single step along the path.""" 978 step_rep_tokens: list[list[str]] = [ 979 step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer) 980 for step_tokenizer in self.step_tokenizers 981 ] 982 if self.intra: 983 step_rep_tokens_and_intra: list[str] = [None] * ( 984 len(step_rep_tokens) * 2 985 ) 986 step_rep_tokens_and_intra[::2] = step_rep_tokens 987 step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len( 988 step_rep_tokens, 989 ) 990 step_rep_tokens = list(flatten(step_rep_tokens_and_intra)) 991 all_tokens: list[str] = [ 992 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 993 *flatten(step_rep_tokens), 994 *empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"), 995 ] 996 return all_tokens 997 998 def _leading_tokens( 999 self, 1000 maze: SolvedMaze, 1001 coord_tokenizer: CoordTokenizers._CoordTokenizer, 1002 ) -> list[str]: 1003 """Returns tokens preceding those from the sequence from `_single_step_tokens`. 1004 1005 Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`. 1006 <PATH_START> should NOT be included. 1007 """ 1008 if StepTokenizers.Coord() in self.step_tokenizers: 1009 return [ 1010 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 1011 *coord_tokenizer.to_tokens(maze.solution[0, ...]), 1012 *empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"), 1013 ] 1014 return [] 1015 1016 def _trailing_tokens( 1017 self, 1018 c: Coord, 1019 coord_tokenizer: CoordTokenizers._CoordTokenizer, 1020 ) -> list[str]: 1021 """Returns tokens following those from the sequence from `_single_step_tokens`. 1022 1023 <PATH_END> should NOT be included. 1024 """ 1025 return [] 1026 1027 # inherits docstring 1028 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 1029 output: bool 1030 1031 if len(set(self.step_tokenizers)) != len(self.step_tokenizers): 1032 # Uninteresting: repeated elements are not useful 1033 output = False 1034 else: 1035 # we do noqa for the comment if false 1036 if len(self.step_tokenizers) == 1 and isinstance( 1037 self.step_tokenizers[0], 1038 StepTokenizers.Distance, 1039 ): 1040 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required. 1041 output = False 1042 else: 1043 output = True 1044 1045 if not output and do_except: 1046 raise ValueError( 1047 "PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.", 1048 ) 1049 1050 return output
Any PathTokenizer
where the tokenization may be assembled from token subsequences, each of which represents a step along the path.
Allows for a sequence of leading and trailing tokens which don't fit the step pattern.
Parameters
step_size
: Selects the size of a single step in the sequencestep_tokenizers
: Selects the combination and permutation of tokenspre
: Whether all steps include an integral preceding delimiter tokenintra
: Whether all steps include a delimiter token after each individual_StepTokenizer
tokenization.post
: Whether all steps include an integral following delimiter token
954 def to_tokens( # noqa: D102 955 self, 956 maze: SolvedMaze, 957 coord_tokenizer: CoordTokenizers._CoordTokenizer, 958 ) -> list[str]: 959 return [ 960 *self._leading_tokens(maze, coord_tokenizer), 961 *flatten( 962 [ 963 self._single_step_tokens(maze, start, end, coord_tokenizer) 964 for start, end in self.step_size.step_start_end_indices(maze) 965 ], 966 ), 967 *self._trailing_tokens(maze, coord_tokenizer), 968 ]
Returns tokens representing the solution path.
1028 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 1029 output: bool 1030 1031 if len(set(self.step_tokenizers)) != len(self.step_tokenizers): 1032 # Uninteresting: repeated elements are not useful 1033 output = False 1034 else: 1035 # we do noqa for the comment if false 1036 if len(self.step_tokenizers) == 1 and isinstance( 1037 self.step_tokenizers[0], 1038 StepTokenizers.Distance, 1039 ): 1040 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required. 1041 output = False 1042 else: 1043 output = True 1044 1045 if not output and do_except: 1046 raise ValueError( 1047 "PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.", 1048 ) 1049 1050 return output
Returns if self
contains data members capable of producing an overall valid MazeTokenizerModular
.
Some _TokenizerElement
instances may be created which are not useful despite obeying data member type hints.
is_valid
allows for more precise detection of invalid _TokenizerElement
s beyond type hinting alone.
If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply return True
for that subclass.
Types of Invalidity
In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity:
- Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study.
E.g.,
_TokenizerElement
s which are strictly worse than some alternative. - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
- Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
- Erroneous: These tokenizers might raise exceptions during use.
Development
is_invalid
is implemented to always return True
in some abstract classes where all currently possible subclass instances are valid.
When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.
Nesting
In general, when implementing this method, there is no need to recursively call is_valid
on nested _TokenizerElement
s contained in the class.
In other words, failures of is_valid
need not bubble up to the top of the nested _TokenizerElement
tree.
MazeTokenizerModular.is_valid
calls is_valid
on each of its _TokenizerElement
s individually, so failure at any level will be detected.
Types of Invalidity
If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments.
This could be used to create more or less stringent filters on the valid _TokenizerElement
instances.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
93def get_tokens_up_to_path_start( 94 tokens: list[str], 95 include_start_coord: bool = True, 96 tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform, 97) -> list[str]: 98 """get tokens up to the path start token 99 100 # Parameters: 101 - `tokens : list[str]` 102 - `include_start_coord : bool` 103 (defaults to `True`) 104 - `tokenization_mode : TokenizationMode` 105 (defaults to `TokenizationMode.AOTP_UT_uniform`) 106 107 # Returns: 108 - `list[str]` subsequence of `tokens` up to the path start token 109 110 # Raises: 111 - `ValueError` : if `tokenization_mode` is invalid 112 """ 113 warnings.warn( 114 "`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.", 115 TokenizerPendingDeprecationWarning, 116 ) 117 path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1 118 if include_start_coord: 119 if is_UT(tokenization_mode): 120 return tokens[: path_start_idx + 1] 121 elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 122 return tokens[: path_start_idx + 5] 123 else: 124 err_msg: str = f"Invalid tokenization mode: {tokenization_mode}" 125 raise ValueError(err_msg) 126 else: 127 return tokens[:path_start_idx]
get tokens up to the path start token
Parameters:
tokens : list[str]
include_start_coord : bool
(defaults toTrue
)tokenization_mode : TokenizationMode
(defaults toTokenizationMode.AOTP_UT_uniform
)
Returns:
list[str]
subsequence oftokens
up to the path start token
Raises:
ValueError
: iftokenization_mode
is invalid