Coverage for tests/unit/generation/test_maze_dataset.py: 100%
139 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
1import copy
2from pathlib import Path
4import numpy as np
5import pytest
6from zanj import ZANJ
8from maze_dataset import (
9 MazeDataset,
10 MazeDatasetConfig,
11 register_maze_filter,
12 set_serialize_minimal_threshold,
13)
14from maze_dataset.constants import CoordArray
15from maze_dataset.dataset.dataset import (
16 register_dataset_filter,
17 register_filter_namespace_for_dataset,
18)
19from maze_dataset.generation.generators import GENERATORS_MAP
20from maze_dataset.maze import SolvedMaze
21from maze_dataset.utils import bool_array_from_string
24class TestMazeDatasetConfig:
25 pass
28TEST_CONFIGS = [
29 MazeDatasetConfig(
30 name="test",
31 grid_n=grid_n,
32 n_mazes=n_mazes,
33 maze_ctor=GENERATORS_MAP["gen_dfs"],
34 maze_ctor_kwargs=maze_ctor_kwargs,
35 )
36 for grid_n, n_mazes, maze_ctor_kwargs in [
37 (3, 5, {}),
38 (3, 1, {}),
39 (5, 5, dict(do_forks=False)),
40 ]
41]
44def test_generate_serial():
45 dataset = MazeDataset.generate(TEST_CONFIGS[0], gen_parallel=False)
47 assert len(dataset) == 5
48 for maze in dataset:
49 assert maze.grid_shape == (3, 3)
52def test_generate_parallel():
53 dataset = MazeDataset.generate(
54 TEST_CONFIGS[0],
55 gen_parallel=True,
56 verbose=True,
57 pool_kwargs=dict(processes=2),
58 )
60 assert len(dataset) == 5
61 for maze in dataset:
62 assert maze.grid_shape == (3, 3)
65def test_data_hash_wip():
66 dataset = MazeDataset.generate(TEST_CONFIGS[0])
67 # TODO: dataset.data_hash doesn't work right now
68 assert dataset
71def test_download():
72 with pytest.raises(NotImplementedError):
73 MazeDataset.download(TEST_CONFIGS[0])
76def test_serialize_load():
77 dataset = MazeDataset.generate(TEST_CONFIGS[0])
78 dataset_copy = MazeDataset.load(dataset.serialize())
80 assert dataset.cfg == dataset_copy.cfg
81 for maze, maze_copy in zip(dataset, dataset_copy, strict=False):
82 assert maze == maze_copy
85@pytest.mark.parametrize(
86 "config",
87 [
88 pytest.param(
89 c,
90 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}",
91 )
92 for c in TEST_CONFIGS
93 ],
94)
95def test_serialize_load_minimal(config):
96 d = MazeDataset.generate(config, gen_parallel=False)
97 d_loaded = MazeDataset.load(d._serialize_minimal())
98 d_loaded.assert_equal(d)
99 assert d_loaded == d
102@pytest.mark.parametrize(
103 "config",
104 [
105 pytest.param(
106 c,
107 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}",
108 )
109 for c in TEST_CONFIGS
110 ],
111)
112def test_save_read_minimal(config):
113 def save_and_read(d: MazeDataset, p: str):
114 d.save(file_path=p)
115 # read as MazeDataset
116 roundtrip = MazeDataset.read(p)
117 assert roundtrip == d
118 # read from zanj
119 z = ZANJ()
120 roundtrip_zanj = z.read(p)
121 assert roundtrip_zanj == d
123 d = MazeDataset.generate(config, gen_parallel=False)
124 p = Path("tests/_temp/test_maze_dataset/") / (d.cfg.to_fname() + ".zanj")
126 # Test with full serialization
127 set_serialize_minimal_threshold(None)
128 save_and_read(d, p)
130 # Test with minimal serialization
131 set_serialize_minimal_threshold(0)
132 save_and_read(d, p)
134 d.save(file_path=p)
135 # read as MazeDataset
136 roundtrip = MazeDataset.read(p)
137 assert d.cfg.diff(roundtrip.cfg) == dict()
138 cfg_diff = roundtrip.cfg.diff(d.cfg)
139 assert cfg_diff == {}
140 assert roundtrip.cfg == d.cfg
141 assert roundtrip.mazes == d.mazes
142 assert roundtrip == d
143 # read from zanj
144 z = ZANJ()
145 roundtrip_zanj = z.read(p)
146 assert roundtrip_zanj == d
149def test_custom_maze_filter():
150 connection_list = bool_array_from_string(
151 """
152 F T
153 F F
155 T F
156 T F
157 """,
158 shape=[2, 2, 2],
159 )
160 solutions = [
161 [[0, 0], [0, 1], [1, 1]],
162 [[0, 0], [0, 1]],
163 [[0, 0]],
164 ]
166 def custom_filter_solution_length(maze: SolvedMaze, solution_length: int) -> bool:
167 return len(maze.solution) == solution_length
169 mazes = [
170 SolvedMaze(connection_list=connection_list, solution=solution)
171 for solution in solutions
172 ]
173 dataset = MazeDataset(cfg=TEST_CONFIGS[0], mazes=mazes)
175 filtered_lambda = dataset.custom_maze_filter(lambda m: len(m.solution) == 1)
176 filtered_func = dataset.custom_maze_filter(
177 custom_filter_solution_length,
178 solution_length=1,
179 )
181 assert filtered_lambda.mazes == filtered_func.mazes == [mazes[2]]
184class TestMazeDatasetFilters:
185 config = MazeDatasetConfig(name="test", grid_n=3, n_mazes=5)
186 connection_list = bool_array_from_string(
187 """
188 F T
189 F F
191 T F
192 T F
193 """,
194 shape=[2, 2, 2],
195 )
197 def test_filters(self):
198 class TestDataset(MazeDataset): ...
200 @register_filter_namespace_for_dataset(TestDataset)
201 class TestFilters:
202 @register_maze_filter
203 @staticmethod
204 def solution_match(maze: SolvedMaze, solution: CoordArray) -> bool:
205 """Test for solution equality"""
206 return (maze.solution == solution).all()
208 @register_dataset_filter
209 @staticmethod
210 def drop_nth(dataset: TestDataset, n: int) -> TestDataset:
211 """Filter mazes by path length"""
212 return copy.deepcopy(
213 TestDataset(
214 dataset.cfg,
215 [maze for i, maze in enumerate(dataset) if i != n],
216 ),
217 )
219 maze1 = SolvedMaze(
220 connection_list=self.connection_list,
221 solution=np.array([[0, 0]]),
222 )
223 maze2 = SolvedMaze(
224 connection_list=self.connection_list,
225 solution=np.array([[0, 1]]),
226 )
228 dataset = TestDataset(self.config, [maze1, maze2])
230 maze_filter = dataset.filter_by.solution_match(solution=np.array([[0, 0]]))
231 maze_filter2 = dataset.filter_by.solution_match(np.array([[0, 0]]))
233 dataset_filter = dataset.filter_by.drop_nth(n=0)
234 dataset_filter2 = dataset.filter_by.drop_nth(0)
236 assert maze_filter.mazes == maze_filter2.mazes == [maze1]
237 assert dataset_filter.mazes == dataset_filter2.mazes == [maze2]
239 def test_path_length(self):
240 long_maze = SolvedMaze(
241 connection_list=self.connection_list,
242 solution=np.array([[0, 0], [0, 1], [1, 1]]),
243 )
245 short_maze = SolvedMaze(
246 connection_list=self.connection_list,
247 solution=np.array([[0, 0], [0, 1]]),
248 )
250 dataset = MazeDataset(self.config, [long_maze, short_maze])
251 path_length_filtered = dataset.filter_by.path_length(3)
252 start_end_filtered = dataset.filter_by.start_end_distance(2)
254 assert type(path_length_filtered) == type(dataset) # noqa: E721
255 assert path_length_filtered.mazes == [long_maze]
256 assert start_end_filtered.mazes == [long_maze]
257 assert dataset.mazes == [long_maze, short_maze]
259 def test_cut_percentile_shortest(self):
260 solutions = [
261 [[0, 0], [0, 1], [1, 1]],
262 [[0, 0], [0, 1]],
263 [[0, 0]],
264 ]
266 mazes = [
267 SolvedMaze(connection_list=self.connection_list, solution=solution)
268 for solution in solutions
269 ]
270 dataset = MazeDataset(cfg=self.config, mazes=mazes)
271 filtered = dataset.filter_by.cut_percentile_shortest(49.0)
273 assert filtered.mazes == mazes[:2]
276DUPE_DATASET = [
277 """
278#####
279# E#
280###X#
281#SXX#
282#####
283""",
284 """
285#####
286#SXE#
287### #
288# #
289#####
290""",
291 """
292#####
293# E#
294###X#
295#SXX#
296#####
297""",
298 """
299#####
300# # #
301# # #
302#EXS#
303#####
304""",
305 """
306#####
307#SXX#
308###X#
309#EXX#
310#####
311""",
312]
315def _helper_dataset_from_ascii(ascii_rep: str) -> MazeDataset:
316 mazes: list[SolvedMaze] = list()
317 for maze_ascii in ascii_rep:
318 # TODO: PERF401 Use `list.extend` to create a transformed list
319 mazes.append(SolvedMaze.from_ascii(maze_ascii.strip()))
321 return MazeDataset(
322 MazeDatasetConfig(
323 name="test",
324 grid_n=mazes[0].grid_shape[0],
325 n_mazes=len(mazes),
326 ),
327 mazes,
328 )
331def test_remove_duplicates():
332 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET)
333 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates()
335 assert len(dataset) == 5
336 assert dataset_deduped.mazes == [dataset.mazes[3], dataset.mazes[4]]
339def test_data_hash():
340 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET)
341 hash_1 = dataset.data_hash()
342 hash_2 = dataset.data_hash()
344 assert hash_1 == hash_2
347def test_remove_duplicates_fast():
348 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET)
349 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates_fast()
351 assert len(dataset) == 5
352 assert dataset_deduped.mazes == [
353 dataset.mazes[0],
354 dataset.mazes[1],
355 dataset.mazes[3],
356 dataset.mazes[4],
357 ]