Coverage for maze_dataset/benchmark/sweep_fit.py: 0%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-20 17:51 -0600

1"""Fit a PySR model to a sweep result and plot the results""" 

2 

3from pathlib import Path 

4from typing import TYPE_CHECKING, Callable 

5 

6import numpy as np 

7import sympy as sp # type: ignore[import-untyped] 

8from jaxtyping import Float 

9from pysr import PySRRegressor # type: ignore[import-untyped] 

10 

11from maze_dataset import MazeDatasetConfig 

12from maze_dataset.benchmark.config_sweep import ( 

13 SweepResult, 

14 plot_grouped, 

15) 

16 

17 

18def extract_training_data( 

19 sweep_result: SweepResult, 

20) -> tuple[Float[np.ndarray, "num_rows 5"], Float[np.ndarray, " num_rows"]]: 

21 """Extract data (X, y) from a SweepResult. 

22 

23 # Parameters: 

24 - `sweep_result : SweepResult` 

25 The sweep result holding configs and success arrays. 

26 

27 # Returns: 

28 - `X : Float[np.ndarray, "num_rows 5"]` 

29 Stacked [p, grid_n, deadends, endpoints_not_equal, generator_func] for each config & param-value 

30 - `y : Float[np.ndarray, "num_rows"]` 

31 The corresponding success rate 

32 """ 

33 x_list: list[list[float]] = [] 

34 y_list: list[float] = [] 

35 for cfg in sweep_result.configs: 

36 # success_arr is an array of success rates for param_values 

37 success_arr = sweep_result.result_values[cfg.to_fname()] 

38 for i, p in enumerate(sweep_result.param_values): 

39 # Temporarily override p in the config's array representation: 

40 arr = cfg._to_ps_array().copy() 

41 arr[0] = p # index 0 is 'p' 

42 x_list.append(arr) # type: ignore[arg-type] 

43 y_list.append(success_arr[i]) 

44 

45 return np.array(x_list, dtype=np.float64), np.array(y_list, dtype=np.float64) 

46 

47 

48DEFAULT_PYSR_KWARGS: dict = dict( 

49 niterations=50, 

50 unary_operators=[ 

51 "exp", 

52 "log", 

53 "square(x) = x^2", 

54 "cube(x) = x^3", 

55 "sigmoid(x) = 1/(1 + exp(-x))", 

56 ], 

57 extra_sympy_mappings={ 

58 "square": lambda x: x**2, 

59 "cube": lambda x: x**3, 

60 "sigmoid": lambda x: 1 / (1 + sp.exp(-x)), 

61 }, 

62 binary_operators=["+", "-", "*", "/", "^"], 

63 # populations=50, 

64 progress=True, 

65 model_selection="best", 

66) 

67 

68 

69def train_pysr_model( 

70 data: SweepResult, 

71 **pysr_kwargs, 

72) -> PySRRegressor: 

73 """Train a PySR model on the given sweep result data""" 

74 # Convert to arrays 

75 x, y = extract_training_data(data) 

76 

77 print(f"training data extracted: {x.shape = }, {y.shape = }") 

78 

79 # Fit the PySR model 

80 model: PySRRegressor = PySRRegressor(**{**DEFAULT_PYSR_KWARGS, **pysr_kwargs}) 

81 model.fit(x, y) 

82 

83 return model 

84 

85 

86def plot_model( 

87 data: SweepResult, 

88 model: PySRRegressor, 

89 save_dir: Path, 

90 show: bool = True, 

91) -> None: 

92 """Plot the model predictions against the sweep data""" 

93 # save all the equations 

94 save_dir.mkdir(parents=True, exist_ok=True) 

95 equations_file: Path = save_dir / "equations.txt" 

96 equations_file.write_text(repr(model)) 

97 print(f"Equations saved to: {equations_file = }") 

98 

99 # Create a callable that predicts from MazeDatasetConfig 

100 predict_fn: Callable = model.get_best()["lambda_format"] 

101 print(f"Best PySR Equation: {model.get_best()['equation'] = }") 

102 print(f"{predict_fn =}") 

103 

104 def predict_config(cfg: MazeDatasetConfig) -> float: 

105 arr = cfg._to_ps_array() 

106 result = predict_fn(arr)[0] 

107 return float(result) # pass the array as separate args 

108 

109 plot_grouped( 

110 data, 

111 predict_fn=predict_config, 

112 save_dir=save_dir, 

113 show=show, 

114 ) 

115 

116 

117def sweep_fit( 

118 data_path: Path, 

119 save_dir: Path, 

120 **pysr_kwargs, 

121) -> None: 

122 """read a sweep result, train a PySR model, and plot the results""" 

123 # Load the sweep result 

124 data: SweepResult = SweepResult.read(data_path) 

125 print(f"loaded data: {data.summary() = }") 

126 

127 # Train the PySR model 

128 model: PySRRegressor = train_pysr_model(data, **pysr_kwargs) 

129 

130 # Plot the model 

131 plot_model(data, model, save_dir, show=False) 

132 

133 

134if __name__ == "__main__": 

135 import argparse 

136 

137 argparser: argparse.ArgumentParser = argparse.ArgumentParser() 

138 argparser.add_argument( 

139 "data_path", 

140 type=Path, 

141 help="Path to the sweep result file", 

142 ) 

143 argparser.add_argument( 

144 "--save_dir", 

145 type=Path, 

146 default=Path("tests/_temp/percolation_fractions/fit_plots/"), 

147 help="Path to save the plots", 

148 ) 

149 argparser.add_argument( 

150 "--niterations", 

151 type=int, 

152 default=50, 

153 help="Number of iterations for PySR", 

154 ) 

155 args: argparse.Namespace = argparser.parse_args() 

156 

157 sweep_fit( 

158 args.data_path, 

159 args.save_dir, 

160 niterations=args.niterations, 

161 # add any additional kwargs here if running in CLI 

162 populations=50, 

163 # ^ Assuming we have 4 cores, this means 2 populations per core, so one is always running. 

164 population_size=50, 

165 # ^ Generations between migrations. 

166 timeout_in_seconds=60 * 60 * 7, 

167 # ^ stop after 7 hours have passed. 

168 maxsize=50, 

169 # ^ Allow greater complexity. 

170 weight_randomize=0.01, 

171 # ^ Randomize the tree much more frequently 

172 turbo=True, 

173 # ^ Faster evaluation (experimental) 

174 ) 

175 

176 

177def create_interactive_plot(heatmap: bool = True) -> None: # noqa: C901, PLR0915 

178 """Create an interactive plot with the specified grid layout 

179 

180 # Parameters: 

181 - `heatmap : bool` 

182 Whether to show heatmaps (defaults to `True`) 

183 """ 

184 import ipywidgets as widgets # type: ignore[import-untyped] 

185 import matplotlib.pyplot as plt 

186 from ipywidgets import FloatSlider, HBox, Layout, VBox 

187 from matplotlib.gridspec import GridSpec 

188 

189 from maze_dataset.dataset.success_predict_math import soft_step 

190 

191 # Create sliders with better layout 

192 x_slider = FloatSlider( 

193 min=0.0, 

194 max=1.0, 

195 step=0.01, 

196 value=0.5, 

197 description="x:", 

198 style={"description_width": "30px"}, 

199 layout=Layout(width="98%"), 

200 ) 

201 

202 p_slider = FloatSlider( 

203 min=0.0, 

204 max=1.0, 

205 step=0.01, 

206 value=0.5, 

207 description="p:", 

208 style={"description_width": "30px"}, 

209 layout=Layout(width="98%"), 

210 ) 

211 

212 alpha_slider = FloatSlider( 

213 min=0.1, 

214 max=30.0, 

215 step=0.1, 

216 value=10.0, 

217 description="α:", # noqa: RUF001 

218 style={"description_width": "30px"}, 

219 layout=Layout(width="98%"), 

220 ) 

221 

222 w_slider = FloatSlider( 

223 min=0.0, 

224 max=20, 

225 step=0.5, 

226 value=4.0 / 7.0, 

227 description="w:", 

228 style={"description_width": "30px"}, 

229 layout=Layout(width="98%"), 

230 ) 

231 

232 # Slider layout control 

233 slider_box = VBox( 

234 [ 

235 widgets.Label("Adjust parameters:"), 

236 HBox( 

237 [x_slider, w_slider], 

238 layout=Layout(width="100%", justify_content="space-between"), 

239 ), 

240 HBox( 

241 [p_slider, alpha_slider], 

242 layout=Layout(width="100%", justify_content="space-between"), 

243 ), 

244 ], 

245 ) 

246 

247 def update_plot(x: float, p: float, alpha: float, w: float) -> None: # noqa: PLR0915 

248 """Update the plot with current slider values 

249 

250 # Parameters: 

251 - `x : float` 

252 x value 

253 - `p : float` 

254 p value 

255 - `k : float` 

256 k value 

257 - `alpha : float` 

258 alpha value 

259 """ 

260 # Set up the figure and grid - now 2x2 grid 

261 fig = plt.figure(figsize=(14, 10)) 

262 gs = GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1]) 

263 

264 # Create x and p values focused on [0,1] range 

265 xs = np.linspace(0.0, 1.0, 500) 

266 ps = np.linspace(0.0, 1.0, 500) 

267 

268 # Plot 1: f(x) vs x (top left) 

269 ax1 = fig.add_subplot(gs[0, 0]) 

270 ys = soft_step(xs, p, alpha, w) 

271 ax1.plot(xs, ys, "b-", linewidth=2.5) 

272 

273 # Add guidelines 

274 ax1.axvline(x=p, color="red", linestyle="--", alpha=0.7, label=f"p = {p:.2f}") 

275 ax1.axvline(x=w, color="green", linestyle="--", alpha=0.7, label=f"w = {w:.2f}") 

276 ax1.axvline(x=x, color="blue", linestyle=":", alpha=0.7, label=f"x = {x:.2f}") 

277 

278 # Add identity line for reference 

279 ax1.plot(xs, xs, "k--", alpha=0.3, label="f(x) = x") 

280 

281 ax1.set_xlim(0, 1) 

282 ax1.set_ylim(0, 1) 

283 ax1.set_xlabel("x") 

284 ax1.set_ylabel("f(x)") 

285 ax1.set_title(f"f(x) with p={p:.2f}, w={w:.2f}, α={alpha:.1f}") # noqa: RUF001 

286 ax1.grid(True, alpha=0.3) 

287 ax1.legend(loc="best") 

288 

289 # Plot 2: f(p) vs p with fixed x (top right) 

290 ax2 = fig.add_subplot(gs[0, 1]) 

291 

292 # Plot the main curve with current x value 

293 f_p_values = np.array([soft_step(x, p_val, alpha, w) for p_val in ps]) 

294 ax2.plot(ps, f_p_values, "blue", linewidth=2.5, label=f"x = {x:.2f}") 

295 

296 # Create additional curves for different x values 

297 x_values = [0.2, 0.4, 0.6, 0.8] 

298 colors = ["purple", "orange", "magenta", "green"] 

299 

300 for x_val, color in zip(x_values, colors, strict=False): 

301 # Don't draw if too close to current x 

302 if abs(x_val - x) > 0.05: # noqa: PLR2004 

303 f_p_values = np.array( 

304 [soft_step(x_val, p_val, alpha, w) for p_val in ps], 

305 ) 

306 ax2.plot( 

307 ps, 

308 f_p_values, 

309 color=color, 

310 linewidth=1.5, 

311 alpha=0.4, 

312 label=f"x = {x_val}", 

313 ) 

314 

315 # Add guideline for current p value 

316 ax2.axvline(x=p, color="red", linestyle="--", alpha=0.7) 

317 

318 ax2.set_xlim(0, 1) 

319 ax2.set_ylim(0, 1) 

320 ax2.set_xlabel("p") 

321 ax2.set_ylabel("f(x,p)") 

322 ax2.set_title(f"f(x,p) for fixed x={x:.2f}, w={w:.2f}, α={alpha:.1f}") # noqa: RUF001 

323 ax2.grid(True, alpha=0.3) 

324 ax2.legend(loc="best") 

325 

326 if heatmap: 

327 # Plot 3: Heatmap of f(x,p) (bottom left) 

328 ax3 = fig.add_subplot(gs[1, 0]) 

329 X, P = np.meshgrid(xs, ps) # noqa: N806 

330 Z = np.zeros_like(X) # noqa: N806 

331 

332 # Calculate f(x,p) for all combinations 

333 for i, p_val in enumerate(ps): 

334 # TYPING: error: Incompatible types in assignment (expression has type "floating[Any]", variable has type "float") [assignment] 

335 for j, x_val in enumerate(xs): # type: ignore[assignment] 

336 Z[i, j] = soft_step(x_val, p_val, alpha, w) 

337 

338 c = ax3.pcolormesh(X, P, Z, cmap="viridis", shading="auto") 

339 

340 # Add current parameter values as lines 

341 ax3.axhline(y=p, color="red", linestyle="--", label=f"p = {p:.2f}") 

342 ax3.axvline(x=w, color="green", linestyle="--", label=f"w = {w:.2f}") 

343 ax3.axvline(x=x, color="blue", linestyle="--", label=f"x = {x:.2f}") 

344 

345 # Add lines for the reference x values used in the top-right plot 

346 for x_val, color in zip(x_values, colors, strict=False): 

347 # Don't draw if too close to current x, magic value is fine 

348 if abs(x_val - x) > 0.05: # noqa: PLR2004 

349 ax3.axvline(x=x_val, color=color, linestyle=":", alpha=0.4) 

350 

351 # Mark the specific point corresponding to the current x and p values 

352 ax3.plot(x, p, "ro", markersize=8) 

353 

354 # yes we mean to use alpha here (RUF001) 

355 ax3.set_xlabel("x") 

356 ax3.set_ylabel("p") 

357 ax3.set_title(f"f(x,p) heatmap with w={w:.2f}, α={alpha:.1f}") # noqa: RUF001 

358 fig.colorbar(c, ax=ax3, label="f(x,p)") 

359 

360 # Plot 4: NEW Heatmap of f(x,p) as function of k and alpha (bottom right) 

361 ax4 = fig.add_subplot(gs[1, 1]) 

362 

363 # Create k and alpha ranges 

364 ws = np.linspace(0.0, 1.0, 100) 

365 alphas = np.linspace(0.1, 30.0, 100) 

366 

367 K, A = np.meshgrid(ws, alphas) # noqa: N806 

368 Z_ka = np.zeros_like(K) # noqa: N806 

369 

370 # Calculate f(x,p) for all combinations of k and alpha 

371 for i, alpha_val in enumerate(alphas): 

372 for j, w_val in enumerate(ws): 

373 Z_ka[i, j] = soft_step(x, p, alpha_val, w_val) 

374 

375 c2 = ax4.pcolormesh(K, A, Z_ka, cmap="plasma", shading="auto") 

376 

377 # Add current parameter values as lines 

378 # yes we mean to use alpha here (RUF001) 

379 ax4.axhline( 

380 y=alpha, 

381 color="purple", 

382 linestyle="--", 

383 label=f"α = {alpha:.1f}", # noqa: RUF001 

384 ) 

385 ax4.axvline(x=w, color="green", linestyle="--", label=f"w = {w:.2f}") 

386 

387 # Mark the specific point corresponding to the current w and alpha values 

388 ax4.plot(w, alpha, "ro", markersize=8) 

389 

390 # yes we mean to use alpha here (RUF001) 

391 ax4.set_xlabel("w") 

392 ax4.set_ylabel("α") # noqa: RUF001 

393 ax4.set_title(f"f(x,p) heatmap with fixed x={x:.2f}, p={p:.2f}") 

394 fig.colorbar(c2, ax=ax4, label="f(x,p,w,α)") # noqa: RUF001 

395 

396 plt.tight_layout() 

397 plt.show() 

398 

399 # Display the interactive widget 

400 interactive_output = widgets.interactive_output( 

401 update_plot, 

402 {"x": x_slider, "p": p_slider, "w": w_slider, "alpha": alpha_slider}, 

403 ) 

404 

405 # we noqa here because we will only call this function inside a notebook 

406 if not TYPE_CHECKING: 

407 display(VBox([slider_box, interactive_output])) # noqa: F821