Coverage for tests/unit/tokenization/test_token_utils.py: 97%
175 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-27 23:43 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-27 23:43 -0600
1import itertools
2from typing import Callable
4import frozendict
5import numpy as np
6import pytest
7from jaxtyping import Int
9from maze_dataset import LatticeMaze
10from maze_dataset.constants import VOCAB, Connection, ConnectionArray
11from maze_dataset.generation import numpy_rng
12from maze_dataset.testing_utils import GRID_N, MAZE_DATASET
13from maze_dataset.token_utils import (
14 _coord_to_strings_UT,
15 coords_to_strings,
16 equal_except_adj_list_sequence,
17 get_adj_list_tokens,
18 get_origin_tokens,
19 get_path_tokens,
20 get_relative_direction,
21 get_target_tokens,
22 is_connection,
23 strings_to_coords,
24 tokens_between,
25)
26from maze_dataset.tokenization import (
27 PathTokenizers,
28 StepTokenizers,
29 get_tokens_up_to_path_start,
30)
31from maze_dataset.tokenization.modular.all_instances import FiniteValued, all_instances
32from maze_dataset.utils import (
33 lattice_connection_array,
34 manhattan_distance,
35)
37MAZE_TOKENS: tuple[list[str], str] = (
38 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
39 "AOTP_UT",
40)
41MAZE_TOKENS_AOTP_CTT_indexed: tuple[list[str], str] = (
42 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
43 "AOTP_CTT_indexed",
44)
45TEST_TOKEN_LISTS: list[tuple[list[str], str]] = [
46 MAZE_TOKENS,
47 MAZE_TOKENS_AOTP_CTT_indexed,
48]
51@pytest.mark.parametrize(
52 ("toks", "tokenizer_name"),
53 [
54 pytest.param(
55 token_list[0],
56 token_list[1],
57 id=f"{token_list[1]}",
58 )
59 for token_list in TEST_TOKEN_LISTS
60 ],
61)
62def test_tokens_between(toks: list[str], tokenizer_name: str):
63 result = tokens_between(toks, "<PATH_START>", "<PATH_END>")
64 match tokenizer_name:
65 case "AOTP_UT":
66 assert result == ["(1,0)", "(1,1)"]
67 case "AOTP_CTT_indexed":
68 assert result == ["(", "1", ",", "0", ")", "(", "1", ",", "1", ")"]
70 # Normal case
71 tokens = ["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"]
72 start_value = "quick"
73 end_value = "over"
74 assert tokens_between(tokens, start_value, end_value) == ["brown", "fox", "jumps"]
76 # Including start and end values
77 assert tokens_between(tokens, start_value, end_value, True, True) == [
78 "quick",
79 "brown",
80 "fox",
81 "jumps",
82 "over",
83 ]
85 # When start_value or end_value is not unique and except_when_tokens_not_unique is True
86 with pytest.raises(ValueError): # noqa: PT011
87 tokens_between(tokens, "the", "dog", False, False, True)
89 # When start_value or end_value is not unique and except_when_tokens_not_unique is False
90 assert tokens_between(tokens, "the", "dog", False, False, False) == [
91 "quick",
92 "brown",
93 "fox",
94 "jumps",
95 "over",
96 "the",
97 "lazy",
98 ]
100 # Empty tokens list
101 with pytest.raises(ValueError): # noqa: PT011
102 tokens_between([], "start", "end")
104 # start_value and end_value are the same
105 with pytest.raises(ValueError): # noqa: PT011
106 tokens_between(tokens, "fox", "fox")
108 # start_value or end_value not in the tokens list
109 with pytest.raises(ValueError): # noqa: PT011
110 tokens_between(tokens, "start", "end")
112 # start_value comes after end_value in the tokens list
113 with pytest.raises(AssertionError):
114 tokens_between(tokens, "over", "quick")
116 # start_value and end_value are at the beginning and end of the tokens list, respectively
117 assert tokens_between(tokens, "the", "dog", True, True) == tokens
119 # Single element in the tokens list, which is the same as start_value and end_value
120 with pytest.raises(ValueError): # noqa: PT011
121 tokens_between(["fox"], "fox", "fox", True, True)
124@pytest.mark.parametrize(
125 ("toks", "tokenizer_name"),
126 [
127 pytest.param(
128 token_list[0],
129 token_list[1],
130 id=f"{token_list[1]}",
131 )
132 for token_list in TEST_TOKEN_LISTS
133 ],
134)
135def test_tokens_between_out_of_order(toks: list[str], tokenizer_name: str):
136 assert tokenizer_name
137 with pytest.raises(AssertionError):
138 tokens_between(toks, "<PATH_END>", "<PATH_START>")
141@pytest.mark.parametrize(
142 ("toks", "tokenizer_name"),
143 [
144 pytest.param(
145 token_list[0],
146 token_list[1],
147 id=f"{token_list[1]}",
148 )
149 for token_list in TEST_TOKEN_LISTS
150 ],
151)
152def test_get_adj_list_tokens(toks: list[str], tokenizer_name: str):
153 result = get_adj_list_tokens(toks)
154 match tokenizer_name:
155 case "AOTP_UT":
156 expected = (
157 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ;".split()
158 )
159 case "AOTP_CTT_indexed":
160 expected = "( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ;".split()
161 assert result == expected
164@pytest.mark.parametrize(
165 ("toks", "tokenizer_name"),
166 [
167 pytest.param(
168 token_list[0],
169 token_list[1],
170 id=f"{token_list[1]}",
171 )
172 for token_list in TEST_TOKEN_LISTS
173 ],
174)
175def test_get_path_tokens(toks: list[str], tokenizer_name: str):
176 result_notrim = get_path_tokens(toks)
177 result_trim = get_path_tokens(toks, trim_end=True)
178 match tokenizer_name:
179 case "AOTP_UT":
180 assert result_notrim == ["<PATH_START>", "(1,0)", "(1,1)", "<PATH_END>"]
181 assert result_trim == ["(1,0)", "(1,1)"]
182 case "AOTP_CTT_indexed":
183 assert (
184 result_notrim == "<PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split()
185 )
186 assert result_trim == "( 1 , 0 ) ( 1 , 1 )".split()
189@pytest.mark.parametrize(
190 ("toks", "tokenizer_name"),
191 [
192 pytest.param(
193 token_list[0],
194 token_list[1],
195 id=f"{token_list[1]}",
196 )
197 for token_list in TEST_TOKEN_LISTS
198 ],
199)
200def test_get_origin_tokens(toks: list[str], tokenizer_name: str):
201 result = get_origin_tokens(toks)
202 match tokenizer_name:
203 case "AOTP_UT":
204 assert result == ["(1,0)"]
205 case "AOTP_CTT_indexed":
206 assert result == "( 1 , 0 )".split()
209@pytest.mark.parametrize(
210 ("toks", "tokenizer_name"),
211 [
212 pytest.param(
213 token_list[0],
214 token_list[1],
215 id=f"{token_list[1]}",
216 )
217 for token_list in TEST_TOKEN_LISTS
218 ],
219)
220def test_get_target_tokens(toks: list[str], tokenizer_name: str):
221 result = get_target_tokens(toks)
222 match tokenizer_name:
223 case "AOTP_UT":
224 assert result == ["(1,1)"]
225 case "AOTP_CTT_indexed":
226 assert result == "( 1 , 1 )".split()
229@pytest.mark.parametrize(
230 ("toks", "tokenizer_name"),
231 [
232 pytest.param(
233 token_list[0],
234 token_list[1],
235 id=f"{token_list[1]}",
236 )
237 for token_list in [MAZE_TOKENS]
238 ],
239)
240def test_get_tokens_up_to_path_start_including_start(
241 toks: list[str],
242 tokenizer_name: str,
243):
244 # Dont test on `MAZE_TOKENS_AOTP_CTT_indexed` because this function doesn't support `AOTP_CTT_indexed` when `include_start_coord=True`.
245 result = get_tokens_up_to_path_start(toks, include_start_coord=True)
246 match tokenizer_name:
247 case "AOTP_UT":
248 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0)".split()
249 case "AOTP_CTT_indexed":
250 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 )".split()
251 assert result == expected
254@pytest.mark.parametrize(
255 ("toks", "tokenizer_name"),
256 [
257 pytest.param(
258 token_list[0],
259 token_list[1],
260 id=f"{token_list[1]}",
261 )
262 for token_list in TEST_TOKEN_LISTS
263 ],
264)
265def test_get_tokens_up_to_path_start_excluding_start(
266 toks: list[str],
267 tokenizer_name: str,
268):
269 result = get_tokens_up_to_path_start(toks, include_start_coord=False)
270 match tokenizer_name:
271 case "AOTP_UT":
272 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START>".split()
273 case "AOTP_CTT_indexed":
274 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START>".split()
275 assert result == expected
278@pytest.mark.parametrize(
279 ("toks", "tokenizer_name"),
280 [
281 pytest.param(
282 token_list[0],
283 token_list[1],
284 id=f"{token_list[1]}",
285 )
286 for token_list in TEST_TOKEN_LISTS
287 ],
288)
289def test_strings_to_coords(toks: list[str], tokenizer_name: str):
290 assert tokenizer_name
291 adj_list = get_adj_list_tokens(toks)
292 skipped = strings_to_coords(adj_list, when_noncoord="skip")
293 included = strings_to_coords(adj_list, when_noncoord="include")
295 assert skipped == [
296 (0, 1),
297 (1, 1),
298 (1, 0),
299 (1, 1),
300 (0, 1),
301 (0, 0),
302 ]
304 assert included == [
305 (0, 1),
306 "<-->",
307 (1, 1),
308 ";",
309 (1, 0),
310 "<-->",
311 (1, 1),
312 ";",
313 (0, 1),
314 "<-->",
315 (0, 0),
316 ";",
317 ]
319 with pytest.raises(ValueError): # noqa: PT011
320 strings_to_coords(adj_list, when_noncoord="error")
322 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)") == [(1, 2), (5, 6)]
323 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="skip") == [
324 (1, 2),
325 (5, 6),
326 ]
327 assert strings_to_coords(
328 "(1,2) <ADJLIST_START> (5,6)",
329 when_noncoord="include",
330 ) == [(1, 2), "<ADJLIST_START>", (5, 6)]
331 with pytest.raises(ValueError): # noqa: PT011
332 strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="error")
335@pytest.mark.parametrize(
336 ("toks", "tokenizer_name"),
337 [
338 pytest.param(
339 token_list[0],
340 token_list[1],
341 id=f"{token_list[1]}",
342 )
343 for token_list in TEST_TOKEN_LISTS
344 ],
345)
346def test_coords_to_strings(toks: list[str], tokenizer_name: str):
347 assert tokenizer_name
348 adj_list = get_adj_list_tokens(toks)
349 # config = MazeDatasetConfig(name="test", grid_n=2, n_mazes=1)
350 coords = strings_to_coords(adj_list, when_noncoord="include")
352 skipped = coords_to_strings(
353 coords,
354 coord_to_strings_func=_coord_to_strings_UT,
355 when_noncoord="skip",
356 )
357 included = coords_to_strings(
358 coords,
359 coord_to_strings_func=_coord_to_strings_UT,
360 when_noncoord="include",
361 )
363 assert skipped == [
364 "(0,1)",
365 "(1,1)",
366 "(1,0)",
367 "(1,1)",
368 "(0,1)",
369 "(0,0)",
370 ]
372 assert included == [
373 "(0,1)",
374 "<-->",
375 "(1,1)",
376 ";",
377 "(1,0)",
378 "<-->",
379 "(1,1)",
380 ";",
381 "(0,1)",
382 "<-->",
383 "(0,0)",
384 ";",
385 ]
387 with pytest.raises(ValueError): # noqa: PT011
388 coords_to_strings(
389 coords,
390 coord_to_strings_func=_coord_to_strings_UT,
391 when_noncoord="error",
392 )
395def test_equal_except_adj_list_sequence():
396 assert equal_except_adj_list_sequence(MAZE_TOKENS[0], MAZE_TOKENS[0])
397 assert not equal_except_adj_list_sequence(
398 MAZE_TOKENS[0],
399 MAZE_TOKENS_AOTP_CTT_indexed[0],
400 )
401 assert equal_except_adj_list_sequence(
402 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
403 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
404 )
405 assert equal_except_adj_list_sequence(
406 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
407 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
408 )
409 assert equal_except_adj_list_sequence(
410 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
411 "<ADJLIST_START> (1,1) <--> (0,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
412 )
413 assert not equal_except_adj_list_sequence(
414 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
415 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,1) (1,0) <PATH_END>".split(),
416 )
417 assert not equal_except_adj_list_sequence(
418 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
419 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END> <PATH_END>".split(),
420 )
421 assert not equal_except_adj_list_sequence(
422 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
423 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
424 )
425 assert not equal_except_adj_list_sequence(
426 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
427 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
428 )
429 with pytest.raises(ValueError): # noqa: PT011
430 equal_except_adj_list_sequence(
431 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
432 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
433 )
434 with pytest.raises(ValueError): # noqa: PT011
435 equal_except_adj_list_sequence(
436 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
437 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
438 )
439 assert not equal_except_adj_list_sequence(
440 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
441 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(),
442 )
444 # CTT
445 assert equal_except_adj_list_sequence(
446 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
447 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
448 )
449 assert equal_except_adj_list_sequence(
450 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
451 "<ADJLIST_START> ( 1 , 1 ) <--> ( 0 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
452 )
453 # This inactive test demonstrates the lack of robustness of the function for comparing source `LatticeMaze` objects.
454 # See function documentation for details.
455 # assert not equal_except_adj_list_sequence(
456 # "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
457 # "<ADJLIST_START> ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split()
458 # )
461# @mivanit: this was really difficult to understand
462@pytest.mark.parametrize(
463 ("type_", "validation_funcs", "assertion"),
464 [
465 pytest.param(
466 type_,
467 vfs,
468 assertion,
469 id=f"{i}-{type_.__name__}",
470 )
471 for i, (type_, vfs, assertion) in enumerate(
472 [
473 (
474 # type
475 PathTokenizers._PathTokenizer,
476 # validation_funcs
477 dict(),
478 # assertion
479 lambda x: PathTokenizers.StepSequence(
480 step_tokenizers=(StepTokenizers.Distance(),),
481 )
482 in x,
483 ),
484 (
485 # type
486 PathTokenizers._PathTokenizer,
487 # validation_funcs
488 {PathTokenizers._PathTokenizer: lambda x: x.is_valid()},
489 # assertion
490 lambda x: PathTokenizers.StepSequence(
491 step_tokenizers=(StepTokenizers.Distance(),),
492 )
493 not in x
494 and PathTokenizers.StepSequence(
495 step_tokenizers=(
496 StepTokenizers.Coord(),
497 StepTokenizers.Coord(),
498 ),
499 )
500 not in x,
501 ),
502 ],
503 )
504 ],
505)
506def test_all_instances2(
507 type_: FiniteValued,
508 validation_funcs: frozendict.frozendict[
509 FiniteValued,
510 Callable[[FiniteValued], bool],
511 ],
512 assertion: Callable[[list[FiniteValued]], bool],
513):
514 assert assertion(all_instances(type_, validation_funcs))
517@pytest.mark.parametrize(
518 ("coords", "result"),
519 [
520 pytest.param(
521 np.array(coords),
522 res,
523 id=f"{coords}",
524 )
525 for coords, res in (
526 [
527 ([[0, 0], [0, 1], [1, 1]], VOCAB.PATH_RIGHT),
528 ([[0, 0], [1, 0], [1, 1]], VOCAB.PATH_LEFT),
529 ([[0, 0], [0, 1], [0, 2]], VOCAB.PATH_FORWARD),
530 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD),
531 ([[0, 0], [0, 1], [0, 1]], VOCAB.PATH_STAY),
532 ([[1, 1], [0, 1], [0, 0]], VOCAB.PATH_LEFT),
533 ([[1, 1], [1, 0], [0, 0]], VOCAB.PATH_RIGHT),
534 ([[0, 2], [0, 1], [0, 0]], VOCAB.PATH_FORWARD),
535 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD),
536 ([[0, 1], [0, 1], [0, 0]], ValueError),
537 ([[0, 1], [1, 1], [0, 0]], ValueError),
538 ([[1, 0], [1, 1], [0, 0]], ValueError),
539 ([[0, 1], [0, 2], [0, 0]], ValueError),
540 ([[0, 1], [0, 0], [0, 0]], VOCAB.PATH_STAY),
541 ([[1, 1], [0, 0], [0, 1]], ValueError),
542 ([[1, 1], [0, 0], [1, 0]], ValueError),
543 ([[0, 2], [0, 0], [0, 1]], ValueError),
544 ([[0, 0], [0, 0], [0, 1]], ValueError),
545 ([[0, 1], [0, 0], [0, 1]], VOCAB.PATH_BACKWARD),
546 ([[-1, 0], [0, 0], [1, 0]], VOCAB.PATH_FORWARD),
547 ([[-1, 0], [0, 0], [0, 1]], VOCAB.PATH_LEFT),
548 ([[-1, 0], [0, 0], [-1, 0]], VOCAB.PATH_BACKWARD),
549 ([[-1, 0], [0, 0], [0, -1]], VOCAB.PATH_RIGHT),
550 ([[-1, 0], [0, 0], [1, 0], [2, 0]], ValueError),
551 ([[-1, 0], [0, 0]], ValueError),
552 ([[-1, 0, 0], [0, 0, 0]], ValueError),
553 ]
554 )
555 ],
556)
557def test_get_relative_direction(
558 coords: Int[np.ndarray, "prev_cur_next=3 axis=2"],
559 result: str | type[Exception],
560):
561 if isinstance(result, type) and issubclass(result, Exception):
562 with pytest.raises(result):
563 get_relative_direction(coords)
564 return
565 assert get_relative_direction(coords) == result
568@pytest.mark.parametrize(
569 ("edges", "result"),
570 [
571 pytest.param(
572 edges,
573 res,
574 id=f"{edges}",
575 )
576 for edges, res in (
577 [
578 (np.array([[0, 0], [0, 1]]), 1),
579 (np.array([[1, 0], [0, 1]]), 2),
580 (np.array([[-1, 0], [0, 1]]), 2),
581 (np.array([[0, 0], [5, 3]]), 8),
582 (
583 np.array(
584 [
585 [[0, 0], [0, 1]],
586 [[1, 0], [0, 1]],
587 [[-1, 0], [0, 1]],
588 [[0, 0], [5, 3]],
589 ],
590 ),
591 [1, 2, 2, 8],
592 ),
593 (np.array([[[0, 0], [5, 3]]]), [8]),
594 ]
595 )
596 ],
597)
598def test_manhattan_distance(
599 edges: ConnectionArray | Connection,
600 result: Int[np.ndarray, " edges"] | Int[np.ndarray, ""] | type[Exception],
601):
602 if isinstance(result, type) and issubclass(result, Exception):
603 with pytest.raises(result):
604 manhattan_distance(edges)
605 return
606 assert np.array_equal(manhattan_distance(edges), np.array(result, dtype=np.int8))
609@pytest.mark.parametrize(
610 "n",
611 [pytest.param(n) for n in [2, 3, 5, 20]],
612)
613def test_lattice_connection_arrray(n):
614 edges = lattice_connection_array(n)
615 assert tuple(edges.shape) == (2 * n * (n - 1), 2, 2)
616 assert np.all(np.sum(edges[:, 1], axis=1) > np.sum(edges[:, 0], axis=1))
617 assert tuple(np.unique(edges, axis=0).shape) == (2 * n * (n - 1), 2, 2)
620@pytest.mark.parametrize(
621 ("edges", "maze"),
622 [
623 pytest.param(
624 edges(),
625 maze,
626 id=f"edges[{i}]; maze[{j}]",
627 )
628 for (i, edges), (j, maze) in itertools.product(
629 enumerate(
630 [
631 lambda: lattice_connection_array(GRID_N),
632 lambda: np.flip(lattice_connection_array(GRID_N), axis=1),
633 lambda: lattice_connection_array(GRID_N - 1),
634 lambda: numpy_rng.choice(
635 lattice_connection_array(GRID_N),
636 2 * GRID_N,
637 axis=0,
638 ),
639 lambda: numpy_rng.choice(
640 lattice_connection_array(GRID_N),
641 1,
642 axis=0,
643 ),
644 ],
645 ),
646 enumerate(MAZE_DATASET.mazes),
647 )
648 ],
649)
650def test_is_connection(edges: ConnectionArray, maze: LatticeMaze):
651 output = is_connection(edges, maze.connection_list)
652 sorted_edges = np.sort(edges, axis=1)
653 edge_direction = (
654 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0
655 ).astype(np.int8)
656 assert np.array_equal(
657 output,
658 maze.connection_list[
659 edge_direction,
660 sorted_edges[:, 0, 0],
661 sorted_edges[:, 0, 1],
662 ],
663 )