Coverage for maze_dataset/tokenization/maze_tokenizer_legacy.py: 78%

172 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-20 17:51 -0600

1"""legacy tokenizer which uses a `TokenizationMode` enum and a `MazeTokenizer` class 

2 

3> [!CAUTION] 

4> `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended 

5> for use, but will remain for compatibility with existing code. 

6 

7""" 

8 

9import warnings 

10from enum import Enum 

11from functools import cached_property 

12from typing import ( 

13 Callable, 

14 Iterable, 

15 Literal, 

16 Mapping, 

17 Sequence, 

18 overload, 

19) 

20 

21import numpy as np 

22from muutils.json_serialize import ( 

23 SerializableDataclass, 

24 serializable_dataclass, 

25 serializable_field, 

26) 

27from muutils.kappa import Kappa 

28from muutils.misc.sequence import WhenMissing 

29 

30# from maze_dataset import SolvedMaze 

31from maze_dataset.constants import ( 

32 SPECIAL_TOKENS, 

33 CoordTup, 

34) 

35from maze_dataset.token_utils import ( 

36 TokenizerPendingDeprecationWarning, 

37 _coord_to_strings_indexed, 

38 _coord_to_strings_UT, 

39 coords_to_strings, 

40 strings_to_coords, 

41) 

42from maze_dataset.tokenization.common import TokenError 

43from maze_dataset.utils import corner_first_ndindex 

44 

45 

46class TokenizationMode(Enum): 

47 """legacy tokenization modes 

48 

49 > [!CAUTION] 

50 > Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use. 

51 > Use `MazeTokenizerModular` instead. 

52 

53 # Abbreviations: 

54 - `AOTP`: Ajacency list, Origin, Target, Path 

55 - `UT`: Unique Token (for each coordiate) 

56 - `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers) 

57 

58 # Modes: 

59 - `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization 

60 example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)` 

61 - `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible 

62 uses `corner_first_ndindex` function to order the tokens 

63 - `AOTP_CTT_indexed`: each coordinate is a tuple of integers 

64 """ 

65 

66 AOTP_UT_rasterized = "AOTP_UT_rasterized" 

67 AOTP_UT_uniform = "AOTP_UT_uniform" 

68 AOTP_CTT_indexed = "AOTP_CTT_indexed" 

69 

70 def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer": 

71 "convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`" 

72 return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size) 

73 

74 

75_NDINDEX_FUNC_MAP: dict[ 

76 TokenizationMode, 

77 Callable[[int], Iterable[tuple[int, ...]]], 

78] = { 

79 TokenizationMode.AOTP_UT_rasterized: lambda n: list(np.ndindex(n, n)), 

80 TokenizationMode.AOTP_UT_uniform: lambda n: corner_first_ndindex(n, 2), 

81} 

82 

83 

84def is_UT(tokenization_mode: TokenizationMode) -> bool: 

85 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 

86 return tokenization_mode in ( 

87 TokenizationMode.AOTP_UT_rasterized, 

88 TokenizationMode.AOTP_UT_uniform, 

89 ) 

90 

91 

92def get_tokens_up_to_path_start( 

93 tokens: list[str], 

94 include_start_coord: bool = True, 

95 tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform, 

96) -> list[str]: 

97 """get tokens up to the path start token 

98 

99 # Parameters: 

100 - `tokens : list[str]` 

101 - `include_start_coord : bool` 

102 (defaults to `True`) 

103 - `tokenization_mode : TokenizationMode` 

104 (defaults to `TokenizationMode.AOTP_UT_uniform`) 

105 

106 # Returns: 

107 - `list[str]` subsequence of `tokens` up to the path start token 

108 

109 # Raises: 

110 - `ValueError` : if `tokenization_mode` is invalid 

111 """ 

112 warnings.warn( 

113 "`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.", 

114 TokenizerPendingDeprecationWarning, 

115 ) 

116 path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1 

117 if include_start_coord: 

118 if is_UT(tokenization_mode): 

119 return tokens[: path_start_idx + 1] 

120 elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 

121 return tokens[: path_start_idx + 5] 

122 else: 

123 err_msg: str = f"Invalid tokenization mode: {tokenization_mode}" 

124 raise ValueError(err_msg) 

125 else: 

126 return tokens[:path_start_idx] 

127 

128 

129_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE: list[str] = [ 

130 "name", 

131 "max_grid_size", 

132 "token_arr", 

133 "tokenizer_map", 

134 "vocab_size", 

135 "padding_token_index", 

136] 

137 

138 

139@serializable_dataclass( 

140 properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, 

141 kw_only=True, 

142) 

143class MazeTokenizer(SerializableDataclass): 

144 """LEGACY Tokenizer for mazes 

145 

146 > [!CAUTION] 

147 > `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended 

148 > for use, but will remain for compatibility with existing code. 

149 

150 # Parameters: 

151 - `tokenization_mode: TokenizationMode` 

152 mode of tokenization. required. 

153 - `max_grid_size: int | None` 

154 maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text 

155 

156 # Properties 

157 - `name: str` 

158 auto-generated name of the tokenizer from mode and size 

159 

160 ## Conditional Properties 

161 

162 - `node_strings_map: Mapping[CoordTup, str]` 

163 map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode 

164 

165 these all return `None` if `max_grid_size` is `None`. 

166 Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None` 

167 

168 - `token_arr: list[str]` 

169 list of tokens, in order of their indices in the vocabulary 

170 - `tokenizer_map: Mapping[str, int]` 

171 map from token to index 

172 - `vocab_size: int` 

173 size of the vocabulary 

174 - `padding_token_index: int` 

175 index of the padding token 

176 

177 # Methods 

178 - `coords_to_strings(coords: list[CoordTup]) -> list[str]` 

179 convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates 

180 - `strings_to_coords(strings: list[str]) -> list[CoordTup]` 

181 convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates 

182 

183 """ 

184 

185 # parameters 

186 # ============================================================ 

187 

188 tokenization_mode: TokenizationMode = serializable_field( 

189 default=TokenizationMode.AOTP_UT_uniform, 

190 serialization_fn=lambda x: x.value, 

191 loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]], 

192 ) 

193 

194 max_grid_size: int | None = serializable_field(default=None) 

195 

196 # properties 

197 # ============================================================ 

198 

199 @property 

200 def name(self) -> str: 

201 """auto-generated name of the tokenizer from mode and size""" 

202 max_grid_size_str: str = ( 

203 f"-g{self.max_grid_size}" if self.max_grid_size is not None else "" 

204 ) 

205 return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}" 

206 

207 @cached_property 

208 def _node_strings_map(self) -> Mapping[CoordTup, list[str]]: 

209 """map a coordinate to a token""" 

210 if self.tokenization_mode in ( 

211 TokenizationMode.AOTP_UT_rasterized, 

212 TokenizationMode.AOTP_UT_uniform, 

213 ): 

214 return Kappa(_coord_to_strings_UT) 

215 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 

216 return Kappa(_coord_to_strings_indexed) 

217 else: 

218 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 

219 raise ValueError(err_msg) 

220 

221 @cached_property 

222 def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None: 

223 """map a coordinate to a token""" 

224 if self.tokenization_mode in ( 

225 TokenizationMode.AOTP_UT_rasterized, 

226 TokenizationMode.AOTP_UT_uniform, 

227 ): 

228 return None 

229 else: 

230 return self._node_strings_map 

231 

232 # conditional properties (on max_grid_size existing) 

233 # ------------------------------------------------------------ 

234 

235 @cached_property 

236 def _token_arr(self) -> list[str]: 

237 """map from index to token""" 

238 if self.max_grid_size is None: 

239 err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }" 

240 raise ValueError(err_msg) 

241 

242 output: list[str] = list(SPECIAL_TOKENS.values()) 

243 

244 if self.tokenization_mode in ( 

245 TokenizationMode.AOTP_UT_rasterized, 

246 TokenizationMode.AOTP_UT_uniform, 

247 ): 

248 output.extend( 

249 [ 

250 self._node_strings_map[coord][0] 

251 for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode]( 

252 self.max_grid_size, 

253 ) 

254 ], 

255 ) 

256 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 

257 # TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models 

258 output.extend( 

259 [ 

260 "(", 

261 ",", 

262 ")", # new special chars 

263 *map(str, range(self.max_grid_size)), # numbers 

264 ], 

265 ) 

266 else: 

267 err_msg: str = ( 

268 f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}", 

269 ) 

270 raise ValueError(err_msg) 

271 

272 return output 

273 

274 @cached_property 

275 def token_arr(self) -> list[str] | None: 

276 "get the token array if the max_grid_size is specified" 

277 if self.max_grid_size is None: 

278 return None 

279 return self._token_arr 

280 

281 @cached_property 

282 def _tokenizer_map(self) -> dict[str, int]: 

283 """map from token to index""" 

284 return {token: i for i, token in enumerate(self._token_arr)} 

285 

286 @cached_property 

287 def tokenizer_map(self) -> dict[str, int] | None: 

288 "get the tokenizer map if the max_grid_size is specified" 

289 if self.max_grid_size is None: 

290 return None 

291 return self._tokenizer_map 

292 

293 @property 

294 def _vocab_size(self) -> int: 

295 return len(self._token_arr) 

296 

297 @property 

298 def vocab_size(self) -> int | None: 

299 "get the size of the vocabulary if the max_grid_size is specified" 

300 if self.max_grid_size is None: 

301 return None 

302 return self._vocab_size 

303 

304 @property 

305 def _n_tokens(self) -> int: 

306 # TODO: deprecate 

307 return self._vocab_size 

308 

309 @property 

310 def n_tokens(self) -> int | None: 

311 "get the number of tokens if the max_grid_size is specified" 

312 if self.max_grid_size is None: 

313 return None 

314 return self._n_tokens 

315 

316 @cached_property 

317 def _padding_token_index(self) -> int: 

318 return self.tokenizer_map[SPECIAL_TOKENS.PADDING] 

319 

320 @cached_property 

321 def padding_token_index(self) -> int | None: 

322 "get the index of the padding token if it exists" 

323 if self.max_grid_size is None: 

324 return None 

325 return self._padding_token_index 

326 

327 # conversion functions 

328 # ============================================================ 

329 

330 @overload 

331 def coords_to_strings( 

332 self, 

333 coords: list[str | CoordTup], 

334 when_noncoord: Literal["include", "skip"] = "skip", 

335 ) -> list[str]: ... 

336 @overload 

337 def coords_to_strings( 

338 self, 

339 coords: list[CoordTup], 

340 when_noncoord: Literal["error"] = "error", 

341 ) -> list[str]: ... 

342 def coords_to_strings( 

343 self, 

344 coords: list[CoordTup], 

345 when_noncoord: WhenMissing = "skip", 

346 ) -> list[str]: 

347 """map a list of coordinate tuples (and maybe other tokens) to strings 

348 

349 wraps `maze_dataset.token_utils.coords_to_strings` with either 

350 `_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode 

351 """ 

352 if self.tokenization_mode in ( 

353 TokenizationMode.AOTP_UT_rasterized, 

354 TokenizationMode.AOTP_UT_uniform, 

355 ): 

356 return coords_to_strings( 

357 coords=coords, 

358 coord_to_strings_func=_coord_to_strings_UT, 

359 when_noncoord=when_noncoord, 

360 ) 

361 elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: 

362 return coords_to_strings( 

363 coords=coords, 

364 coord_to_strings_func=_coord_to_strings_indexed, 

365 when_noncoord=when_noncoord, 

366 ) 

367 else: 

368 err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" 

369 raise ValueError(err_msg) 

370 

371 @overload 

372 def strings_to_coords( 

373 cls, # noqa: N805 

374 text: str | list[str], 

375 when_noncoord: Literal["skip"] = "skip", 

376 ) -> list[CoordTup]: ... 

377 @overload 

378 def strings_to_coords( 

379 cls, # noqa: N805 

380 text: str | list[str], 

381 when_noncoord: Literal["error"] = "error", 

382 ) -> list[CoordTup]: ... 

383 @overload 

384 def strings_to_coords( 

385 cls, # noqa: N805 

386 text: str | list[str], 

387 when_noncoord: Literal["include"] = "include", 

388 ) -> list[str | CoordTup]: ... 

389 @classmethod 

390 def strings_to_coords( 

391 cls, 

392 text: str | list[str], 

393 when_noncoord: WhenMissing = "skip", 

394 ) -> list[str | CoordTup]: 

395 "wrapper for `maze_dataset.token_utils.strings_to_coords`" 

396 return strings_to_coords(text=text, when_noncoord=when_noncoord) 

397 

398 def encode(self, text: str | list[str]) -> list[int]: 

399 """encode a string or list of strings into a list of tokens""" 

400 try: 

401 if isinstance(text, str): 

402 text = text.split() 

403 return [self.tokenizer_map[token] for token in text] 

404 except KeyError as e: 

405 err_msg: str = ( 

406 f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}" 

407 ) 

408 raise TokenError(err_msg) from e 

409 

410 def decode( 

411 self, 

412 tokens: Sequence[int], 

413 joined_tokens: bool = False, 

414 ) -> list[str] | str: 

415 """decode a list of tokens into a string or list of strings""" 

416 try: 

417 output: list[str] = [self.token_arr[token] for token in tokens] 

418 except IndexError as e: 

419 err_msg: str = ( 

420 f"Token index '{e}' not found in vocabulary of length {self.vocab_size}" 

421 ) 

422 raise TokenError(err_msg) from e 

423 if joined_tokens: 

424 return " ".join(output) 

425 else: 

426 return output 

427 

428 # UT-only coordinate stuff 

429 # ============================================================ 

430 

431 @cached_property 

432 def coordinate_tokens_coords(self) -> dict[CoordTup, int]: 

433 "map of coordiante tuples to their token ids, only valid for UT" 

434 # print(f"{self.tokenization_mode = }") 

435 if not self.is_UT(): 

436 err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }" 

437 raise ValueError(err_msg) 

438 

439 if self.max_grid_size is None: 

440 err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }" 

441 raise ValueError(err_msg) 

442 

443 raw_converted: list[CoordTup | str] = self.strings_to_coords( 

444 self.token_arr, 

445 when_noncoord="include", 

446 ) 

447 

448 # filter out non-coordinates 

449 return { 

450 coord: i 

451 for i, coord in enumerate(raw_converted) 

452 if not isinstance(coord, str) 

453 } 

454 

455 @cached_property 

456 def coordinate_tokens_ids(self) -> dict[str, int]: 

457 "map of coordinate tokens to their token ids, only valid for UT" 

458 # checks performed in call 

459 output: dict[str, int] = dict() 

460 

461 for coord, index in self.coordinate_tokens_coords.items(): 

462 _for_key: list[str] = self.coords_to_strings([coord]) 

463 assert len(_for_key) == 1 

464 output[_for_key[0]] = index 

465 

466 return output 

467 

468 # other 

469 # ============================================================ 

470 

471 def summary(self) -> dict: 

472 """returns a summary of the tokenization mode""" 

473 return { 

474 "tokenization_mode": self.tokenization_mode.value, 

475 "max_grid_size": self.max_grid_size, 

476 "vocab_size": self.vocab_size, 

477 } 

478 

479 def is_AOTP(self) -> bool: 

480 """returns true if a tokenization mode is Adjacency list, Origin, Target, Path""" 

481 return self.tokenization_mode in ( 

482 TokenizationMode.AOTP_UT_rasterized, 

483 TokenizationMode.AOTP_UT_uniform, 

484 TokenizationMode.AOTP_CTT_indexed, 

485 ) 

486 

487 def is_UT(self) -> bool: 

488 "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" 

489 return is_UT(self.tokenization_mode) 

490 

491 def clear_cache(self) -> None: 

492 """clears all cached properties""" 

493 # delete the properties only if they exist 

494 for name, prop in self.__class__.__dict__.items(): 

495 if isinstance(prop, cached_property): 

496 # if the property exists, delete it 

497 try: # noqa: SIM105 

498 delattr(self, name) 

499 except AttributeError: 

500 pass