Coverage for maze_dataset/tokenization/modular/elements.py: 90%

370 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 12:48 -0600

1"""implements subclasses of `_TokenizerElement` to be used in `MazeTokenizerModular`""" 

2 

3import abc 

4import random 

5from typing import ( 

6 Callable, 

7 Literal, 

8 Sequence, 

9 TypedDict, 

10) 

11 

12import numpy as np 

13from jaxtyping import Bool, Int 

14from muutils.json_serialize import ( 

15 serializable_dataclass, 

16 serializable_field, 

17) 

18from muutils.misc import empty_sequence_if_attr_false, flatten 

19 

20# from maze_dataset import SolvedMaze 

21from maze_dataset.constants import ( 

22 VOCAB, 

23 ConnectionArray, 

24 ConnectionList, 

25 Coord, 

26 CoordTup, 

27) 

28from maze_dataset.generation import numpy_rng 

29from maze_dataset.maze.lattice_maze import LatticeMaze, SolvedMaze 

30from maze_dataset.token_utils import ( 

31 connection_list_to_adj_list, 

32 get_cardinal_direction, 

33 get_relative_direction, 

34 is_connection, 

35 tokens_between, 

36) 

37from maze_dataset.tokenization.modular.element_base import ( 

38 __TokenizerElementNamespace, 

39 _load_tokenizer_element, 

40 _TokenizerElement, 

41 _unsupported_is_invalid, 

42 mark_as_unsupported, 

43) 

44from maze_dataset.utils import lattice_connection_array 

45 

46 

47class CoordTokenizers(__TokenizerElementNamespace): 

48 """Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 

49 

50 key = "coord_tokenizer" 

51 

52 @serializable_dataclass(frozen=True, kw_only=True) 

53 class _CoordTokenizer(_TokenizerElement, abc.ABC): 

54 """Superclass for classes which tokenize singular coords in a maze.""" 

55 

56 @abc.abstractmethod 

57 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: 

58 pass 

59 

60 @classmethod 

61 def attribute_key(cls) -> str: 

62 return CoordTokenizers.key 

63 

64 def is_valid(self, do_except: bool = False) -> bool: 

65 # No invalid instances possible within data member type hint bounds 

66 return True 

67 

68 @serializable_dataclass(frozen=True, kw_only=True) 

69 class UT(_CoordTokenizer): 

70 """Unique token coordinate tokenizer.""" 

71 

72 # inherit docstring 

73 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 

74 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])] 

75 

76 @serializable_dataclass(frozen=True, kw_only=True) 

77 class CTT(_CoordTokenizer): 

78 """Coordinate tuple tokenizer 

79 

80 # Parameters 

81 - `pre`: Whether all coords include an integral preceding delimiter token 

82 - `intra`: Whether all coords include a delimiter token between coordinates 

83 - `post`: Whether all coords include an integral following delimiter token 

84 """ 

85 

86 pre: bool = serializable_field(default=True) 

87 intra: bool = serializable_field(default=True) 

88 post: bool = serializable_field(default=True) 

89 # Implement methods 

90 

91 # inherit docstring 

92 def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 

93 return [ 

94 *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"), 

95 str(coord[0]), 

96 *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"), 

97 str(coord[1]), 

98 *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"), 

99 ] 

100 

101 

102class EdgeGroupings(__TokenizerElementNamespace): 

103 """Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`.""" 

104 

105 key = "edge_grouping" 

106 

107 class _GroupingTokenParams(TypedDict): 

108 """A uniform private hyperparameter interface used by `AdjListTokenizer`.""" 

109 

110 connection_token_ordinal: Literal[0, 1, 2] 

111 intra: bool 

112 grouped: bool 

113 

114 @serializable_dataclass(frozen=True, kw_only=True) 

115 class _EdgeGrouping(_TokenizerElement, abc.ABC): 

116 """Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping.""" 

117 

118 @classmethod 

119 def attribute_key(cls) -> str: 

120 return EdgeGroupings.key 

121 

122 def is_valid(self, do_except: bool = False) -> bool: 

123 return True 

124 

125 @abc.abstractmethod 

126 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 

127 """Divides a ConnectionArray into groups of edges. 

128 

129 Shuffles/sequences within each group if applicable. 

130 """ 

131 pass 

132 

133 @abc.abstractmethod 

134 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 

135 """Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize. 

136 

137 These hyperparameters are not used by `_EdgeGrouping` internally. 

138 They are located in `_EdgeGrouping` rather than in `AdjListTokenizer` 

139 since the hyperparameter space is a function of the `_EdgeGrouping` subclass. 

140 This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses 

141 into a uniform private interface used by `AdjListTokenizer`. 

142 """ 

143 pass 

144 

145 @serializable_dataclass(frozen=True, kw_only=True) 

146 class Ungrouped(_EdgeGrouping): 

147 """No grouping occurs, each edge is tokenized individually. 

148 

149 # Parameters 

150 - `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears. 

151 Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization. 

152 """ 

153 

154 connection_token_ordinal: Literal[0, 1, 2] = serializable_field( 

155 default=1, 

156 assert_type=False, 

157 ) 

158 

159 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 

160 return EdgeGroupings._GroupingTokenParams( 

161 connection_token_ordinal=self.connection_token_ordinal, 

162 intra=False, 

163 grouped=False, 

164 ) 

165 

166 def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]: 

167 return np.expand_dims(edges, 1) 

168 

169 @serializable_dataclass(frozen=True, kw_only=True) 

170 @mark_as_unsupported(_unsupported_is_invalid) 

171 class ByLeadingCoord(_EdgeGrouping): 

172 """All edges with the same leading coord are grouped together. 

173 

174 # Parameters 

175 - `intra`: Whether all edge groupings include a delimiter token between individual edge representations. 

176 Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `) 

177 - `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. 

178 If false, the fixed order is lexicographical by (row, col). 

179 In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord. 

180 - `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears. 

181 Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization. 

182 """ 

183 

184 intra: bool = serializable_field(default=True) 

185 shuffle_group: bool = serializable_field(default=True) 

186 connection_token_ordinal: Literal[0, 1] = serializable_field( 

187 default=0, 

188 assert_type=False, 

189 ) 

190 

191 def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": 

192 return EdgeGroupings._GroupingTokenParams( 

193 connection_token_ordinal=self.connection_token_ordinal, 

194 intra=self.intra, 

195 grouped=True, 

196 ) 

197 

198 def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: 

199 # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function 

200 index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort( 

201 (edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]), 

202 ) 

203 sorted_edges: ConnectionArray = edges[index_array, ...] 

204 groups: list[ConnectionArray] = np.split( 

205 sorted_edges, 

206 np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:], 

207 ) 

208 if self.shuffle_group: 

209 [numpy_rng.shuffle(g, axis=0) for g in groups] 

210 return groups 

211 

212 

213class EdgePermuters(__TokenizerElementNamespace): 

214 """Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`.""" 

215 

216 key = "edge_permuter" 

217 

218 @serializable_dataclass(frozen=True, kw_only=True) 

219 class _EdgePermuter(_TokenizerElement, abc.ABC): 

220 """Specifies how to sequence the two coords that encode a lattice edge.""" 

221 

222 @classmethod 

223 def attribute_key(cls) -> str: 

224 return EdgePermuters.key 

225 

226 def is_valid(self, do_except: bool = False) -> bool: 

227 # No invalid instances possible within data member type hint bounds 

228 return True 

229 

230 @staticmethod 

231 @abc.abstractmethod 

232 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 

233 """Executes a permutation. 

234 

235 Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation. 

236 

237 # Parameters 

238 - `lattice_edges`: Array of lattice edges. 

239 The two coords in shape[1] must be adjacent in the lattice. 

240 

241 # Returns 

242 - Array of lattice edges with entries along shape[1] systematically permuted. 

243 - shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`. 

244 """ 

245 pass 

246 

247 @serializable_dataclass(frozen=True, kw_only=True) 

248 class SortedCoords(_EdgePermuter): 

249 """returns a sorted representation. useful for checking consistency""" 

250 

251 @staticmethod 

252 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 

253 return lattice_edges[ 

254 np.lexsort( 

255 ( 

256 lattice_edges[:, 1, 1], 

257 lattice_edges[:, 1, 0], 

258 lattice_edges[:, 0, 1], 

259 lattice_edges[:, 0, 0], 

260 ), 

261 ), 

262 ..., 

263 ] 

264 

265 @serializable_dataclass(frozen=True, kw_only=True) 

266 class RandomCoords(_EdgePermuter): 

267 """Permutes each edge randomly.""" 

268 

269 @staticmethod 

270 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 

271 numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges) 

272 return lattice_edges 

273 

274 @serializable_dataclass(frozen=True, kw_only=True) 

275 class BothCoords(_EdgePermuter): 

276 """Includes both possible permutations of every edge in the output. 

277 

278 Since input ConnectionList has only 1 instance of each edge, 

279 a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`. 

280 """ 

281 

282 @staticmethod 

283 def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: 

284 return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0) 

285 

286 

287class EdgeSubsets(__TokenizerElementNamespace): 

288 """Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`.""" 

289 

290 key = "edge_subset" 

291 

292 @serializable_dataclass(frozen=True, kw_only=True) 

293 class _EdgeSubset(_TokenizerElement, abc.ABC): 

294 """Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized.""" 

295 

296 @classmethod 

297 def attribute_key(cls) -> str: 

298 return EdgeSubsets.key 

299 

300 def is_valid(self, do_except: bool = False) -> bool: 

301 return True 

302 

303 @abc.abstractmethod 

304 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 

305 """Returns the set of lattice edges to be tokenized.""" 

306 pass 

307 

308 @serializable_dataclass(frozen=True, kw_only=True) 

309 class AllLatticeEdges(_EdgeSubset): 

310 """All 2n**2-2n edges of the lattice are tokenized. 

311 

312 If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`. 

313 """ 

314 

315 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 

316 return lattice_connection_array(maze.grid_n) 

317 

318 @serializable_dataclass(frozen=True, kw_only=True) 

319 class ConnectionEdges(_EdgeSubset): 

320 """Only edges which contain a connection are tokenized. 

321 

322 Alternatively, only edges which contain a wall are tokenized. 

323 

324 # Parameters 

325 - `walls`: Whether wall edges or connection edges are tokenized. 

326 If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`. 

327 """ 

328 

329 walls: bool = serializable_field(default=False) 

330 

331 def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: 

332 conn_list: ConnectionList = maze.connection_list 

333 if self.walls: 

334 conn_list = np.logical_not(conn_list) 

335 conn_list[0, -1, :] = False 

336 conn_list[1, :, -1] = False 

337 return connection_list_to_adj_list( 

338 conn_list, 

339 shuffle_d0=False, 

340 shuffle_d1=False, 

341 ) 

342 

343 

344def _adjlist_no_pre_unsupported(self_, do_except: bool = False) -> bool: # noqa: ANN001 

345 """Returns False if `pre` is True, True otherwise.""" 

346 output: bool = self_.pre is False 

347 if do_except and not output: 

348 raise ValueError( 

349 "AdjListCoord does not support `pre == False`.", 

350 ) 

351 

352 return output 

353 

354 

355class AdjListTokenizers(__TokenizerElementNamespace): 

356 """Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 

357 

358 key = "adj_list_tokenizer" 

359 

360 @serializable_dataclass(frozen=True, kw_only=True) 

361 @mark_as_unsupported(_adjlist_no_pre_unsupported) 

362 class _AdjListTokenizer(_TokenizerElement, abc.ABC): 

363 """Specifies how the adjacency list is tokenized. 

364 

365 Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations. 

366 See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details. 

367 

368 # Parameters 

369 - `pre`: Whether all edge groupings include a preceding delimiter token 

370 - `post`: Whether all edge groupings include a following delimiter token 

371 - `shuffle_d0`: Specifies how to sequence the edge groupings. 

372 If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group. 

373 - `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping. 

374 - `edge_subset`: Specifies the subset of lattice edges to be tokenized. 

375 - `edge_permuter`: Specifies, in each edge tokenization, which coord either: 

376 1. Appears first in the tokenization, for `AdjListCoord`. 

377 2. Is tokenized directly as a coord, for `AdjListCardinal`. 

378 - `shuffle`: For each edge, the leading coord is selected randomly. 

379 - `all`: Each edge appears twice in the tokenization, appearing with both leading coords. 

380 - `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details. 

381 """ 

382 

383 pre: bool = serializable_field(default=False, assert_type=False) 

384 post: bool = serializable_field(default=True) 

385 shuffle_d0: bool = serializable_field(default=True) 

386 edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field( 

387 default=EdgeGroupings.Ungrouped(), 

388 loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings), 

389 ) 

390 edge_subset: EdgeSubsets._EdgeSubset = serializable_field( 

391 default=EdgeSubsets.ConnectionEdges(), 

392 loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets), 

393 ) 

394 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 

395 default=EdgePermuters.RandomCoords(), 

396 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 

397 ) 

398 

399 @classmethod 

400 def attribute_key(cls) -> str: 

401 return AdjListTokenizers.key 

402 

403 def is_valid(self, do_except: bool = False) -> bool: 

404 # No invalid instances possible within data member type hint bounds 

405 return True 

406 

407 @abc.abstractmethod 

408 def _tokenization_callables( 

409 self, 

410 edges: ConnectionArray, 

411 is_conn: Bool[np.ndarray, " edges"], 

412 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

413 *args, 

414 **kwargs, 

415 ) -> list[Callable]: 

416 """Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization. 

417 

418 # Returns 

419 - `[0]`: leading coord tokens 

420 - `[1]`: connector tokens 

421 - `[2]`: trailing coord tokens 

422 """ 

423 pass 

424 

425 def _tokenize_edge_grouping( 

426 self, 

427 edges: ConnectionArray, 

428 maze: LatticeMaze, 

429 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

430 group_params: EdgeGroupings._GroupingTokenParams, 

431 ) -> Sequence[str]: 

432 """Tokenizes a single edge grouping.""" 

433 cxn_ord: int = group_params["connection_token_ordinal"] 

434 is_conn: Bool[np.ndarray, edges] = is_connection( 

435 edges, 

436 maze.connection_list, 

437 ) 

438 tokenize_callables = self._tokenization_callables( 

439 edges, 

440 is_conn, 

441 coord_tokenizer, 

442 ) 

443 

444 if group_params["grouped"]: 

445 # If grouped 

446 callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1] 

447 repeated_callables = [ 

448 tokenize_callables[i] for i in callable_permutation 

449 ] 

450 return flatten( 

451 [ 

452 tokenize_callables[0](0), 

453 [ 

454 [ 

455 *[ 

456 tok_callable(i) 

457 for tok_callable in repeated_callables 

458 ], 

459 *( 

460 (VOCAB.ADJLIST_INTRA,) 

461 if group_params["intra"] 

462 else () 

463 ), 

464 ] 

465 for i in range(edges.shape[0]) 

466 ], 

467 ], 

468 ) 

469 else: 

470 # If ungrouped 

471 callable_permutation = [0, 2] 

472 callable_permutation.insert(cxn_ord, 1) 

473 tokenize_callables = [ 

474 tokenize_callables[i] for i in callable_permutation 

475 ] 

476 

477 return flatten( 

478 [ 

479 [ 

480 [ 

481 *[ 

482 tok_callable(i) 

483 for tok_callable in tokenize_callables 

484 ], 

485 *empty_sequence_if_attr_false( 

486 (VOCAB.ADJLIST_INTRA,), 

487 group_params, 

488 "intra", 

489 ), 

490 ] 

491 for i in range(edges.shape[0]) 

492 ], 

493 ], 

494 ) 

495 

496 def to_tokens( 

497 self, 

498 maze: LatticeMaze, 

499 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

500 ) -> list[str]: 

501 # Get the set of edges to be tokenized 

502 edges: ConnectionArray = self.edge_subset._get_edges(maze) 

503 # Systematically permute the leading coord of each edge 

504 edges: ConnectionArray = self.edge_permuter._permute(edges) 

505 group_params: EdgeGroupings._GroupingTokenParams = ( 

506 self.edge_grouping._token_params() 

507 ) 

508 # then, we need to group the edges 

509 groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges) 

510 # shuffle the groups if specified 

511 if self.shuffle_d0: 

512 if isinstance(groups, np.ndarray): 

513 numpy_rng.shuffle(groups, axis=0) 

514 elif isinstance(groups, list): 

515 random.shuffle(groups) 

516 else: 

517 err_msg: str = f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported." 

518 raise TypeError(err_msg) 

519 # Tokenize each group with optional delimiters 

520 tokens: list[str] = list( 

521 flatten( 

522 [ 

523 [ 

524 *empty_sequence_if_attr_false( 

525 (VOCAB.ADJLIST_PRE,), 

526 self, 

527 "pre", 

528 ), 

529 *self._tokenize_edge_grouping( 

530 group, 

531 maze, 

532 coord_tokenizer, 

533 group_params, 

534 ), 

535 *empty_sequence_if_attr_false( 

536 (VOCAB.ADJACENCY_ENDLINE,), 

537 self, 

538 "post", 

539 ), 

540 ] 

541 for group in groups 

542 ], 

543 ), 

544 ) 

545 return tokens 

546 

547 @serializable_dataclass(frozen=True, kw_only=True) 

548 class AdjListCoord(_AdjListTokenizer): 

549 """Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.""" 

550 

551 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 

552 default=EdgePermuters.RandomCoords(), 

553 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 

554 ) 

555 

556 def _tokenization_callables( 

557 self, 

558 edges: ConnectionArray, 

559 is_conn: Bool[np.ndarray, " edges"], 

560 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

561 *args, 

562 **kwargs, 

563 ) -> list[Callable]: 

564 # Map from `is_conn` to the tokens which represent connections and walls 

565 conn_token_map: dict[bool, str] = { 

566 True: VOCAB.CONNECTOR, 

567 False: VOCAB.ADJLIST_WALL, 

568 } 

569 return [ 

570 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 

571 lambda i: conn_token_map[is_conn[i]], 

572 lambda i: coord_tokenizer.to_tokens(edges[i, 1]), 

573 ] 

574 

575 @serializable_dataclass(frozen=True, kw_only=True) 

576 class AdjListCardinal(_AdjListTokenizer): 

577 """Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members. 

578 

579 # Parameters 

580 - `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens. 

581 """ 

582 

583 edge_permuter: EdgePermuters._EdgePermuter = serializable_field( 

584 default=EdgePermuters.BothCoords(), 

585 loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), 

586 ) 

587 

588 def _tokenization_callables( 

589 self, 

590 edges: ConnectionArray, 

591 is_conn: Bool[np.ndarray, " edges"], 

592 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

593 *args, 

594 **kwargs, 

595 ) -> list[Callable]: 

596 # Map from `is_conn` to the tokens which represent connections and walls 

597 conn_token_map: dict[bool, str] = { 

598 True: VOCAB.CONNECTOR, 

599 False: VOCAB.ADJLIST_WALL, 

600 } 

601 return [ 

602 lambda i: coord_tokenizer.to_tokens(edges[i, 0]), 

603 lambda i: conn_token_map[is_conn[i]], 

604 lambda i: get_cardinal_direction(edges[i]), 

605 ] 

606 

607 

608class TargetTokenizers(__TokenizerElementNamespace): 

609 """Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 

610 

611 key = "target_tokenizer" 

612 

613 @serializable_dataclass(frozen=True, kw_only=True) 

614 class _TargetTokenizer(_TokenizerElement, abc.ABC): 

615 """Superclass of tokenizers for maze targets.""" 

616 

617 @abc.abstractmethod 

618 def to_tokens( 

619 self, 

620 targets: Sequence[Coord], 

621 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

622 ) -> list[str]: 

623 """Returns tokens representing the target.""" 

624 pass 

625 

626 @classmethod 

627 def attribute_key(cls) -> str: 

628 return TargetTokenizers.key 

629 

630 @serializable_dataclass(frozen=True, kw_only=True) 

631 class Unlabeled(_TargetTokenizer): 

632 """Targets are simply listed as coord tokens. 

633 

634 - `post`: Whether all coords include an integral following delimiter token 

635 """ 

636 

637 post: bool = serializable_field(default=False) 

638 

639 # inherit docstring 

640 def to_tokens( # noqa: D102 

641 self, 

642 targets: Sequence[Coord], 

643 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

644 ) -> list[str]: 

645 return list( 

646 flatten( 

647 [ 

648 [ 

649 *coord_tokenizer.to_tokens(target), 

650 *empty_sequence_if_attr_false( 

651 [VOCAB.TARGET_POST], 

652 self, 

653 "post", 

654 ), 

655 ] 

656 for target in targets 

657 ], 

658 ), 

659 ) 

660 

661 # inherit docstring 

662 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 

663 # No invalid instances possible within data member type hint bounds 

664 return True 

665 

666 

667class StepSizes(__TokenizerElementNamespace): 

668 """Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`.""" 

669 

670 key = "step_size" 

671 

672 @serializable_dataclass(frozen=True, kw_only=True) 

673 class _StepSize(_TokenizerElement, abc.ABC): 

674 """Specifies which coords in `maze.solution` are used to represent the path.""" 

675 

676 @classmethod 

677 def attribute_key(cls) -> str: 

678 return StepSizes.key 

679 

680 @abc.abstractmethod # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call 

681 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 

682 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 

683 raise NotImplementedError( 

684 "Subclasses must implement `StepSize.step_indices.", 

685 ) 

686 

687 def step_start_end_indices(self, maze: SolvedMaze) -> list[tuple[int, int]]: 

688 """Returns steps as tuples of starting and ending positions for each step.""" 

689 indices: list[int] = self._step_single_indices(maze) 

690 # TODO: RUF007 Prefer `itertools.pairwise()` over `zip()` when iterating over successive pairs 

691 return [ 

692 (start, end) 

693 for start, end in zip(indices[:-1], indices[1:], strict=False) # noqa: RUF007 

694 ] 

695 

696 def is_valid(self, do_except: bool = False) -> bool: 

697 # No invalid instances possible within data member type hint bounds 

698 return True 

699 

700 @serializable_dataclass(frozen=True, kw_only=True) 

701 class Singles(_StepSize): 

702 """Every coord in `maze.solution` is represented. 

703 

704 Legacy tokenizers all use this behavior. 

705 """ 

706 

707 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 

708 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 

709 return list(range(maze.solution.shape[0])) 

710 

711 @serializable_dataclass(frozen=True, kw_only=True) 

712 @mark_as_unsupported(_unsupported_is_invalid) 

713 class Straightaways(_StepSize): 

714 """Only coords where the path turns are represented in the path. 

715 

716 I.e., the path is represented as a sequence of straightaways, 

717 specified by the coords at the turns. 

718 """ 

719 

720 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 

721 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 

722 last_turn_coord: Coord = maze.solution[0, ...] 

723 indices: list[int] = [0] 

724 for i, coord in enumerate(maze.solution): 

725 if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]: 

726 indices.append(i - 1) 

727 last_turn_coord = maze.solution[i - 1, ...] 

728 indices.append(i) 

729 return indices 

730 

731 @serializable_dataclass(frozen=True, kw_only=True) 

732 class Forks(_StepSize): 

733 """Only coords at forks, where the path has >=2 options for the next step are included. 

734 

735 Excludes the option of backtracking. 

736 The starting and ending coords are always included. 

737 """ 

738 

739 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 

740 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 

741 return maze.get_solution_forking_points(always_include_endpoints=True)[0] 

742 

743 @serializable_dataclass(frozen=True, kw_only=True) 

744 @mark_as_unsupported(_unsupported_is_invalid) 

745 class ForksAndStraightaways(_StepSize): 

746 """Includes the union of the coords included by `Forks` and `Straightaways`. 

747 

748 See documentation for those classes for details. 

749 """ 

750 

751 def _step_single_indices(self, maze: SolvedMaze) -> list[int]: 

752 """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" 

753 return list( 

754 np.unique( 

755 np.concatenate( 

756 ( 

757 StepSizes.Straightaways()._step_single_indices(maze), 

758 StepSizes.Forks()._step_single_indices(maze), 

759 ), 

760 ), 

761 ), 

762 ) 

763 

764 

765class StepTokenizers(__TokenizerElementNamespace): 

766 """Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 

767 

768 key = "step_tokenizers" 

769 

770 @serializable_dataclass(frozen=True, kw_only=True) 

771 class _StepTokenizer(_TokenizerElement, abc.ABC): 

772 """Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized.""" 

773 

774 @classmethod 

775 def attribute_key(cls) -> str: 

776 return StepTokenizers.key 

777 

778 @abc.abstractmethod 

779 def to_tokens( 

780 self, 

781 maze: SolvedMaze, 

782 start_index: int, 

783 end_index: int, 

784 **kwargs, 

785 ) -> list[str]: 

786 """Tokenizes a single step in the solution. 

787 

788 # Parameters 

789 - `maze`: Maze to be tokenized 

790 - `start_index`: The index of the Coord in `maze.solution` at which the current step starts 

791 - `end_index`: The index of the Coord in `maze.solution` at which the current step ends 

792 """ 

793 raise NotImplementedError( 

794 "Subclasses must implement `StepTokenizer.to_tokens.", 

795 ) 

796 

797 def is_valid(self, do_except: bool = False) -> bool: 

798 # No invalid instances possible within data member type hint bounds 

799 return True 

800 

801 @serializable_dataclass(frozen=True, kw_only=True) 

802 class Coord(_StepTokenizer): 

803 """A direct tokenization of the end position coord represents the step.""" 

804 

805 # inherit docstring 

806 def to_tokens( # noqa: D102 

807 self, 

808 maze: SolvedMaze, 

809 start_index: int, 

810 end_index: int, 

811 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

812 ) -> list[str]: 

813 return coord_tokenizer.to_tokens(maze.solution[end_index, ...]) 

814 

815 @serializable_dataclass(frozen=True, kw_only=True) 

816 class Cardinal(_StepTokenizer): 

817 """A step is tokenized with a cardinal direction token. 

818 

819 It is the direction of the step from the starting position along the solution. 

820 """ 

821 

822 # inherit docstring 

823 def to_tokens( # noqa: D102 

824 self, 

825 maze: SolvedMaze, 

826 start_index: int, 

827 end_index: int, 

828 **kwargs, 

829 ) -> list[str]: 

830 return [ 

831 get_cardinal_direction(maze.solution[start_index : start_index + 2]), 

832 ] 

833 

834 @serializable_dataclass(frozen=True, kw_only=True) 

835 class Relative(_StepTokenizer): 

836 """Tokenizes a solution step using relative first-person directions (right, left, forward, etc.). 

837 

838 To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH. 

839 Similarly to `Cardinal`, the direction is that of the step from the starting position. 

840 """ 

841 

842 # inherit docstring 

843 def to_tokens( # noqa: D102 

844 self, 

845 maze: SolvedMaze, 

846 start_index: int, 

847 end_index: int, 

848 **kwargs, 

849 ) -> list[str]: 

850 if start_index == 0: 

851 start = maze.solution[0] 

852 previous = start + np.array([1, 0]) 

853 return [ 

854 get_relative_direction( 

855 np.concatenate( 

856 ( 

857 np.expand_dims(previous, 0), 

858 maze.solution[start_index : start_index + 2], 

859 ), 

860 axis=0, 

861 ), 

862 ), 

863 ] 

864 return [ 

865 get_relative_direction( 

866 maze.solution[start_index - 1 : start_index + 2], 

867 ), 

868 ] 

869 

870 @serializable_dataclass(frozen=True, kw_only=True) 

871 class Distance(_StepTokenizer): 

872 """A count of the number of individual steps from the starting point to the end point. 

873 

874 Contains no information about directionality, only the distance traveled in the step. 

875 `Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`. 

876 This constraint is enforced in `_PathTokenizer.is_valid`. 

877 """ 

878 

879 # inherit docstring 

880 def to_tokens( # noqa: D102 

881 self, 

882 maze: SolvedMaze, 

883 start_index: int, 

884 end_index: int, 

885 **kwargs, 

886 ) -> list[str]: 

887 d: int = end_index - start_index 

888 return [getattr(VOCAB, f"I_{d:03}")] 

889 

890 """ 

891 `StepTokenizerPermutation` 

892 A sequence of unique `_StepTokenizer`s. 

893 This type exists mostly just for the clarity and convenience of `_PathTokenizer` code. 

894 """ 

895 StepTokenizerPermutation: type = ( 

896 tuple[_StepTokenizer] 

897 | tuple[_StepTokenizer, _StepTokenizer] 

898 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer] 

899 | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer] 

900 ) 

901 

902 

903class PathTokenizers(__TokenizerElementNamespace): 

904 """Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" 

905 

906 key = "path_tokenizer" 

907 

908 @serializable_dataclass(frozen=True, kw_only=True) 

909 class _PathTokenizer(_TokenizerElement, abc.ABC): 

910 """Superclass of tokenizers for maze solution paths.""" 

911 

912 @abc.abstractmethod 

913 def to_tokens( 

914 self, 

915 maze: SolvedMaze, 

916 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

917 ) -> list[str]: 

918 """Returns tokens representing the solution path.""" 

919 pass 

920 

921 @classmethod 

922 def attribute_key(cls) -> str: 

923 return PathTokenizers.key 

924 

925 @serializable_dataclass(frozen=True, kw_only=True) 

926 class StepSequence(_PathTokenizer, abc.ABC): 

927 """Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path. 

928 

929 Allows for a sequence of leading and trailing tokens which don't fit the step pattern. 

930 

931 # Parameters 

932 - `step_size`: Selects the size of a single step in the sequence 

933 - `step_tokenizers`: Selects the combination and permutation of tokens 

934 - `pre`: Whether all steps include an integral preceding delimiter token 

935 - `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization. 

936 - `post`: Whether all steps include an integral following delimiter token 

937 """ 

938 

939 step_size: StepSizes._StepSize = serializable_field( 

940 default=StepSizes.Singles(), 

941 loading_fn=lambda x: _load_tokenizer_element(x, StepSizes), 

942 ) 

943 step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field( 

944 default=(StepTokenizers.Coord(),), 

945 serialization_fn=lambda x: [y.serialize() for y in x], 

946 loading_fn=lambda x: tuple(x[StepTokenizers.key]), 

947 ) 

948 pre: bool = serializable_field(default=False) 

949 intra: bool = serializable_field(default=False) 

950 post: bool = serializable_field(default=False) 

951 

952 # inherit docstring 

953 def to_tokens( # noqa: D102 

954 self, 

955 maze: SolvedMaze, 

956 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

957 ) -> list[str]: 

958 return [ 

959 *self._leading_tokens(maze, coord_tokenizer), 

960 *flatten( 

961 [ 

962 self._single_step_tokens(maze, start, end, coord_tokenizer) 

963 for start, end in self.step_size.step_start_end_indices(maze) 

964 ], 

965 ), 

966 *self._trailing_tokens(maze, coord_tokenizer), 

967 ] 

968 

969 def _single_step_tokens( 

970 self, 

971 maze: SolvedMaze, 

972 i: int, 

973 j: int, 

974 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

975 ) -> list[str]: 

976 """Returns the token sequence representing a single step along the path.""" 

977 step_rep_tokens: list[list[str]] = [ 

978 step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer) 

979 for step_tokenizer in self.step_tokenizers 

980 ] 

981 if self.intra: 

982 step_rep_tokens_and_intra: list[str] = [None] * ( 

983 len(step_rep_tokens) * 2 

984 ) 

985 step_rep_tokens_and_intra[::2] = step_rep_tokens 

986 step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len( 

987 step_rep_tokens, 

988 ) 

989 step_rep_tokens = list(flatten(step_rep_tokens_and_intra)) 

990 all_tokens: list[str] = [ 

991 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 

992 *flatten(step_rep_tokens), 

993 *empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"), 

994 ] 

995 return all_tokens 

996 

997 def _leading_tokens( 

998 self, 

999 maze: SolvedMaze, 

1000 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

1001 ) -> list[str]: 

1002 """Returns tokens preceding those from the sequence from `_single_step_tokens`. 

1003 

1004 Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`. 

1005 <PATH_START> should NOT be included. 

1006 """ 

1007 if StepTokenizers.Coord() in self.step_tokenizers: 

1008 return [ 

1009 *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), 

1010 *coord_tokenizer.to_tokens(maze.solution[0, ...]), 

1011 *empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"), 

1012 ] 

1013 return [] 

1014 

1015 def _trailing_tokens( 

1016 self, 

1017 c: Coord, 

1018 coord_tokenizer: CoordTokenizers._CoordTokenizer, 

1019 ) -> list[str]: 

1020 """Returns tokens following those from the sequence from `_single_step_tokens`. 

1021 

1022 <PATH_END> should NOT be included. 

1023 """ 

1024 return [] 

1025 

1026 # inherits docstring 

1027 def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 

1028 output: bool 

1029 

1030 if len(set(self.step_tokenizers)) != len(self.step_tokenizers): 

1031 # Uninteresting: repeated elements are not useful 

1032 output = False 

1033 else: 

1034 # we do noqa for the comment if false 

1035 if len(self.step_tokenizers) == 1 and isinstance( 

1036 self.step_tokenizers[0], 

1037 StepTokenizers.Distance, 

1038 ): 

1039 # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required. 

1040 output = False 

1041 else: 

1042 output = True 

1043 

1044 if not output and do_except: 

1045 raise ValueError( 

1046 "PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.", 

1047 ) 

1048 

1049 return output 

1050 

1051 

1052class PromptSequencers(__TokenizerElementNamespace): 

1053 """Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`.""" 

1054 

1055 key = "prompt_sequencer" 

1056 

1057 @serializable_dataclass(frozen=True, kw_only=True) 

1058 class _PromptSequencer(_TokenizerElement, abc.ABC): 

1059 """Sequences token regions into a complete maze tokenization. 

1060 

1061 # Parameters 

1062 - `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position. 

1063 - `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`. 

1064 Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s. 

1065 """ 

1066 

1067 coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field( 

1068 default=CoordTokenizers.UT(), 

1069 loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers), 

1070 ) 

1071 adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field( 

1072 default=AdjListTokenizers.AdjListCoord(), 

1073 loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers), 

1074 ) 

1075 

1076 @classmethod 

1077 def attribute_key(cls) -> str: 

1078 return PromptSequencers.key 

1079 

1080 @staticmethod 

1081 def _trim_if_unsolved_maze( 

1082 untrimmed: list[str], 

1083 is_untargeted: bool = False, 

1084 is_unsolved: bool = False, 

1085 ) -> list[str]: 

1086 """Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze. 

1087 

1088 # Development 

1089 This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP. 

1090 It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses. 

1091 """ 

1092 if is_untargeted: 

1093 return tokens_between( 

1094 untrimmed, 

1095 VOCAB.ADJLIST_START, 

1096 VOCAB.ADJLIST_END, 

1097 include_start=True, 

1098 include_end=True, 

1099 ) 

1100 if is_unsolved: 

1101 if VOCAB.TARGET_END in untrimmed: 

1102 return tokens_between( 

1103 untrimmed, 

1104 VOCAB.ADJLIST_START, 

1105 VOCAB.TARGET_END, 

1106 include_start=True, 

1107 include_end=True, 

1108 ) 

1109 else: 

1110 return tokens_between( 

1111 untrimmed, 

1112 VOCAB.ADJLIST_START, 

1113 VOCAB.ORIGIN_END, 

1114 include_start=True, 

1115 include_end=True, 

1116 ) 

1117 return untrimmed 

1118 

1119 def to_tokens( 

1120 self, 

1121 maze: LatticeMaze, 

1122 *args, 

1123 **kwargs, 

1124 ) -> list[str]: 

1125 """Returns a complete list of tokens for a given set of maze elements.""" 

1126 untrimmed: list[str] = self._sequence_tokens( 

1127 *self._get_prompt_regions(maze), 

1128 ) 

1129 return self._trim_if_unsolved_maze( 

1130 untrimmed, 

1131 not hasattr(maze, "start_pos"), 

1132 not hasattr(maze, "solution"), 

1133 ) 

1134 

1135 def _get_prompt_regions( 

1136 self, 

1137 maze: LatticeMaze, 

1138 *args, 

1139 **kwargs, 

1140 ) -> list[list[str]]: 

1141 """Gets the prompt regions of a maze in a fixed sequence. 

1142 

1143 This method is NOT responsible for including/excluding any prompt regions. 

1144 Always return according to the API described under Returns. 

1145 This implementation is expected to be suitable for most `PromptSequencer` subclasses. 

1146 Subclasses may override this method if needed for special behavior. 

1147 

1148 # Returns 

1149 - [0]: list[str] Adjacency list tokens 

1150 - [1]: list[str] Origin tokens 

1151 - [2]: list[str] Target tokens 

1152 - [3]: list[str] Path tokens 

1153 

1154 # `None`-valued Args 

1155 If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized. 

1156 To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables. 

1157 """ 

1158 origin: Coord | None = getattr(maze, "start_pos", None) 

1159 target: list[Coord] | None = [ 

1160 getattr(maze, "end_pos", None), 

1161 ] # TargetTokenizer requires target: Sequence[Coord] 

1162 

1163 return [ 

1164 ( 

1165 self.adj_list_tokenizer.to_tokens( 

1166 maze, 

1167 coord_tokenizer=self.coord_tokenizer, 

1168 ) 

1169 if hasattr(self, "adj_list_tokenizer") 

1170 else [] 

1171 ), 

1172 self.coord_tokenizer.to_tokens(origin) if origin is not None else [], 

1173 ( 

1174 self.target_tokenizer.to_tokens( 

1175 target, 

1176 coord_tokenizer=self.coord_tokenizer, 

1177 ) 

1178 if target[0] is not None and hasattr(self, "target_tokenizer") 

1179 else [] 

1180 ), 

1181 ( 

1182 self.path_tokenizer.to_tokens( 

1183 maze, 

1184 coord_tokenizer=self.coord_tokenizer, 

1185 ) 

1186 if hasattr(maze, "solution") and hasattr(self, "path_tokenizer") 

1187 else [] 

1188 ), 

1189 ] 

1190 

1191 @abc.abstractmethod 

1192 def _sequence_tokens( 

1193 self, 

1194 adj_list: list[str], 

1195 origin: list[str] | None, 

1196 target: list[str] | None, 

1197 path: list[str] | None, 

1198 ) -> list[str]: 

1199 """Sequences token regions into a complete prompt. 

1200 

1201 Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as <ADJLIST_START>, <ORIGIN_END>, etc. 

1202 

1203 # Parameters 

1204 - `adj_list`: Tokens representing the adjacency list 

1205 - `origin`: Tokens representing the origin 

1206 - `target`: Tokens representing the target 

1207 - `path`: Tokens representing the path 

1208 """ 

1209 pass 

1210 

1211 def is_valid(self, do_except: bool = False) -> bool: 

1212 # No invalid instances possible within data member type hint bounds 

1213 return True 

1214 

1215 @serializable_dataclass(frozen=True, kw_only=True) 

1216 class AOTP(_PromptSequencer): 

1217 """Sequences a prompt as [adjacency list, origin, target, path]. 

1218 

1219 # Parameters 

1220 - `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`. 

1221 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`. 

1222 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 

1223 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 

1224 

1225 """ 

1226 

1227 target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field( 

1228 default=TargetTokenizers.Unlabeled(), 

1229 loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers), 

1230 ) 

1231 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 

1232 default=PathTokenizers.StepSequence(), 

1233 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 

1234 ) 

1235 

1236 def _sequence_tokens( 

1237 self, 

1238 adj_list: list[str], 

1239 origin: list[str], 

1240 target: list[str], 

1241 path: list[str], 

1242 ) -> list[str]: 

1243 return [ 

1244 VOCAB.ADJLIST_START, 

1245 *adj_list, 

1246 VOCAB.ADJLIST_END, 

1247 VOCAB.ORIGIN_START, 

1248 *origin, 

1249 VOCAB.ORIGIN_END, 

1250 VOCAB.TARGET_START, 

1251 *target, 

1252 VOCAB.TARGET_END, 

1253 VOCAB.PATH_START, 

1254 *path, 

1255 VOCAB.PATH_END, 

1256 ] 

1257 

1258 @serializable_dataclass(frozen=True, kw_only=True) 

1259 class AOP(_PromptSequencer): 

1260 """Sequences a prompt as [adjacency list, origin, path]. 

1261 

1262 Still includes "<TARGET_START>" and "<TARGET_END>" tokens, but no representation of the target itself. 

1263 

1264 # Parameters 

1265 - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. 

1266 Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. 

1267 """ 

1268 

1269 path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( 

1270 default=PathTokenizers.StepSequence(), 

1271 loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), 

1272 ) 

1273 

1274 def _sequence_tokens( 

1275 self, 

1276 adj_list: list[str], 

1277 origin: list[str], 

1278 # explicitly no target in this tokenizer 

1279 target: list[str], 

1280 path: list[str], 

1281 ) -> list[str]: 

1282 return [ 

1283 VOCAB.ADJLIST_START, 

1284 *adj_list, 

1285 VOCAB.ADJLIST_END, 

1286 VOCAB.ORIGIN_START, 

1287 *origin, 

1288 VOCAB.ORIGIN_END, 

1289 VOCAB.TARGET_START, 

1290 VOCAB.TARGET_END, 

1291 VOCAB.PATH_START, 

1292 *path, 

1293 VOCAB.PATH_END, 

1294 ]