Coverage for tests/unit/tokenization/test_tokenizer.py: 87%

254 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-27 23:43 -0600

1import itertools 

2import random 

3import re 

4from collections import Counter 

5from itertools import product 

6from typing import Iterable, Sequence 

7 

8import frozendict 

9import numpy as np 

10import pytest 

11from jaxtyping import Int 

12from muutils.misc import flatten 

13 

14from maze_dataset import ( 

15 VOCAB, 

16 ConnectionArray, 

17 Coord, 

18 CoordArray, 

19 CoordTup, 

20 LatticeMaze, 

21 MazeDataset, 

22 MazeDatasetConfig, 

23 SolvedMaze, 

24) 

25from maze_dataset.generation import LatticeMazeGenerators 

26from maze_dataset.generation.seed import GLOBAL_SEED 

27from maze_dataset.plotting.print_tokens import color_maze_tokens_AOTP 

28from maze_dataset.testing_utils import ( 

29 ASCII_MAZES, 

30 LEGACY_AND_EQUIVALENT_TOKENIZERS, 

31 MANUAL_MAZE, 

32 MAZE_DATASET, 

33 MIXED_MAZES, 

34) 

35from maze_dataset.token_utils import ( 

36 connection_list_to_adj_list, 

37 equal_except_adj_list_sequence, 

38) 

39from maze_dataset.tokenization import ( 

40 AdjListTokenizers, 

41 CoordTokenizers, 

42 EdgeGroupings, 

43 EdgePermuters, 

44 EdgeSubsets, 

45 MazeTokenizer, 

46 MazeTokenizerModular, 

47 PathTokenizers, 

48 PromptSequencers, 

49 StepSizes, 

50 StepTokenizers, 

51 TargetTokenizers, 

52 TokenizationMode, 

53 _TokenizerElement, 

54) 

55from maze_dataset.tokenization.modular.all_instances import all_instances 

56from maze_dataset.utils import lattice_max_degrees, manhattan_distance 

57 

58# Use for test fuzzing when there are too many possible tokenizers 

59NUM_TOKENIZERS_TO_TEST = 100 

60 

61 

62@pytest.mark.parametrize( 

63 ("tok_mode", "max_grid_size"), 

64 list( 

65 product( 

66 [ 

67 TokenizationMode.AOTP_UT_rasterized, 

68 TokenizationMode.AOTP_UT_uniform, 

69 TokenizationMode.AOTP_CTT_indexed, 

70 ], 

71 [None, 3, 100], 

72 ), 

73 ), 

74) 

75def test_tokenizer_serialization(tok_mode: TokenizationMode, max_grid_size: int | None): 

76 tokenizer: MazeTokenizer = MazeTokenizer( 

77 tokenization_mode=tok_mode, 

78 max_grid_size=max_grid_size, 

79 ) 

80 

81 serialized: dict = tokenizer.serialize() 

82 print(serialized) 

83 tokenizer_loaded: MazeTokenizer = MazeTokenizer.load(serialized) 

84 

85 assert tokenizer == tokenizer_loaded 

86 

87 

88def test_tokenizer(): 

89 cfg: MazeDatasetConfig = MazeDatasetConfig( 

90 name="test", 

91 grid_n=5, 

92 n_mazes=3, 

93 maze_ctor=LatticeMazeGenerators.gen_dfs, 

94 ) 

95 # to create a dataset, just call MazeDataset.from_config 

96 dataset: MazeDataset = MazeDataset.from_config( 

97 cfg, 

98 do_download=False, 

99 load_local=False, 

100 do_generate=True, 

101 save_local=False, 

102 verbose=True, 

103 gen_parallel=False, 

104 ) 

105 

106 for mode in ( 

107 TokenizationMode.AOTP_UT_rasterized, 

108 TokenizationMode.AOTP_UT_uniform, 

109 TokenizationMode.AOTP_CTT_indexed, 

110 ): 

111 tokenizer: MazeTokenizer = MazeTokenizer( 

112 tokenization_mode=mode, 

113 max_grid_size=100, 

114 ) 

115 

116 assert tokenizer.name == f"maze_tokenizer-{mode.name}-g{100}" 

117 

118 if mode == TokenizationMode.AOTP_CTT_indexed: 

119 assert tokenizer.node_strings_map is not None 

120 assert 100 < tokenizer.vocab_size < 200 

121 elif mode in ( 

122 TokenizationMode.AOTP_UT_rasterized, 

123 TokenizationMode.AOTP_UT_uniform, 

124 ): 

125 assert tokenizer.node_strings_map is None 

126 assert tokenizer.vocab_size > 10000 

127 

128 assert isinstance(tokenizer.token_arr, Iterable) 

129 assert all(isinstance(token, str) for token in tokenizer.token_arr) 

130 assert len(tokenizer.token_arr) == tokenizer.vocab_size 

131 

132 print(tokenizer.summary()) 

133 

134 for maze in dataset: 

135 # clear the cache here so we test if it works fine on the next loop 

136 tokenizer.clear_cache() 

137 

138 maze_tok = maze.as_tokens(maze_tokenizer=tokenizer) 

139 

140 maze_encoded = tokenizer.encode(maze_tok) 

141 maze_decoded = tokenizer.decode(maze_encoded) 

142 

143 assert maze_tok == maze_decoded 

144 

145 # you can view the tokens directly 

146 print("\nRaw tokens:\n") 

147 print(" ".join(maze_tok)) 

148 

149 maze_recovered = SolvedMaze.from_tokens(maze_tok, maze_tokenizer=tokenizer) 

150 

151 assert (maze.connection_list == maze_recovered.connection_list).all() 

152 

153 # or color and print them in various formats 

154 print("\nColored tokens, raw html:\n") 

155 print(color_maze_tokens_AOTP(maze_tok, fmt="html")) 

156 print("\nColored tokens, raw latex:\n") 

157 print(color_maze_tokens_AOTP(maze_tok, fmt="latex")) 

158 print("\nColored tokens, terminal:\n") 

159 print(color_maze_tokens_AOTP(maze_tok, fmt="terminal")) 

160 

161 

162@pytest.mark.parametrize( 

163 ("maze_ascii", "tokenizer", "tokens"), 

164 [ 

165 pytest.param( 

166 ASCII_MAZES[maze_ascii_key][1], # maze_ascii 

167 tokenizer, # tok_mode 

168 ASCII_MAZES[maze_ascii_key][0], # tokens 

169 id=f"{tokenizer.name}_{maze_ascii_key}", 

170 ) 

171 for maze_ascii_key, tokenizer in product( 

172 ["small_3x3", "big_10x10"], 

173 LEGACY_AND_EQUIVALENT_TOKENIZERS, 

174 ) 

175 ], 

176) 

177def test_maze_to_tokens_roundtrip( 

178 maze_ascii: list[str], 

179 tokenizer: MazeTokenizer | MazeTokenizerModular, 

180 tokens: str, 

181): 

182 if not tokenizer.is_UT(): 

183 # The hardcoded `tokens` assumes a UT tokenizer. 

184 # Here we modify `tokens` to match what a `AOTP_CTT_indexed` tokenizer would produce. 

185 tokens = re.sub(r"\(([0-9]),([0-9])\)", r"(\1 , \2)", tokens) 

186 tokens = re.sub(r"\(([0-9]+ ,)", r"( \1", tokens) 

187 tokens = re.sub(r"(, [0-9]+)\)", r"\1 )", tokens) 

188 tokens_original_split: list[str] = tokens.split() 

189 

190 # join into a single string, and get a maze out 

191 ascii_str: str = "\n".join(maze_ascii) 

192 maze: SolvedMaze = SolvedMaze.from_ascii(ascii_str) 

193 

194 # maze as tokens 

195 tokens_from_maze: list[str] = maze.as_tokens(tokenizer) 

196 

197 # maze round trip 

198 maze_roundtrip: SolvedMaze = SolvedMaze.from_tokens(tokens_from_maze, tokenizer) 

199 tokens_roundtrip: list[str] = maze_roundtrip.as_tokens(tokenizer) 

200 

201 # check that the mazes and tokens are all equivalent 

202 assert maze == maze_roundtrip 

203 assert equal_except_adj_list_sequence(tokens_original_split, tokens_from_maze) 

204 assert equal_except_adj_list_sequence(tokens_original_split, tokens_roundtrip) 

205 

206 

207@pytest.mark.parametrize( 

208 ("tok_mode", "max_grid_size", "result"), 

209 [ 

210 pytest.param( 

211 tok_mode, 

212 max_grid_size, 

213 MazeTokenizer(tokenization_mode=tok_mode, max_grid_size=max_grid_size), 

214 id=f"{tok_mode}-{max_grid_size}", 

215 ) 

216 for tok_mode, max_grid_size in [ 

217 (TokenizationMode.AOTP_CTT_indexed, None), 

218 (TokenizationMode.AOTP_UT_rasterized, None), 

219 (TokenizationMode.AOTP_UT_uniform, None), 

220 (TokenizationMode.AOTP_CTT_indexed, 5), 

221 ] 

222 ], 

223) 

224def test_to_legacy_tokenizer( 

225 tok_mode: TokenizationMode, 

226 max_grid_size: int | None, 

227 result: MazeTokenizer, 

228): 

229 assert tok_mode.to_legacy_tokenizer(max_grid_size) == result 

230 

231 

232# MazeTokenizerModular tests 

233# ===================== 

234 

235# Backwards compatibility tests 

236# ============================= 

237 

238 

239@pytest.mark.parametrize( 

240 ("maze", "legacy_tokenizer"), 

241 [ 

242 pytest.param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}") 

243 for maze, tok_spec in itertools.product( 

244 [(maze, i) for i, maze in enumerate(MIXED_MAZES)], 

245 [tok_mode for tok_mode in TokenizationMode], # noqa: C416 

246 ) 

247 ], 

248) 

249def test_to_tokens_backwards_compatible( 

250 maze: SolvedMaze, 

251 legacy_tokenizer: TokenizationMode, 

252): 

253 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tokenizer) 

254 toks: list[str] = maze.as_tokens(tokenizer) 

255 toks2: list[str] = tokenizer.to_tokens(maze) 

256 toks_legacy: list[str] = maze.as_tokens(legacy_tokenizer) 

257 

258 try: 

259 assert equal_except_adj_list_sequence(toks, toks_legacy) 

260 assert equal_except_adj_list_sequence(toks2, toks_legacy) 

261 except AssertionError as e: 

262 msg: str = ( 

263 "Tokens from `as_tokens` and `to_tokens` should be equal to tokens from `as_tokens` with the legacy tokenizer.\n" 

264 f"{len(toks) = }, {len(toks2) = }, {len(toks_legacy) = }\n" 

265 f"{toks = }\n{toks2 = }\n{toks_legacy = }" 

266 ) 

267 raise AssertionError(msg) from e 

268 

269 

270@pytest.mark.parametrize( 

271 ("coords", "legacy_tok_mode"), 

272 [ 

273 pytest.param( 

274 coords, 

275 tok_mode, 

276 id=f"{tok_mode.value}-coords(type={type(coords[0])},len={len(coords)})", 

277 ) 

278 for tok_mode, coords in itertools.product( 

279 [tok_mode for tok_mode in TokenizationMode], # noqa: C416 

280 [ 

281 *[[maze.start_pos] for maze in MAZE_DATASET.mazes[:2]], 

282 [maze.start_pos for maze in MAZE_DATASET.mazes], 

283 *[[tuple(maze.start_pos)] for maze in MAZE_DATASET.mazes[:2]], 

284 [tuple(maze.start_pos) for maze in MAZE_DATASET.mazes], 

285 ], 

286 ) 

287 ], 

288) 

289def test_coords_to_strings_backwards_compatible( 

290 coords: list[Coord, CoordTup], 

291 legacy_tok_mode: TokenizationMode, 

292): 

293 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tok_mode) 

294 legacy_tokenizer = MazeTokenizer(tokenization_mode=legacy_tok_mode) 

295 strings: list[str] = tokenizer.coords_to_strings(coords) 

296 strings_legacy: list[str] = legacy_tokenizer.coords_to_strings(coords) 

297 assert strings == strings_legacy 

298 

299 

300@pytest.mark.parametrize( 

301 ("maze", "tok_mode"), 

302 [ 

303 pytest.param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}") 

304 for maze, tok_spec in itertools.product( 

305 [(maze, i) for i, maze in enumerate(MIXED_MAZES)], 

306 [tok_mode for tok_mode in TokenizationMode], # noqa: C416 

307 ) 

308 ], 

309) 

310def test_from_tokens_backwards_compatible( 

311 maze: LatticeMaze, 

312 tok_mode: TokenizationMode, 

313): 

314 tokenizer = MazeTokenizerModular.from_legacy(tok_mode) 

315 toks = maze.as_tokens(tok_mode) 

316 # Equality test of `as_tokens` output done in a separate unit test 

317 maze_legacy: LatticeMaze = LatticeMaze.from_tokens(toks, tok_mode) 

318 maze: LatticeMaze = LatticeMaze.from_tokens(toks, tokenizer) 

319 assert maze == maze_legacy 

320 

321 

322# General functionality tests 

323# =========================== 

324 

325 

326@pytest.mark.parametrize( 

327 ("el", "result"), 

328 [ 

329 pytest.param(elem, result, id=elem.name) 

330 for elem, result in [ 

331 (CoordTokenizers.CTT(), True), 

332 (CoordTokenizers.CTT(intra=True), True), 

333 (CoordTokenizers.UT(), True), 

334 (AdjListTokenizers.AdjListCoord(), True), 

335 (AdjListTokenizers.AdjListCoord(post=True), True), 

336 (TargetTokenizers.Unlabeled(post=True), True), 

337 (PathTokenizers.StepSequence(), True), 

338 ( 

339 PathTokenizers.StepSequence(step_tokenizers=(StepTokenizers.Coord(),)), 

340 True, 

341 ), 

342 ( 

343 PathTokenizers.StepSequence( 

344 step_tokenizers=( 

345 StepTokenizers.Coord(), 

346 StepTokenizers.Coord(), 

347 ), 

348 ), 

349 False, 

350 ), 

351 (PromptSequencers.AOP(), True), 

352 (PromptSequencers.AOP(path_tokenizer=PathTokenizers.StepSequence()), True), 

353 ( 

354 PromptSequencers.AOP( 

355 path_tokenizer=PathTokenizers.StepSequence( 

356 step_tokenizers=(StepTokenizers.Coord(),), 

357 ), 

358 ), 

359 True, 

360 ), 

361 ( 

362 PromptSequencers.AOP( 

363 path_tokenizer=PathTokenizers.StepSequence( 

364 step_tokenizers=( 

365 StepTokenizers.Coord(), 

366 StepTokenizers.Coord(), 

367 ), 

368 ), 

369 ), 

370 True, 

371 ), 

372 ] 

373 ], 

374) 

375def test_tokenizer_element_is_valid(el: _TokenizerElement, result: bool): 

376 assert el.is_valid() == result 

377 

378 

379@pytest.mark.parametrize( 

380 ("tokenizer", "result"), 

381 [ 

382 pytest.param(tokenizer, result, id=str(tokenizer)) 

383 for tokenizer, result in [ 

384 (MazeTokenizerModular(), True), 

385 (MazeTokenizerModular.from_legacy(TokenizationMode.AOTP_CTT_indexed), True), 

386 (MazeTokenizerModular(prompt_sequencer=PromptSequencers.AOP()), False), 

387 ] 

388 ], 

389) 

390def test_is_legacy_equivalent(tokenizer: MazeTokenizerModular, result: bool): 

391 assert tokenizer.is_legacy_equivalent() == result 

392 

393 

394def _helper_test_path_tokenizers( 

395 pt: PathTokenizers._PathTokenizer, 

396 maze: SolvedMaze, 

397 footprint_inds: Sequence[int], 

398): 

399 ct: CoordTokenizers._CoordTokenizer = CoordTokenizers.UT() 

400 path_toks: list[str] = pt.to_tokens(maze, ct) 

401 path_toks_set: set[str] = set(path_toks) 

402 footprint_inds: Int[np.ndarray, " footprint_index"] = np.array(footprint_inds) 

403 footprints: Int[np.ndarray, "footprint_index row_col=2"] = maze.solution[ 

404 footprint_inds 

405 ] 

406 if StepTokenizers.Coord() in pt.step_tokenizers: 

407 non_steps: set[CoordTup] = {tuple(c) for c in maze.solution} - { 

408 tuple(c) for c in footprints 

409 } 

410 assert all(ct.to_tokens(coord)[0] in path_toks_set for coord in footprints) 

411 assert all(ct.to_tokens(coord)[0] not in path_toks_set for coord in non_steps) 

412 if StepTokenizers.Distance() in pt.step_tokenizers: 

413 distances: list[int] = footprint_inds[1:] - footprint_inds[:-1] 

414 assert ( 

415 len( 

416 Counter(getattr(VOCAB, f"I_{d:03}") for d in distances) 

417 - Counter(path_toks), 

418 ) 

419 == 0 

420 ) 

421 if StepTokenizers.Cardinal() in pt.step_tokenizers: 

422 c = Counter(path_toks) 

423 assert ( 

424 c[VOCAB.PATH_NORTH] 

425 + c[VOCAB.PATH_SOUTH] 

426 + c[VOCAB.PATH_EAST] 

427 + c[VOCAB.PATH_WEST] 

428 == len(footprint_inds) - 1 

429 ) 

430 if StepTokenizers.Relative() in pt.step_tokenizers: 

431 c = Counter(path_toks) 

432 assert ( 

433 c[VOCAB.PATH_LEFT] 

434 + c[VOCAB.PATH_RIGHT] 

435 + c[VOCAB.PATH_FORWARD] 

436 + c[VOCAB.PATH_BACKWARD] 

437 == len(footprint_inds) - 1 

438 ) 

439 

440 

441@pytest.mark.parametrize( 

442 ("pt", "manual_maze"), 

443 [ 

444 pytest.param(tokenizer, maze_kv[1], id=f"{tokenizer.name}-{maze_kv[0]}") 

445 for maze_kv, tokenizer in itertools.product( 

446 ASCII_MAZES.items(), 

447 random.sample( 

448 list( 

449 all_instances( 

450 PathTokenizers._PathTokenizer, 

451 {_TokenizerElement: lambda x: x.is_valid()}, 

452 ), 

453 ), 

454 NUM_TOKENIZERS_TO_TEST, 

455 ), 

456 ) 

457 ], 

458) 

459def test_path_tokenizers(pt: PathTokenizers._PathTokenizer, manual_maze: MANUAL_MAZE): 

460 solved_maze: SolvedMaze = SolvedMaze.from_ascii("\n".join(manual_maze.ascii)) 

461 match type(pt.step_size): 

462 case StepSizes.Singles: 

463 footprint_inds = range(solved_maze.solution.shape[0]) 

464 case StepSizes.Straightaways: 

465 swy_coordtup_set: set[CoordTup] = { 

466 tuple(c) for c in manual_maze.straightaway_footprints 

467 } 

468 footprint_inds: list[int] = [ 

469 i 

470 for i, c in enumerate(solved_maze.solution) 

471 if tuple(c) in swy_coordtup_set 

472 ] 

473 case StepSizes.Forks: 

474 footprint_inds = solved_maze.get_solution_forking_points( 

475 always_include_endpoints=True, 

476 )[0] 

477 case StepSizes.ForksAndStraightaways: 

478 swy_step_inds: list[int] = StepSizes.Straightaways()._step_single_indices( 

479 solved_maze, 

480 ) 

481 footprint_inds: Int[np.ndarray, " footprint_index"] = np.concatenate( 

482 ( 

483 solved_maze.get_solution_forking_points( 

484 always_include_endpoints=True, 

485 )[0], 

486 swy_step_inds, 

487 ), 

488 ) 

489 footprint_inds, _ = np.unique(footprint_inds, axis=0, return_index=True) 

490 _helper_test_path_tokenizers( 

491 pt, 

492 solved_maze, 

493 footprint_inds, 

494 ) 

495 

496 

497@pytest.mark.parametrize( 

498 ("ep", "maze"), 

499 [ 

500 pytest.param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]") 

501 for (i, maze), tokenizer in itertools.product( 

502 enumerate(MIXED_MAZES[:6]), 

503 all_instances( 

504 EdgePermuters._EdgePermuter, 

505 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}), 

506 ), 

507 ) 

508 ], 

509) 

510def test_edge_permuters(ep: EdgePermuters._EdgePermuter, maze: LatticeMaze): 

511 edges: ConnectionArray = connection_list_to_adj_list( 

512 maze.connection_list, 

513 shuffle_d0=False, 

514 shuffle_d1=False, 

515 ) 

516 edges_copy: ConnectionArray = connection_list_to_adj_list( 

517 maze.connection_list, 

518 shuffle_d0=False, 

519 shuffle_d1=False, 

520 ) 

521 assert np.array_equal(edges, edges_copy) 

522 old_shape = edges.shape 

523 permuted: ConnectionArray = ep._permute(edges) 

524 match ep: 

525 case EdgePermuters.RandomCoords(): 

526 assert permuted.shape == old_shape 

527 assert edges is permuted 

528 i = 0 

529 while np.array_equal(permuted, edges_copy) and i < 2: 

530 # Permute again in case for small mazes the random selection happened to not change anything 

531 permuted: ConnectionArray = ep._permute(permuted) 

532 i += 1 

533 assert not np.array_equal(permuted, edges_copy) 

534 case EdgePermuters.BothCoords(): 

535 new_shape = old_shape[0] * 2, *old_shape[1:] 

536 n = old_shape[0] 

537 assert permuted.shape == new_shape 

538 assert np.array_equal(permuted[:n, ...], edges_copy) 

539 assert np.array_equal(permuted[:n, 0, :], permuted[n:, 1, :]) 

540 assert np.array_equal(permuted[:n, 1, :], permuted[n:, 0, :]) 

541 assert edges is not permuted 

542 

543 

544@pytest.mark.parametrize( 

545 ("es", "maze"), 

546 [ 

547 pytest.param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]") 

548 for (i, maze), tokenizer in itertools.product( 

549 enumerate(MIXED_MAZES[:6]), 

550 all_instances( 

551 EdgeSubsets._EdgeSubset, 

552 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}), 

553 ), 

554 ) 

555 ], 

556) 

557def test_edge_subsets(es: EdgeSubsets._EdgeSubset, maze: LatticeMaze): 

558 edges: ConnectionArray = es._get_edges(maze) 

559 n: int = maze.grid_n 

560 match type(es): 

561 case EdgeSubsets.AllLatticeEdges: 

562 assert_shape: tuple = (2 * n * (n - 1), 2, 2) 

563 case EdgeSubsets.ConnectionEdges: 

564 if not es.walls: 

565 assert_shape: tuple = (np.count_nonzero(maze.connection_list), 2, 2) 

566 else: 

567 assert_shape: tuple = ( 

568 2 * n * (n - 1) - np.count_nonzero(maze.connection_list), 

569 2, 

570 2, 

571 ) 

572 assert edges.dtype == np.int8 

573 assert assert_shape == tuple(edges.shape) 

574 assert assert_shape == tuple( 

575 np.unique(edges, axis=0).shape, 

576 ) # All edges are unique (swapping leading/trailing coords is considered different) 

577 assert np.array_equal( 

578 manhattan_distance(edges), 

579 np.array([1] * assert_shape[0], dtype=np.int8), 

580 ) 

581 

582 

583@pytest.mark.parametrize( 

584 ("tok_elem", "es", "maze"), 

585 [ 

586 # we do a little accessing private members here 

587 pytest.param(tok_elem, es, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]") 

588 for (i, maze), tok_elem, es in itertools.product( 

589 enumerate(MIXED_MAZES[:6]), 

590 all_instances( 

591 EdgeGroupings._EdgeGrouping, 

592 frozendict.frozendict( 

593 { 

594 _TokenizerElement: lambda x: x.is_valid(), 

595 # Add a condition to prune the range space that doesn't affect functionality being tested 

596 EdgeGroupings.ByLeadingCoord: lambda x: x.intra 

597 and x.connection_token_ordinal == 1, 

598 }, 

599 ), 

600 ), 

601 all_instances( 

602 EdgeSubsets._EdgeSubset, 

603 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}), 

604 ), 

605 ) 

606 ], 

607) 

608def test_edge_groupings( 

609 tok_elem: EdgeGroupings._EdgeGrouping, 

610 es: EdgeSubsets._EdgeSubset, 

611 maze: LatticeMaze, 

612): 

613 # we do a little more accessing private members here 

614 edges: ConnectionArray = es._get_edges(maze) 

615 # n: int = maze.grid_n 

616 groups: Sequence[ConnectionArray] = tok_elem._group_edges(edges) 

617 

618 assert all( 

619 not np.any(np.diff(g[:, 0], axis=0)) for g in groups 

620 ) # Asserts that the leading coord is the same for all edges within each group 

621 match type(tok_elem): 

622 case EdgeGroupings.Ungrouped: 

623 assert_shape = edges.shape[0], 1, 2, 2 

624 assert tuple(groups.shape) == assert_shape 

625 case EdgeGroupings.ByLeadingCoord: 

626 assert len(groups) == np.unique(edges[:, 0, :], axis=0).shape[0] 

627 assert sum(g.shape[0] for g in groups) == edges.shape[0] 

628 # trailing_coords: list[CoordArray] = [g[:, 1, :] for g in groups] 

629 # vector_diffs is the position vector difference between the trailing coords of each group 

630 # These are stacked into a single array since we don't care about maintaining group separation 

631 vector_diffs: CoordArray = np.stack( 

632 list(flatten([np.diff(g[:, 1, :], axis=0) for g in groups], 1)), 

633 ) 

634 if tok_elem.shuffle_group: 

635 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)} 

636 # The set of all 2D vectors between any 2 coords adjacent to a central coord 

637 allowed_diffs = allowed_diffs.union( 

638 {(-d[0], -d[1]) for d in allowed_diffs}, 

639 ) 

640 else: 

641 # If vector_diffs are lexicographically sorted, these are the only possible values. Any other value indicates an error in sorting 

642 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)} 

643 assert all( 

644 tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0) 

645 ) 

646 

647 

648random.seed(GLOBAL_SEED) 

649 

650 

651@pytest.mark.parametrize( 

652 ("tok_elem", "maze"), 

653 [ 

654 pytest.param(tok_elem, maze, id=f"{tok_elem.name}-maze[{i}]") 

655 for (i, maze), tok_elem in itertools.product( 

656 enumerate(MAZE_DATASET), 

657 random.sample( 

658 list( 

659 all_instances( 

660 # yes we access a private member 

661 AdjListTokenizers._AdjListTokenizer, 

662 { 

663 _TokenizerElement: lambda x: x.is_valid(), 

664 }, 

665 ), 

666 ), 

667 100, 

668 ), 

669 ) 

670 ], 

671) 

672# too many branches and "too complex" but whatever 

673def test_adjlist_tokenizers( # noqa: PLR0912,C901 

674 tok_elem: AdjListTokenizers._AdjListTokenizer, 

675 maze: LatticeMaze, 

676): 

677 toks: list[str] = tok_elem.to_tokens(maze, CoordTokenizers.UT()) 

678 tok_counter: Counter = Counter(toks) 

679 n: int = maze.grid_n 

680 edge_count: int = 1 # To be updated in match/case blocks 

681 group_count: int = 1 # To be updated in match/case blocks 

682 

683 match tok_elem.edge_subset: 

684 case EdgeSubsets.AllLatticeEdges(): 

685 edge_count *= n * (n - 1) * 2 

686 case EdgeSubsets.ConnectionEdges(walls=False): 

687 edge_count *= np.count_nonzero(maze.connection_list) 

688 case EdgeSubsets.ConnectionEdges(walls=True): 

689 edge_count *= n * (n - 1) * 2 - np.count_nonzero(maze.connection_list) 

690 case _: 

691 msg: str = f"`match` case missing for {tok_elem.edge_subset = }" 

692 raise NotImplementedError(msg) 

693 

694 match tok_elem.edge_permuter: 

695 case EdgePermuters.BothCoords(): 

696 edge_count *= 2 

697 if tok_elem.edge_subset == EdgeSubsets.ConnectionEdges(walls=True): 

698 group_count *= np.count_nonzero( 

699 lattice_max_degrees(n) - maze.coord_degrees() > 0, 

700 ) # All coords with 1 adjacent wall, not counting outer boundaries 

701 else: 

702 group_count *= np.count_nonzero( 

703 maze.coord_degrees() > 0, 

704 ) # All coords with >0 connections 

705 case EdgePermuters.RandomCoords() | EdgePermuters.SortedCoords(): 

706 edge_count *= 1 

707 group_count = None # Group count is stochastic 

708 

709 match type(tok_elem.edge_grouping): 

710 case EdgeGroupings.Ungrouped: 

711 group_count = edge_count # Override all above cases 

712 case EdgeGroupings.ByLeadingCoord: 

713 if group_count is not None: 

714 group_count *= 1 

715 if tok_elem.edge_grouping.intra: 

716 assert tok_counter[VOCAB.ADJLIST_INTRA] == edge_count 

717 case _: 

718 msg: str = f"`match` case missing for {tok_elem.edge_grouping = }" 

719 raise NotImplementedError(msg) 

720 

721 match type(tok_elem): 

722 case AdjListTokenizers.AdjListCoord: 

723 pass 

724 case AdjListTokenizers.AdjListCardinal: 

725 assert ( 

726 tok_counter[VOCAB.PATH_NORTH] 

727 + tok_counter[VOCAB.PATH_SOUTH] 

728 + tok_counter[VOCAB.PATH_EAST] 

729 + tok_counter[VOCAB.PATH_WEST] 

730 == edge_count 

731 ) 

732 

733 if group_count is not None: 

734 if tok_elem.pre: 

735 assert tok_counter[VOCAB.ADJLIST_PRE] == group_count 

736 if tok_elem.post: 

737 assert tok_counter[VOCAB.ADJACENCY_ENDLINE] == group_count 

738 

739 assert tok_counter[VOCAB.CONNECTOR] + tok_counter[VOCAB.ADJLIST_WALL] == edge_count 

740 

741 

742@pytest.mark.parametrize( 

743 ("tok_elem", "valid"), 

744 [ 

745 pytest.param( 

746 tok_elem, 

747 valid, 

748 id=f"{tok_elem!r}", 

749 ) 

750 for tok_elem, valid in ( 

751 [ 

752 (StepSizes.ForksAndStraightaways(), False), 

753 (StepSizes.Straightaways(), False), 

754 (StepSizes.Forks(), True), 

755 (AdjListTokenizers.AdjListCoord(), True), 

756 (AdjListTokenizers.AdjListCoord(pre=True), False), 

757 (AdjListTokenizers.AdjListCardinal(), True), 

758 (AdjListTokenizers.AdjListCardinal(pre=True), False), 

759 (EdgeGroupings.Ungrouped(), True), 

760 (EdgeGroupings.ByLeadingCoord(), False), 

761 (EdgeGroupings.ByLeadingCoord(connection_token_ordinal=0), False), 

762 ] 

763 ) 

764 ], 

765) 

766def test_unsupported_elements(tok_elem: _TokenizerElement, valid: bool): 

767 assert tok_elem.is_valid() == valid