Coverage for tests/unit/tokenization/test_tokenizer.py: 87%
254 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-27 23:43 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-27 23:43 -0600
1import itertools
2import random
3import re
4from collections import Counter
5from itertools import product
6from typing import Iterable, Sequence
8import frozendict
9import numpy as np
10import pytest
11from jaxtyping import Int
12from muutils.misc import flatten
14from maze_dataset import (
15 VOCAB,
16 ConnectionArray,
17 Coord,
18 CoordArray,
19 CoordTup,
20 LatticeMaze,
21 MazeDataset,
22 MazeDatasetConfig,
23 SolvedMaze,
24)
25from maze_dataset.generation import LatticeMazeGenerators
26from maze_dataset.generation.seed import GLOBAL_SEED
27from maze_dataset.plotting.print_tokens import color_maze_tokens_AOTP
28from maze_dataset.testing_utils import (
29 ASCII_MAZES,
30 LEGACY_AND_EQUIVALENT_TOKENIZERS,
31 MANUAL_MAZE,
32 MAZE_DATASET,
33 MIXED_MAZES,
34)
35from maze_dataset.token_utils import (
36 connection_list_to_adj_list,
37 equal_except_adj_list_sequence,
38)
39from maze_dataset.tokenization import (
40 AdjListTokenizers,
41 CoordTokenizers,
42 EdgeGroupings,
43 EdgePermuters,
44 EdgeSubsets,
45 MazeTokenizer,
46 MazeTokenizerModular,
47 PathTokenizers,
48 PromptSequencers,
49 StepSizes,
50 StepTokenizers,
51 TargetTokenizers,
52 TokenizationMode,
53 _TokenizerElement,
54)
55from maze_dataset.tokenization.modular.all_instances import all_instances
56from maze_dataset.utils import lattice_max_degrees, manhattan_distance
58# Use for test fuzzing when there are too many possible tokenizers
59NUM_TOKENIZERS_TO_TEST = 100
62@pytest.mark.parametrize(
63 ("tok_mode", "max_grid_size"),
64 list(
65 product(
66 [
67 TokenizationMode.AOTP_UT_rasterized,
68 TokenizationMode.AOTP_UT_uniform,
69 TokenizationMode.AOTP_CTT_indexed,
70 ],
71 [None, 3, 100],
72 ),
73 ),
74)
75def test_tokenizer_serialization(tok_mode: TokenizationMode, max_grid_size: int | None):
76 tokenizer: MazeTokenizer = MazeTokenizer(
77 tokenization_mode=tok_mode,
78 max_grid_size=max_grid_size,
79 )
81 serialized: dict = tokenizer.serialize()
82 print(serialized)
83 tokenizer_loaded: MazeTokenizer = MazeTokenizer.load(serialized)
85 assert tokenizer == tokenizer_loaded
88def test_tokenizer():
89 cfg: MazeDatasetConfig = MazeDatasetConfig(
90 name="test",
91 grid_n=5,
92 n_mazes=3,
93 maze_ctor=LatticeMazeGenerators.gen_dfs,
94 )
95 # to create a dataset, just call MazeDataset.from_config
96 dataset: MazeDataset = MazeDataset.from_config(
97 cfg,
98 do_download=False,
99 load_local=False,
100 do_generate=True,
101 save_local=False,
102 verbose=True,
103 gen_parallel=False,
104 )
106 for mode in (
107 TokenizationMode.AOTP_UT_rasterized,
108 TokenizationMode.AOTP_UT_uniform,
109 TokenizationMode.AOTP_CTT_indexed,
110 ):
111 tokenizer: MazeTokenizer = MazeTokenizer(
112 tokenization_mode=mode,
113 max_grid_size=100,
114 )
116 assert tokenizer.name == f"maze_tokenizer-{mode.name}-g{100}"
118 if mode == TokenizationMode.AOTP_CTT_indexed:
119 assert tokenizer.node_strings_map is not None
120 assert 100 < tokenizer.vocab_size < 200
121 elif mode in (
122 TokenizationMode.AOTP_UT_rasterized,
123 TokenizationMode.AOTP_UT_uniform,
124 ):
125 assert tokenizer.node_strings_map is None
126 assert tokenizer.vocab_size > 10000
128 assert isinstance(tokenizer.token_arr, Iterable)
129 assert all(isinstance(token, str) for token in tokenizer.token_arr)
130 assert len(tokenizer.token_arr) == tokenizer.vocab_size
132 print(tokenizer.summary())
134 for maze in dataset:
135 # clear the cache here so we test if it works fine on the next loop
136 tokenizer.clear_cache()
138 maze_tok = maze.as_tokens(maze_tokenizer=tokenizer)
140 maze_encoded = tokenizer.encode(maze_tok)
141 maze_decoded = tokenizer.decode(maze_encoded)
143 assert maze_tok == maze_decoded
145 # you can view the tokens directly
146 print("\nRaw tokens:\n")
147 print(" ".join(maze_tok))
149 maze_recovered = SolvedMaze.from_tokens(maze_tok, maze_tokenizer=tokenizer)
151 assert (maze.connection_list == maze_recovered.connection_list).all()
153 # or color and print them in various formats
154 print("\nColored tokens, raw html:\n")
155 print(color_maze_tokens_AOTP(maze_tok, fmt="html"))
156 print("\nColored tokens, raw latex:\n")
157 print(color_maze_tokens_AOTP(maze_tok, fmt="latex"))
158 print("\nColored tokens, terminal:\n")
159 print(color_maze_tokens_AOTP(maze_tok, fmt="terminal"))
162@pytest.mark.parametrize(
163 ("maze_ascii", "tokenizer", "tokens"),
164 [
165 pytest.param(
166 ASCII_MAZES[maze_ascii_key][1], # maze_ascii
167 tokenizer, # tok_mode
168 ASCII_MAZES[maze_ascii_key][0], # tokens
169 id=f"{tokenizer.name}_{maze_ascii_key}",
170 )
171 for maze_ascii_key, tokenizer in product(
172 ["small_3x3", "big_10x10"],
173 LEGACY_AND_EQUIVALENT_TOKENIZERS,
174 )
175 ],
176)
177def test_maze_to_tokens_roundtrip(
178 maze_ascii: list[str],
179 tokenizer: MazeTokenizer | MazeTokenizerModular,
180 tokens: str,
181):
182 if not tokenizer.is_UT():
183 # The hardcoded `tokens` assumes a UT tokenizer.
184 # Here we modify `tokens` to match what a `AOTP_CTT_indexed` tokenizer would produce.
185 tokens = re.sub(r"\(([0-9]),([0-9])\)", r"(\1 , \2)", tokens)
186 tokens = re.sub(r"\(([0-9]+ ,)", r"( \1", tokens)
187 tokens = re.sub(r"(, [0-9]+)\)", r"\1 )", tokens)
188 tokens_original_split: list[str] = tokens.split()
190 # join into a single string, and get a maze out
191 ascii_str: str = "\n".join(maze_ascii)
192 maze: SolvedMaze = SolvedMaze.from_ascii(ascii_str)
194 # maze as tokens
195 tokens_from_maze: list[str] = maze.as_tokens(tokenizer)
197 # maze round trip
198 maze_roundtrip: SolvedMaze = SolvedMaze.from_tokens(tokens_from_maze, tokenizer)
199 tokens_roundtrip: list[str] = maze_roundtrip.as_tokens(tokenizer)
201 # check that the mazes and tokens are all equivalent
202 assert maze == maze_roundtrip
203 assert equal_except_adj_list_sequence(tokens_original_split, tokens_from_maze)
204 assert equal_except_adj_list_sequence(tokens_original_split, tokens_roundtrip)
207@pytest.mark.parametrize(
208 ("tok_mode", "max_grid_size", "result"),
209 [
210 pytest.param(
211 tok_mode,
212 max_grid_size,
213 MazeTokenizer(tokenization_mode=tok_mode, max_grid_size=max_grid_size),
214 id=f"{tok_mode}-{max_grid_size}",
215 )
216 for tok_mode, max_grid_size in [
217 (TokenizationMode.AOTP_CTT_indexed, None),
218 (TokenizationMode.AOTP_UT_rasterized, None),
219 (TokenizationMode.AOTP_UT_uniform, None),
220 (TokenizationMode.AOTP_CTT_indexed, 5),
221 ]
222 ],
223)
224def test_to_legacy_tokenizer(
225 tok_mode: TokenizationMode,
226 max_grid_size: int | None,
227 result: MazeTokenizer,
228):
229 assert tok_mode.to_legacy_tokenizer(max_grid_size) == result
232# MazeTokenizerModular tests
233# =====================
235# Backwards compatibility tests
236# =============================
239@pytest.mark.parametrize(
240 ("maze", "legacy_tokenizer"),
241 [
242 pytest.param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}")
243 for maze, tok_spec in itertools.product(
244 [(maze, i) for i, maze in enumerate(MIXED_MAZES)],
245 [tok_mode for tok_mode in TokenizationMode], # noqa: C416
246 )
247 ],
248)
249def test_to_tokens_backwards_compatible(
250 maze: SolvedMaze,
251 legacy_tokenizer: TokenizationMode,
252):
253 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tokenizer)
254 toks: list[str] = maze.as_tokens(tokenizer)
255 toks2: list[str] = tokenizer.to_tokens(maze)
256 toks_legacy: list[str] = maze.as_tokens(legacy_tokenizer)
258 try:
259 assert equal_except_adj_list_sequence(toks, toks_legacy)
260 assert equal_except_adj_list_sequence(toks2, toks_legacy)
261 except AssertionError as e:
262 msg: str = (
263 "Tokens from `as_tokens` and `to_tokens` should be equal to tokens from `as_tokens` with the legacy tokenizer.\n"
264 f"{len(toks) = }, {len(toks2) = }, {len(toks_legacy) = }\n"
265 f"{toks = }\n{toks2 = }\n{toks_legacy = }"
266 )
267 raise AssertionError(msg) from e
270@pytest.mark.parametrize(
271 ("coords", "legacy_tok_mode"),
272 [
273 pytest.param(
274 coords,
275 tok_mode,
276 id=f"{tok_mode.value}-coords(type={type(coords[0])},len={len(coords)})",
277 )
278 for tok_mode, coords in itertools.product(
279 [tok_mode for tok_mode in TokenizationMode], # noqa: C416
280 [
281 *[[maze.start_pos] for maze in MAZE_DATASET.mazes[:2]],
282 [maze.start_pos for maze in MAZE_DATASET.mazes],
283 *[[tuple(maze.start_pos)] for maze in MAZE_DATASET.mazes[:2]],
284 [tuple(maze.start_pos) for maze in MAZE_DATASET.mazes],
285 ],
286 )
287 ],
288)
289def test_coords_to_strings_backwards_compatible(
290 coords: list[Coord, CoordTup],
291 legacy_tok_mode: TokenizationMode,
292):
293 tokenizer: MazeTokenizerModular = MazeTokenizerModular.from_legacy(legacy_tok_mode)
294 legacy_tokenizer = MazeTokenizer(tokenization_mode=legacy_tok_mode)
295 strings: list[str] = tokenizer.coords_to_strings(coords)
296 strings_legacy: list[str] = legacy_tokenizer.coords_to_strings(coords)
297 assert strings == strings_legacy
300@pytest.mark.parametrize(
301 ("maze", "tok_mode"),
302 [
303 pytest.param(maze[0], tok_spec, id=f"{tok_spec.value}-maze{maze[1]}")
304 for maze, tok_spec in itertools.product(
305 [(maze, i) for i, maze in enumerate(MIXED_MAZES)],
306 [tok_mode for tok_mode in TokenizationMode], # noqa: C416
307 )
308 ],
309)
310def test_from_tokens_backwards_compatible(
311 maze: LatticeMaze,
312 tok_mode: TokenizationMode,
313):
314 tokenizer = MazeTokenizerModular.from_legacy(tok_mode)
315 toks = maze.as_tokens(tok_mode)
316 # Equality test of `as_tokens` output done in a separate unit test
317 maze_legacy: LatticeMaze = LatticeMaze.from_tokens(toks, tok_mode)
318 maze: LatticeMaze = LatticeMaze.from_tokens(toks, tokenizer)
319 assert maze == maze_legacy
322# General functionality tests
323# ===========================
326@pytest.mark.parametrize(
327 ("el", "result"),
328 [
329 pytest.param(elem, result, id=elem.name)
330 for elem, result in [
331 (CoordTokenizers.CTT(), True),
332 (CoordTokenizers.CTT(intra=True), True),
333 (CoordTokenizers.UT(), True),
334 (AdjListTokenizers.AdjListCoord(), True),
335 (AdjListTokenizers.AdjListCoord(post=True), True),
336 (TargetTokenizers.Unlabeled(post=True), True),
337 (PathTokenizers.StepSequence(), True),
338 (
339 PathTokenizers.StepSequence(step_tokenizers=(StepTokenizers.Coord(),)),
340 True,
341 ),
342 (
343 PathTokenizers.StepSequence(
344 step_tokenizers=(
345 StepTokenizers.Coord(),
346 StepTokenizers.Coord(),
347 ),
348 ),
349 False,
350 ),
351 (PromptSequencers.AOP(), True),
352 (PromptSequencers.AOP(path_tokenizer=PathTokenizers.StepSequence()), True),
353 (
354 PromptSequencers.AOP(
355 path_tokenizer=PathTokenizers.StepSequence(
356 step_tokenizers=(StepTokenizers.Coord(),),
357 ),
358 ),
359 True,
360 ),
361 (
362 PromptSequencers.AOP(
363 path_tokenizer=PathTokenizers.StepSequence(
364 step_tokenizers=(
365 StepTokenizers.Coord(),
366 StepTokenizers.Coord(),
367 ),
368 ),
369 ),
370 True,
371 ),
372 ]
373 ],
374)
375def test_tokenizer_element_is_valid(el: _TokenizerElement, result: bool):
376 assert el.is_valid() == result
379@pytest.mark.parametrize(
380 ("tokenizer", "result"),
381 [
382 pytest.param(tokenizer, result, id=str(tokenizer))
383 for tokenizer, result in [
384 (MazeTokenizerModular(), True),
385 (MazeTokenizerModular.from_legacy(TokenizationMode.AOTP_CTT_indexed), True),
386 (MazeTokenizerModular(prompt_sequencer=PromptSequencers.AOP()), False),
387 ]
388 ],
389)
390def test_is_legacy_equivalent(tokenizer: MazeTokenizerModular, result: bool):
391 assert tokenizer.is_legacy_equivalent() == result
394def _helper_test_path_tokenizers(
395 pt: PathTokenizers._PathTokenizer,
396 maze: SolvedMaze,
397 footprint_inds: Sequence[int],
398):
399 ct: CoordTokenizers._CoordTokenizer = CoordTokenizers.UT()
400 path_toks: list[str] = pt.to_tokens(maze, ct)
401 path_toks_set: set[str] = set(path_toks)
402 footprint_inds: Int[np.ndarray, " footprint_index"] = np.array(footprint_inds)
403 footprints: Int[np.ndarray, "footprint_index row_col=2"] = maze.solution[
404 footprint_inds
405 ]
406 if StepTokenizers.Coord() in pt.step_tokenizers:
407 non_steps: set[CoordTup] = {tuple(c) for c in maze.solution} - {
408 tuple(c) for c in footprints
409 }
410 assert all(ct.to_tokens(coord)[0] in path_toks_set for coord in footprints)
411 assert all(ct.to_tokens(coord)[0] not in path_toks_set for coord in non_steps)
412 if StepTokenizers.Distance() in pt.step_tokenizers:
413 distances: list[int] = footprint_inds[1:] - footprint_inds[:-1]
414 assert (
415 len(
416 Counter(getattr(VOCAB, f"I_{d:03}") for d in distances)
417 - Counter(path_toks),
418 )
419 == 0
420 )
421 if StepTokenizers.Cardinal() in pt.step_tokenizers:
422 c = Counter(path_toks)
423 assert (
424 c[VOCAB.PATH_NORTH]
425 + c[VOCAB.PATH_SOUTH]
426 + c[VOCAB.PATH_EAST]
427 + c[VOCAB.PATH_WEST]
428 == len(footprint_inds) - 1
429 )
430 if StepTokenizers.Relative() in pt.step_tokenizers:
431 c = Counter(path_toks)
432 assert (
433 c[VOCAB.PATH_LEFT]
434 + c[VOCAB.PATH_RIGHT]
435 + c[VOCAB.PATH_FORWARD]
436 + c[VOCAB.PATH_BACKWARD]
437 == len(footprint_inds) - 1
438 )
441@pytest.mark.parametrize(
442 ("pt", "manual_maze"),
443 [
444 pytest.param(tokenizer, maze_kv[1], id=f"{tokenizer.name}-{maze_kv[0]}")
445 for maze_kv, tokenizer in itertools.product(
446 ASCII_MAZES.items(),
447 random.sample(
448 list(
449 all_instances(
450 PathTokenizers._PathTokenizer,
451 {_TokenizerElement: lambda x: x.is_valid()},
452 ),
453 ),
454 NUM_TOKENIZERS_TO_TEST,
455 ),
456 )
457 ],
458)
459def test_path_tokenizers(pt: PathTokenizers._PathTokenizer, manual_maze: MANUAL_MAZE):
460 solved_maze: SolvedMaze = SolvedMaze.from_ascii("\n".join(manual_maze.ascii))
461 match type(pt.step_size):
462 case StepSizes.Singles:
463 footprint_inds = range(solved_maze.solution.shape[0])
464 case StepSizes.Straightaways:
465 swy_coordtup_set: set[CoordTup] = {
466 tuple(c) for c in manual_maze.straightaway_footprints
467 }
468 footprint_inds: list[int] = [
469 i
470 for i, c in enumerate(solved_maze.solution)
471 if tuple(c) in swy_coordtup_set
472 ]
473 case StepSizes.Forks:
474 footprint_inds = solved_maze.get_solution_forking_points(
475 always_include_endpoints=True,
476 )[0]
477 case StepSizes.ForksAndStraightaways:
478 swy_step_inds: list[int] = StepSizes.Straightaways()._step_single_indices(
479 solved_maze,
480 )
481 footprint_inds: Int[np.ndarray, " footprint_index"] = np.concatenate(
482 (
483 solved_maze.get_solution_forking_points(
484 always_include_endpoints=True,
485 )[0],
486 swy_step_inds,
487 ),
488 )
489 footprint_inds, _ = np.unique(footprint_inds, axis=0, return_index=True)
490 _helper_test_path_tokenizers(
491 pt,
492 solved_maze,
493 footprint_inds,
494 )
497@pytest.mark.parametrize(
498 ("ep", "maze"),
499 [
500 pytest.param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]")
501 for (i, maze), tokenizer in itertools.product(
502 enumerate(MIXED_MAZES[:6]),
503 all_instances(
504 EdgePermuters._EdgePermuter,
505 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}),
506 ),
507 )
508 ],
509)
510def test_edge_permuters(ep: EdgePermuters._EdgePermuter, maze: LatticeMaze):
511 edges: ConnectionArray = connection_list_to_adj_list(
512 maze.connection_list,
513 shuffle_d0=False,
514 shuffle_d1=False,
515 )
516 edges_copy: ConnectionArray = connection_list_to_adj_list(
517 maze.connection_list,
518 shuffle_d0=False,
519 shuffle_d1=False,
520 )
521 assert np.array_equal(edges, edges_copy)
522 old_shape = edges.shape
523 permuted: ConnectionArray = ep._permute(edges)
524 match ep:
525 case EdgePermuters.RandomCoords():
526 assert permuted.shape == old_shape
527 assert edges is permuted
528 i = 0
529 while np.array_equal(permuted, edges_copy) and i < 2:
530 # Permute again in case for small mazes the random selection happened to not change anything
531 permuted: ConnectionArray = ep._permute(permuted)
532 i += 1
533 assert not np.array_equal(permuted, edges_copy)
534 case EdgePermuters.BothCoords():
535 new_shape = old_shape[0] * 2, *old_shape[1:]
536 n = old_shape[0]
537 assert permuted.shape == new_shape
538 assert np.array_equal(permuted[:n, ...], edges_copy)
539 assert np.array_equal(permuted[:n, 0, :], permuted[n:, 1, :])
540 assert np.array_equal(permuted[:n, 1, :], permuted[n:, 0, :])
541 assert edges is not permuted
544@pytest.mark.parametrize(
545 ("es", "maze"),
546 [
547 pytest.param(tokenizer, maze, id=f"{tokenizer.name}-maze[{i}]")
548 for (i, maze), tokenizer in itertools.product(
549 enumerate(MIXED_MAZES[:6]),
550 all_instances(
551 EdgeSubsets._EdgeSubset,
552 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}),
553 ),
554 )
555 ],
556)
557def test_edge_subsets(es: EdgeSubsets._EdgeSubset, maze: LatticeMaze):
558 edges: ConnectionArray = es._get_edges(maze)
559 n: int = maze.grid_n
560 match type(es):
561 case EdgeSubsets.AllLatticeEdges:
562 assert_shape: tuple = (2 * n * (n - 1), 2, 2)
563 case EdgeSubsets.ConnectionEdges:
564 if not es.walls:
565 assert_shape: tuple = (np.count_nonzero(maze.connection_list), 2, 2)
566 else:
567 assert_shape: tuple = (
568 2 * n * (n - 1) - np.count_nonzero(maze.connection_list),
569 2,
570 2,
571 )
572 assert edges.dtype == np.int8
573 assert assert_shape == tuple(edges.shape)
574 assert assert_shape == tuple(
575 np.unique(edges, axis=0).shape,
576 ) # All edges are unique (swapping leading/trailing coords is considered different)
577 assert np.array_equal(
578 manhattan_distance(edges),
579 np.array([1] * assert_shape[0], dtype=np.int8),
580 )
583@pytest.mark.parametrize(
584 ("tok_elem", "es", "maze"),
585 [
586 # we do a little accessing private members here
587 pytest.param(tok_elem, es, maze, id=f"{tok_elem.name}-{es.name}-maze[{i}]")
588 for (i, maze), tok_elem, es in itertools.product(
589 enumerate(MIXED_MAZES[:6]),
590 all_instances(
591 EdgeGroupings._EdgeGrouping,
592 frozendict.frozendict(
593 {
594 _TokenizerElement: lambda x: x.is_valid(),
595 # Add a condition to prune the range space that doesn't affect functionality being tested
596 EdgeGroupings.ByLeadingCoord: lambda x: x.intra
597 and x.connection_token_ordinal == 1,
598 },
599 ),
600 ),
601 all_instances(
602 EdgeSubsets._EdgeSubset,
603 frozendict.frozendict({_TokenizerElement: lambda x: x.is_valid()}),
604 ),
605 )
606 ],
607)
608def test_edge_groupings(
609 tok_elem: EdgeGroupings._EdgeGrouping,
610 es: EdgeSubsets._EdgeSubset,
611 maze: LatticeMaze,
612):
613 # we do a little more accessing private members here
614 edges: ConnectionArray = es._get_edges(maze)
615 # n: int = maze.grid_n
616 groups: Sequence[ConnectionArray] = tok_elem._group_edges(edges)
618 assert all(
619 not np.any(np.diff(g[:, 0], axis=0)) for g in groups
620 ) # Asserts that the leading coord is the same for all edges within each group
621 match type(tok_elem):
622 case EdgeGroupings.Ungrouped:
623 assert_shape = edges.shape[0], 1, 2, 2
624 assert tuple(groups.shape) == assert_shape
625 case EdgeGroupings.ByLeadingCoord:
626 assert len(groups) == np.unique(edges[:, 0, :], axis=0).shape[0]
627 assert sum(g.shape[0] for g in groups) == edges.shape[0]
628 # trailing_coords: list[CoordArray] = [g[:, 1, :] for g in groups]
629 # vector_diffs is the position vector difference between the trailing coords of each group
630 # These are stacked into a single array since we don't care about maintaining group separation
631 vector_diffs: CoordArray = np.stack(
632 list(flatten([np.diff(g[:, 1, :], axis=0) for g in groups], 1)),
633 )
634 if tok_elem.shuffle_group:
635 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)}
636 # The set of all 2D vectors between any 2 coords adjacent to a central coord
637 allowed_diffs = allowed_diffs.union(
638 {(-d[0], -d[1]) for d in allowed_diffs},
639 )
640 else:
641 # If vector_diffs are lexicographically sorted, these are the only possible values. Any other value indicates an error in sorting
642 allowed_diffs = {(1, -1), (1, 1), (0, 2), (2, 0)}
643 assert all(
644 tuple(diff) in allowed_diffs for diff in np.unique(vector_diffs, axis=0)
645 )
648random.seed(GLOBAL_SEED)
651@pytest.mark.parametrize(
652 ("tok_elem", "maze"),
653 [
654 pytest.param(tok_elem, maze, id=f"{tok_elem.name}-maze[{i}]")
655 for (i, maze), tok_elem in itertools.product(
656 enumerate(MAZE_DATASET),
657 random.sample(
658 list(
659 all_instances(
660 # yes we access a private member
661 AdjListTokenizers._AdjListTokenizer,
662 {
663 _TokenizerElement: lambda x: x.is_valid(),
664 },
665 ),
666 ),
667 100,
668 ),
669 )
670 ],
671)
672# too many branches and "too complex" but whatever
673def test_adjlist_tokenizers( # noqa: PLR0912,C901
674 tok_elem: AdjListTokenizers._AdjListTokenizer,
675 maze: LatticeMaze,
676):
677 toks: list[str] = tok_elem.to_tokens(maze, CoordTokenizers.UT())
678 tok_counter: Counter = Counter(toks)
679 n: int = maze.grid_n
680 edge_count: int = 1 # To be updated in match/case blocks
681 group_count: int = 1 # To be updated in match/case blocks
683 match tok_elem.edge_subset:
684 case EdgeSubsets.AllLatticeEdges():
685 edge_count *= n * (n - 1) * 2
686 case EdgeSubsets.ConnectionEdges(walls=False):
687 edge_count *= np.count_nonzero(maze.connection_list)
688 case EdgeSubsets.ConnectionEdges(walls=True):
689 edge_count *= n * (n - 1) * 2 - np.count_nonzero(maze.connection_list)
690 case _:
691 msg: str = f"`match` case missing for {tok_elem.edge_subset = }"
692 raise NotImplementedError(msg)
694 match tok_elem.edge_permuter:
695 case EdgePermuters.BothCoords():
696 edge_count *= 2
697 if tok_elem.edge_subset == EdgeSubsets.ConnectionEdges(walls=True):
698 group_count *= np.count_nonzero(
699 lattice_max_degrees(n) - maze.coord_degrees() > 0,
700 ) # All coords with 1 adjacent wall, not counting outer boundaries
701 else:
702 group_count *= np.count_nonzero(
703 maze.coord_degrees() > 0,
704 ) # All coords with >0 connections
705 case EdgePermuters.RandomCoords() | EdgePermuters.SortedCoords():
706 edge_count *= 1
707 group_count = None # Group count is stochastic
709 match type(tok_elem.edge_grouping):
710 case EdgeGroupings.Ungrouped:
711 group_count = edge_count # Override all above cases
712 case EdgeGroupings.ByLeadingCoord:
713 if group_count is not None:
714 group_count *= 1
715 if tok_elem.edge_grouping.intra:
716 assert tok_counter[VOCAB.ADJLIST_INTRA] == edge_count
717 case _:
718 msg: str = f"`match` case missing for {tok_elem.edge_grouping = }"
719 raise NotImplementedError(msg)
721 match type(tok_elem):
722 case AdjListTokenizers.AdjListCoord:
723 pass
724 case AdjListTokenizers.AdjListCardinal:
725 assert (
726 tok_counter[VOCAB.PATH_NORTH]
727 + tok_counter[VOCAB.PATH_SOUTH]
728 + tok_counter[VOCAB.PATH_EAST]
729 + tok_counter[VOCAB.PATH_WEST]
730 == edge_count
731 )
733 if group_count is not None:
734 if tok_elem.pre:
735 assert tok_counter[VOCAB.ADJLIST_PRE] == group_count
736 if tok_elem.post:
737 assert tok_counter[VOCAB.ADJACENCY_ENDLINE] == group_count
739 assert tok_counter[VOCAB.CONNECTOR] + tok_counter[VOCAB.ADJLIST_WALL] == edge_count
742@pytest.mark.parametrize(
743 ("tok_elem", "valid"),
744 [
745 pytest.param(
746 tok_elem,
747 valid,
748 id=f"{tok_elem!r}",
749 )
750 for tok_elem, valid in (
751 [
752 (StepSizes.ForksAndStraightaways(), False),
753 (StepSizes.Straightaways(), False),
754 (StepSizes.Forks(), True),
755 (AdjListTokenizers.AdjListCoord(), True),
756 (AdjListTokenizers.AdjListCoord(pre=True), False),
757 (AdjListTokenizers.AdjListCardinal(), True),
758 (AdjListTokenizers.AdjListCardinal(pre=True), False),
759 (EdgeGroupings.Ungrouped(), True),
760 (EdgeGroupings.ByLeadingCoord(), False),
761 (EdgeGroupings.ByLeadingCoord(connection_token_ordinal=0), False),
762 ]
763 )
764 ],
765)
766def test_unsupported_elements(tok_elem: _TokenizerElement, valid: bool):
767 assert tok_elem.is_valid() == valid