Coverage for maze_dataset/tokenization/modular/maze_tokenizer_modular.py: 82%

128 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 12:48 -0600

1"implements the actual `MazeTokenizerModular` class" 

2 

3import base64 

4import warnings 

5from functools import cached_property 

6from typing import ( 

7 Iterable, 

8 Literal, 

9 Sequence, 

10 overload, 

11) 

12 

13from muutils.json_serialize import ( 

14 SerializableDataclass, 

15 serializable_dataclass, 

16 serializable_field, 

17) 

18from muutils.misc import flatten 

19from muutils.misc.sequence import WhenMissing 

20 

21# from maze_dataset import SolvedMaze 

22from maze_dataset.constants import ( 

23 VOCAB, 

24 VOCAB_LIST, 

25 VOCAB_TOKEN_TO_INDEX, 

26 Coord, 

27 CoordTup, 

28) 

29from maze_dataset.maze.lattice_maze import LatticeMaze 

30from maze_dataset.token_utils import ( 

31 TokenizerPendingDeprecationWarning, 

32 strings_to_coords, 

33) 

34from maze_dataset.tokenization.common import TokenError 

35from maze_dataset.tokenization.maze_tokenizer_legacy import ( 

36 MazeTokenizer, 

37 TokenizationMode, 

38) 

39from maze_dataset.tokenization.modular.element_base import ( 

40 _load_tokenizer_element, 

41 _TokenizerElement, 

42) 

43from maze_dataset.tokenization.modular.elements import CoordTokenizers, PromptSequencers 

44from maze_dataset.tokenization.modular.fst_load import check_tokenizer_in_fst 

45from maze_dataset.tokenization.modular.hashing import ( 

46 _hash_tokenizer_name, 

47) 

48 

49 

50@serializable_dataclass( 

51 frozen=True, 

52 kw_only=True, 

53 properties_to_serialize=["tokenizer_element_tree_concrete", "name"], 

54) 

55class MazeTokenizerModular(SerializableDataclass): 

56 """Tokenizer for mazes 

57 

58 # Parameters 

59 - `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt. 

60 

61 # Development 

62 - To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`. 

63 - Furthermore, the mapping reflected in `from_legacy` must also be maintained. 

64 - Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior. 

65 """ 

66 

67 prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field( 

68 default=PromptSequencers.AOTP(), 

69 loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers), 

70 ) 

71 

72 def hash_int(self) -> int: 

73 "return integer hash using blake2b" 

74 return _hash_tokenizer_name(self.name) 

75 

76 def __hash__(self) -> int: 

77 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" 

78 return self.hash_int() 

79 

80 def hash_b64(self, n_bytes: int = 8) -> str: 

81 """filename-safe base64 encoding of the hash""" 

82 # Use modulus to ensure the integer fits within n_bytes * 8 bits 

83 hash_mod: int = self.hash_int() % (1 << (n_bytes * 8)) 

84 

85 encoded = base64.b64encode( 

86 hash_mod.to_bytes(n_bytes, byteorder="big"), 

87 altchars=b"-_", 

88 ).decode() 

89 

90 # Remove any padding equals signs 

91 return encoded.rstrip("=") 

92 

93 # Information Querying Methods 

94 

95 @cached_property 

96 def tokenizer_elements(self) -> list[_TokenizerElement]: 

97 "returns a list of all the elements of this tokenizer" 

98 return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()] 

99 

100 def tokenizer_element_tree(self, abstract: bool = False) -> str: 

101 """Returns a string representation of the tree of tokenizer elements contained in `self`. 

102 

103 # Parameters 

104 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 

105 """ 

106 return "\n".join( 

107 [ 

108 type(self).__name__, 

109 self.prompt_sequencer.tokenizer_element_tree( 

110 abstract=abstract, 

111 depth=1, 

112 ), 

113 ], 

114 ) 

115 

116 @property 

117 def tokenizer_element_tree_concrete(self) -> str: 

118 """Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`.""" 

119 return self.tokenizer_element_tree() 

120 

121 def tokenizer_element_dict(self) -> dict: 

122 """Nested dictionary of the internal `TokenizerElement`s.""" 

123 return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()} 

124 

125 @property 

126 def name(self) -> str: 

127 """Serializes MazeTokenizer into a key for encoding in zanj""" 

128 return "-".join([type(self).__name__, self.prompt_sequencer.name]) # noqa: FLY002 

129 

130 def summary(self) -> dict[str, str]: 

131 """Single-level dictionary of the internal `TokenizerElement`s.""" 

132 return { 

133 # "prompt_sequencer": self.prompt_sequencer.name, 

134 **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements}, 

135 } 

136 

137 @staticmethod 

138 def _type_check(obj: any) -> None: 

139 """Helper method for `has_element`""" 

140 if not ( 

141 isinstance(obj, _TokenizerElement) 

142 or (isinstance(obj, type) and issubclass(obj, _TokenizerElement)) 

143 ): 

144 err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass." 

145 raise TypeError(err_msg) 

146 

147 def _has_element_singular( 

148 self, 

149 el: type[_TokenizerElement] | _TokenizerElement, 

150 ) -> bool: 

151 """Helper method for `has_element`""" 

152 self._type_check(el) 

153 if isinstance(el, type): 

154 return any(isinstance(e, el) for e in self.tokenizer_elements) 

155 else: 

156 return el in self.tokenizer_elements 

157 

158 def has_element( 

159 self, 

160 *elements: Sequence[type[_TokenizerElement] | _TokenizerElement], 

161 ) -> bool: 

162 """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`. 

163 

164 Querying with a partial subset of `_TokenizerElement` fields is not currently supported. 

165 To do such a query, assemble multiple calls to `has_elements`. 

166 

167 # Parameters 

168 - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes. 

169 If an instance is provided, then comparison is done via instance equality. 

170 If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted. 

171 """ 

172 if len(elements) == 1 and isinstance(elements[0], Iterable): 

173 elements = elements[0] 

174 return all(self._has_element_singular(e) for e in elements) 

175 

176 def is_valid(self, do_except: bool = False) -> bool: 

177 """Returns `True` if `self` is a valid tokenizer. 

178 

179 Evaluates the validity of all of `self.tokenizer_elements` according to each one's method. 

180 """ 

181 return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements) 

182 

183 def is_legacy_equivalent(self) -> bool: 

184 """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`.""" 

185 return any( 

186 self == MazeTokenizerModular.from_legacy(tok_mode) 

187 for tok_mode in TokenizationMode 

188 ) 

189 

190 def is_tested_tokenizer(self, do_except: bool = False) -> bool: 

191 """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers. 

192 

193 uses an fst on the `name` attributes of all the tokenizers 

194 

195 if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested. 

196 """ 

197 is_valid: bool = self.is_valid(do_except=do_except) 

198 in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except) 

199 

200 if do_except: 

201 assert is_valid, "self.is_valid returns False" 

202 return True 

203 else: 

204 return in_tested_fst and is_valid 

205 

206 def is_AOTP(self) -> bool: 

207 "is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path" 

208 return self.has_element(PromptSequencers.AOTP) 

209 

210 def is_UT(self) -> bool: 

211 "is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)" 

212 return self.has_element(CoordTokenizers.UT) 

213 

214 # Alternate Constructors 

215 # ====================== 

216 

217 @classmethod 

218 def from_legacy( 

219 cls, 

220 legacy_maze_tokenizer: MazeTokenizer | TokenizationMode, 

221 ) -> "MazeTokenizerModular": 

222 """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance.""" 

223 if isinstance(legacy_maze_tokenizer, MazeTokenizer): 

224 legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode 

225 return { 

226 TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(), 

227 TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(), 

228 TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular( 

229 prompt_sequencer=PromptSequencers.AOTP( 

230 coord_tokenizer=CoordTokenizers.CTT(), 

231 ), 

232 ), 

233 }[legacy_maze_tokenizer] 

234 

235 # Simple properties 

236 # ================= 

237 @classmethod 

238 def from_tokens( 

239 cls, 

240 tokens: str | list[str], 

241 ) -> "MazeTokenizerModular": 

242 """Infers most `MazeTokenizerModular` parameters from a full sequence of tokens.""" 

243 raise NotImplementedError( 

244 "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported", 

245 ) 

246 

247 @property 

248 def token_arr(self) -> list[str] | None: 

249 """map from index to token""" 

250 return VOCAB_LIST 

251 

252 @property 

253 def tokenizer_map(self) -> dict[str, int]: 

254 """map from token to index""" 

255 return VOCAB_TOKEN_TO_INDEX 

256 

257 @property 

258 def vocab_size(self) -> int: 

259 """Number of tokens in the static vocab""" 

260 return len(VOCAB_LIST) 

261 

262 @property 

263 def n_tokens(self) -> int: 

264 "get the number of tokens in the vocabulary (deprecated)" 

265 err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead." 

266 raise NameError(err_msg) 

267 

268 @property 

269 def padding_token_index(self) -> int: 

270 "get the index of the padding token" 

271 return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING] 

272 

273 # conversion functions 

274 # ============================================================ 

275 

276 def to_tokens( 

277 self, 

278 maze: LatticeMaze, 

279 ) -> list[str]: 

280 """Converts maze into a list of tokens.""" 

281 return self.prompt_sequencer.to_tokens(maze) 

282 

283 def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]: 

284 "calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords" 

285 return list( 

286 flatten( 

287 [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords], 

288 ), 

289 ) 

290 

291 # TODO: unclear why we need to use `noqa: N805` here since its a classmethod 

292 # maybe we need to hit every overload with `@classmethod`? 

293 @overload 

294 def strings_to_coords( 

295 cls, # noqa: N805 

296 text: str | list[str], 

297 when_noncoord: Literal["skip"] = "skip", 

298 ) -> list[CoordTup]: ... 

299 @overload 

300 def strings_to_coords( 

301 cls, # noqa: N805 

302 text: str | list[str], 

303 when_noncoord: Literal["error"] = "error", 

304 ) -> list[CoordTup]: ... 

305 @overload 

306 def strings_to_coords( 

307 cls, # noqa: N805 

308 text: str | list[str], 

309 when_noncoord: Literal["include"] = "include", 

310 ) -> list[str | CoordTup]: ... 

311 @classmethod 

312 def strings_to_coords( 

313 cls, 

314 text: str | list[str], 

315 when_noncoord: WhenMissing = "skip", 

316 ) -> list[str | CoordTup]: 

317 "wrapper for maze_dataset.token_utils.strings_to_coords" 

318 warnings.warn( 

319 "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.", 

320 TokenizerPendingDeprecationWarning, 

321 ) 

322 return strings_to_coords(text=text, when_noncoord=when_noncoord) 

323 

324 @staticmethod 

325 def encode(text: str | list[str]) -> list[int]: 

326 """encode a string or list of strings into a list of tokens""" 

327 try: 

328 if isinstance(text, str): 

329 text = text.split() 

330 return [VOCAB_TOKEN_TO_INDEX[token] for token in text] 

331 except KeyError as e: 

332 err_msg: str = f"Token {e} not found in `VOCAB`." 

333 raise TokenError(err_msg) from e 

334 

335 @staticmethod 

336 def decode( 

337 token_ids: Sequence[int], 

338 joined_tokens: bool = False, 

339 ) -> list[str] | str: 

340 """decode a list of tokens into a string or list of strings""" 

341 try: 

342 output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids] 

343 except IndexError as e: 

344 err_msg: str = f"Token index '{e}' not found in `VOCAB`." 

345 raise TokenError(err_msg) from e 

346 if joined_tokens: 

347 return " ".join(output) 

348 else: 

349 return output