Coverage for maze_dataset/benchmark/config_sweep.py: 0%
171 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
1"""Benchmarking of how successful maze generation is for various values of percolation"""
3import functools
4import json
5import warnings
6from pathlib import Path
7from typing import Any, Callable, Generic, Literal, Sequence, TypeVar
9import matplotlib.pyplot as plt
10import numpy as np
11from jaxtyping import Float
12from muutils.dictmagic import dotlist_to_nested_dict, update_with_nested_dict
13from muutils.json_serialize import (
14 JSONitem,
15 SerializableDataclass,
16 json_serialize,
17 serializable_dataclass,
18 serializable_field,
19)
20from muutils.parallel import run_maybe_parallel
21from zanj import ZANJ
23from maze_dataset import MazeDataset, MazeDatasetConfig
24from maze_dataset.generation import LatticeMazeGenerators
26SweepReturnType = TypeVar("SweepReturnType")
27ParamType = TypeVar("ParamType")
28AnalysisFunc = Callable[[MazeDatasetConfig], SweepReturnType]
31def dataset_success_fraction(cfg: MazeDatasetConfig) -> float:
32 """empirical success fraction of maze generation
34 for use as an `analyze_func` in `sweep()`
35 """
36 dataset: MazeDataset = MazeDataset.from_config(
37 cfg,
38 do_download=False,
39 load_local=False,
40 save_local=False,
41 verbose=False,
42 )
44 return len(dataset) / cfg.n_mazes
47ANALYSIS_FUNCS: dict[str, AnalysisFunc] = dict(
48 dataset_success_fraction=dataset_success_fraction,
49)
52def sweep(
53 cfg_base: MazeDatasetConfig,
54 param_values: list[ParamType],
55 param_key: str,
56 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
57) -> list[SweepReturnType]:
58 """given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value
60 # Parameters:
61 - `cfg_base : MazeDatasetConfig`
62 base config on which we will modify the value at `param_key` with values from `param_values`
63 - `param_values : list[ParamType]`
64 list of values to try
65 - `param_key : str`
66 value to modify in `cfg_base`
67 - `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]`
68 function which analyzes the resulting config. originally built for `dataset_success_fraction`
70 # Returns:
71 - `list[SweepReturnType]`
72 _description_
73 """
74 outputs: list[SweepReturnType] = []
76 for p in param_values:
77 # update the config
78 cfg_dict: dict = cfg_base.serialize()
79 update_with_nested_dict(
80 cfg_dict,
81 dotlist_to_nested_dict({param_key: p}),
82 )
83 cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict)
85 outputs.append(analyze_func(cfg_test))
87 return outputs
90@serializable_dataclass()
91class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]):
92 """result of a parameter sweep"""
94 configs: list[MazeDatasetConfig] = serializable_field(
95 serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs],
96 deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs],
97 )
98 param_values: list[ParamType] = serializable_field(
99 serialization_fn=lambda x: json_serialize(x),
100 deserialize_fn=lambda x: x,
101 assert_type=False,
102 )
103 result_values: dict[str, Sequence[SweepReturnType]] = serializable_field(
104 serialization_fn=lambda x: json_serialize(x),
105 deserialize_fn=lambda x: x,
106 assert_type=False,
107 )
108 param_key: str
109 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field(
110 serialization_fn=lambda f: f.__name__,
111 deserialize_fn=ANALYSIS_FUNCS.get,
112 assert_type=False,
113 )
115 def summary(self) -> JSONitem:
116 "human-readable and json-dumpable short summary of the result"
117 return {
118 "len(configs)": len(self.configs),
119 "len(param_values)": len(self.param_values),
120 "len(result_values)": len(self.result_values),
121 "param_key": self.param_key,
122 "analyze_func": self.analyze_func.__name__,
123 }
125 def save(self, path: str | Path, z: ZANJ | None = None) -> None:
126 "save to a file with zanj"
127 if z is None:
128 z = ZANJ()
130 z.save(self, path)
132 @classmethod
133 def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult":
134 "read from a file with zanj"
135 if z is None:
136 z = ZANJ()
138 return z.read(path)
140 def configs_by_name(self) -> dict[str, MazeDatasetConfig]:
141 "return configs by name"
142 return {cfg.name: cfg for cfg in self.configs}
144 def configs_by_key(self) -> dict[str, MazeDatasetConfig]:
145 "return configs by the key used in `result_values`, which is the filename of the config"
146 return {cfg.to_fname(): cfg for cfg in self.configs}
148 def configs_shared(self) -> dict[str, Any]:
149 "return key: value pairs that are shared across all configs"
150 # we know that the configs all have the same keys,
151 # so this way of doing it is fine
152 config_vals: dict[str, set[Any]] = dict()
153 for cfg in self.configs:
154 for k, v in cfg.serialize().items():
155 if k not in config_vals:
156 config_vals[k] = set()
157 config_vals[k].add(json.dumps(v))
159 shared_vals: dict[str, Any] = dict()
161 cfg_ser: dict = self.configs[0].serialize()
162 for k, v in config_vals.items():
163 if len(v) == 1:
164 shared_vals[k] = cfg_ser[k]
166 return shared_vals
168 def configs_differing_keys(self) -> set[str]:
169 "return keys that differ across configs"
170 shared_vals: dict[str, Any] = self.configs_shared()
171 differing_keys: set[str] = set()
173 for k in MazeDatasetConfig.__dataclass_fields__:
174 if k not in shared_vals:
175 differing_keys.add(k)
177 return differing_keys
179 def configs_value_set(self, key: str) -> list[Any]:
180 "return a list of the unique values for a given key"
181 d: dict[str, Any] = {
182 json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key)
183 for cfg in self.configs
184 }
186 return list(d.values())
188 def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult":
189 "get a subset of this `Result` where the configs has `key` satisfying `val_check`"
190 configs_list: list[MazeDatasetConfig] = [
191 cfg for cfg in self.configs if val_check(getattr(cfg, key))
192 ]
193 configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list}
194 result_values: dict[str, Sequence[SweepReturnType]] = {
195 k: self.result_values[k] for k in configs_keys
196 }
198 return SweepResult(
199 configs=configs_list,
200 param_values=self.param_values,
201 result_values=result_values,
202 param_key=self.param_key,
203 analyze_func=self.analyze_func,
204 )
206 @classmethod
207 def analyze(
208 cls,
209 configs: list[MazeDatasetConfig],
210 param_values: list[ParamType],
211 param_key: str,
212 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
213 parallel: bool | int = False,
214 **kwargs,
215 ) -> "SweepResult":
216 """Analyze success rate of maze generation for different percolation values
218 # Parameters:
219 - `configs : list[MazeDatasetConfig]`
220 configs to try
221 - `param_values : np.ndarray`
222 numpy array of values to try
224 # Returns:
225 - `SweepResult`
226 """
227 n_pvals: int = len(param_values)
229 result_values_list: list[float] = run_maybe_parallel(
230 # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type]
231 func=functools.partial( # type: ignore[arg-type]
232 sweep,
233 param_values=param_values,
234 param_key=param_key,
235 analyze_func=analyze_func,
236 ),
237 iterable=configs,
238 keep_ordered=True,
239 parallel=parallel,
240 pbar_kwargs=dict(total=len(configs)),
241 **kwargs,
242 )
243 result_values: dict[str, Float[np.ndarray, n_pvals]] = {
244 cfg.to_fname(): np.array(res)
245 for cfg, res in zip(configs, result_values_list, strict=False)
246 }
247 return cls(
248 configs=configs,
249 param_values=param_values,
250 # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type]
251 result_values=result_values, # type: ignore[arg-type]
252 param_key=param_key,
253 analyze_func=analyze_func,
254 )
256 def plot(
257 self,
258 save_path: str | None = None,
259 cfg_keys: list[str] | None = None,
260 cmap_name: str | None = "viridis",
261 plot_only: bool = False,
262 show: bool = True,
263 ax: plt.Axes | None = None,
264 minify_title: bool = False,
265 legend_kwargs: dict[str, Any] | None = None,
266 ) -> plt.Axes:
267 """Plot the results of percolation analysis"""
268 # set up figure
269 if not ax:
270 fig: plt.Figure
271 ax_: plt.Axes
272 fig, ax_ = plt.subplots(1, 1, figsize=(22, 10))
273 else:
274 ax_ = ax
276 # plot
277 cmap = plt.get_cmap(cmap_name)
278 n_cfgs: int = len(self.result_values)
279 for i, (ep_cfg_name, result_values) in enumerate(
280 sorted(
281 self.result_values.items(),
282 # HACK: sort by grid size
283 # |--< name of config
284 # | |-----------< gets 'g{n}'
285 # | | |--< gets '{n}'
286 # | | |
287 key=lambda x: int(x[0].split("-")[0][1:]),
288 ),
289 ):
290 ax_.plot(
291 # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type]
292 self.param_values, # type: ignore[arg-type]
293 # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type]
294 result_values, # type: ignore[arg-type]
295 ".-",
296 label=self.configs_by_key()[ep_cfg_name].name,
297 color=cmap((i + 0.5) / (n_cfgs - 0.5)),
298 )
300 # repr of config
301 cfg_shared: dict = self.configs_shared()
302 if minify_title:
303 cfg_shared["endpoint_kwargs"] = {
304 k: v
305 for k, v in cfg_shared["endpoint_kwargs"].items()
306 if k != "except_on_no_valid_endpoint"
307 }
308 cfg_repr: str = (
309 str(cfg_shared)
310 if cfg_keys is None
311 else (
312 "MazeDatasetConfig("
313 + ", ".join(
314 [
315 f"{k}={cfg_shared[k].__name__}"
316 # TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type]
317 if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type]
318 else f"{k}={cfg_shared[k]}"
319 for k in cfg_keys
320 ],
321 )
322 + ")"
323 )
324 )
326 # add title and stuff
327 if not plot_only:
328 ax_.set_xlabel(self.param_key)
329 ax_.set_ylabel(self.analyze_func.__name__)
330 ax_.set_title(
331 f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}",
332 )
333 ax_.grid(True)
334 # ax_.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, -0.1))
335 legend_kwargs = {
336 **dict(loc="center left"),
337 **(legend_kwargs or dict()),
338 }
339 ax_.legend(**legend_kwargs)
341 # save and show
342 if save_path:
343 plt.savefig(save_path)
345 if show:
346 plt.show()
348 return ax_
351DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [
352 (
353 "any",
354 dict(deadend_start=False, deadend_end=False, except_on_no_valid_endpoint=False),
355 ),
356 (
357 "deadends",
358 dict(
359 deadend_start=True,
360 deadend_end=True,
361 endpoints_not_equal=False,
362 except_on_no_valid_endpoint=False,
363 ),
364 ),
365 (
366 "deadends_unique",
367 dict(
368 deadend_start=True,
369 deadend_end=True,
370 endpoints_not_equal=True,
371 except_on_no_valid_endpoint=False,
372 ),
373 ),
374]
377def endpoint_kwargs_to_name(ep_kwargs: dict) -> str:
378 """convert endpoint kwargs options to a human-readable name"""
379 if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False):
380 if ep_kwargs.get("endpoints_not_equal", False):
381 return "deadends_unique"
382 else:
383 return "deadends"
384 else:
385 return "any"
388def full_percolation_analysis(
389 n_mazes: int,
390 p_val_count: int,
391 grid_sizes: list[int],
392 ep_kwargs: list[tuple[str, dict]] | None = None,
393 generators: Sequence[Callable] = (
394 LatticeMazeGenerators.gen_percolation,
395 LatticeMazeGenerators.gen_dfs_percolation,
396 ),
397 save_dir: Path = Path("../docs/benchmarks/percolation_fractions"),
398 parallel: bool | int = False,
399 **analyze_kwargs,
400) -> SweepResult:
401 "run the full analysis of how percolation affects maze generation success"
402 if ep_kwargs is None:
403 ep_kwargs = DEFAULT_ENDPOINT_KWARGS
405 # configs
406 configs: list[MazeDatasetConfig] = list()
408 # TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx`
409 for ep_kw_name, ep_kw in ep_kwargs: # noqa: B007
410 for gf_idx, gen_func in enumerate(generators): # noqa: B007
411 configs.extend(
412 [
413 MazeDatasetConfig(
414 name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}",
415 grid_n=grid_n,
416 n_mazes=n_mazes,
417 maze_ctor=gen_func,
418 maze_ctor_kwargs=dict(p=float("nan")),
419 endpoint_kwargs=ep_kw,
420 )
421 for grid_n in grid_sizes
422 ],
423 )
425 # get results
426 result: SweepResult = SweepResult.analyze(
427 configs=configs, # type: ignore[misc]
428 # TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]" [arg-type]
429 param_values=np.linspace(0.0, 1.0, p_val_count).tolist(), # type: ignore[arg-type]
430 param_key="maze_ctor_kwargs.p",
431 analyze_func=dataset_success_fraction,
432 parallel=parallel,
433 **analyze_kwargs,
434 )
436 # save the result
437 results_path: Path = (
438 save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj"
439 )
440 print(f"Saving results to {results_path.as_posix()}")
441 result.save(results_path)
443 return result
446def _is_eq(a, b) -> bool: # noqa: ANN001
447 """check if two objects are equal"""
448 return a == b
451def plot_grouped( # noqa: C901
452 results: SweepResult,
453 predict_fn: Callable[[MazeDatasetConfig], float] | None = None,
454 prediction_density: int = 50,
455 save_dir: Path | None = None,
456 show: bool = True,
457 logy: bool = False,
458 save_fmt: str = "svg",
459 figsize: tuple[int, int] = (22, 10),
460 minify_title: bool = False,
461 legend_kwargs: dict[str, Any] | None = None,
462 manual_titles: dict[Literal["x", "y", "title"], str] | None = None,
463) -> None:
464 """Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs
466 with separate colormaps for each maze generator function
468 # Parameters:
469 - `results : SweepResult`
470 The sweep results to plot
471 - `predict_fn : Callable[[MazeDatasetConfig], float] | None`
472 Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines.
473 - `prediction_density : int`
474 Number of points to use for prediction curves (default: 50)
475 - `save_dir : Path | None`
476 Directory to save plots (defaults to `None`, meaning no saving)
477 - `show : bool`
478 Whether to display the plots (defaults to `True`)
480 # Usage:
481 ```python
482 >>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16])
483 >>> plot_grouped(result, save_dir=Path("./plots"), show=False)
484 ```
485 """
486 # groups
487 endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs") # type: ignore[assignment]
488 generator_funcs_names: list[str] = list(
489 {cfg.maze_ctor.__name__ for cfg in results.configs},
490 )
492 # if predicting, create denser p values
493 if predict_fn is not None:
494 p_dense = np.linspace(0.0, 1.0, prediction_density)
496 # separate plot for each set of endpoint kwargs
497 for ep_kw in endpoint_kwargs_set:
498 results_epkw: SweepResult = results.get_where(
499 "endpoint_kwargs",
500 functools.partial(_is_eq, b=ep_kw),
501 # lambda x: x == ep_kw,
502 )
503 shared_keys: set[str] = set(results_epkw.configs_shared().keys())
504 cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"})
505 fig, ax = plt.subplots(1, 1, figsize=figsize)
506 for gf_idx, gen_func in enumerate(generator_funcs_names):
507 results_filtered: SweepResult = results_epkw.get_where(
508 "maze_ctor",
509 # HACK: big hassle to do this without a lambda, is it really that bad?
510 lambda x: x.__name__ == gen_func, # noqa: B023
511 )
512 if len(results_filtered.configs) < 1:
513 warnings.warn(
514 f"No results for {gen_func} and {ep_kw}. Skipping.",
515 )
516 continue
518 cmap_name = "Reds" if gf_idx == 0 else "Blues"
519 cmap = plt.get_cmap(cmap_name)
521 # Plot actual results
522 ax = results_filtered.plot(
523 cfg_keys=list(cfg_keys),
524 ax=ax,
525 show=False,
526 cmap_name=cmap_name,
527 minify_title=minify_title,
528 legend_kwargs=legend_kwargs,
529 )
530 if logy:
531 ax.set_yscale("log")
533 # Plot predictions if function provided
534 if predict_fn is not None:
535 for cfg_idx, cfg in enumerate(results_filtered.configs):
536 predictions = []
537 for p in p_dense:
538 cfg_temp = MazeDatasetConfig.load(cfg.serialize())
539 cfg_temp.maze_ctor_kwargs["p"] = p
540 predictions.append(predict_fn(cfg_temp))
542 # Get the same color as the actual data
543 n_cfgs: int = len(results_filtered.configs)
544 color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5))
546 # Plot prediction as dashed line
547 ax.plot(p_dense, predictions, "--", color=color, alpha=0.8)
549 if manual_titles:
550 ax.set_xlabel(manual_titles["x"])
551 ax.set_ylabel(manual_titles["y"])
552 ax.set_title(manual_titles["title"])
554 # save and show
555 if save_dir:
556 save_path: Path = (
557 save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.{save_fmt}"
558 )
559 print(f"Saving plot to {save_path.as_posix()}")
560 save_path.parent.mkdir(exist_ok=True, parents=True)
561 plt.savefig(save_path)
563 if show:
564 plt.show()