Coverage for tests/unit/generation/test_maze_dataset.py: 100%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1import copy 

2from pathlib import Path 

3 

4import numpy as np 

5import pytest 

6from zanj import ZANJ 

7 

8from maze_dataset import ( 

9 MazeDataset, 

10 MazeDatasetConfig, 

11 register_maze_filter, 

12 set_serialize_minimal_threshold, 

13) 

14from maze_dataset.constants import CoordArray 

15from maze_dataset.dataset.dataset import ( 

16 register_dataset_filter, 

17 register_filter_namespace_for_dataset, 

18) 

19from maze_dataset.generation.generators import GENERATORS_MAP 

20from maze_dataset.maze import SolvedMaze 

21from maze_dataset.utils import bool_array_from_string 

22 

23 

24class TestMazeDatasetConfig: 

25 pass 

26 

27 

28TEST_CONFIGS = [ 

29 MazeDatasetConfig( 

30 name="test", 

31 grid_n=grid_n, 

32 n_mazes=n_mazes, 

33 maze_ctor=GENERATORS_MAP["gen_dfs"], 

34 maze_ctor_kwargs=maze_ctor_kwargs, 

35 ) 

36 for grid_n, n_mazes, maze_ctor_kwargs in [ 

37 (3, 5, {}), 

38 (3, 1, {}), 

39 (5, 5, dict(do_forks=False)), 

40 ] 

41] 

42 

43 

44def test_generate_serial(): 

45 dataset = MazeDataset.generate(TEST_CONFIGS[0], gen_parallel=False) 

46 

47 assert len(dataset) == 5 

48 for maze in dataset: 

49 assert maze.grid_shape == (3, 3) 

50 

51 

52def test_generate_parallel(): 

53 dataset = MazeDataset.generate( 

54 TEST_CONFIGS[0], 

55 gen_parallel=True, 

56 verbose=True, 

57 pool_kwargs=dict(processes=2), 

58 ) 

59 

60 assert len(dataset) == 5 

61 for maze in dataset: 

62 assert maze.grid_shape == (3, 3) 

63 

64 

65def test_data_hash_wip(): 

66 dataset = MazeDataset.generate(TEST_CONFIGS[0]) 

67 # TODO: dataset.data_hash doesn't work right now 

68 assert dataset 

69 

70 

71def test_download(): 

72 with pytest.raises(NotImplementedError): 

73 MazeDataset.download(TEST_CONFIGS[0]) 

74 

75 

76def test_serialize_load(): 

77 dataset = MazeDataset.generate(TEST_CONFIGS[0]) 

78 dataset_copy = MazeDataset.load(dataset.serialize()) 

79 

80 assert dataset.cfg == dataset_copy.cfg 

81 for maze, maze_copy in zip(dataset, dataset_copy, strict=False): 

82 assert maze == maze_copy 

83 

84 

85@pytest.mark.parametrize( 

86 "config", 

87 [ 

88 pytest.param( 

89 c, 

90 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}", 

91 ) 

92 for c in TEST_CONFIGS 

93 ], 

94) 

95def test_serialize_load_minimal(config): 

96 d = MazeDataset.generate(config, gen_parallel=False) 

97 d_loaded = MazeDataset.load(d._serialize_minimal()) 

98 d_loaded.assert_equal(d) 

99 assert d_loaded == d 

100 

101 

102@pytest.mark.parametrize( 

103 "config", 

104 [ 

105 pytest.param( 

106 c, 

107 id=f"{c.grid_n=}; {c.n_mazes=}; {c.maze_ctor_kwargs=}", 

108 ) 

109 for c in TEST_CONFIGS 

110 ], 

111) 

112def test_save_read_minimal(config): 

113 def save_and_read(d: MazeDataset, p: str): 

114 d.save(file_path=p) 

115 # read as MazeDataset 

116 roundtrip = MazeDataset.read(p) 

117 assert roundtrip == d 

118 # read from zanj 

119 z = ZANJ() 

120 roundtrip_zanj = z.read(p) 

121 assert roundtrip_zanj == d 

122 

123 d = MazeDataset.generate(config, gen_parallel=False) 

124 p = Path("tests/_temp/test_maze_dataset/") / (d.cfg.to_fname() + ".zanj") 

125 

126 # Test with full serialization 

127 set_serialize_minimal_threshold(None) 

128 save_and_read(d, p) 

129 

130 # Test with minimal serialization 

131 set_serialize_minimal_threshold(0) 

132 save_and_read(d, p) 

133 

134 d.save(file_path=p) 

135 # read as MazeDataset 

136 roundtrip = MazeDataset.read(p) 

137 assert d.cfg.diff(roundtrip.cfg) == dict() 

138 cfg_diff = roundtrip.cfg.diff(d.cfg) 

139 assert cfg_diff == {} 

140 assert roundtrip.cfg == d.cfg 

141 assert roundtrip.mazes == d.mazes 

142 assert roundtrip == d 

143 # read from zanj 

144 z = ZANJ() 

145 roundtrip_zanj = z.read(p) 

146 assert roundtrip_zanj == d 

147 

148 

149def test_custom_maze_filter(): 

150 connection_list = bool_array_from_string( 

151 """ 

152 F T 

153 F F 

154 

155 T F 

156 T F 

157 """, 

158 shape=[2, 2, 2], 

159 ) 

160 solutions = [ 

161 [[0, 0], [0, 1], [1, 1]], 

162 [[0, 0], [0, 1]], 

163 [[0, 0]], 

164 ] 

165 

166 def custom_filter_solution_length(maze: SolvedMaze, solution_length: int) -> bool: 

167 return len(maze.solution) == solution_length 

168 

169 mazes = [ 

170 SolvedMaze(connection_list=connection_list, solution=solution) 

171 for solution in solutions 

172 ] 

173 dataset = MazeDataset(cfg=TEST_CONFIGS[0], mazes=mazes) 

174 

175 filtered_lambda = dataset.custom_maze_filter(lambda m: len(m.solution) == 1) 

176 filtered_func = dataset.custom_maze_filter( 

177 custom_filter_solution_length, 

178 solution_length=1, 

179 ) 

180 

181 assert filtered_lambda.mazes == filtered_func.mazes == [mazes[2]] 

182 

183 

184class TestMazeDatasetFilters: 

185 config = MazeDatasetConfig(name="test", grid_n=3, n_mazes=5) 

186 connection_list = bool_array_from_string( 

187 """ 

188 F T 

189 F F 

190 

191 T F 

192 T F 

193 """, 

194 shape=[2, 2, 2], 

195 ) 

196 

197 def test_filters(self): 

198 class TestDataset(MazeDataset): ... 

199 

200 @register_filter_namespace_for_dataset(TestDataset) 

201 class TestFilters: 

202 @register_maze_filter 

203 @staticmethod 

204 def solution_match(maze: SolvedMaze, solution: CoordArray) -> bool: 

205 """Test for solution equality""" 

206 return (maze.solution == solution).all() 

207 

208 @register_dataset_filter 

209 @staticmethod 

210 def drop_nth(dataset: TestDataset, n: int) -> TestDataset: 

211 """Filter mazes by path length""" 

212 return copy.deepcopy( 

213 TestDataset( 

214 dataset.cfg, 

215 [maze for i, maze in enumerate(dataset) if i != n], 

216 ), 

217 ) 

218 

219 maze1 = SolvedMaze( 

220 connection_list=self.connection_list, 

221 solution=np.array([[0, 0]]), 

222 ) 

223 maze2 = SolvedMaze( 

224 connection_list=self.connection_list, 

225 solution=np.array([[0, 1]]), 

226 ) 

227 

228 dataset = TestDataset(self.config, [maze1, maze2]) 

229 

230 maze_filter = dataset.filter_by.solution_match(solution=np.array([[0, 0]])) 

231 maze_filter2 = dataset.filter_by.solution_match(np.array([[0, 0]])) 

232 

233 dataset_filter = dataset.filter_by.drop_nth(n=0) 

234 dataset_filter2 = dataset.filter_by.drop_nth(0) 

235 

236 assert maze_filter.mazes == maze_filter2.mazes == [maze1] 

237 assert dataset_filter.mazes == dataset_filter2.mazes == [maze2] 

238 

239 def test_path_length(self): 

240 long_maze = SolvedMaze( 

241 connection_list=self.connection_list, 

242 solution=np.array([[0, 0], [0, 1], [1, 1]]), 

243 ) 

244 

245 short_maze = SolvedMaze( 

246 connection_list=self.connection_list, 

247 solution=np.array([[0, 0], [0, 1]]), 

248 ) 

249 

250 dataset = MazeDataset(self.config, [long_maze, short_maze]) 

251 path_length_filtered = dataset.filter_by.path_length(3) 

252 start_end_filtered = dataset.filter_by.start_end_distance(2) 

253 

254 assert type(path_length_filtered) == type(dataset) # noqa: E721 

255 assert path_length_filtered.mazes == [long_maze] 

256 assert start_end_filtered.mazes == [long_maze] 

257 assert dataset.mazes == [long_maze, short_maze] 

258 

259 def test_cut_percentile_shortest(self): 

260 solutions = [ 

261 [[0, 0], [0, 1], [1, 1]], 

262 [[0, 0], [0, 1]], 

263 [[0, 0]], 

264 ] 

265 

266 mazes = [ 

267 SolvedMaze(connection_list=self.connection_list, solution=solution) 

268 for solution in solutions 

269 ] 

270 dataset = MazeDataset(cfg=self.config, mazes=mazes) 

271 filtered = dataset.filter_by.cut_percentile_shortest(49.0) 

272 

273 assert filtered.mazes == mazes[:2] 

274 

275 

276DUPE_DATASET = [ 

277 """ 

278##### 

279# E# 

280###X# 

281#SXX# 

282##### 

283""", 

284 """ 

285##### 

286#SXE# 

287### # 

288# # 

289##### 

290""", 

291 """ 

292##### 

293# E# 

294###X# 

295#SXX# 

296##### 

297""", 

298 """ 

299##### 

300# # # 

301# # # 

302#EXS# 

303##### 

304""", 

305 """ 

306##### 

307#SXX# 

308###X# 

309#EXX# 

310##### 

311""", 

312] 

313 

314 

315def _helper_dataset_from_ascii(ascii_rep: str) -> MazeDataset: 

316 mazes: list[SolvedMaze] = list() 

317 for maze_ascii in ascii_rep: 

318 # TODO: PERF401 Use `list.extend` to create a transformed list 

319 mazes.append(SolvedMaze.from_ascii(maze_ascii.strip())) 

320 

321 return MazeDataset( 

322 MazeDatasetConfig( 

323 name="test", 

324 grid_n=mazes[0].grid_shape[0], 

325 n_mazes=len(mazes), 

326 ), 

327 mazes, 

328 ) 

329 

330 

331def test_remove_duplicates(): 

332 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET) 

333 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates() 

334 

335 assert len(dataset) == 5 

336 assert dataset_deduped.mazes == [dataset.mazes[3], dataset.mazes[4]] 

337 

338 

339def test_data_hash(): 

340 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET) 

341 hash_1 = dataset.data_hash() 

342 hash_2 = dataset.data_hash() 

343 

344 assert hash_1 == hash_2 

345 

346 

347def test_remove_duplicates_fast(): 

348 dataset: MazeDataset = _helper_dataset_from_ascii(DUPE_DATASET) 

349 dataset_deduped: MazeDataset = dataset.filter_by.remove_duplicates_fast() 

350 

351 assert len(dataset) == 5 

352 assert dataset_deduped.mazes == [ 

353 dataset.mazes[0], 

354 dataset.mazes[1], 

355 dataset.mazes[3], 

356 dataset.mazes[4], 

357 ]