docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.benchmark.config_sweep

Benchmarking of how successful maze generation is for various values of percolation


  1"""Benchmarking of how successful maze generation is for various values of percolation"""
  2
  3import functools
  4import json
  5import warnings
  6from pathlib import Path
  7from typing import Any, Callable, Generic, Literal, Sequence, TypeVar
  8
  9import matplotlib.pyplot as plt
 10import numpy as np
 11from jaxtyping import Float
 12from muutils.dictmagic import dotlist_to_nested_dict, update_with_nested_dict
 13from muutils.json_serialize import (
 14	JSONitem,
 15	SerializableDataclass,
 16	json_serialize,
 17	serializable_dataclass,
 18	serializable_field,
 19)
 20from muutils.parallel import run_maybe_parallel
 21from zanj import ZANJ
 22
 23from maze_dataset import MazeDataset, MazeDatasetConfig
 24from maze_dataset.generation import LatticeMazeGenerators
 25
 26SweepReturnType = TypeVar("SweepReturnType")
 27ParamType = TypeVar("ParamType")
 28AnalysisFunc = Callable[[MazeDatasetConfig], SweepReturnType]
 29
 30
 31def dataset_success_fraction(cfg: MazeDatasetConfig) -> float:
 32	"""empirical success fraction of maze generation
 33
 34	for use as an `analyze_func` in `sweep()`
 35	"""
 36	dataset: MazeDataset = MazeDataset.from_config(
 37		cfg,
 38		do_download=False,
 39		load_local=False,
 40		save_local=False,
 41		verbose=False,
 42	)
 43
 44	return len(dataset) / cfg.n_mazes
 45
 46
 47ANALYSIS_FUNCS: dict[str, AnalysisFunc] = dict(
 48	dataset_success_fraction=dataset_success_fraction,
 49)
 50
 51
 52def sweep(
 53	cfg_base: MazeDatasetConfig,
 54	param_values: list[ParamType],
 55	param_key: str,
 56	analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
 57) -> list[SweepReturnType]:
 58	"""given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value
 59
 60	# Parameters:
 61	- `cfg_base : MazeDatasetConfig`
 62		base config on which we will modify the value at `param_key` with values from `param_values`
 63	- `param_values : list[ParamType]`
 64		list of values to try
 65	- `param_key : str`
 66		value to modify in `cfg_base`
 67	- `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]`
 68		function which analyzes the resulting config. originally built for `dataset_success_fraction`
 69
 70	# Returns:
 71	- `list[SweepReturnType]`
 72		_description_
 73	"""
 74	outputs: list[SweepReturnType] = []
 75
 76	for p in param_values:
 77		# update the config
 78		cfg_dict: dict = cfg_base.serialize()
 79		update_with_nested_dict(
 80			cfg_dict,
 81			dotlist_to_nested_dict({param_key: p}),
 82		)
 83		cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict)
 84
 85		outputs.append(analyze_func(cfg_test))
 86
 87	return outputs
 88
 89
 90@serializable_dataclass()
 91class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]):
 92	"""result of a parameter sweep"""
 93
 94	configs: list[MazeDatasetConfig] = serializable_field(
 95		serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs],
 96		deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs],
 97	)
 98	param_values: list[ParamType] = serializable_field(
 99		serialization_fn=lambda x: json_serialize(x),
100		deserialize_fn=lambda x: x,
101		assert_type=False,
102	)
103	result_values: dict[str, Sequence[SweepReturnType]] = serializable_field(
104		serialization_fn=lambda x: json_serialize(x),
105		deserialize_fn=lambda x: x,
106		assert_type=False,
107	)
108	param_key: str
109	analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field(
110		serialization_fn=lambda f: f.__name__,
111		deserialize_fn=ANALYSIS_FUNCS.get,
112		assert_type=False,
113	)
114
115	def summary(self) -> JSONitem:
116		"human-readable and json-dumpable short summary of the result"
117		return {
118			"len(configs)": len(self.configs),
119			"len(param_values)": len(self.param_values),
120			"len(result_values)": len(self.result_values),
121			"param_key": self.param_key,
122			"analyze_func": self.analyze_func.__name__,
123		}
124
125	def save(self, path: str | Path, z: ZANJ | None = None) -> None:
126		"save to a file with zanj"
127		if z is None:
128			z = ZANJ()
129
130		z.save(self, path)
131
132	@classmethod
133	def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult":
134		"read from a file with zanj"
135		if z is None:
136			z = ZANJ()
137
138		return z.read(path)
139
140	def configs_by_name(self) -> dict[str, MazeDatasetConfig]:
141		"return configs by name"
142		return {cfg.name: cfg for cfg in self.configs}
143
144	def configs_by_key(self) -> dict[str, MazeDatasetConfig]:
145		"return configs by the key used in `result_values`, which is the filename of the config"
146		return {cfg.to_fname(): cfg for cfg in self.configs}
147
148	def configs_shared(self) -> dict[str, Any]:
149		"return key: value pairs that are shared across all configs"
150		# we know that the configs all have the same keys,
151		# so this way of doing it is fine
152		config_vals: dict[str, set[Any]] = dict()
153		for cfg in self.configs:
154			for k, v in cfg.serialize().items():
155				if k not in config_vals:
156					config_vals[k] = set()
157				config_vals[k].add(json.dumps(v))
158
159		shared_vals: dict[str, Any] = dict()
160
161		cfg_ser: dict = self.configs[0].serialize()
162		for k, v in config_vals.items():
163			if len(v) == 1:
164				shared_vals[k] = cfg_ser[k]
165
166		return shared_vals
167
168	def configs_differing_keys(self) -> set[str]:
169		"return keys that differ across configs"
170		shared_vals: dict[str, Any] = self.configs_shared()
171		differing_keys: set[str] = set()
172
173		for k in MazeDatasetConfig.__dataclass_fields__:
174			if k not in shared_vals:
175				differing_keys.add(k)
176
177		return differing_keys
178
179	def configs_value_set(self, key: str) -> list[Any]:
180		"return a list of the unique values for a given key"
181		d: dict[str, Any] = {
182			json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key)
183			for cfg in self.configs
184		}
185
186		return list(d.values())
187
188	def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult":
189		"get a subset of this `Result` where the configs has `key` satisfying `val_check`"
190		configs_list: list[MazeDatasetConfig] = [
191			cfg for cfg in self.configs if val_check(getattr(cfg, key))
192		]
193		configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list}
194		result_values: dict[str, Sequence[SweepReturnType]] = {
195			k: self.result_values[k] for k in configs_keys
196		}
197
198		return SweepResult(
199			configs=configs_list,
200			param_values=self.param_values,
201			result_values=result_values,
202			param_key=self.param_key,
203			analyze_func=self.analyze_func,
204		)
205
206	@classmethod
207	def analyze(
208		cls,
209		configs: list[MazeDatasetConfig],
210		param_values: list[ParamType],
211		param_key: str,
212		analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
213		parallel: bool | int = False,
214		**kwargs,
215	) -> "SweepResult":
216		"""Analyze success rate of maze generation for different percolation values
217
218		# Parameters:
219		- `configs : list[MazeDatasetConfig]`
220		configs to try
221		- `param_values : np.ndarray`
222		numpy array of values to try
223
224		# Returns:
225		- `SweepResult`
226		"""
227		n_pvals: int = len(param_values)
228
229		result_values_list: list[float] = run_maybe_parallel(
230			# TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]"  [arg-type]
231			func=functools.partial(  # type: ignore[arg-type]
232				sweep,
233				param_values=param_values,
234				param_key=param_key,
235				analyze_func=analyze_func,
236			),
237			iterable=configs,
238			keep_ordered=True,
239			parallel=parallel,
240			pbar_kwargs=dict(total=len(configs)),
241			**kwargs,
242		)
243		result_values: dict[str, Float[np.ndarray, n_pvals]] = {
244			cfg.to_fname(): np.array(res)
245			for cfg, res in zip(configs, result_values_list, strict=False)
246		}
247		return cls(
248			configs=configs,
249			param_values=param_values,
250			# TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]"  [arg-type]
251			result_values=result_values,  # type: ignore[arg-type]
252			param_key=param_key,
253			analyze_func=analyze_func,
254		)
255
256	def plot(
257		self,
258		save_path: str | None = None,
259		cfg_keys: list[str] | None = None,
260		cmap_name: str | None = "viridis",
261		plot_only: bool = False,
262		show: bool = True,
263		ax: plt.Axes | None = None,
264		minify_title: bool = False,
265		legend_kwargs: dict[str, Any] | None = None,
266	) -> plt.Axes:
267		"""Plot the results of percolation analysis"""
268		# set up figure
269		if not ax:
270			fig: plt.Figure
271			ax_: plt.Axes
272			fig, ax_ = plt.subplots(1, 1, figsize=(22, 10))
273		else:
274			ax_ = ax
275
276		# plot
277		cmap = plt.get_cmap(cmap_name)
278		n_cfgs: int = len(self.result_values)
279		for i, (ep_cfg_name, result_values) in enumerate(
280			sorted(
281				self.result_values.items(),
282				# HACK: sort by grid size
283				#                 |--< name of config
284				#                 |    |-----------< gets 'g{n}'
285				#                 |    |            |--< gets '{n}'
286				#                 |    |            |
287				key=lambda x: int(x[0].split("-")[0][1:]),
288			),
289		):
290			ax_.plot(
291				# TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
292				self.param_values,  # type: ignore[arg-type]
293				# TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
294				result_values,  # type: ignore[arg-type]
295				".-",
296				label=self.configs_by_key()[ep_cfg_name].name,
297				color=cmap((i + 0.5) / (n_cfgs - 0.5)),
298			)
299
300		# repr of config
301		cfg_shared: dict = self.configs_shared()
302		if minify_title:
303			cfg_shared["endpoint_kwargs"] = {
304				k: v
305				for k, v in cfg_shared["endpoint_kwargs"].items()
306				if k != "except_on_no_valid_endpoint"
307			}
308		cfg_repr: str = (
309			str(cfg_shared)
310			if cfg_keys is None
311			else (
312				"MazeDatasetConfig("
313				+ ", ".join(
314					[
315						f"{k}={cfg_shared[k].__name__}"
316						# TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo"  [arg-type]
317						if isinstance(cfg_shared[k], Callable)  # type: ignore[arg-type]
318						else f"{k}={cfg_shared[k]}"
319						for k in cfg_keys
320					],
321				)
322				+ ")"
323			)
324		)
325
326		# add title and stuff
327		if not plot_only:
328			ax_.set_xlabel(self.param_key)
329			ax_.set_ylabel(self.analyze_func.__name__)
330			ax_.set_title(
331				f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}",
332			)
333			ax_.grid(True)
334			# ax_.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, -0.1))
335			legend_kwargs = {
336				**dict(loc="center left"),
337				**(legend_kwargs or dict()),
338			}
339			ax_.legend(**legend_kwargs)
340
341		# save and show
342		if save_path:
343			plt.savefig(save_path)
344
345		if show:
346			plt.show()
347
348		return ax_
349
350
351DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [
352	(
353		"any",
354		dict(deadend_start=False, deadend_end=False, except_on_no_valid_endpoint=False),
355	),
356	(
357		"deadends",
358		dict(
359			deadend_start=True,
360			deadend_end=True,
361			endpoints_not_equal=False,
362			except_on_no_valid_endpoint=False,
363		),
364	),
365	(
366		"deadends_unique",
367		dict(
368			deadend_start=True,
369			deadend_end=True,
370			endpoints_not_equal=True,
371			except_on_no_valid_endpoint=False,
372		),
373	),
374]
375
376
377def endpoint_kwargs_to_name(ep_kwargs: dict) -> str:
378	"""convert endpoint kwargs options to a human-readable name"""
379	if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False):
380		if ep_kwargs.get("endpoints_not_equal", False):
381			return "deadends_unique"
382		else:
383			return "deadends"
384	else:
385		return "any"
386
387
388def full_percolation_analysis(
389	n_mazes: int,
390	p_val_count: int,
391	grid_sizes: list[int],
392	ep_kwargs: list[tuple[str, dict]] | None = None,
393	generators: Sequence[Callable] = (
394		LatticeMazeGenerators.gen_percolation,
395		LatticeMazeGenerators.gen_dfs_percolation,
396	),
397	save_dir: Path = Path("../docs/benchmarks/percolation_fractions"),
398	parallel: bool | int = False,
399	**analyze_kwargs,
400) -> SweepResult:
401	"run the full analysis of how percolation affects maze generation success"
402	if ep_kwargs is None:
403		ep_kwargs = DEFAULT_ENDPOINT_KWARGS
404
405	# configs
406	configs: list[MazeDatasetConfig] = list()
407
408	# TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx`
409	for ep_kw_name, ep_kw in ep_kwargs:  # noqa: B007
410		for gf_idx, gen_func in enumerate(generators):  # noqa: B007
411			configs.extend(
412				[
413					MazeDatasetConfig(
414						name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}",
415						grid_n=grid_n,
416						n_mazes=n_mazes,
417						maze_ctor=gen_func,
418						maze_ctor_kwargs=dict(p=float("nan")),
419						endpoint_kwargs=ep_kw,
420					)
421					for grid_n in grid_sizes
422				],
423			)
424
425	# get results
426	result: SweepResult = SweepResult.analyze(
427		configs=configs,  # type: ignore[misc]
428		# TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]"  [arg-type]
429		param_values=np.linspace(0.0, 1.0, p_val_count).tolist(),  # type: ignore[arg-type]
430		param_key="maze_ctor_kwargs.p",
431		analyze_func=dataset_success_fraction,
432		parallel=parallel,
433		**analyze_kwargs,
434	)
435
436	# save the result
437	results_path: Path = (
438		save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj"
439	)
440	print(f"Saving results to {results_path.as_posix()}")
441	result.save(results_path)
442
443	return result
444
445
446def _is_eq(a, b) -> bool:  # noqa: ANN001
447	"""check if two objects are equal"""
448	return a == b
449
450
451def plot_grouped(  # noqa: C901
452	results: SweepResult,
453	predict_fn: Callable[[MazeDatasetConfig], float] | None = None,
454	prediction_density: int = 50,
455	save_dir: Path | None = None,
456	show: bool = True,
457	logy: bool = False,
458	save_fmt: str = "svg",
459	figsize: tuple[int, int] = (22, 10),
460	minify_title: bool = False,
461	legend_kwargs: dict[str, Any] | None = None,
462	manual_titles: dict[Literal["x", "y", "title"], str] | None = None,
463) -> None:
464	"""Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs
465
466	with separate colormaps for each maze generator function
467
468	# Parameters:
469	- `results : SweepResult`
470		The sweep results to plot
471	- `predict_fn : Callable[[MazeDatasetConfig], float] | None`
472		Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines.
473	- `prediction_density : int`
474		Number of points to use for prediction curves (default: 50)
475	- `save_dir : Path | None`
476		Directory to save plots (defaults to `None`, meaning no saving)
477	- `show : bool`
478		Whether to display the plots (defaults to `True`)
479
480	# Usage:
481	```python
482	>>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16])
483	>>> plot_grouped(result, save_dir=Path("./plots"), show=False)
484	```
485	"""
486	# groups
487	endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs")  # type: ignore[assignment]
488	generator_funcs_names: list[str] = list(
489		{cfg.maze_ctor.__name__ for cfg in results.configs},
490	)
491
492	# if predicting, create denser p values
493	if predict_fn is not None:
494		p_dense = np.linspace(0.0, 1.0, prediction_density)
495
496	# separate plot for each set of endpoint kwargs
497	for ep_kw in endpoint_kwargs_set:
498		results_epkw: SweepResult = results.get_where(
499			"endpoint_kwargs",
500			functools.partial(_is_eq, b=ep_kw),
501			# lambda x: x == ep_kw,
502		)
503		shared_keys: set[str] = set(results_epkw.configs_shared().keys())
504		cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"})
505		fig, ax = plt.subplots(1, 1, figsize=figsize)
506		for gf_idx, gen_func in enumerate(generator_funcs_names):
507			results_filtered: SweepResult = results_epkw.get_where(
508				"maze_ctor",
509				# HACK: big hassle to do this without a lambda, is it really that bad?
510				lambda x: x.__name__ == gen_func,  # noqa: B023
511			)
512			if len(results_filtered.configs) < 1:
513				warnings.warn(
514					f"No results for {gen_func} and {ep_kw}. Skipping.",
515				)
516				continue
517
518			cmap_name = "Reds" if gf_idx == 0 else "Blues"
519			cmap = plt.get_cmap(cmap_name)
520
521			# Plot actual results
522			ax = results_filtered.plot(
523				cfg_keys=list(cfg_keys),
524				ax=ax,
525				show=False,
526				cmap_name=cmap_name,
527				minify_title=minify_title,
528				legend_kwargs=legend_kwargs,
529			)
530			if logy:
531				ax.set_yscale("log")
532
533			# Plot predictions if function provided
534			if predict_fn is not None:
535				for cfg_idx, cfg in enumerate(results_filtered.configs):
536					predictions = []
537					for p in p_dense:
538						cfg_temp = MazeDatasetConfig.load(cfg.serialize())
539						cfg_temp.maze_ctor_kwargs["p"] = p
540						predictions.append(predict_fn(cfg_temp))
541
542					# Get the same color as the actual data
543					n_cfgs: int = len(results_filtered.configs)
544					color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5))
545
546					# Plot prediction as dashed line
547					ax.plot(p_dense, predictions, "--", color=color, alpha=0.8)
548
549			if manual_titles:
550				ax.set_xlabel(manual_titles["x"])
551				ax.set_ylabel(manual_titles["y"])
552				ax.set_title(manual_titles["title"])
553
554		# save and show
555		if save_dir:
556			save_path: Path = (
557				save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.{save_fmt}"
558			)
559			print(f"Saving plot to {save_path.as_posix()}")
560			save_path.parent.mkdir(exist_ok=True, parents=True)
561			plt.savefig(save_path)
562
563		if show:
564			plt.show()

AnalysisFunc = typing.Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType]
def dataset_success_fraction(cfg: maze_dataset.MazeDatasetConfig) -> float:
32def dataset_success_fraction(cfg: MazeDatasetConfig) -> float:
33	"""empirical success fraction of maze generation
34
35	for use as an `analyze_func` in `sweep()`
36	"""
37	dataset: MazeDataset = MazeDataset.from_config(
38		cfg,
39		do_download=False,
40		load_local=False,
41		save_local=False,
42		verbose=False,
43	)
44
45	return len(dataset) / cfg.n_mazes

empirical success fraction of maze generation

for use as an analyze_func in sweep()

ANALYSIS_FUNCS: dict[str, typing.Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType]] = {'dataset_success_fraction': <function dataset_success_fraction>}
def sweep( cfg_base: maze_dataset.MazeDatasetConfig, param_values: list[~ParamType], param_key: str, analyze_func: Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType]) -> list[~SweepReturnType]:
53def sweep(
54	cfg_base: MazeDatasetConfig,
55	param_values: list[ParamType],
56	param_key: str,
57	analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
58) -> list[SweepReturnType]:
59	"""given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value
60
61	# Parameters:
62	- `cfg_base : MazeDatasetConfig`
63		base config on which we will modify the value at `param_key` with values from `param_values`
64	- `param_values : list[ParamType]`
65		list of values to try
66	- `param_key : str`
67		value to modify in `cfg_base`
68	- `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]`
69		function which analyzes the resulting config. originally built for `dataset_success_fraction`
70
71	# Returns:
72	- `list[SweepReturnType]`
73		_description_
74	"""
75	outputs: list[SweepReturnType] = []
76
77	for p in param_values:
78		# update the config
79		cfg_dict: dict = cfg_base.serialize()
80		update_with_nested_dict(
81			cfg_dict,
82			dotlist_to_nested_dict({param_key: p}),
83		)
84		cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict)
85
86		outputs.append(analyze_func(cfg_test))
87
88	return outputs

given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value

Parameters:

  • cfg_base : MazeDatasetConfig base config on which we will modify the value at param_key with values from param_values
  • param_values : list[ParamType] list of values to try
  • param_key : str value to modify in cfg_base
  • analyze_func : Callable[[MazeDatasetConfig], SweepReturnType] function which analyzes the resulting config. originally built for dataset_success_fraction

Returns:

  • list[SweepReturnType] _description_
@serializable_dataclass()
class SweepResult(muutils.json_serialize.serializable_dataclass.SerializableDataclass, typing.Generic[~ParamType, ~SweepReturnType]):
 91@serializable_dataclass()
 92class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]):
 93	"""result of a parameter sweep"""
 94
 95	configs: list[MazeDatasetConfig] = serializable_field(
 96		serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs],
 97		deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs],
 98	)
 99	param_values: list[ParamType] = serializable_field(
100		serialization_fn=lambda x: json_serialize(x),
101		deserialize_fn=lambda x: x,
102		assert_type=False,
103	)
104	result_values: dict[str, Sequence[SweepReturnType]] = serializable_field(
105		serialization_fn=lambda x: json_serialize(x),
106		deserialize_fn=lambda x: x,
107		assert_type=False,
108	)
109	param_key: str
110	analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field(
111		serialization_fn=lambda f: f.__name__,
112		deserialize_fn=ANALYSIS_FUNCS.get,
113		assert_type=False,
114	)
115
116	def summary(self) -> JSONitem:
117		"human-readable and json-dumpable short summary of the result"
118		return {
119			"len(configs)": len(self.configs),
120			"len(param_values)": len(self.param_values),
121			"len(result_values)": len(self.result_values),
122			"param_key": self.param_key,
123			"analyze_func": self.analyze_func.__name__,
124		}
125
126	def save(self, path: str | Path, z: ZANJ | None = None) -> None:
127		"save to a file with zanj"
128		if z is None:
129			z = ZANJ()
130
131		z.save(self, path)
132
133	@classmethod
134	def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult":
135		"read from a file with zanj"
136		if z is None:
137			z = ZANJ()
138
139		return z.read(path)
140
141	def configs_by_name(self) -> dict[str, MazeDatasetConfig]:
142		"return configs by name"
143		return {cfg.name: cfg for cfg in self.configs}
144
145	def configs_by_key(self) -> dict[str, MazeDatasetConfig]:
146		"return configs by the key used in `result_values`, which is the filename of the config"
147		return {cfg.to_fname(): cfg for cfg in self.configs}
148
149	def configs_shared(self) -> dict[str, Any]:
150		"return key: value pairs that are shared across all configs"
151		# we know that the configs all have the same keys,
152		# so this way of doing it is fine
153		config_vals: dict[str, set[Any]] = dict()
154		for cfg in self.configs:
155			for k, v in cfg.serialize().items():
156				if k not in config_vals:
157					config_vals[k] = set()
158				config_vals[k].add(json.dumps(v))
159
160		shared_vals: dict[str, Any] = dict()
161
162		cfg_ser: dict = self.configs[0].serialize()
163		for k, v in config_vals.items():
164			if len(v) == 1:
165				shared_vals[k] = cfg_ser[k]
166
167		return shared_vals
168
169	def configs_differing_keys(self) -> set[str]:
170		"return keys that differ across configs"
171		shared_vals: dict[str, Any] = self.configs_shared()
172		differing_keys: set[str] = set()
173
174		for k in MazeDatasetConfig.__dataclass_fields__:
175			if k not in shared_vals:
176				differing_keys.add(k)
177
178		return differing_keys
179
180	def configs_value_set(self, key: str) -> list[Any]:
181		"return a list of the unique values for a given key"
182		d: dict[str, Any] = {
183			json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key)
184			for cfg in self.configs
185		}
186
187		return list(d.values())
188
189	def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult":
190		"get a subset of this `Result` where the configs has `key` satisfying `val_check`"
191		configs_list: list[MazeDatasetConfig] = [
192			cfg for cfg in self.configs if val_check(getattr(cfg, key))
193		]
194		configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list}
195		result_values: dict[str, Sequence[SweepReturnType]] = {
196			k: self.result_values[k] for k in configs_keys
197		}
198
199		return SweepResult(
200			configs=configs_list,
201			param_values=self.param_values,
202			result_values=result_values,
203			param_key=self.param_key,
204			analyze_func=self.analyze_func,
205		)
206
207	@classmethod
208	def analyze(
209		cls,
210		configs: list[MazeDatasetConfig],
211		param_values: list[ParamType],
212		param_key: str,
213		analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
214		parallel: bool | int = False,
215		**kwargs,
216	) -> "SweepResult":
217		"""Analyze success rate of maze generation for different percolation values
218
219		# Parameters:
220		- `configs : list[MazeDatasetConfig]`
221		configs to try
222		- `param_values : np.ndarray`
223		numpy array of values to try
224
225		# Returns:
226		- `SweepResult`
227		"""
228		n_pvals: int = len(param_values)
229
230		result_values_list: list[float] = run_maybe_parallel(
231			# TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]"  [arg-type]
232			func=functools.partial(  # type: ignore[arg-type]
233				sweep,
234				param_values=param_values,
235				param_key=param_key,
236				analyze_func=analyze_func,
237			),
238			iterable=configs,
239			keep_ordered=True,
240			parallel=parallel,
241			pbar_kwargs=dict(total=len(configs)),
242			**kwargs,
243		)
244		result_values: dict[str, Float[np.ndarray, n_pvals]] = {
245			cfg.to_fname(): np.array(res)
246			for cfg, res in zip(configs, result_values_list, strict=False)
247		}
248		return cls(
249			configs=configs,
250			param_values=param_values,
251			# TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]"  [arg-type]
252			result_values=result_values,  # type: ignore[arg-type]
253			param_key=param_key,
254			analyze_func=analyze_func,
255		)
256
257	def plot(
258		self,
259		save_path: str | None = None,
260		cfg_keys: list[str] | None = None,
261		cmap_name: str | None = "viridis",
262		plot_only: bool = False,
263		show: bool = True,
264		ax: plt.Axes | None = None,
265		minify_title: bool = False,
266		legend_kwargs: dict[str, Any] | None = None,
267	) -> plt.Axes:
268		"""Plot the results of percolation analysis"""
269		# set up figure
270		if not ax:
271			fig: plt.Figure
272			ax_: plt.Axes
273			fig, ax_ = plt.subplots(1, 1, figsize=(22, 10))
274		else:
275			ax_ = ax
276
277		# plot
278		cmap = plt.get_cmap(cmap_name)
279		n_cfgs: int = len(self.result_values)
280		for i, (ep_cfg_name, result_values) in enumerate(
281			sorted(
282				self.result_values.items(),
283				# HACK: sort by grid size
284				#                 |--< name of config
285				#                 |    |-----------< gets 'g{n}'
286				#                 |    |            |--< gets '{n}'
287				#                 |    |            |
288				key=lambda x: int(x[0].split("-")[0][1:]),
289			),
290		):
291			ax_.plot(
292				# TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
293				self.param_values,  # type: ignore[arg-type]
294				# TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
295				result_values,  # type: ignore[arg-type]
296				".-",
297				label=self.configs_by_key()[ep_cfg_name].name,
298				color=cmap((i + 0.5) / (n_cfgs - 0.5)),
299			)
300
301		# repr of config
302		cfg_shared: dict = self.configs_shared()
303		if minify_title:
304			cfg_shared["endpoint_kwargs"] = {
305				k: v
306				for k, v in cfg_shared["endpoint_kwargs"].items()
307				if k != "except_on_no_valid_endpoint"
308			}
309		cfg_repr: str = (
310			str(cfg_shared)
311			if cfg_keys is None
312			else (
313				"MazeDatasetConfig("
314				+ ", ".join(
315					[
316						f"{k}={cfg_shared[k].__name__}"
317						# TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo"  [arg-type]
318						if isinstance(cfg_shared[k], Callable)  # type: ignore[arg-type]
319						else f"{k}={cfg_shared[k]}"
320						for k in cfg_keys
321					],
322				)
323				+ ")"
324			)
325		)
326
327		# add title and stuff
328		if not plot_only:
329			ax_.set_xlabel(self.param_key)
330			ax_.set_ylabel(self.analyze_func.__name__)
331			ax_.set_title(
332				f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}",
333			)
334			ax_.grid(True)
335			# ax_.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, -0.1))
336			legend_kwargs = {
337				**dict(loc="center left"),
338				**(legend_kwargs or dict()),
339			}
340			ax_.legend(**legend_kwargs)
341
342		# save and show
343		if save_path:
344			plt.savefig(save_path)
345
346		if show:
347			plt.show()
348
349		return ax_

result of a parameter sweep

SweepResult( configs: list[maze_dataset.MazeDatasetConfig], param_values: list[~ParamType], result_values: dict[str, typing.Sequence[~SweepReturnType]], param_key: str, analyze_func: Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType])
param_values: list[~ParamType]
result_values: dict[str, typing.Sequence[~SweepReturnType]]
param_key: str
analyze_func: Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType]
def summary( self) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]:
116	def summary(self) -> JSONitem:
117		"human-readable and json-dumpable short summary of the result"
118		return {
119			"len(configs)": len(self.configs),
120			"len(param_values)": len(self.param_values),
121			"len(result_values)": len(self.result_values),
122			"param_key": self.param_key,
123			"analyze_func": self.analyze_func.__name__,
124		}

human-readable and json-dumpable short summary of the result

def save(self, path: str | pathlib.Path, z: zanj.zanj.ZANJ | None = None) -> None:
126	def save(self, path: str | Path, z: ZANJ | None = None) -> None:
127		"save to a file with zanj"
128		if z is None:
129			z = ZANJ()
130
131		z.save(self, path)

save to a file with zanj

@classmethod
def read( cls, path: str | pathlib.Path, z: zanj.zanj.ZANJ | None = None) -> SweepResult:
133	@classmethod
134	def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult":
135		"read from a file with zanj"
136		if z is None:
137			z = ZANJ()
138
139		return z.read(path)

read from a file with zanj

def configs_by_name( self) -> dict[str, maze_dataset.MazeDatasetConfig]:
141	def configs_by_name(self) -> dict[str, MazeDatasetConfig]:
142		"return configs by name"
143		return {cfg.name: cfg for cfg in self.configs}

return configs by name

def configs_by_key( self) -> dict[str, maze_dataset.MazeDatasetConfig]:
145	def configs_by_key(self) -> dict[str, MazeDatasetConfig]:
146		"return configs by the key used in `result_values`, which is the filename of the config"
147		return {cfg.to_fname(): cfg for cfg in self.configs}

return configs by the key used in result_values, which is the filename of the config

def configs_shared(self) -> dict[str, typing.Any]:
149	def configs_shared(self) -> dict[str, Any]:
150		"return key: value pairs that are shared across all configs"
151		# we know that the configs all have the same keys,
152		# so this way of doing it is fine
153		config_vals: dict[str, set[Any]] = dict()
154		for cfg in self.configs:
155			for k, v in cfg.serialize().items():
156				if k not in config_vals:
157					config_vals[k] = set()
158				config_vals[k].add(json.dumps(v))
159
160		shared_vals: dict[str, Any] = dict()
161
162		cfg_ser: dict = self.configs[0].serialize()
163		for k, v in config_vals.items():
164			if len(v) == 1:
165				shared_vals[k] = cfg_ser[k]
166
167		return shared_vals

return key: value pairs that are shared across all configs

def configs_differing_keys(self) -> set[str]:
169	def configs_differing_keys(self) -> set[str]:
170		"return keys that differ across configs"
171		shared_vals: dict[str, Any] = self.configs_shared()
172		differing_keys: set[str] = set()
173
174		for k in MazeDatasetConfig.__dataclass_fields__:
175			if k not in shared_vals:
176				differing_keys.add(k)
177
178		return differing_keys

return keys that differ across configs

def configs_value_set(self, key: str) -> list[typing.Any]:
180	def configs_value_set(self, key: str) -> list[Any]:
181		"return a list of the unique values for a given key"
182		d: dict[str, Any] = {
183			json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key)
184			for cfg in self.configs
185		}
186
187		return list(d.values())

return a list of the unique values for a given key

def get_where( self, key: str, val_check: Callable[[Any], bool]) -> SweepResult:
189	def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult":
190		"get a subset of this `Result` where the configs has `key` satisfying `val_check`"
191		configs_list: list[MazeDatasetConfig] = [
192			cfg for cfg in self.configs if val_check(getattr(cfg, key))
193		]
194		configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list}
195		result_values: dict[str, Sequence[SweepReturnType]] = {
196			k: self.result_values[k] for k in configs_keys
197		}
198
199		return SweepResult(
200			configs=configs_list,
201			param_values=self.param_values,
202			result_values=result_values,
203			param_key=self.param_key,
204			analyze_func=self.analyze_func,
205		)

get a subset of this Result where the configs has key satisfying val_check

@classmethod
def analyze( cls, configs: list[maze_dataset.MazeDatasetConfig], param_values: list[~ParamType], param_key: str, analyze_func: Callable[[maze_dataset.MazeDatasetConfig], ~SweepReturnType], parallel: bool | int = False, **kwargs) -> SweepResult:
207	@classmethod
208	def analyze(
209		cls,
210		configs: list[MazeDatasetConfig],
211		param_values: list[ParamType],
212		param_key: str,
213		analyze_func: Callable[[MazeDatasetConfig], SweepReturnType],
214		parallel: bool | int = False,
215		**kwargs,
216	) -> "SweepResult":
217		"""Analyze success rate of maze generation for different percolation values
218
219		# Parameters:
220		- `configs : list[MazeDatasetConfig]`
221		configs to try
222		- `param_values : np.ndarray`
223		numpy array of values to try
224
225		# Returns:
226		- `SweepResult`
227		"""
228		n_pvals: int = len(param_values)
229
230		result_values_list: list[float] = run_maybe_parallel(
231			# TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]"  [arg-type]
232			func=functools.partial(  # type: ignore[arg-type]
233				sweep,
234				param_values=param_values,
235				param_key=param_key,
236				analyze_func=analyze_func,
237			),
238			iterable=configs,
239			keep_ordered=True,
240			parallel=parallel,
241			pbar_kwargs=dict(total=len(configs)),
242			**kwargs,
243		)
244		result_values: dict[str, Float[np.ndarray, n_pvals]] = {
245			cfg.to_fname(): np.array(res)
246			for cfg, res in zip(configs, result_values_list, strict=False)
247		}
248		return cls(
249			configs=configs,
250			param_values=param_values,
251			# TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]"  [arg-type]
252			result_values=result_values,  # type: ignore[arg-type]
253			param_key=param_key,
254			analyze_func=analyze_func,
255		)

Analyze success rate of maze generation for different percolation values

Parameters:

  • configs : list[MazeDatasetConfig] configs to try
  • param_values : np.ndarray numpy array of values to try

Returns:

def plot( self, save_path: str | None = None, cfg_keys: list[str] | None = None, cmap_name: str | None = 'viridis', plot_only: bool = False, show: bool = True, ax: matplotlib.axes._axes.Axes | None = None, minify_title: bool = False, legend_kwargs: dict[str, typing.Any] | None = None) -> matplotlib.axes._axes.Axes:
257	def plot(
258		self,
259		save_path: str | None = None,
260		cfg_keys: list[str] | None = None,
261		cmap_name: str | None = "viridis",
262		plot_only: bool = False,
263		show: bool = True,
264		ax: plt.Axes | None = None,
265		minify_title: bool = False,
266		legend_kwargs: dict[str, Any] | None = None,
267	) -> plt.Axes:
268		"""Plot the results of percolation analysis"""
269		# set up figure
270		if not ax:
271			fig: plt.Figure
272			ax_: plt.Axes
273			fig, ax_ = plt.subplots(1, 1, figsize=(22, 10))
274		else:
275			ax_ = ax
276
277		# plot
278		cmap = plt.get_cmap(cmap_name)
279		n_cfgs: int = len(self.result_values)
280		for i, (ep_cfg_name, result_values) in enumerate(
281			sorted(
282				self.result_values.items(),
283				# HACK: sort by grid size
284				#                 |--< name of config
285				#                 |    |-----------< gets 'g{n}'
286				#                 |    |            |--< gets '{n}'
287				#                 |    |            |
288				key=lambda x: int(x[0].split("-")[0][1:]),
289			),
290		):
291			ax_.plot(
292				# TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
293				self.param_values,  # type: ignore[arg-type]
294				# TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str"  [arg-type]
295				result_values,  # type: ignore[arg-type]
296				".-",
297				label=self.configs_by_key()[ep_cfg_name].name,
298				color=cmap((i + 0.5) / (n_cfgs - 0.5)),
299			)
300
301		# repr of config
302		cfg_shared: dict = self.configs_shared()
303		if minify_title:
304			cfg_shared["endpoint_kwargs"] = {
305				k: v
306				for k, v in cfg_shared["endpoint_kwargs"].items()
307				if k != "except_on_no_valid_endpoint"
308			}
309		cfg_repr: str = (
310			str(cfg_shared)
311			if cfg_keys is None
312			else (
313				"MazeDatasetConfig("
314				+ ", ".join(
315					[
316						f"{k}={cfg_shared[k].__name__}"
317						# TYPING: error: Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo"  [arg-type]
318						if isinstance(cfg_shared[k], Callable)  # type: ignore[arg-type]
319						else f"{k}={cfg_shared[k]}"
320						for k in cfg_keys
321					],
322				)
323				+ ")"
324			)
325		)
326
327		# add title and stuff
328		if not plot_only:
329			ax_.set_xlabel(self.param_key)
330			ax_.set_ylabel(self.analyze_func.__name__)
331			ax_.set_title(
332				f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}",
333			)
334			ax_.grid(True)
335			# ax_.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, -0.1))
336			legend_kwargs = {
337				**dict(loc="center left"),
338				**(legend_kwargs or dict()),
339			}
340			ax_.legend(**legend_kwargs)
341
342		# save and show
343		if save_path:
344			plt.savefig(save_path)
345
346		if show:
347			plt.show()
348
349		return ax_

Plot the results of percolation analysis

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
serialize
load
validate_fields_types
validate_field_type
diff
update_from_nested_dict
DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [('any', {'deadend_start': False, 'deadend_end': False, 'except_on_no_valid_endpoint': False}), ('deadends', {'deadend_start': True, 'deadend_end': True, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False}), ('deadends_unique', {'deadend_start': True, 'deadend_end': True, 'endpoints_not_equal': True, 'except_on_no_valid_endpoint': False})]
def endpoint_kwargs_to_name(ep_kwargs: dict) -> str:
378def endpoint_kwargs_to_name(ep_kwargs: dict) -> str:
379	"""convert endpoint kwargs options to a human-readable name"""
380	if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False):
381		if ep_kwargs.get("endpoints_not_equal", False):
382			return "deadends_unique"
383		else:
384			return "deadends"
385	else:
386		return "any"

convert endpoint kwargs options to a human-readable name

def full_percolation_analysis( n_mazes: int, p_val_count: int, grid_sizes: list[int], ep_kwargs: list[tuple[str, dict]] | None = None, generators: Sequence[Callable] = (<function LatticeMazeGenerators.gen_percolation>, <function LatticeMazeGenerators.gen_dfs_percolation>), save_dir: pathlib.Path = PosixPath('../docs/benchmarks/percolation_fractions'), parallel: bool | int = False, **analyze_kwargs) -> SweepResult:
389def full_percolation_analysis(
390	n_mazes: int,
391	p_val_count: int,
392	grid_sizes: list[int],
393	ep_kwargs: list[tuple[str, dict]] | None = None,
394	generators: Sequence[Callable] = (
395		LatticeMazeGenerators.gen_percolation,
396		LatticeMazeGenerators.gen_dfs_percolation,
397	),
398	save_dir: Path = Path("../docs/benchmarks/percolation_fractions"),
399	parallel: bool | int = False,
400	**analyze_kwargs,
401) -> SweepResult:
402	"run the full analysis of how percolation affects maze generation success"
403	if ep_kwargs is None:
404		ep_kwargs = DEFAULT_ENDPOINT_KWARGS
405
406	# configs
407	configs: list[MazeDatasetConfig] = list()
408
409	# TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx`
410	for ep_kw_name, ep_kw in ep_kwargs:  # noqa: B007
411		for gf_idx, gen_func in enumerate(generators):  # noqa: B007
412			configs.extend(
413				[
414					MazeDatasetConfig(
415						name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}",
416						grid_n=grid_n,
417						n_mazes=n_mazes,
418						maze_ctor=gen_func,
419						maze_ctor_kwargs=dict(p=float("nan")),
420						endpoint_kwargs=ep_kw,
421					)
422					for grid_n in grid_sizes
423				],
424			)
425
426	# get results
427	result: SweepResult = SweepResult.analyze(
428		configs=configs,  # type: ignore[misc]
429		# TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]"  [arg-type]
430		param_values=np.linspace(0.0, 1.0, p_val_count).tolist(),  # type: ignore[arg-type]
431		param_key="maze_ctor_kwargs.p",
432		analyze_func=dataset_success_fraction,
433		parallel=parallel,
434		**analyze_kwargs,
435	)
436
437	# save the result
438	results_path: Path = (
439		save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj"
440	)
441	print(f"Saving results to {results_path.as_posix()}")
442	result.save(results_path)
443
444	return result

run the full analysis of how percolation affects maze generation success

def plot_grouped( results: SweepResult, predict_fn: Optional[Callable[[maze_dataset.MazeDatasetConfig], float]] = None, prediction_density: int = 50, save_dir: pathlib.Path | None = None, show: bool = True, logy: bool = False, save_fmt: str = 'svg', figsize: tuple[int, int] = (22, 10), minify_title: bool = False, legend_kwargs: dict[str, typing.Any] | None = None, manual_titles: dict[typing.Literal['x', 'y', 'title'], str] | None = None) -> None:
452def plot_grouped(  # noqa: C901
453	results: SweepResult,
454	predict_fn: Callable[[MazeDatasetConfig], float] | None = None,
455	prediction_density: int = 50,
456	save_dir: Path | None = None,
457	show: bool = True,
458	logy: bool = False,
459	save_fmt: str = "svg",
460	figsize: tuple[int, int] = (22, 10),
461	minify_title: bool = False,
462	legend_kwargs: dict[str, Any] | None = None,
463	manual_titles: dict[Literal["x", "y", "title"], str] | None = None,
464) -> None:
465	"""Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs
466
467	with separate colormaps for each maze generator function
468
469	# Parameters:
470	- `results : SweepResult`
471		The sweep results to plot
472	- `predict_fn : Callable[[MazeDatasetConfig], float] | None`
473		Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines.
474	- `prediction_density : int`
475		Number of points to use for prediction curves (default: 50)
476	- `save_dir : Path | None`
477		Directory to save plots (defaults to `None`, meaning no saving)
478	- `show : bool`
479		Whether to display the plots (defaults to `True`)
480
481	# Usage:
482	```python
483	>>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16])
484	>>> plot_grouped(result, save_dir=Path("./plots"), show=False)
485	```
486	"""
487	# groups
488	endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs")  # type: ignore[assignment]
489	generator_funcs_names: list[str] = list(
490		{cfg.maze_ctor.__name__ for cfg in results.configs},
491	)
492
493	# if predicting, create denser p values
494	if predict_fn is not None:
495		p_dense = np.linspace(0.0, 1.0, prediction_density)
496
497	# separate plot for each set of endpoint kwargs
498	for ep_kw in endpoint_kwargs_set:
499		results_epkw: SweepResult = results.get_where(
500			"endpoint_kwargs",
501			functools.partial(_is_eq, b=ep_kw),
502			# lambda x: x == ep_kw,
503		)
504		shared_keys: set[str] = set(results_epkw.configs_shared().keys())
505		cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"})
506		fig, ax = plt.subplots(1, 1, figsize=figsize)
507		for gf_idx, gen_func in enumerate(generator_funcs_names):
508			results_filtered: SweepResult = results_epkw.get_where(
509				"maze_ctor",
510				# HACK: big hassle to do this without a lambda, is it really that bad?
511				lambda x: x.__name__ == gen_func,  # noqa: B023
512			)
513			if len(results_filtered.configs) < 1:
514				warnings.warn(
515					f"No results for {gen_func} and {ep_kw}. Skipping.",
516				)
517				continue
518
519			cmap_name = "Reds" if gf_idx == 0 else "Blues"
520			cmap = plt.get_cmap(cmap_name)
521
522			# Plot actual results
523			ax = results_filtered.plot(
524				cfg_keys=list(cfg_keys),
525				ax=ax,
526				show=False,
527				cmap_name=cmap_name,
528				minify_title=minify_title,
529				legend_kwargs=legend_kwargs,
530			)
531			if logy:
532				ax.set_yscale("log")
533
534			# Plot predictions if function provided
535			if predict_fn is not None:
536				for cfg_idx, cfg in enumerate(results_filtered.configs):
537					predictions = []
538					for p in p_dense:
539						cfg_temp = MazeDatasetConfig.load(cfg.serialize())
540						cfg_temp.maze_ctor_kwargs["p"] = p
541						predictions.append(predict_fn(cfg_temp))
542
543					# Get the same color as the actual data
544					n_cfgs: int = len(results_filtered.configs)
545					color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5))
546
547					# Plot prediction as dashed line
548					ax.plot(p_dense, predictions, "--", color=color, alpha=0.8)
549
550			if manual_titles:
551				ax.set_xlabel(manual_titles["x"])
552				ax.set_ylabel(manual_titles["y"])
553				ax.set_title(manual_titles["title"])
554
555		# save and show
556		if save_dir:
557			save_path: Path = (
558				save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.{save_fmt}"
559			)
560			print(f"Saving plot to {save_path.as_posix()}")
561			save_path.parent.mkdir(exist_ok=True, parents=True)
562			plt.savefig(save_path)
563
564		if show:
565			plt.show()

Plot grouped sweep percolation value results for each distinct endpoint_kwargs in the configs

with separate colormaps for each maze generator function

Parameters:

  • results : SweepResult The sweep results to plot
  • predict_fn : Callable[[MazeDatasetConfig], float] | None Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines.
  • prediction_density : int Number of points to use for prediction curves (default: 50)
  • save_dir : Path | None Directory to save plots (defaults to None, meaning no saving)
  • show : bool Whether to display the plots (defaults to True)

Usage:

>>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16])
>>> plot_grouped(result, save_dir=Path("./plots"), show=False)