Coverage for tests/unit/tokenization/test_maze_tokenization.py: 100%
14 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
1import pytest
3from maze_dataset import (
4 LatticeMazeGenerators,
5 MazeDataset,
6 MazeDatasetConfig,
7 SolvedMaze,
8)
9from maze_dataset.testing_utils import LEGACY_AND_EQUIVALENT_TOKENIZERS
10from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular
13@pytest.mark.parametrize(
14 "tokenizer",
15 [
16 pytest.param(tokenizer, id=tokenizer.name)
17 for tokenizer in LEGACY_AND_EQUIVALENT_TOKENIZERS
18 ],
19)
20def test_tokenization_roundtrip(tokenizer: MazeTokenizer | MazeTokenizerModular):
21 dataset: MazeDataset = MazeDataset.from_config(
22 MazeDatasetConfig(
23 name="test",
24 grid_n=5,
25 n_mazes=5,
26 maze_ctor=LatticeMazeGenerators.gen_dfs,
27 ),
28 allow_generation_metadata_filter_mismatch=True,
29 )
31 dataset_tokenized: list[list[str]] = dataset.as_tokens(tokenizer)
32 # dataset_tokenized_joined: list[str] = dataset.as_tokens(
33 # tokenizer, join_tokens_individual_maze=True
34 # )
36 # TODO: can't test that these match because order in adjacency list is random
38 dataset_tokenized_individual: list[list[str]] = [
39 maze.as_tokens(tokenizer) for maze in dataset.mazes
40 ]
42 # we can't type hint easily that from_tokens usually returns a SolvedMaze
43 mazes_roundtrip: list[SolvedMaze] = [
44 SolvedMaze.from_tokens( # type: ignore[misc]
45 tokens=maze_tokens,
46 maze_tokenizer=tokenizer,
47 )
48 for maze_tokens in dataset_tokenized
49 ]
51 mazes_roundtrip_individual: list[SolvedMaze] = [
52 SolvedMaze.from_tokens( # type: ignore[misc]
53 tokens=maze_tokens,
54 maze_tokenizer=tokenizer,
55 )
56 for maze_tokens in dataset_tokenized_individual
57 ]
59 # NOTE: can't test the tokenization explicitly because order in adjacency list is random
60 # test both tokenized as a whole and tokenized individually
61 # for maze_tok, maze_tok_indiv in zip(dataset_tokenized, dataset_tokenized_individual):
62 # assert all(
63 # x == y
64 # for x, y in zip(maze_tok, maze_tok_indiv)
65 # ), f"maze_tok: {' '.join(maze_tok)}, maze_tok_indiv: {' '.join(maze_tok_indiv)}"
67 # test roundtrip
68 for maze, maze_rt, maze_rt_indiv in zip(
69 dataset.mazes,
70 mazes_roundtrip,
71 mazes_roundtrip_individual,
72 strict=False,
73 ):
74 assert maze == maze_rt, f"maze: {maze}, maze_rt: {maze_rt}"
75 assert maze == maze_rt_indiv, f"maze: {maze}, maze_rt_indiv: {maze_rt_indiv}"