Coverage for maze_dataset/dataset/configs.py: 47%
47 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 14:42 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 14:42 -0600
1"`MAZE_DATASET_CONFIGS` contains some default configs for tests and demos"
3import copy
4from typing import Callable, Iterator, Mapping
6from maze_dataset.dataset.maze_dataset import MazeDatasetConfig
7from maze_dataset.generation.generators import LatticeMazeGenerators
9_MAZE_DATASET_CONFIGS_SRC: dict[str, MazeDatasetConfig] = {
10 cfg.to_fname(): cfg
11 for cfg in [
12 MazeDatasetConfig(
13 name="test",
14 grid_n=3,
15 n_mazes=5,
16 maze_ctor=LatticeMazeGenerators.gen_dfs,
17 ),
18 MazeDatasetConfig(
19 name="test-perc",
20 grid_n=3,
21 n_mazes=5,
22 maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
23 maze_ctor_kwargs={"p": 0.7},
24 ),
25 MazeDatasetConfig(
26 name="demo_small",
27 grid_n=3,
28 n_mazes=100,
29 maze_ctor=LatticeMazeGenerators.gen_dfs,
30 ),
31 MazeDatasetConfig(
32 name="demo",
33 grid_n=6,
34 n_mazes=10000,
35 maze_ctor=LatticeMazeGenerators.gen_dfs,
36 ),
37 ]
38}
41class _MazeDatsetConfigsWrapper(Mapping[str, MazeDatasetConfig]):
42 "wrap the default configs in a read-only dict-like object"
44 def __init__(self, configs: dict[str, MazeDatasetConfig]) -> None:
45 "initialize with a dict of configs"
46 self._configs = configs
48 def __getitem__(self, item: str) -> MazeDatasetConfig:
49 return self._configs[item]
51 def __len__(self) -> int:
52 return len(self._configs)
54 def __iter__(self) -> Iterator:
55 "iterate over the keys"
56 return iter(self._configs)
58 # TYPING: error: Return type "list[str]" of "keys" incompatible with return type "KeysView[str]" in supertype "Mapping" [override]
59 def keys(self) -> list[str]: # type: ignore[override]
60 "return the keys"
61 return list(self._configs.keys())
63 # TYPING: error: Return type "list[tuple[str, MazeDatasetConfig]]" of "items" incompatible with return type "ItemsView[str, MazeDatasetConfig]" in supertype "Mapping" [override]
64 def items(self) -> list[tuple[str, MazeDatasetConfig]]: # type: ignore[override]
65 "return the items"
66 return [(k, copy.deepcopy(v)) for k, v in self._configs.items()]
68 # TYPING: error: Return type "list[MazeDatasetConfig]" of "values" incompatible with return type "ValuesView[MazeDatasetConfig]" in supertype "Mapping" [override]
69 def values(self) -> list[MazeDatasetConfig]: # type: ignore[override]
70 return [copy.deepcopy(v) for v in self._configs.values()]
73MAZE_DATASET_CONFIGS: _MazeDatsetConfigsWrapper = _MazeDatsetConfigsWrapper(
74 _MAZE_DATASET_CONFIGS_SRC,
75)
78def _get_configs_for_examples() -> list[dict]:
79 """Generate a comprehensive list of diverse maze configurations.
81 # Returns:
82 - `list[dict]`
83 List of configuration dictionaries for maze generation
84 """
85 configs: list[dict] = []
87 # Define the grid sizes to test
88 grid_sizes: list[int] = [5, 8, 12, 15, 20]
90 # Define percolation probabilities
91 percolation_probs: list[float] = [0.3, 0.5, 0.7]
93 # Core algorithms with basic configurations
94 basic_algorithms: dict[str, tuple[Callable, dict]] = {
95 "dfs": (LatticeMazeGenerators.gen_dfs, {}),
96 "wilson": (LatticeMazeGenerators.gen_wilson, {}),
97 "kruskal": (LatticeMazeGenerators.gen_kruskal, {}),
98 "recursive_division": (LatticeMazeGenerators.gen_recursive_division, {}),
99 }
101 # Generate basic configurations for each algorithm and grid size
102 for grid_n in grid_sizes:
103 for algo_name, (maze_ctor, base_kwargs) in basic_algorithms.items():
104 configs.append(
105 dict(
106 name="basic",
107 grid_n=grid_n,
108 maze_ctor=maze_ctor,
109 maze_ctor_kwargs=base_kwargs,
110 description=f"Basic {algo_name.upper()} maze ({grid_n}x{grid_n})",
111 tags=[f"algo:{algo_name}", "basic", f"grid:{grid_n}"],
112 )
113 )
115 # Generate percolation configurations
116 for grid_n in grid_sizes:
117 for p in percolation_probs:
118 # Pure percolation
119 configs.append(
120 dict(
121 name=f"p{p}",
122 grid_n=grid_n,
123 maze_ctor=LatticeMazeGenerators.gen_percolation,
124 maze_ctor_kwargs=dict(p=p),
125 description=f"Pure percolation (p={p}) ({grid_n}x{grid_n})",
126 tags=[
127 "algo:percolation",
128 "percolation",
129 f"percolation:{p}",
130 f"grid:{grid_n}",
131 ],
132 )
133 )
135 # DFS with percolation
136 configs.append(
137 dict(
138 name=f"p{p}",
139 grid_n=grid_n,
140 maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
141 maze_ctor_kwargs=dict(p=p),
142 description=f"DFS with percolation (p={p}) ({grid_n}x{grid_n})",
143 tags=[
144 "algo:dfs_percolation",
145 "dfs",
146 "percolation",
147 f"percolation:{p}",
148 f"grid:{grid_n}",
149 ],
150 )
151 )
153 # Generate specialized constraint configurations
154 constraint_base_config: dict = dict(
155 grid_n=10,
156 maze_ctor=LatticeMazeGenerators.gen_dfs,
157 )
158 constraint_base_tags: list[str] = [
159 "algo:dfs",
160 "dfs",
161 "constrained_dfs",
162 f"grid:{constraint_base_config['grid_n']}",
163 ]
165 constraint_configs: list[dict] = [
166 # DFS without forks (simple path)
167 dict(
168 name="forkless",
169 maze_ctor_kwargs=dict(do_forks=False),
170 description="DFS without forks (10x10)",
171 tags=["forkless"],
172 ),
173 # Accessible cells constraints
174 dict(
175 name="accessible_cells_count",
176 maze_ctor_kwargs=dict(accessible_cells=50),
177 description="DFS with limited accessible cells (50)",
178 tags=["limited:cells", "limited:absolute"],
179 ),
180 dict(
181 name="accessible_cells_ratio",
182 maze_ctor_kwargs=dict(accessible_cells=0.6),
183 description="DFS with 60% accessible cells",
184 tags=["limited:cells", "limited:ratio"],
185 ),
186 # Tree depth constraints
187 dict(
188 name="max_tree_depth_absolute",
189 maze_ctor_kwargs=dict(max_tree_depth=10),
190 description="DFS with max tree depth of 10",
191 tags=["limited:depth", "limited:absolute"],
192 ),
193 dict(
194 name="max_tree_depth_ratio",
195 maze_ctor_kwargs=dict(max_tree_depth=0.3),
196 description="DFS with max tree depth 30% of grid size",
197 tags=["limited:depth", "limited:ratio"],
198 ),
199 # Start position constraint
200 dict(
201 name="start_center",
202 maze_ctor_kwargs=dict(start_coord=[5, 5]),
203 description="DFS starting from center of grid",
204 tags=["custom_start"],
205 ),
206 dict(
207 name="start_corner",
208 maze_ctor_kwargs=dict(start_coord=[0, 0]),
209 description="DFS starting from corner of grid",
210 tags=["custom_start"],
211 ),
212 ]
214 # Add combined constraints as special case
215 configs.append(
216 dict(
217 name="combined_constraints",
218 grid_n=15,
219 maze_ctor=LatticeMazeGenerators.gen_dfs,
220 maze_ctor_kwargs=dict(
221 accessible_cells=100,
222 max_tree_depth=25,
223 start_coord=[7, 7],
224 ),
225 description="DFS with multiple constraints (100 cells, depth 25, center start)",
226 tags=["algo:dfs", "dfs", "constrained_dfs", "grid:15"],
227 )
228 )
230 # Apply the base config to all constraint configs and add to main configs list
231 for config in constraint_configs:
232 full_config = constraint_base_config.copy()
233 full_config.update(config)
234 full_config["tags"] = constraint_base_tags + config["tags"]
235 configs.append(full_config)
237 # Generate endpoint options
238 endpoint_variations: list[tuple[bool, bool, str]] = [
239 (True, False, "deadend start only"),
240 (False, True, "deadend end only"),
241 (True, True, "deadend start and end"),
242 ]
244 for deadend_start, deadend_end, desc in endpoint_variations:
245 configs.append(
246 dict(
247 name=f"deadend_s{int(deadend_start)}_e{int(deadend_end)}",
248 grid_n=8,
249 maze_ctor=LatticeMazeGenerators.gen_dfs,
250 maze_ctor_kwargs={},
251 endpoint_kwargs=dict(
252 deadend_start=deadend_start,
253 deadend_end=deadend_end,
254 endpoints_not_equal=True,
255 ),
256 description=f"DFS with {desc}",
257 tags=["algo:dfs", "dfs", "deadend_endpoints", "grid:8"],
258 )
259 )
261 # Add percolation with deadend endpoints
262 configs.append(
263 dict(
264 name="deadends",
265 grid_n=8,
266 maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
267 maze_ctor_kwargs=dict(p=0.3),
268 endpoint_kwargs=dict(
269 deadend_start=True,
270 deadend_end=True,
271 endpoints_not_equal=True,
272 except_on_no_valid_endpoint=False,
273 ),
274 description="DFS percolation (p=0.3) with deadend endpoints",
275 tags=[
276 "algo:dfs_percolation",
277 "dfs",
278 "percolation",
279 "deadend_endpoints",
280 "grid:8",
281 ],
282 )
283 )
285 return configs