Coverage for tests/unit/generation/test_latticemaze.py: 100%

110 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1import numpy as np 

2import pytest 

3 

4from maze_dataset.constants import CoordArray 

5from maze_dataset.generation.default_generators import DEFAULT_GENERATORS 

6from maze_dataset.generation.generators import GENERATORS_MAP 

7from maze_dataset.maze import LatticeMaze, PixelColors, SolvedMaze, TargetedLatticeMaze 

8from maze_dataset.utils import adj_list_to_nested_set, bool_array_from_string 

9 

10 

11# thanks to gpt for these tests of _from_pixel_grid 

12@pytest.fixture 

13def example_pixel_grid(): 

14 return ~np.array( 

15 [ 

16 [1, 1, 1, 1, 1], 

17 [1, 0, 0, 0, 1], 

18 [1, 1, 1, 0, 1], 

19 [1, 0, 0, 0, 1], 

20 [1, 1, 1, 1, 1], 

21 ], 

22 dtype=bool, 

23 ) 

24 

25 

26@pytest.fixture 

27def example_rgb_pixel_grid(): 

28 return np.array( 

29 [ 

30 [ 

31 PixelColors.WALL, 

32 PixelColors.WALL, 

33 PixelColors.WALL, 

34 PixelColors.WALL, 

35 PixelColors.WALL, 

36 ], 

37 [ 

38 PixelColors.WALL, 

39 PixelColors.OPEN, 

40 PixelColors.OPEN, 

41 PixelColors.OPEN, 

42 PixelColors.WALL, 

43 ], 

44 [ 

45 PixelColors.WALL, 

46 PixelColors.WALL, 

47 PixelColors.WALL, 

48 PixelColors.WALL, 

49 PixelColors.WALL, 

50 ], 

51 [ 

52 PixelColors.WALL, 

53 PixelColors.OPEN, 

54 PixelColors.WALL, 

55 PixelColors.OPEN, 

56 PixelColors.WALL, 

57 ], 

58 [ 

59 PixelColors.WALL, 

60 PixelColors.WALL, 

61 PixelColors.WALL, 

62 PixelColors.WALL, 

63 PixelColors.WALL, 

64 ], 

65 ], 

66 dtype=np.uint8, 

67 ) 

68 

69 

70def test_from_pixel_grid_bw(example_pixel_grid): 

71 connection_list, grid_shape = LatticeMaze._from_pixel_grid_bw(example_pixel_grid) 

72 

73 assert isinstance(connection_list, np.ndarray) 

74 assert connection_list.shape == (2, 2, 2) 

75 assert np.all(connection_list[0] == np.array([[False, True], [False, False]])) 

76 assert np.all(connection_list[1] == np.array([[True, False], [True, False]])) 

77 assert grid_shape == (2, 2) 

78 

79 

80def test_from_pixel_grid_with_positions(example_rgb_pixel_grid): 

81 marked_positions = { 

82 "start": PixelColors.START, 

83 "end": PixelColors.END, 

84 "path": PixelColors.PATH, 

85 } 

86 

87 ( 

88 connection_list, 

89 grid_shape, 

90 out_positions, 

91 ) = LatticeMaze._from_pixel_grid_with_positions( 

92 example_rgb_pixel_grid, 

93 marked_positions, 

94 ) 

95 

96 assert isinstance(connection_list, np.ndarray) 

97 assert connection_list.shape == (2, 2, 2) 

98 assert np.all(connection_list[0] == np.array([[False, False], [False, False]])) 

99 assert np.all(connection_list[1] == np.array([[True, False], [False, False]])) 

100 assert grid_shape == (2, 2) 

101 

102 assert isinstance(out_positions, dict) 

103 assert len(out_positions) == 3 

104 

105 assert "start" in out_positions 

106 assert "end" in out_positions 

107 

108 assert isinstance(out_positions["start"], np.ndarray) 

109 assert isinstance(out_positions["end"], np.ndarray) 

110 assert isinstance(out_positions["path"], np.ndarray) 

111 

112 assert out_positions["start"].shape == (0,) 

113 assert out_positions["end"].shape == (0,) 

114 assert out_positions["path"].shape == (0,) 

115 

116 

117def test_find_start_end_points_in_rgb_pixel_grid(): 

118 rgb_pixel_grid_with_positions = np.array( 

119 [ 

120 [ 

121 PixelColors.WALL, 

122 PixelColors.WALL, 

123 PixelColors.WALL, 

124 PixelColors.WALL, 

125 PixelColors.WALL, 

126 ], 

127 [ 

128 PixelColors.WALL, 

129 PixelColors.START, 

130 PixelColors.OPEN, 

131 PixelColors.END, 

132 PixelColors.WALL, 

133 ], 

134 [ 

135 PixelColors.WALL, 

136 PixelColors.WALL, 

137 PixelColors.WALL, 

138 PixelColors.WALL, 

139 PixelColors.WALL, 

140 ], 

141 [ 

142 PixelColors.WALL, 

143 PixelColors.OPEN, 

144 PixelColors.WALL, 

145 PixelColors.OPEN, 

146 PixelColors.WALL, 

147 ], 

148 [ 

149 PixelColors.WALL, 

150 PixelColors.WALL, 

151 PixelColors.WALL, 

152 PixelColors.WALL, 

153 PixelColors.WALL, 

154 ], 

155 ], 

156 dtype=np.uint8, 

157 ) 

158 

159 marked_positions = { 

160 "start": PixelColors.START, 

161 "end": PixelColors.END, 

162 "path": PixelColors.PATH, 

163 } 

164 

165 ( 

166 connection_list, 

167 grid_shape, 

168 out_positions, 

169 ) = LatticeMaze._from_pixel_grid_with_positions( 

170 rgb_pixel_grid_with_positions, 

171 marked_positions, 

172 ) 

173 

174 print(f"{out_positions = }") 

175 

176 assert isinstance(out_positions, dict) 

177 assert len(out_positions) == 3 

178 assert "start" in out_positions 

179 assert "end" in out_positions 

180 assert isinstance(out_positions["start"], np.ndarray) 

181 assert isinstance(out_positions["end"], np.ndarray) 

182 assert isinstance(out_positions["path"], np.ndarray) 

183 

184 assert np.all(out_positions["start"] == np.array([[0, 0]])) 

185 assert np.all(out_positions["end"] == np.array([[0, 1]])) 

186 assert out_positions["path"].shape == (0,) 

187 

188 

189@pytest.mark.parametrize(("gfunc_name", "kwargs"), DEFAULT_GENERATORS) 

190def test_pixels_ascii_roundtrip(gfunc_name, kwargs): 

191 """tests all generators work and can be written to/from ascii and pixels""" 

192 n: int = 5 

193 maze_gen_func = GENERATORS_MAP[gfunc_name] 

194 maze: LatticeMaze = maze_gen_func(np.array([n, n]), **kwargs) 

195 

196 maze_pixels: np.ndarray = maze.as_pixels() 

197 maze_ascii: str = maze.as_ascii() 

198 

199 assert maze == LatticeMaze.from_pixels(maze_pixels) 

200 assert maze == LatticeMaze.from_ascii(maze_ascii) 

201 

202 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3) 

203 assert maze_pixels.shape == expected_shape, ( 

204 f"{maze_pixels.shape} != {expected_shape}" 

205 ) 

206 assert all(n * 2 + 1 == len(line) for line in maze_ascii.splitlines()), ( 

207 f"{maze_ascii}" 

208 ) 

209 

210 

211@pytest.mark.parametrize(("gfunc_name", "kwargs"), DEFAULT_GENERATORS) 

212def test_targeted_solved_maze(gfunc_name, kwargs): 

213 n: int = 5 

214 maze_gen_func = GENERATORS_MAP[gfunc_name] 

215 maze: LatticeMaze = maze_gen_func(np.array([n, n]), **kwargs) 

216 solution: CoordArray = maze.generate_random_path() 

217 tgt_maze: TargetedLatticeMaze = TargetedLatticeMaze.from_lattice_maze( 

218 maze, 

219 solution[0], 

220 solution[-1], 

221 ) 

222 

223 tgt_maze_pixels: np.ndarray = tgt_maze.as_pixels() 

224 tgt_maze_ascii: str = tgt_maze.as_ascii() 

225 

226 assert tgt_maze == TargetedLatticeMaze.from_pixels(tgt_maze_pixels) 

227 assert tgt_maze == TargetedLatticeMaze.from_ascii(tgt_maze_ascii) 

228 

229 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3) 

230 assert tgt_maze_pixels.shape == expected_shape, ( 

231 f"{tgt_maze_pixels.shape} != {expected_shape}" 

232 ) 

233 assert all(n * 2 + 1 == len(line) for line in tgt_maze_ascii.splitlines()), ( 

234 f"{tgt_maze_ascii}" 

235 ) 

236 

237 solved_maze: SolvedMaze = SolvedMaze.from_targeted_lattice_maze(tgt_maze) 

238 

239 solved_maze_pixels: np.ndarray = solved_maze.as_pixels() 

240 solved_maze_ascii: str = solved_maze.as_ascii() 

241 

242 assert solved_maze == SolvedMaze.from_pixels(solved_maze_pixels) 

243 assert solved_maze == SolvedMaze.from_ascii(solved_maze_ascii) 

244 

245 expected_shape: tuple = (n * 2 + 1, n * 2 + 1, 3) 

246 assert tgt_maze_pixels.shape == expected_shape, ( 

247 f"{tgt_maze_pixels.shape} != {expected_shape}" 

248 ) 

249 assert all(n * 2 + 1 == len(line) for line in solved_maze_ascii.splitlines()), ( 

250 f"{solved_maze_ascii}" 

251 ) 

252 

253 

254def test_as_adj_list(): 

255 connection_list = bool_array_from_string( 

256 """ 

257 F T 

258 F F 

259 

260 T F 

261 T F 

262 """, 

263 shape=[2, 2, 2], 

264 ) 

265 

266 maze = LatticeMaze(connection_list=connection_list) 

267 

268 adj_list = maze.as_adj_list(shuffle_d0=False, shuffle_d1=False) 

269 

270 expected = [[[0, 1], [1, 1]], [[0, 0], [0, 1]], [[1, 0], [1, 1]]] 

271 

272 assert adj_list_to_nested_set(expected) == adj_list_to_nested_set(adj_list) 

273 

274 

275@pytest.mark.parametrize(("gfunc_name", "kwargs"), DEFAULT_GENERATORS) 

276def test_get_nodes(gfunc_name, kwargs): 

277 maze_gen_func = GENERATORS_MAP[gfunc_name] 

278 maze = maze_gen_func(np.array((3, 2)), **kwargs) 

279 assert ( 

280 maze.get_nodes().tolist() 

281 == np.array([(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1)]).tolist() 

282 ) 

283 

284 

285@pytest.mark.parametrize(("gfunc_name", "kwargs"), DEFAULT_GENERATORS) 

286def test_generate_random_path(gfunc_name, kwargs): 

287 maze_gen_func = GENERATORS_MAP[gfunc_name] 

288 maze = maze_gen_func(np.array((2, 2)), **kwargs) 

289 path = maze.generate_random_path() 

290 

291 # len > 1 ensures that we have unique start and end nodes 

292 assert len(path) > 1 

293 

294 

295@pytest.mark.parametrize(("gfunc_name", "kwargs"), DEFAULT_GENERATORS) 

296def test_generate_random_path_size_1(gfunc_name, kwargs): 

297 maze_gen_func = GENERATORS_MAP[gfunc_name] 

298 maze = maze_gen_func(np.array((1, 1)), **kwargs) 

299 with pytest.raises(AssertionError): 

300 maze.generate_random_path()