Coverage for maze_dataset/benchmark/sweep_fit.py: 0%
139 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
1"""Fit a PySR model to a sweep result and plot the results"""
3from pathlib import Path
4from typing import TYPE_CHECKING, Callable
6import numpy as np
7import sympy as sp # type: ignore[import-untyped]
8from jaxtyping import Float
9from pysr import PySRRegressor # type: ignore[import-untyped]
11from maze_dataset import MazeDatasetConfig
12from maze_dataset.benchmark.config_sweep import (
13 SweepResult,
14 plot_grouped,
15)
18def extract_training_data(
19 sweep_result: SweepResult,
20) -> tuple[Float[np.ndarray, "num_rows 5"], Float[np.ndarray, " num_rows"]]:
21 """Extract data (X, y) from a SweepResult.
23 # Parameters:
24 - `sweep_result : SweepResult`
25 The sweep result holding configs and success arrays.
27 # Returns:
28 - `X : Float[np.ndarray, "num_rows 5"]`
29 Stacked [p, grid_n, deadends, endpoints_not_equal, generator_func] for each config & param-value
30 - `y : Float[np.ndarray, "num_rows"]`
31 The corresponding success rate
32 """
33 x_list: list[list[float]] = []
34 y_list: list[float] = []
35 for cfg in sweep_result.configs:
36 # success_arr is an array of success rates for param_values
37 success_arr = sweep_result.result_values[cfg.to_fname()]
38 for i, p in enumerate(sweep_result.param_values):
39 # Temporarily override p in the config's array representation:
40 arr = cfg._to_ps_array().copy()
41 arr[0] = p # index 0 is 'p'
42 x_list.append(arr) # type: ignore[arg-type]
43 y_list.append(success_arr[i])
45 return np.array(x_list, dtype=np.float64), np.array(y_list, dtype=np.float64)
48DEFAULT_PYSR_KWARGS: dict = dict(
49 niterations=50,
50 unary_operators=[
51 "exp",
52 "log",
53 "square(x) = x^2",
54 "cube(x) = x^3",
55 "sigmoid(x) = 1/(1 + exp(-x))",
56 ],
57 extra_sympy_mappings={
58 "square": lambda x: x**2,
59 "cube": lambda x: x**3,
60 "sigmoid": lambda x: 1 / (1 + sp.exp(-x)),
61 },
62 binary_operators=["+", "-", "*", "/", "^"],
63 # populations=50,
64 progress=True,
65 model_selection="best",
66)
69def train_pysr_model(
70 data: SweepResult,
71 **pysr_kwargs,
72) -> PySRRegressor:
73 """Train a PySR model on the given sweep result data"""
74 # Convert to arrays
75 x, y = extract_training_data(data)
77 print(f"training data extracted: {x.shape = }, {y.shape = }")
79 # Fit the PySR model
80 model: PySRRegressor = PySRRegressor(**{**DEFAULT_PYSR_KWARGS, **pysr_kwargs})
81 model.fit(x, y)
83 return model
86def plot_model(
87 data: SweepResult,
88 model: PySRRegressor,
89 save_dir: Path,
90 show: bool = True,
91) -> None:
92 """Plot the model predictions against the sweep data"""
93 # save all the equations
94 save_dir.mkdir(parents=True, exist_ok=True)
95 equations_file: Path = save_dir / "equations.txt"
96 equations_file.write_text(repr(model))
97 print(f"Equations saved to: {equations_file = }")
99 # Create a callable that predicts from MazeDatasetConfig
100 predict_fn: Callable = model.get_best()["lambda_format"]
101 print(f"Best PySR Equation: {model.get_best()['equation'] = }")
102 print(f"{predict_fn =}")
104 def predict_config(cfg: MazeDatasetConfig) -> float:
105 arr = cfg._to_ps_array()
106 result = predict_fn(arr)[0]
107 return float(result) # pass the array as separate args
109 plot_grouped(
110 data,
111 predict_fn=predict_config,
112 save_dir=save_dir,
113 show=show,
114 )
117def sweep_fit(
118 data_path: Path,
119 save_dir: Path,
120 **pysr_kwargs,
121) -> None:
122 """read a sweep result, train a PySR model, and plot the results"""
123 # Load the sweep result
124 data: SweepResult = SweepResult.read(data_path)
125 print(f"loaded data: {data.summary() = }")
127 # Train the PySR model
128 model: PySRRegressor = train_pysr_model(data, **pysr_kwargs)
130 # Plot the model
131 plot_model(data, model, save_dir, show=False)
134if __name__ == "__main__":
135 import argparse
137 argparser: argparse.ArgumentParser = argparse.ArgumentParser()
138 argparser.add_argument(
139 "data_path",
140 type=Path,
141 help="Path to the sweep result file",
142 )
143 argparser.add_argument(
144 "--save_dir",
145 type=Path,
146 default=Path("tests/_temp/percolation_fractions/fit_plots/"),
147 help="Path to save the plots",
148 )
149 argparser.add_argument(
150 "--niterations",
151 type=int,
152 default=50,
153 help="Number of iterations for PySR",
154 )
155 args: argparse.Namespace = argparser.parse_args()
157 sweep_fit(
158 args.data_path,
159 args.save_dir,
160 niterations=args.niterations,
161 # add any additional kwargs here if running in CLI
162 populations=50,
163 # ^ Assuming we have 4 cores, this means 2 populations per core, so one is always running.
164 population_size=50,
165 # ^ Generations between migrations.
166 timeout_in_seconds=60 * 60 * 7,
167 # ^ stop after 7 hours have passed.
168 maxsize=50,
169 # ^ Allow greater complexity.
170 weight_randomize=0.01,
171 # ^ Randomize the tree much more frequently
172 turbo=True,
173 # ^ Faster evaluation (experimental)
174 )
177def create_interactive_plot(heatmap: bool = True) -> None: # noqa: C901, PLR0915
178 """Create an interactive plot with the specified grid layout
180 # Parameters:
181 - `heatmap : bool`
182 Whether to show heatmaps (defaults to `True`)
183 """
184 import ipywidgets as widgets # type: ignore[import-untyped]
185 import matplotlib.pyplot as plt
186 from ipywidgets import FloatSlider, HBox, Layout, VBox
187 from matplotlib.gridspec import GridSpec
189 from maze_dataset.dataset.success_predict_math import soft_step
191 # Create sliders with better layout
192 x_slider = FloatSlider(
193 min=0.0,
194 max=1.0,
195 step=0.01,
196 value=0.5,
197 description="x:",
198 style={"description_width": "30px"},
199 layout=Layout(width="98%"),
200 )
202 p_slider = FloatSlider(
203 min=0.0,
204 max=1.0,
205 step=0.01,
206 value=0.5,
207 description="p:",
208 style={"description_width": "30px"},
209 layout=Layout(width="98%"),
210 )
212 alpha_slider = FloatSlider(
213 min=0.1,
214 max=30.0,
215 step=0.1,
216 value=10.0,
217 description="α:", # noqa: RUF001
218 style={"description_width": "30px"},
219 layout=Layout(width="98%"),
220 )
222 w_slider = FloatSlider(
223 min=0.0,
224 max=20,
225 step=0.5,
226 value=4.0 / 7.0,
227 description="w:",
228 style={"description_width": "30px"},
229 layout=Layout(width="98%"),
230 )
232 # Slider layout control
233 slider_box = VBox(
234 [
235 widgets.Label("Adjust parameters:"),
236 HBox(
237 [x_slider, w_slider],
238 layout=Layout(width="100%", justify_content="space-between"),
239 ),
240 HBox(
241 [p_slider, alpha_slider],
242 layout=Layout(width="100%", justify_content="space-between"),
243 ),
244 ],
245 )
247 def update_plot(x: float, p: float, alpha: float, w: float) -> None: # noqa: PLR0915
248 """Update the plot with current slider values
250 # Parameters:
251 - `x : float`
252 x value
253 - `p : float`
254 p value
255 - `k : float`
256 k value
257 - `alpha : float`
258 alpha value
259 """
260 # Set up the figure and grid - now 2x2 grid
261 fig = plt.figure(figsize=(14, 10))
262 gs = GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])
264 # Create x and p values focused on [0,1] range
265 xs = np.linspace(0.0, 1.0, 500)
266 ps = np.linspace(0.0, 1.0, 500)
268 # Plot 1: f(x) vs x (top left)
269 ax1 = fig.add_subplot(gs[0, 0])
270 ys = soft_step(xs, p, alpha, w)
271 ax1.plot(xs, ys, "b-", linewidth=2.5)
273 # Add guidelines
274 ax1.axvline(x=p, color="red", linestyle="--", alpha=0.7, label=f"p = {p:.2f}")
275 ax1.axvline(x=w, color="green", linestyle="--", alpha=0.7, label=f"w = {w:.2f}")
276 ax1.axvline(x=x, color="blue", linestyle=":", alpha=0.7, label=f"x = {x:.2f}")
278 # Add identity line for reference
279 ax1.plot(xs, xs, "k--", alpha=0.3, label="f(x) = x")
281 ax1.set_xlim(0, 1)
282 ax1.set_ylim(0, 1)
283 ax1.set_xlabel("x")
284 ax1.set_ylabel("f(x)")
285 ax1.set_title(f"f(x) with p={p:.2f}, w={w:.2f}, α={alpha:.1f}") # noqa: RUF001
286 ax1.grid(True, alpha=0.3)
287 ax1.legend(loc="best")
289 # Plot 2: f(p) vs p with fixed x (top right)
290 ax2 = fig.add_subplot(gs[0, 1])
292 # Plot the main curve with current x value
293 f_p_values = np.array([soft_step(x, p_val, alpha, w) for p_val in ps])
294 ax2.plot(ps, f_p_values, "blue", linewidth=2.5, label=f"x = {x:.2f}")
296 # Create additional curves for different x values
297 x_values = [0.2, 0.4, 0.6, 0.8]
298 colors = ["purple", "orange", "magenta", "green"]
300 for x_val, color in zip(x_values, colors, strict=False):
301 # Don't draw if too close to current x
302 if abs(x_val - x) > 0.05: # noqa: PLR2004
303 f_p_values = np.array(
304 [soft_step(x_val, p_val, alpha, w) for p_val in ps],
305 )
306 ax2.plot(
307 ps,
308 f_p_values,
309 color=color,
310 linewidth=1.5,
311 alpha=0.4,
312 label=f"x = {x_val}",
313 )
315 # Add guideline for current p value
316 ax2.axvline(x=p, color="red", linestyle="--", alpha=0.7)
318 ax2.set_xlim(0, 1)
319 ax2.set_ylim(0, 1)
320 ax2.set_xlabel("p")
321 ax2.set_ylabel("f(x,p)")
322 ax2.set_title(f"f(x,p) for fixed x={x:.2f}, w={w:.2f}, α={alpha:.1f}") # noqa: RUF001
323 ax2.grid(True, alpha=0.3)
324 ax2.legend(loc="best")
326 if heatmap:
327 # Plot 3: Heatmap of f(x,p) (bottom left)
328 ax3 = fig.add_subplot(gs[1, 0])
329 X, P = np.meshgrid(xs, ps) # noqa: N806
330 Z = np.zeros_like(X) # noqa: N806
332 # Calculate f(x,p) for all combinations
333 for i, p_val in enumerate(ps):
334 # TYPING: error: Incompatible types in assignment (expression has type "floating[Any]", variable has type "float") [assignment]
335 for j, x_val in enumerate(xs): # type: ignore[assignment]
336 Z[i, j] = soft_step(x_val, p_val, alpha, w)
338 c = ax3.pcolormesh(X, P, Z, cmap="viridis", shading="auto")
340 # Add current parameter values as lines
341 ax3.axhline(y=p, color="red", linestyle="--", label=f"p = {p:.2f}")
342 ax3.axvline(x=w, color="green", linestyle="--", label=f"w = {w:.2f}")
343 ax3.axvline(x=x, color="blue", linestyle="--", label=f"x = {x:.2f}")
345 # Add lines for the reference x values used in the top-right plot
346 for x_val, color in zip(x_values, colors, strict=False):
347 # Don't draw if too close to current x, magic value is fine
348 if abs(x_val - x) > 0.05: # noqa: PLR2004
349 ax3.axvline(x=x_val, color=color, linestyle=":", alpha=0.4)
351 # Mark the specific point corresponding to the current x and p values
352 ax3.plot(x, p, "ro", markersize=8)
354 # yes we mean to use alpha here (RUF001)
355 ax3.set_xlabel("x")
356 ax3.set_ylabel("p")
357 ax3.set_title(f"f(x,p) heatmap with w={w:.2f}, α={alpha:.1f}") # noqa: RUF001
358 fig.colorbar(c, ax=ax3, label="f(x,p)")
360 # Plot 4: NEW Heatmap of f(x,p) as function of k and alpha (bottom right)
361 ax4 = fig.add_subplot(gs[1, 1])
363 # Create k and alpha ranges
364 ws = np.linspace(0.0, 1.0, 100)
365 alphas = np.linspace(0.1, 30.0, 100)
367 K, A = np.meshgrid(ws, alphas) # noqa: N806
368 Z_ka = np.zeros_like(K) # noqa: N806
370 # Calculate f(x,p) for all combinations of k and alpha
371 for i, alpha_val in enumerate(alphas):
372 for j, w_val in enumerate(ws):
373 Z_ka[i, j] = soft_step(x, p, alpha_val, w_val)
375 c2 = ax4.pcolormesh(K, A, Z_ka, cmap="plasma", shading="auto")
377 # Add current parameter values as lines
378 # yes we mean to use alpha here (RUF001)
379 ax4.axhline(
380 y=alpha,
381 color="purple",
382 linestyle="--",
383 label=f"α = {alpha:.1f}", # noqa: RUF001
384 )
385 ax4.axvline(x=w, color="green", linestyle="--", label=f"w = {w:.2f}")
387 # Mark the specific point corresponding to the current w and alpha values
388 ax4.plot(w, alpha, "ro", markersize=8)
390 # yes we mean to use alpha here (RUF001)
391 ax4.set_xlabel("w")
392 ax4.set_ylabel("α") # noqa: RUF001
393 ax4.set_title(f"f(x,p) heatmap with fixed x={x:.2f}, p={p:.2f}")
394 fig.colorbar(c2, ax=ax4, label="f(x,p,w,α)") # noqa: RUF001
396 plt.tight_layout()
397 plt.show()
399 # Display the interactive widget
400 interactive_output = widgets.interactive_output(
401 update_plot,
402 {"x": x_slider, "p": p_slider, "w": w_slider, "alpha": alpha_slider},
403 )
405 # we noqa here because we will only call this function inside a notebook
406 if not TYPE_CHECKING:
407 display(VBox([slider_box, interactive_output])) # noqa: F821