In [1]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
IN_NOTEBOOK: bool = "__vsc_ipynb_file__" in globals()
# import pysr before torch to avoid
# UserWarning: torch was imported before juliacall. This may cause a segfault. To avoid this, import juliacall before importing torch. For updates, see https://github.com/pytorch/pytorch/issues/78829.
# doing the PYSR fit is expensive and takes *forever* in CI and tests, so we only do it in the notebook itself
if IN_NOTEBOOK:
import pysr # noqa: F401
from zanj import ZANJ
from maze_dataset import LatticeMazeGenerators, MazeDataset, MazeDatasetConfig
from maze_dataset.benchmark.config_sweep import (
SweepResult,
dataset_success_fraction,
full_percolation_analysis,
plot_grouped,
)
from maze_dataset.benchmark.sweep_fit import sweep_fit
# we want to ensure matplotlib is imported at the top and the formatter doesn't remove it
# so that figures are not created during notebook tests
assert plt
Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython
run a basic analysis¶
In [2]:
# Run the analysis
results: SweepResult = SweepResult.analyze(
configs=[
MazeDatasetConfig(
name=f"g{grid_n}-perc",
grid_n=grid_n,
n_mazes=32,
maze_ctor=LatticeMazeGenerators.gen_percolation,
maze_ctor_kwargs=dict(),
endpoint_kwargs=dict(
deadend_start=False,
deadend_end=False,
endpoints_not_equal=False,
except_on_no_valid_endpoint=False,
),
)
for grid_n in [2, 4, 6]
],
param_values=np.linspace(0.0, 1.0, 16).tolist(),
param_key="maze_ctor_kwargs.p",
analyze_func=dataset_success_fraction,
parallel=False,
)
# Plot results
results.plot(save_path=None, cfg_keys=["n_mazes", "endpoint_kwargs"])
Processing 3 items: 100%|██████████| 3/3 [00:01<00:00, 2.56it/s]
Out[2]:
<Axes: title={'center': "maze_ctor_kwargs.p vs dataset_success_fraction\nMazeDatasetConfig(n_mazes=32, endpoint_kwargs={'deadend_start': False, 'deadend_end': False, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False})"}, xlabel='maze_ctor_kwargs.p', ylabel='dataset_success_fraction'>
check saving/loading¶
In [3]:
path = Path("../tests/_temp/dataset_frac_sweep/results_small.zanj")
results.save(path)
ZANJ().read(path).plot(cfg_keys=["n_mazes", "endpoint_kwargs"])
Out[3]:
<Axes: title={'center': "maze_ctor_kwargs.p vs dataset_success_fraction\nMazeDatasetConfig(n_mazes=32, endpoint_kwargs={'deadend_start': False, 'deadend_end': False, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False})"}, xlabel='maze_ctor_kwargs.p', ylabel='dataset_success_fraction'>
sweep acrossall endpoint kwargs and generator funcs¶
In [4]:
results_sweep: SweepResult = full_percolation_analysis(
n_mazes=8,
p_val_count=11,
grid_sizes=[2, 4],
parallel=False,
save_dir=Path("tests/_temp/dataset_frac_sweep"),
)
Processing 12 items: 100%|██████████| 12/12 [00:01<00:00, 8.41it/s]
Saving results to tests/_temp/dataset_frac_sweep/result-n8-c12-p11.zanj
In [5]:
results_medium: SweepResult = SweepResult.read(
"../docs/benchmarks/percolation_fractions/medium/result-n128-c42-p50.zanj",
# "../docs/benchmarks/percolation_fractions/large/result-n256-c54-p100.zanj"
)
# magic autoreload
%load_ext autoreload
%autoreload 2
plot_grouped(
results_medium,
predict_fn=lambda x: x.success_fraction_estimate(),
prediction_density=10,
figsize=(10, 4),
# for paper version
# prediction_density=100,
# figsize=(11, 4),
# save_fmt="pdf",
# save_dir=Path("../docs/paper/figures/ep"),
# minify_title=True,
# legend_kwargs=dict(loc="center right", ncol=1, bbox_to_anchor=(1.175, 0.5)),
# manual_titles=dict(
# x="percolation probability $p$",
# y="success fraction",
# title="",
# ),
)
perform a pysr regression on a dataset we load¶
In [6]:
DATA_PATH_DIR: Path = Path("../docs/benchmarks/percolation_fractions/")
# DATA_PATH: str = DATA_PATH_DIR / "large/result-n256-c54-p100.zanj"
# DATA_PATH: str = DATA_PATH_DIR / "medium/result-n128-c42-p50.zanj"
DATA_PATH: str = DATA_PATH_DIR / "small/result-n64-c30-p25.zanj"
# DATA_PATH: str = DATA_PATH_DIR / "test/result-n16-c12-p16.zanj"
if IN_NOTEBOOK:
sweep_fit(
DATA_PATH,
Path("tests/_temp/fit_plots/"),
niterations=3,
)
/home/miv/projects/mazes/maze-dataset/.venv/lib/python3.13/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off. warnings.warn( Compiling Julia backend...
loaded data: data.summary() = {'len(configs)': 30, 'len(param_values)': 25, 'len(result_values)': 30, 'param_key': 'maze_ctor_kwargs.p', 'analyze_func': 'dataset_success_fraction'}
training data extracted: x.shape = (750, 5), y.shape = (750,)
/home/miv/projects/mazes/maze-dataset/.venv/lib/python3.13/site-packages/pysr/sr.py:96: UserWarning: You are using the `^` operator, but have not set up `constraints` for it. This may lead to overly complex expressions. One typical constraint is to use `constraints={..., '^': (-1, 1)}`, which will allow arbitrary-complexity base (-1) but only powers such as a constant or variable (1). For more tips, please see https://ai.damtp.cam.ac.uk/pysr/tuning/
warnings.warn(
[ Info: Started!
[ Info: Final population:
[ Info: Results saved to:
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity Loss Score Equation
1 1.151e-01 0.000e+00 y = 0.65076
3 9.782e-02 8.151e-02 y = 0.54285 ^ x₂
4 9.379e-02 4.201e-02 y = 1.3093 - sigmoid(x₂)
5 9.187e-02 2.071e-02 y = (0.67643 ^ x₂) + -0.13356
8 8.157e-02 3.966e-02 y = 0.62299 ^ square(x₂ + (x₀ ^ x₁))
11 7.360e-02 3.426e-02 y = sigmoid((1.9875 - x₂) - cube((x₂ + x₀) - 0.67704))
20 6.255e-02 1.807e-02 y = (sigmoid(x₃ * cube(sigmoid(x₄ + -0.38019) * -1.3509)) ...
^ (((x₀ * 1.6707) ^ x₄) ^ 2.148)) ^ x₂
24 6.118e-02 5.528e-03 y = (sigmoid(x₃ * cube(sigmoid(((x₄ + -0.85588) + x₄) + -0...
.15546) * -1.3509)) ^ (((x₀ * 1.6707) ^ x₄) ^ 2.148)) ^ x₂
26 6.113e-02 4.297e-04 y = (sigmoid(cube(x₃ * (sigmoid(((x₄ + -0.486) + x₄) + -0....
15546) * -1.3509))) ^ ((((x₀ * 1.0115) / 0.60542) ^ x₄) ^ ...
2.148)) ^ x₂
27 6.107e-02 9.486e-04 y = (sigmoid(x₃ * cube(-1.3509 * sigmoid((x₄ + -0.486) + (...
x₄ + log(0.74351))))) ^ (((x₀ * (1.0115 / 0.60542)) ^ x₄) ...
^ 2.148)) ^ x₂
28 6.086e-02 3.459e-03 y = (sigmoid(x₃ * cube(sigmoid((x₄ + -0.486) + (x₄ + (-0.8...
1661 + x₄))) * -1.3509)) ^ ((((x₀ * 1.0115) / 0.60542) ^ x...
₄) ^ 2.148)) ^ x₂
───────────────────────────────────────────────────────────────────────────────────────────────────
Equations saved to: equations_file = PosixPath('tests/_temp/fit_plots/equations.txt')
Best PySR Equation: model.get_best()['equation'] = '0.6229922 ^ square(x2 + (x0 ^ x1))'
predict_fn =PySRFunction(X=>0.6229922**((x0**x1 + x2)**2))
Saving plot to tests/_temp/fit_plots/ep_any.svg
Saving plot to tests/_temp/fit_plots/ep_deadends.svg
Saving plot to tests/_temp/fit_plots/ep_deadends_unique.svg
- outputs/20251012_223100_eKCAEL/hall_of_fame.csv
interactive plots for figuring out maze_dataset.math.soft_step()¶
In [7]:
# Run the interactive visualization if in a Jupyter notebook
if IN_NOTEBOOK:
from maze_dataset.benchmark.sweep_fit import create_interactive_plot
create_interactive_plot(True)
VBox(children=(VBox(children=(Label(value='Adjust parameters:'), HBox(children=(FloatSlider(value=0.5, descrip…
In [8]:
cfg = MazeDatasetConfig(
name="test",
seed=3,
grid_n=5,
n_mazes=10,
maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
maze_ctor_kwargs=dict(p=0.7),
endpoint_kwargs=dict(
deadend_start=True,
# deadend_end=True,
endpoints_not_equal=True,
except_on_no_valid_endpoint=False,
),
)
print(f"{cfg.success_fraction_estimate() = }")
cfg_new = cfg.success_fraction_compensate()
print(f"{cfg_new.n_mazes = }")
/home/miv/projects/mazes/maze-dataset/maze_dataset/dataset/dataset.py:95: UserWarning: in GPTDatasetConfig self.name='test', self.seed=3 is trying to override GLOBAL_SEED = 42 warnings.warn(
cfg.success_fraction_estimate() = np.float64(0.45037493871086454) cfg_new.n_mazes = 27
In [9]:
len(MazeDataset.from_config(cfg_new))
Out[9]:
12