In [1]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
IN_NOTEBOOK: bool = "__vsc_ipynb_file__" in globals()
# import pysr before torch to avoid
# UserWarning: torch was imported before juliacall. This may cause a segfault. To avoid this, import juliacall before importing torch. For updates, see https://github.com/pytorch/pytorch/issues/78829.
# doing the PYSR fit is expensive and takes *forever* in CI and tests, so we only do it in the notebook itself
if IN_NOTEBOOK:
import pysr # noqa: F401
from zanj import ZANJ
from maze_dataset import LatticeMazeGenerators, MazeDataset, MazeDatasetConfig
from maze_dataset.benchmark.config_sweep import (
SweepResult,
dataset_success_fraction,
full_percolation_analysis,
plot_grouped,
)
from maze_dataset.benchmark.sweep_fit import sweep_fit
# we want to ensure matplotlib is imported at the top and the formatter doesn't remove it
# so that figures are not created during notebook tests
assert plt
Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython
run a basic analysis¶
In [2]:
# Run the analysis
results: SweepResult = SweepResult.analyze(
configs=[
MazeDatasetConfig(
name=f"g{grid_n}-perc",
grid_n=grid_n,
n_mazes=32,
maze_ctor=LatticeMazeGenerators.gen_percolation,
maze_ctor_kwargs=dict(),
endpoint_kwargs=dict(
deadend_start=False,
deadend_end=False,
endpoints_not_equal=False,
except_on_no_valid_endpoint=False,
),
)
for grid_n in [2, 4, 6]
],
param_values=np.linspace(0.0, 1.0, 16).tolist(),
param_key="maze_ctor_kwargs.p",
analyze_func=dataset_success_fraction,
parallel=False,
)
# Plot results
results.plot(save_path=None, cfg_keys=["n_mazes", "endpoint_kwargs"])
Processing 3 items: 100%|██████████| 3/3 [00:02<00:00, 1.26it/s]
Out[2]:
<Axes: title={'center': "maze_ctor_kwargs.p vs dataset_success_fraction\nMazeDatasetConfig(n_mazes=32, endpoint_kwargs={'deadend_start': False, 'deadend_end': False, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False})"}, xlabel='maze_ctor_kwargs.p', ylabel='dataset_success_fraction'>
check saving/loading¶
In [3]:
path = Path("../tests/_temp/dataset_frac_sweep/results_small.zanj")
results.save(path)
ZANJ().read(path).plot(cfg_keys=["n_mazes", "endpoint_kwargs"])
Out[3]:
<Axes: title={'center': "maze_ctor_kwargs.p vs dataset_success_fraction\nMazeDatasetConfig(n_mazes=32, endpoint_kwargs={'deadend_start': False, 'deadend_end': False, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False})"}, xlabel='maze_ctor_kwargs.p', ylabel='dataset_success_fraction'>
sweep acrossall endpoint kwargs and generator funcs¶
In [4]:
results_sweep: SweepResult = full_percolation_analysis(
n_mazes=8,
p_val_count=11,
grid_sizes=[2, 4],
parallel=False,
save_dir=Path("tests/_temp/dataset_frac_sweep"),
)
Processing 12 items: 100%|██████████| 12/12 [00:02<00:00, 4.14it/s]
Saving results to tests/_temp/dataset_frac_sweep/result-n8-c12-p11.zanj
In [5]:
results_medium: SweepResult = SweepResult.read(
"../docs/benchmarks/percolation_fractions/medium/result-n128-c42-p50.zanj",
# "../docs/benchmarks/percolation_fractions/large/result-n256-c54-p100.zanj"
)
plot_grouped(
results_medium,
predict_fn=lambda x: x.success_fraction_estimate(),
prediction_density=10,
figsize=(10, 4),
# for paper version
# prediction_density=100,
# figsize=(11, 4),
# save_fmt="pdf",
# save_dir=Path("../docs/paper/figures/ep"),
# minify_title=True,
# legend_kwargs=dict(loc="center right", ncol=1, bbox_to_anchor=(1.14, 0.5)),
# manual_titles=dict(
# x="percolation probability $p$",
# y="success fraction",
# title="",
# ),
)
perform a pysr regression on a dataset we load¶
In [6]:
DATA_PATH_DIR: Path = Path("../docs/benchmarks/percolation_fractions/")
# DATA_PATH: str = DATA_PATH_DIR / "large/result-n256-c54-p100.zanj"
# DATA_PATH: str = DATA_PATH_DIR / "medium/result-n128-c42-p50.zanj"
DATA_PATH: str = DATA_PATH_DIR / "small/result-n64-c30-p25.zanj"
# DATA_PATH: str = DATA_PATH_DIR / "test/result-n16-c12-p16.zanj"
if IN_NOTEBOOK:
sweep_fit(
DATA_PATH,
Path("tests/_temp/fit_plots/"),
niterations=3,
)
/home/miv/projects/mazes/maze-dataset/.venv/lib/python3.13/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off. warnings.warn( Compiling Julia backend...
loaded data: data.summary() = {'len(configs)': 30, 'len(param_values)': 25, 'len(result_values)': 30, 'param_key': 'maze_ctor_kwargs.p', 'analyze_func': 'dataset_success_fraction'} training data extracted: x.shape = (750, 5), y.shape = (750,)
/home/miv/projects/mazes/maze-dataset/.venv/lib/python3.13/site-packages/pysr/sr.py:96: UserWarning: You are using the `^` operator, but have not set up `constraints` for it. This may lead to overly complex expressions. One typical constraint is to use `constraints={..., '^': (-1, 1)}`, which will allow arbitrary-complexity base (-1) but only powers such as a constant or variable (1). For more tips, please see https://ai.damtp.cam.ac.uk/pysr/tuning/ warnings.warn( [ Info: Started! [ Info: Final population: [ Info: Results saved to:
─────────────────────────────────────────────────────────────────────────────────────────────────── Complexity Loss Score Equation 1 1.151e-01 0.000e+00 y = 0.65071 3 9.782e-02 8.151e-02 y = 0.54285 ^ x₂ 4 9.377e-02 4.223e-02 y = 1.3048 - sigmoid(x₂) 6 9.187e-02 1.024e-02 y = sigmoid(-0.16822 ^ x₂) * 1.1852 8 8.795e-02 2.183e-02 y = sigmoid((x₀ - x₄) ^ (x₂ + x₂)) 9 8.474e-02 3.720e-02 y = (sigmoid(square(x₄ - x₀)) ^ x₂) + -0.11459 10 8.363e-02 1.310e-02 y = (sigmoid(-0.92301 * (x₀ + -0.92301)) ^ x₂) / 1.105 11 7.360e-02 1.278e-01 y = sigmoid((1.9875 - x₂) - cube((x₂ + x₀) - 0.67704)) 15 6.575e-02 2.819e-02 y = sigmoid(((x₂ * square(x₀)) * (-1.0812 - x₄)) + exp(x₄ ... - square(x₂))) 17 6.548e-02 2.018e-03 y = sigmoid((((-1.2604 - x₄) * x₂) * square(x₀)) + exp(squ... are(x₄) - square(cube(x₂)))) 18 6.500e-02 7.371e-03 y = sigmoid((((-1.5774 - x₄) * square(x₄)) * square(x₀)) +... exp(square(x₄) - square(cube(x₂)))) 21 6.243e-02 1.346e-02 y = (sigmoid(x₄ * cube((x₃ * -1.3509) * x₃)) ^ (((x₀ * (1.... 0115 / 0.60542)) ^ x₄) ^ 2.148)) ^ x₂ 23 6.084e-02 1.294e-02 y = (sigmoid((x₄ * sigmoid(x₃ * (x₄ + 0.42308))) * (x₃ * -... 1.3509)) ^ (((x₀ * 1.6707) ^ x₄) ^ 2.148)) ^ x₂ 24 6.072e-02 1.920e-03 y = (sigmoid(cube((x₄ * sigmoid(x₃ * (x₂ + 0.43005))) * (x... ₃ * -1.3774))) ^ (((x₀ * 1.6489) ^ x₄) ^ 2.1599)) ^ x₂ 29 5.858e-02 7.183e-03 y = (sigmoid(cube(x₃ * (x₄ + (((x₄ + x₂) + -0.68531) * -1.... 5203)))) ^ ((x₀ * (x₂ / 0.60003)) ^ (x₄ / 0.45258))) ^ ((x... ₄ ^ x₄) ^ x₂) ─────────────────────────────────────────────────────────────────────────────────────────────────── Equations saved to: equations_file = PosixPath('tests/_temp/fit_plots/equations.txt') Best PySR Equation: model.get_best()['equation'] = 'sigmoid((1.9875373 - x2) - cube((x2 + x0) - 0.6770367))' predict_fn =PySRFunction(X=>1/(0.137032480131706*exp(x2 + (x0 + x2 - 1*0.6770367)**3) + 1)) Saving plot to tests/_temp/fit_plots/ep_any.svg Saving plot to tests/_temp/fit_plots/ep_deadends.svg Saving plot to tests/_temp/fit_plots/ep_deadends_unique.svg - outputs/20250729_224811_XuajNs/hall_of_fame.csv
interactive plots for figuring out maze_dataset.math.soft_step()
¶
In [7]:
# Run the interactive visualization if in a Jupyter notebook
if IN_NOTEBOOK:
from maze_dataset.benchmark.sweep_fit import create_interactive_plot
create_interactive_plot(True)
VBox(children=(VBox(children=(Label(value='Adjust parameters:'), HBox(children=(FloatSlider(value=0.5, descrip…
In [8]:
cfg = MazeDatasetConfig(
name="test",
seed=3,
grid_n=5,
n_mazes=10,
maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
maze_ctor_kwargs=dict(p=0.7),
endpoint_kwargs=dict(
deadend_start=True,
# deadend_end=True,
endpoints_not_equal=True,
except_on_no_valid_endpoint=False,
),
)
print(f"{cfg.success_fraction_estimate() = }")
cfg_new = cfg.success_fraction_compensate()
print(f"{cfg_new.n_mazes = }")
/home/miv/projects/mazes/maze-dataset/maze_dataset/dataset/dataset.py:95: UserWarning: in GPTDatasetConfig self.name='test', self.seed=3 is trying to override GLOBAL_SEED = 42 warnings.warn(
cfg.success_fraction_estimate() = np.float64(0.45037493871086454) cfg_new.n_mazes = 27
In [9]:
len(MazeDataset.from_config(cfg_new))
Out[9]:
12