Coverage for maze_dataset/constants.py: 16%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 00:49 -0600

1"""constants and type hints used accross the package""" 

2 

3import warnings 

4from dataclasses import dataclass, field, make_dataclass 

5from typing import Iterator 

6 

7import numpy as np 

8from jaxtyping import Bool, Int8 

9 

10from maze_dataset.utils import corner_first_ndindex 

11 

12# various type hints for coordinates, connections, etc. 

13 

14Coord = Int8[np.ndarray, "row_col=2"] 

15"single coordinate as array" 

16 

17CoordTup = tuple[int, int] 

18"single coordinate as tuple" 

19 

20CoordArray = Int8[np.ndarray, "coord row_col=2"] 

21"array of coordinates" 

22 

23CoordList = list[CoordTup] 

24"list of tuple coordinates" 

25 

26Connection = Int8[np.ndarray, "coord=2 row_col=2"] 

27"single connection (pair of coords) as array" 

28 

29ConnectionList = Bool[np.ndarray, "lattice_dim=2 row col"] 

30"internal representation used in `LatticeMaze`" 

31 

32ConnectionArray = Int8[np.ndarray, "edges leading_trailing_coord=2 row_col=2"] 

33"n_edges * 2 * 2 array of connections, like an adjacency list" 

34 

35 

36class SpecialTokensError(Exception): 

37 "(unused!) errors related to special tokens" 

38 

39 pass 

40 

41 

42_SPECIAL_TOKENS_ABBREVIATIONS: dict[str, str] = { 

43 "<ADJLIST_START>": "<A_S>", 

44 "<ADJLIST_END>": "<A_E>", 

45 "<TARGET_START>": "<T_S>", 

46 "<TARGET_END>": "<T_E>", 

47 "<ORIGIN_START>": "<O_S>", 

48 "<ORIGIN_END>": "<O_E>", 

49 "<PATH_START>": "<P_S>", 

50 "<PATH_END>": "<P_E>", 

51 "<-->": "<-->", 

52 ";": ";", 

53 "<PADDING>": "<PAD>", 

54} 

55"map abbreviations for (some) special tokens" 

56 

57 

58@dataclass(frozen=True) 

59class _SPECIAL_TOKENS_BASE: # noqa: N801 

60 "special dataclass used for handling special tokens" 

61 

62 ADJLIST_START: str = "<ADJLIST_START>" 

63 ADJLIST_END: str = "<ADJLIST_END>" 

64 TARGET_START: str = "<TARGET_START>" 

65 TARGET_END: str = "<TARGET_END>" 

66 ORIGIN_START: str = "<ORIGIN_START>" 

67 ORIGIN_END: str = "<ORIGIN_END>" 

68 PATH_START: str = "<PATH_START>" 

69 PATH_END: str = "<PATH_END>" 

70 CONNECTOR: str = "<-->" 

71 ADJACENCY_ENDLINE: str = ";" 

72 PADDING: str = "<PADDING>" 

73 

74 def __getitem__(self, key: str) -> str: 

75 key_upper: str = key.upper() 

76 

77 if not isinstance(key, str): 

78 err_msg: str = f"key must be str, not {type(key)}" 

79 raise TypeError(err_msg) 

80 

81 # error checking for old lowercase format 

82 if key != key_upper: 

83 warnings.warn( 

84 f"Accessing special token '{key}' without uppercase. this is deprecated and will be removed in the future.", 

85 DeprecationWarning, 

86 ) 

87 key = key_upper 

88 

89 # `ADJLIST` used to be `adj_list`, changed to match actual token content 

90 if key_upper not in self.keys(): 

91 key_upper_modified: str = key_upper.replace("ADJ_LIST", "ADJLIST") 

92 if key_upper_modified in self.keys(): 

93 warnings.warn( 

94 f"Accessing '{key}' in old format, should use {key_upper_modified}. this is deprecated and will be removed in the future.", 

95 DeprecationWarning, 

96 ) 

97 return getattr(self, key_upper_modified) 

98 else: 

99 err_msg: str = f"invalid special token '{key}'" 

100 raise KeyError(err_msg) 

101 

102 # normal return 

103 return getattr(self, key.upper()) 

104 

105 def get_abbrev(self, key: str) -> str: 

106 return _SPECIAL_TOKENS_ABBREVIATIONS[self[key]] 

107 

108 def __iter__(self) -> Iterator[str]: 

109 return iter(self.__dict__.keys()) 

110 

111 def __len__(self) -> int: 

112 return len(self.__dict__.keys()) 

113 

114 def __contains__(self, key: str) -> bool: 

115 return key in self.__dict__ 

116 

117 def values(self) -> Iterator[str]: 

118 return self.__dict__.values() 

119 

120 def items(self) -> Iterator[tuple[str, str]]: 

121 return self.__dict__.items() 

122 

123 def keys(self) -> Iterator[str]: 

124 return self.__dict__.keys() 

125 

126 

127SPECIAL_TOKENS: _SPECIAL_TOKENS_BASE = _SPECIAL_TOKENS_BASE() 

128"special tokens" 

129 

130 

131DIRECTIONS_MAP: Int8[np.ndarray, "direction axes"] = np.array( 

132 [ 

133 [0, 1], # down 

134 [0, -1], # up 

135 [1, 1], # right 

136 [1, -1], # left 

137 ], 

138) 

139"down, up, right, left directions for when inside a `ConnectionList`" 

140 

141 

142NEIGHBORS_MASK: Int8[np.ndarray, "coord point"] = np.array( 

143 [ 

144 [0, 1], # down 

145 [0, -1], # up 

146 [1, 0], # right 

147 [-1, 0], # left 

148 ], 

149) 

150"down, up, right, left as vectors" 

151 

152 

153# last element of the tuple is actually a Field[str], but mypy complains 

154_VOCAB_FIELDS: list[tuple[str, type[str], str]] = [ 

155 # *[(k, str, field(default=v)) for k, v in SPECIAL_TOKENS.items()], 

156 ("COORD_PRE", str, field(default="(")), 

157 ("COORD_INTRA", str, field(default=",")), 

158 ("COORD_POST", str, field(default=")")), 

159 ("TARGET_INTRA", str, field(default="=")), 

160 ("TARGET_POST", str, field(default="||")), 

161 ("PATH_INTRA", str, field(default=":")), 

162 ("PATH_POST", str, field(default="THEN")), 

163 ("NEGATIVE", str, field(default="-")), 

164 ("UNKNOWN", str, field(default="<UNK>")), 

165 *[ 

166 (f"TARGET_{a}", str, field(default=f"TARGET_{a}")) 

167 for a in "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 

168 ], 

169 ("TARGET_NORTH", str, field(default="TARGET_NORTH")), 

170 ("TARGET_SOUTH", str, field(default="TARGET_SOUTH")), 

171 ("TARGET_EAST", str, field(default="TARGET_EAST")), 

172 ("TARGET_WEST", str, field(default="TARGET_WEST")), 

173 ("TARGET_NORTHEAST", str, field(default="TARGET_NORTHEAST")), 

174 ("TARGET_NORTHWEST", str, field(default="TARGET_NORTHWEST")), 

175 ("TARGET_SOUTHEAST", str, field(default="TARGET_SOUTHEAST")), 

176 ("TARGET_SOUTHWEST", str, field(default="TARGET_SOUTHWEST")), 

177 ("TARGET_CENTER", str, field(default="TARGET_CENTER")), 

178 ("PATH_NORTH", str, field(default="NORTH")), 

179 ("PATH_SOUTH", str, field(default="SOUTH")), 

180 ("PATH_EAST", str, field(default="EAST")), 

181 ("PATH_WEST", str, field(default="WEST")), 

182 ("PATH_FORWARD", str, field(default="FORWARD")), 

183 ("PATH_BACKWARD", str, field(default="BACKWARD")), 

184 ("PATH_LEFT", str, field(default="LEFT")), 

185 ("PATH_RIGHT", str, field(default="RIGHT")), 

186 ("PATH_STAY", str, field(default="STAY")), 

187 *[ 

188 (f"I_{i:03}", str, field(default=f"+{i}")) for i in range(256) 

189 ], # General purpose positive int tokens. Used by `StepTokenizers.Distance`. 

190 *[ 

191 (f"CTT_{i}", str, field(default=f"{i}")) for i in range(128) 

192 ], # Coord tuple tokens 

193 *[ 

194 (f"I_N{-i:03}", str, field(default=f"{i}")) for i in range(-256, 0) 

195 ], # General purpose negative int tokens 

196 ("PATH_PRE", str, field(default="STEP")), 

197 ("ADJLIST_PRE", str, field(default="ADJ_GROUP")), 

198 ("ADJLIST_INTRA", str, field(default="&")), 

199 ("ADJLIST_WALL", str, field(default="<XX>")), 

200 *[(f"RESERVE_{i}", str, field(default=f"<RESERVE_{i}>")) for i in range(708, 1596)], 

201 *[ 

202 (f"UT_{x:02}_{y:02}", str, field(default=f"({x},{y})")) 

203 for x, y in corner_first_ndindex(50) 

204 ], 

205] 

206"fields for the `MazeTokenizerModular` style combined vocab" 

207 

208_VOCAB_BASE: type = make_dataclass( 

209 "_VOCAB_BASE", 

210 fields=_VOCAB_FIELDS, 

211 bases=(_SPECIAL_TOKENS_BASE,), 

212 frozen=True, 

213) 

214"combined vocab class, private" 

215# TODO: edit __getitem__ to add warning for accessing a RESERVE token 

216 

217# HACK: mypy doesn't recognize the fields in this dataclass 

218VOCAB: _VOCAB_BASE = _VOCAB_BASE() # type: ignore 

219"public access to universal vocabulary for `MazeTokenizerModular`" 

220VOCAB_LIST: list[str] = list(VOCAB.values()) 

221"list of `VOCAB` tokens, in order" 

222VOCAB_TOKEN_TO_INDEX: dict[str, int] = {token: i for i, token in enumerate(VOCAB_LIST)} 

223"map of `VOCAB` tokens to their indices" 

224 

225# CARDINAL_MAP: Maps tuple(coord1 - coord0) : cardinal direction 

226CARDINAL_MAP: dict[tuple[int, int], str] = { 

227 (-1, 0): VOCAB.PATH_NORTH, 

228 (1, 0): VOCAB.PATH_SOUTH, 

229 (0, -1): VOCAB.PATH_WEST, 

230 (0, 1): VOCAB.PATH_EAST, 

231} 

232"map of cardinal directions to appropriate tokens"