Coverage for tests/unit/generation/test_custom_generator.py: 81%
119 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 22:20 -0700
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 22:20 -0700
1"""Tests for custom maze generator registration system"""
3from pathlib import Path
5import numpy as np
6import pytest
7from zanj import ZANJ
9from maze_dataset import LatticeMaze, MazeDataset, MazeDatasetConfig
10from maze_dataset.constants import Coord, CoordTup
11from maze_dataset.generation import (
12 GENERATORS_MAP,
13 LatticeMazeGenerators,
14 get_maze_with_solution,
15)
16from maze_dataset.generation.registration import (
17 MazeGeneratorRegistrationError,
18 register_maze_generator,
19)
21# Temp directory for file operations
22TEMP_PATH = Path("tests/_temp/test_custom_generator/")
25def test_register_valid_function():
26 """Test that a valid function can be registered successfully"""
28 @register_maze_generator
29 def gen_test_valid(
30 grid_shape: Coord | CoordTup,
31 lattice_dim: int = 2,
32 ) -> LatticeMaze:
33 """Simple test generator - fully connected grid"""
34 grid_shape_: Coord = np.array(grid_shape)
35 connection_list: np.ndarray = np.zeros(
36 (lattice_dim, *grid_shape_), dtype=np.bool_
37 )
39 # Create fully connected grid
40 if grid_shape_[1] > 1:
41 connection_list[1, :, : grid_shape_[1] - 1] = True
42 if grid_shape_[0] > 1:
43 connection_list[0, : grid_shape_[0] - 1, :] = True
45 return LatticeMaze(
46 connection_list=connection_list,
47 generation_meta=dict(
48 func_name="gen_test_valid",
49 grid_shape=grid_shape_,
50 fully_connected=True,
51 ),
52 )
54 # Test registration worked
55 assert "gen_test_valid" in GENERATORS_MAP
56 assert hasattr(LatticeMazeGenerators, "gen_test_valid")
58 # Test function works
59 maze = get_maze_with_solution("gen_test_valid", (5, 5))
60 assert maze.grid_shape == (5, 5)
62 # Test via LatticeMazeGenerators
63 maze2 = LatticeMazeGenerators.gen_test_valid((4, 4))
64 assert maze2.grid_shape == (4, 4)
67def test_maze_dataset_config_with_custom_generator():
68 """Test creating, saving, and loading MazeDatasetConfig with custom generator"""
70 @register_maze_generator
71 def gen_test_config(
72 grid_shape: Coord | CoordTup,
73 custom_param: float = 0.5,
74 ) -> LatticeMaze:
75 """Test generator with custom parameter"""
76 grid_shape_: Coord = np.array(grid_shape)
77 connection_list: np.ndarray = np.zeros((2, *grid_shape_), dtype=np.bool_)
79 # Simple fully connected pattern
80 if grid_shape_[1] > 1:
81 connection_list[1, :, : grid_shape_[1] - 1] = True
82 if grid_shape_[0] > 1:
83 connection_list[0, : grid_shape_[0] - 1, :] = True
85 return LatticeMaze(
86 connection_list=connection_list,
87 generation_meta=dict(
88 func_name="gen_test_config",
89 grid_shape=grid_shape_,
90 custom_param=custom_param,
91 fully_connected=True,
92 ),
93 )
95 # Create config with custom generator
96 config = MazeDatasetConfig(
97 name="test_custom",
98 grid_n=5,
99 n_mazes=3,
100 maze_ctor=gen_test_config,
101 maze_ctor_kwargs={"custom_param": 0.7},
102 )
104 # Test serialization/deserialization
105 serialized = config.serialize()
106 loaded_config = MazeDatasetConfig.load(serialized)
108 assert loaded_config.name == config.name
109 assert loaded_config.grid_n == config.grid_n
110 assert loaded_config.n_mazes == config.n_mazes
111 assert loaded_config.maze_ctor_kwargs == config.maze_ctor_kwargs
113 # Test save/load to file using ZANJ
114 TEMP_PATH.mkdir(parents=True, exist_ok=True)
115 config_path = TEMP_PATH / "test_config.zanj"
117 z = ZANJ()
118 z.save(config, config_path)
119 file_loaded_config = z.read(config_path)
121 assert file_loaded_config.name == config.name
122 assert file_loaded_config.maze_ctor_kwargs == config.maze_ctor_kwargs
125def test_maze_dataset_with_custom_generator():
126 """Test creating, saving, and loading MazeDataset with custom generator"""
128 @register_maze_generator
129 def gen_test_dataset(
130 grid_shape: Coord | CoordTup,
131 lattice_dim: int = 2,
132 ) -> LatticeMaze:
133 """Test generator for dataset creation"""
134 grid_shape_: Coord = np.array(grid_shape)
135 connection_list: np.ndarray = np.zeros(
136 (lattice_dim, *grid_shape_), dtype=np.bool_
137 )
139 # Create simple pattern - connect every cell to its right/down neighbor
140 if grid_shape_[1] > 1:
141 connection_list[1, :, : grid_shape_[1] - 1] = True
142 if grid_shape_[0] > 1:
143 connection_list[0, : grid_shape_[0] - 1, :] = True
145 return LatticeMaze(
146 connection_list=connection_list,
147 generation_meta=dict(
148 func_name="gen_test_dataset",
149 grid_shape=grid_shape_,
150 fully_connected=True,
151 ),
152 )
154 # Create config and generate dataset
155 config = MazeDatasetConfig(
156 name="test_dataset",
157 grid_n=4,
158 n_mazes=2,
159 maze_ctor=gen_test_dataset,
160 maze_ctor_kwargs={},
161 )
163 dataset = MazeDataset.generate(config, gen_parallel=False)
165 # Test dataset properties
166 assert len(dataset) == 2
167 for maze in dataset:
168 assert maze.grid_shape == (4, 4)
170 # Test save/load dataset
171 TEMP_PATH.mkdir(parents=True, exist_ok=True)
172 dataset_path = TEMP_PATH / "test_dataset.zanj"
174 dataset.save(dataset_path)
175 loaded_dataset = MazeDataset.read(dataset_path)
177 assert len(loaded_dataset) == len(dataset)
178 assert loaded_dataset.cfg.name == dataset.cfg.name
179 for original, loaded in zip(dataset, loaded_dataset, strict=True):
180 assert original.grid_shape == loaded.grid_shape
181 assert np.array_equal(original.connection_list, loaded.connection_list)
184# bunch of type ignores here, because we are testing to make sure that
185# the registration system raises errors for invalid function signatures
188def test_registration_error_missing_grid_shape():
189 """Test error when function is missing grid_shape parameter"""
191 def invalid_missing_grid_shape(x):
192 assert x # Use parameter to avoid warning
193 return LatticeMaze(np.zeros((2, 3, 3), dtype=np.bool_), {})
195 with pytest.raises(
196 MazeGeneratorRegistrationError,
197 match="must have 'grid_shape' as its first parameter",
198 ):
199 register_maze_generator(invalid_missing_grid_shape) # type: ignore[type-var]
202def test_registration_error_wrong_param_name():
203 """Test error when first parameter has wrong name"""
205 def invalid_wrong_param_name(shape):
206 assert shape # Use parameter to avoid warning
207 return LatticeMaze(np.zeros((2, 3, 3), dtype=np.bool_), {})
209 with pytest.raises(
210 MazeGeneratorRegistrationError,
211 match="must have 'grid_shape' as its first parameter",
212 ):
213 register_maze_generator(invalid_wrong_param_name) # type: ignore[type-var]
216def test_registration_error_missing_type_annotation():
217 """Test error when grid_shape lacks type annotation"""
219 def invalid_missing_type_annotation(grid_shape):
220 assert grid_shape # Use parameter to avoid warning
221 return LatticeMaze(np.zeros((2, 3, 3), dtype=np.bool_), {})
223 with pytest.raises(
224 MazeGeneratorRegistrationError,
225 match=r"must be typed as 'Coord \| CoordTup' or compatible type",
226 ):
227 register_maze_generator(invalid_missing_type_annotation) # type: ignore[type-var]
230def test_registration_error_missing_return_annotation():
231 """Test error when function lacks return type annotation"""
233 def invalid_missing_return_annotation(grid_shape: Coord | CoordTup):
234 assert grid_shape is not None # Use parameter to avoid warning
235 return LatticeMaze(np.zeros((2, 3, 3), dtype=np.bool_), {})
237 with pytest.raises(
238 MazeGeneratorRegistrationError,
239 match="must have a return type annotation of LatticeMaze",
240 ):
241 register_maze_generator(invalid_missing_return_annotation) # type: ignore[type-var]
244def test_registration_error_wrong_return_type():
245 """Test error when function has wrong return type annotation"""
247 def invalid_wrong_return_type(grid_shape: Coord | CoordTup) -> str:
248 assert grid_shape is not None # Use parameter to avoid warning
249 return "wrong"
251 with pytest.raises(MazeGeneratorRegistrationError, match="must return LatticeMaze"):
252 register_maze_generator(invalid_wrong_return_type) # type: ignore[type-var]
255def test_registration_error_invalid_grid_shape_type():
256 """Test error when grid_shape has invalid type annotation"""
258 def invalid_grid_shape_type(grid_shape: str) -> LatticeMaze:
259 assert grid_shape # Use parameter to avoid warning
260 return LatticeMaze(np.zeros((2, 3, 3), dtype=np.bool_), {})
262 with pytest.raises(
263 MazeGeneratorRegistrationError,
264 match=r"must be typed as 'Coord \| CoordTup' or compatible type",
265 ):
266 register_maze_generator(invalid_grid_shape_type) # type: ignore[type-var]
269def test_duplicate_registration_error():
270 """Test that registering a function with an existing name raises an error"""
272 @register_maze_generator
273 def gen_test_duplicate_unique(
274 grid_shape: Coord | CoordTup,
275 ) -> LatticeMaze:
276 """First registration"""
277 assert grid_shape is not None # Use parameter to avoid warning
278 return LatticeMaze(
279 np.zeros((2, 3, 3), dtype=np.bool_),
280 generation_meta={
281 "func_name": "gen_test_duplicate_unique",
282 "fully_connected": True,
283 },
284 )
286 # Try to register another function with the same name
287 # type ignore because we are intentionally using the same name
288 def gen_test_duplicate_unique( # type: ignore[no-redef] # noqa: F811
289 grid_shape: Coord | CoordTup,
290 ) -> LatticeMaze:
291 """Second registration attempt with same name"""
292 assert grid_shape is not None # Use parameter to avoid warning
293 return LatticeMaze(
294 np.zeros((2, 3, 3), dtype=np.bool_),
295 generation_meta={
296 "func_name": "gen_test_duplicate_unique",
297 "fully_connected": True,
298 },
299 )
301 with pytest.raises(ValueError, match="already exists in GENERATORS_MAP"):
302 register_maze_generator(gen_test_duplicate_unique)