Coverage for maze_dataset/benchmark/speed.py: 0%
38 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
1"benchmark the speed of maze generation"
3import functools
4import random
5import timeit
6from pathlib import Path
7from typing import Any, Sequence
9from tqdm import tqdm
11from maze_dataset import MazeDataset, MazeDatasetConfig
12from maze_dataset.generation.default_generators import DEFAULT_GENERATORS
13from maze_dataset.generation.generators import GENERATORS_MAP
15_BASE_CFG_KWARGS: dict = dict(
16 grid_n=None,
17 n_mazes=None,
18)
20_GENERATE_KWARGS: dict = dict(
21 gen_parallel=False,
22 pool_kwargs=None,
23 verbose=False,
24 # do_generate = True,
25 # load_local = False,
26 # save_local = False,
27 # zanj = None,
28 # do_download = False,
29 # local_base_path = "INVALID",
30 # except_on_config_mismatch = True,
31 # verbose = False,
32)
35def time_generation(
36 base_configs: list[tuple[str, dict]],
37 grid_n_vals: list[int],
38 n_mazes_vals: list[int],
39 trials: int = 10,
40 verbose: bool = False,
41) -> list[dict[str, Any]]:
42 "time the generation of mazes for various configurations"
43 # assemble configs
44 configs: list[MazeDatasetConfig] = list()
46 for b_cfg in base_configs:
47 for grid_n in grid_n_vals:
48 for n_mazes in n_mazes_vals:
49 configs.append(
50 MazeDatasetConfig(
51 name="benchmark",
52 grid_n=grid_n,
53 n_mazes=n_mazes,
54 maze_ctor=GENERATORS_MAP[b_cfg[0]],
55 maze_ctor_kwargs=b_cfg[1],
56 ),
57 )
59 # shuffle configs (in place) (otherwise progress bar is annoying)
60 random.shuffle(configs)
62 # time generation for each config
63 times: list[dict[str, Any]] = list()
64 total: int = len(configs)
65 for idx, cfg in tqdm(
66 enumerate(configs),
67 desc="Timing generation",
68 unit="config",
69 total=total,
70 disable=verbose,
71 ):
72 if verbose:
73 print(f"Timing generation for config {idx + 1}/{total}\n{cfg}")
75 t: float = (
76 timeit.timeit(
77 stmt=functools.partial(MazeDataset.generate, cfg, **_GENERATE_KWARGS), # type: ignore[arg-type]
78 number=trials,
79 )
80 / trials
81 )
83 if verbose:
84 print(f"avg time: {t:.3f} s")
86 times.append(
87 dict(
88 cfg_name=cfg.name,
89 grid_n=cfg.grid_n,
90 n_mazes=cfg.n_mazes,
91 maze_ctor=cfg.maze_ctor.__name__,
92 maze_ctor_kwargs=cfg.maze_ctor_kwargs,
93 trials=trials,
94 time=t,
95 ),
96 )
98 return times
101def run_benchmark(
102 save_path: str,
103 base_configs: list[tuple[str, dict]] | None = None,
104 grid_n_vals: Sequence[int] = (2, 3, 4, 5, 8, 10, 16, 25, 32),
105 n_mazes_vals: Sequence[int] = tuple(range(1, 12, 2)),
106 trials: int = 10,
107 verbose: bool = True,
108) -> "pd.DataFrame": # type: ignore[name-defined] # noqa: F821
109 "run the benchmark and save the results to a file"
110 import pandas as pd
112 if base_configs is None:
113 base_configs = DEFAULT_GENERATORS
115 times: list[dict] = time_generation(
116 base_configs=base_configs,
117 grid_n_vals=list(grid_n_vals),
118 n_mazes_vals=list(n_mazes_vals),
119 trials=trials,
120 verbose=verbose,
121 )
123 df: pd.DataFrame = pd.DataFrame(times)
125 # print the whole dataframe contents to console as csv
126 print(df.to_csv())
128 # save to file
129 Path(save_path).parent.mkdir(parents=True, exist_ok=True)
130 df.to_json(save_path, orient="records", lines=True)
132 return df