maze_dataset.benchmark.config_sweep
Benchmarking of how successful maze generation is for various values of percolation
1"""Benchmarking of how successful maze generation is for various values of percolation""" 2 3import functools 4import json 5import warnings 6from pathlib import Path 7from typing import Any, Callable, Generic, Literal, Sequence, TypeVar 8 9import matplotlib.pyplot as plt 10import numpy as np 11from jaxtyping import Float 12from muutils.dictmagic import dotlist_to_nested_dict, update_with_nested_dict 13from muutils.json_serialize import ( 14 JSONitem, 15 SerializableDataclass, 16 json_serialize, 17 serializable_dataclass, 18 serializable_field, 19) 20from muutils.parallel import run_maybe_parallel 21from zanj import ZANJ 22 23from maze_dataset import MazeDataset, MazeDatasetConfig 24from maze_dataset.generation import LatticeMazeGenerators 25 26SweepReturnType = TypeVar("SweepReturnType") 27ParamType = TypeVar("ParamType") 28AnalysisFunc = Callable[[MazeDatasetConfig], SweepReturnType] 29 30 31def dataset_success_fraction(cfg: MazeDatasetConfig) -> float: 32 """empirical success fraction of maze generation 33 34 for use as an `analyze_func` in `sweep()` 35 """ 36 dataset: MazeDataset = MazeDataset.from_config( 37 cfg, 38 do_download=False, 39 load_local=False, 40 save_local=False, 41 verbose=False, 42 ) 43 44 return len(dataset) / cfg.n_mazes 45 46 47ANALYSIS_FUNCS: dict[str, AnalysisFunc] = dict( 48 dataset_success_fraction=dataset_success_fraction, 49) 50 51 52def sweep( 53 cfg_base: MazeDatasetConfig, 54 param_values: list[ParamType], 55 param_key: str, 56 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 57) -> list[SweepReturnType]: 58 """given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value 59 60 # Parameters: 61 - `cfg_base : MazeDatasetConfig` 62 base config on which we will modify the value at `param_key` with values from `param_values` 63 - `param_values : list[ParamType]` 64 list of values to try 65 - `param_key : str` 66 value to modify in `cfg_base` 67 - `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]` 68 function which analyzes the resulting config. originally built for `dataset_success_fraction` 69 70 # Returns: 71 - `list[SweepReturnType]` 72 _description_ 73 """ 74 outputs: list[SweepReturnType] = [] 75 76 for p in param_values: 77 # update the config 78 cfg_dict: dict = cfg_base.serialize() 79 update_with_nested_dict( 80 cfg_dict, 81 dotlist_to_nested_dict({param_key: p}), 82 ) 83 cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict) 84 85 outputs.append(analyze_func(cfg_test)) 86 87 return outputs 88 89 90@serializable_dataclass() 91class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]): 92 """result of a parameter sweep""" 93 94 configs: list[MazeDatasetConfig] = serializable_field( 95 serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs], 96 deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs], 97 ) 98 param_values: list[ParamType] = serializable_field( 99 serialization_fn=lambda x: json_serialize(x), 100 deserialize_fn=lambda x: x, 101 assert_type=False, 102 ) 103 result_values: dict[str, Sequence[SweepReturnType]] = serializable_field( 104 serialization_fn=lambda x: json_serialize(x), 105 deserialize_fn=lambda x: x, 106 assert_type=False, 107 ) 108 param_key: str 109 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field( 110 serialization_fn=lambda f: f.__name__, 111 deserialize_fn=ANALYSIS_FUNCS.get, 112 assert_type=False, 113 ) 114 115 def summary(self) -> JSONitem: 116 "human-readable and json-dumpable short summary of the result" 117 return { 118 "len(configs)": len(self.configs), 119 "len(param_values)": len(self.param_values), 120 "len(result_values)": len(self.result_values), 121 "param_key": self.param_key, 122 "analyze_func": self.analyze_func.__name__, 123 } 124 125 def save(self, path: str | Path, z: ZANJ | None = None) -> None: 126 "save to a file with zanj" 127 if z is None: 128 z = ZANJ() 129 130 z.save(self, path) 131 132 @classmethod 133 def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult": 134 "read from a file with zanj" 135 if z is None: 136 z = ZANJ() 137 138 return z.read(path) 139 140 def configs_by_name(self) -> dict[str, MazeDatasetConfig]: 141 "return configs by name" 142 return {cfg.name: cfg for cfg in self.configs} 143 144 def configs_by_key(self) -> dict[str, MazeDatasetConfig]: 145 "return configs by the key used in `result_values`, which is the filename of the config" 146 return {cfg.to_fname(): cfg for cfg in self.configs} 147 148 def configs_shared(self) -> dict[str, Any]: 149 "return key: value pairs that are shared across all configs" 150 # we know that the configs all have the same keys, 151 # so this way of doing it is fine 152 config_vals: dict[str, set[Any]] = dict() 153 for cfg in self.configs: 154 for k, v in cfg.serialize().items(): 155 if k not in config_vals: 156 config_vals[k] = set() 157 config_vals[k].add(json.dumps(v)) 158 159 shared_vals: dict[str, Any] = dict() 160 161 cfg_ser: dict = self.configs[0].serialize() 162 for k, v in config_vals.items(): 163 if len(v) == 1: 164 shared_vals[k] = cfg_ser[k] 165 166 return shared_vals 167 168 def configs_differing_keys(self) -> set[str]: 169 "return keys that differ across configs" 170 shared_vals: dict[str, Any] = self.configs_shared() 171 differing_keys: set[str] = set() 172 173 for k in MazeDatasetConfig.__dataclass_fields__: 174 if k not in shared_vals: 175 differing_keys.add(k) 176 177 return differing_keys 178 179 def configs_value_set(self, key: str) -> list[Any]: 180 "return a list of the unique values for a given key" 181 d: dict[str, Any] = { 182 json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key) 183 for cfg in self.configs 184 } 185 186 return list(d.values()) 187 188 def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult": 189 "get a subset of this `Result` where the configs has `key` satisfying `val_check`" 190 configs_list: list[MazeDatasetConfig] = [ 191 cfg for cfg in self.configs if val_check(getattr(cfg, key)) 192 ] 193 configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list} 194 result_values: dict[str, Sequence[SweepReturnType]] = { 195 k: self.result_values[k] for k in configs_keys 196 } 197 198 return SweepResult( 199 configs=configs_list, 200 param_values=self.param_values, 201 result_values=result_values, 202 param_key=self.param_key, 203 analyze_func=self.analyze_func, 204 ) 205 206 @classmethod 207 def analyze( 208 cls, 209 configs: list[MazeDatasetConfig], 210 param_values: list[ParamType], 211 param_key: str, 212 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 213 parallel: bool | int = False, 214 **kwargs, 215 ) -> "SweepResult": 216 """Analyze success rate of maze generation for different percolation values 217 218 # Parameters: 219 - `configs : list[MazeDatasetConfig]` 220 configs to try 221 - `param_values : np.ndarray` 222 numpy array of values to try 223 224 # Returns: 225 - `SweepResult` 226 """ 227 n_pvals: int = len(param_values) 228 229 result_values_list: list[float] = run_maybe_parallel( 230 # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type] 231 func=functools.partial( # type: ignore[arg-type] 232 sweep, 233 param_values=param_values, 234 param_key=param_key, 235 analyze_func=analyze_func, 236 ), 237 iterable=configs, 238 keep_ordered=True, 239 parallel=parallel, 240 pbar_kwargs=dict(total=len(configs)), 241 **kwargs, 242 ) 243 result_values: dict[str, Float[np.ndarray, n_pvals]] = { 244 cfg.to_fname(): np.array(res) 245 for cfg, res in zip(configs, result_values_list, strict=False) 246 } 247 return cls( 248 configs=configs, 249 param_values=param_values, 250 # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type] 251 result_values=result_values, # type: ignore[arg-type] 252 param_key=param_key, 253 analyze_func=analyze_func, 254 ) 255 256 def plot( 257 self, 258 save_path: str | None = None, 259 cfg_keys: list[str] | None = None, 260 cmap_name: str | None = "viridis", 261 plot_only: bool = False, 262 show: bool = True, 263 ax: plt.Axes | None = None, 264 minify_title: bool = False, 265 legend_kwargs: dict[str, Any] | None = None, 266 ) -> plt.Axes: 267 """Plot the results of percolation analysis""" 268 # set up figure 269 if not ax: 270 fig: plt.Figure 271 ax_: plt.Axes 272 fig, ax_ = plt.subplots(1, 1, figsize=(22, 10)) 273 else: 274 ax_ = ax 275 276 # plot 277 cmap = plt.get_cmap(cmap_name) 278 n_cfgs: int = len(self.result_values) 279 for i, (ep_cfg_name, result_values) in enumerate( 280 sorted( 281 self.result_values.items(), 282 # HACK: sort by grid size 283 # |--< name of config 284 # | |-----------< gets 'g{n}' 285 # | | |--< gets '{n}' 286 # | | | 287 key=lambda x: int(x[0].split("-")[0][1:]), 288 ), 289 ): 290 ax_.plot( 291 # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 292 self.param_values, # type: ignore[arg-type] 293 # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 294 result_values, # type: ignore[arg-type] 295 ".-", 296 label=self.configs_by_key()[ep_cfg_name].name, 297 color=cmap((i + 0.5) / (n_cfgs - 0.5)), 298 ) 299 300 # repr of config 301 cfg_shared: dict = self.configs_shared() 302 if minify_title: 303 cfg_shared["endpoint_kwargs"] = { 304 k: v 305 for k, v in cfg_shared["endpoint_kwargs"].items() 306 if k != "except_on_no_valid_endpoint" 307 } 308 cfg_repr: str = ( 309 str(cfg_shared) 310 if cfg_keys is None 311 else ( 312 "MazeDatasetConfig(" 313 + ", ".join( 314 [ 315 f"{k}={cfg_shared[k].__name__}" 316 # TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type] 317 if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type] 318 else f"{k}={cfg_shared[k]}" 319 for k in cfg_keys 320 ], 321 ) 322 + ")" 323 ) 324 ) 325 326 # add title and stuff 327 if not plot_only: 328 ax_.set_xlabel(self.param_key) 329 ax_.set_ylabel(self.analyze_func.__name__) 330 ax_.set_title( 331 f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}", 332 ) 333 ax_.grid(True) 334 # ax_.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, -0.1)) 335 legend_kwargs = { 336 **dict(loc="center left"), 337 **(legend_kwargs or dict()), 338 } 339 ax_.legend(**legend_kwargs) 340 341 # save and show 342 if save_path: 343 plt.savefig(save_path) 344 345 if show: 346 plt.show() 347 348 return ax_ 349 350 351DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [ 352 ( 353 "any", 354 dict(deadend_start=False, deadend_end=False, except_on_no_valid_endpoint=False), 355 ), 356 ( 357 "deadends", 358 dict( 359 deadend_start=True, 360 deadend_end=True, 361 endpoints_not_equal=False, 362 except_on_no_valid_endpoint=False, 363 ), 364 ), 365 ( 366 "deadends_unique", 367 dict( 368 deadend_start=True, 369 deadend_end=True, 370 endpoints_not_equal=True, 371 except_on_no_valid_endpoint=False, 372 ), 373 ), 374] 375 376 377def endpoint_kwargs_to_name(ep_kwargs: dict) -> str: 378 """convert endpoint kwargs options to a human-readable name""" 379 if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False): 380 if ep_kwargs.get("endpoints_not_equal", False): 381 return "deadends_unique" 382 else: 383 return "deadends" 384 else: 385 return "any" 386 387 388def full_percolation_analysis( 389 n_mazes: int, 390 p_val_count: int, 391 grid_sizes: list[int], 392 ep_kwargs: list[tuple[str, dict]] | None = None, 393 generators: Sequence[Callable] = ( 394 LatticeMazeGenerators.gen_percolation, 395 LatticeMazeGenerators.gen_dfs_percolation, 396 ), 397 save_dir: Path = Path("../docs/benchmarks/percolation_fractions"), 398 parallel: bool | int = False, 399 **analyze_kwargs, 400) -> SweepResult: 401 "run the full analysis of how percolation affects maze generation success" 402 if ep_kwargs is None: 403 ep_kwargs = DEFAULT_ENDPOINT_KWARGS 404 405 # configs 406 configs: list[MazeDatasetConfig] = list() 407 408 # TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx` 409 for ep_kw_name, ep_kw in ep_kwargs: # noqa: B007 410 for gf_idx, gen_func in enumerate(generators): # noqa: B007 411 configs.extend( 412 [ 413 MazeDatasetConfig( 414 name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}", 415 grid_n=grid_n, 416 n_mazes=n_mazes, 417 maze_ctor=gen_func, 418 maze_ctor_kwargs=dict(p=float("nan")), 419 endpoint_kwargs=ep_kw, 420 ) 421 for grid_n in grid_sizes 422 ], 423 ) 424 425 # get results 426 result: SweepResult = SweepResult.analyze( 427 configs=configs, # type: ignore[misc] 428 # TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]" [arg-type] 429 param_values=np.linspace(0.0, 1.0, p_val_count).tolist(), # type: ignore[arg-type] 430 param_key="maze_ctor_kwargs.p", 431 analyze_func=dataset_success_fraction, 432 parallel=parallel, 433 **analyze_kwargs, 434 ) 435 436 # save the result 437 results_path: Path = ( 438 save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj" 439 ) 440 print(f"Saving results to {results_path.as_posix()}") 441 result.save(results_path) 442 443 return result 444 445 446def _is_eq(a, b) -> bool: # noqa: ANN001 447 """check if two objects are equal""" 448 return a == b 449 450 451def plot_grouped( # noqa: C901 452 results: SweepResult, 453 predict_fn: Callable[[MazeDatasetConfig], float] | None = None, 454 prediction_density: int = 50, 455 save_dir: Path | None = None, 456 show: bool = True, 457 logy: bool = False, 458 save_fmt: str = "svg", 459 figsize: tuple[int, int] = (22, 10), 460 minify_title: bool = False, 461 legend_kwargs: dict[str, Any] | None = None, 462 manual_titles: dict[Literal["x", "y", "title"], str] | None = None, 463) -> None: 464 """Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs 465 466 with separate colormaps for each maze generator function 467 468 # Parameters: 469 - `results : SweepResult` 470 The sweep results to plot 471 - `predict_fn : Callable[[MazeDatasetConfig], float] | None` 472 Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines. 473 - `prediction_density : int` 474 Number of points to use for prediction curves (default: 50) 475 - `save_dir : Path | None` 476 Directory to save plots (defaults to `None`, meaning no saving) 477 - `show : bool` 478 Whether to display the plots (defaults to `True`) 479 480 # Usage: 481 ```python 482 >>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16]) 483 >>> plot_grouped(result, save_dir=Path("./plots"), show=False) 484 ``` 485 """ 486 # groups 487 endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs") # type: ignore[assignment] 488 generator_funcs_names: list[str] = list( 489 {cfg.maze_ctor.__name__ for cfg in results.configs}, 490 ) 491 492 # if predicting, create denser p values 493 if predict_fn is not None: 494 p_dense = np.linspace(0.0, 1.0, prediction_density) 495 496 # separate plot for each set of endpoint kwargs 497 for ep_kw in endpoint_kwargs_set: 498 results_epkw: SweepResult = results.get_where( 499 "endpoint_kwargs", 500 functools.partial(_is_eq, b=ep_kw), 501 # lambda x: x == ep_kw, 502 ) 503 shared_keys: set[str] = set(results_epkw.configs_shared().keys()) 504 cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"}) 505 fig, ax = plt.subplots(1, 1, figsize=figsize) 506 for gf_idx, gen_func in enumerate(generator_funcs_names): 507 results_filtered: SweepResult = results_epkw.get_where( 508 "maze_ctor", 509 # HACK: big hassle to do this without a lambda, is it really that bad? 510 lambda x: x.__name__ == gen_func, # noqa: B023 511 ) 512 if len(results_filtered.configs) < 1: 513 warnings.warn( 514 f"No results for {gen_func} and {ep_kw}. Skipping.", 515 ) 516 continue 517 518 cmap_name = "Reds" if gf_idx == 0 else "Blues" 519 cmap = plt.get_cmap(cmap_name) 520 521 # Plot actual results 522 ax = results_filtered.plot( 523 cfg_keys=list(cfg_keys), 524 ax=ax, 525 show=False, 526 cmap_name=cmap_name, 527 minify_title=minify_title, 528 legend_kwargs=legend_kwargs, 529 ) 530 if logy: 531 ax.set_yscale("log") 532 533 # Plot predictions if function provided 534 if predict_fn is not None: 535 for cfg_idx, cfg in enumerate(results_filtered.configs): 536 predictions = [] 537 for p in p_dense: 538 cfg_temp = MazeDatasetConfig.load(cfg.serialize()) 539 cfg_temp.maze_ctor_kwargs["p"] = p 540 predictions.append(predict_fn(cfg_temp)) 541 542 # Get the same color as the actual data 543 n_cfgs: int = len(results_filtered.configs) 544 color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5)) 545 546 # Plot prediction as dashed line 547 ax.plot(p_dense, predictions, "--", color=color, alpha=0.8) 548 549 if manual_titles: 550 ax.set_xlabel(manual_titles["x"]) 551 ax.set_ylabel(manual_titles["y"]) 552 ax.set_title(manual_titles["title"]) 553 554 # save and show 555 if save_dir: 556 save_path: Path = ( 557 save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.{save_fmt}" 558 ) 559 print(f"Saving plot to {save_path.as_posix()}") 560 save_path.parent.mkdir(exist_ok=True, parents=True) 561 plt.savefig(save_path) 562 563 if show: 564 plt.show()
32def dataset_success_fraction(cfg: MazeDatasetConfig) -> float: 33 """empirical success fraction of maze generation 34 35 for use as an `analyze_func` in `sweep()` 36 """ 37 dataset: MazeDataset = MazeDataset.from_config( 38 cfg, 39 do_download=False, 40 load_local=False, 41 save_local=False, 42 verbose=False, 43 ) 44 45 return len(dataset) / cfg.n_mazes
empirical success fraction of maze generation
for use as an analyze_func
in sweep()
53def sweep( 54 cfg_base: MazeDatasetConfig, 55 param_values: list[ParamType], 56 param_key: str, 57 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 58) -> list[SweepReturnType]: 59 """given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value 60 61 # Parameters: 62 - `cfg_base : MazeDatasetConfig` 63 base config on which we will modify the value at `param_key` with values from `param_values` 64 - `param_values : list[ParamType]` 65 list of values to try 66 - `param_key : str` 67 value to modify in `cfg_base` 68 - `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]` 69 function which analyzes the resulting config. originally built for `dataset_success_fraction` 70 71 # Returns: 72 - `list[SweepReturnType]` 73 _description_ 74 """ 75 outputs: list[SweepReturnType] = [] 76 77 for p in param_values: 78 # update the config 79 cfg_dict: dict = cfg_base.serialize() 80 update_with_nested_dict( 81 cfg_dict, 82 dotlist_to_nested_dict({param_key: p}), 83 ) 84 cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict) 85 86 outputs.append(analyze_func(cfg_test)) 87 88 return outputs
given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value
Parameters:
cfg_base : MazeDatasetConfig
base config on which we will modify the value atparam_key
with values fromparam_values
param_values : list[ParamType]
list of values to tryparam_key : str
value to modify incfg_base
analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]
function which analyzes the resulting config. originally built fordataset_success_fraction
Returns:
list[SweepReturnType]
_description_
91@serializable_dataclass() 92class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]): 93 """result of a parameter sweep""" 94 95 configs: list[MazeDatasetConfig] = serializable_field( 96 serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs], 97 deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs], 98 ) 99 param_values: list[ParamType] = serializable_field( 100 serialization_fn=lambda x: json_serialize(x), 101 deserialize_fn=lambda x: x, 102 assert_type=False, 103 ) 104 result_values: dict[str, Sequence[SweepReturnType]] = serializable_field( 105 serialization_fn=lambda x: json_serialize(x), 106 deserialize_fn=lambda x: x, 107 assert_type=False, 108 ) 109 param_key: str 110 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field( 111 serialization_fn=lambda f: f.__name__, 112 deserialize_fn=ANALYSIS_FUNCS.get, 113 assert_type=False, 114 ) 115 116 def summary(self) -> JSONitem: 117 "human-readable and json-dumpable short summary of the result" 118 return { 119 "len(configs)": len(self.configs), 120 "len(param_values)": len(self.param_values), 121 "len(result_values)": len(self.result_values), 122 "param_key": self.param_key, 123 "analyze_func": self.analyze_func.__name__, 124 } 125 126 def save(self, path: str | Path, z: ZANJ | None = None) -> None: 127 "save to a file with zanj" 128 if z is None: 129 z = ZANJ() 130 131 z.save(self, path) 132 133 @classmethod 134 def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult": 135 "read from a file with zanj" 136 if z is None: 137 z = ZANJ() 138 139 return z.read(path) 140 141 def configs_by_name(self) -> dict[str, MazeDatasetConfig]: 142 "return configs by name" 143 return {cfg.name: cfg for cfg in self.configs} 144 145 def configs_by_key(self) -> dict[str, MazeDatasetConfig]: 146 "return configs by the key used in `result_values`, which is the filename of the config" 147 return {cfg.to_fname(): cfg for cfg in self.configs} 148 149 def configs_shared(self) -> dict[str, Any]: 150 "return key: value pairs that are shared across all configs" 151 # we know that the configs all have the same keys, 152 # so this way of doing it is fine 153 config_vals: dict[str, set[Any]] = dict() 154 for cfg in self.configs: 155 for k, v in cfg.serialize().items(): 156 if k not in config_vals: 157 config_vals[k] = set() 158 config_vals[k].add(json.dumps(v)) 159 160 shared_vals: dict[str, Any] = dict() 161 162 cfg_ser: dict = self.configs[0].serialize() 163 for k, v in config_vals.items(): 164 if len(v) == 1: 165 shared_vals[k] = cfg_ser[k] 166 167 return shared_vals 168 169 def configs_differing_keys(self) -> set[str]: 170 "return keys that differ across configs" 171 shared_vals: dict[str, Any] = self.configs_shared() 172 differing_keys: set[str] = set() 173 174 for k in MazeDatasetConfig.__dataclass_fields__: 175 if k not in shared_vals: 176 differing_keys.add(k) 177 178 return differing_keys 179 180 def configs_value_set(self, key: str) -> list[Any]: 181 "return a list of the unique values for a given key" 182 d: dict[str, Any] = { 183 json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key) 184 for cfg in self.configs 185 } 186 187 return list(d.values()) 188 189 def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult": 190 "get a subset of this `Result` where the configs has `key` satisfying `val_check`" 191 configs_list: list[MazeDatasetConfig] = [ 192 cfg for cfg in self.configs if val_check(getattr(cfg, key)) 193 ] 194 configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list} 195 result_values: dict[str, Sequence[SweepReturnType]] = { 196 k: self.result_values[k] for k in configs_keys 197 } 198 199 return SweepResult( 200 configs=configs_list, 201 param_values=self.param_values, 202 result_values=result_values, 203 param_key=self.param_key, 204 analyze_func=self.analyze_func, 205 ) 206 207 @classmethod 208 def analyze( 209 cls, 210 configs: list[MazeDatasetConfig], 211 param_values: list[ParamType], 212 param_key: str, 213 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 214 parallel: bool | int = False, 215 **kwargs, 216 ) -> "SweepResult": 217 """Analyze success rate of maze generation for different percolation values 218 219 # Parameters: 220 - `configs : list[MazeDatasetConfig]` 221 configs to try 222 - `param_values : np.ndarray` 223 numpy array of values to try 224 225 # Returns: 226 - `SweepResult` 227 """ 228 n_pvals: int = len(param_values) 229 230 result_values_list: list[float] = run_maybe_parallel( 231 # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type] 232 func=functools.partial( # type: ignore[arg-type] 233 sweep, 234 param_values=param_values, 235 param_key=param_key, 236 analyze_func=analyze_func, 237 ), 238 iterable=configs, 239 keep_ordered=True, 240 parallel=parallel, 241 pbar_kwargs=dict(total=len(configs)), 242 **kwargs, 243 ) 244 result_values: dict[str, Float[np.ndarray, n_pvals]] = { 245 cfg.to_fname(): np.array(res) 246 for cfg, res in zip(configs, result_values_list, strict=False) 247 } 248 return cls( 249 configs=configs, 250 param_values=param_values, 251 # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type] 252 result_values=result_values, # type: ignore[arg-type] 253 param_key=param_key, 254 analyze_func=analyze_func, 255 ) 256 257 def plot( 258 self, 259 save_path: str | None = None, 260 cfg_keys: list[str] | None = None, 261 cmap_name: str | None = "viridis", 262 plot_only: bool = False, 263 show: bool = True, 264 ax: plt.Axes | None = None, 265 minify_title: bool = False, 266 legend_kwargs: dict[str, Any] | None = None, 267 ) -> plt.Axes: 268 """Plot the results of percolation analysis""" 269 # set up figure 270 if not ax: 271 fig: plt.Figure 272 ax_: plt.Axes 273 fig, ax_ = plt.subplots(1, 1, figsize=(22, 10)) 274 else: 275 ax_ = ax 276 277 # plot 278 cmap = plt.get_cmap(cmap_name) 279 n_cfgs: int = len(self.result_values) 280 for i, (ep_cfg_name, result_values) in enumerate( 281 sorted( 282 self.result_values.items(), 283 # HACK: sort by grid size 284 # |--< name of config 285 # | |-----------< gets 'g{n}' 286 # | | |--< gets '{n}' 287 # | | | 288 key=lambda x: int(x[0].split("-")[0][1:]), 289 ), 290 ): 291 ax_.plot( 292 # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 293 self.param_values, # type: ignore[arg-type] 294 # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 295 result_values, # type: ignore[arg-type] 296 ".-", 297 label=self.configs_by_key()[ep_cfg_name].name, 298 color=cmap((i + 0.5) / (n_cfgs - 0.5)), 299 ) 300 301 # repr of config 302 cfg_shared: dict = self.configs_shared() 303 if minify_title: 304 cfg_shared["endpoint_kwargs"] = { 305 k: v 306 for k, v in cfg_shared["endpoint_kwargs"].items() 307 if k != "except_on_no_valid_endpoint" 308 } 309 cfg_repr: str = ( 310 str(cfg_shared) 311 if cfg_keys is None 312 else ( 313 "MazeDatasetConfig(" 314 + ", ".join( 315 [ 316 f"{k}={cfg_shared[k].__name__}" 317 # TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type] 318 if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type] 319 else f"{k}={cfg_shared[k]}" 320 for k in cfg_keys 321 ], 322 ) 323 + ")" 324 ) 325 ) 326 327 # add title and stuff 328 if not plot_only: 329 ax_.set_xlabel(self.param_key) 330 ax_.set_ylabel(self.analyze_func.__name__) 331 ax_.set_title( 332 f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}", 333 ) 334 ax_.grid(True) 335 # ax_.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, -0.1)) 336 legend_kwargs = { 337 **dict(loc="center left"), 338 **(legend_kwargs or dict()), 339 } 340 ax_.legend(**legend_kwargs) 341 342 # save and show 343 if save_path: 344 plt.savefig(save_path) 345 346 if show: 347 plt.show() 348 349 return ax_
result of a parameter sweep
116 def summary(self) -> JSONitem: 117 "human-readable and json-dumpable short summary of the result" 118 return { 119 "len(configs)": len(self.configs), 120 "len(param_values)": len(self.param_values), 121 "len(result_values)": len(self.result_values), 122 "param_key": self.param_key, 123 "analyze_func": self.analyze_func.__name__, 124 }
human-readable and json-dumpable short summary of the result
126 def save(self, path: str | Path, z: ZANJ | None = None) -> None: 127 "save to a file with zanj" 128 if z is None: 129 z = ZANJ() 130 131 z.save(self, path)
save to a file with zanj
133 @classmethod 134 def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult": 135 "read from a file with zanj" 136 if z is None: 137 z = ZANJ() 138 139 return z.read(path)
read from a file with zanj
141 def configs_by_name(self) -> dict[str, MazeDatasetConfig]: 142 "return configs by name" 143 return {cfg.name: cfg for cfg in self.configs}
return configs by name
145 def configs_by_key(self) -> dict[str, MazeDatasetConfig]: 146 "return configs by the key used in `result_values`, which is the filename of the config" 147 return {cfg.to_fname(): cfg for cfg in self.configs}
return configs by the key used in result_values
, which is the filename of the config
169 def configs_differing_keys(self) -> set[str]: 170 "return keys that differ across configs" 171 shared_vals: dict[str, Any] = self.configs_shared() 172 differing_keys: set[str] = set() 173 174 for k in MazeDatasetConfig.__dataclass_fields__: 175 if k not in shared_vals: 176 differing_keys.add(k) 177 178 return differing_keys
return keys that differ across configs
180 def configs_value_set(self, key: str) -> list[Any]: 181 "return a list of the unique values for a given key" 182 d: dict[str, Any] = { 183 json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key) 184 for cfg in self.configs 185 } 186 187 return list(d.values())
return a list of the unique values for a given key
189 def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult": 190 "get a subset of this `Result` where the configs has `key` satisfying `val_check`" 191 configs_list: list[MazeDatasetConfig] = [ 192 cfg for cfg in self.configs if val_check(getattr(cfg, key)) 193 ] 194 configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list} 195 result_values: dict[str, Sequence[SweepReturnType]] = { 196 k: self.result_values[k] for k in configs_keys 197 } 198 199 return SweepResult( 200 configs=configs_list, 201 param_values=self.param_values, 202 result_values=result_values, 203 param_key=self.param_key, 204 analyze_func=self.analyze_func, 205 )
get a subset of this Result
where the configs has key
satisfying val_check
207 @classmethod 208 def analyze( 209 cls, 210 configs: list[MazeDatasetConfig], 211 param_values: list[ParamType], 212 param_key: str, 213 analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], 214 parallel: bool | int = False, 215 **kwargs, 216 ) -> "SweepResult": 217 """Analyze success rate of maze generation for different percolation values 218 219 # Parameters: 220 - `configs : list[MazeDatasetConfig]` 221 configs to try 222 - `param_values : np.ndarray` 223 numpy array of values to try 224 225 # Returns: 226 - `SweepResult` 227 """ 228 n_pvals: int = len(param_values) 229 230 result_values_list: list[float] = run_maybe_parallel( 231 # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type] 232 func=functools.partial( # type: ignore[arg-type] 233 sweep, 234 param_values=param_values, 235 param_key=param_key, 236 analyze_func=analyze_func, 237 ), 238 iterable=configs, 239 keep_ordered=True, 240 parallel=parallel, 241 pbar_kwargs=dict(total=len(configs)), 242 **kwargs, 243 ) 244 result_values: dict[str, Float[np.ndarray, n_pvals]] = { 245 cfg.to_fname(): np.array(res) 246 for cfg, res in zip(configs, result_values_list, strict=False) 247 } 248 return cls( 249 configs=configs, 250 param_values=param_values, 251 # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type] 252 result_values=result_values, # type: ignore[arg-type] 253 param_key=param_key, 254 analyze_func=analyze_func, 255 )
Analyze success rate of maze generation for different percolation values
Parameters:
configs : list[MazeDatasetConfig]
configs to tryparam_values : np.ndarray
numpy array of values to try
Returns:
257 def plot( 258 self, 259 save_path: str | None = None, 260 cfg_keys: list[str] | None = None, 261 cmap_name: str | None = "viridis", 262 plot_only: bool = False, 263 show: bool = True, 264 ax: plt.Axes | None = None, 265 minify_title: bool = False, 266 legend_kwargs: dict[str, Any] | None = None, 267 ) -> plt.Axes: 268 """Plot the results of percolation analysis""" 269 # set up figure 270 if not ax: 271 fig: plt.Figure 272 ax_: plt.Axes 273 fig, ax_ = plt.subplots(1, 1, figsize=(22, 10)) 274 else: 275 ax_ = ax 276 277 # plot 278 cmap = plt.get_cmap(cmap_name) 279 n_cfgs: int = len(self.result_values) 280 for i, (ep_cfg_name, result_values) in enumerate( 281 sorted( 282 self.result_values.items(), 283 # HACK: sort by grid size 284 # |--< name of config 285 # | |-----------< gets 'g{n}' 286 # | | |--< gets '{n}' 287 # | | | 288 key=lambda x: int(x[0].split("-")[0][1:]), 289 ), 290 ): 291 ax_.plot( 292 # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 293 self.param_values, # type: ignore[arg-type] 294 # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] 295 result_values, # type: ignore[arg-type] 296 ".-", 297 label=self.configs_by_key()[ep_cfg_name].name, 298 color=cmap((i + 0.5) / (n_cfgs - 0.5)), 299 ) 300 301 # repr of config 302 cfg_shared: dict = self.configs_shared() 303 if minify_title: 304 cfg_shared["endpoint_kwargs"] = { 305 k: v 306 for k, v in cfg_shared["endpoint_kwargs"].items() 307 if k != "except_on_no_valid_endpoint" 308 } 309 cfg_repr: str = ( 310 str(cfg_shared) 311 if cfg_keys is None 312 else ( 313 "MazeDatasetConfig(" 314 + ", ".join( 315 [ 316 f"{k}={cfg_shared[k].__name__}" 317 # TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type] 318 if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type] 319 else f"{k}={cfg_shared[k]}" 320 for k in cfg_keys 321 ], 322 ) 323 + ")" 324 ) 325 ) 326 327 # add title and stuff 328 if not plot_only: 329 ax_.set_xlabel(self.param_key) 330 ax_.set_ylabel(self.analyze_func.__name__) 331 ax_.set_title( 332 f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}", 333 ) 334 ax_.grid(True) 335 # ax_.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, -0.1)) 336 legend_kwargs = { 337 **dict(loc="center left"), 338 **(legend_kwargs or dict()), 339 } 340 ax_.legend(**legend_kwargs) 341 342 # save and show 343 if save_path: 344 plt.savefig(save_path) 345 346 if show: 347 plt.show() 348 349 return ax_
Plot the results of percolation analysis
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- serialize
- load
- validate_fields_types
- validate_field_type
- diff
- update_from_nested_dict
378def endpoint_kwargs_to_name(ep_kwargs: dict) -> str: 379 """convert endpoint kwargs options to a human-readable name""" 380 if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False): 381 if ep_kwargs.get("endpoints_not_equal", False): 382 return "deadends_unique" 383 else: 384 return "deadends" 385 else: 386 return "any"
convert endpoint kwargs options to a human-readable name
389def full_percolation_analysis( 390 n_mazes: int, 391 p_val_count: int, 392 grid_sizes: list[int], 393 ep_kwargs: list[tuple[str, dict]] | None = None, 394 generators: Sequence[Callable] = ( 395 LatticeMazeGenerators.gen_percolation, 396 LatticeMazeGenerators.gen_dfs_percolation, 397 ), 398 save_dir: Path = Path("../docs/benchmarks/percolation_fractions"), 399 parallel: bool | int = False, 400 **analyze_kwargs, 401) -> SweepResult: 402 "run the full analysis of how percolation affects maze generation success" 403 if ep_kwargs is None: 404 ep_kwargs = DEFAULT_ENDPOINT_KWARGS 405 406 # configs 407 configs: list[MazeDatasetConfig] = list() 408 409 # TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx` 410 for ep_kw_name, ep_kw in ep_kwargs: # noqa: B007 411 for gf_idx, gen_func in enumerate(generators): # noqa: B007 412 configs.extend( 413 [ 414 MazeDatasetConfig( 415 name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}", 416 grid_n=grid_n, 417 n_mazes=n_mazes, 418 maze_ctor=gen_func, 419 maze_ctor_kwargs=dict(p=float("nan")), 420 endpoint_kwargs=ep_kw, 421 ) 422 for grid_n in grid_sizes 423 ], 424 ) 425 426 # get results 427 result: SweepResult = SweepResult.analyze( 428 configs=configs, # type: ignore[misc] 429 # TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]" [arg-type] 430 param_values=np.linspace(0.0, 1.0, p_val_count).tolist(), # type: ignore[arg-type] 431 param_key="maze_ctor_kwargs.p", 432 analyze_func=dataset_success_fraction, 433 parallel=parallel, 434 **analyze_kwargs, 435 ) 436 437 # save the result 438 results_path: Path = ( 439 save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj" 440 ) 441 print(f"Saving results to {results_path.as_posix()}") 442 result.save(results_path) 443 444 return result
run the full analysis of how percolation affects maze generation success
452def plot_grouped( # noqa: C901 453 results: SweepResult, 454 predict_fn: Callable[[MazeDatasetConfig], float] | None = None, 455 prediction_density: int = 50, 456 save_dir: Path | None = None, 457 show: bool = True, 458 logy: bool = False, 459 save_fmt: str = "svg", 460 figsize: tuple[int, int] = (22, 10), 461 minify_title: bool = False, 462 legend_kwargs: dict[str, Any] | None = None, 463 manual_titles: dict[Literal["x", "y", "title"], str] | None = None, 464) -> None: 465 """Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs 466 467 with separate colormaps for each maze generator function 468 469 # Parameters: 470 - `results : SweepResult` 471 The sweep results to plot 472 - `predict_fn : Callable[[MazeDatasetConfig], float] | None` 473 Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines. 474 - `prediction_density : int` 475 Number of points to use for prediction curves (default: 50) 476 - `save_dir : Path | None` 477 Directory to save plots (defaults to `None`, meaning no saving) 478 - `show : bool` 479 Whether to display the plots (defaults to `True`) 480 481 # Usage: 482 ```python 483 >>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16]) 484 >>> plot_grouped(result, save_dir=Path("./plots"), show=False) 485 ``` 486 """ 487 # groups 488 endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs") # type: ignore[assignment] 489 generator_funcs_names: list[str] = list( 490 {cfg.maze_ctor.__name__ for cfg in results.configs}, 491 ) 492 493 # if predicting, create denser p values 494 if predict_fn is not None: 495 p_dense = np.linspace(0.0, 1.0, prediction_density) 496 497 # separate plot for each set of endpoint kwargs 498 for ep_kw in endpoint_kwargs_set: 499 results_epkw: SweepResult = results.get_where( 500 "endpoint_kwargs", 501 functools.partial(_is_eq, b=ep_kw), 502 # lambda x: x == ep_kw, 503 ) 504 shared_keys: set[str] = set(results_epkw.configs_shared().keys()) 505 cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"}) 506 fig, ax = plt.subplots(1, 1, figsize=figsize) 507 for gf_idx, gen_func in enumerate(generator_funcs_names): 508 results_filtered: SweepResult = results_epkw.get_where( 509 "maze_ctor", 510 # HACK: big hassle to do this without a lambda, is it really that bad? 511 lambda x: x.__name__ == gen_func, # noqa: B023 512 ) 513 if len(results_filtered.configs) < 1: 514 warnings.warn( 515 f"No results for {gen_func} and {ep_kw}. Skipping.", 516 ) 517 continue 518 519 cmap_name = "Reds" if gf_idx == 0 else "Blues" 520 cmap = plt.get_cmap(cmap_name) 521 522 # Plot actual results 523 ax = results_filtered.plot( 524 cfg_keys=list(cfg_keys), 525 ax=ax, 526 show=False, 527 cmap_name=cmap_name, 528 minify_title=minify_title, 529 legend_kwargs=legend_kwargs, 530 ) 531 if logy: 532 ax.set_yscale("log") 533 534 # Plot predictions if function provided 535 if predict_fn is not None: 536 for cfg_idx, cfg in enumerate(results_filtered.configs): 537 predictions = [] 538 for p in p_dense: 539 cfg_temp = MazeDatasetConfig.load(cfg.serialize()) 540 cfg_temp.maze_ctor_kwargs["p"] = p 541 predictions.append(predict_fn(cfg_temp)) 542 543 # Get the same color as the actual data 544 n_cfgs: int = len(results_filtered.configs) 545 color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5)) 546 547 # Plot prediction as dashed line 548 ax.plot(p_dense, predictions, "--", color=color, alpha=0.8) 549 550 if manual_titles: 551 ax.set_xlabel(manual_titles["x"]) 552 ax.set_ylabel(manual_titles["y"]) 553 ax.set_title(manual_titles["title"]) 554 555 # save and show 556 if save_dir: 557 save_path: Path = ( 558 save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.{save_fmt}" 559 ) 560 print(f"Saving plot to {save_path.as_posix()}") 561 save_path.parent.mkdir(exist_ok=True, parents=True) 562 plt.savefig(save_path) 563 564 if show: 565 plt.show()
Plot grouped sweep percolation value results for each distinct endpoint_kwargs
in the configs
with separate colormaps for each maze generator function
Parameters:
results : SweepResult
The sweep results to plotpredict_fn : Callable[[MazeDatasetConfig], float] | None
Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines.prediction_density : int
Number of points to use for prediction curves (default: 50)save_dir : Path | None
Directory to save plots (defaults toNone
, meaning no saving)show : bool
Whether to display the plots (defaults toTrue
)
Usage:
>>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16])
>>> plot_grouped(result, save_dir=Path("./plots"), show=False)