Coverage for maze_dataset/constants.py: 16%
85 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
1"""constants and type hints used accross the package"""
3import warnings
4from dataclasses import dataclass, field, make_dataclass
5from typing import Iterator
7import numpy as np
8from jaxtyping import Bool, Int8
10from maze_dataset.utils import corner_first_ndindex
12# various type hints for coordinates, connections, etc.
14Coord = Int8[np.ndarray, "row_col=2"]
15"single coordinate as array"
17CoordTup = tuple[int, int]
18"single coordinate as tuple"
20CoordArray = Int8[np.ndarray, "coord row_col=2"]
21"array of coordinates"
23CoordList = list[CoordTup]
24"list of tuple coordinates"
26Connection = Int8[np.ndarray, "coord=2 row_col=2"]
27"single connection (pair of coords) as array"
29ConnectionList = Bool[np.ndarray, "lattice_dim=2 row col"]
30"internal representation used in `LatticeMaze`"
32ConnectionArray = Int8[np.ndarray, "edges leading_trailing_coord=2 row_col=2"]
33"n_edges * 2 * 2 array of connections, like an adjacency list"
36class SpecialTokensError(Exception):
37 "(unused!) errors related to special tokens"
39 pass
42_SPECIAL_TOKENS_ABBREVIATIONS: dict[str, str] = {
43 "<ADJLIST_START>": "<A_S>",
44 "<ADJLIST_END>": "<A_E>",
45 "<TARGET_START>": "<T_S>",
46 "<TARGET_END>": "<T_E>",
47 "<ORIGIN_START>": "<O_S>",
48 "<ORIGIN_END>": "<O_E>",
49 "<PATH_START>": "<P_S>",
50 "<PATH_END>": "<P_E>",
51 "<-->": "<-->",
52 ";": ";",
53 "<PADDING>": "<PAD>",
54}
55"map abbreviations for (some) special tokens"
58@dataclass(frozen=True)
59class _SPECIAL_TOKENS_BASE: # noqa: N801
60 "special dataclass used for handling special tokens"
62 ADJLIST_START: str = "<ADJLIST_START>"
63 ADJLIST_END: str = "<ADJLIST_END>"
64 TARGET_START: str = "<TARGET_START>"
65 TARGET_END: str = "<TARGET_END>"
66 ORIGIN_START: str = "<ORIGIN_START>"
67 ORIGIN_END: str = "<ORIGIN_END>"
68 PATH_START: str = "<PATH_START>"
69 PATH_END: str = "<PATH_END>"
70 CONNECTOR: str = "<-->"
71 ADJACENCY_ENDLINE: str = ";"
72 PADDING: str = "<PADDING>"
74 def __getitem__(self, key: str) -> str:
75 key_upper: str = key.upper()
77 if not isinstance(key, str):
78 err_msg: str = f"key must be str, not {type(key)}"
79 raise TypeError(err_msg)
81 # error checking for old lowercase format
82 if key != key_upper:
83 warnings.warn(
84 f"Accessing special token '{key}' without uppercase. this is deprecated and will be removed in the future.",
85 DeprecationWarning,
86 )
87 key = key_upper
89 # `ADJLIST` used to be `adj_list`, changed to match actual token content
90 if key_upper not in self.keys():
91 key_upper_modified: str = key_upper.replace("ADJ_LIST", "ADJLIST")
92 if key_upper_modified in self.keys():
93 warnings.warn(
94 f"Accessing '{key}' in old format, should use {key_upper_modified}. this is deprecated and will be removed in the future.",
95 DeprecationWarning,
96 )
97 return getattr(self, key_upper_modified)
98 else:
99 err_msg: str = f"invalid special token '{key}'"
100 raise KeyError(err_msg)
102 # normal return
103 return getattr(self, key.upper())
105 def get_abbrev(self, key: str) -> str:
106 return _SPECIAL_TOKENS_ABBREVIATIONS[self[key]]
108 def __iter__(self) -> Iterator[str]:
109 return iter(self.__dict__.keys())
111 def __len__(self) -> int:
112 return len(self.__dict__.keys())
114 def __contains__(self, key: str) -> bool:
115 return key in self.__dict__
117 def values(self) -> Iterator[str]:
118 return self.__dict__.values()
120 def items(self) -> Iterator[tuple[str, str]]:
121 return self.__dict__.items()
123 def keys(self) -> Iterator[str]:
124 return self.__dict__.keys()
127SPECIAL_TOKENS: _SPECIAL_TOKENS_BASE = _SPECIAL_TOKENS_BASE()
128"special tokens"
131DIRECTIONS_MAP: Int8[np.ndarray, "direction axes"] = np.array(
132 [
133 [0, 1], # down
134 [0, -1], # up
135 [1, 1], # right
136 [1, -1], # left
137 ],
138)
139"down, up, right, left directions for when inside a `ConnectionList`"
142NEIGHBORS_MASK: Int8[np.ndarray, "coord point"] = np.array(
143 [
144 [0, 1], # down
145 [0, -1], # up
146 [1, 0], # right
147 [-1, 0], # left
148 ],
149)
150"down, up, right, left as vectors"
153# last element of the tuple is actually a Field[str], but mypy complains
154_VOCAB_FIELDS: list[tuple[str, type[str], str]] = [
155 # *[(k, str, field(default=v)) for k, v in SPECIAL_TOKENS.items()],
156 ("COORD_PRE", str, field(default="(")),
157 ("COORD_INTRA", str, field(default=",")),
158 ("COORD_POST", str, field(default=")")),
159 ("TARGET_INTRA", str, field(default="=")),
160 ("TARGET_POST", str, field(default="||")),
161 ("PATH_INTRA", str, field(default=":")),
162 ("PATH_POST", str, field(default="THEN")),
163 ("NEGATIVE", str, field(default="-")),
164 ("UNKNOWN", str, field(default="<UNK>")),
165 *[
166 (f"TARGET_{a}", str, field(default=f"TARGET_{a}"))
167 for a in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
168 ],
169 ("TARGET_NORTH", str, field(default="TARGET_NORTH")),
170 ("TARGET_SOUTH", str, field(default="TARGET_SOUTH")),
171 ("TARGET_EAST", str, field(default="TARGET_EAST")),
172 ("TARGET_WEST", str, field(default="TARGET_WEST")),
173 ("TARGET_NORTHEAST", str, field(default="TARGET_NORTHEAST")),
174 ("TARGET_NORTHWEST", str, field(default="TARGET_NORTHWEST")),
175 ("TARGET_SOUTHEAST", str, field(default="TARGET_SOUTHEAST")),
176 ("TARGET_SOUTHWEST", str, field(default="TARGET_SOUTHWEST")),
177 ("TARGET_CENTER", str, field(default="TARGET_CENTER")),
178 ("PATH_NORTH", str, field(default="NORTH")),
179 ("PATH_SOUTH", str, field(default="SOUTH")),
180 ("PATH_EAST", str, field(default="EAST")),
181 ("PATH_WEST", str, field(default="WEST")),
182 ("PATH_FORWARD", str, field(default="FORWARD")),
183 ("PATH_BACKWARD", str, field(default="BACKWARD")),
184 ("PATH_LEFT", str, field(default="LEFT")),
185 ("PATH_RIGHT", str, field(default="RIGHT")),
186 ("PATH_STAY", str, field(default="STAY")),
187 *[
188 (f"I_{i:03}", str, field(default=f"+{i}")) for i in range(256)
189 ], # General purpose positive int tokens. Used by `StepTokenizers.Distance`.
190 *[
191 (f"CTT_{i}", str, field(default=f"{i}")) for i in range(128)
192 ], # Coord tuple tokens
193 *[
194 (f"I_N{-i:03}", str, field(default=f"{i}")) for i in range(-256, 0)
195 ], # General purpose negative int tokens
196 ("PATH_PRE", str, field(default="STEP")),
197 ("ADJLIST_PRE", str, field(default="ADJ_GROUP")),
198 ("ADJLIST_INTRA", str, field(default="&")),
199 ("ADJLIST_WALL", str, field(default="<XX>")),
200 *[(f"RESERVE_{i}", str, field(default=f"<RESERVE_{i}>")) for i in range(708, 1596)],
201 *[
202 (f"UT_{x:02}_{y:02}", str, field(default=f"({x},{y})"))
203 for x, y in corner_first_ndindex(50)
204 ],
205]
206"fields for the `MazeTokenizerModular` style combined vocab"
208_VOCAB_BASE: type = make_dataclass(
209 "_VOCAB_BASE",
210 fields=_VOCAB_FIELDS,
211 bases=(_SPECIAL_TOKENS_BASE,),
212 frozen=True,
213)
214"combined vocab class, private"
215# TODO: edit __getitem__ to add warning for accessing a RESERVE token
217# HACK: mypy doesn't recognize the fields in this dataclass
218VOCAB: _VOCAB_BASE = _VOCAB_BASE() # type: ignore
219"public access to universal vocabulary for `MazeTokenizerModular`"
220VOCAB_LIST: list[str] = list(VOCAB.values())
221"list of `VOCAB` tokens, in order"
222VOCAB_TOKEN_TO_INDEX: dict[str, int] = {token: i for i, token in enumerate(VOCAB_LIST)}
223"map of `VOCAB` tokens to their indices"
225# CARDINAL_MAP: Maps tuple(coord1 - coord0) : cardinal direction
226CARDINAL_MAP: dict[tuple[int, int], str] = {
227 (-1, 0): VOCAB.PATH_NORTH,
228 (1, 0): VOCAB.PATH_SOUTH,
229 (0, -1): VOCAB.PATH_WEST,
230 (0, 1): VOCAB.PATH_EAST,
231}
232"map of cardinal directions to appropriate tokens"