Coverage for tests/unit/generation/test_custom_generator.py: 81%

119 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-08-03 22:20 -0700

1"""Tests for custom maze generator registration system""" 

2 

3from pathlib import Path 

4 

5import numpy as np 

6import pytest 

7from zanj import ZANJ 

8 

9from maze_dataset import LatticeMaze, MazeDataset, MazeDatasetConfig 

10from maze_dataset.constants import Coord, CoordTup 

11from maze_dataset.generation import ( 

12 GENERATORS_MAP, 

13 LatticeMazeGenerators, 

14 get_maze_with_solution, 

15) 

16from maze_dataset.generation.registration import ( 

17 MazeGeneratorRegistrationError, 

18 register_maze_generator, 

19) 

20 

21# Temp directory for file operations 

22TEMP_PATH = Path("tests/_temp/test_custom_generator/") 

23 

24 

25def test_register_valid_function(): 

26 """Test that a valid function can be registered successfully""" 

27 

28 @register_maze_generator 

29 def gen_test_valid( 

30 grid_shape: Coord | CoordTup, 

31 lattice_dim: int = 2, 

32 ) -> LatticeMaze: 

33 """Simple test generator - fully connected grid""" 

34 grid_shape_: Coord = np.array(grid_shape) 

35 connection_list: np.ndarray = np.zeros( 

36 (lattice_dim, *grid_shape_), dtype=np.bool_ 

37 ) 

38 

39 # Create fully connected grid 

40 if grid_shape_[1] > 1: 

41 connection_list[1, :, : grid_shape_[1] - 1] = True 

42 if grid_shape_[0] > 1: 

43 connection_list[0, : grid_shape_[0] - 1, :] = True 

44 

45 return LatticeMaze( 

46 connection_list=connection_list, 

47 generation_meta=dict( 

48 func_name="gen_test_valid", 

49 grid_shape=grid_shape_, 

50 fully_connected=True, 

51 ), 

52 ) 

53 

54 # Test registration worked 

55 assert "gen_test_valid" in GENERATORS_MAP 

56 assert hasattr(LatticeMazeGenerators, "gen_test_valid") 

57 

58 # Test function works 

59 maze = get_maze_with_solution("gen_test_valid", (5, 5)) 

60 assert maze.grid_shape == (5, 5) 

61 

62 # Test via LatticeMazeGenerators 

63 maze2 = LatticeMazeGenerators.gen_test_valid((4, 4)) 

64 assert maze2.grid_shape == (4, 4) 

65 

66 

67def test_maze_dataset_config_with_custom_generator(): 

68 """Test creating, saving, and loading MazeDatasetConfig with custom generator""" 

69 

70 @register_maze_generator 

71 def gen_test_config( 

72 grid_shape: Coord | CoordTup, 

73 custom_param: float = 0.5, 

74 ) -> LatticeMaze: 

75 """Test generator with custom parameter""" 

76 grid_shape_: Coord = np.array(grid_shape) 

77 connection_list: np.ndarray = np.zeros((2, *grid_shape_), dtype=np.bool_) 

78 

79 # Simple fully connected pattern 

80 if grid_shape_[1] > 1: 

81 connection_list[1, :, : grid_shape_[1] - 1] = True 

82 if grid_shape_[0] > 1: 

83 connection_list[0, : grid_shape_[0] - 1, :] = True 

84 

85 return LatticeMaze( 

86 connection_list=connection_list, 

87 generation_meta=dict( 

88 func_name="gen_test_config", 

89 grid_shape=grid_shape_, 

90 custom_param=custom_param, 

91 fully_connected=True, 

92 ), 

93 ) 

94 

95 # Create config with custom generator 

96 config = MazeDatasetConfig( 

97 name="test_custom", 

98 grid_n=5, 

99 n_mazes=3, 

100 maze_ctor=gen_test_config, 

101 maze_ctor_kwargs={"custom_param": 0.7}, 

102 ) 

103 

104 # Test serialization/deserialization 

105 serialized = config.serialize() 

106 loaded_config = MazeDatasetConfig.load(serialized) 

107 

108 assert loaded_config.name == config.name 

109 assert loaded_config.grid_n == config.grid_n 

110 assert loaded_config.n_mazes == config.n_mazes 

111 assert loaded_config.maze_ctor_kwargs == config.maze_ctor_kwargs 

112 

113 # Test save/load to file using ZANJ 

114 TEMP_PATH.mkdir(parents=True, exist_ok=True) 

115 config_path = TEMP_PATH / "test_config.zanj" 

116 

117 z = ZANJ() 

118 z.save(config, config_path) 

119 file_loaded_config = z.read(config_path) 

120 

121 assert file_loaded_config.name == config.name 

122 assert file_loaded_config.maze_ctor_kwargs == config.maze_ctor_kwargs 

123 

124 

125def test_maze_dataset_with_custom_generator(): 

126 """Test creating, saving, and loading MazeDataset with custom generator""" 

127 

128 @register_maze_generator 

129 def gen_test_dataset( 

130 grid_shape: Coord | CoordTup, 

131 lattice_dim: int = 2, 

132 ) -> LatticeMaze: 

133 """Test generator for dataset creation""" 

134 grid_shape_: Coord = np.array(grid_shape) 

135 connection_list: np.ndarray = np.zeros( 

136 (lattice_dim, *grid_shape_), dtype=np.bool_ 

137 ) 

138 

139 # Create simple pattern - connect every cell to its right/down neighbor 

140 if grid_shape_[1] > 1: 

141 connection_list[1, :, : grid_shape_[1] - 1] = True 

142 if grid_shape_[0] > 1: 

143 connection_list[0, : grid_shape_[0] - 1, :] = True 

144 

145 return LatticeMaze( 

146 connection_list=connection_list, 

147 generation_meta=dict( 

148 func_name="gen_test_dataset", 

149 grid_shape=grid_shape_, 

150 fully_connected=True, 

151 ), 

152 ) 

153 

154 # Create config and generate dataset 

155 config = MazeDatasetConfig( 

156 name="test_dataset", 

157 grid_n=4, 

158 n_mazes=2, 

159 maze_ctor=gen_test_dataset, 

160 maze_ctor_kwargs={}, 

161 ) 

162 

163 dataset = MazeDataset.generate(config, gen_parallel=False) 

164 

165 # Test dataset properties 

166 assert len(dataset) == 2 

167 for maze in dataset: 

168 assert maze.grid_shape == (4, 4) 

169 

170 # Test save/load dataset 

171 TEMP_PATH.mkdir(parents=True, exist_ok=True) 

172 dataset_path = TEMP_PATH / "test_dataset.zanj" 

173 

174 dataset.save(dataset_path) 

175 loaded_dataset = MazeDataset.read(dataset_path) 

176 

177 assert len(loaded_dataset) == len(dataset) 

178 assert loaded_dataset.cfg.name == dataset.cfg.name 

179 for original, loaded in zip(dataset, loaded_dataset, strict=True): 

180 assert original.grid_shape == loaded.grid_shape 

181 assert np.array_equal(original.connection_list, loaded.connection_list) 

182 

183 

184# bunch of type ignores here, because we are testing to make sure that 

185# the registration system raises errors for invalid function signatures 

186 

187 

188def test_registration_error_missing_grid_shape(): 

189 """Test error when function is missing grid_shape parameter""" 

190 

191 def invalid_missing_grid_shape(x): 

192 assert x # Use parameter to avoid warning 

193 return LatticeMaze(np.zeros((2, 3, 3), dtype=np.bool_), {}) 

194 

195 with pytest.raises( 

196 MazeGeneratorRegistrationError, 

197 match="must have 'grid_shape' as its first parameter", 

198 ): 

199 register_maze_generator(invalid_missing_grid_shape) # type: ignore[type-var] 

200 

201 

202def test_registration_error_wrong_param_name(): 

203 """Test error when first parameter has wrong name""" 

204 

205 def invalid_wrong_param_name(shape): 

206 assert shape # Use parameter to avoid warning 

207 return LatticeMaze(np.zeros((2, 3, 3), dtype=np.bool_), {}) 

208 

209 with pytest.raises( 

210 MazeGeneratorRegistrationError, 

211 match="must have 'grid_shape' as its first parameter", 

212 ): 

213 register_maze_generator(invalid_wrong_param_name) # type: ignore[type-var] 

214 

215 

216def test_registration_error_missing_type_annotation(): 

217 """Test error when grid_shape lacks type annotation""" 

218 

219 def invalid_missing_type_annotation(grid_shape): 

220 assert grid_shape # Use parameter to avoid warning 

221 return LatticeMaze(np.zeros((2, 3, 3), dtype=np.bool_), {}) 

222 

223 with pytest.raises( 

224 MazeGeneratorRegistrationError, 

225 match=r"must be typed as 'Coord \| CoordTup' or compatible type", 

226 ): 

227 register_maze_generator(invalid_missing_type_annotation) # type: ignore[type-var] 

228 

229 

230def test_registration_error_missing_return_annotation(): 

231 """Test error when function lacks return type annotation""" 

232 

233 def invalid_missing_return_annotation(grid_shape: Coord | CoordTup): 

234 assert grid_shape is not None # Use parameter to avoid warning 

235 return LatticeMaze(np.zeros((2, 3, 3), dtype=np.bool_), {}) 

236 

237 with pytest.raises( 

238 MazeGeneratorRegistrationError, 

239 match="must have a return type annotation of LatticeMaze", 

240 ): 

241 register_maze_generator(invalid_missing_return_annotation) # type: ignore[type-var] 

242 

243 

244def test_registration_error_wrong_return_type(): 

245 """Test error when function has wrong return type annotation""" 

246 

247 def invalid_wrong_return_type(grid_shape: Coord | CoordTup) -> str: 

248 assert grid_shape is not None # Use parameter to avoid warning 

249 return "wrong" 

250 

251 with pytest.raises(MazeGeneratorRegistrationError, match="must return LatticeMaze"): 

252 register_maze_generator(invalid_wrong_return_type) # type: ignore[type-var] 

253 

254 

255def test_registration_error_invalid_grid_shape_type(): 

256 """Test error when grid_shape has invalid type annotation""" 

257 

258 def invalid_grid_shape_type(grid_shape: str) -> LatticeMaze: 

259 assert grid_shape # Use parameter to avoid warning 

260 return LatticeMaze(np.zeros((2, 3, 3), dtype=np.bool_), {}) 

261 

262 with pytest.raises( 

263 MazeGeneratorRegistrationError, 

264 match=r"must be typed as 'Coord \| CoordTup' or compatible type", 

265 ): 

266 register_maze_generator(invalid_grid_shape_type) # type: ignore[type-var] 

267 

268 

269def test_duplicate_registration_error(): 

270 """Test that registering a function with an existing name raises an error""" 

271 

272 @register_maze_generator 

273 def gen_test_duplicate_unique( 

274 grid_shape: Coord | CoordTup, 

275 ) -> LatticeMaze: 

276 """First registration""" 

277 assert grid_shape is not None # Use parameter to avoid warning 

278 return LatticeMaze( 

279 np.zeros((2, 3, 3), dtype=np.bool_), 

280 generation_meta={ 

281 "func_name": "gen_test_duplicate_unique", 

282 "fully_connected": True, 

283 }, 

284 ) 

285 

286 # Try to register another function with the same name 

287 # type ignore because we are intentionally using the same name 

288 def gen_test_duplicate_unique( # type: ignore[no-redef] # noqa: F811 

289 grid_shape: Coord | CoordTup, 

290 ) -> LatticeMaze: 

291 """Second registration attempt with same name""" 

292 assert grid_shape is not None # Use parameter to avoid warning 

293 return LatticeMaze( 

294 np.zeros((2, 3, 3), dtype=np.bool_), 

295 generation_meta={ 

296 "func_name": "gen_test_duplicate_unique", 

297 "fully_connected": True, 

298 }, 

299 ) 

300 

301 with pytest.raises(ValueError, match="already exists in GENERATORS_MAP"): 

302 register_maze_generator(gen_test_duplicate_unique)