Coverage for maze_dataset/benchmark/config_sweep.py: 0%

171 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 12:48 -0600

1"""Benchmarking of how successful maze generation is for various values of percolation""" 

2 

3import functools 

4import json 

5import warnings 

6from pathlib import Path 

7from typing import Any, Callable, Generic, Literal, Sequence, TypeVar 

8 

9import matplotlib.pyplot as plt 

10import numpy as np 

11from jaxtyping import Float 

12from muutils.dictmagic import dotlist_to_nested_dict, update_with_nested_dict 

13from muutils.json_serialize import ( 

14 JSONitem, 

15 SerializableDataclass, 

16 json_serialize, 

17 serializable_dataclass, 

18 serializable_field, 

19) 

20from muutils.parallel import run_maybe_parallel 

21from zanj import ZANJ 

22 

23from maze_dataset import MazeDataset, MazeDatasetConfig 

24from maze_dataset.generation import LatticeMazeGenerators 

25 

26SweepReturnType = TypeVar("SweepReturnType") 

27ParamType = TypeVar("ParamType") 

28AnalysisFunc = Callable[[MazeDatasetConfig], SweepReturnType] 

29 

30 

31def dataset_success_fraction(cfg: MazeDatasetConfig) -> float: 

32 """empirical success fraction of maze generation 

33 

34 for use as an `analyze_func` in `sweep()` 

35 """ 

36 dataset: MazeDataset = MazeDataset.from_config( 

37 cfg, 

38 do_download=False, 

39 load_local=False, 

40 save_local=False, 

41 verbose=False, 

42 ) 

43 

44 return len(dataset) / cfg.n_mazes 

45 

46 

47ANALYSIS_FUNCS: dict[str, AnalysisFunc] = dict( 

48 dataset_success_fraction=dataset_success_fraction, 

49) 

50 

51 

52def sweep( 

53 cfg_base: MazeDatasetConfig, 

54 param_values: list[ParamType], 

55 param_key: str, 

56 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 

57) -> list[SweepReturnType]: 

58 """given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value 

59 

60 # Parameters: 

61 - `cfg_base : MazeDatasetConfig` 

62 base config on which we will modify the value at `param_key` with values from `param_values` 

63 - `param_values : list[ParamType]` 

64 list of values to try 

65 - `param_key : str` 

66 value to modify in `cfg_base` 

67 - `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]` 

68 function which analyzes the resulting config. originally built for `dataset_success_fraction` 

69 

70 # Returns: 

71 - `list[SweepReturnType]` 

72 _description_ 

73 """ 

74 outputs: list[SweepReturnType] = [] 

75 

76 for p in param_values: 

77 # update the config 

78 cfg_dict: dict = cfg_base.serialize() 

79 update_with_nested_dict( 

80 cfg_dict, 

81 dotlist_to_nested_dict({param_key: p}), 

82 ) 

83 cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict) 

84 

85 outputs.append(analyze_func(cfg_test)) 

86 

87 return outputs 

88 

89 

90@serializable_dataclass() 

91class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]): 

92 """result of a parameter sweep""" 

93 

94 configs: list[MazeDatasetConfig] = serializable_field( 

95 serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs], 

96 deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs], 

97 ) 

98 param_values: list[ParamType] = serializable_field( 

99 serialization_fn=lambda x: json_serialize(x), 

100 deserialize_fn=lambda x: x, 

101 assert_type=False, 

102 ) 

103 result_values: dict[str, Sequence[SweepReturnType]] = serializable_field( 

104 serialization_fn=lambda x: json_serialize(x), 

105 deserialize_fn=lambda x: x, 

106 assert_type=False, 

107 ) 

108 param_key: str 

109 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field( 

110 serialization_fn=lambda f: f.__name__, 

111 deserialize_fn=ANALYSIS_FUNCS.get, 

112 assert_type=False, 

113 ) 

114 

115 def summary(self) -> JSONitem: 

116 "human-readable and json-dumpable short summary of the result" 

117 return { 

118 "len(configs)": len(self.configs), 

119 "len(param_values)": len(self.param_values), 

120 "len(result_values)": len(self.result_values), 

121 "param_key": self.param_key, 

122 "analyze_func": self.analyze_func.__name__, 

123 } 

124 

125 def save(self, path: str | Path, z: ZANJ | None = None) -> None: 

126 "save to a file with zanj" 

127 if z is None: 

128 z = ZANJ() 

129 

130 z.save(self, path) 

131 

132 @classmethod 

133 def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult": 

134 "read from a file with zanj" 

135 if z is None: 

136 z = ZANJ() 

137 

138 return z.read(path) 

139 

140 def configs_by_name(self) -> dict[str, MazeDatasetConfig]: 

141 "return configs by name" 

142 return {cfg.name: cfg for cfg in self.configs} 

143 

144 def configs_by_key(self) -> dict[str, MazeDatasetConfig]: 

145 "return configs by the key used in `result_values`, which is the filename of the config" 

146 return {cfg.to_fname(): cfg for cfg in self.configs} 

147 

148 def configs_shared(self) -> dict[str, Any]: 

149 "return key: value pairs that are shared across all configs" 

150 # we know that the configs all have the same keys, 

151 # so this way of doing it is fine 

152 config_vals: dict[str, set[Any]] = dict() 

153 for cfg in self.configs: 

154 for k, v in cfg.serialize().items(): 

155 if k not in config_vals: 

156 config_vals[k] = set() 

157 config_vals[k].add(json.dumps(v)) 

158 

159 shared_vals: dict[str, Any] = dict() 

160 

161 cfg_ser: dict = self.configs[0].serialize() 

162 for k, v in config_vals.items(): 

163 if len(v) == 1: 

164 shared_vals[k] = cfg_ser[k] 

165 

166 return shared_vals 

167 

168 def configs_differing_keys(self) -> set[str]: 

169 "return keys that differ across configs" 

170 shared_vals: dict[str, Any] = self.configs_shared() 

171 differing_keys: set[str] = set() 

172 

173 for k in MazeDatasetConfig.__dataclass_fields__: 

174 if k not in shared_vals: 

175 differing_keys.add(k) 

176 

177 return differing_keys 

178 

179 def configs_value_set(self, key: str) -> list[Any]: 

180 "return a list of the unique values for a given key" 

181 d: dict[str, Any] = { 

182 json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key) 

183 for cfg in self.configs 

184 } 

185 

186 return list(d.values()) 

187 

188 def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult": 

189 "get a subset of this `Result` where the configs has `key` satisfying `val_check`" 

190 configs_list: list[MazeDatasetConfig] = [ 

191 cfg for cfg in self.configs if val_check(getattr(cfg, key)) 

192 ] 

193 configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list} 

194 result_values: dict[str, Sequence[SweepReturnType]] = { 

195 k: self.result_values[k] for k in configs_keys 

196 } 

197 

198 return SweepResult( 

199 configs=configs_list, 

200 param_values=self.param_values, 

201 result_values=result_values, 

202 param_key=self.param_key, 

203 analyze_func=self.analyze_func, 

204 ) 

205 

206 @classmethod 

207 def analyze( 

208 cls, 

209 configs: list[MazeDatasetConfig], 

210 param_values: list[ParamType], 

211 param_key: str, 

212 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 

213 parallel: bool | int = False, 

214 **kwargs, 

215 ) -> "SweepResult": 

216 """Analyze success rate of maze generation for different percolation values 

217 

218 # Parameters: 

219 - `configs : list[MazeDatasetConfig]` 

220 configs to try 

221 - `param_values : np.ndarray` 

222 numpy array of values to try 

223 

224 # Returns: 

225 - `SweepResult` 

226 """ 

227 n_pvals: int = len(param_values) 

228 

229 result_values_list: list[float] = run_maybe_parallel( 

230 # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type] 

231 func=functools.partial( # type: ignore[arg-type] 

232 sweep, 

233 param_values=param_values, 

234 param_key=param_key, 

235 analyze_func=analyze_func, 

236 ), 

237 iterable=configs, 

238 keep_ordered=True, 

239 parallel=parallel, 

240 pbar_kwargs=dict(total=len(configs)), 

241 **kwargs, 

242 ) 

243 result_values: dict[str, Float[np.ndarray, n_pvals]] = { 

244 cfg.to_fname(): np.array(res) 

245 for cfg, res in zip(configs, result_values_list, strict=False) 

246 } 

247 return cls( 

248 configs=configs, 

249 param_values=param_values, 

250 # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type] 

251 result_values=result_values, # type: ignore[arg-type] 

252 param_key=param_key, 

253 analyze_func=analyze_func, 

254 ) 

255 

256 def plot( 

257 self, 

258 save_path: str | None = None, 

259 cfg_keys: list[str] | None = None, 

260 cmap_name: str | None = "viridis", 

261 plot_only: bool = False, 

262 show: bool = True, 

263 ax: plt.Axes | None = None, 

264 minify_title: bool = False, 

265 legend_kwargs: dict[str, Any] | None = None, 

266 ) -> plt.Axes: 

267 """Plot the results of percolation analysis""" 

268 # set up figure 

269 if not ax: 

270 fig: plt.Figure 

271 ax_: plt.Axes 

272 fig, ax_ = plt.subplots(1, 1, figsize=(22, 10)) 

273 else: 

274 ax_ = ax 

275 

276 # plot 

277 cmap = plt.get_cmap(cmap_name) 

278 n_cfgs: int = len(self.result_values) 

279 for i, (ep_cfg_name, result_values) in enumerate( 

280 sorted( 

281 self.result_values.items(), 

282 # HACK: sort by grid size 

283 # |--< name of config 

284 # | |-----------< gets 'g{n}' 

285 # | | |--< gets '{n}' 

286 # | | | 

287 key=lambda x: int(x[0].split("-")[0][1:]), 

288 ), 

289 ): 

290 ax_.plot( 

291 # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 

292 self.param_values, # type: ignore[arg-type] 

293 # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 

294 result_values, # type: ignore[arg-type] 

295 ".-", 

296 label=self.configs_by_key()[ep_cfg_name].name, 

297 color=cmap((i + 0.5) / (n_cfgs - 0.5)), 

298 ) 

299 

300 # repr of config 

301 cfg_shared: dict = self.configs_shared() 

302 if minify_title: 

303 cfg_shared["endpoint_kwargs"] = { 

304 k: v 

305 for k, v in cfg_shared["endpoint_kwargs"].items() 

306 if k != "except_on_no_valid_endpoint" 

307 } 

308 cfg_repr: str = ( 

309 str(cfg_shared) 

310 if cfg_keys is None 

311 else ( 

312 "MazeDatasetConfig(" 

313 + ", ".join( 

314 [ 

315 f"{k}={cfg_shared[k].__name__}" 

316 # TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type] 

317 if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type] 

318 else f"{k}={cfg_shared[k]}" 

319 for k in cfg_keys 

320 ], 

321 ) 

322 + ")" 

323 ) 

324 ) 

325 

326 # add title and stuff 

327 if not plot_only: 

328 ax_.set_xlabel(self.param_key) 

329 ax_.set_ylabel(self.analyze_func.__name__) 

330 ax_.set_title( 

331 f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}", 

332 ) 

333 ax_.grid(True) 

334 # ax_.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, -0.1)) 

335 legend_kwargs = { 

336 **dict(loc="center left"), 

337 **(legend_kwargs or dict()), 

338 } 

339 ax_.legend(**legend_kwargs) 

340 

341 # save and show 

342 if save_path: 

343 plt.savefig(save_path) 

344 

345 if show: 

346 plt.show() 

347 

348 return ax_ 

349 

350 

351DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [ 

352 ( 

353 "any", 

354 dict(deadend_start=False, deadend_end=False, except_on_no_valid_endpoint=False), 

355 ), 

356 ( 

357 "deadends", 

358 dict( 

359 deadend_start=True, 

360 deadend_end=True, 

361 endpoints_not_equal=False, 

362 except_on_no_valid_endpoint=False, 

363 ), 

364 ), 

365 ( 

366 "deadends_unique", 

367 dict( 

368 deadend_start=True, 

369 deadend_end=True, 

370 endpoints_not_equal=True, 

371 except_on_no_valid_endpoint=False, 

372 ), 

373 ), 

374] 

375 

376 

377def endpoint_kwargs_to_name(ep_kwargs: dict) -> str: 

378 """convert endpoint kwargs options to a human-readable name""" 

379 if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False): 

380 if ep_kwargs.get("endpoints_not_equal", False): 

381 return "deadends_unique" 

382 else: 

383 return "deadends" 

384 else: 

385 return "any" 

386 

387 

388def full_percolation_analysis( 

389 n_mazes: int, 

390 p_val_count: int, 

391 grid_sizes: list[int], 

392 ep_kwargs: list[tuple[str, dict]] | None = None, 

393 generators: Sequence[Callable] = ( 

394 LatticeMazeGenerators.gen_percolation, 

395 LatticeMazeGenerators.gen_dfs_percolation, 

396 ), 

397 save_dir: Path = Path("../docs/benchmarks/percolation_fractions"), 

398 parallel: bool | int = False, 

399 **analyze_kwargs, 

400) -> SweepResult: 

401 "run the full analysis of how percolation affects maze generation success" 

402 if ep_kwargs is None: 

403 ep_kwargs = DEFAULT_ENDPOINT_KWARGS 

404 

405 # configs 

406 configs: list[MazeDatasetConfig] = list() 

407 

408 # TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx` 

409 for ep_kw_name, ep_kw in ep_kwargs: # noqa: B007 

410 for gf_idx, gen_func in enumerate(generators): # noqa: B007 

411 configs.extend( 

412 [ 

413 MazeDatasetConfig( 

414 name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}", 

415 grid_n=grid_n, 

416 n_mazes=n_mazes, 

417 maze_ctor=gen_func, 

418 maze_ctor_kwargs=dict(p=float("nan")), 

419 endpoint_kwargs=ep_kw, 

420 ) 

421 for grid_n in grid_sizes 

422 ], 

423 ) 

424 

425 # get results 

426 result: SweepResult = SweepResult.analyze( 

427 configs=configs, # type: ignore[misc] 

428 # TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]" [arg-type] 

429 param_values=np.linspace(0.0, 1.0, p_val_count).tolist(), # type: ignore[arg-type] 

430 param_key="maze_ctor_kwargs.p", 

431 analyze_func=dataset_success_fraction, 

432 parallel=parallel, 

433 **analyze_kwargs, 

434 ) 

435 

436 # save the result 

437 results_path: Path = ( 

438 save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj" 

439 ) 

440 print(f"Saving results to {results_path.as_posix()}") 

441 result.save(results_path) 

442 

443 return result 

444 

445 

446def _is_eq(a, b) -> bool: # noqa: ANN001 

447 """check if two objects are equal""" 

448 return a == b 

449 

450 

451def plot_grouped( # noqa: C901 

452 results: SweepResult, 

453 predict_fn: Callable[[MazeDatasetConfig], float] | None = None, 

454 prediction_density: int = 50, 

455 save_dir: Path | None = None, 

456 show: bool = True, 

457 logy: bool = False, 

458 save_fmt: str = "svg", 

459 figsize: tuple[int, int] = (22, 10), 

460 minify_title: bool = False, 

461 legend_kwargs: dict[str, Any] | None = None, 

462 manual_titles: dict[Literal["x", "y", "title"], str] | None = None, 

463) -> None: 

464 """Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs 

465 

466 with separate colormaps for each maze generator function 

467 

468 # Parameters: 

469 - `results : SweepResult` 

470 The sweep results to plot 

471 - `predict_fn : Callable[[MazeDatasetConfig], float] | None` 

472 Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines. 

473 - `prediction_density : int` 

474 Number of points to use for prediction curves (default: 50) 

475 - `save_dir : Path | None` 

476 Directory to save plots (defaults to `None`, meaning no saving) 

477 - `show : bool` 

478 Whether to display the plots (defaults to `True`) 

479 

480 # Usage: 

481 ```python 

482 >>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16]) 

483 >>> plot_grouped(result, save_dir=Path("./plots"), show=False) 

484 ``` 

485 """ 

486 # groups 

487 endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs") # type: ignore[assignment] 

488 generator_funcs_names: list[str] = list( 

489 {cfg.maze_ctor.__name__ for cfg in results.configs}, 

490 ) 

491 

492 # if predicting, create denser p values 

493 if predict_fn is not None: 

494 p_dense = np.linspace(0.0, 1.0, prediction_density) 

495 

496 # separate plot for each set of endpoint kwargs 

497 for ep_kw in endpoint_kwargs_set: 

498 results_epkw: SweepResult = results.get_where( 

499 "endpoint_kwargs", 

500 functools.partial(_is_eq, b=ep_kw), 

501 # lambda x: x == ep_kw, 

502 ) 

503 shared_keys: set[str] = set(results_epkw.configs_shared().keys()) 

504 cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"}) 

505 fig, ax = plt.subplots(1, 1, figsize=figsize) 

506 for gf_idx, gen_func in enumerate(generator_funcs_names): 

507 results_filtered: SweepResult = results_epkw.get_where( 

508 "maze_ctor", 

509 # HACK: big hassle to do this without a lambda, is it really that bad? 

510 lambda x: x.__name__ == gen_func, # noqa: B023 

511 ) 

512 if len(results_filtered.configs) < 1: 

513 warnings.warn( 

514 f"No results for {gen_func} and {ep_kw}. Skipping.", 

515 ) 

516 continue 

517 

518 cmap_name = "Reds" if gf_idx == 0 else "Blues" 

519 cmap = plt.get_cmap(cmap_name) 

520 

521 # Plot actual results 

522 ax = results_filtered.plot( 

523 cfg_keys=list(cfg_keys), 

524 ax=ax, 

525 show=False, 

526 cmap_name=cmap_name, 

527 minify_title=minify_title, 

528 legend_kwargs=legend_kwargs, 

529 ) 

530 if logy: 

531 ax.set_yscale("log") 

532 

533 # Plot predictions if function provided 

534 if predict_fn is not None: 

535 for cfg_idx, cfg in enumerate(results_filtered.configs): 

536 predictions = [] 

537 for p in p_dense: 

538 cfg_temp = MazeDatasetConfig.load(cfg.serialize()) 

539 cfg_temp.maze_ctor_kwargs["p"] = p 

540 predictions.append(predict_fn(cfg_temp)) 

541 

542 # Get the same color as the actual data 

543 n_cfgs: int = len(results_filtered.configs) 

544 color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5)) 

545 

546 # Plot prediction as dashed line 

547 ax.plot(p_dense, predictions, "--", color=color, alpha=0.8) 

548 

549 if manual_titles: 

550 ax.set_xlabel(manual_titles["x"]) 

551 ax.set_ylabel(manual_titles["y"]) 

552 ax.set_title(manual_titles["title"]) 

553 

554 # save and show 

555 if save_dir: 

556 save_path: Path = ( 

557 save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.{save_fmt}" 

558 ) 

559 print(f"Saving plot to {save_path.as_posix()}") 

560 save_path.parent.mkdir(exist_ok=True, parents=True) 

561 plt.savefig(save_path) 

562 

563 if show: 

564 plt.show()