Coverage for tests/unit/tokenization/test_token_utils.py: 97%
175 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 21:38 -0700
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 21:38 -0700
1import itertools
2from typing import Callable
4import frozendict
5import numpy as np
6import pytest
7from jaxtyping import Int
9from maze_dataset import LatticeMaze
10from maze_dataset.constants import VOCAB, Connection, ConnectionArray
11from maze_dataset.generation import _NUMPY_RNG
12from maze_dataset.testing_utils import GRID_N, MAZE_DATASET
13from maze_dataset.token_utils import (
14 _coord_to_strings_UT,
15 coords_to_strings,
16 equal_except_adj_list_sequence,
17 get_adj_list_tokens,
18 get_origin_tokens,
19 get_path_tokens,
20 get_relative_direction,
21 get_target_tokens,
22 is_connection,
23 strings_to_coords,
24 tokens_between,
25)
26from maze_dataset.tokenization import (
27 PathTokenizers,
28 StepTokenizers,
29 get_tokens_up_to_path_start,
30)
31from maze_dataset.tokenization.modular.all_instances import FiniteValued, all_instances
32from maze_dataset.utils import (
33 lattice_connection_array,
34 manhattan_distance,
35)
37MAZE_TOKENS: tuple[list[str], str] = (
38 [
39 "<ADJLIST_START>",
40 "(0,1)",
41 "<-->",
42 "(1,1)",
43 ";",
44 "(1,0)",
45 "<-->",
46 "(1,1)",
47 ";",
48 "(0,1)",
49 "<-->",
50 "(0,0)",
51 ";",
52 "<ADJLIST_END>",
53 "<ORIGIN_START>",
54 "(1,0)",
55 "<ORIGIN_END>",
56 "<TARGET_START>",
57 "(1,1)",
58 "<TARGET_END>",
59 "<PATH_START>",
60 "(1,0)",
61 "(1,1)",
62 "<PATH_END>",
63 ],
64 "AOTP_UT",
65)
66MAZE_TOKENS_AOTP_CTT_indexed: tuple[list[str], str] = (
67 [
68 "<ADJLIST_START>",
69 "(",
70 "0",
71 ",",
72 "1",
73 ")",
74 "<-->",
75 "(",
76 "1",
77 ",",
78 "1",
79 ")",
80 ";",
81 "(",
82 "1",
83 ",",
84 "0",
85 ")",
86 "<-->",
87 "(",
88 "1",
89 ",",
90 "1",
91 ")",
92 ";",
93 "(",
94 "0",
95 ",",
96 "1",
97 ")",
98 "<-->",
99 "(",
100 "0",
101 ",",
102 "0",
103 ")",
104 ";",
105 "<ADJLIST_END>",
106 "<ORIGIN_START>",
107 "(",
108 "1",
109 ",",
110 "0",
111 ")",
112 "<ORIGIN_END>",
113 "<TARGET_START>",
114 "(",
115 "1",
116 ",",
117 "1",
118 ")",
119 "<TARGET_END>",
120 "<PATH_START>",
121 "(",
122 "1",
123 ",",
124 "0",
125 ")",
126 "(",
127 "1",
128 ",",
129 "1",
130 ")",
131 "<PATH_END>",
132 ],
133 "AOTP_CTT_indexed",
134)
135TEST_TOKEN_LISTS: list[tuple[list[str], str]] = [
136 MAZE_TOKENS,
137 MAZE_TOKENS_AOTP_CTT_indexed,
138]
141@pytest.mark.parametrize(
142 ("toks", "tokenizer_name"),
143 [
144 pytest.param(
145 token_list[0],
146 token_list[1],
147 id=f"{token_list[1]}",
148 )
149 for token_list in TEST_TOKEN_LISTS
150 ],
151)
152def test_tokens_between(toks: list[str], tokenizer_name: str):
153 result = tokens_between(toks, "<PATH_START>", "<PATH_END>")
154 match tokenizer_name:
155 case "AOTP_UT":
156 assert result == ["(1,0)", "(1,1)"]
157 case "AOTP_CTT_indexed":
158 assert result == ["(", "1", ",", "0", ")", "(", "1", ",", "1", ")"]
160 # Normal case
161 tokens = ["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"]
162 start_value = "quick"
163 end_value = "over"
164 assert tokens_between(tokens, start_value, end_value) == ["brown", "fox", "jumps"]
166 # Including start and end values
167 assert tokens_between(tokens, start_value, end_value, True, True) == [
168 "quick",
169 "brown",
170 "fox",
171 "jumps",
172 "over",
173 ]
175 # When start_value or end_value is not unique and except_when_tokens_not_unique is True
176 with pytest.raises(ValueError): # noqa: PT011
177 tokens_between(tokens, "the", "dog", False, False, True)
179 # When start_value or end_value is not unique and except_when_tokens_not_unique is False
180 assert tokens_between(tokens, "the", "dog", False, False, False) == [
181 "quick",
182 "brown",
183 "fox",
184 "jumps",
185 "over",
186 "the",
187 "lazy",
188 ]
190 # Empty tokens list
191 with pytest.raises(ValueError): # noqa: PT011
192 tokens_between([], "start", "end")
194 # start_value and end_value are the same
195 with pytest.raises(ValueError): # noqa: PT011
196 tokens_between(tokens, "fox", "fox")
198 # start_value or end_value not in the tokens list
199 with pytest.raises(ValueError): # noqa: PT011
200 tokens_between(tokens, "start", "end")
202 # start_value comes after end_value in the tokens list
203 with pytest.raises(AssertionError):
204 tokens_between(tokens, "over", "quick")
206 # start_value and end_value are at the beginning and end of the tokens list, respectively
207 assert tokens_between(tokens, "the", "dog", True, True) == tokens
209 # Single element in the tokens list, which is the same as start_value and end_value
210 with pytest.raises(ValueError): # noqa: PT011
211 tokens_between(["fox"], "fox", "fox", True, True)
214@pytest.mark.parametrize(
215 ("toks", "tokenizer_name"),
216 [
217 pytest.param(
218 token_list[0],
219 token_list[1],
220 id=f"{token_list[1]}",
221 )
222 for token_list in TEST_TOKEN_LISTS
223 ],
224)
225def test_tokens_between_out_of_order(toks: list[str], tokenizer_name: str):
226 assert tokenizer_name
227 with pytest.raises(AssertionError):
228 tokens_between(toks, "<PATH_END>", "<PATH_START>")
231@pytest.mark.parametrize(
232 ("toks", "tokenizer_name"),
233 [
234 pytest.param(
235 token_list[0],
236 token_list[1],
237 id=f"{token_list[1]}",
238 )
239 for token_list in TEST_TOKEN_LISTS
240 ],
241)
242def test_get_adj_list_tokens(toks: list[str], tokenizer_name: str):
243 result = get_adj_list_tokens(toks)
244 match tokenizer_name:
245 case "AOTP_UT":
246 expected = [
247 "(0,1)",
248 "<-->",
249 "(1,1)",
250 ";",
251 "(1,0)",
252 "<-->",
253 "(1,1)",
254 ";",
255 "(0,1)",
256 "<-->",
257 "(0,0)",
258 ";",
259 ]
260 case "AOTP_CTT_indexed":
261 expected = [
262 "(",
263 "0",
264 ",",
265 "1",
266 ")",
267 "<-->",
268 "(",
269 "1",
270 ",",
271 "1",
272 ")",
273 ";",
274 "(",
275 "1",
276 ",",
277 "0",
278 ")",
279 "<-->",
280 "(",
281 "1",
282 ",",
283 "1",
284 ")",
285 ";",
286 "(",
287 "0",
288 ",",
289 "1",
290 ")",
291 "<-->",
292 "(",
293 "0",
294 ",",
295 "0",
296 ")",
297 ";",
298 ]
299 assert result == expected
302@pytest.mark.parametrize(
303 ("toks", "tokenizer_name"),
304 [
305 pytest.param(
306 token_list[0],
307 token_list[1],
308 id=f"{token_list[1]}",
309 )
310 for token_list in TEST_TOKEN_LISTS
311 ],
312)
313def test_get_path_tokens(toks: list[str], tokenizer_name: str):
314 result_notrim = get_path_tokens(toks)
315 result_trim = get_path_tokens(toks, trim_end=True)
316 match tokenizer_name:
317 case "AOTP_UT":
318 assert result_notrim == ["<PATH_START>", "(1,0)", "(1,1)", "<PATH_END>"]
319 assert result_trim == ["(1,0)", "(1,1)"]
320 case "AOTP_CTT_indexed":
321 assert result_notrim == [
322 "<PATH_START>",
323 "(",
324 "1",
325 ",",
326 "0",
327 ")",
328 "(",
329 "1",
330 ",",
331 "1",
332 ")",
333 "<PATH_END>",
334 ]
335 assert result_trim == ["(", "1", ",", "0", ")", "(", "1", ",", "1", ")"]
338@pytest.mark.parametrize(
339 ("toks", "tokenizer_name"),
340 [
341 pytest.param(
342 token_list[0],
343 token_list[1],
344 id=f"{token_list[1]}",
345 )
346 for token_list in TEST_TOKEN_LISTS
347 ],
348)
349def test_get_origin_tokens(toks: list[str], tokenizer_name: str):
350 result = get_origin_tokens(toks)
351 match tokenizer_name:
352 case "AOTP_UT":
353 assert result == ["(1,0)"]
354 case "AOTP_CTT_indexed":
355 assert result == ["(", "1", ",", "0", ")"]
358@pytest.mark.parametrize(
359 ("toks", "tokenizer_name"),
360 [
361 pytest.param(
362 token_list[0],
363 token_list[1],
364 id=f"{token_list[1]}",
365 )
366 for token_list in TEST_TOKEN_LISTS
367 ],
368)
369def test_get_target_tokens(toks: list[str], tokenizer_name: str):
370 result = get_target_tokens(toks)
371 match tokenizer_name:
372 case "AOTP_UT":
373 assert result == ["(1,1)"]
374 case "AOTP_CTT_indexed":
375 assert result == ["(", "1", ",", "1", ")"]
378@pytest.mark.parametrize(
379 ("toks", "tokenizer_name"),
380 [
381 pytest.param(
382 token_list[0],
383 token_list[1],
384 id=f"{token_list[1]}",
385 )
386 for token_list in [MAZE_TOKENS]
387 ],
388)
389def test_get_tokens_up_to_path_start_including_start(
390 toks: list[str],
391 tokenizer_name: str,
392):
393 # Dont test on `MAZE_TOKENS_AOTP_CTT_indexed` because this function doesn't support `AOTP_CTT_indexed` when `include_start_coord=True`.
394 result = get_tokens_up_to_path_start(toks, include_start_coord=True)
395 match tokenizer_name:
396 case "AOTP_UT":
397 expected = [
398 "<ADJLIST_START>",
399 "(0,1)",
400 "<-->",
401 "(1,1)",
402 ";",
403 "(1,0)",
404 "<-->",
405 "(1,1)",
406 ";",
407 "(0,1)",
408 "<-->",
409 "(0,0)",
410 ";",
411 "<ADJLIST_END>",
412 "<ORIGIN_START>",
413 "(1,0)",
414 "<ORIGIN_END>",
415 "<TARGET_START>",
416 "(1,1)",
417 "<TARGET_END>",
418 "<PATH_START>",
419 "(1,0)",
420 ]
421 case "AOTP_CTT_indexed":
422 expected = [
423 "<ADJLIST_START>",
424 "(",
425 "0",
426 ",",
427 "1",
428 ")",
429 "<-->",
430 "(",
431 "1",
432 ",",
433 "1",
434 ")",
435 ";",
436 "(",
437 "1",
438 ",",
439 "0",
440 ")",
441 "<-->",
442 "(",
443 "1",
444 ",",
445 "1",
446 ")",
447 ";",
448 "(",
449 "0",
450 ",",
451 "1",
452 ")",
453 "<-->",
454 "(",
455 "0",
456 ",",
457 "0",
458 ")",
459 ";",
460 "<ADJLIST_END>",
461 "<ORIGIN_START>",
462 "(",
463 "1",
464 ",",
465 "0",
466 ")",
467 "<ORIGIN_END>",
468 "<TARGET_START>",
469 "(",
470 "1",
471 ",",
472 "1",
473 ")",
474 "<TARGET_END>",
475 "<PATH_START>",
476 "(",
477 "1",
478 ",",
479 "0",
480 ")",
481 ]
482 assert result == expected
485@pytest.mark.parametrize(
486 ("toks", "tokenizer_name"),
487 [
488 pytest.param(
489 token_list[0],
490 token_list[1],
491 id=f"{token_list[1]}",
492 )
493 for token_list in TEST_TOKEN_LISTS
494 ],
495)
496def test_get_tokens_up_to_path_start_excluding_start(
497 toks: list[str],
498 tokenizer_name: str,
499):
500 result = get_tokens_up_to_path_start(toks, include_start_coord=False)
501 match tokenizer_name:
502 case "AOTP_UT":
503 expected = [
504 "<ADJLIST_START>",
505 "(0,1)",
506 "<-->",
507 "(1,1)",
508 ";",
509 "(1,0)",
510 "<-->",
511 "(1,1)",
512 ";",
513 "(0,1)",
514 "<-->",
515 "(0,0)",
516 ";",
517 "<ADJLIST_END>",
518 "<ORIGIN_START>",
519 "(1,0)",
520 "<ORIGIN_END>",
521 "<TARGET_START>",
522 "(1,1)",
523 "<TARGET_END>",
524 "<PATH_START>",
525 ]
526 case "AOTP_CTT_indexed":
527 expected = [
528 "<ADJLIST_START>",
529 "(",
530 "0",
531 ",",
532 "1",
533 ")",
534 "<-->",
535 "(",
536 "1",
537 ",",
538 "1",
539 ")",
540 ";",
541 "(",
542 "1",
543 ",",
544 "0",
545 ")",
546 "<-->",
547 "(",
548 "1",
549 ",",
550 "1",
551 ")",
552 ";",
553 "(",
554 "0",
555 ",",
556 "1",
557 ")",
558 "<-->",
559 "(",
560 "0",
561 ",",
562 "0",
563 ")",
564 ";",
565 "<ADJLIST_END>",
566 "<ORIGIN_START>",
567 "(",
568 "1",
569 ",",
570 "0",
571 ")",
572 "<ORIGIN_END>",
573 "<TARGET_START>",
574 "(",
575 "1",
576 ",",
577 "1",
578 ")",
579 "<TARGET_END>",
580 "<PATH_START>",
581 ]
582 assert result == expected
585@pytest.mark.parametrize(
586 ("toks", "tokenizer_name"),
587 [
588 pytest.param(
589 token_list[0],
590 token_list[1],
591 id=f"{token_list[1]}",
592 )
593 for token_list in TEST_TOKEN_LISTS
594 ],
595)
596def test_strings_to_coords(toks: list[str], tokenizer_name: str):
597 assert tokenizer_name
598 adj_list = get_adj_list_tokens(toks)
599 skipped = strings_to_coords(adj_list, when_noncoord="skip")
600 included = strings_to_coords(adj_list, when_noncoord="include")
602 assert skipped == [
603 (0, 1),
604 (1, 1),
605 (1, 0),
606 (1, 1),
607 (0, 1),
608 (0, 0),
609 ]
611 assert included == [
612 (0, 1),
613 "<-->",
614 (1, 1),
615 ";",
616 (1, 0),
617 "<-->",
618 (1, 1),
619 ";",
620 (0, 1),
621 "<-->",
622 (0, 0),
623 ";",
624 ]
626 with pytest.raises(ValueError): # noqa: PT011
627 strings_to_coords(adj_list, when_noncoord="error")
629 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)") == [(1, 2), (5, 6)]
630 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="skip") == [
631 (1, 2),
632 (5, 6),
633 ]
634 assert strings_to_coords(
635 "(1,2) <ADJLIST_START> (5,6)",
636 when_noncoord="include",
637 ) == [(1, 2), "<ADJLIST_START>", (5, 6)]
638 with pytest.raises(ValueError): # noqa: PT011
639 strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="error")
642@pytest.mark.parametrize(
643 ("toks", "tokenizer_name"),
644 [
645 pytest.param(
646 token_list[0],
647 token_list[1],
648 id=f"{token_list[1]}",
649 )
650 for token_list in TEST_TOKEN_LISTS
651 ],
652)
653def test_coords_to_strings(toks: list[str], tokenizer_name: str):
654 assert tokenizer_name
655 adj_list = get_adj_list_tokens(toks)
656 # config = MazeDatasetConfig(name="test", grid_n=2, n_mazes=1)
657 coords = strings_to_coords(adj_list, when_noncoord="include")
659 skipped = coords_to_strings(
660 coords,
661 coord_to_strings_func=_coord_to_strings_UT,
662 when_noncoord="skip",
663 )
664 included = coords_to_strings(
665 coords,
666 coord_to_strings_func=_coord_to_strings_UT,
667 when_noncoord="include",
668 )
670 assert skipped == [
671 "(0,1)",
672 "(1,1)",
673 "(1,0)",
674 "(1,1)",
675 "(0,1)",
676 "(0,0)",
677 ]
679 assert included == [
680 "(0,1)",
681 "<-->",
682 "(1,1)",
683 ";",
684 "(1,0)",
685 "<-->",
686 "(1,1)",
687 ";",
688 "(0,1)",
689 "<-->",
690 "(0,0)",
691 ";",
692 ]
694 with pytest.raises(ValueError): # noqa: PT011
695 coords_to_strings(
696 coords,
697 coord_to_strings_func=_coord_to_strings_UT,
698 when_noncoord="error",
699 )
702def test_equal_except_adj_list_sequence():
703 assert equal_except_adj_list_sequence(MAZE_TOKENS[0], MAZE_TOKENS[0])
704 assert not equal_except_adj_list_sequence(
705 MAZE_TOKENS[0],
706 MAZE_TOKENS_AOTP_CTT_indexed[0],
707 )
708 assert equal_except_adj_list_sequence(
709 [
710 "<ADJLIST_START>",
711 "(0,1)",
712 "<-->",
713 "(1,1)",
714 ";",
715 "(1,0)",
716 "<-->",
717 "(1,1)",
718 ";",
719 "(0,1)",
720 "<-->",
721 "(0,0)",
722 ";",
723 "<ADJLIST_END>",
724 "<ORIGIN_START>",
725 "(1,0)",
726 "<ORIGIN_END>",
727 "<TARGET_START>",
728 "(1,1)",
729 "<TARGET_END>",
730 "<PATH_START>",
731 "(1,0)",
732 "(1,1)",
733 "<PATH_END>",
734 ],
735 [
736 "<ADJLIST_START>",
737 "(0,1)",
738 "<-->",
739 "(1,1)",
740 ";",
741 "(1,0)",
742 "<-->",
743 "(1,1)",
744 ";",
745 "(0,1)",
746 "<-->",
747 "(0,0)",
748 ";",
749 "<ADJLIST_END>",
750 "<ORIGIN_START>",
751 "(1,0)",
752 "<ORIGIN_END>",
753 "<TARGET_START>",
754 "(1,1)",
755 "<TARGET_END>",
756 "<PATH_START>",
757 "(1,0)",
758 "(1,1)",
759 "<PATH_END>",
760 ],
761 )
762 assert equal_except_adj_list_sequence(
763 [
764 "<ADJLIST_START>",
765 "(0,1)",
766 "<-->",
767 "(1,1)",
768 ";",
769 "(1,0)",
770 "<-->",
771 "(1,1)",
772 ";",
773 "(0,1)",
774 "<-->",
775 "(0,0)",
776 ";",
777 "<ADJLIST_END>",
778 "<ORIGIN_START>",
779 "(1,0)",
780 "<ORIGIN_END>",
781 "<TARGET_START>",
782 "(1,1)",
783 "<TARGET_END>",
784 "<PATH_START>",
785 "(1,0)",
786 "(1,1)",
787 "<PATH_END>",
788 ],
789 [
790 "<ADJLIST_START>",
791 "(1,0)",
792 "<-->",
793 "(1,1)",
794 ";",
795 "(0,1)",
796 "<-->",
797 "(0,0)",
798 ";",
799 "(0,1)",
800 "<-->",
801 "(1,1)",
802 ";",
803 "<ADJLIST_END>",
804 "<ORIGIN_START>",
805 "(1,0)",
806 "<ORIGIN_END>",
807 "<TARGET_START>",
808 "(1,1)",
809 "<TARGET_END>",
810 "<PATH_START>",
811 "(1,0)",
812 "(1,1)",
813 "<PATH_END>",
814 ],
815 )
816 assert equal_except_adj_list_sequence(
817 [
818 "<ADJLIST_START>",
819 "(0,1)",
820 "<-->",
821 "(1,1)",
822 ";",
823 "(1,0)",
824 "<-->",
825 "(1,1)",
826 ";",
827 "(0,1)",
828 "<-->",
829 "(0,0)",
830 ";",
831 "<ADJLIST_END>",
832 "<ORIGIN_START>",
833 "(1,0)",
834 "<ORIGIN_END>",
835 "<TARGET_START>",
836 "(1,1)",
837 "<TARGET_END>",
838 "<PATH_START>",
839 "(1,0)",
840 "(1,1)",
841 "<PATH_END>",
842 ],
843 [
844 "<ADJLIST_START>",
845 "(1,1)",
846 "<-->",
847 "(0,1)",
848 ";",
849 "(1,0)",
850 "<-->",
851 "(1,1)",
852 ";",
853 "(0,1)",
854 "<-->",
855 "(0,0)",
856 ";",
857 "<ADJLIST_END>",
858 "<ORIGIN_START>",
859 "(1,0)",
860 "<ORIGIN_END>",
861 "<TARGET_START>",
862 "(1,1)",
863 "<TARGET_END>",
864 "<PATH_START>",
865 "(1,0)",
866 "(1,1)",
867 "<PATH_END>",
868 ],
869 )
870 assert not equal_except_adj_list_sequence(
871 [
872 "<ADJLIST_START>",
873 "(0,1)",
874 "<-->",
875 "(1,1)",
876 ";",
877 "(1,0)",
878 "<-->",
879 "(1,1)",
880 ";",
881 "(0,1)",
882 "<-->",
883 "(0,0)",
884 ";",
885 "<ADJLIST_END>",
886 "<ORIGIN_START>",
887 "(1,0)",
888 "<ORIGIN_END>",
889 "<TARGET_START>",
890 "(1,1)",
891 "<TARGET_END>",
892 "<PATH_START>",
893 "(1,0)",
894 "(1,1)",
895 "<PATH_END>",
896 ],
897 [
898 "<ADJLIST_START>",
899 "(1,0)",
900 "<-->",
901 "(1,1)",
902 ";",
903 "(0,1)",
904 "<-->",
905 "(0,0)",
906 ";",
907 "(0,1)",
908 "<-->",
909 "(1,1)",
910 ";",
911 "<ADJLIST_END>",
912 "<ORIGIN_START>",
913 "(1,0)",
914 "<ORIGIN_END>",
915 "<TARGET_START>",
916 "(1,1)",
917 "<TARGET_END>",
918 "<PATH_START>",
919 "(1,1)",
920 "(1,0)",
921 "<PATH_END>",
922 ],
923 )
924 assert not equal_except_adj_list_sequence(
925 [
926 "<ADJLIST_START>",
927 "(0,1)",
928 "<-->",
929 "(1,1)",
930 ";",
931 "(1,0)",
932 "<-->",
933 "(1,1)",
934 ";",
935 "(0,1)",
936 "<-->",
937 "(0,0)",
938 ";",
939 "<ADJLIST_END>",
940 "<ORIGIN_START>",
941 "(1,0)",
942 "<ORIGIN_END>",
943 "<TARGET_START>",
944 "(1,1)",
945 "<TARGET_END>",
946 "<PATH_START>",
947 "(1,0)",
948 "(1,1)",
949 "<PATH_END>",
950 ],
951 [
952 "<ADJLIST_START>",
953 "(0,1)",
954 "<-->",
955 "(1,1)",
956 ";",
957 "(1,0)",
958 "<-->",
959 "(1,1)",
960 ";",
961 "(0,1)",
962 "<-->",
963 "(0,0)",
964 ";",
965 "<ADJLIST_END>",
966 "<ORIGIN_START>",
967 "(1,0)",
968 "<ORIGIN_END>",
969 "<TARGET_START>",
970 "(1,1)",
971 "<TARGET_END>",
972 "<PATH_START>",
973 "(1,0)",
974 "(1,1)",
975 "<PATH_END>",
976 "<PATH_END>",
977 ],
978 )
979 assert not equal_except_adj_list_sequence(
980 [
981 "<ADJLIST_START>",
982 "(0,1)",
983 "<-->",
984 "(1,1)",
985 ";",
986 "(1,0)",
987 "<-->",
988 "(1,1)",
989 ";",
990 "(0,1)",
991 "<-->",
992 "(0,0)",
993 ";",
994 "<ADJLIST_END>",
995 "(1,0)",
996 "<ORIGIN_END>",
997 "<TARGET_START>",
998 "(1,1)",
999 "<TARGET_END>",
1000 "<PATH_START>",
1001 "(1,0)",
1002 "(1,1)",
1003 "<PATH_END>",
1004 ],
1005 [
1006 "<ADJLIST_START>",
1007 "(0,1)",
1008 "<-->",
1009 "(1,1)",
1010 ";",
1011 "(1,0)",
1012 "<-->",
1013 "(1,1)",
1014 ";",
1015 "(0,1)",
1016 "<-->",
1017 "(0,0)",
1018 ";",
1019 "<ADJLIST_END>",
1020 "<ORIGIN_START>",
1021 "(1,0)",
1022 "<ORIGIN_END>",
1023 "<TARGET_START>",
1024 "(1,1)",
1025 "<TARGET_END>",
1026 "<PATH_START>",
1027 "(1,0)",
1028 "(1,1)",
1029 "<PATH_END>",
1030 ],
1031 )
1032 assert not equal_except_adj_list_sequence(
1033 [
1034 "<ADJLIST_START>",
1035 "(0,1)",
1036 "<-->",
1037 "(1,1)",
1038 ";",
1039 "(1,0)",
1040 "<-->",
1041 "(1,1)",
1042 ";",
1043 "(0,1)",
1044 "<-->",
1045 "(0,0)",
1046 ";",
1047 "<ADJLIST_END>",
1048 "<ORIGIN_START>",
1049 "(1,0)",
1050 "<ORIGIN_END>",
1051 "<TARGET_START>",
1052 "(1,1)",
1053 "<TARGET_END>",
1054 "<PATH_START>",
1055 "(1,0)",
1056 "(1,1)",
1057 "<PATH_END>",
1058 ],
1059 [
1060 "(0,1)",
1061 "<-->",
1062 "(1,1)",
1063 ";",
1064 "(1,0)",
1065 "<-->",
1066 "(1,1)",
1067 ";",
1068 "(0,1)",
1069 "<-->",
1070 "(0,0)",
1071 ";",
1072 "<ADJLIST_END>",
1073 "<ORIGIN_START>",
1074 "(1,0)",
1075 "<ORIGIN_END>",
1076 "<TARGET_START>",
1077 "(1,1)",
1078 "<TARGET_END>",
1079 "<PATH_START>",
1080 "(1,0)",
1081 "(1,1)",
1082 "<PATH_END>",
1083 ],
1084 )
1085 with pytest.raises(ValueError): # noqa: PT011
1086 equal_except_adj_list_sequence(
1087 [
1088 "(0,1)",
1089 "<-->",
1090 "(1,1)",
1091 ";",
1092 "(1,0)",
1093 "<-->",
1094 "(1,1)",
1095 ";",
1096 "(0,1)",
1097 "<-->",
1098 "(0,0)",
1099 ";",
1100 "<ADJLIST_END>",
1101 "<ORIGIN_START>",
1102 "(1,0)",
1103 "<ORIGIN_END>",
1104 "<TARGET_START>",
1105 "(1,1)",
1106 "<TARGET_END>",
1107 "<PATH_START>",
1108 "(1,0)",
1109 "(1,1)",
1110 "<PATH_END>",
1111 ],
1112 [
1113 "(0,1)",
1114 "<-->",
1115 "(1,1)",
1116 ";",
1117 "(1,0)",
1118 "<-->",
1119 "(1,1)",
1120 ";",
1121 "(0,1)",
1122 "<-->",
1123 "(0,0)",
1124 ";",
1125 "<ADJLIST_END>",
1126 "<ORIGIN_START>",
1127 "(1,0)",
1128 "<ORIGIN_END>",
1129 "<TARGET_START>",
1130 "(1,1)",
1131 "<TARGET_END>",
1132 "<PATH_START>",
1133 "(1,0)",
1134 "(1,1)",
1135 "<PATH_END>",
1136 ],
1137 )
1138 with pytest.raises(ValueError): # noqa: PT011
1139 equal_except_adj_list_sequence(
1140 [
1141 "<ADJLIST_START>",
1142 "(0,1)",
1143 "<-->",
1144 "(1,1)",
1145 ";",
1146 "(1,0)",
1147 "<-->",
1148 "(1,1)",
1149 ";",
1150 "(0,1)",
1151 "<-->",
1152 "(0,0)",
1153 ";",
1154 "<ORIGIN_START>",
1155 "(1,0)",
1156 "<ORIGIN_END>",
1157 "<TARGET_START>",
1158 "(1,1)",
1159 "<TARGET_END>",
1160 "<PATH_START>",
1161 "(1,0)",
1162 "(1,1)",
1163 "<PATH_END>",
1164 ],
1165 [
1166 "<ADJLIST_START>",
1167 "(0,1)",
1168 "<-->",
1169 "(1,1)",
1170 ";",
1171 "(1,0)",
1172 "<-->",
1173 "(1,1)",
1174 ";",
1175 "(0,1)",
1176 "<-->",
1177 "(0,0)",
1178 ";",
1179 "<ORIGIN_START>",
1180 "(1,0)",
1181 "<ORIGIN_END>",
1182 "<TARGET_START>",
1183 "(1,1)",
1184 "<TARGET_END>",
1185 "<PATH_START>",
1186 "(1,0)",
1187 "(1,1)",
1188 "<PATH_END>",
1189 ],
1190 )
1191 assert not equal_except_adj_list_sequence(
1192 [
1193 "<ADJLIST_START>",
1194 "(0,1)",
1195 "<-->",
1196 "(1,1)",
1197 ";",
1198 "(1,0)",
1199 "<-->",
1200 "(1,1)",
1201 ";",
1202 "(0,1)",
1203 "<-->",
1204 "(0,0)",
1205 ";",
1206 "<ADJLIST_END>",
1207 "<ORIGIN_START>",
1208 "(1,0)",
1209 "<ORIGIN_END>",
1210 "<TARGET_START>",
1211 "(1,1)",
1212 "<TARGET_END>",
1213 "<PATH_START>",
1214 "(1,0)",
1215 "(1,1)",
1216 "<PATH_END>",
1217 ],
1218 [
1219 "<ADJLIST_START>",
1220 "(0,1)",
1221 "<-->",
1222 "(1,1)",
1223 ";",
1224 "(1,0)",
1225 "<-->",
1226 "(1,1)",
1227 ";",
1228 "(0,1)",
1229 "<-->",
1230 "(0,0)",
1231 ";",
1232 "<ORIGIN_START>",
1233 "(1,0)",
1234 "<ORIGIN_END>",
1235 "<TARGET_START>",
1236 "(1,1)",
1237 "<TARGET_END>",
1238 "<PATH_START>",
1239 "(1,0)",
1240 "(1,1)",
1241 "<PATH_END>",
1242 ],
1243 )
1245 # CTT
1246 assert equal_except_adj_list_sequence(
1247 [
1248 "<ADJLIST_START>",
1249 "(",
1250 "0",
1251 ",",
1252 "1",
1253 ")",
1254 "<-->",
1255 "(",
1256 "1",
1257 ",",
1258 "1",
1259 ")",
1260 ";",
1261 "(",
1262 "1",
1263 ",",
1264 "0",
1265 ")",
1266 "<-->",
1267 "(",
1268 "1",
1269 ",",
1270 "1",
1271 ")",
1272 ";",
1273 "(",
1274 "0",
1275 ",",
1276 "1",
1277 ")",
1278 "<-->",
1279 "(",
1280 "0",
1281 ",",
1282 "0",
1283 ")",
1284 ";",
1285 "<ADJLIST_END>",
1286 "<ORIGIN_START>",
1287 "(",
1288 "1",
1289 ",",
1290 "0",
1291 ")",
1292 "<ORIGIN_END>",
1293 "<TARGET_START>",
1294 "(",
1295 "1",
1296 ",",
1297 "1",
1298 ")",
1299 "<TARGET_END>",
1300 "<PATH_START>",
1301 "(",
1302 "1",
1303 ",",
1304 "0",
1305 ")",
1306 "(",
1307 "1",
1308 ",",
1309 "1",
1310 ")",
1311 "<PATH_END>",
1312 ],
1313 [
1314 "<ADJLIST_START>",
1315 "(",
1316 "0",
1317 ",",
1318 "1",
1319 ")",
1320 "<-->",
1321 "(",
1322 "1",
1323 ",",
1324 "1",
1325 ")",
1326 ";",
1327 "(",
1328 "1",
1329 ",",
1330 "0",
1331 ")",
1332 "<-->",
1333 "(",
1334 "1",
1335 ",",
1336 "1",
1337 ")",
1338 ";",
1339 "(",
1340 "0",
1341 ",",
1342 "1",
1343 ")",
1344 "<-->",
1345 "(",
1346 "0",
1347 ",",
1348 "0",
1349 ")",
1350 ";",
1351 "<ADJLIST_END>",
1352 "<ORIGIN_START>",
1353 "(",
1354 "1",
1355 ",",
1356 "0",
1357 ")",
1358 "<ORIGIN_END>",
1359 "<TARGET_START>",
1360 "(",
1361 "1",
1362 ",",
1363 "1",
1364 ")",
1365 "<TARGET_END>",
1366 "<PATH_START>",
1367 "(",
1368 "1",
1369 ",",
1370 "0",
1371 ")",
1372 "(",
1373 "1",
1374 ",",
1375 "1",
1376 ")",
1377 "<PATH_END>",
1378 ],
1379 )
1380 assert equal_except_adj_list_sequence(
1381 [
1382 "<ADJLIST_START>",
1383 "(",
1384 "0",
1385 ",",
1386 "1",
1387 ")",
1388 "<-->",
1389 "(",
1390 "1",
1391 ",",
1392 "1",
1393 ")",
1394 ";",
1395 "(",
1396 "1",
1397 ",",
1398 "0",
1399 ")",
1400 "<-->",
1401 "(",
1402 "1",
1403 ",",
1404 "1",
1405 ")",
1406 ";",
1407 "(",
1408 "0",
1409 ",",
1410 "1",
1411 ")",
1412 "<-->",
1413 "(",
1414 "0",
1415 ",",
1416 "0",
1417 ")",
1418 ";",
1419 "<ADJLIST_END>",
1420 "<ORIGIN_START>",
1421 "(",
1422 "1",
1423 ",",
1424 "0",
1425 ")",
1426 "<ORIGIN_END>",
1427 "<TARGET_START>",
1428 "(",
1429 "1",
1430 ",",
1431 "1",
1432 ")",
1433 "<TARGET_END>",
1434 "<PATH_START>",
1435 "(",
1436 "1",
1437 ",",
1438 "0",
1439 ")",
1440 "(",
1441 "1",
1442 ",",
1443 "1",
1444 ")",
1445 "<PATH_END>",
1446 ],
1447 [
1448 "<ADJLIST_START>",
1449 "(",
1450 "1",
1451 ",",
1452 "1",
1453 ")",
1454 "<-->",
1455 "(",
1456 "0",
1457 ",",
1458 "1",
1459 ")",
1460 ";",
1461 "(",
1462 "1",
1463 ",",
1464 "0",
1465 ")",
1466 "<-->",
1467 "(",
1468 "1",
1469 ",",
1470 "1",
1471 ")",
1472 ";",
1473 "(",
1474 "0",
1475 ",",
1476 "1",
1477 ")",
1478 "<-->",
1479 "(",
1480 "0",
1481 ",",
1482 "0",
1483 ")",
1484 ";",
1485 "<ADJLIST_END>",
1486 "<ORIGIN_START>",
1487 "(",
1488 "1",
1489 ",",
1490 "0",
1491 ")",
1492 "<ORIGIN_END>",
1493 "<TARGET_START>",
1494 "(",
1495 "1",
1496 ",",
1497 "1",
1498 ")",
1499 "<TARGET_END>",
1500 "<PATH_START>",
1501 "(",
1502 "1",
1503 ",",
1504 "0",
1505 ")",
1506 "(",
1507 "1",
1508 ",",
1509 "1",
1510 ")",
1511 "<PATH_END>",
1512 ],
1513 )
1514 # This inactive test demonstrates the lack of robustness of the function for comparing source `LatticeMaze` objects.
1515 # See function documentation for details.
1516 # assert not equal_except_adj_list_sequence(
1517 # "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(),
1518 # "<ADJLIST_START> ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split()
1519 # )
1522# @mivanit: this was really difficult to understand
1523@pytest.mark.parametrize(
1524 ("type_", "validation_funcs", "assertion"),
1525 [
1526 pytest.param(
1527 type_,
1528 vfs,
1529 assertion,
1530 id=f"{i}-{type_.__name__}",
1531 )
1532 for i, (type_, vfs, assertion) in enumerate(
1533 [
1534 (
1535 # type
1536 PathTokenizers._PathTokenizer,
1537 # validation_funcs
1538 dict(),
1539 # assertion
1540 lambda x: PathTokenizers.StepSequence(
1541 step_tokenizers=(StepTokenizers.Distance(),),
1542 )
1543 in x,
1544 ),
1545 (
1546 # type
1547 PathTokenizers._PathTokenizer,
1548 # validation_funcs
1549 {PathTokenizers._PathTokenizer: lambda x: x.is_valid()},
1550 # assertion
1551 lambda x: PathTokenizers.StepSequence(
1552 step_tokenizers=(StepTokenizers.Distance(),),
1553 )
1554 not in x
1555 and PathTokenizers.StepSequence(
1556 step_tokenizers=(
1557 StepTokenizers.Coord(),
1558 StepTokenizers.Coord(),
1559 ),
1560 )
1561 not in x,
1562 ),
1563 ],
1564 )
1565 ],
1566)
1567def test_all_instances2(
1568 type_: FiniteValued,
1569 validation_funcs: frozendict.frozendict[
1570 FiniteValued,
1571 Callable[[FiniteValued], bool],
1572 ],
1573 assertion: Callable[[list[FiniteValued]], bool],
1574):
1575 assert assertion(all_instances(type_, validation_funcs))
1578@pytest.mark.parametrize(
1579 ("coords", "result"),
1580 [
1581 pytest.param(
1582 np.array(coords),
1583 res,
1584 id=f"{coords}",
1585 )
1586 for coords, res in (
1587 [
1588 ([[0, 0], [0, 1], [1, 1]], VOCAB.PATH_RIGHT),
1589 ([[0, 0], [1, 0], [1, 1]], VOCAB.PATH_LEFT),
1590 ([[0, 0], [0, 1], [0, 2]], VOCAB.PATH_FORWARD),
1591 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD),
1592 ([[0, 0], [0, 1], [0, 1]], VOCAB.PATH_STAY),
1593 ([[1, 1], [0, 1], [0, 0]], VOCAB.PATH_LEFT),
1594 ([[1, 1], [1, 0], [0, 0]], VOCAB.PATH_RIGHT),
1595 ([[0, 2], [0, 1], [0, 0]], VOCAB.PATH_FORWARD),
1596 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD),
1597 ([[0, 1], [0, 1], [0, 0]], ValueError),
1598 ([[0, 1], [1, 1], [0, 0]], ValueError),
1599 ([[1, 0], [1, 1], [0, 0]], ValueError),
1600 ([[0, 1], [0, 2], [0, 0]], ValueError),
1601 ([[0, 1], [0, 0], [0, 0]], VOCAB.PATH_STAY),
1602 ([[1, 1], [0, 0], [0, 1]], ValueError),
1603 ([[1, 1], [0, 0], [1, 0]], ValueError),
1604 ([[0, 2], [0, 0], [0, 1]], ValueError),
1605 ([[0, 0], [0, 0], [0, 1]], ValueError),
1606 ([[0, 1], [0, 0], [0, 1]], VOCAB.PATH_BACKWARD),
1607 ([[-1, 0], [0, 0], [1, 0]], VOCAB.PATH_FORWARD),
1608 ([[-1, 0], [0, 0], [0, 1]], VOCAB.PATH_LEFT),
1609 ([[-1, 0], [0, 0], [-1, 0]], VOCAB.PATH_BACKWARD),
1610 ([[-1, 0], [0, 0], [0, -1]], VOCAB.PATH_RIGHT),
1611 ([[-1, 0], [0, 0], [1, 0], [2, 0]], ValueError),
1612 ([[-1, 0], [0, 0]], ValueError),
1613 ([[-1, 0, 0], [0, 0, 0]], ValueError),
1614 ]
1615 )
1616 ],
1617)
1618def test_get_relative_direction(
1619 coords: Int[np.ndarray, "prev_cur_next=3 axis=2"],
1620 result: str | type[Exception],
1621):
1622 if isinstance(result, type) and issubclass(result, Exception):
1623 with pytest.raises(result):
1624 get_relative_direction(coords)
1625 return
1626 assert get_relative_direction(coords) == result
1629@pytest.mark.parametrize(
1630 ("edges", "result"),
1631 [
1632 pytest.param(
1633 edges,
1634 res,
1635 id=f"{edges}",
1636 )
1637 for edges, res in (
1638 [
1639 (np.array([[0, 0], [0, 1]]), 1),
1640 (np.array([[1, 0], [0, 1]]), 2),
1641 (np.array([[-1, 0], [0, 1]]), 2),
1642 (np.array([[0, 0], [5, 3]]), 8),
1643 (
1644 np.array(
1645 [
1646 [[0, 0], [0, 1]],
1647 [[1, 0], [0, 1]],
1648 [[-1, 0], [0, 1]],
1649 [[0, 0], [5, 3]],
1650 ],
1651 ),
1652 [1, 2, 2, 8],
1653 ),
1654 (np.array([[[0, 0], [5, 3]]]), [8]),
1655 ]
1656 )
1657 ],
1658)
1659def test_manhattan_distance(
1660 edges: ConnectionArray | Connection,
1661 result: Int[np.ndarray, " edges"] | Int[np.ndarray, ""] | type[Exception],
1662):
1663 if isinstance(result, type) and issubclass(result, Exception):
1664 with pytest.raises(result):
1665 manhattan_distance(edges)
1666 return
1667 assert np.array_equal(manhattan_distance(edges), np.array(result, dtype=np.int8))
1670@pytest.mark.parametrize(
1671 "n",
1672 [pytest.param(n) for n in [2, 3, 5, 20]],
1673)
1674def test_lattice_connection_arrray(n):
1675 edges = lattice_connection_array(n)
1676 assert tuple(edges.shape) == (2 * n * (n - 1), 2, 2)
1677 assert np.all(np.sum(edges[:, 1], axis=1) > np.sum(edges[:, 0], axis=1))
1678 assert tuple(np.unique(edges, axis=0).shape) == (2 * n * (n - 1), 2, 2)
1681@pytest.mark.parametrize(
1682 ("edges", "maze"),
1683 [
1684 pytest.param(
1685 edges(),
1686 maze,
1687 id=f"edges[{i}]; maze[{j}]",
1688 )
1689 for (i, edges), (j, maze) in itertools.product(
1690 enumerate(
1691 [
1692 lambda: lattice_connection_array(GRID_N),
1693 lambda: np.flip(lattice_connection_array(GRID_N), axis=1),
1694 lambda: lattice_connection_array(GRID_N - 1),
1695 lambda: _NUMPY_RNG.choice(
1696 lattice_connection_array(GRID_N),
1697 2 * GRID_N,
1698 axis=0,
1699 ),
1700 lambda: _NUMPY_RNG.choice(
1701 lattice_connection_array(GRID_N),
1702 1,
1703 axis=0,
1704 ),
1705 ],
1706 ),
1707 enumerate(MAZE_DATASET.mazes),
1708 )
1709 ],
1710)
1711def test_is_connection(edges: ConnectionArray, maze: LatticeMaze):
1712 output = is_connection(edges, maze.connection_list)
1713 sorted_edges = np.sort(edges, axis=1)
1714 edge_direction = (
1715 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0
1716 ).astype(np.int8)
1717 assert np.array_equal(
1718 output,
1719 maze.connection_list[
1720 edge_direction,
1721 sorted_edges[:, 0, 0],
1722 sorted_edges[:, 0, 1],
1723 ],
1724 )