Coverage for maze_dataset/testing_utils.py: 100%
20 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
1"""Shared utilities for tests only.
3Do not import into any module outside of the tests directory
4"""
6import itertools
7from typing import Final, NamedTuple, Sequence
9import frozendict
10import numpy as np
12from maze_dataset import (
13 CoordArray,
14 LatticeMaze,
15 LatticeMazeGenerators,
16 MazeDataset,
17 MazeDatasetConfig,
18 SolvedMaze,
19 TargetedLatticeMaze,
20)
21from maze_dataset.tokenization import (
22 MazeTokenizer,
23 MazeTokenizerModular,
24 TokenizationMode,
25)
27GRID_N: Final[int] = 5
28N_MAZES: Final[int] = 5
29CFG: Final[MazeDatasetConfig] = MazeDatasetConfig(
30 name="test",
31 grid_n=GRID_N,
32 n_mazes=N_MAZES,
33 maze_ctor=LatticeMazeGenerators.gen_dfs,
34)
35MAZE_DATASET: Final[MazeDataset] = MazeDataset.from_config(
36 CFG,
37 do_download=False,
38 load_local=False,
39 do_generate=True,
40 save_local=False,
41 verbose=True,
42 gen_parallel=False,
43)
44LATTICE_MAZES: Final[tuple[LatticeMaze, ...]] = tuple(
45 LatticeMazeGenerators.gen_dfs(np.array([GRID_N, GRID_N])) for _ in range(N_MAZES)
46)
47_PATHS = tuple(maze.generate_random_path() for maze in LATTICE_MAZES)
48TARGETED_MAZES: Final[tuple[TargetedLatticeMaze, ...]] = tuple(
49 TargetedLatticeMaze.from_lattice_maze(maze, path[0], path[-1])
50 for maze, path in zip(LATTICE_MAZES, _PATHS, strict=False)
51)
52# MIXED_MAZES alternates the maze types, so you can slice a contiguous subset and still get all types
53MIXED_MAZES: Final[tuple[LatticeMaze | TargetedLatticeMaze | SolvedMaze, ...]] = tuple(
54 x
55 for x in itertools.chain.from_iterable(
56 itertools.zip_longest(MAZE_DATASET.mazes, TARGETED_MAZES, LATTICE_MAZES),
57 )
58)
61class MANUAL_MAZE(NamedTuple): # noqa: N801
62 """A named tuple for manual maze definitions"""
64 tokens: str
65 ascii: Sequence[str]
66 straightaway_footprints: CoordArray
69ASCII_MAZES: Final[frozendict.frozendict[str, MANUAL_MAZE]] = frozendict.frozendict(
70 small_3x3=MANUAL_MAZE(
71 tokens="<ADJLIST_START> (2,0) <--> (2,1) ; (0,0) <--> (0,1) ; (0,0) <--> (1,0) ; (0,2) <--> (1,2) ; (1,0) <--> (2,0) ; (0,2) <--> (0,1) ; (2,2) <--> (2,1) ; (1,1) <--> (2,1) ; <ADJLIST_END> <ORIGIN_START> (0,0) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (0,0) (1,0) (2,0) (2,1) <PATH_END>",
72 ascii=(
73 "#######",
74 "#S #",
75 "#X### #",
76 "#X# # #",
77 "#X# ###",
78 "#XXE #",
79 "#######",
80 ),
81 straightaway_footprints=np.array(
82 [
83 [0, 0],
84 [2, 0],
85 [2, 1],
86 ],
87 ),
88 ),
89 big_10x10=MANUAL_MAZE(
90 tokens="<ADJLIST_START> (8,2) <--> (8,3) ; (3,7) <--> (3,6) ; (6,7) <--> (6,8) ; (4,6) <--> (5,6) ; (9,5) <--> (9,4) ; (3,3) <--> (3,4) ; (5,1) <--> (4,1) ; (2,6) <--> (2,7) ; (8,5) <--> (8,4) ; (1,9) <--> (2,9) ; (4,1) <--> (4,2) ; (0,8) <--> (0,7) ; (5,4) <--> (5,3) ; (6,3) <--> (6,4) ; (5,0) <--> (4,0) ; (5,3) <--> (5,2) ; (3,1) <--> (2,1) ; (9,1) <--> (9,0) ; (3,5) <--> (3,6) ; (5,5) <--> (6,5) ; (7,1) <--> (7,2) ; (0,1) <--> (1,1) ; (7,8) <--> (8,8) ; (3,9) <--> (4,9) ; (4,6) <--> (4,7) ; (0,6) <--> (0,7) ; (3,4) <--> (3,5) ; (6,0) <--> (5,0) ; (7,7) <--> (7,6) ; (1,6) <--> (0,6) ; (6,1) <--> (6,0) ; (8,6) <--> (8,7) ; (9,9) <--> (9,8) ; (1,8) <--> (1,9) ; (2,1) <--> (2,2) ; (9,2) <--> (9,3) ; (5,9) <--> (6,9) ; (3,2) <--> (2,2) ; (0,8) <--> (0,9) ; (5,6) <--> (5,7) ; (2,3) <--> (2,4) ; (4,5) <--> (4,4) ; (8,9) <--> (8,8) ; (9,6) <--> (8,6) ; (3,7) <--> (3,8) ; (8,0) <--> (7,0) ; (6,1) <--> (6,2) ; (0,1) <--> (0,0) ; (7,3) <--> (7,4) ; (9,4) <--> (9,3) ; (9,6) <--> (9,5) ; (8,7) <--> (7,7) ; (5,2) <--> (5,1) ; (0,0) <--> (1,0) ; (7,2) <--> (7,3) ; (2,5) <--> (2,6) ; (4,9) <--> (5,9) ; (5,5) <--> (5,4) ; (5,6) <--> (6,6) ; (7,8) <--> (7,9) ; (1,7) <--> (2,7) ; (4,6) <--> (4,5) ; (1,1) <--> (1,2) ; (3,1) <--> (3,0) ; (1,5) <--> (1,6) ; (8,3) <--> (8,4) ; (9,9) <--> (8,9) ; (8,5) <--> (7,5) ; (1,4) <--> (2,4) ; (3,0) <--> (4,0) ; (3,3) <--> (4,3) ; (6,9) <--> (6,8) ; (1,0) <--> (2,0) ; (6,0) <--> (7,0) ; (8,0) <--> (9,0) ; (2,3) <--> (2,2) ; (2,8) <--> (3,8) ; (5,7) <--> (6,7) ; (1,3) <--> (0,3) ; (9,7) <--> (9,8) ; (7,5) <--> (7,4) ; (1,8) <--> (2,8) ; (6,5) <--> (6,4) ; (0,2) <--> (1,2) ; (0,7) <--> (1,7) ; (0,3) <--> (0,2) ; (4,3) <--> (4,2) ; (5,8) <--> (4,8) ; (9,1) <--> (8,1) ; (9,2) <--> (8,2) ; (1,3) <--> (1,4) ; (2,9) <--> (3,9) ; (4,8) <--> (4,7) ; (0,5) <--> (0,4) ; (8,1) <--> (7,1) ; (0,3) <--> (0,4) ; (9,7) <--> (9,6) ; (7,6) <--> (6,6) ; (1,5) <--> (0,5) ; <ADJLIST_END> <ORIGIN_START> (6,2) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (6,2) (6,1) (6,0) (5,0) (4,0) (3,0) (3,1) (2,1) <PATH_END>",
91 ascii=(
92 "#####################",
93 "# # # #",
94 "# # # # ### # # #####",
95 "# # # # # # #",
96 "# ####### ##### # # #",
97 "# #E # # # #",
98 "###X# ########### # #",
99 "#XXX# # # #",
100 "#X##### ########### #",
101 "#X# # # #",
102 "#X# ######### ### # #",
103 "#X# # # # #",
104 "#X######### # # ### #",
105 "#XXXXS# # # #",
106 "# ########### #######",
107 "# # # # #",
108 "# # ####### ### # ###",
109 "# # # # # #",
110 "# # # ####### ##### #",
111 "# # #",
112 "#####################",
113 ),
114 straightaway_footprints=np.array(
115 [
116 [6, 2],
117 [6, 0],
118 [3, 0],
119 [3, 1],
120 [2, 1],
121 ],
122 ),
123 ),
124 longer_10x10=MANUAL_MAZE(
125 tokens="<ADJLIST_START> (8,2) <--> (8,3) ; (3,7) <--> (3,6) ; (6,7) <--> (6,8) ; (4,6) <--> (5,6) ; (9,5) <--> (9,4) ; (3,3) <--> (3,4) ; (5,1) <--> (4,1) ; (2,6) <--> (2,7) ; (8,5) <--> (8,4) ; (1,9) <--> (2,9) ; (4,1) <--> (4,2) ; (0,8) <--> (0,7) ; (5,4) <--> (5,3) ; (6,3) <--> (6,4) ; (5,0) <--> (4,0) ; (5,3) <--> (5,2) ; (3,1) <--> (2,1) ; (9,1) <--> (9,0) ; (3,5) <--> (3,6) ; (5,5) <--> (6,5) ; (7,1) <--> (7,2) ; (0,1) <--> (1,1) ; (7,8) <--> (8,8) ; (3,9) <--> (4,9) ; (4,6) <--> (4,7) ; (0,6) <--> (0,7) ; (3,4) <--> (3,5) ; (6,0) <--> (5,0) ; (7,7) <--> (7,6) ; (1,6) <--> (0,6) ; (6,1) <--> (6,0) ; (8,6) <--> (8,7) ; (9,9) <--> (9,8) ; (1,8) <--> (1,9) ; (2,1) <--> (2,2) ; (9,2) <--> (9,3) ; (5,9) <--> (6,9) ; (3,2) <--> (2,2) ; (0,8) <--> (0,9) ; (5,6) <--> (5,7) ; (2,3) <--> (2,4) ; (4,5) <--> (4,4) ; (8,9) <--> (8,8) ; (9,6) <--> (8,6) ; (3,7) <--> (3,8) ; (8,0) <--> (7,0) ; (6,1) <--> (6,2) ; (0,1) <--> (0,0) ; (7,3) <--> (7,4) ; (9,4) <--> (9,3) ; (9,6) <--> (9,5) ; (8,7) <--> (7,7) ; (5,2) <--> (5,1) ; (0,0) <--> (1,0) ; (7,2) <--> (7,3) ; (2,5) <--> (2,6) ; (4,9) <--> (5,9) ; (5,5) <--> (5,4) ; (5,6) <--> (6,6) ; (7,8) <--> (7,9) ; (1,7) <--> (2,7) ; (4,6) <--> (4,5) ; (1,1) <--> (1,2) ; (3,1) <--> (3,0) ; (1,5) <--> (1,6) ; (8,3) <--> (8,4) ; (9,9) <--> (8,9) ; (8,5) <--> (7,5) ; (1,4) <--> (2,4) ; (3,0) <--> (4,0) ; (3,3) <--> (4,3) ; (6,9) <--> (6,8) ; (1,0) <--> (2,0) ; (6,0) <--> (7,0) ; (8,0) <--> (9,0) ; (2,3) <--> (2,2) ; (2,8) <--> (3,8) ; (5,7) <--> (6,7) ; (1,3) <--> (0,3) ; (9,7) <--> (9,8) ; (7,5) <--> (7,4) ; (1,8) <--> (2,8) ; (6,5) <--> (6,4) ; (0,2) <--> (1,2) ; (0,7) <--> (1,7) ; (0,3) <--> (0,2) ; (4,3) <--> (4,2) ; (5,8) <--> (4,8) ; (9,1) <--> (8,1) ; (9,2) <--> (8,2) ; (1,3) <--> (1,4) ; (2,9) <--> (3,9) ; (4,8) <--> (4,7) ; (0,5) <--> (0,4) ; (8,1) <--> (7,1) ; (0,3) <--> (0,4) ; (9,7) <--> (9,6) ; (7,6) <--> (6,6) ; (1,5) <--> (0,5) ; <ADJLIST_END> <ORIGIN_START> (6,2) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (6,2) (6,1) (6,0) (5,0) (4,0) (3,0) (3,1) (2,1) (2,2) (2,3) (2,4) (1,4) (1,3) (0,3) (0,4) (0,5) (1,5) (1,6) (0,6) (0,7) (0,8) <PATH_END>",
126 ascii=(
127 "#####################",
128 "# # XXXXX#XXXXE #",
129 "# # # #X###X#X# #####",
130 "# # #XXX#XXX# # #",
131 "# #######X##### # # #",
132 "# #XXXXXXX# # # #",
133 "###X# ########### # #",
134 "#XXX# # # #",
135 "#X##### ########### #",
136 "#X# # # #",
137 "#X# ######### ### # #",
138 "#X# # # # #",
139 "#X######### # # ### #",
140 "#XXXXS# # # #",
141 "# ########### #######",
142 "# # # # #",
143 "# # ####### ### # ###",
144 "# # # # # #",
145 "# # # ####### ##### #",
146 "# # #",
147 "#####################",
148 ),
149 straightaway_footprints=np.array(
150 [
151 [6, 2],
152 [6, 0],
153 [3, 0],
154 [3, 1],
155 [2, 1],
156 [2, 4],
157 [1, 4],
158 [1, 3],
159 [0, 3],
160 [0, 5],
161 [1, 5],
162 [1, 6],
163 [0, 6],
164 [0, 8],
165 ],
166 ),
167 ),
168)
170# A list of legacy `MazeTokenizer`s and their `MazeTokenizerModular` equivalents.
171# Used for unit tests where both versions are supported
172LEGACY_AND_EQUIVALENT_TOKENIZERS: list[MazeTokenizer | MazeTokenizerModular] = [
173 *[
174 MazeTokenizer(tokenization_mode=tok_mode, max_grid_size=20)
175 for tok_mode in TokenizationMode
176 ],
177 *[MazeTokenizerModular.from_legacy(tok_mode) for tok_mode in TokenizationMode],
178]