maze_dataset.tokenization.modular.elements
implements subclasses of _TokenizerElement
to be used in MazeTokenizerModular
1"""implements subclasses of `_TokenizerElement` to be used in `MazeTokenizerModular`""" 2 3import abc 4import random 5from typing import ( 6 Callable, 7 Literal, 8 Sequence, 9 TypedDict, 10) 11 12import numpy as np 13from jaxtyping import Bool, Int 14from muutils.json_serialize import ( 15 serializable_dataclass, 16 serializable_field, 17) 18from muutils.misc import empty_sequence_if_attr_false, flatten 19 20# from maze_dataset import SolvedMaze 21from maze_dataset.constants import ( 22 VOCAB, 23 ConnectionArray, 24 ConnectionList, 25 Coord, 26 CoordTup, 27) 28from maze_dataset.generation import numpy_rng 29from maze_dataset.maze.lattice_maze import LatticeMaze, SolvedMaze 30from maze_dataset.token_utils import ( 31 connection_list_to_adj_list, 32 get_cardinal_direction, 33 get_relative_direction, 34 is_connection, 35 tokens_between, 36) 37from maze_dataset.tokenization.modular.element_base import ( 38 __TokenizerElementNamespace, 39 _load_tokenizer_element, 40 _TokenizerElement, 41 _unsupported_is_invalid, 42 mark_as_unsupported, 43) 44from maze_dataset.utils import lattice_connection_array 45 46 47class CoordTokenizers(__TokenizerElementNamespace): 48 """Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 49 50 key = "coord_tokenizer" 51 52 @serializable_dataclass(frozen=True, kw_only=True) 53 class _CoordTokenizer(_TokenizerElement, abc.ABC): 54 """Superclass for classes which tokenize singular coords in a maze.""" 55 56 @abc.abstractmethod 57 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: 58 pass 59 60 @classmethod 61 def attribute_key(cls) -> str: 62 return CoordTokenizers.key 63 64 def is_valid(self, do_except: bool = False) -> bool: 65 # No invalid instances possible within data member type hint bounds 66 return True 67 68 @serializable_dataclass(frozen=True, kw_only=True) 69 class UT(_CoordTokenizer): 70 """Unique token coordinate tokenizer.""" 71 72 # inherit docstring 73 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 74 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])] 75 76 @serializable_dataclass(frozen=True, kw_only=True) 77 class CTT(_CoordTokenizer): 78 """Coordinate tuple tokenizer 79 80 # Parameters 81 - `pre`: Whether all coords include an integral preceding delimiter token 82 - `intra`: Whether all coords include a delimiter token between coordinates 83 - `post`: Whether all coords include an integral following delimiter token 84 """ 85 86 pre: bool = serializable_field(default=True) 87 intra: bool = serializable_field(default=True) 88 post: bool = serializable_field(default=True) 89 # Implement methods 90 91 # inherit docstring 92 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 93 return [ 94 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"), 95 str(coord[0]), 96 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"), 97 str(coord[1]), 98 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"), 99 ] 100 101 102class EdgeGroupings(__TokenizerElementNamespace): 103 """Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`.""" 104 105 key = "edge_grouping" 106 107 class _GroupingTokenParams(TypedDict): 108 """A uniform private hyperparameter interface used by `AdjListTokenizer`.""" 109 110 connection_token_ordinal: Literal[0, 1, 2] 111 intra: bool 112 grouped: bool 113 114 @serializable_dataclass(frozen=True, kw_only=True) 115 class _EdgeGrouping(_TokenizerElement, abc.ABC): 116 """Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping.""" 117 118 @classmethod 119 def attribute_key(cls) -> str: 120 return EdgeGroupings.key 121 122 def is_valid(self, do_except: bool = False) -> bool: 123 return True 124 125 @abc.abstractmethod 126 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 127 """Divides a ConnectionArray into groups of edges. 128 129 Shuffles/sequences within each group if applicable. 130 """ 131 pass 132 133 @abc.abstractmethod 134 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 135 """Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize. 136 137 These hyperparameters are not used by `_EdgeGrouping` internally. 138 They are located in `_EdgeGrouping` rather than in `AdjListTokenizer` 139 since the hyperparameter space is a function of the `_EdgeGrouping` subclass. 140 This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses 141 into a uniform private interface used by `AdjListTokenizer`. 142 """ 143 pass 144 145 @serializable_dataclass(frozen=True, kw_only=True) 146 class Ungrouped(_EdgeGrouping): 147 """No grouping occurs, each edge is tokenized individually. 148 149 # Parameters 150 - `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears. 151 Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization. 152 """ 153 154 connection_token_ordinal: Literal[0, 1, 2] = serializable_field( 155 default=1, 156 assert_type=False, 157 ) 158 159 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 160 return EdgeGroupings._GroupingTokenParams( 161 connection_token_ordinal=self.connection_token_ordinal, 162 intra=False, 163 grouped=False, 164 ) 165 166 def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]: 167 return np.expand_dims(edges, 1) 168 169 @serializable_dataclass(frozen=True, kw_only=True) 170 @mark_as_unsupported(_unsupported_is_invalid) 171 class ByLeadingCoord(_EdgeGrouping): 172 """All edges with the same leading coord are grouped together. 173 174 # Parameters 175 - `intra`: Whether all edge groupings include a delimiter token between individual edge representations. 176 Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `) 177 - `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. 178 If false, the fixed order is lexicographical by (row, col). 179 In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord. 180 - `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears. 181 Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization. 182 """ 183 184 intra: bool = serializable_field(default=True) 185 shuffle_group: bool = serializable_field(default=True) 186 connection_token_ordinal: Literal[0, 1] = serializable_field( 187 default=0, 188 assert_type=False, 189 ) 190 191 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 192 return EdgeGroupings._GroupingTokenParams( 193 connection_token_ordinal=self.connection_token_ordinal, 194 intra=self.intra, 195 grouped=True, 196 ) 197 198 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 199 # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function 200 index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort( 201 (edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]), 202 ) 203 sorted_edges: ConnectionArray = edges[index_array, ...] 204 groups: list[ConnectionArray] = np.split( 205 sorted_edges, 206 np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:], 207 ) 208 if self.shuffle_group: 209 [numpy_rng.shuffle(g, axis=0) for g in groups] 210 return groups 211 212 213class EdgePermuters(__TokenizerElementNamespace): 214 """Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`.""" 215 216 key = "edge_permuter" 217 218 @serializable_dataclass(frozen=True, kw_only=True) 219 class _EdgePermuter(_TokenizerElement, abc.ABC): 220 """Specifies how to sequence the two coords that encode a lattice edge.""" 221 222 @classmethod 223 def attribute_key(cls) -> str: 224 return EdgePermuters.key 225 226 def is_valid(self, do_except: bool = False) -> bool: 227 # No invalid instances possible within data member type hint bounds 228 return True 229 230 @staticmethod 231 @abc.abstractmethod 232 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 233 """Executes a permutation. 234 235 Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation. 236 237 # Parameters 238 - `lattice_edges`: Array of lattice edges. 239 The two coords in shape[1] must be adjacent in the lattice. 240 241 # Returns 242 - Array of lattice edges with entries along shape[1] systematically permuted. 243 - shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`. 244 """ 245 pass 246 247 @serializable_dataclass(frozen=True, kw_only=True) 248 class SortedCoords(_EdgePermuter): 249 """returns a sorted representation. useful for checking consistency""" 250 251 @staticmethod 252 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 253 return lattice_edges[ 254 np.lexsort( 255 ( 256 lattice_edges[:, 1, 1], 257 lattice_edges[:, 1, 0], 258 lattice_edges[:, 0, 1], 259 lattice_edges[:, 0, 0], 260 ), 261 ), 262 ..., 263 ] 264 265 @serializable_dataclass(frozen=True, kw_only=True) 266 class RandomCoords(_EdgePermuter): 267 """Permutes each edge randomly.""" 268 269 @staticmethod 270 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 271 numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges) 272 return lattice_edges 273 274 @serializable_dataclass(frozen=True, kw_only=True) 275 class BothCoords(_EdgePermuter): 276 """Includes both possible permutations of every edge in the output. 277 278 Since input ConnectionList has only 1 instance of each edge, 279 a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`. 280 """ 281 282 @staticmethod 283 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 284 return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0) 285 286 287class EdgeSubsets(__TokenizerElementNamespace): 288 """Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`.""" 289 290 key = "edge_subset" 291 292 @serializable_dataclass(frozen=True, kw_only=True) 293 class _EdgeSubset(_TokenizerElement, abc.ABC): 294 """Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized.""" 295 296 @classmethod 297 def attribute_key(cls) -> str: 298 return EdgeSubsets.key 299 300 def is_valid(self, do_except: bool = False) -> bool: 301 return True 302 303 @abc.abstractmethod 304 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 305 """Returns the set of lattice edges to be tokenized.""" 306 pass 307 308 @serializable_dataclass(frozen=True, kw_only=True) 309 class AllLatticeEdges(_EdgeSubset): 310 """All 2n**2-2n edges of the lattice are tokenized. 311 312 If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`. 313 """ 314 315 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 316 return lattice_connection_array(maze.grid_n) 317 318 @serializable_dataclass(frozen=True, kw_only=True) 319 class ConnectionEdges(_EdgeSubset): 320 """Only edges which contain a connection are tokenized. 321 322 Alternatively, only edges which contain a wall are tokenized. 323 324 # Parameters 325 - `walls`: Whether wall edges or connection edges are tokenized. 326 If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`. 327 """ 328 329 walls: bool = serializable_field(default=False) 330 331 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 332 conn_list: ConnectionList = maze.connection_list 333 if self.walls: 334 conn_list = np.logical_not(conn_list) 335 conn_list[0, -1, :] = False 336 conn_list[1, :, -1] = False 337 return connection_list_to_adj_list( 338 conn_list, 339 shuffle_d0=False, 340 shuffle_d1=False, 341 ) 342 343 344def _adjlist_no_pre_unsupported(self_, do_except: bool = False) -> bool: # noqa: ANN001 345 """Returns False if `pre` is True, True otherwise.""" 346 output: bool = self_.pre is False 347 if do_except and not output: 348 raise ValueError( 349 "AdjListCoord does not support `pre == False`.", 350 ) 351 352 return output 353 354 355class AdjListTokenizers(__TokenizerElementNamespace): 356 """Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 357 358 key = "adj_list_tokenizer" 359 360 @serializable_dataclass(frozen=True, kw_only=True) 361 @mark_as_unsupported(_adjlist_no_pre_unsupported) 362 class _AdjListTokenizer(_TokenizerElement, abc.ABC): 363 """Specifies how the adjacency list is tokenized. 364 365 Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations. 366 See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details. 367 368 # Parameters 369 - `pre`: Whether all edge groupings include a preceding delimiter token 370 - `post`: Whether all edge groupings include a following delimiter token 371 - `shuffle_d0`: Specifies how to sequence the edge groupings. 372 If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group. 373 - `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping. 374 - `edge_subset`: Specifies the subset of lattice edges to be tokenized. 375 - `edge_permuter`: Specifies, in each edge tokenization, which coord either: 376 1. Appears first in the tokenization, for `AdjListCoord`. 377 2. Is tokenized directly as a coord, for `AdjListCardinal`. 378 - `shuffle`: For each edge, the leading coord is selected randomly. 379 - `all`: Each edge appears twice in the tokenization, appearing with both leading coords. 380 - `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details. 381 """ 382 383 pre: bool = serializable_field(default=False, assert_type=False) 384 post: bool = serializable_field(default=True) 385 shuffle_d0: bool = serializable_field(default=True) 386 edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field( 387 default=EdgeGroupings.Ungrouped(), 388 loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings), 389 ) 390 edge_subset: EdgeSubsets._EdgeSubset = serializable_field( 391 default=EdgeSubsets.ConnectionEdges(), 392 loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets), 393 ) 394 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 395 default=EdgePermuters.RandomCoords(), 396 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 397 ) 398 399 @classmethod 400 def attribute_key(cls) -> str: 401 return AdjListTokenizers.key 402 403 def is_valid(self, do_except: bool = False) -> bool: 404 # No invalid instances possible within data member type hint bounds 405 return True 406 407 @abc.abstractmethod 408 def _tokenization_callables( 409 self, 410 edges: ConnectionArray, 411 is_conn: Bool[np.ndarray, " edges"], 412 coord_tokenizer: CoordTokenizers._CoordTokenizer, 413 *args, 414 **kwargs, 415 ) -> list[Callable]: 416 """Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization. 417 418 # Returns 419 - `[0]`: leading coord tokens 420 - `[1]`: connector tokens 421 - `[2]`: trailing coord tokens 422 """ 423 pass 424 425 def _tokenize_edge_grouping( 426 self, 427 edges: ConnectionArray, 428 maze: LatticeMaze, 429 coord_tokenizer: CoordTokenizers._CoordTokenizer, 430 group_params: EdgeGroupings._GroupingTokenParams, 431 ) -> Sequence[str]: 432 """Tokenizes a single edge grouping.""" 433 cxn_ord: int = group_params["connection_token_ordinal"] 434 is_conn: Bool[np.ndarray, edges] = is_connection( 435 edges, 436 maze.connection_list, 437 ) 438 tokenize_callables = self._tokenization_callables( 439 edges, 440 is_conn, 441 coord_tokenizer, 442 ) 443 444 if group_params["grouped"]: 445 # If grouped 446 callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1] 447 repeated_callables = [ 448 tokenize_callables[i] for i in callable_permutation 449 ] 450 return flatten( 451 [ 452 tokenize_callables[0](0), 453 [ 454 [ 455 *[ 456 tok_callable(i) 457 for tok_callable in repeated_callables 458 ], 459 *( 460 (VOCAB.ADJLIST_INTRA,) 461 if group_params["intra"] 462 else () 463 ), 464 ] 465 for i in range(edges.shape[0]) 466 ], 467 ], 468 ) 469 else: 470 # If ungrouped 471 callable_permutation = [0, 2] 472 callable_permutation.insert(cxn_ord, 1) 473 tokenize_callables = [ 474 tokenize_callables[i] for i in callable_permutation 475 ] 476 477 return flatten( 478 [ 479 [ 480 [ 481 *[ 482 tok_callable(i) 483 for tok_callable in tokenize_callables 484 ], 485 *empty_sequence_if_attr_false( 486 (VOCAB.ADJLIST_INTRA,), 487 group_params, 488 "intra", 489 ), 490 ] 491 for i in range(edges.shape[0]) 492 ], 493 ], 494 ) 495 496 def to_tokens( 497 self, 498 maze: LatticeMaze, 499 coord_tokenizer: CoordTokenizers._CoordTokenizer, 500 ) -> list[str]: 501 # Get the set of edges to be tokenized 502 edges: ConnectionArray = self.edge_subset._get_edges(maze) 503 # Systematically permute the leading coord of each edge 504 edges: ConnectionArray = self.edge_permuter._permute(edges) 505 group_params: EdgeGroupings._GroupingTokenParams = ( 506 self.edge_grouping._token_params() 507 ) 508 # then, we need to group the edges 509 groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges) 510 # shuffle the groups if specified 511 if self.shuffle_d0: 512 if isinstance(groups, np.ndarray): 513 numpy_rng.shuffle(groups, axis=0) 514 elif isinstance(groups, list): 515 random.shuffle(groups) 516 else: 517 err_msg: str = f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported." 518 raise TypeError(err_msg) 519 # Tokenize each group with optional delimiters 520 tokens: list[str] = list( 521 flatten( 522 [ 523 [ 524 *empty_sequence_if_attr_false( 525 (VOCAB.ADJLIST_PRE,), 526 self, 527 "pre", 528 ), 529 *self._tokenize_edge_grouping( 530 group, 531 maze, 532 coord_tokenizer, 533 group_params, 534 ), 535 *empty_sequence_if_attr_false( 536 (VOCAB.ADJACENCY_ENDLINE,), 537 self, 538 "post", 539 ), 540 ] 541 for group in groups 542 ], 543 ), 544 ) 545 return tokens 546 547 @serializable_dataclass(frozen=True, kw_only=True) 548 class AdjListCoord(_AdjListTokenizer): 549 """Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.""" 550 551 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 552 default=EdgePermuters.RandomCoords(), 553 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 554 ) 555 556 def _tokenization_callables( 557 self, 558 edges: ConnectionArray, 559 is_conn: Bool[np.ndarray, " edges"], 560 coord_tokenizer: CoordTokenizers._CoordTokenizer, 561 *args, 562 **kwargs, 563 ) -> list[Callable]: 564 # Map from `is_conn` to the tokens which represent connections and walls 565 conn_token_map: dict[bool, str] = { 566 True: VOCAB.CONNECTOR, 567 False: VOCAB.ADJLIST_WALL, 568 } 569 return [ 570 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 571 lambda i: conn_token_map[is_conn[i]], 572 lambda i: coord_tokenizer.to_tokens(edges[i, 1]), 573 ] 574 575 @serializable_dataclass(frozen=True, kw_only=True) 576 class AdjListCardinal(_AdjListTokenizer): 577 """Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members. 578 579 # Parameters 580 - `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens. 581 """ 582 583 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 584 default=EdgePermuters.BothCoords(), 585 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 586 ) 587 588 def _tokenization_callables( 589 self, 590 edges: ConnectionArray, 591 is_conn: Bool[np.ndarray, " edges"], 592 coord_tokenizer: CoordTokenizers._CoordTokenizer, 593 *args, 594 **kwargs, 595 ) -> list[Callable]: 596 # Map from `is_conn` to the tokens which represent connections and walls 597 conn_token_map: dict[bool, str] = { 598 True: VOCAB.CONNECTOR, 599 False: VOCAB.ADJLIST_WALL, 600 } 601 return [ 602 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 603 lambda i: conn_token_map[is_conn[i]], 604 lambda i: get_cardinal_direction(edges[i]), 605 ] 606 607 608class TargetTokenizers(__TokenizerElementNamespace): 609 """Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 610 611 key = "target_tokenizer" 612 613 @serializable_dataclass(frozen=True, kw_only=True) 614 class _TargetTokenizer(_TokenizerElement, abc.ABC): 615 """Superclass of tokenizers for maze targets.""" 616 617 @abc.abstractmethod 618 def to_tokens( 619 self, 620 targets: Sequence[Coord], 621 coord_tokenizer: CoordTokenizers._CoordTokenizer, 622 ) -> list[str]: 623 """Returns tokens representing the target.""" 624 pass 625 626 @classmethod 627 def attribute_key(cls) -> str: 628 return TargetTokenizers.key 629 630 @serializable_dataclass(frozen=True, kw_only=True) 631 class Unlabeled(_TargetTokenizer): 632 """Targets are simply listed as coord tokens. 633 634 - `post`: Whether all coords include an integral following delimiter token 635 """ 636 637 post: bool = serializable_field(default=False) 638 639 # inherit docstring 640 def to_tokens( # noqa: D102 641 self, 642 targets: Sequence[Coord], 643 coord_tokenizer: CoordTokenizers._CoordTokenizer, 644 ) -> list[str]: 645 return list( 646 flatten( 647 [ 648 [ 649 *coord_tokenizer.to_tokens(target), 650 *empty_sequence_if_attr_false( 651 [VOCAB.TARGET_POST], 652 self, 653 "post", 654 ), 655 ] 656 for target in targets 657 ], 658 ), 659 ) 660 661 # inherit docstring 662 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 663 # No invalid instances possible within data member type hint bounds 664 return True 665 666 667class StepSizes(__TokenizerElementNamespace): 668 """Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`.""" 669 670 key = "step_size" 671 672 @serializable_dataclass(frozen=True, kw_only=True) 673 class _StepSize(_TokenizerElement, abc.ABC): 674 """Specifies which coords in `maze.solution` are used to represent the path.""" 675 676 @classmethod 677 def attribute_key(cls) -> str: 678 return StepSizes.key 679 680 @abc.abstractmethod # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call 681 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 682 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 683 raise NotImplementedError( 684 "Subclasses must implement `StepSize.step_indices.", 685 ) 686 687 def step_start_end_indices(self, maze: SolvedMaze) -> list[tuple[int, int]]: 688 """Returns steps as tuples of starting and ending positions for each step.""" 689 indices: list[int] = self._step_single_indices(maze) 690 # TODO: RUF007 Prefer `itertools.pairwise()` over `zip()` when iterating over successive pairs 691 return [ 692 (start, end) 693 for start, end in zip(indices[:-1], indices[1:], strict=False) # noqa: RUF007 694 ] 695 696 def is_valid(self, do_except: bool = False) -> bool: 697 # No invalid instances possible within data member type hint bounds 698 return True 699 700 @serializable_dataclass(frozen=True, kw_only=True) 701 class Singles(_StepSize): 702 """Every coord in `maze.solution` is represented. 703 704 Legacy tokenizers all use this behavior. 705 """ 706 707 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 708 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 709 return list(range(maze.solution.shape[0])) 710 711 @serializable_dataclass(frozen=True, kw_only=True) 712 @mark_as_unsupported(_unsupported_is_invalid) 713 class Straightaways(_StepSize): 714 """Only coords where the path turns are represented in the path. 715 716 I.e., the path is represented as a sequence of straightaways, 717 specified by the coords at the turns. 718 """ 719 720 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 721 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 722 last_turn_coord: Coord = maze.solution[0, ...] 723 indices: list[int] = [0] 724 for i, coord in enumerate(maze.solution): 725 if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]: 726 indices.append(i - 1) 727 last_turn_coord = maze.solution[i - 1, ...] 728 indices.append(i) 729 return indices 730 731 @serializable_dataclass(frozen=True, kw_only=True) 732 class Forks(_StepSize): 733 """Only coords at forks, where the path has >=2 options for the next step are included. 734 735 Excludes the option of backtracking. 736 The starting and ending coords are always included. 737 """ 738 739 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 740 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 741 return maze.get_solution_forking_points(always_include_endpoints=True)[0] 742 743 @serializable_dataclass(frozen=True, kw_only=True) 744 @mark_as_unsupported(_unsupported_is_invalid) 745 class ForksAndStraightaways(_StepSize): 746 """Includes the union of the coords included by `Forks` and `Straightaways`. 747 748 See documentation for those classes for details. 749 """ 750 751 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 752 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 753 return list( 754 np.unique( 755 np.concatenate( 756 ( 757 StepSizes.Straightaways()._step_single_indices(maze), 758 StepSizes.Forks()._step_single_indices(maze), 759 ), 760 ), 761 ), 762 ) 763 764 765class StepTokenizers(__TokenizerElementNamespace): 766 """Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 767 768 key = "step_tokenizers" 769 770 @serializable_dataclass(frozen=True, kw_only=True) 771 class _StepTokenizer(_TokenizerElement, abc.ABC): 772 """Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized.""" 773 774 @classmethod 775 def attribute_key(cls) -> str: 776 return StepTokenizers.key 777 778 @abc.abstractmethod 779 def to_tokens( 780 self, 781 maze: SolvedMaze, 782 start_index: int, 783 end_index: int, 784 **kwargs, 785 ) -> list[str]: 786 """Tokenizes a single step in the solution. 787 788 # Parameters 789 - `maze`: Maze to be tokenized 790 - `start_index`: The index of the Coord in `maze.solution` at which the current step starts 791 - `end_index`: The index of the Coord in `maze.solution` at which the current step ends 792 """ 793 raise NotImplementedError( 794 "Subclasses must implement `StepTokenizer.to_tokens.", 795 ) 796 797 def is_valid(self, do_except: bool = False) -> bool: 798 # No invalid instances possible within data member type hint bounds 799 return True 800 801 @serializable_dataclass(frozen=True, kw_only=True) 802 class Coord(_StepTokenizer): 803 """A direct tokenization of the end position coord represents the step.""" 804 805 # inherit docstring 806 def to_tokens( # noqa: D102 807 self, 808 maze: SolvedMaze, 809 start_index: int, 810 end_index: int, 811 coord_tokenizer: CoordTokenizers._CoordTokenizer, 812 ) -> list[str]: 813 return coord_tokenizer.to_tokens(maze.solution[end_index, ...]) 814 815 @serializable_dataclass(frozen=True, kw_only=True) 816 class Cardinal(_StepTokenizer): 817 """A step is tokenized with a cardinal direction token. 818 819 It is the direction of the step from the starting position along the solution. 820 """ 821 822 # inherit docstring 823 def to_tokens( # noqa: D102 824 self, 825 maze: SolvedMaze, 826 start_index: int, 827 end_index: int, 828 **kwargs, 829 ) -> list[str]: 830 return [ 831 get_cardinal_direction(maze.solution[start_index : start_index + 2]), 832 ] 833 834 @serializable_dataclass(frozen=True, kw_only=True) 835 class Relative(_StepTokenizer): 836 """Tokenizes a solution step using relative first-person directions (right, left, forward, etc.). 837 838 To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH. 839 Similarly to `Cardinal`, the direction is that of the step from the starting position. 840 """ 841 842 # inherit docstring 843 def to_tokens( # noqa: D102 844 self, 845 maze: SolvedMaze, 846 start_index: int, 847 end_index: int, 848 **kwargs, 849 ) -> list[str]: 850 if start_index == 0: 851 start = maze.solution[0] 852 previous = start + np.array([1, 0]) 853 return [ 854 get_relative_direction( 855 np.concatenate( 856 ( 857 np.expand_dims(previous, 0), 858 maze.solution[start_index : start_index + 2], 859 ), 860 axis=0, 861 ), 862 ), 863 ] 864 return [ 865 get_relative_direction( 866 maze.solution[start_index - 1 : start_index + 2], 867 ), 868 ] 869 870 @serializable_dataclass(frozen=True, kw_only=True) 871 class Distance(_StepTokenizer): 872 """A count of the number of individual steps from the starting point to the end point. 873 874 Contains no information about directionality, only the distance traveled in the step. 875 `Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`. 876 This constraint is enforced in `_PathTokenizer.is_valid`. 877 """ 878 879 # inherit docstring 880 def to_tokens( # noqa: D102 881 self, 882 maze: SolvedMaze, 883 start_index: int, 884 end_index: int, 885 **kwargs, 886 ) -> list[str]: 887 d: int = end_index - start_index 888 return [getattr(VOCAB, f"I_{d:03}")] 889 890 """ 891 `StepTokenizerPermutation` 892 A sequence of unique `_StepTokenizer`s. 893 This type exists mostly just for the clarity and convenience of `_PathTokenizer` code. 894 """ 895 StepTokenizerPermutation: type = ( 896 tuple[_StepTokenizer] 897 | tuple[_StepTokenizer, _StepTokenizer] 898 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer] 899 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer] 900 ) 901 902 903class PathTokenizers(__TokenizerElementNamespace): 904 """Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 905 906 key = "path_tokenizer" 907 908 @serializable_dataclass(frozen=True, kw_only=True) 909 class _PathTokenizer(_TokenizerElement, abc.ABC): 910 """Superclass of tokenizers for maze solution paths.""" 911 912 @abc.abstractmethod 913 def to_tokens( 914 self, 915 maze: SolvedMaze, 916 coord_tokenizer: CoordTokenizers._CoordTokenizer, 917 ) -> list[str]: 918 """Returns tokens representing the solution path.""" 919 pass 920 921 @classmethod 922 def attribute_key(cls) -> str: 923 return PathTokenizers.key 924 925 @serializable_dataclass(frozen=True, kw_only=True) 926 class StepSequence(_PathTokenizer, abc.ABC): 927 """Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path. 928 929 Allows for a sequence of leading and trailing tokens which don't fit the step pattern. 930 931 # Parameters 932 - `step_size`: Selects the size of a single step in the sequence 933 - `step_tokenizers`: Selects the combination and permutation of tokens 934 - `pre`: Whether all steps include an integral preceding delimiter token 935 - `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization. 936 - `post`: Whether all steps include an integral following delimiter token 937 """ 938 939 step_size: StepSizes._StepSize = serializable_field( 940 default=StepSizes.Singles(), 941 loading_fn=lambda x: _load_tokenizer_element(x, StepSizes), 942 ) 943 step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field( 944 default=(StepTokenizers.Coord(),), 945 serialization_fn=lambda x: [y.serialize() for y in x], 946 loading_fn=lambda x: tuple(x[StepTokenizers.key]), 947 ) 948 pre: bool = serializable_field(default=False) 949 intra: bool = serializable_field(default=False) 950 post: bool = serializable_field(default=False) 951 952 # inherit docstring 953 def to_tokens( # noqa: D102 954 self, 955 maze: SolvedMaze, 956 coord_tokenizer: CoordTokenizers._CoordTokenizer, 957 ) -> list[str]: 958 return [ 959 *self._leading_tokens(maze, coord_tokenizer), 960 *flatten( 961 [ 962 self._single_step_tokens(maze, start, end, coord_tokenizer) 963 for start, end in self.step_size.step_start_end_indices(maze) 964 ], 965 ), 966 *self._trailing_tokens(maze, coord_tokenizer), 967 ] 968 969 def _single_step_tokens( 970 self, 971 maze: SolvedMaze, 972 i: int, 973 j: int, 974 coord_tokenizer: CoordTokenizers._CoordTokenizer, 975 ) -> list[str]: 976 """Returns the token sequence representing a single step along the path.""" 977 step_rep_tokens: list[list[str]] = [ 978 step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer) 979 for step_tokenizer in self.step_tokenizers 980 ] 981 if self.intra: 982 step_rep_tokens_and_intra: list[str] = [None] * ( 983 len(step_rep_tokens) * 2 984 ) 985 step_rep_tokens_and_intra[::2] = step_rep_tokens 986 step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len( 987 step_rep_tokens, 988 ) 989 step_rep_tokens = list(flatten(step_rep_tokens_and_intra)) 990 all_tokens: list[str] = [ 991 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 992 *flatten(step_rep_tokens), 993 *empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"), 994 ] 995 return all_tokens 996 997 def _leading_tokens( 998 self, 999 maze: SolvedMaze, 1000 coord_tokenizer: CoordTokenizers._CoordTokenizer, 1001 ) -> list[str]: 1002 """Returns tokens preceding those from the sequence from `_single_step_tokens`. 1003 1004 Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`. 1005 <PATH_START> should NOT be included. 1006 """ 1007 if StepTokenizers.Coord() in self.step_tokenizers: 1008 return [ 1009 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 1010 *coord_tokenizer.to_tokens(maze.solution[0, ...]), 1011 *empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"), 1012 ] 1013 return [] 1014 1015 def _trailing_tokens( 1016 self, 1017 c: Coord, 1018 coord_tokenizer: CoordTokenizers._CoordTokenizer, 1019 ) -> list[str]: 1020 """Returns tokens following those from the sequence from `_single_step_tokens`. 1021 1022 <PATH_END> should NOT be included. 1023 """ 1024 return [] 1025 1026 # inherits docstring 1027 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 1028 output: bool 1029 1030 if len(set(self.step_tokenizers)) != len(self.step_tokenizers): 1031 # Uninteresting: repeated elements are not useful 1032 output = False 1033 else: 1034 # we do noqa for the comment if false 1035 if len(self.step_tokenizers) == 1 and isinstance( 1036 self.step_tokenizers[0], 1037 StepTokenizers.Distance, 1038 ): 1039 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required. 1040 output = False 1041 else: 1042 output = True 1043 1044 if not output and do_except: 1045 raise ValueError( 1046 "PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.", 1047 ) 1048 1049 return output 1050 1051 1052class PromptSequencers(__TokenizerElementNamespace): 1053 """Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`.""" 1054 1055 key = "prompt_sequencer" 1056 1057 @serializable_dataclass(frozen=True, kw_only=True) 1058 class _PromptSequencer(_TokenizerElement, abc.ABC): 1059 """Sequences token regions into a complete maze tokenization. 1060 1061 # Parameters 1062 - `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position. 1063 - `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`. 1064 Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s. 1065 """ 1066 1067 coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field( 1068 default=CoordTokenizers.UT(), 1069 loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers), 1070 ) 1071 adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field( 1072 default=AdjListTokenizers.AdjListCoord(), 1073 loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers), 1074 ) 1075 1076 @classmethod 1077 def attribute_key(cls) -> str: 1078 return PromptSequencers.key 1079 1080 @staticmethod 1081 def _trim_if_unsolved_maze( 1082 untrimmed: list[str], 1083 is_untargeted: bool = False, 1084 is_unsolved: bool = False, 1085 ) -> list[str]: 1086 """Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze. 1087 1088 # Development 1089 This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP. 1090 It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses. 1091 """ 1092 if is_untargeted: 1093 return tokens_between( 1094 untrimmed, 1095 VOCAB.ADJLIST_START, 1096 VOCAB.ADJLIST_END, 1097 include_start=True, 1098 include_end=True, 1099 ) 1100 if is_unsolved: 1101 if VOCAB.TARGET_END in untrimmed: 1102 return tokens_between( 1103 untrimmed, 1104 VOCAB.ADJLIST_START, 1105 VOCAB.TARGET_END, 1106 include_start=True, 1107 include_end=True, 1108 ) 1109 else: 1110 return tokens_between( 1111 untrimmed, 1112 VOCAB.ADJLIST_START, 1113 VOCAB.ORIGIN_END, 1114 include_start=True, 1115 include_end=True, 1116 ) 1117 return untrimmed 1118 1119 def to_tokens( 1120 self, 1121 maze: LatticeMaze, 1122 *args, 1123 **kwargs, 1124 ) -> list[str]: 1125 """Returns a complete list of tokens for a given set of maze elements.""" 1126 untrimmed: list[str] = self._sequence_tokens( 1127 *self._get_prompt_regions(maze), 1128 ) 1129 return self._trim_if_unsolved_maze( 1130 untrimmed, 1131 not hasattr(maze, "start_pos"), 1132 not hasattr(maze, "solution"), 1133 ) 1134 1135 def _get_prompt_regions( 1136 self, 1137 maze: LatticeMaze, 1138 *args, 1139 **kwargs, 1140 ) -> list[list[str]]: 1141 """Gets the prompt regions of a maze in a fixed sequence. 1142 1143 This method is NOT responsible for including/excluding any prompt regions. 1144 Always return according to the API described under Returns. 1145 This implementation is expected to be suitable for most `PromptSequencer` subclasses. 1146 Subclasses may override this method if needed for special behavior. 1147 1148 # Returns 1149 - [0]: list[str] Adjacency list tokens 1150 - [1]: list[str] Origin tokens 1151 - [2]: list[str] Target tokens 1152 - [3]: list[str] Path tokens 1153 1154 # `None`-valued Args 1155 If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized. 1156 To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables. 1157 """ 1158 origin: Coord | None = getattr(maze, "start_pos", None) 1159 target: list[Coord] | None = [ 1160 getattr(maze, "end_pos", None), 1161 ] # TargetTokenizer requires target: Sequence[Coord] 1162 1163 return [ 1164 ( 1165 self.adj_list_tokenizer.to_tokens( 1166 maze, 1167 coord_tokenizer=self.coord_tokenizer, 1168 ) 1169 if hasattr(self, "adj_list_tokenizer") 1170 else [] 1171 ), 1172 self.coord_tokenizer.to_tokens(origin) if origin is not None else [], 1173 ( 1174 self.target_tokenizer.to_tokens( 1175 target, 1176 coord_tokenizer=self.coord_tokenizer, 1177 ) 1178 if target[0] is not None and hasattr(self, "target_tokenizer") 1179 else [] 1180 ), 1181 ( 1182 self.path_tokenizer.to_tokens( 1183 maze, 1184 coord_tokenizer=self.coord_tokenizer, 1185 ) 1186 if hasattr(maze, "solution") and hasattr(self, "path_tokenizer") 1187 else [] 1188 ), 1189 ] 1190 1191 @abc.abstractmethod 1192 def _sequence_tokens( 1193 self, 1194 adj_list: list[str], 1195 origin: list[str] | None, 1196 target: list[str] | None, 1197 path: list[str] | None, 1198 ) -> list[str]: 1199 """Sequences token regions into a complete prompt. 1200 1201 Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as <ADJLIST_START>, <ORIGIN_END>, etc. 1202 1203 # Parameters 1204 - `adj_list`: Tokens representing the adjacency list 1205 - `origin`: Tokens representing the origin 1206 - `target`: Tokens representing the target 1207 - `path`: Tokens representing the path 1208 """ 1209 pass 1210 1211 def is_valid(self, do_except: bool = False) -> bool: 1212 # No invalid instances possible within data member type hint bounds 1213 return True 1214 1215 @serializable_dataclass(frozen=True, kw_only=True) 1216 class AOTP(_PromptSequencer): 1217 """Sequences a prompt as [adjacency list, origin, target, path]. 1218 1219 # Parameters 1220 - `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`. 1221 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`. 1222 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 1223 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 1224 1225 """ 1226 1227 target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field( 1228 default=TargetTokenizers.Unlabeled(), 1229 loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers), 1230 ) 1231 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 1232 default=PathTokenizers.StepSequence(), 1233 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 1234 ) 1235 1236 def _sequence_tokens( 1237 self, 1238 adj_list: list[str], 1239 origin: list[str], 1240 target: list[str], 1241 path: list[str], 1242 ) -> list[str]: 1243 return [ 1244 VOCAB.ADJLIST_START, 1245 *adj_list, 1246 VOCAB.ADJLIST_END, 1247 VOCAB.ORIGIN_START, 1248 *origin, 1249 VOCAB.ORIGIN_END, 1250 VOCAB.TARGET_START, 1251 *target, 1252 VOCAB.TARGET_END, 1253 VOCAB.PATH_START, 1254 *path, 1255 VOCAB.PATH_END, 1256 ] 1257 1258 @serializable_dataclass(frozen=True, kw_only=True) 1259 class AOP(_PromptSequencer): 1260 """Sequences a prompt as [adjacency list, origin, path]. 1261 1262 Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself. 1263 1264 # Parameters 1265 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 1266 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 1267 """ 1268 1269 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 1270 default=PathTokenizers.StepSequence(), 1271 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 1272 ) 1273 1274 def _sequence_tokens( 1275 self, 1276 adj_list: list[str], 1277 origin: list[str], 1278 # explicitly no target in this tokenizer 1279 target: list[str], 1280 path: list[str], 1281 ) -> list[str]: 1282 return [ 1283 VOCAB.ADJLIST_START, 1284 *adj_list, 1285 VOCAB.ADJLIST_END, 1286 VOCAB.ORIGIN_START, 1287 *origin, 1288 VOCAB.ORIGIN_END, 1289 VOCAB.TARGET_START, 1290 VOCAB.TARGET_END, 1291 VOCAB.PATH_START, 1292 *path, 1293 VOCAB.PATH_END, 1294 ]
48class CoordTokenizers(__TokenizerElementNamespace): 49 """Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 50 51 key = "coord_tokenizer" 52 53 @serializable_dataclass(frozen=True, kw_only=True) 54 class _CoordTokenizer(_TokenizerElement, abc.ABC): 55 """Superclass for classes which tokenize singular coords in a maze.""" 56 57 @abc.abstractmethod 58 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: 59 pass 60 61 @classmethod 62 def attribute_key(cls) -> str: 63 return CoordTokenizers.key 64 65 def is_valid(self, do_except: bool = False) -> bool: 66 # No invalid instances possible within data member type hint bounds 67 return True 68 69 @serializable_dataclass(frozen=True, kw_only=True) 70 class UT(_CoordTokenizer): 71 """Unique token coordinate tokenizer.""" 72 73 # inherit docstring 74 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 75 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])] 76 77 @serializable_dataclass(frozen=True, kw_only=True) 78 class CTT(_CoordTokenizer): 79 """Coordinate tuple tokenizer 80 81 # Parameters 82 - `pre`: Whether all coords include an integral preceding delimiter token 83 - `intra`: Whether all coords include a delimiter token between coordinates 84 - `post`: Whether all coords include an integral following delimiter token 85 """ 86 87 pre: bool = serializable_field(default=True) 88 intra: bool = serializable_field(default=True) 89 post: bool = serializable_field(default=True) 90 # Implement methods 91 92 # inherit docstring 93 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 94 return [ 95 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"), 96 str(coord[0]), 97 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"), 98 str(coord[1]), 99 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"), 100 ]
Namespace for _CoordTokenizer
subclass hierarchy used by MazeTokenizerModular
.
69 @serializable_dataclass(frozen=True, kw_only=True) 70 class UT(_CoordTokenizer): 71 """Unique token coordinate tokenizer.""" 72 73 # inherit docstring 74 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 75 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]
Unique token coordinate tokenizer.
74 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 75 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]
Converts a maze element into a list of tokens.
Not all _TokenizerElement
subclasses produce tokens, so this is not an abstract method.
Those subclasses which do produce tokens should override this method.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
77 @serializable_dataclass(frozen=True, kw_only=True) 78 class CTT(_CoordTokenizer): 79 """Coordinate tuple tokenizer 80 81 # Parameters 82 - `pre`: Whether all coords include an integral preceding delimiter token 83 - `intra`: Whether all coords include a delimiter token between coordinates 84 - `post`: Whether all coords include an integral following delimiter token 85 """ 86 87 pre: bool = serializable_field(default=True) 88 intra: bool = serializable_field(default=True) 89 post: bool = serializable_field(default=True) 90 # Implement methods 91 92 # inherit docstring 93 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 94 return [ 95 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"), 96 str(coord[0]), 97 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"), 98 str(coord[1]), 99 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"), 100 ]
Coordinate tuple tokenizer
Parameters
93 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 94 return [ 95 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"), 96 str(coord[0]), 97 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"), 98 str(coord[1]), 99 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"), 100 ]
Converts a maze element into a list of tokens.
Not all _TokenizerElement
subclasses produce tokens, so this is not an abstract method.
Those subclasses which do produce tokens should override this method.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
103class EdgeGroupings(__TokenizerElementNamespace): 104 """Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`.""" 105 106 key = "edge_grouping" 107 108 class _GroupingTokenParams(TypedDict): 109 """A uniform private hyperparameter interface used by `AdjListTokenizer`.""" 110 111 connection_token_ordinal: Literal[0, 1, 2] 112 intra: bool 113 grouped: bool 114 115 @serializable_dataclass(frozen=True, kw_only=True) 116 class _EdgeGrouping(_TokenizerElement, abc.ABC): 117 """Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping.""" 118 119 @classmethod 120 def attribute_key(cls) -> str: 121 return EdgeGroupings.key 122 123 def is_valid(self, do_except: bool = False) -> bool: 124 return True 125 126 @abc.abstractmethod 127 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 128 """Divides a ConnectionArray into groups of edges. 129 130 Shuffles/sequences within each group if applicable. 131 """ 132 pass 133 134 @abc.abstractmethod 135 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 136 """Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize. 137 138 These hyperparameters are not used by `_EdgeGrouping` internally. 139 They are located in `_EdgeGrouping` rather than in `AdjListTokenizer` 140 since the hyperparameter space is a function of the `_EdgeGrouping` subclass. 141 This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses 142 into a uniform private interface used by `AdjListTokenizer`. 143 """ 144 pass 145 146 @serializable_dataclass(frozen=True, kw_only=True) 147 class Ungrouped(_EdgeGrouping): 148 """No grouping occurs, each edge is tokenized individually. 149 150 # Parameters 151 - `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears. 152 Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization. 153 """ 154 155 connection_token_ordinal: Literal[0, 1, 2] = serializable_field( 156 default=1, 157 assert_type=False, 158 ) 159 160 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 161 return EdgeGroupings._GroupingTokenParams( 162 connection_token_ordinal=self.connection_token_ordinal, 163 intra=False, 164 grouped=False, 165 ) 166 167 def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]: 168 return np.expand_dims(edges, 1) 169 170 @serializable_dataclass(frozen=True, kw_only=True) 171 @mark_as_unsupported(_unsupported_is_invalid) 172 class ByLeadingCoord(_EdgeGrouping): 173 """All edges with the same leading coord are grouped together. 174 175 # Parameters 176 - `intra`: Whether all edge groupings include a delimiter token between individual edge representations. 177 Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `) 178 - `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. 179 If false, the fixed order is lexicographical by (row, col). 180 In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord. 181 - `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears. 182 Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization. 183 """ 184 185 intra: bool = serializable_field(default=True) 186 shuffle_group: bool = serializable_field(default=True) 187 connection_token_ordinal: Literal[0, 1] = serializable_field( 188 default=0, 189 assert_type=False, 190 ) 191 192 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 193 return EdgeGroupings._GroupingTokenParams( 194 connection_token_ordinal=self.connection_token_ordinal, 195 intra=self.intra, 196 grouped=True, 197 ) 198 199 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 200 # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function 201 index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort( 202 (edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]), 203 ) 204 sorted_edges: ConnectionArray = edges[index_array, ...] 205 groups: list[ConnectionArray] = np.split( 206 sorted_edges, 207 np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:], 208 ) 209 if self.shuffle_group: 210 [numpy_rng.shuffle(g, axis=0) for g in groups] 211 return groups
Namespace for _EdgeGrouping
subclass hierarchy used by _AdjListTokenizer
.
146 @serializable_dataclass(frozen=True, kw_only=True) 147 class Ungrouped(_EdgeGrouping): 148 """No grouping occurs, each edge is tokenized individually. 149 150 # Parameters 151 - `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears. 152 Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization. 153 """ 154 155 connection_token_ordinal: Literal[0, 1, 2] = serializable_field( 156 default=1, 157 assert_type=False, 158 ) 159 160 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 161 return EdgeGroupings._GroupingTokenParams( 162 connection_token_ordinal=self.connection_token_ordinal, 163 intra=False, 164 grouped=False, 165 ) 166 167 def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]: 168 return np.expand_dims(edges, 1)
No grouping occurs, each edge is tokenized individually.
Parameters
connection_token_ordinal
: At which index in the edge tokenization the connector (or wall) token appears. Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
170 @serializable_dataclass(frozen=True, kw_only=True) 171 @mark_as_unsupported(_unsupported_is_invalid) 172 class ByLeadingCoord(_EdgeGrouping): 173 """All edges with the same leading coord are grouped together. 174 175 # Parameters 176 - `intra`: Whether all edge groupings include a delimiter token between individual edge representations. 177 Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `) 178 - `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. 179 If false, the fixed order is lexicographical by (row, col). 180 In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord. 181 - `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears. 182 Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization. 183 """ 184 185 intra: bool = serializable_field(default=True) 186 shuffle_group: bool = serializable_field(default=True) 187 connection_token_ordinal: Literal[0, 1] = serializable_field( 188 default=0, 189 assert_type=False, 190 ) 191 192 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 193 return EdgeGroupings._GroupingTokenParams( 194 connection_token_ordinal=self.connection_token_ordinal, 195 intra=self.intra, 196 grouped=True, 197 ) 198 199 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 200 # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function 201 index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort( 202 (edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]), 203 ) 204 sorted_edges: ConnectionArray = edges[index_array, ...] 205 groups: list[ConnectionArray] = np.split( 206 sorted_edges, 207 np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:], 208 ) 209 if self.shuffle_group: 210 [numpy_rng.shuffle(g, axis=0) for g in groups] 211 return groups
All edges with the same leading coord are grouped together.
Parameters
intra
: Whether all edge groupings include a delimiter token between individual edge representations. Note that each edge representation will already always include a connector token (VOCAB.CONNECTOR
, or possibly `)shuffle_group
: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. If false, the fixed order is lexicographical by (row, col). In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.connection_token_ordinal
: At which index in token sequence representing a single edge the connector (or wall) token appears. Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool: # noqa: ANN001 258 """Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes""" 259 if do_except: 260 err_msg: str = ( 261 f"Class `{type(self).__name__ = }, marked as unsupported, is not valid." 262 f"{type(self) = }, {self = }" 263 ) 264 raise ValueError(err_msg) 265 266 return False
Default implementation of is_valid
for mark_as_unsupported
-decorated classes
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
214class EdgePermuters(__TokenizerElementNamespace): 215 """Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`.""" 216 217 key = "edge_permuter" 218 219 @serializable_dataclass(frozen=True, kw_only=True) 220 class _EdgePermuter(_TokenizerElement, abc.ABC): 221 """Specifies how to sequence the two coords that encode a lattice edge.""" 222 223 @classmethod 224 def attribute_key(cls) -> str: 225 return EdgePermuters.key 226 227 def is_valid(self, do_except: bool = False) -> bool: 228 # No invalid instances possible within data member type hint bounds 229 return True 230 231 @staticmethod 232 @abc.abstractmethod 233 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 234 """Executes a permutation. 235 236 Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation. 237 238 # Parameters 239 - `lattice_edges`: Array of lattice edges. 240 The two coords in shape[1] must be adjacent in the lattice. 241 242 # Returns 243 - Array of lattice edges with entries along shape[1] systematically permuted. 244 - shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`. 245 """ 246 pass 247 248 @serializable_dataclass(frozen=True, kw_only=True) 249 class SortedCoords(_EdgePermuter): 250 """returns a sorted representation. useful for checking consistency""" 251 252 @staticmethod 253 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 254 return lattice_edges[ 255 np.lexsort( 256 ( 257 lattice_edges[:, 1, 1], 258 lattice_edges[:, 1, 0], 259 lattice_edges[:, 0, 1], 260 lattice_edges[:, 0, 0], 261 ), 262 ), 263 ..., 264 ] 265 266 @serializable_dataclass(frozen=True, kw_only=True) 267 class RandomCoords(_EdgePermuter): 268 """Permutes each edge randomly.""" 269 270 @staticmethod 271 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 272 numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges) 273 return lattice_edges 274 275 @serializable_dataclass(frozen=True, kw_only=True) 276 class BothCoords(_EdgePermuter): 277 """Includes both possible permutations of every edge in the output. 278 279 Since input ConnectionList has only 1 instance of each edge, 280 a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`. 281 """ 282 283 @staticmethod 284 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 285 return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)
Namespace for _EdgePermuter
subclass hierarchy used by _AdjListTokenizer
.
248 @serializable_dataclass(frozen=True, kw_only=True) 249 class SortedCoords(_EdgePermuter): 250 """returns a sorted representation. useful for checking consistency""" 251 252 @staticmethod 253 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 254 return lattice_edges[ 255 np.lexsort( 256 ( 257 lattice_edges[:, 1, 1], 258 lattice_edges[:, 1, 0], 259 lattice_edges[:, 0, 1], 260 lattice_edges[:, 0, 0], 261 ), 262 ), 263 ..., 264 ]
returns a sorted representation. useful for checking consistency
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
266 @serializable_dataclass(frozen=True, kw_only=True) 267 class RandomCoords(_EdgePermuter): 268 """Permutes each edge randomly.""" 269 270 @staticmethod 271 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 272 numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges) 273 return lattice_edges
Permutes each edge randomly.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
275 @serializable_dataclass(frozen=True, kw_only=True) 276 class BothCoords(_EdgePermuter): 277 """Includes both possible permutations of every edge in the output. 278 279 Since input ConnectionList has only 1 instance of each edge, 280 a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`. 281 """ 282 283 @staticmethod 284 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 285 return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)
Includes both possible permutations of every edge in the output.
Since input ConnectionList has only 1 instance of each edge,
a call to BothCoords._permute
will modify lattice_edges
in-place, doubling shape[0]
.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
288class EdgeSubsets(__TokenizerElementNamespace): 289 """Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`.""" 290 291 key = "edge_subset" 292 293 @serializable_dataclass(frozen=True, kw_only=True) 294 class _EdgeSubset(_TokenizerElement, abc.ABC): 295 """Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized.""" 296 297 @classmethod 298 def attribute_key(cls) -> str: 299 return EdgeSubsets.key 300 301 def is_valid(self, do_except: bool = False) -> bool: 302 return True 303 304 @abc.abstractmethod 305 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 306 """Returns the set of lattice edges to be tokenized.""" 307 pass 308 309 @serializable_dataclass(frozen=True, kw_only=True) 310 class AllLatticeEdges(_EdgeSubset): 311 """All 2n**2-2n edges of the lattice are tokenized. 312 313 If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`. 314 """ 315 316 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 317 return lattice_connection_array(maze.grid_n) 318 319 @serializable_dataclass(frozen=True, kw_only=True) 320 class ConnectionEdges(_EdgeSubset): 321 """Only edges which contain a connection are tokenized. 322 323 Alternatively, only edges which contain a wall are tokenized. 324 325 # Parameters 326 - `walls`: Whether wall edges or connection edges are tokenized. 327 If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`. 328 """ 329 330 walls: bool = serializable_field(default=False) 331 332 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 333 conn_list: ConnectionList = maze.connection_list 334 if self.walls: 335 conn_list = np.logical_not(conn_list) 336 conn_list[0, -1, :] = False 337 conn_list[1, :, -1] = False 338 return connection_list_to_adj_list( 339 conn_list, 340 shuffle_d0=False, 341 shuffle_d1=False, 342 )
Namespace for _EdgeSubset
subclass hierarchy used by _AdjListTokenizer
.
309 @serializable_dataclass(frozen=True, kw_only=True) 310 class AllLatticeEdges(_EdgeSubset): 311 """All 2n**2-2n edges of the lattice are tokenized. 312 313 If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`. 314 """ 315 316 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 317 return lattice_connection_array(maze.grid_n)
All 2n**2-2n edges of the lattice are tokenized.
If a wall exists on that edge, the edge is tokenized in the same manner, using VOCAB.ADJLIST_WALL
in place of VOCAB.CONNECTOR
.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
319 @serializable_dataclass(frozen=True, kw_only=True) 320 class ConnectionEdges(_EdgeSubset): 321 """Only edges which contain a connection are tokenized. 322 323 Alternatively, only edges which contain a wall are tokenized. 324 325 # Parameters 326 - `walls`: Whether wall edges or connection edges are tokenized. 327 If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`. 328 """ 329 330 walls: bool = serializable_field(default=False) 331 332 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 333 conn_list: ConnectionList = maze.connection_list 334 if self.walls: 335 conn_list = np.logical_not(conn_list) 336 conn_list[0, -1, :] = False 337 conn_list[1, :, -1] = False 338 return connection_list_to_adj_list( 339 conn_list, 340 shuffle_d0=False, 341 shuffle_d1=False, 342 )
Only edges which contain a connection are tokenized.
Alternatively, only edges which contain a wall are tokenized.
Parameters
walls
: Whether wall edges or connection edges are tokenized. If true,VOCAB.ADJLIST_WALL
is used in place ofVOCAB.CONNECTOR
.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
356class AdjListTokenizers(__TokenizerElementNamespace): 357 """Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 358 359 key = "adj_list_tokenizer" 360 361 @serializable_dataclass(frozen=True, kw_only=True) 362 @mark_as_unsupported(_adjlist_no_pre_unsupported) 363 class _AdjListTokenizer(_TokenizerElement, abc.ABC): 364 """Specifies how the adjacency list is tokenized. 365 366 Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations. 367 See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details. 368 369 # Parameters 370 - `pre`: Whether all edge groupings include a preceding delimiter token 371 - `post`: Whether all edge groupings include a following delimiter token 372 - `shuffle_d0`: Specifies how to sequence the edge groupings. 373 If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group. 374 - `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping. 375 - `edge_subset`: Specifies the subset of lattice edges to be tokenized. 376 - `edge_permuter`: Specifies, in each edge tokenization, which coord either: 377 1. Appears first in the tokenization, for `AdjListCoord`. 378 2. Is tokenized directly as a coord, for `AdjListCardinal`. 379 - `shuffle`: For each edge, the leading coord is selected randomly. 380 - `all`: Each edge appears twice in the tokenization, appearing with both leading coords. 381 - `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details. 382 """ 383 384 pre: bool = serializable_field(default=False, assert_type=False) 385 post: bool = serializable_field(default=True) 386 shuffle_d0: bool = serializable_field(default=True) 387 edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field( 388 default=EdgeGroupings.Ungrouped(), 389 loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings), 390 ) 391 edge_subset: EdgeSubsets._EdgeSubset = serializable_field( 392 default=EdgeSubsets.ConnectionEdges(), 393 loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets), 394 ) 395 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 396 default=EdgePermuters.RandomCoords(), 397 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 398 ) 399 400 @classmethod 401 def attribute_key(cls) -> str: 402 return AdjListTokenizers.key 403 404 def is_valid(self, do_except: bool = False) -> bool: 405 # No invalid instances possible within data member type hint bounds 406 return True 407 408 @abc.abstractmethod 409 def _tokenization_callables( 410 self, 411 edges: ConnectionArray, 412 is_conn: Bool[np.ndarray, " edges"], 413 coord_tokenizer: CoordTokenizers._CoordTokenizer, 414 *args, 415 **kwargs, 416 ) -> list[Callable]: 417 """Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization. 418 419 # Returns 420 - `[0]`: leading coord tokens 421 - `[1]`: connector tokens 422 - `[2]`: trailing coord tokens 423 """ 424 pass 425 426 def _tokenize_edge_grouping( 427 self, 428 edges: ConnectionArray, 429 maze: LatticeMaze, 430 coord_tokenizer: CoordTokenizers._CoordTokenizer, 431 group_params: EdgeGroupings._GroupingTokenParams, 432 ) -> Sequence[str]: 433 """Tokenizes a single edge grouping.""" 434 cxn_ord: int = group_params["connection_token_ordinal"] 435 is_conn: Bool[np.ndarray, edges] = is_connection( 436 edges, 437 maze.connection_list, 438 ) 439 tokenize_callables = self._tokenization_callables( 440 edges, 441 is_conn, 442 coord_tokenizer, 443 ) 444 445 if group_params["grouped"]: 446 # If grouped 447 callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1] 448 repeated_callables = [ 449 tokenize_callables[i] for i in callable_permutation 450 ] 451 return flatten( 452 [ 453 tokenize_callables[0](0), 454 [ 455 [ 456 *[ 457 tok_callable(i) 458 for tok_callable in repeated_callables 459 ], 460 *( 461 (VOCAB.ADJLIST_INTRA,) 462 if group_params["intra"] 463 else () 464 ), 465 ] 466 for i in range(edges.shape[0]) 467 ], 468 ], 469 ) 470 else: 471 # If ungrouped 472 callable_permutation = [0, 2] 473 callable_permutation.insert(cxn_ord, 1) 474 tokenize_callables = [ 475 tokenize_callables[i] for i in callable_permutation 476 ] 477 478 return flatten( 479 [ 480 [ 481 [ 482 *[ 483 tok_callable(i) 484 for tok_callable in tokenize_callables 485 ], 486 *empty_sequence_if_attr_false( 487 (VOCAB.ADJLIST_INTRA,), 488 group_params, 489 "intra", 490 ), 491 ] 492 for i in range(edges.shape[0]) 493 ], 494 ], 495 ) 496 497 def to_tokens( 498 self, 499 maze: LatticeMaze, 500 coord_tokenizer: CoordTokenizers._CoordTokenizer, 501 ) -> list[str]: 502 # Get the set of edges to be tokenized 503 edges: ConnectionArray = self.edge_subset._get_edges(maze) 504 # Systematically permute the leading coord of each edge 505 edges: ConnectionArray = self.edge_permuter._permute(edges) 506 group_params: EdgeGroupings._GroupingTokenParams = ( 507 self.edge_grouping._token_params() 508 ) 509 # then, we need to group the edges 510 groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges) 511 # shuffle the groups if specified 512 if self.shuffle_d0: 513 if isinstance(groups, np.ndarray): 514 numpy_rng.shuffle(groups, axis=0) 515 elif isinstance(groups, list): 516 random.shuffle(groups) 517 else: 518 err_msg: str = f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported." 519 raise TypeError(err_msg) 520 # Tokenize each group with optional delimiters 521 tokens: list[str] = list( 522 flatten( 523 [ 524 [ 525 *empty_sequence_if_attr_false( 526 (VOCAB.ADJLIST_PRE,), 527 self, 528 "pre", 529 ), 530 *self._tokenize_edge_grouping( 531 group, 532 maze, 533 coord_tokenizer, 534 group_params, 535 ), 536 *empty_sequence_if_attr_false( 537 (VOCAB.ADJACENCY_ENDLINE,), 538 self, 539 "post", 540 ), 541 ] 542 for group in groups 543 ], 544 ), 545 ) 546 return tokens 547 548 @serializable_dataclass(frozen=True, kw_only=True) 549 class AdjListCoord(_AdjListTokenizer): 550 """Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.""" 551 552 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 553 default=EdgePermuters.RandomCoords(), 554 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 555 ) 556 557 def _tokenization_callables( 558 self, 559 edges: ConnectionArray, 560 is_conn: Bool[np.ndarray, " edges"], 561 coord_tokenizer: CoordTokenizers._CoordTokenizer, 562 *args, 563 **kwargs, 564 ) -> list[Callable]: 565 # Map from `is_conn` to the tokens which represent connections and walls 566 conn_token_map: dict[bool, str] = { 567 True: VOCAB.CONNECTOR, 568 False: VOCAB.ADJLIST_WALL, 569 } 570 return [ 571 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 572 lambda i: conn_token_map[is_conn[i]], 573 lambda i: coord_tokenizer.to_tokens(edges[i, 1]), 574 ] 575 576 @serializable_dataclass(frozen=True, kw_only=True) 577 class AdjListCardinal(_AdjListTokenizer): 578 """Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members. 579 580 # Parameters 581 - `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens. 582 """ 583 584 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 585 default=EdgePermuters.BothCoords(), 586 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 587 ) 588 589 def _tokenization_callables( 590 self, 591 edges: ConnectionArray, 592 is_conn: Bool[np.ndarray, " edges"], 593 coord_tokenizer: CoordTokenizers._CoordTokenizer, 594 *args, 595 **kwargs, 596 ) -> list[Callable]: 597 # Map from `is_conn` to the tokens which represent connections and walls 598 conn_token_map: dict[bool, str] = { 599 True: VOCAB.CONNECTOR, 600 False: VOCAB.ADJLIST_WALL, 601 } 602 return [ 603 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 604 lambda i: conn_token_map[is_conn[i]], 605 lambda i: get_cardinal_direction(edges[i]), 606 ]
Namespace for _AdjListTokenizer
subclass hierarchy used by MazeTokenizerModular
.
548 @serializable_dataclass(frozen=True, kw_only=True) 549 class AdjListCoord(_AdjListTokenizer): 550 """Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.""" 551 552 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 553 default=EdgePermuters.RandomCoords(), 554 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 555 ) 556 557 def _tokenization_callables( 558 self, 559 edges: ConnectionArray, 560 is_conn: Bool[np.ndarray, " edges"], 561 coord_tokenizer: CoordTokenizers._CoordTokenizer, 562 *args, 563 **kwargs, 564 ) -> list[Callable]: 565 # Map from `is_conn` to the tokens which represent connections and walls 566 conn_token_map: dict[bool, str] = { 567 True: VOCAB.CONNECTOR, 568 False: VOCAB.ADJLIST_WALL, 569 } 570 return [ 571 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 572 lambda i: conn_token_map[is_conn[i]], 573 lambda i: coord_tokenizer.to_tokens(edges[i, 1]), 574 ]
Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- AdjListTokenizers._AdjListTokenizer
- pre
- post
- shuffle_d0
- edge_grouping
- edge_subset
- attribute_key
- is_valid
- to_tokens
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
576 @serializable_dataclass(frozen=True, kw_only=True) 577 class AdjListCardinal(_AdjListTokenizer): 578 """Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members. 579 580 # Parameters 581 - `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens. 582 """ 583 584 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 585 default=EdgePermuters.BothCoords(), 586 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 587 ) 588 589 def _tokenization_callables( 590 self, 591 edges: ConnectionArray, 592 is_conn: Bool[np.ndarray, " edges"], 593 coord_tokenizer: CoordTokenizers._CoordTokenizer, 594 *args, 595 **kwargs, 596 ) -> list[Callable]: 597 # Map from `is_conn` to the tokens which represent connections and walls 598 conn_token_map: dict[bool, str] = { 599 True: VOCAB.CONNECTOR, 600 False: VOCAB.ADJLIST_WALL, 601 } 602 return [ 603 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 604 lambda i: conn_token_map[is_conn[i]], 605 lambda i: get_cardinal_direction(edges[i]), 606 ]
Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.
Parameters
coord_first
: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- AdjListTokenizers._AdjListTokenizer
- pre
- post
- shuffle_d0
- edge_grouping
- edge_subset
- attribute_key
- is_valid
- to_tokens
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
609class TargetTokenizers(__TokenizerElementNamespace): 610 """Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 611 612 key = "target_tokenizer" 613 614 @serializable_dataclass(frozen=True, kw_only=True) 615 class _TargetTokenizer(_TokenizerElement, abc.ABC): 616 """Superclass of tokenizers for maze targets.""" 617 618 @abc.abstractmethod 619 def to_tokens( 620 self, 621 targets: Sequence[Coord], 622 coord_tokenizer: CoordTokenizers._CoordTokenizer, 623 ) -> list[str]: 624 """Returns tokens representing the target.""" 625 pass 626 627 @classmethod 628 def attribute_key(cls) -> str: 629 return TargetTokenizers.key 630 631 @serializable_dataclass(frozen=True, kw_only=True) 632 class Unlabeled(_TargetTokenizer): 633 """Targets are simply listed as coord tokens. 634 635 - `post`: Whether all coords include an integral following delimiter token 636 """ 637 638 post: bool = serializable_field(default=False) 639 640 # inherit docstring 641 def to_tokens( # noqa: D102 642 self, 643 targets: Sequence[Coord], 644 coord_tokenizer: CoordTokenizers._CoordTokenizer, 645 ) -> list[str]: 646 return list( 647 flatten( 648 [ 649 [ 650 *coord_tokenizer.to_tokens(target), 651 *empty_sequence_if_attr_false( 652 [VOCAB.TARGET_POST], 653 self, 654 "post", 655 ), 656 ] 657 for target in targets 658 ], 659 ), 660 ) 661 662 # inherit docstring 663 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 664 # No invalid instances possible within data member type hint bounds 665 return True
Namespace for _TargetTokenizer
subclass hierarchy used by MazeTokenizerModular
.
631 @serializable_dataclass(frozen=True, kw_only=True) 632 class Unlabeled(_TargetTokenizer): 633 """Targets are simply listed as coord tokens. 634 635 - `post`: Whether all coords include an integral following delimiter token 636 """ 637 638 post: bool = serializable_field(default=False) 639 640 # inherit docstring 641 def to_tokens( # noqa: D102 642 self, 643 targets: Sequence[Coord], 644 coord_tokenizer: CoordTokenizers._CoordTokenizer, 645 ) -> list[str]: 646 return list( 647 flatten( 648 [ 649 [ 650 *coord_tokenizer.to_tokens(target), 651 *empty_sequence_if_attr_false( 652 [VOCAB.TARGET_POST], 653 self, 654 "post", 655 ), 656 ] 657 for target in targets 658 ], 659 ), 660 ) 661 662 # inherit docstring 663 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 664 # No invalid instances possible within data member type hint bounds 665 return True
Targets are simply listed as coord tokens.
post
: Whether all coords include an integral following delimiter token
641 def to_tokens( # noqa: D102 642 self, 643 targets: Sequence[Coord], 644 coord_tokenizer: CoordTokenizers._CoordTokenizer, 645 ) -> list[str]: 646 return list( 647 flatten( 648 [ 649 [ 650 *coord_tokenizer.to_tokens(target), 651 *empty_sequence_if_attr_false( 652 [VOCAB.TARGET_POST], 653 self, 654 "post", 655 ), 656 ] 657 for target in targets 658 ], 659 ), 660 )
Returns tokens representing the target.
663 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 664 # No invalid instances possible within data member type hint bounds 665 return True
Returns if self
contains data members capable of producing an overall valid MazeTokenizerModular
.
Some _TokenizerElement
instances may be created which are not useful despite obeying data member type hints.
is_valid
allows for more precise detection of invalid _TokenizerElement
s beyond type hinting alone.
If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply return True
for that subclass.
Types of Invalidity
In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity:
- Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study.
E.g.,
_TokenizerElement
s which are strictly worse than some alternative. - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
- Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
- Erroneous: These tokenizers might raise exceptions during use.
Development
is_invalid
is implemented to always return True
in some abstract classes where all currently possible subclass instances are valid.
When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.
Nesting
In general, when implementing this method, there is no need to recursively call is_valid
on nested _TokenizerElement
s contained in the class.
In other words, failures of is_valid
need not bubble up to the top of the nested _TokenizerElement
tree.
MazeTokenizerModular.is_valid
calls is_valid
on each of its _TokenizerElement
s individually, so failure at any level will be detected.
Types of Invalidity
If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments.
This could be used to create more or less stringent filters on the valid _TokenizerElement
instances.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
668class StepSizes(__TokenizerElementNamespace): 669 """Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`.""" 670 671 key = "step_size" 672 673 @serializable_dataclass(frozen=True, kw_only=True) 674 class _StepSize(_TokenizerElement, abc.ABC): 675 """Specifies which coords in `maze.solution` are used to represent the path.""" 676 677 @classmethod 678 def attribute_key(cls) -> str: 679 return StepSizes.key 680 681 @abc.abstractmethod # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call 682 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 683 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 684 raise NotImplementedError( 685 "Subclasses must implement `StepSize.step_indices.", 686 ) 687 688 def step_start_end_indices(self, maze: SolvedMaze) -> list[tuple[int, int]]: 689 """Returns steps as tuples of starting and ending positions for each step.""" 690 indices: list[int] = self._step_single_indices(maze) 691 # TODO: RUF007 Prefer `itertools.pairwise()` over `zip()` when iterating over successive pairs 692 return [ 693 (start, end) 694 for start, end in zip(indices[:-1], indices[1:], strict=False) # noqa: RUF007 695 ] 696 697 def is_valid(self, do_except: bool = False) -> bool: 698 # No invalid instances possible within data member type hint bounds 699 return True 700 701 @serializable_dataclass(frozen=True, kw_only=True) 702 class Singles(_StepSize): 703 """Every coord in `maze.solution` is represented. 704 705 Legacy tokenizers all use this behavior. 706 """ 707 708 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 709 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 710 return list(range(maze.solution.shape[0])) 711 712 @serializable_dataclass(frozen=True, kw_only=True) 713 @mark_as_unsupported(_unsupported_is_invalid) 714 class Straightaways(_StepSize): 715 """Only coords where the path turns are represented in the path. 716 717 I.e., the path is represented as a sequence of straightaways, 718 specified by the coords at the turns. 719 """ 720 721 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 722 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 723 last_turn_coord: Coord = maze.solution[0, ...] 724 indices: list[int] = [0] 725 for i, coord in enumerate(maze.solution): 726 if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]: 727 indices.append(i - 1) 728 last_turn_coord = maze.solution[i - 1, ...] 729 indices.append(i) 730 return indices 731 732 @serializable_dataclass(frozen=True, kw_only=True) 733 class Forks(_StepSize): 734 """Only coords at forks, where the path has >=2 options for the next step are included. 735 736 Excludes the option of backtracking. 737 The starting and ending coords are always included. 738 """ 739 740 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 741 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 742 return maze.get_solution_forking_points(always_include_endpoints=True)[0] 743 744 @serializable_dataclass(frozen=True, kw_only=True) 745 @mark_as_unsupported(_unsupported_is_invalid) 746 class ForksAndStraightaways(_StepSize): 747 """Includes the union of the coords included by `Forks` and `Straightaways`. 748 749 See documentation for those classes for details. 750 """ 751 752 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 753 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 754 return list( 755 np.unique( 756 np.concatenate( 757 ( 758 StepSizes.Straightaways()._step_single_indices(maze), 759 StepSizes.Forks()._step_single_indices(maze), 760 ), 761 ), 762 ), 763 )
Namespace for _StepSize
subclass hierarchy used by MazeTokenizerModular
.
701 @serializable_dataclass(frozen=True, kw_only=True) 702 class Singles(_StepSize): 703 """Every coord in `maze.solution` is represented. 704 705 Legacy tokenizers all use this behavior. 706 """ 707 708 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 709 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 710 return list(range(maze.solution.shape[0]))
Every coord in maze.solution
is represented.
Legacy tokenizers all use this behavior.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
712 @serializable_dataclass(frozen=True, kw_only=True) 713 @mark_as_unsupported(_unsupported_is_invalid) 714 class Straightaways(_StepSize): 715 """Only coords where the path turns are represented in the path. 716 717 I.e., the path is represented as a sequence of straightaways, 718 specified by the coords at the turns. 719 """ 720 721 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 722 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 723 last_turn_coord: Coord = maze.solution[0, ...] 724 indices: list[int] = [0] 725 for i, coord in enumerate(maze.solution): 726 if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]: 727 indices.append(i - 1) 728 last_turn_coord = maze.solution[i - 1, ...] 729 indices.append(i) 730 return indices
Only coords where the path turns are represented in the path.
I.e., the path is represented as a sequence of straightaways, specified by the coords at the turns.
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool: # noqa: ANN001 258 """Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes""" 259 if do_except: 260 err_msg: str = ( 261 f"Class `{type(self).__name__ = }, marked as unsupported, is not valid." 262 f"{type(self) = }, {self = }" 263 ) 264 raise ValueError(err_msg) 265 266 return False
Default implementation of is_valid
for mark_as_unsupported
-decorated classes
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
732 @serializable_dataclass(frozen=True, kw_only=True) 733 class Forks(_StepSize): 734 """Only coords at forks, where the path has >=2 options for the next step are included. 735 736 Excludes the option of backtracking. 737 The starting and ending coords are always included. 738 """ 739 740 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 741 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 742 return maze.get_solution_forking_points(always_include_endpoints=True)[0]
Only coords at forks, where the path has >=2 options for the next step are included.
Excludes the option of backtracking. The starting and ending coords are always included.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
744 @serializable_dataclass(frozen=True, kw_only=True) 745 @mark_as_unsupported(_unsupported_is_invalid) 746 class ForksAndStraightaways(_StepSize): 747 """Includes the union of the coords included by `Forks` and `Straightaways`. 748 749 See documentation for those classes for details. 750 """ 751 752 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 753 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 754 return list( 755 np.unique( 756 np.concatenate( 757 ( 758 StepSizes.Straightaways()._step_single_indices(maze), 759 StepSizes.Forks()._step_single_indices(maze), 760 ), 761 ), 762 ), 763 )
Includes the union of the coords included by Forks
and Straightaways
.
See documentation for those classes for details.
257def _unsupported_is_invalid(self, do_except: bool = False) -> bool: # noqa: ANN001 258 """Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes""" 259 if do_except: 260 err_msg: str = ( 261 f"Class `{type(self).__name__ = }, marked as unsupported, is not valid." 262 f"{type(self) = }, {self = }" 263 ) 264 raise ValueError(err_msg) 265 266 return False
Default implementation of is_valid
for mark_as_unsupported
-decorated classes
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- to_tokens
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
766class StepTokenizers(__TokenizerElementNamespace): 767 """Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 768 769 key = "step_tokenizers" 770 771 @serializable_dataclass(frozen=True, kw_only=True) 772 class _StepTokenizer(_TokenizerElement, abc.ABC): 773 """Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized.""" 774 775 @classmethod 776 def attribute_key(cls) -> str: 777 return StepTokenizers.key 778 779 @abc.abstractmethod 780 def to_tokens( 781 self, 782 maze: SolvedMaze, 783 start_index: int, 784 end_index: int, 785 **kwargs, 786 ) -> list[str]: 787 """Tokenizes a single step in the solution. 788 789 # Parameters 790 - `maze`: Maze to be tokenized 791 - `start_index`: The index of the Coord in `maze.solution` at which the current step starts 792 - `end_index`: The index of the Coord in `maze.solution` at which the current step ends 793 """ 794 raise NotImplementedError( 795 "Subclasses must implement `StepTokenizer.to_tokens.", 796 ) 797 798 def is_valid(self, do_except: bool = False) -> bool: 799 # No invalid instances possible within data member type hint bounds 800 return True 801 802 @serializable_dataclass(frozen=True, kw_only=True) 803 class Coord(_StepTokenizer): 804 """A direct tokenization of the end position coord represents the step.""" 805 806 # inherit docstring 807 def to_tokens( # noqa: D102 808 self, 809 maze: SolvedMaze, 810 start_index: int, 811 end_index: int, 812 coord_tokenizer: CoordTokenizers._CoordTokenizer, 813 ) -> list[str]: 814 return coord_tokenizer.to_tokens(maze.solution[end_index, ...]) 815 816 @serializable_dataclass(frozen=True, kw_only=True) 817 class Cardinal(_StepTokenizer): 818 """A step is tokenized with a cardinal direction token. 819 820 It is the direction of the step from the starting position along the solution. 821 """ 822 823 # inherit docstring 824 def to_tokens( # noqa: D102 825 self, 826 maze: SolvedMaze, 827 start_index: int, 828 end_index: int, 829 **kwargs, 830 ) -> list[str]: 831 return [ 832 get_cardinal_direction(maze.solution[start_index : start_index + 2]), 833 ] 834 835 @serializable_dataclass(frozen=True, kw_only=True) 836 class Relative(_StepTokenizer): 837 """Tokenizes a solution step using relative first-person directions (right, left, forward, etc.). 838 839 To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH. 840 Similarly to `Cardinal`, the direction is that of the step from the starting position. 841 """ 842 843 # inherit docstring 844 def to_tokens( # noqa: D102 845 self, 846 maze: SolvedMaze, 847 start_index: int, 848 end_index: int, 849 **kwargs, 850 ) -> list[str]: 851 if start_index == 0: 852 start = maze.solution[0] 853 previous = start + np.array([1, 0]) 854 return [ 855 get_relative_direction( 856 np.concatenate( 857 ( 858 np.expand_dims(previous, 0), 859 maze.solution[start_index : start_index + 2], 860 ), 861 axis=0, 862 ), 863 ), 864 ] 865 return [ 866 get_relative_direction( 867 maze.solution[start_index - 1 : start_index + 2], 868 ), 869 ] 870 871 @serializable_dataclass(frozen=True, kw_only=True) 872 class Distance(_StepTokenizer): 873 """A count of the number of individual steps from the starting point to the end point. 874 875 Contains no information about directionality, only the distance traveled in the step. 876 `Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`. 877 This constraint is enforced in `_PathTokenizer.is_valid`. 878 """ 879 880 # inherit docstring 881 def to_tokens( # noqa: D102 882 self, 883 maze: SolvedMaze, 884 start_index: int, 885 end_index: int, 886 **kwargs, 887 ) -> list[str]: 888 d: int = end_index - start_index 889 return [getattr(VOCAB, f"I_{d:03}")] 890 891 """ 892 `StepTokenizerPermutation` 893 A sequence of unique `_StepTokenizer`s. 894 This type exists mostly just for the clarity and convenience of `_PathTokenizer` code. 895 """ 896 StepTokenizerPermutation: type = ( 897 tuple[_StepTokenizer] 898 | tuple[_StepTokenizer, _StepTokenizer] 899 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer] 900 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer] 901 )
Namespace for _StepTokenizer
subclass hierarchy used by MazeTokenizerModular
.
802 @serializable_dataclass(frozen=True, kw_only=True) 803 class Coord(_StepTokenizer): 804 """A direct tokenization of the end position coord represents the step.""" 805 806 # inherit docstring 807 def to_tokens( # noqa: D102 808 self, 809 maze: SolvedMaze, 810 start_index: int, 811 end_index: int, 812 coord_tokenizer: CoordTokenizers._CoordTokenizer, 813 ) -> list[str]: 814 return coord_tokenizer.to_tokens(maze.solution[end_index, ...])
A direct tokenization of the end position coord represents the step.
807 def to_tokens( # noqa: D102 808 self, 809 maze: SolvedMaze, 810 start_index: int, 811 end_index: int, 812 coord_tokenizer: CoordTokenizers._CoordTokenizer, 813 ) -> list[str]: 814 return coord_tokenizer.to_tokens(maze.solution[end_index, ...])
Tokenizes a single step in the solution.
Parameters
maze
: Maze to be tokenizedstart_index
: The index of the Coord inmaze.solution
at which the current step startsend_index
: The index of the Coord inmaze.solution
at which the current step ends
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
816 @serializable_dataclass(frozen=True, kw_only=True) 817 class Cardinal(_StepTokenizer): 818 """A step is tokenized with a cardinal direction token. 819 820 It is the direction of the step from the starting position along the solution. 821 """ 822 823 # inherit docstring 824 def to_tokens( # noqa: D102 825 self, 826 maze: SolvedMaze, 827 start_index: int, 828 end_index: int, 829 **kwargs, 830 ) -> list[str]: 831 return [ 832 get_cardinal_direction(maze.solution[start_index : start_index + 2]), 833 ]
A step is tokenized with a cardinal direction token.
It is the direction of the step from the starting position along the solution.
824 def to_tokens( # noqa: D102 825 self, 826 maze: SolvedMaze, 827 start_index: int, 828 end_index: int, 829 **kwargs, 830 ) -> list[str]: 831 return [ 832 get_cardinal_direction(maze.solution[start_index : start_index + 2]), 833 ]
Tokenizes a single step in the solution.
Parameters
maze
: Maze to be tokenizedstart_index
: The index of the Coord inmaze.solution
at which the current step startsend_index
: The index of the Coord inmaze.solution
at which the current step ends
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
835 @serializable_dataclass(frozen=True, kw_only=True) 836 class Relative(_StepTokenizer): 837 """Tokenizes a solution step using relative first-person directions (right, left, forward, etc.). 838 839 To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH. 840 Similarly to `Cardinal`, the direction is that of the step from the starting position. 841 """ 842 843 # inherit docstring 844 def to_tokens( # noqa: D102 845 self, 846 maze: SolvedMaze, 847 start_index: int, 848 end_index: int, 849 **kwargs, 850 ) -> list[str]: 851 if start_index == 0: 852 start = maze.solution[0] 853 previous = start + np.array([1, 0]) 854 return [ 855 get_relative_direction( 856 np.concatenate( 857 ( 858 np.expand_dims(previous, 0), 859 maze.solution[start_index : start_index + 2], 860 ), 861 axis=0, 862 ), 863 ), 864 ] 865 return [ 866 get_relative_direction( 867 maze.solution[start_index - 1 : start_index + 2], 868 ), 869 ]
Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).
To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH.
Similarly to Cardinal
, the direction is that of the step from the starting position.
844 def to_tokens( # noqa: D102 845 self, 846 maze: SolvedMaze, 847 start_index: int, 848 end_index: int, 849 **kwargs, 850 ) -> list[str]: 851 if start_index == 0: 852 start = maze.solution[0] 853 previous = start + np.array([1, 0]) 854 return [ 855 get_relative_direction( 856 np.concatenate( 857 ( 858 np.expand_dims(previous, 0), 859 maze.solution[start_index : start_index + 2], 860 ), 861 axis=0, 862 ), 863 ), 864 ] 865 return [ 866 get_relative_direction( 867 maze.solution[start_index - 1 : start_index + 2], 868 ), 869 ]
Tokenizes a single step in the solution.
Parameters
maze
: Maze to be tokenizedstart_index
: The index of the Coord inmaze.solution
at which the current step startsend_index
: The index of the Coord inmaze.solution
at which the current step ends
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
871 @serializable_dataclass(frozen=True, kw_only=True) 872 class Distance(_StepTokenizer): 873 """A count of the number of individual steps from the starting point to the end point. 874 875 Contains no information about directionality, only the distance traveled in the step. 876 `Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`. 877 This constraint is enforced in `_PathTokenizer.is_valid`. 878 """ 879 880 # inherit docstring 881 def to_tokens( # noqa: D102 882 self, 883 maze: SolvedMaze, 884 start_index: int, 885 end_index: int, 886 **kwargs, 887 ) -> list[str]: 888 d: int = end_index - start_index 889 return [getattr(VOCAB, f"I_{d:03}")]
A count of the number of individual steps from the starting point to the end point.
Contains no information about directionality, only the distance traveled in the step.
Distance
must be combined with at least one other _StepTokenizer
in a StepTokenizerPermutation
.
This constraint is enforced in _PathTokenizer.is_valid
.
881 def to_tokens( # noqa: D102 882 self, 883 maze: SolvedMaze, 884 start_index: int, 885 end_index: int, 886 **kwargs, 887 ) -> list[str]: 888 d: int = end_index - start_index 889 return [getattr(VOCAB, f"I_{d:03}")]
Tokenizes a single step in the solution.
Parameters
maze
: Maze to be tokenizedstart_index
: The index of the Coord inmaze.solution
at which the current step startsend_index
: The index of the Coord inmaze.solution
at which the current step ends
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
904class PathTokenizers(__TokenizerElementNamespace): 905 """Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 906 907 key = "path_tokenizer" 908 909 @serializable_dataclass(frozen=True, kw_only=True) 910 class _PathTokenizer(_TokenizerElement, abc.ABC): 911 """Superclass of tokenizers for maze solution paths.""" 912 913 @abc.abstractmethod 914 def to_tokens( 915 self, 916 maze: SolvedMaze, 917 coord_tokenizer: CoordTokenizers._CoordTokenizer, 918 ) -> list[str]: 919 """Returns tokens representing the solution path.""" 920 pass 921 922 @classmethod 923 def attribute_key(cls) -> str: 924 return PathTokenizers.key 925 926 @serializable_dataclass(frozen=True, kw_only=True) 927 class StepSequence(_PathTokenizer, abc.ABC): 928 """Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path. 929 930 Allows for a sequence of leading and trailing tokens which don't fit the step pattern. 931 932 # Parameters 933 - `step_size`: Selects the size of a single step in the sequence 934 - `step_tokenizers`: Selects the combination and permutation of tokens 935 - `pre`: Whether all steps include an integral preceding delimiter token 936 - `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization. 937 - `post`: Whether all steps include an integral following delimiter token 938 """ 939 940 step_size: StepSizes._StepSize = serializable_field( 941 default=StepSizes.Singles(), 942 loading_fn=lambda x: _load_tokenizer_element(x, StepSizes), 943 ) 944 step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field( 945 default=(StepTokenizers.Coord(),), 946 serialization_fn=lambda x: [y.serialize() for y in x], 947 loading_fn=lambda x: tuple(x[StepTokenizers.key]), 948 ) 949 pre: bool = serializable_field(default=False) 950 intra: bool = serializable_field(default=False) 951 post: bool = serializable_field(default=False) 952 953 # inherit docstring 954 def to_tokens( # noqa: D102 955 self, 956 maze: SolvedMaze, 957 coord_tokenizer: CoordTokenizers._CoordTokenizer, 958 ) -> list[str]: 959 return [ 960 *self._leading_tokens(maze, coord_tokenizer), 961 *flatten( 962 [ 963 self._single_step_tokens(maze, start, end, coord_tokenizer) 964 for start, end in self.step_size.step_start_end_indices(maze) 965 ], 966 ), 967 *self._trailing_tokens(maze, coord_tokenizer), 968 ] 969 970 def _single_step_tokens( 971 self, 972 maze: SolvedMaze, 973 i: int, 974 j: int, 975 coord_tokenizer: CoordTokenizers._CoordTokenizer, 976 ) -> list[str]: 977 """Returns the token sequence representing a single step along the path.""" 978 step_rep_tokens: list[list[str]] = [ 979 step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer) 980 for step_tokenizer in self.step_tokenizers 981 ] 982 if self.intra: 983 step_rep_tokens_and_intra: list[str] = [None] * ( 984 len(step_rep_tokens) * 2 985 ) 986 step_rep_tokens_and_intra[::2] = step_rep_tokens 987 step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len( 988 step_rep_tokens, 989 ) 990 step_rep_tokens = list(flatten(step_rep_tokens_and_intra)) 991 all_tokens: list[str] = [ 992 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 993 *flatten(step_rep_tokens), 994 *empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"), 995 ] 996 return all_tokens 997 998 def _leading_tokens( 999 self, 1000 maze: SolvedMaze, 1001 coord_tokenizer: CoordTokenizers._CoordTokenizer, 1002 ) -> list[str]: 1003 """Returns tokens preceding those from the sequence from `_single_step_tokens`. 1004 1005 Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`. 1006 <PATH_START> should NOT be included. 1007 """ 1008 if StepTokenizers.Coord() in self.step_tokenizers: 1009 return [ 1010 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 1011 *coord_tokenizer.to_tokens(maze.solution[0, ...]), 1012 *empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"), 1013 ] 1014 return [] 1015 1016 def _trailing_tokens( 1017 self, 1018 c: Coord, 1019 coord_tokenizer: CoordTokenizers._CoordTokenizer, 1020 ) -> list[str]: 1021 """Returns tokens following those from the sequence from `_single_step_tokens`. 1022 1023 <PATH_END> should NOT be included. 1024 """ 1025 return [] 1026 1027 # inherits docstring 1028 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 1029 output: bool 1030 1031 if len(set(self.step_tokenizers)) != len(self.step_tokenizers): 1032 # Uninteresting: repeated elements are not useful 1033 output = False 1034 else: 1035 # we do noqa for the comment if false 1036 if len(self.step_tokenizers) == 1 and isinstance( 1037 self.step_tokenizers[0], 1038 StepTokenizers.Distance, 1039 ): 1040 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required. 1041 output = False 1042 else: 1043 output = True 1044 1045 if not output and do_except: 1046 raise ValueError( 1047 "PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.", 1048 ) 1049 1050 return output
Namespace for _PathTokenizer
subclass hierarchy used by MazeTokenizerModular
.
926 @serializable_dataclass(frozen=True, kw_only=True) 927 class StepSequence(_PathTokenizer, abc.ABC): 928 """Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path. 929 930 Allows for a sequence of leading and trailing tokens which don't fit the step pattern. 931 932 # Parameters 933 - `step_size`: Selects the size of a single step in the sequence 934 - `step_tokenizers`: Selects the combination and permutation of tokens 935 - `pre`: Whether all steps include an integral preceding delimiter token 936 - `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization. 937 - `post`: Whether all steps include an integral following delimiter token 938 """ 939 940 step_size: StepSizes._StepSize = serializable_field( 941 default=StepSizes.Singles(), 942 loading_fn=lambda x: _load_tokenizer_element(x, StepSizes), 943 ) 944 step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field( 945 default=(StepTokenizers.Coord(),), 946 serialization_fn=lambda x: [y.serialize() for y in x], 947 loading_fn=lambda x: tuple(x[StepTokenizers.key]), 948 ) 949 pre: bool = serializable_field(default=False) 950 intra: bool = serializable_field(default=False) 951 post: bool = serializable_field(default=False) 952 953 # inherit docstring 954 def to_tokens( # noqa: D102 955 self, 956 maze: SolvedMaze, 957 coord_tokenizer: CoordTokenizers._CoordTokenizer, 958 ) -> list[str]: 959 return [ 960 *self._leading_tokens(maze, coord_tokenizer), 961 *flatten( 962 [ 963 self._single_step_tokens(maze, start, end, coord_tokenizer) 964 for start, end in self.step_size.step_start_end_indices(maze) 965 ], 966 ), 967 *self._trailing_tokens(maze, coord_tokenizer), 968 ] 969 970 def _single_step_tokens( 971 self, 972 maze: SolvedMaze, 973 i: int, 974 j: int, 975 coord_tokenizer: CoordTokenizers._CoordTokenizer, 976 ) -> list[str]: 977 """Returns the token sequence representing a single step along the path.""" 978 step_rep_tokens: list[list[str]] = [ 979 step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer) 980 for step_tokenizer in self.step_tokenizers 981 ] 982 if self.intra: 983 step_rep_tokens_and_intra: list[str] = [None] * ( 984 len(step_rep_tokens) * 2 985 ) 986 step_rep_tokens_and_intra[::2] = step_rep_tokens 987 step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len( 988 step_rep_tokens, 989 ) 990 step_rep_tokens = list(flatten(step_rep_tokens_and_intra)) 991 all_tokens: list[str] = [ 992 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 993 *flatten(step_rep_tokens), 994 *empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"), 995 ] 996 return all_tokens 997 998 def _leading_tokens( 999 self, 1000 maze: SolvedMaze, 1001 coord_tokenizer: CoordTokenizers._CoordTokenizer, 1002 ) -> list[str]: 1003 """Returns tokens preceding those from the sequence from `_single_step_tokens`. 1004 1005 Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`. 1006 <PATH_START> should NOT be included. 1007 """ 1008 if StepTokenizers.Coord() in self.step_tokenizers: 1009 return [ 1010 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 1011 *coord_tokenizer.to_tokens(maze.solution[0, ...]), 1012 *empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"), 1013 ] 1014 return [] 1015 1016 def _trailing_tokens( 1017 self, 1018 c: Coord, 1019 coord_tokenizer: CoordTokenizers._CoordTokenizer, 1020 ) -> list[str]: 1021 """Returns tokens following those from the sequence from `_single_step_tokens`. 1022 1023 <PATH_END> should NOT be included. 1024 """ 1025 return [] 1026 1027 # inherits docstring 1028 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 1029 output: bool 1030 1031 if len(set(self.step_tokenizers)) != len(self.step_tokenizers): 1032 # Uninteresting: repeated elements are not useful 1033 output = False 1034 else: 1035 # we do noqa for the comment if false 1036 if len(self.step_tokenizers) == 1 and isinstance( 1037 self.step_tokenizers[0], 1038 StepTokenizers.Distance, 1039 ): 1040 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required. 1041 output = False 1042 else: 1043 output = True 1044 1045 if not output and do_except: 1046 raise ValueError( 1047 "PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.", 1048 ) 1049 1050 return output
Any PathTokenizer
where the tokenization may be assembled from token subsequences, each of which represents a step along the path.
Allows for a sequence of leading and trailing tokens which don't fit the step pattern.
Parameters
step_size
: Selects the size of a single step in the sequencestep_tokenizers
: Selects the combination and permutation of tokenspre
: Whether all steps include an integral preceding delimiter tokenintra
: Whether all steps include a delimiter token after each individual_StepTokenizer
tokenization.post
: Whether all steps include an integral following delimiter token
954 def to_tokens( # noqa: D102 955 self, 956 maze: SolvedMaze, 957 coord_tokenizer: CoordTokenizers._CoordTokenizer, 958 ) -> list[str]: 959 return [ 960 *self._leading_tokens(maze, coord_tokenizer), 961 *flatten( 962 [ 963 self._single_step_tokens(maze, start, end, coord_tokenizer) 964 for start, end in self.step_size.step_start_end_indices(maze) 965 ], 966 ), 967 *self._trailing_tokens(maze, coord_tokenizer), 968 ]
Returns tokens representing the solution path.
1028 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 1029 output: bool 1030 1031 if len(set(self.step_tokenizers)) != len(self.step_tokenizers): 1032 # Uninteresting: repeated elements are not useful 1033 output = False 1034 else: 1035 # we do noqa for the comment if false 1036 if len(self.step_tokenizers) == 1 and isinstance( 1037 self.step_tokenizers[0], 1038 StepTokenizers.Distance, 1039 ): 1040 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required. 1041 output = False 1042 else: 1043 output = True 1044 1045 if not output and do_except: 1046 raise ValueError( 1047 "PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.", 1048 ) 1049 1050 return output
Returns if self
contains data members capable of producing an overall valid MazeTokenizerModular
.
Some _TokenizerElement
instances may be created which are not useful despite obeying data member type hints.
is_valid
allows for more precise detection of invalid _TokenizerElement
s beyond type hinting alone.
If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply return True
for that subclass.
Types of Invalidity
In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity:
- Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study.
E.g.,
_TokenizerElement
s which are strictly worse than some alternative. - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
- Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
- Erroneous: These tokenizers might raise exceptions during use.
Development
is_invalid
is implemented to always return True
in some abstract classes where all currently possible subclass instances are valid.
When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.
Nesting
In general, when implementing this method, there is no need to recursively call is_valid
on nested _TokenizerElement
s contained in the class.
In other words, failures of is_valid
need not bubble up to the top of the nested _TokenizerElement
tree.
MazeTokenizerModular.is_valid
calls is_valid
on each of its _TokenizerElement
s individually, so failure at any level will be detected.
Types of Invalidity
If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments.
This could be used to create more or less stringent filters on the valid _TokenizerElement
instances.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
1053class PromptSequencers(__TokenizerElementNamespace): 1054 """Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`.""" 1055 1056 key = "prompt_sequencer" 1057 1058 @serializable_dataclass(frozen=True, kw_only=True) 1059 class _PromptSequencer(_TokenizerElement, abc.ABC): 1060 """Sequences token regions into a complete maze tokenization. 1061 1062 # Parameters 1063 - `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position. 1064 - `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`. 1065 Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s. 1066 """ 1067 1068 coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field( 1069 default=CoordTokenizers.UT(), 1070 loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers), 1071 ) 1072 adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field( 1073 default=AdjListTokenizers.AdjListCoord(), 1074 loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers), 1075 ) 1076 1077 @classmethod 1078 def attribute_key(cls) -> str: 1079 return PromptSequencers.key 1080 1081 @staticmethod 1082 def _trim_if_unsolved_maze( 1083 untrimmed: list[str], 1084 is_untargeted: bool = False, 1085 is_unsolved: bool = False, 1086 ) -> list[str]: 1087 """Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze. 1088 1089 # Development 1090 This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP. 1091 It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses. 1092 """ 1093 if is_untargeted: 1094 return tokens_between( 1095 untrimmed, 1096 VOCAB.ADJLIST_START, 1097 VOCAB.ADJLIST_END, 1098 include_start=True, 1099 include_end=True, 1100 ) 1101 if is_unsolved: 1102 if VOCAB.TARGET_END in untrimmed: 1103 return tokens_between( 1104 untrimmed, 1105 VOCAB.ADJLIST_START, 1106 VOCAB.TARGET_END, 1107 include_start=True, 1108 include_end=True, 1109 ) 1110 else: 1111 return tokens_between( 1112 untrimmed, 1113 VOCAB.ADJLIST_START, 1114 VOCAB.ORIGIN_END, 1115 include_start=True, 1116 include_end=True, 1117 ) 1118 return untrimmed 1119 1120 def to_tokens( 1121 self, 1122 maze: LatticeMaze, 1123 *args, 1124 **kwargs, 1125 ) -> list[str]: 1126 """Returns a complete list of tokens for a given set of maze elements.""" 1127 untrimmed: list[str] = self._sequence_tokens( 1128 *self._get_prompt_regions(maze), 1129 ) 1130 return self._trim_if_unsolved_maze( 1131 untrimmed, 1132 not hasattr(maze, "start_pos"), 1133 not hasattr(maze, "solution"), 1134 ) 1135 1136 def _get_prompt_regions( 1137 self, 1138 maze: LatticeMaze, 1139 *args, 1140 **kwargs, 1141 ) -> list[list[str]]: 1142 """Gets the prompt regions of a maze in a fixed sequence. 1143 1144 This method is NOT responsible for including/excluding any prompt regions. 1145 Always return according to the API described under Returns. 1146 This implementation is expected to be suitable for most `PromptSequencer` subclasses. 1147 Subclasses may override this method if needed for special behavior. 1148 1149 # Returns 1150 - [0]: list[str] Adjacency list tokens 1151 - [1]: list[str] Origin tokens 1152 - [2]: list[str] Target tokens 1153 - [3]: list[str] Path tokens 1154 1155 # `None`-valued Args 1156 If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized. 1157 To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables. 1158 """ 1159 origin: Coord | None = getattr(maze, "start_pos", None) 1160 target: list[Coord] | None = [ 1161 getattr(maze, "end_pos", None), 1162 ] # TargetTokenizer requires target: Sequence[Coord] 1163 1164 return [ 1165 ( 1166 self.adj_list_tokenizer.to_tokens( 1167 maze, 1168 coord_tokenizer=self.coord_tokenizer, 1169 ) 1170 if hasattr(self, "adj_list_tokenizer") 1171 else [] 1172 ), 1173 self.coord_tokenizer.to_tokens(origin) if origin is not None else [], 1174 ( 1175 self.target_tokenizer.to_tokens( 1176 target, 1177 coord_tokenizer=self.coord_tokenizer, 1178 ) 1179 if target[0] is not None and hasattr(self, "target_tokenizer") 1180 else [] 1181 ), 1182 ( 1183 self.path_tokenizer.to_tokens( 1184 maze, 1185 coord_tokenizer=self.coord_tokenizer, 1186 ) 1187 if hasattr(maze, "solution") and hasattr(self, "path_tokenizer") 1188 else [] 1189 ), 1190 ] 1191 1192 @abc.abstractmethod 1193 def _sequence_tokens( 1194 self, 1195 adj_list: list[str], 1196 origin: list[str] | None, 1197 target: list[str] | None, 1198 path: list[str] | None, 1199 ) -> list[str]: 1200 """Sequences token regions into a complete prompt. 1201 1202 Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as <ADJLIST_START>, <ORIGIN_END>, etc. 1203 1204 # Parameters 1205 - `adj_list`: Tokens representing the adjacency list 1206 - `origin`: Tokens representing the origin 1207 - `target`: Tokens representing the target 1208 - `path`: Tokens representing the path 1209 """ 1210 pass 1211 1212 def is_valid(self, do_except: bool = False) -> bool: 1213 # No invalid instances possible within data member type hint bounds 1214 return True 1215 1216 @serializable_dataclass(frozen=True, kw_only=True) 1217 class AOTP(_PromptSequencer): 1218 """Sequences a prompt as [adjacency list, origin, target, path]. 1219 1220 # Parameters 1221 - `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`. 1222 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`. 1223 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 1224 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 1225 1226 """ 1227 1228 target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field( 1229 default=TargetTokenizers.Unlabeled(), 1230 loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers), 1231 ) 1232 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 1233 default=PathTokenizers.StepSequence(), 1234 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 1235 ) 1236 1237 def _sequence_tokens( 1238 self, 1239 adj_list: list[str], 1240 origin: list[str], 1241 target: list[str], 1242 path: list[str], 1243 ) -> list[str]: 1244 return [ 1245 VOCAB.ADJLIST_START, 1246 *adj_list, 1247 VOCAB.ADJLIST_END, 1248 VOCAB.ORIGIN_START, 1249 *origin, 1250 VOCAB.ORIGIN_END, 1251 VOCAB.TARGET_START, 1252 *target, 1253 VOCAB.TARGET_END, 1254 VOCAB.PATH_START, 1255 *path, 1256 VOCAB.PATH_END, 1257 ] 1258 1259 @serializable_dataclass(frozen=True, kw_only=True) 1260 class AOP(_PromptSequencer): 1261 """Sequences a prompt as [adjacency list, origin, path]. 1262 1263 Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself. 1264 1265 # Parameters 1266 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 1267 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 1268 """ 1269 1270 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 1271 default=PathTokenizers.StepSequence(), 1272 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 1273 ) 1274 1275 def _sequence_tokens( 1276 self, 1277 adj_list: list[str], 1278 origin: list[str], 1279 # explicitly no target in this tokenizer 1280 target: list[str], 1281 path: list[str], 1282 ) -> list[str]: 1283 return [ 1284 VOCAB.ADJLIST_START, 1285 *adj_list, 1286 VOCAB.ADJLIST_END, 1287 VOCAB.ORIGIN_START, 1288 *origin, 1289 VOCAB.ORIGIN_END, 1290 VOCAB.TARGET_START, 1291 VOCAB.TARGET_END, 1292 VOCAB.PATH_START, 1293 *path, 1294 VOCAB.PATH_END, 1295 ]
Namespace for _PromptSequencer
subclass hierarchy used by MazeTokenizerModular
.
1216 @serializable_dataclass(frozen=True, kw_only=True) 1217 class AOTP(_PromptSequencer): 1218 """Sequences a prompt as [adjacency list, origin, target, path]. 1219 1220 # Parameters 1221 - `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`. 1222 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`. 1223 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 1224 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 1225 1226 """ 1227 1228 target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field( 1229 default=TargetTokenizers.Unlabeled(), 1230 loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers), 1231 ) 1232 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 1233 default=PathTokenizers.StepSequence(), 1234 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 1235 ) 1236 1237 def _sequence_tokens( 1238 self, 1239 adj_list: list[str], 1240 origin: list[str], 1241 target: list[str], 1242 path: list[str], 1243 ) -> list[str]: 1244 return [ 1245 VOCAB.ADJLIST_START, 1246 *adj_list, 1247 VOCAB.ADJLIST_END, 1248 VOCAB.ORIGIN_START, 1249 *origin, 1250 VOCAB.ORIGIN_END, 1251 VOCAB.TARGET_START, 1252 *target, 1253 VOCAB.TARGET_END, 1254 VOCAB.PATH_START, 1255 *path, 1256 VOCAB.PATH_END, 1257 ]
Sequences a prompt as [adjacency list, origin, target, path].
Parameters
target_tokenizer
: Tokenizer element which tokenizes the target(s) of aTargetedLatticeMaze
. Usescoord_tokenizer
to tokenize coords if that is part of the design of thatTargetTokenizer
.path_tokenizer
: Tokenizer element which tokenizes the solution path of aSolvedMaze
. Usescoord_tokenizer
to tokenize coords if that is part of the design of thatPathTokenizer
.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- PromptSequencers._PromptSequencer
- coord_tokenizer
- adj_list_tokenizer
- attribute_key
- to_tokens
- is_valid
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
1259 @serializable_dataclass(frozen=True, kw_only=True) 1260 class AOP(_PromptSequencer): 1261 """Sequences a prompt as [adjacency list, origin, path]. 1262 1263 Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself. 1264 1265 # Parameters 1266 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 1267 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 1268 """ 1269 1270 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 1271 default=PathTokenizers.StepSequence(), 1272 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 1273 ) 1274 1275 def _sequence_tokens( 1276 self, 1277 adj_list: list[str], 1278 origin: list[str], 1279 # explicitly no target in this tokenizer 1280 target: list[str], 1281 path: list[str], 1282 ) -> list[str]: 1283 return [ 1284 VOCAB.ADJLIST_START, 1285 *adj_list, 1286 VOCAB.ADJLIST_END, 1287 VOCAB.ORIGIN_START, 1288 *origin, 1289 VOCAB.ORIGIN_END, 1290 VOCAB.TARGET_START, 1291 VOCAB.TARGET_END, 1292 VOCAB.PATH_START, 1293 *path, 1294 VOCAB.PATH_END, 1295 ]
Sequences a prompt as [adjacency list, origin, path].
Still includes "
Parameters
path_tokenizer
: Tokenizer element which tokenizes the solution path of aSolvedMaze
. Usescoord_tokenizer
to tokenize coords if that is part of the design of thatPathTokenizer
.
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- PromptSequencers._PromptSequencer
- coord_tokenizer
- adj_list_tokenizer
- attribute_key
- to_tokens
- is_valid
- maze_dataset.tokenization.modular.element_base._TokenizerElement
- name
- tokenizer_elements
- tokenizer_element_tree
- tokenizer_element_dict
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict