Coverage for tests/unit/tokenization/test_maze_tokenization.py: 100%

14 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1import pytest 

2 

3from maze_dataset import ( 

4 LatticeMazeGenerators, 

5 MazeDataset, 

6 MazeDatasetConfig, 

7 SolvedMaze, 

8) 

9from maze_dataset.testing_utils import LEGACY_AND_EQUIVALENT_TOKENIZERS 

10from maze_dataset.tokenization import MazeTokenizer, MazeTokenizerModular 

11 

12 

13@pytest.mark.parametrize( 

14 "tokenizer", 

15 [ 

16 pytest.param(tokenizer, id=tokenizer.name) 

17 for tokenizer in LEGACY_AND_EQUIVALENT_TOKENIZERS 

18 ], 

19) 

20def test_tokenization_roundtrip(tokenizer: MazeTokenizer | MazeTokenizerModular): 

21 dataset: MazeDataset = MazeDataset.from_config( 

22 MazeDatasetConfig( 

23 name="test", 

24 grid_n=5, 

25 n_mazes=5, 

26 maze_ctor=LatticeMazeGenerators.gen_dfs, 

27 ), 

28 allow_generation_metadata_filter_mismatch=True, 

29 ) 

30 

31 dataset_tokenized: list[list[str]] = dataset.as_tokens(tokenizer) 

32 # dataset_tokenized_joined: list[str] = dataset.as_tokens( 

33 # tokenizer, join_tokens_individual_maze=True 

34 # ) 

35 

36 # TODO: can't test that these match because order in adjacency list is random 

37 

38 dataset_tokenized_individual: list[list[str]] = [ 

39 maze.as_tokens(tokenizer) for maze in dataset.mazes 

40 ] 

41 

42 # we can't type hint easily that from_tokens usually returns a SolvedMaze 

43 mazes_roundtrip: list[SolvedMaze] = [ 

44 SolvedMaze.from_tokens( # type: ignore[misc] 

45 tokens=maze_tokens, 

46 maze_tokenizer=tokenizer, 

47 ) 

48 for maze_tokens in dataset_tokenized 

49 ] 

50 

51 mazes_roundtrip_individual: list[SolvedMaze] = [ 

52 SolvedMaze.from_tokens( # type: ignore[misc] 

53 tokens=maze_tokens, 

54 maze_tokenizer=tokenizer, 

55 ) 

56 for maze_tokens in dataset_tokenized_individual 

57 ] 

58 

59 # NOTE: can't test the tokenization explicitly because order in adjacency list is random 

60 # test both tokenized as a whole and tokenized individually 

61 # for maze_tok, maze_tok_indiv in zip(dataset_tokenized, dataset_tokenized_individual): 

62 # assert all( 

63 # x == y 

64 # for x, y in zip(maze_tok, maze_tok_indiv) 

65 # ), f"maze_tok: {' '.join(maze_tok)}, maze_tok_indiv: {' '.join(maze_tok_indiv)}" 

66 

67 # test roundtrip 

68 for maze, maze_rt, maze_rt_indiv in zip( 

69 dataset.mazes, 

70 mazes_roundtrip, 

71 mazes_roundtrip_individual, 

72 strict=False, 

73 ): 

74 assert maze == maze_rt, f"maze: {maze}, maze_rt: {maze_rt}" 

75 assert maze == maze_rt_indiv, f"maze: {maze}, maze_rt_indiv: {maze_rt_indiv}"