Coverage for maze_dataset/dataset/configs.py: 47%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 14:42 -0600

1"`MAZE_DATASET_CONFIGS` contains some default configs for tests and demos" 

2 

3import copy 

4from typing import Callable, Iterator, Mapping 

5 

6from maze_dataset.dataset.maze_dataset import MazeDatasetConfig 

7from maze_dataset.generation.generators import LatticeMazeGenerators 

8 

9_MAZE_DATASET_CONFIGS_SRC: dict[str, MazeDatasetConfig] = { 

10 cfg.to_fname(): cfg 

11 for cfg in [ 

12 MazeDatasetConfig( 

13 name="test", 

14 grid_n=3, 

15 n_mazes=5, 

16 maze_ctor=LatticeMazeGenerators.gen_dfs, 

17 ), 

18 MazeDatasetConfig( 

19 name="test-perc", 

20 grid_n=3, 

21 n_mazes=5, 

22 maze_ctor=LatticeMazeGenerators.gen_dfs_percolation, 

23 maze_ctor_kwargs={"p": 0.7}, 

24 ), 

25 MazeDatasetConfig( 

26 name="demo_small", 

27 grid_n=3, 

28 n_mazes=100, 

29 maze_ctor=LatticeMazeGenerators.gen_dfs, 

30 ), 

31 MazeDatasetConfig( 

32 name="demo", 

33 grid_n=6, 

34 n_mazes=10000, 

35 maze_ctor=LatticeMazeGenerators.gen_dfs, 

36 ), 

37 ] 

38} 

39 

40 

41class _MazeDatsetConfigsWrapper(Mapping[str, MazeDatasetConfig]): 

42 "wrap the default configs in a read-only dict-like object" 

43 

44 def __init__(self, configs: dict[str, MazeDatasetConfig]) -> None: 

45 "initialize with a dict of configs" 

46 self._configs = configs 

47 

48 def __getitem__(self, item: str) -> MazeDatasetConfig: 

49 return self._configs[item] 

50 

51 def __len__(self) -> int: 

52 return len(self._configs) 

53 

54 def __iter__(self) -> Iterator: 

55 "iterate over the keys" 

56 return iter(self._configs) 

57 

58 # TYPING: error: Return type "list[str]" of "keys" incompatible with return type "KeysView[str]" in supertype "Mapping" [override] 

59 def keys(self) -> list[str]: # type: ignore[override] 

60 "return the keys" 

61 return list(self._configs.keys()) 

62 

63 # TYPING: error: Return type "list[tuple[str, MazeDatasetConfig]]" of "items" incompatible with return type "ItemsView[str, MazeDatasetConfig]" in supertype "Mapping" [override] 

64 def items(self) -> list[tuple[str, MazeDatasetConfig]]: # type: ignore[override] 

65 "return the items" 

66 return [(k, copy.deepcopy(v)) for k, v in self._configs.items()] 

67 

68 # TYPING: error: Return type "list[MazeDatasetConfig]" of "values" incompatible with return type "ValuesView[MazeDatasetConfig]" in supertype "Mapping" [override] 

69 def values(self) -> list[MazeDatasetConfig]: # type: ignore[override] 

70 return [copy.deepcopy(v) for v in self._configs.values()] 

71 

72 

73MAZE_DATASET_CONFIGS: _MazeDatsetConfigsWrapper = _MazeDatsetConfigsWrapper( 

74 _MAZE_DATASET_CONFIGS_SRC, 

75) 

76 

77 

78def _get_configs_for_examples() -> list[dict]: 

79 """Generate a comprehensive list of diverse maze configurations. 

80 

81 # Returns: 

82 - `list[dict]` 

83 List of configuration dictionaries for maze generation 

84 """ 

85 configs: list[dict] = [] 

86 

87 # Define the grid sizes to test 

88 grid_sizes: list[int] = [5, 8, 12, 15, 20] 

89 

90 # Define percolation probabilities 

91 percolation_probs: list[float] = [0.3, 0.5, 0.7] 

92 

93 # Core algorithms with basic configurations 

94 basic_algorithms: dict[str, tuple[Callable, dict]] = { 

95 "dfs": (LatticeMazeGenerators.gen_dfs, {}), 

96 "wilson": (LatticeMazeGenerators.gen_wilson, {}), 

97 "kruskal": (LatticeMazeGenerators.gen_kruskal, {}), 

98 "recursive_division": (LatticeMazeGenerators.gen_recursive_division, {}), 

99 } 

100 

101 # Generate basic configurations for each algorithm and grid size 

102 for grid_n in grid_sizes: 

103 for algo_name, (maze_ctor, base_kwargs) in basic_algorithms.items(): 

104 configs.append( 

105 dict( 

106 name="basic", 

107 grid_n=grid_n, 

108 maze_ctor=maze_ctor, 

109 maze_ctor_kwargs=base_kwargs, 

110 description=f"Basic {algo_name.upper()} maze ({grid_n}x{grid_n})", 

111 tags=[f"algo:{algo_name}", "basic", f"grid:{grid_n}"], 

112 ) 

113 ) 

114 

115 # Generate percolation configurations 

116 for grid_n in grid_sizes: 

117 for p in percolation_probs: 

118 # Pure percolation 

119 configs.append( 

120 dict( 

121 name=f"p{p}", 

122 grid_n=grid_n, 

123 maze_ctor=LatticeMazeGenerators.gen_percolation, 

124 maze_ctor_kwargs=dict(p=p), 

125 description=f"Pure percolation (p={p}) ({grid_n}x{grid_n})", 

126 tags=[ 

127 "algo:percolation", 

128 "percolation", 

129 f"percolation:{p}", 

130 f"grid:{grid_n}", 

131 ], 

132 ) 

133 ) 

134 

135 # DFS with percolation 

136 configs.append( 

137 dict( 

138 name=f"p{p}", 

139 grid_n=grid_n, 

140 maze_ctor=LatticeMazeGenerators.gen_dfs_percolation, 

141 maze_ctor_kwargs=dict(p=p), 

142 description=f"DFS with percolation (p={p}) ({grid_n}x{grid_n})", 

143 tags=[ 

144 "algo:dfs_percolation", 

145 "dfs", 

146 "percolation", 

147 f"percolation:{p}", 

148 f"grid:{grid_n}", 

149 ], 

150 ) 

151 ) 

152 

153 # Generate specialized constraint configurations 

154 constraint_base_config: dict = dict( 

155 grid_n=10, 

156 maze_ctor=LatticeMazeGenerators.gen_dfs, 

157 ) 

158 constraint_base_tags: list[str] = [ 

159 "algo:dfs", 

160 "dfs", 

161 "constrained_dfs", 

162 f"grid:{constraint_base_config['grid_n']}", 

163 ] 

164 

165 constraint_configs: list[dict] = [ 

166 # DFS without forks (simple path) 

167 dict( 

168 name="forkless", 

169 maze_ctor_kwargs=dict(do_forks=False), 

170 description="DFS without forks (10x10)", 

171 tags=["forkless"], 

172 ), 

173 # Accessible cells constraints 

174 dict( 

175 name="accessible_cells_count", 

176 maze_ctor_kwargs=dict(accessible_cells=50), 

177 description="DFS with limited accessible cells (50)", 

178 tags=["limited:cells", "limited:absolute"], 

179 ), 

180 dict( 

181 name="accessible_cells_ratio", 

182 maze_ctor_kwargs=dict(accessible_cells=0.6), 

183 description="DFS with 60% accessible cells", 

184 tags=["limited:cells", "limited:ratio"], 

185 ), 

186 # Tree depth constraints 

187 dict( 

188 name="max_tree_depth_absolute", 

189 maze_ctor_kwargs=dict(max_tree_depth=10), 

190 description="DFS with max tree depth of 10", 

191 tags=["limited:depth", "limited:absolute"], 

192 ), 

193 dict( 

194 name="max_tree_depth_ratio", 

195 maze_ctor_kwargs=dict(max_tree_depth=0.3), 

196 description="DFS with max tree depth 30% of grid size", 

197 tags=["limited:depth", "limited:ratio"], 

198 ), 

199 # Start position constraint 

200 dict( 

201 name="start_center", 

202 maze_ctor_kwargs=dict(start_coord=[5, 5]), 

203 description="DFS starting from center of grid", 

204 tags=["custom_start"], 

205 ), 

206 dict( 

207 name="start_corner", 

208 maze_ctor_kwargs=dict(start_coord=[0, 0]), 

209 description="DFS starting from corner of grid", 

210 tags=["custom_start"], 

211 ), 

212 ] 

213 

214 # Add combined constraints as special case 

215 configs.append( 

216 dict( 

217 name="combined_constraints", 

218 grid_n=15, 

219 maze_ctor=LatticeMazeGenerators.gen_dfs, 

220 maze_ctor_kwargs=dict( 

221 accessible_cells=100, 

222 max_tree_depth=25, 

223 start_coord=[7, 7], 

224 ), 

225 description="DFS with multiple constraints (100 cells, depth 25, center start)", 

226 tags=["algo:dfs", "dfs", "constrained_dfs", "grid:15"], 

227 ) 

228 ) 

229 

230 # Apply the base config to all constraint configs and add to main configs list 

231 for config in constraint_configs: 

232 full_config = constraint_base_config.copy() 

233 full_config.update(config) 

234 full_config["tags"] = constraint_base_tags + config["tags"] 

235 configs.append(full_config) 

236 

237 # Generate endpoint options 

238 endpoint_variations: list[tuple[bool, bool, str]] = [ 

239 (True, False, "deadend start only"), 

240 (False, True, "deadend end only"), 

241 (True, True, "deadend start and end"), 

242 ] 

243 

244 for deadend_start, deadend_end, desc in endpoint_variations: 

245 configs.append( 

246 dict( 

247 name=f"deadend_s{int(deadend_start)}_e{int(deadend_end)}", 

248 grid_n=8, 

249 maze_ctor=LatticeMazeGenerators.gen_dfs, 

250 maze_ctor_kwargs={}, 

251 endpoint_kwargs=dict( 

252 deadend_start=deadend_start, 

253 deadend_end=deadend_end, 

254 endpoints_not_equal=True, 

255 ), 

256 description=f"DFS with {desc}", 

257 tags=["algo:dfs", "dfs", "deadend_endpoints", "grid:8"], 

258 ) 

259 ) 

260 

261 # Add percolation with deadend endpoints 

262 configs.append( 

263 dict( 

264 name="deadends", 

265 grid_n=8, 

266 maze_ctor=LatticeMazeGenerators.gen_dfs_percolation, 

267 maze_ctor_kwargs=dict(p=0.3), 

268 endpoint_kwargs=dict( 

269 deadend_start=True, 

270 deadend_end=True, 

271 endpoints_not_equal=True, 

272 except_on_no_valid_endpoint=False, 

273 ), 

274 description="DFS percolation (p=0.3) with deadend endpoints", 

275 tags=[ 

276 "algo:dfs_percolation", 

277 "dfs", 

278 "percolation", 

279 "deadend_endpoints", 

280 "grid:8", 

281 ], 

282 ) 

283 ) 

284 

285 return configs