Coverage for tests/all_tokenizers/test_all_tokenizers.py: 99%
91 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
1import itertools
2import os
3from collections import Counter
4from typing import Callable, Iterable
6import pytest
7from zanj import ZANJ
9from maze_dataset import VOCAB, VOCAB_LIST, LatticeMaze
10from maze_dataset.maze.lattice_maze import SolvedMaze
11from maze_dataset.testing_utils import MIXED_MAZES
12from maze_dataset.token_utils import equal_except_adj_list_sequence
13from maze_dataset.tokenization import (
14 AdjListTokenizers,
15 CoordTokenizers,
16 EdgeGroupings,
17 EdgePermuters,
18 MazeTokenizerModular,
19 PathTokenizers,
20 PromptSequencers,
21 StepSizes,
22 StepTokenizers,
23 _TokenizerElement,
24)
25from maze_dataset.tokenization.modular.all_instances import all_instances
26from maze_dataset.tokenization.modular.all_tokenizers import (
27 EVERY_TEST_TOKENIZERS,
28 MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,
29 sample_tokenizers_for_test,
30 save_hashes,
31)
33# Size of the sample from `all_tokenizers.get_all_tokenizers()` to test
34# get from env, or set to default value of 100
35_os_env_num_tokenizers: str = os.getenv("NUM_TOKENIZERS_TO_TEST", "100")
36NUM_TOKENIZERS_TO_TEST: int | None = (
37 int(_os_env_num_tokenizers) if _os_env_num_tokenizers.isdigit() else None
38)
39print(f"{NUM_TOKENIZERS_TO_TEST = }")
41SAMPLED_TOKENIZERS: list[MazeTokenizerModular] = sample_tokenizers_for_test(
42 NUM_TOKENIZERS_TO_TEST,
43)
45SAMPLED_MAZES: list[SolvedMaze] = MIXED_MAZES[:6]
48@pytest.fixture(scope="session")
49def save_tokenizer_hashes():
50 save_hashes()
53@pytest.mark.parametrize(
54 "class_",
55 [pytest.param(c, id=c.__name__) for c in _TokenizerElement.__subclasses__()],
56)
57def test_all_instances_tokenizerelement(class_: type):
58 all_vals = list(
59 all_instances(
60 class_,
61 validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,
62 ),
63 )
64 assert len({hash(elem) for elem in all_vals}) == len(all_vals)
67SAMPLE_MIN: int = len(EVERY_TEST_TOKENIZERS)
70@pytest.mark.parametrize(
71 ("n", "result"),
72 [
73 pytest.param(i, result)
74 for i, result in [
75 (SAMPLE_MIN - 1, ValueError),
76 (SAMPLE_MIN, None),
77 (SAMPLE_MIN + 5, None),
78 (SAMPLE_MIN + 200, None),
79 ]
80 ],
81)
82def test_sample_tokenizers_for_test(n: int, result: type[Exception] | None):
83 if isinstance(result, type) and issubclass(result, Exception):
84 with pytest.raises(result):
85 sample_tokenizers_for_test(n)
86 return
87 mts: list[MazeTokenizerModular] = sample_tokenizers_for_test(n)
88 mts_set: set[MazeTokenizerModular] = set(mts)
89 assert len(mts) == len(mts_set)
90 assert set(EVERY_TEST_TOKENIZERS).issubset(mts_set)
91 if n > SAMPLE_MIN + 1:
92 mts2: list[MazeTokenizerModular] = sample_tokenizers_for_test(n)
93 assert set(mts2) != mts_set # Check that succesive samples are different
96@pytest.mark.parametrize(
97 "tokenizer",
98 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS],
99)
100def test_token_region_delimiters(tokenizer: MazeTokenizerModular):
101 """<PATH_START> and similar token region delimiters should appear at most 1 time, regardless of tokenizer."""
102 for maze in SAMPLED_MAZES:
103 counts: Counter = Counter(maze.as_tokens(tokenizer))
104 assert all([counts[tok] < 2 for tok in VOCAB_LIST[:8]])
107@pytest.mark.parametrize(
108 "tokenizer",
109 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS],
110)
111def test_token_stability(tokenizer: MazeTokenizerModular):
112 """Tests consistency of tokenizations over multiple method calls."""
113 for maze in SAMPLED_MAZES:
114 tokens1: list[str] = maze.as_tokens(tokenizer)
115 tokens2: list[str] = maze.as_tokens(tokenizer)
116 if tokenizer.has_element(
117 EdgeGroupings.ByLeadingCoord,
118 EdgePermuters.RandomCoords,
119 ) or tokenizer.has_element(
120 AdjListTokenizers.AdjListCardinal,
121 EdgePermuters.RandomCoords,
122 ):
123 # In this case, the adjlist is expected to have different token counts over multiple calls
124 # Exclude that region from the test
125 non_adjlist1 = tokens1[: tokens1.index(VOCAB.ADJLIST_START)]
126 non_adjlist1.extend(tokens1[tokens1.index(VOCAB.ADJLIST_END) :])
127 non_adjlist2 = tokens2[: tokens2.index(VOCAB.ADJLIST_START)]
128 non_adjlist2.extend(tokens2[tokens2.index(VOCAB.ADJLIST_END) :])
129 assert non_adjlist1 == non_adjlist2
130 else:
131 assert equal_except_adj_list_sequence(tokens1, tokens2)
134@pytest.mark.parametrize(
135 "tokenizer",
136 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS],
137)
138def test_tokenizer_properties(tokenizer: MazeTokenizerModular):
139 # Just make sure the call doesn't raise exception
140 assert len(tokenizer.name) > 5
142 assert tokenizer.vocab_size == 4096
143 assert isinstance(tokenizer.token_arr, Iterable)
144 assert all(isinstance(token, str) for token in tokenizer.token_arr)
145 assert tokenizer.token_arr[tokenizer.padding_token_index] == VOCAB.PADDING
147 # Just make sure the call doesn't raise exception
148 print(tokenizer.summary())
151@pytest.mark.parametrize(
152 "tokenizer",
153 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS],
154)
155def test_encode_decode(tokenizer: MazeTokenizerModular):
156 for maze in SAMPLED_MAZES:
157 maze_tok: list[str] = maze.as_tokens(maze_tokenizer=tokenizer)
158 maze_encoded: list[int] = tokenizer.encode(maze_tok)
159 maze_decoded: LatticeMaze = tokenizer.decode(maze_encoded)
160 assert maze_tok == maze_decoded
163@pytest.mark.parametrize(
164 "tokenizer",
165 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS],
166)
167def test_zanj_save_read(tokenizer: MazeTokenizerModular):
168 path = os.path.abspath(
169 os.path.join(
170 os.path.curdir,
171 "data",
172 f"mmt.{tokenizer.hash_b64()}.zanj",
173 ),
174 )
175 zanj = ZANJ()
176 zanj.save(tokenizer, path)
177 assert zanj.read(path) == tokenizer
180@pytest.mark.parametrize(
181 "tokenizer",
182 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS],
183)
184def test_is_AOTP(tokenizer: MazeTokenizerModular):
185 if isinstance(tokenizer.prompt_sequencer, PromptSequencers.AOTP):
186 assert tokenizer.is_AOTP()
187 else:
188 assert not tokenizer.is_AOTP()
191@pytest.mark.parametrize(
192 "tokenizer",
193 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS],
194)
195def test_is_UT(tokenizer: MazeTokenizerModular):
196 if isinstance(tokenizer.prompt_sequencer.coord_tokenizer, CoordTokenizers.UT):
197 assert tokenizer.is_UT()
198 else:
199 assert not tokenizer.is_UT()
202_has_elems_type = (
203 type[_TokenizerElement]
204 | _TokenizerElement
205 | Iterable[type[_TokenizerElement] | _TokenizerElement]
206)
209@pytest.mark.parametrize(
210 ("tokenizer", "elems", "result_func"),
211 [
212 pytest.param(
213 tokenizer,
214 elems_tuple[0],
215 elems_tuple[1],
216 id=f"{tokenizer.name}-{elems_tuple[0]}",
217 )
218 for tokenizer, elems_tuple in itertools.product(
219 SAMPLED_TOKENIZERS,
220 [
221 (
222 [PromptSequencers.AOTP()],
223 lambda mt, els: mt.prompt_sequencer == els[0],
224 ),
225 (PromptSequencers.AOTP(), lambda mt, els: mt.prompt_sequencer == els),
226 (
227 [CoordTokenizers.CTT()],
228 lambda mt, els: mt.prompt_sequencer.coord_tokenizer == els[0],
229 ),
230 (
231 CoordTokenizers.CTT(intra=False),
232 lambda mt, els: mt.prompt_sequencer.coord_tokenizer == els,
233 ),
234 (
235 [CoordTokenizers.CTT],
236 lambda mt, els: isinstance(
237 mt.prompt_sequencer.coord_tokenizer,
238 els[0],
239 ),
240 ),
241 (
242 CoordTokenizers._CoordTokenizer,
243 lambda mt, els: isinstance(
244 mt.prompt_sequencer.coord_tokenizer,
245 els,
246 ),
247 ),
248 (
249 StepSizes.Singles,
250 lambda mt, els: isinstance(
251 mt.prompt_sequencer.path_tokenizer.step_size,
252 els,
253 ),
254 ),
255 (
256 StepTokenizers.Coord,
257 lambda mt, els: any(
258 isinstance(step_tok, els)
259 for step_tok in mt.prompt_sequencer.path_tokenizer.step_tokenizers
260 ),
261 ),
262 (
263 [CoordTokenizers.CTT()],
264 lambda mt, els: mt.prompt_sequencer.coord_tokenizer == els[0],
265 ),
266 (
267 [CoordTokenizers.CTT, PathTokenizers.StepSequence],
268 lambda mt, els: isinstance(
269 mt.prompt_sequencer.coord_tokenizer,
270 els[0],
271 )
272 and isinstance(mt.prompt_sequencer.path_tokenizer, els[1]),
273 ),
274 # ((a for a in [CoordTokenizers.CTT, PathTokenizers.Coords]),
275 # lambda mt, els: isinstance(mt.coord_tokenizer, list(els)[0]) and isinstance(mt.path_tokenizer, list(els)[1])
276 # ),
277 (
278 [CoordTokenizers.CTT, PathTokenizers.StepSequence(post=False)],
279 lambda mt, els: isinstance(
280 mt.prompt_sequencer.coord_tokenizer,
281 els[0],
282 )
283 and mt.prompt_sequencer.path_tokenizer == els[1],
284 ),
285 (
286 [
287 CoordTokenizers.CTT,
288 PathTokenizers.StepSequence,
289 PromptSequencers.AOP(),
290 ],
291 lambda mt, els: isinstance(
292 mt.prompt_sequencer.coord_tokenizer,
293 els[0],
294 )
295 and isinstance(mt.prompt_sequencer.path_tokenizer, els[1])
296 and mt.prompt_sequencer == els[2],
297 ),
298 ],
299 )
300 ],
301)
302def test_has_element(
303 tokenizer: MazeTokenizerModular,
304 elems: _has_elems_type,
305 result_func: Callable[[MazeTokenizerModular, _has_elems_type], bool],
306):
307 assert tokenizer.has_element(elems) == result_func(tokenizer, elems)