Coverage for maze_dataset/tokenization/modular/elements.py: 90%
370 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
1"""implements subclasses of `_TokenizerElement` to be used in `MazeTokenizerModular`"""
3import abc
4import random
5from typing import (
6 Callable,
7 Literal,
8 Sequence,
9 TypedDict,
10)
12import numpy as np
13from jaxtyping import Bool, Int
14from muutils.json_serialize import (
15 serializable_dataclass,
16 serializable_field,
17)
18from muutils.misc import empty_sequence_if_attr_false, flatten
20# from maze_dataset import SolvedMaze
21from maze_dataset.constants import (
22 VOCAB,
23 ConnectionArray,
24 ConnectionList,
25 Coord,
26 CoordTup,
27)
28from maze_dataset.generation import numpy_rng
29from maze_dataset.maze.lattice_maze import LatticeMaze, SolvedMaze
30from maze_dataset.token_utils import (
31 connection_list_to_adj_list,
32 get_cardinal_direction,
33 get_relative_direction,
34 is_connection,
35 tokens_between,
36)
37from maze_dataset.tokenization.modular.element_base import (
38 __TokenizerElementNamespace,
39 _load_tokenizer_element,
40 _TokenizerElement,
41 _unsupported_is_invalid,
42 mark_as_unsupported,
43)
44from maze_dataset.utils import lattice_connection_array
47class CoordTokenizers(__TokenizerElementNamespace):
48 """Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
50 key = "coord_tokenizer"
52 @serializable_dataclass(frozen=True, kw_only=True)
53 class _CoordTokenizer(_TokenizerElement, abc.ABC):
54 """Superclass for classes which tokenize singular coords in a maze."""
56 @abc.abstractmethod
57 def to_tokens(self, coord: Coord | CoordTup) -> list[str]:
58 pass
60 @classmethod
61 def attribute_key(cls) -> str:
62 return CoordTokenizers.key
64 def is_valid(self, do_except: bool = False) -> bool:
65 # No invalid instances possible within data member type hint bounds
66 return True
68 @serializable_dataclass(frozen=True, kw_only=True)
69 class UT(_CoordTokenizer):
70 """Unique token coordinate tokenizer."""
72 # inherit docstring
73 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102
74 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])]
76 @serializable_dataclass(frozen=True, kw_only=True)
77 class CTT(_CoordTokenizer):
78 """Coordinate tuple tokenizer
80 # Parameters
81 - `pre`: Whether all coords include an integral preceding delimiter token
82 - `intra`: Whether all coords include a delimiter token between coordinates
83 - `post`: Whether all coords include an integral following delimiter token
84 """
86 pre: bool = serializable_field(default=True)
87 intra: bool = serializable_field(default=True)
88 post: bool = serializable_field(default=True)
89 # Implement methods
91 # inherit docstring
92 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102
93 return [
94 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"),
95 str(coord[0]),
96 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"),
97 str(coord[1]),
98 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"),
99 ]
102class EdgeGroupings(__TokenizerElementNamespace):
103 """Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`."""
105 key = "edge_grouping"
107 class _GroupingTokenParams(TypedDict):
108 """A uniform private hyperparameter interface used by `AdjListTokenizer`."""
110 connection_token_ordinal: Literal[0, 1, 2]
111 intra: bool
112 grouped: bool
114 @serializable_dataclass(frozen=True, kw_only=True)
115 class _EdgeGrouping(_TokenizerElement, abc.ABC):
116 """Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping."""
118 @classmethod
119 def attribute_key(cls) -> str:
120 return EdgeGroupings.key
122 def is_valid(self, do_except: bool = False) -> bool:
123 return True
125 @abc.abstractmethod
126 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
127 """Divides a ConnectionArray into groups of edges.
129 Shuffles/sequences within each group if applicable.
130 """
131 pass
133 @abc.abstractmethod
134 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
135 """Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize.
137 These hyperparameters are not used by `_EdgeGrouping` internally.
138 They are located in `_EdgeGrouping` rather than in `AdjListTokenizer`
139 since the hyperparameter space is a function of the `_EdgeGrouping` subclass.
140 This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses
141 into a uniform private interface used by `AdjListTokenizer`.
142 """
143 pass
145 @serializable_dataclass(frozen=True, kw_only=True)
146 class Ungrouped(_EdgeGrouping):
147 """No grouping occurs, each edge is tokenized individually.
149 # Parameters
150 - `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears.
151 Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization.
152 """
154 connection_token_ordinal: Literal[0, 1, 2] = serializable_field(
155 default=1,
156 assert_type=False,
157 )
159 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
160 return EdgeGroupings._GroupingTokenParams(
161 connection_token_ordinal=self.connection_token_ordinal,
162 intra=False,
163 grouped=False,
164 )
166 def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]:
167 return np.expand_dims(edges, 1)
169 @serializable_dataclass(frozen=True, kw_only=True)
170 @mark_as_unsupported(_unsupported_is_invalid)
171 class ByLeadingCoord(_EdgeGrouping):
172 """All edges with the same leading coord are grouped together.
174 # Parameters
175 - `intra`: Whether all edge groupings include a delimiter token between individual edge representations.
176 Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `)
177 - `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order.
178 If false, the fixed order is lexicographical by (row, col).
179 In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord.
180 - `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears.
181 Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization.
182 """
184 intra: bool = serializable_field(default=True)
185 shuffle_group: bool = serializable_field(default=True)
186 connection_token_ordinal: Literal[0, 1] = serializable_field(
187 default=0,
188 assert_type=False,
189 )
191 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams":
192 return EdgeGroupings._GroupingTokenParams(
193 connection_token_ordinal=self.connection_token_ordinal,
194 intra=self.intra,
195 grouped=True,
196 )
198 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]:
199 # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function
200 index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort(
201 (edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]),
202 )
203 sorted_edges: ConnectionArray = edges[index_array, ...]
204 groups: list[ConnectionArray] = np.split(
205 sorted_edges,
206 np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:],
207 )
208 if self.shuffle_group:
209 [numpy_rng.shuffle(g, axis=0) for g in groups]
210 return groups
213class EdgePermuters(__TokenizerElementNamespace):
214 """Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`."""
216 key = "edge_permuter"
218 @serializable_dataclass(frozen=True, kw_only=True)
219 class _EdgePermuter(_TokenizerElement, abc.ABC):
220 """Specifies how to sequence the two coords that encode a lattice edge."""
222 @classmethod
223 def attribute_key(cls) -> str:
224 return EdgePermuters.key
226 def is_valid(self, do_except: bool = False) -> bool:
227 # No invalid instances possible within data member type hint bounds
228 return True
230 @staticmethod
231 @abc.abstractmethod
232 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
233 """Executes a permutation.
235 Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation.
237 # Parameters
238 - `lattice_edges`: Array of lattice edges.
239 The two coords in shape[1] must be adjacent in the lattice.
241 # Returns
242 - Array of lattice edges with entries along shape[1] systematically permuted.
243 - shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`.
244 """
245 pass
247 @serializable_dataclass(frozen=True, kw_only=True)
248 class SortedCoords(_EdgePermuter):
249 """returns a sorted representation. useful for checking consistency"""
251 @staticmethod
252 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
253 return lattice_edges[
254 np.lexsort(
255 (
256 lattice_edges[:, 1, 1],
257 lattice_edges[:, 1, 0],
258 lattice_edges[:, 0, 1],
259 lattice_edges[:, 0, 0],
260 ),
261 ),
262 ...,
263 ]
265 @serializable_dataclass(frozen=True, kw_only=True)
266 class RandomCoords(_EdgePermuter):
267 """Permutes each edge randomly."""
269 @staticmethod
270 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
271 numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges)
272 return lattice_edges
274 @serializable_dataclass(frozen=True, kw_only=True)
275 class BothCoords(_EdgePermuter):
276 """Includes both possible permutations of every edge in the output.
278 Since input ConnectionList has only 1 instance of each edge,
279 a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`.
280 """
282 @staticmethod
283 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray:
284 return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0)
287class EdgeSubsets(__TokenizerElementNamespace):
288 """Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`."""
290 key = "edge_subset"
292 @serializable_dataclass(frozen=True, kw_only=True)
293 class _EdgeSubset(_TokenizerElement, abc.ABC):
294 """Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized."""
296 @classmethod
297 def attribute_key(cls) -> str:
298 return EdgeSubsets.key
300 def is_valid(self, do_except: bool = False) -> bool:
301 return True
303 @abc.abstractmethod
304 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
305 """Returns the set of lattice edges to be tokenized."""
306 pass
308 @serializable_dataclass(frozen=True, kw_only=True)
309 class AllLatticeEdges(_EdgeSubset):
310 """All 2n**2-2n edges of the lattice are tokenized.
312 If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`.
313 """
315 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
316 return lattice_connection_array(maze.grid_n)
318 @serializable_dataclass(frozen=True, kw_only=True)
319 class ConnectionEdges(_EdgeSubset):
320 """Only edges which contain a connection are tokenized.
322 Alternatively, only edges which contain a wall are tokenized.
324 # Parameters
325 - `walls`: Whether wall edges or connection edges are tokenized.
326 If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`.
327 """
329 walls: bool = serializable_field(default=False)
331 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray:
332 conn_list: ConnectionList = maze.connection_list
333 if self.walls:
334 conn_list = np.logical_not(conn_list)
335 conn_list[0, -1, :] = False
336 conn_list[1, :, -1] = False
337 return connection_list_to_adj_list(
338 conn_list,
339 shuffle_d0=False,
340 shuffle_d1=False,
341 )
344def _adjlist_no_pre_unsupported(self_, do_except: bool = False) -> bool: # noqa: ANN001
345 """Returns False if `pre` is True, True otherwise."""
346 output: bool = self_.pre is False
347 if do_except and not output:
348 raise ValueError(
349 "AdjListCoord does not support `pre == False`.",
350 )
352 return output
355class AdjListTokenizers(__TokenizerElementNamespace):
356 """Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
358 key = "adj_list_tokenizer"
360 @serializable_dataclass(frozen=True, kw_only=True)
361 @mark_as_unsupported(_adjlist_no_pre_unsupported)
362 class _AdjListTokenizer(_TokenizerElement, abc.ABC):
363 """Specifies how the adjacency list is tokenized.
365 Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations.
366 See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details.
368 # Parameters
369 - `pre`: Whether all edge groupings include a preceding delimiter token
370 - `post`: Whether all edge groupings include a following delimiter token
371 - `shuffle_d0`: Specifies how to sequence the edge groupings.
372 If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group.
373 - `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping.
374 - `edge_subset`: Specifies the subset of lattice edges to be tokenized.
375 - `edge_permuter`: Specifies, in each edge tokenization, which coord either:
376 1. Appears first in the tokenization, for `AdjListCoord`.
377 2. Is tokenized directly as a coord, for `AdjListCardinal`.
378 - `shuffle`: For each edge, the leading coord is selected randomly.
379 - `all`: Each edge appears twice in the tokenization, appearing with both leading coords.
380 - `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details.
381 """
383 pre: bool = serializable_field(default=False, assert_type=False)
384 post: bool = serializable_field(default=True)
385 shuffle_d0: bool = serializable_field(default=True)
386 edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field(
387 default=EdgeGroupings.Ungrouped(),
388 loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings),
389 )
390 edge_subset: EdgeSubsets._EdgeSubset = serializable_field(
391 default=EdgeSubsets.ConnectionEdges(),
392 loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets),
393 )
394 edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
395 default=EdgePermuters.RandomCoords(),
396 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
397 )
399 @classmethod
400 def attribute_key(cls) -> str:
401 return AdjListTokenizers.key
403 def is_valid(self, do_except: bool = False) -> bool:
404 # No invalid instances possible within data member type hint bounds
405 return True
407 @abc.abstractmethod
408 def _tokenization_callables(
409 self,
410 edges: ConnectionArray,
411 is_conn: Bool[np.ndarray, " edges"],
412 coord_tokenizer: CoordTokenizers._CoordTokenizer,
413 *args,
414 **kwargs,
415 ) -> list[Callable]:
416 """Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization.
418 # Returns
419 - `[0]`: leading coord tokens
420 - `[1]`: connector tokens
421 - `[2]`: trailing coord tokens
422 """
423 pass
425 def _tokenize_edge_grouping(
426 self,
427 edges: ConnectionArray,
428 maze: LatticeMaze,
429 coord_tokenizer: CoordTokenizers._CoordTokenizer,
430 group_params: EdgeGroupings._GroupingTokenParams,
431 ) -> Sequence[str]:
432 """Tokenizes a single edge grouping."""
433 cxn_ord: int = group_params["connection_token_ordinal"]
434 is_conn: Bool[np.ndarray, edges] = is_connection(
435 edges,
436 maze.connection_list,
437 )
438 tokenize_callables = self._tokenization_callables(
439 edges,
440 is_conn,
441 coord_tokenizer,
442 )
444 if group_params["grouped"]:
445 # If grouped
446 callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1]
447 repeated_callables = [
448 tokenize_callables[i] for i in callable_permutation
449 ]
450 return flatten(
451 [
452 tokenize_callables[0](0),
453 [
454 [
455 *[
456 tok_callable(i)
457 for tok_callable in repeated_callables
458 ],
459 *(
460 (VOCAB.ADJLIST_INTRA,)
461 if group_params["intra"]
462 else ()
463 ),
464 ]
465 for i in range(edges.shape[0])
466 ],
467 ],
468 )
469 else:
470 # If ungrouped
471 callable_permutation = [0, 2]
472 callable_permutation.insert(cxn_ord, 1)
473 tokenize_callables = [
474 tokenize_callables[i] for i in callable_permutation
475 ]
477 return flatten(
478 [
479 [
480 [
481 *[
482 tok_callable(i)
483 for tok_callable in tokenize_callables
484 ],
485 *empty_sequence_if_attr_false(
486 (VOCAB.ADJLIST_INTRA,),
487 group_params,
488 "intra",
489 ),
490 ]
491 for i in range(edges.shape[0])
492 ],
493 ],
494 )
496 def to_tokens(
497 self,
498 maze: LatticeMaze,
499 coord_tokenizer: CoordTokenizers._CoordTokenizer,
500 ) -> list[str]:
501 # Get the set of edges to be tokenized
502 edges: ConnectionArray = self.edge_subset._get_edges(maze)
503 # Systematically permute the leading coord of each edge
504 edges: ConnectionArray = self.edge_permuter._permute(edges)
505 group_params: EdgeGroupings._GroupingTokenParams = (
506 self.edge_grouping._token_params()
507 )
508 # then, we need to group the edges
509 groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges)
510 # shuffle the groups if specified
511 if self.shuffle_d0:
512 if isinstance(groups, np.ndarray):
513 numpy_rng.shuffle(groups, axis=0)
514 elif isinstance(groups, list):
515 random.shuffle(groups)
516 else:
517 err_msg: str = f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported."
518 raise TypeError(err_msg)
519 # Tokenize each group with optional delimiters
520 tokens: list[str] = list(
521 flatten(
522 [
523 [
524 *empty_sequence_if_attr_false(
525 (VOCAB.ADJLIST_PRE,),
526 self,
527 "pre",
528 ),
529 *self._tokenize_edge_grouping(
530 group,
531 maze,
532 coord_tokenizer,
533 group_params,
534 ),
535 *empty_sequence_if_attr_false(
536 (VOCAB.ADJACENCY_ENDLINE,),
537 self,
538 "post",
539 ),
540 ]
541 for group in groups
542 ],
543 ),
544 )
545 return tokens
547 @serializable_dataclass(frozen=True, kw_only=True)
548 class AdjListCoord(_AdjListTokenizer):
549 """Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members."""
551 edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
552 default=EdgePermuters.RandomCoords(),
553 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
554 )
556 def _tokenization_callables(
557 self,
558 edges: ConnectionArray,
559 is_conn: Bool[np.ndarray, " edges"],
560 coord_tokenizer: CoordTokenizers._CoordTokenizer,
561 *args,
562 **kwargs,
563 ) -> list[Callable]:
564 # Map from `is_conn` to the tokens which represent connections and walls
565 conn_token_map: dict[bool, str] = {
566 True: VOCAB.CONNECTOR,
567 False: VOCAB.ADJLIST_WALL,
568 }
569 return [
570 lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
571 lambda i: conn_token_map[is_conn[i]],
572 lambda i: coord_tokenizer.to_tokens(edges[i, 1]),
573 ]
575 @serializable_dataclass(frozen=True, kw_only=True)
576 class AdjListCardinal(_AdjListTokenizer):
577 """Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members.
579 # Parameters
580 - `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens.
581 """
583 edge_permuter: EdgePermuters._EdgePermuter = serializable_field(
584 default=EdgePermuters.BothCoords(),
585 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters),
586 )
588 def _tokenization_callables(
589 self,
590 edges: ConnectionArray,
591 is_conn: Bool[np.ndarray, " edges"],
592 coord_tokenizer: CoordTokenizers._CoordTokenizer,
593 *args,
594 **kwargs,
595 ) -> list[Callable]:
596 # Map from `is_conn` to the tokens which represent connections and walls
597 conn_token_map: dict[bool, str] = {
598 True: VOCAB.CONNECTOR,
599 False: VOCAB.ADJLIST_WALL,
600 }
601 return [
602 lambda i: coord_tokenizer.to_tokens(edges[i, 0]),
603 lambda i: conn_token_map[is_conn[i]],
604 lambda i: get_cardinal_direction(edges[i]),
605 ]
608class TargetTokenizers(__TokenizerElementNamespace):
609 """Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
611 key = "target_tokenizer"
613 @serializable_dataclass(frozen=True, kw_only=True)
614 class _TargetTokenizer(_TokenizerElement, abc.ABC):
615 """Superclass of tokenizers for maze targets."""
617 @abc.abstractmethod
618 def to_tokens(
619 self,
620 targets: Sequence[Coord],
621 coord_tokenizer: CoordTokenizers._CoordTokenizer,
622 ) -> list[str]:
623 """Returns tokens representing the target."""
624 pass
626 @classmethod
627 def attribute_key(cls) -> str:
628 return TargetTokenizers.key
630 @serializable_dataclass(frozen=True, kw_only=True)
631 class Unlabeled(_TargetTokenizer):
632 """Targets are simply listed as coord tokens.
634 - `post`: Whether all coords include an integral following delimiter token
635 """
637 post: bool = serializable_field(default=False)
639 # inherit docstring
640 def to_tokens( # noqa: D102
641 self,
642 targets: Sequence[Coord],
643 coord_tokenizer: CoordTokenizers._CoordTokenizer,
644 ) -> list[str]:
645 return list(
646 flatten(
647 [
648 [
649 *coord_tokenizer.to_tokens(target),
650 *empty_sequence_if_attr_false(
651 [VOCAB.TARGET_POST],
652 self,
653 "post",
654 ),
655 ]
656 for target in targets
657 ],
658 ),
659 )
661 # inherit docstring
662 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102
663 # No invalid instances possible within data member type hint bounds
664 return True
667class StepSizes(__TokenizerElementNamespace):
668 """Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`."""
670 key = "step_size"
672 @serializable_dataclass(frozen=True, kw_only=True)
673 class _StepSize(_TokenizerElement, abc.ABC):
674 """Specifies which coords in `maze.solution` are used to represent the path."""
676 @classmethod
677 def attribute_key(cls) -> str:
678 return StepSizes.key
680 @abc.abstractmethod # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call
681 def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
682 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
683 raise NotImplementedError(
684 "Subclasses must implement `StepSize.step_indices.",
685 )
687 def step_start_end_indices(self, maze: SolvedMaze) -> list[tuple[int, int]]:
688 """Returns steps as tuples of starting and ending positions for each step."""
689 indices: list[int] = self._step_single_indices(maze)
690 # TODO: RUF007 Prefer `itertools.pairwise()` over `zip()` when iterating over successive pairs
691 return [
692 (start, end)
693 for start, end in zip(indices[:-1], indices[1:], strict=False) # noqa: RUF007
694 ]
696 def is_valid(self, do_except: bool = False) -> bool:
697 # No invalid instances possible within data member type hint bounds
698 return True
700 @serializable_dataclass(frozen=True, kw_only=True)
701 class Singles(_StepSize):
702 """Every coord in `maze.solution` is represented.
704 Legacy tokenizers all use this behavior.
705 """
707 def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
708 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
709 return list(range(maze.solution.shape[0]))
711 @serializable_dataclass(frozen=True, kw_only=True)
712 @mark_as_unsupported(_unsupported_is_invalid)
713 class Straightaways(_StepSize):
714 """Only coords where the path turns are represented in the path.
716 I.e., the path is represented as a sequence of straightaways,
717 specified by the coords at the turns.
718 """
720 def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
721 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
722 last_turn_coord: Coord = maze.solution[0, ...]
723 indices: list[int] = [0]
724 for i, coord in enumerate(maze.solution):
725 if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]:
726 indices.append(i - 1)
727 last_turn_coord = maze.solution[i - 1, ...]
728 indices.append(i)
729 return indices
731 @serializable_dataclass(frozen=True, kw_only=True)
732 class Forks(_StepSize):
733 """Only coords at forks, where the path has >=2 options for the next step are included.
735 Excludes the option of backtracking.
736 The starting and ending coords are always included.
737 """
739 def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
740 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
741 return maze.get_solution_forking_points(always_include_endpoints=True)[0]
743 @serializable_dataclass(frozen=True, kw_only=True)
744 @mark_as_unsupported(_unsupported_is_invalid)
745 class ForksAndStraightaways(_StepSize):
746 """Includes the union of the coords included by `Forks` and `Straightaways`.
748 See documentation for those classes for details.
749 """
751 def _step_single_indices(self, maze: SolvedMaze) -> list[int]:
752 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized."""
753 return list(
754 np.unique(
755 np.concatenate(
756 (
757 StepSizes.Straightaways()._step_single_indices(maze),
758 StepSizes.Forks()._step_single_indices(maze),
759 ),
760 ),
761 ),
762 )
765class StepTokenizers(__TokenizerElementNamespace):
766 """Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
768 key = "step_tokenizers"
770 @serializable_dataclass(frozen=True, kw_only=True)
771 class _StepTokenizer(_TokenizerElement, abc.ABC):
772 """Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized."""
774 @classmethod
775 def attribute_key(cls) -> str:
776 return StepTokenizers.key
778 @abc.abstractmethod
779 def to_tokens(
780 self,
781 maze: SolvedMaze,
782 start_index: int,
783 end_index: int,
784 **kwargs,
785 ) -> list[str]:
786 """Tokenizes a single step in the solution.
788 # Parameters
789 - `maze`: Maze to be tokenized
790 - `start_index`: The index of the Coord in `maze.solution` at which the current step starts
791 - `end_index`: The index of the Coord in `maze.solution` at which the current step ends
792 """
793 raise NotImplementedError(
794 "Subclasses must implement `StepTokenizer.to_tokens.",
795 )
797 def is_valid(self, do_except: bool = False) -> bool:
798 # No invalid instances possible within data member type hint bounds
799 return True
801 @serializable_dataclass(frozen=True, kw_only=True)
802 class Coord(_StepTokenizer):
803 """A direct tokenization of the end position coord represents the step."""
805 # inherit docstring
806 def to_tokens( # noqa: D102
807 self,
808 maze: SolvedMaze,
809 start_index: int,
810 end_index: int,
811 coord_tokenizer: CoordTokenizers._CoordTokenizer,
812 ) -> list[str]:
813 return coord_tokenizer.to_tokens(maze.solution[end_index, ...])
815 @serializable_dataclass(frozen=True, kw_only=True)
816 class Cardinal(_StepTokenizer):
817 """A step is tokenized with a cardinal direction token.
819 It is the direction of the step from the starting position along the solution.
820 """
822 # inherit docstring
823 def to_tokens( # noqa: D102
824 self,
825 maze: SolvedMaze,
826 start_index: int,
827 end_index: int,
828 **kwargs,
829 ) -> list[str]:
830 return [
831 get_cardinal_direction(maze.solution[start_index : start_index + 2]),
832 ]
834 @serializable_dataclass(frozen=True, kw_only=True)
835 class Relative(_StepTokenizer):
836 """Tokenizes a solution step using relative first-person directions (right, left, forward, etc.).
838 To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH.
839 Similarly to `Cardinal`, the direction is that of the step from the starting position.
840 """
842 # inherit docstring
843 def to_tokens( # noqa: D102
844 self,
845 maze: SolvedMaze,
846 start_index: int,
847 end_index: int,
848 **kwargs,
849 ) -> list[str]:
850 if start_index == 0:
851 start = maze.solution[0]
852 previous = start + np.array([1, 0])
853 return [
854 get_relative_direction(
855 np.concatenate(
856 (
857 np.expand_dims(previous, 0),
858 maze.solution[start_index : start_index + 2],
859 ),
860 axis=0,
861 ),
862 ),
863 ]
864 return [
865 get_relative_direction(
866 maze.solution[start_index - 1 : start_index + 2],
867 ),
868 ]
870 @serializable_dataclass(frozen=True, kw_only=True)
871 class Distance(_StepTokenizer):
872 """A count of the number of individual steps from the starting point to the end point.
874 Contains no information about directionality, only the distance traveled in the step.
875 `Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`.
876 This constraint is enforced in `_PathTokenizer.is_valid`.
877 """
879 # inherit docstring
880 def to_tokens( # noqa: D102
881 self,
882 maze: SolvedMaze,
883 start_index: int,
884 end_index: int,
885 **kwargs,
886 ) -> list[str]:
887 d: int = end_index - start_index
888 return [getattr(VOCAB, f"I_{d:03}")]
890 """
891 `StepTokenizerPermutation`
892 A sequence of unique `_StepTokenizer`s.
893 This type exists mostly just for the clarity and convenience of `_PathTokenizer` code.
894 """
895 StepTokenizerPermutation: type = (
896 tuple[_StepTokenizer]
897 | tuple[_StepTokenizer, _StepTokenizer]
898 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer]
899 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer]
900 )
903class PathTokenizers(__TokenizerElementNamespace):
904 """Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`."""
906 key = "path_tokenizer"
908 @serializable_dataclass(frozen=True, kw_only=True)
909 class _PathTokenizer(_TokenizerElement, abc.ABC):
910 """Superclass of tokenizers for maze solution paths."""
912 @abc.abstractmethod
913 def to_tokens(
914 self,
915 maze: SolvedMaze,
916 coord_tokenizer: CoordTokenizers._CoordTokenizer,
917 ) -> list[str]:
918 """Returns tokens representing the solution path."""
919 pass
921 @classmethod
922 def attribute_key(cls) -> str:
923 return PathTokenizers.key
925 @serializable_dataclass(frozen=True, kw_only=True)
926 class StepSequence(_PathTokenizer, abc.ABC):
927 """Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path.
929 Allows for a sequence of leading and trailing tokens which don't fit the step pattern.
931 # Parameters
932 - `step_size`: Selects the size of a single step in the sequence
933 - `step_tokenizers`: Selects the combination and permutation of tokens
934 - `pre`: Whether all steps include an integral preceding delimiter token
935 - `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization.
936 - `post`: Whether all steps include an integral following delimiter token
937 """
939 step_size: StepSizes._StepSize = serializable_field(
940 default=StepSizes.Singles(),
941 loading_fn=lambda x: _load_tokenizer_element(x, StepSizes),
942 )
943 step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field(
944 default=(StepTokenizers.Coord(),),
945 serialization_fn=lambda x: [y.serialize() for y in x],
946 loading_fn=lambda x: tuple(x[StepTokenizers.key]),
947 )
948 pre: bool = serializable_field(default=False)
949 intra: bool = serializable_field(default=False)
950 post: bool = serializable_field(default=False)
952 # inherit docstring
953 def to_tokens( # noqa: D102
954 self,
955 maze: SolvedMaze,
956 coord_tokenizer: CoordTokenizers._CoordTokenizer,
957 ) -> list[str]:
958 return [
959 *self._leading_tokens(maze, coord_tokenizer),
960 *flatten(
961 [
962 self._single_step_tokens(maze, start, end, coord_tokenizer)
963 for start, end in self.step_size.step_start_end_indices(maze)
964 ],
965 ),
966 *self._trailing_tokens(maze, coord_tokenizer),
967 ]
969 def _single_step_tokens(
970 self,
971 maze: SolvedMaze,
972 i: int,
973 j: int,
974 coord_tokenizer: CoordTokenizers._CoordTokenizer,
975 ) -> list[str]:
976 """Returns the token sequence representing a single step along the path."""
977 step_rep_tokens: list[list[str]] = [
978 step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer)
979 for step_tokenizer in self.step_tokenizers
980 ]
981 if self.intra:
982 step_rep_tokens_and_intra: list[str] = [None] * (
983 len(step_rep_tokens) * 2
984 )
985 step_rep_tokens_and_intra[::2] = step_rep_tokens
986 step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len(
987 step_rep_tokens,
988 )
989 step_rep_tokens = list(flatten(step_rep_tokens_and_intra))
990 all_tokens: list[str] = [
991 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
992 *flatten(step_rep_tokens),
993 *empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"),
994 ]
995 return all_tokens
997 def _leading_tokens(
998 self,
999 maze: SolvedMaze,
1000 coord_tokenizer: CoordTokenizers._CoordTokenizer,
1001 ) -> list[str]:
1002 """Returns tokens preceding those from the sequence from `_single_step_tokens`.
1004 Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`.
1005 <PATH_START> should NOT be included.
1006 """
1007 if StepTokenizers.Coord() in self.step_tokenizers:
1008 return [
1009 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"),
1010 *coord_tokenizer.to_tokens(maze.solution[0, ...]),
1011 *empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"),
1012 ]
1013 return []
1015 def _trailing_tokens(
1016 self,
1017 c: Coord,
1018 coord_tokenizer: CoordTokenizers._CoordTokenizer,
1019 ) -> list[str]:
1020 """Returns tokens following those from the sequence from `_single_step_tokens`.
1022 <PATH_END> should NOT be included.
1023 """
1024 return []
1026 # inherits docstring
1027 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102
1028 output: bool
1030 if len(set(self.step_tokenizers)) != len(self.step_tokenizers):
1031 # Uninteresting: repeated elements are not useful
1032 output = False
1033 else:
1034 # we do noqa for the comment if false
1035 if len(self.step_tokenizers) == 1 and isinstance(
1036 self.step_tokenizers[0],
1037 StepTokenizers.Distance,
1038 ):
1039 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required.
1040 output = False
1041 else:
1042 output = True
1044 if not output and do_except:
1045 raise ValueError(
1046 "PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.",
1047 )
1049 return output
1052class PromptSequencers(__TokenizerElementNamespace):
1053 """Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`."""
1055 key = "prompt_sequencer"
1057 @serializable_dataclass(frozen=True, kw_only=True)
1058 class _PromptSequencer(_TokenizerElement, abc.ABC):
1059 """Sequences token regions into a complete maze tokenization.
1061 # Parameters
1062 - `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position.
1063 - `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`.
1064 Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s.
1065 """
1067 coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field(
1068 default=CoordTokenizers.UT(),
1069 loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers),
1070 )
1071 adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field(
1072 default=AdjListTokenizers.AdjListCoord(),
1073 loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers),
1074 )
1076 @classmethod
1077 def attribute_key(cls) -> str:
1078 return PromptSequencers.key
1080 @staticmethod
1081 def _trim_if_unsolved_maze(
1082 untrimmed: list[str],
1083 is_untargeted: bool = False,
1084 is_unsolved: bool = False,
1085 ) -> list[str]:
1086 """Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze.
1088 # Development
1089 This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP.
1090 It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses.
1091 """
1092 if is_untargeted:
1093 return tokens_between(
1094 untrimmed,
1095 VOCAB.ADJLIST_START,
1096 VOCAB.ADJLIST_END,
1097 include_start=True,
1098 include_end=True,
1099 )
1100 if is_unsolved:
1101 if VOCAB.TARGET_END in untrimmed:
1102 return tokens_between(
1103 untrimmed,
1104 VOCAB.ADJLIST_START,
1105 VOCAB.TARGET_END,
1106 include_start=True,
1107 include_end=True,
1108 )
1109 else:
1110 return tokens_between(
1111 untrimmed,
1112 VOCAB.ADJLIST_START,
1113 VOCAB.ORIGIN_END,
1114 include_start=True,
1115 include_end=True,
1116 )
1117 return untrimmed
1119 def to_tokens(
1120 self,
1121 maze: LatticeMaze,
1122 *args,
1123 **kwargs,
1124 ) -> list[str]:
1125 """Returns a complete list of tokens for a given set of maze elements."""
1126 untrimmed: list[str] = self._sequence_tokens(
1127 *self._get_prompt_regions(maze),
1128 )
1129 return self._trim_if_unsolved_maze(
1130 untrimmed,
1131 not hasattr(maze, "start_pos"),
1132 not hasattr(maze, "solution"),
1133 )
1135 def _get_prompt_regions(
1136 self,
1137 maze: LatticeMaze,
1138 *args,
1139 **kwargs,
1140 ) -> list[list[str]]:
1141 """Gets the prompt regions of a maze in a fixed sequence.
1143 This method is NOT responsible for including/excluding any prompt regions.
1144 Always return according to the API described under Returns.
1145 This implementation is expected to be suitable for most `PromptSequencer` subclasses.
1146 Subclasses may override this method if needed for special behavior.
1148 # Returns
1149 - [0]: list[str] Adjacency list tokens
1150 - [1]: list[str] Origin tokens
1151 - [2]: list[str] Target tokens
1152 - [3]: list[str] Path tokens
1154 # `None`-valued Args
1155 If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized.
1156 To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables.
1157 """
1158 origin: Coord | None = getattr(maze, "start_pos", None)
1159 target: list[Coord] | None = [
1160 getattr(maze, "end_pos", None),
1161 ] # TargetTokenizer requires target: Sequence[Coord]
1163 return [
1164 (
1165 self.adj_list_tokenizer.to_tokens(
1166 maze,
1167 coord_tokenizer=self.coord_tokenizer,
1168 )
1169 if hasattr(self, "adj_list_tokenizer")
1170 else []
1171 ),
1172 self.coord_tokenizer.to_tokens(origin) if origin is not None else [],
1173 (
1174 self.target_tokenizer.to_tokens(
1175 target,
1176 coord_tokenizer=self.coord_tokenizer,
1177 )
1178 if target[0] is not None and hasattr(self, "target_tokenizer")
1179 else []
1180 ),
1181 (
1182 self.path_tokenizer.to_tokens(
1183 maze,
1184 coord_tokenizer=self.coord_tokenizer,
1185 )
1186 if hasattr(maze, "solution") and hasattr(self, "path_tokenizer")
1187 else []
1188 ),
1189 ]
1191 @abc.abstractmethod
1192 def _sequence_tokens(
1193 self,
1194 adj_list: list[str],
1195 origin: list[str] | None,
1196 target: list[str] | None,
1197 path: list[str] | None,
1198 ) -> list[str]:
1199 """Sequences token regions into a complete prompt.
1201 Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as <ADJLIST_START>, <ORIGIN_END>, etc.
1203 # Parameters
1204 - `adj_list`: Tokens representing the adjacency list
1205 - `origin`: Tokens representing the origin
1206 - `target`: Tokens representing the target
1207 - `path`: Tokens representing the path
1208 """
1209 pass
1211 def is_valid(self, do_except: bool = False) -> bool:
1212 # No invalid instances possible within data member type hint bounds
1213 return True
1215 @serializable_dataclass(frozen=True, kw_only=True)
1216 class AOTP(_PromptSequencer):
1217 """Sequences a prompt as [adjacency list, origin, target, path].
1219 # Parameters
1220 - `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`.
1221 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`.
1222 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1223 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1225 """
1227 target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field(
1228 default=TargetTokenizers.Unlabeled(),
1229 loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers),
1230 )
1231 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1232 default=PathTokenizers.StepSequence(),
1233 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1234 )
1236 def _sequence_tokens(
1237 self,
1238 adj_list: list[str],
1239 origin: list[str],
1240 target: list[str],
1241 path: list[str],
1242 ) -> list[str]:
1243 return [
1244 VOCAB.ADJLIST_START,
1245 *adj_list,
1246 VOCAB.ADJLIST_END,
1247 VOCAB.ORIGIN_START,
1248 *origin,
1249 VOCAB.ORIGIN_END,
1250 VOCAB.TARGET_START,
1251 *target,
1252 VOCAB.TARGET_END,
1253 VOCAB.PATH_START,
1254 *path,
1255 VOCAB.PATH_END,
1256 ]
1258 @serializable_dataclass(frozen=True, kw_only=True)
1259 class AOP(_PromptSequencer):
1260 """Sequences a prompt as [adjacency list, origin, path].
1262 Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself.
1264 # Parameters
1265 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`.
1266 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`.
1267 """
1269 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field(
1270 default=PathTokenizers.StepSequence(),
1271 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers),
1272 )
1274 def _sequence_tokens(
1275 self,
1276 adj_list: list[str],
1277 origin: list[str],
1278 # explicitly no target in this tokenizer
1279 target: list[str],
1280 path: list[str],
1281 ) -> list[str]:
1282 return [
1283 VOCAB.ADJLIST_START,
1284 *adj_list,
1285 VOCAB.ADJLIST_END,
1286 VOCAB.ORIGIN_START,
1287 *origin,
1288 VOCAB.ORIGIN_END,
1289 VOCAB.TARGET_START,
1290 VOCAB.TARGET_END,
1291 VOCAB.PATH_START,
1292 *path,
1293 VOCAB.PATH_END,
1294 ]