Coverage for tests/unit/generation/test_latticemaze.py: 100%
110 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
1import numpy as np
2import pytest
4from maze_dataset.constants import CoordArray
5from maze_dataset.generation.default_generators import DEFAULT_GENERATORS
6from maze_dataset.generation.generators import GENERATORS_MAP
7from maze_dataset.maze import LatticeMaze, PixelColors, SolvedMaze, TargetedLatticeMaze
8from maze_dataset.utils import adj_list_to_nested_set, bool_array_from_string
11# thanks to gpt for these tests of _from_pixel_grid
12@pytest.fixture
13def example_pixel_grid():
14 return ~np.array(
15 [
16 [1, 1, 1, 1, 1],
17 [1, 0, 0, 0, 1],
18 [1, 1, 1, 0, 1],
19 [1, 0, 0, 0, 1],
20 [1, 1, 1, 1, 1],
21 ],
22 dtype=bool,
23 )
26@pytest.fixture
27def example_rgb_pixel_grid():
28 return np.array(
29 [
30 [
31 PixelColors.WALL,
32 PixelColors.WALL,
33 PixelColors.WALL,
34 PixelColors.WALL,
35 PixelColors.WALL,
36 ],
37 [
38 PixelColors.WALL,
39 PixelColors.OPEN,
40 PixelColors.OPEN,
41 PixelColors.OPEN,
42 PixelColors.WALL,
43 ],
44 [
45 PixelColors.WALL,
46 PixelColors.WALL,
47 PixelColors.WALL,
48 PixelColors.WALL,
49 PixelColors.WALL,
50 ],
51 [
52 PixelColors.WALL,
53 PixelColors.OPEN,
54 PixelColors.WALL,
55 PixelColors.OPEN,
56 PixelColors.WALL,
57 ],
58 [
59 PixelColors.WALL,
60 PixelColors.WALL,
61 PixelColors.WALL,
62 PixelColors.WALL,
63 PixelColors.WALL,
64 ],
65 ],
66 dtype=np.uint8,
67 )
70def test_from_pixel_grid_bw(example_pixel_grid):
71 connection_list, grid_shape = LatticeMaze._from_pixel_grid_bw(example_pixel_grid)
73 assert isinstance(connection_list, np.ndarray)
74 assert connection_list.shape == (2, 2, 2)
75 assert np.all(connection_list[0] == np.array([[False, True], [False, False]]))
76 assert np.all(connection_list[1] == np.array([[True, False], [True, False]]))
77 assert grid_shape == (2, 2)
80def test_from_pixel_grid_with_positions(example_rgb_pixel_grid):
81 marked_positions = {
82 "start": PixelColors.START,
83 "end": PixelColors.END,
84 "path": PixelColors.PATH,
85 }
87 (
88 connection_list,
89 grid_shape,
90 out_positions,
91 ) = LatticeMaze._from_pixel_grid_with_positions(
92 example_rgb_pixel_grid,
93 marked_positions,
94 )
96 assert isinstance(connection_list, np.ndarray)
97 assert connection_list.shape == (2, 2, 2)
98 assert np.all(connection_list[0] == np.array([[False, False], [False, False]]))
99 assert np.all(connection_list[1] == np.array([[True, False], [False, False]]))
100 assert grid_shape == (2, 2)
102 assert isinstance(out_positions, dict)
103 assert len(out_positions) == 3
105 assert "start" in out_positions
106 assert "end" in out_positions
108 assert isinstance(out_positions["start"], np.ndarray)
109 assert isinstance(out_positions["end"], np.ndarray)
110 assert isinstance(out_positions["path"], np.ndarray)
112 assert out_positions["start"].shape == (0,)
113 assert out_positions["end"].shape == (0,)
114 assert out_positions["path"].shape == (0,)
117def test_find_start_end_points_in_rgb_pixel_grid():
118 rgb_pixel_grid_with_positions = np.array(
119 [
120 [
121 PixelColors.WALL,
122 PixelColors.WALL,
123 PixelColors.WALL,
124 PixelColors.WALL,
125 PixelColors.WALL,
126 ],
127 [
128 PixelColors.WALL,
129 PixelColors.START,
130 PixelColors.OPEN,
131 PixelColors.END,
132 PixelColors.WALL,
133 ],
134 [
135 PixelColors.WALL,
136 PixelColors.WALL,
137 PixelColors.WALL,
138 PixelColors.WALL,
139 PixelColors.WALL,
140 ],
141 [
142 PixelColors.WALL,
143 PixelColors.OPEN,
144 PixelColors.WALL,
145 PixelColors.OPEN,
146 PixelColors.WALL,
147 ],
148 [
149 PixelColors.WALL,
150 PixelColors.WALL,
151 PixelColors.WALL,
152 PixelColors.WALL,
153 PixelColors.WALL,
154 ],
155 ],
156 dtype=np.uint8,
157 )
159 marked_positions = {
160 "start": PixelColors.START,
161 "end": PixelColors.END,
162 "path": PixelColors.PATH,
163 }
165 (
166 connection_list,
167 grid_shape,
168 out_positions,
169 ) = LatticeMaze._from_pixel_grid_with_positions(
170 rgb_pixel_grid_with_positions,
171 marked_positions,
172 )
174 print(f"{out_positions = }")
176 assert isinstance(out_positions, dict)
177 assert len(out_positions) == 3
178 assert "start" in out_positions
179 assert "end" in out_positions
180 assert isinstance(out_positions["start"], np.ndarray)
181 assert isinstance(out_positions["end"], np.ndarray)
182 assert isinstance(out_positions["path"], np.ndarray)
184 assert np.all(out_positions["start"] == np.array([[0, 0]]))
185 assert np.all(out_positions["end"] == np.array([[0, 1]]))
186 assert out_positions["path"].shape == (0,)
189@pytest.mark.parametrize(("gfunc_name", "kwargs"), DEFAULT_GENERATORS)
190def test_pixels_ascii_roundtrip(gfunc_name, kwargs):
191 """tests all generators work and can be written to/from ascii and pixels"""
192 n: int = 5
193 maze_gen_func = GENERATORS_MAP[gfunc_name]
194 maze: LatticeMaze = maze_gen_func(np.array([n, n]), **kwargs)
196 maze_pixels: np.ndarray = maze.as_pixels()
197 maze_ascii: str = maze.as_ascii()
199 assert maze == LatticeMaze.from_pixels(maze_pixels)
200 assert maze == LatticeMaze.from_ascii(maze_ascii)
202 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3)
203 assert maze_pixels.shape == expected_shape, (
204 f"{maze_pixels.shape} != {expected_shape}"
205 )
206 assert all(n * 2 + 1 == len(line) for line in maze_ascii.splitlines()), (
207 f"{maze_ascii}"
208 )
211@pytest.mark.parametrize(("gfunc_name", "kwargs"), DEFAULT_GENERATORS)
212def test_targeted_solved_maze(gfunc_name, kwargs):
213 n: int = 5
214 maze_gen_func = GENERATORS_MAP[gfunc_name]
215 maze: LatticeMaze = maze_gen_func(np.array([n, n]), **kwargs)
216 solution: CoordArray = maze.generate_random_path()
217 tgt_maze: TargetedLatticeMaze = TargetedLatticeMaze.from_lattice_maze(
218 maze,
219 solution[0],
220 solution[-1],
221 )
223 tgt_maze_pixels: np.ndarray = tgt_maze.as_pixels()
224 tgt_maze_ascii: str = tgt_maze.as_ascii()
226 assert tgt_maze == TargetedLatticeMaze.from_pixels(tgt_maze_pixels)
227 assert tgt_maze == TargetedLatticeMaze.from_ascii(tgt_maze_ascii)
229 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3)
230 assert tgt_maze_pixels.shape == expected_shape, (
231 f"{tgt_maze_pixels.shape} != {expected_shape}"
232 )
233 assert all(n * 2 + 1 == len(line) for line in tgt_maze_ascii.splitlines()), (
234 f"{tgt_maze_ascii}"
235 )
237 solved_maze: SolvedMaze = SolvedMaze.from_targeted_lattice_maze(tgt_maze)
239 solved_maze_pixels: np.ndarray = solved_maze.as_pixels()
240 solved_maze_ascii: str = solved_maze.as_ascii()
242 assert solved_maze == SolvedMaze.from_pixels(solved_maze_pixels)
243 assert solved_maze == SolvedMaze.from_ascii(solved_maze_ascii)
245 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3)
246 assert tgt_maze_pixels.shape == expected_shape, (
247 f"{tgt_maze_pixels.shape} != {expected_shape}"
248 )
249 assert all(n * 2 + 1 == len(line) for line in solved_maze_ascii.splitlines()), (
250 f"{solved_maze_ascii}"
251 )
254def test_as_adj_list():
255 connection_list = bool_array_from_string(
256 """
257 F T
258 F F
260 T F
261 T F
262 """,
263 shape=[2, 2, 2],
264 )
266 maze = LatticeMaze(connection_list=connection_list)
268 adj_list = maze.as_adj_list(shuffle_d0=False, shuffle_d1=False)
270 expected = [[[0, 1], [1, 1]], [[0, 0], [0, 1]], [[1, 0], [1, 1]]]
272 assert adj_list_to_nested_set(expected) == adj_list_to_nested_set(adj_list)
275@pytest.mark.parametrize(("gfunc_name", "kwargs"), DEFAULT_GENERATORS)
276def test_get_nodes(gfunc_name, kwargs):
277 maze_gen_func = GENERATORS_MAP[gfunc_name]
278 maze = maze_gen_func(np.array((3, 2)), **kwargs)
279 assert (
280 maze.get_nodes().tolist()
281 == np.array([(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)]).tolist()
282 )
285@pytest.mark.parametrize(("gfunc_name", "kwargs"), DEFAULT_GENERATORS)
286def test_generate_random_path(gfunc_name, kwargs):
287 maze_gen_func = GENERATORS_MAP[gfunc_name]
288 maze = maze_gen_func(np.array((2, 2)), **kwargs)
289 path = maze.generate_random_path()
291 # len > 1 ensures that we have unique start and end nodes
292 assert len(path) > 1
295@pytest.mark.parametrize(("gfunc_name", "kwargs"), DEFAULT_GENERATORS)
296def test_generate_random_path_size_1(gfunc_name, kwargs):
297 maze_gen_func = GENERATORS_MAP[gfunc_name]
298 maze = maze_gen_func(np.array((1, 1)), **kwargs)
299 with pytest.raises(AssertionError):
300 maze.generate_random_path()