In [1]:
from pathlib import Path
import numpy as np
# import pysr before torch to avoid
# UserWarning: torch was imported before juliacall. This may cause a segfault. To avoid this, import juliacall before importing torch. For updates, see https://github.com/pytorch/pytorch/issues/78829.
import pysr # noqa: F401
from zanj import ZANJ
from maze_dataset import LatticeMazeGenerators, MazeDataset, MazeDatasetConfig
from maze_dataset.benchmark.config_sweep import (
SweepResult,
dataset_success_fraction,
full_percolation_analysis,
plot_grouped,
)
from maze_dataset.benchmark.sweep_fit import sweep_fit
Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython
run a basic analysis¶
In [2]:
# Run the analysis
results: SweepResult = SweepResult.analyze(
configs=[
MazeDatasetConfig(
name=f"g{grid_n}-perc",
grid_n=grid_n,
n_mazes=32,
maze_ctor=LatticeMazeGenerators.gen_percolation,
maze_ctor_kwargs=dict(),
endpoint_kwargs=dict(
deadend_start=False,
deadend_end=False,
endpoints_not_equal=False,
except_on_no_valid_endpoint=False,
),
)
for grid_n in [2, 4, 6]
],
param_values=np.linspace(0.0, 1.0, 16).tolist(),
param_key="maze_ctor_kwargs.p",
analyze_func=dataset_success_fraction,
parallel=False,
)
# Plot results
results.plot(save_path=None, cfg_keys=["n_mazes", "endpoint_kwargs"])
tqdm_allowed_kwargs = {'leave', 'dynamic_ncols', 'unit_scale', 'iterable', 'gui', 'smoothing', 'file', 'ascii', 'unit_divisor', 'write_bytes', 'lock_args', 'miniters', 'self', 'nrows', 'bar_format', 'colour', 'position', 'disable', 'initial', 'desc', 'total', 'postfix', 'delay', 'unit', 'ncols', 'maxinterval', 'mininterval'} mapped_kwargs = {'total': 3, 'desc': 'Processing 3 items'}
Processing 3 items: 100%|██████████| 3/3 [00:01<00:00, 2.81it/s]
Out[2]:
<Axes: title={'center': "maze_ctor_kwargs.p vs dataset_success_fraction\nMazeDatasetConfig(n_mazes=32, endpoint_kwargs={'deadend_start': False, 'deadend_end': False, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False})"}, xlabel='maze_ctor_kwargs.p', ylabel='dataset_success_fraction'>
check saving/loading¶
In [3]:
path = Path("../tests/_temp/dataset_frac_sweep/results_small.zanj")
results.save(path)
ZANJ().read(path).plot(cfg_keys=["n_mazes", "endpoint_kwargs"])
Out[3]:
<Axes: title={'center': "maze_ctor_kwargs.p vs dataset_success_fraction\nMazeDatasetConfig(n_mazes=32, endpoint_kwargs={'deadend_start': False, 'deadend_end': False, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False})"}, xlabel='maze_ctor_kwargs.p', ylabel='dataset_success_fraction'>
sweep acrossall endpoint kwargs and generator funcs¶
In [4]:
results_sweep: SweepResult = full_percolation_analysis(
n_mazes=16,
p_val_count=11,
grid_sizes=[2, 4, 6],
parallel=False,
save_dir=Path("tests/_temp/dataset_frac_sweep"),
)
tqdm_allowed_kwargs = {'leave', 'dynamic_ncols', 'unit_scale', 'iterable', 'gui', 'smoothing', 'file', 'ascii', 'unit_divisor', 'write_bytes', 'lock_args', 'miniters', 'self', 'nrows', 'bar_format', 'colour', 'position', 'disable', 'initial', 'desc', 'total', 'postfix', 'delay', 'unit', 'ncols', 'maxinterval', 'mininterval'} mapped_kwargs = {'total': 18, 'desc': 'Processing 18 items'}
Processing 18 items: 100%|██████████| 18/18 [00:05<00:00, 3.46it/s]
Saving results to tests/_temp/dataset_frac_sweep/result-n16-c18-p11.zanj
In [5]:
results_medium: SweepResult = SweepResult.read(
"../docs/benchmarks/percolation_fractions/medium/result-n128-c42-p50.zanj",
# "../docs/benchmarks/percolation_fractions/large/result-n256-c54-p100.zanj"
)
In [6]:
plot_grouped(
results_medium,
predict_fn=lambda x: x.success_fraction_estimate(),
prediction_density=100,
# for paper version
# figsize=(11, 4),
# save_fmt="pdf",
# save_dir=Path("../docs/paper/figures/ep"),
# minify_title=True,
# legend_kwargs=dict(loc="center right", ncol=1, bbox_to_anchor=(1.14, 0.5)),
# manual_titles=dict(
# x="percolation probability $p$",
# y="success fraction",
# title="",
# ),
)
perform a pysr regression on a dataset we load¶
In [7]:
DATA_PATH_DIR: Path = Path("../docs/benchmarks/percolation_fractions/")
# DATA_PATH: str = DATA_PATH_DIR / "large/result-n256-c54-p100.zanj"
# DATA_PATH: str = DATA_PATH_DIR / "medium/result-n128-c42-p50.zanj"
DATA_PATH: str = DATA_PATH_DIR / "small/result-n64-c30-p25.zanj"
# DATA_PATH: str = DATA_PATH_DIR / "test/result-n16-c12-p16.zanj"
sweep_fit(
DATA_PATH,
Path("tests/_temp/fit_plots/"),
niterations=3,
)
/home/miv/projects/mazes/maze-dataset/.venv/lib/python3.12/site-packages/pysr/sr.py:2774: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off. warnings.warn( Compiling Julia backend...
loaded data: data.summary() = {'len(configs)': 30, 'len(param_values)': 25, 'len(result_values)': 30, 'param_key': 'maze_ctor_kwargs.p', 'analyze_func': 'dataset_success_fraction'} training data extracted: x.shape = (750, 5), y.shape = (750,)
/home/miv/projects/mazes/maze-dataset/.venv/lib/python3.12/site-packages/pysr/sr.py:87: UserWarning: You are using the `^` operator, but have not set up `constraints` for it. This may lead to overly complex expressions. One typical constraint is to use `constraints={..., '^': (-1, 1)}`, which will allow arbitrary-complexity base (-1) but only powers such as a constant or variable (1). For more tips, please see https://ai.damtp.cam.ac.uk/pysr/tuning/ warnings.warn( [ Info: Started! [ Info: Final population: [ Info: Results saved to:
─────────────────────────────────────────────────────────────────────────────────────────────────── Complexity Loss Score Equation 1 1.151e-01 1.594e+01 y = 0.65075 3 9.782e-02 8.151e-02 y = 0.54285 ^ x₂ 4 9.272e-02 5.349e-02 y = 0.41565 / sigmoid(x₂) 5 9.194e-02 8.443e-03 y = 0.54285 ^ (x₂ ^ x₄) 6 9.187e-02 7.806e-04 y = sigmoid(0.092093 ^ (x₂ + -0.26235)) 8 8.157e-02 5.949e-02 y = 0.62299 ^ square(x₂ + (x₀ ^ x₁)) 11 7.360e-02 3.426e-02 y = sigmoid((1.9875 - x₂) - cube((x₂ + x₀) - 0.67704)) 21 6.855e-02 7.105e-03 y = (sigmoid(cube(sigmoid(-0.51816) * (x₃ * -1.0527))) ^ (... (((x₀ * square(-1.0187)) / 0.53774) ^ x₄) ^ 1.0218)) ^ x₂ 22 6.385e-02 7.099e-02 y = (sigmoid(cube(x₃ * (sigmoid((x₄ + x₄) + -0.70577) * -1... .0527))) ^ (((x₀ * 1.9298) ^ x₄) ^ 1.2272)) ^ x₂ 26 6.113e-02 1.089e-02 y = (sigmoid(cube(x₃ * (sigmoid(((x₄ + -0.486) + x₄) + -0.... 15546) * -1.3509))) ^ ((((x₀ * 1.0115) / 0.60542) ^ x₄) ^ ... 2.148)) ^ x₂ 29 5.678e-02 2.461e-02 y = log(((sigmoid(((x₃ + cube(x₀)) + -0.21807) * ((x₄ + 0.... 24383) * -1.6438)) ^ (square(((x₀ * -1.006) / 0.54027) ^ x... ₄) ^ 1.4216)) ^ x₂) + 1.2992) ─────────────────────────────────────────────────────────────────────────────────────────────────── Equations saved to: equations_file = PosixPath('tests/_temp/fit_plots/equations.txt') Best PySR Equation: model.get_best()['equation'] = '(sigmoid(cube(x3 * (sigmoid((x4 + x4) + -0.70577157) * -1.0527287))) ^ (((x0 * 1.9297646) ^ x4) ^ 1.2272323)) ^ x2' predict_fn =PySRFunction(X=>((1/(exp(0.140414285413667*x3**3/(0.493727481942107 + exp(-2*x4))**3) + 1))**(((x0*1.9297646)**x4)**1.2272323))**x2) Saving plot to tests/_temp/fit_plots/ep_any.svg Saving plot to tests/_temp/fit_plots/ep_deadends.svg Saving plot to tests/_temp/fit_plots/ep_deadends_unique.svg - outputs/20250408_231019_6fG9HE/hall_of_fame.csv
interactive plots for figuring out maze_dataset.math.soft_step()
¶
In [8]:
# Run the interactive visualization if in a Jupyter notebook
if "__vsc_ipynb_file__" in globals():
from maze_dataset.benchmark.sweep_fit import create_interactive_plot
create_interactive_plot(True)
VBox(children=(VBox(children=(Label(value='Adjust parameters:'), HBox(children=(FloatSlider(value=0.5, descrip…
In [9]:
cfg = MazeDatasetConfig(
name="test",
seed=3,
grid_n=5,
n_mazes=10,
maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
maze_ctor_kwargs=dict(p=0.7),
endpoint_kwargs=dict(
deadend_start=True,
# deadend_end=True,
endpoints_not_equal=True,
except_on_no_valid_endpoint=False,
),
)
print(f"{cfg.success_fraction_estimate() = }")
cfg_new = cfg.success_fraction_compensate()
print(f"{cfg_new.n_mazes = }")
/home/miv/projects/mazes/maze-dataset/maze_dataset/dataset/dataset.py:94: UserWarning: in GPTDatasetConfig self.name='test', self.seed=3 is trying to override GLOBAL_SEED = 42 warnings.warn(
cfg.success_fraction_estimate() = np.float64(0.45037493871086454) cfg_new.n_mazes = 27
In [10]:
len(MazeDataset.from_config(cfg_new))
Out[10]:
12