In [1]:
from pathlib import Path

import numpy as np

# import pysr before torch to avoid
# UserWarning: torch was imported before juliacall. This may cause a segfault. To avoid this, import juliacall before importing torch. For updates, see https://github.com/pytorch/pytorch/issues/78829.
import pysr  # noqa: F401
from zanj import ZANJ

from maze_dataset import LatticeMazeGenerators, MazeDataset, MazeDatasetConfig
from maze_dataset.benchmark.config_sweep import (
	SweepResult,
	dataset_success_fraction,
	full_percolation_analysis,
	plot_grouped,
)
from maze_dataset.benchmark.sweep_fit import sweep_fit
Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython

run a basic analysis¶

In [2]:
# Run the analysis
results: SweepResult = SweepResult.analyze(
	configs=[
		MazeDatasetConfig(
			name=f"g{grid_n}-perc",
			grid_n=grid_n,
			n_mazes=32,
			maze_ctor=LatticeMazeGenerators.gen_percolation,
			maze_ctor_kwargs=dict(),
			endpoint_kwargs=dict(
				deadend_start=False,
				deadend_end=False,
				endpoints_not_equal=False,
				except_on_no_valid_endpoint=False,
			),
		)
		for grid_n in [2, 4, 6]
	],
	param_values=np.linspace(0.0, 1.0, 16).tolist(),
	param_key="maze_ctor_kwargs.p",
	analyze_func=dataset_success_fraction,
	parallel=False,
)

# Plot results
results.plot(save_path=None, cfg_keys=["n_mazes", "endpoint_kwargs"])
tqdm_allowed_kwargs = {'leave', 'dynamic_ncols', 'unit_scale', 'iterable', 'gui', 'smoothing', 'file', 'ascii', 'unit_divisor', 'write_bytes', 'lock_args', 'miniters', 'self', 'nrows', 'bar_format', 'colour', 'position', 'disable', 'initial', 'desc', 'total', 'postfix', 'delay', 'unit', 'ncols', 'maxinterval', 'mininterval'}
mapped_kwargs = {'total': 3, 'desc': 'Processing 3 items'}
Processing 3 items: 100%|██████████| 3/3 [00:01<00:00,  2.81it/s]
No description has been provided for this image
Out[2]:
<Axes: title={'center': "maze_ctor_kwargs.p vs dataset_success_fraction\nMazeDatasetConfig(n_mazes=32, endpoint_kwargs={'deadend_start': False, 'deadend_end': False, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False})"}, xlabel='maze_ctor_kwargs.p', ylabel='dataset_success_fraction'>

check saving/loading¶

In [3]:
path = Path("../tests/_temp/dataset_frac_sweep/results_small.zanj")
results.save(path)
ZANJ().read(path).plot(cfg_keys=["n_mazes", "endpoint_kwargs"])
No description has been provided for this image
Out[3]:
<Axes: title={'center': "maze_ctor_kwargs.p vs dataset_success_fraction\nMazeDatasetConfig(n_mazes=32, endpoint_kwargs={'deadend_start': False, 'deadend_end': False, 'endpoints_not_equal': False, 'except_on_no_valid_endpoint': False})"}, xlabel='maze_ctor_kwargs.p', ylabel='dataset_success_fraction'>

sweep acrossall endpoint kwargs and generator funcs¶

In [4]:
results_sweep: SweepResult = full_percolation_analysis(
	n_mazes=16,
	p_val_count=11,
	grid_sizes=[2, 4, 6],
	parallel=False,
	save_dir=Path("tests/_temp/dataset_frac_sweep"),
)
tqdm_allowed_kwargs = {'leave', 'dynamic_ncols', 'unit_scale', 'iterable', 'gui', 'smoothing', 'file', 'ascii', 'unit_divisor', 'write_bytes', 'lock_args', 'miniters', 'self', 'nrows', 'bar_format', 'colour', 'position', 'disable', 'initial', 'desc', 'total', 'postfix', 'delay', 'unit', 'ncols', 'maxinterval', 'mininterval'}
mapped_kwargs = {'total': 18, 'desc': 'Processing 18 items'}
Processing 18 items: 100%|██████████| 18/18 [00:05<00:00,  3.46it/s]
Saving results to tests/_temp/dataset_frac_sweep/result-n16-c18-p11.zanj
In [5]:
results_medium: SweepResult = SweepResult.read(
	"../docs/benchmarks/percolation_fractions/medium/result-n128-c42-p50.zanj",
	# "../docs/benchmarks/percolation_fractions/large/result-n256-c54-p100.zanj"
)
In [6]:
plot_grouped(
	results_medium,
	predict_fn=lambda x: x.success_fraction_estimate(),
	prediction_density=100,
	# for paper version
	# figsize=(11, 4),
	# save_fmt="pdf",
	# save_dir=Path("../docs/paper/figures/ep"),
	# minify_title=True,
	# legend_kwargs=dict(loc="center right", ncol=1, bbox_to_anchor=(1.14, 0.5)),
	# manual_titles=dict(
	# 	x="percolation probability $p$",
	# 	y="success fraction",
	# 	title="",
	# ),
)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

perform a pysr regression on a dataset we load¶

In [7]:
DATA_PATH_DIR: Path = Path("../docs/benchmarks/percolation_fractions/")
# DATA_PATH: str = DATA_PATH_DIR / "large/result-n256-c54-p100.zanj"
# DATA_PATH: str = DATA_PATH_DIR / "medium/result-n128-c42-p50.zanj"
DATA_PATH: str = DATA_PATH_DIR / "small/result-n64-c30-p25.zanj"
# DATA_PATH: str = DATA_PATH_DIR / "test/result-n16-c12-p16.zanj"

sweep_fit(
	DATA_PATH,
	Path("tests/_temp/fit_plots/"),
	niterations=3,
)
/home/miv/projects/mazes/maze-dataset/.venv/lib/python3.12/site-packages/pysr/sr.py:2774: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.
  warnings.warn(
Compiling Julia backend...
loaded data: data.summary() = {'len(configs)': 30, 'len(param_values)': 25, 'len(result_values)': 30, 'param_key': 'maze_ctor_kwargs.p', 'analyze_func': 'dataset_success_fraction'}
training data extracted: x.shape = (750, 5), y.shape = (750,)
/home/miv/projects/mazes/maze-dataset/.venv/lib/python3.12/site-packages/pysr/sr.py:87: UserWarning: You are using the `^` operator, but have not set up `constraints` for it. This may lead to overly complex expressions. One typical constraint is to use `constraints={..., '^': (-1, 1)}`, which will allow arbitrary-complexity base (-1) but only powers such as a constant or variable (1). For more tips, please see https://ai.damtp.cam.ac.uk/pysr/tuning/
  warnings.warn(
[ Info: Started!
[ Info: Final population:
[ Info: Results saved to:
───────────────────────────────────────────────────────────────────────────────────────────────────
Complexity  Loss       Score      Equation
1           1.151e-01  1.594e+01  y = 0.65075
3           9.782e-02  8.151e-02  y = 0.54285 ^ x₂
4           9.272e-02  5.349e-02  y = 0.41565 / sigmoid(x₂)
5           9.194e-02  8.443e-03  y = 0.54285 ^ (x₂ ^ x₄)
6           9.187e-02  7.806e-04  y = sigmoid(0.092093 ^ (x₂ + -0.26235))
8           8.157e-02  5.949e-02  y = 0.62299 ^ square(x₂ + (x₀ ^ x₁))
11          7.360e-02  3.426e-02  y = sigmoid((1.9875 - x₂) - cube((x₂ + x₀) - 0.67704))
21          6.855e-02  7.105e-03  y = (sigmoid(cube(sigmoid(-0.51816) * (x₃ * -1.0527))) ^ (...
                                      (((x₀ * square(-1.0187)) / 0.53774) ^ x₄) ^ 1.0218)) ^ x₂
22          6.385e-02  7.099e-02  y = (sigmoid(cube(x₃ * (sigmoid((x₄ + x₄) + -0.70577) * -1...
                                      .0527))) ^ (((x₀ * 1.9298) ^ x₄) ^ 1.2272)) ^ x₂
26          6.113e-02  1.089e-02  y = (sigmoid(cube(x₃ * (sigmoid(((x₄ + -0.486) + x₄) + -0....
                                      15546) * -1.3509))) ^ ((((x₀ * 1.0115) / 0.60542) ^ x₄) ^ ...
                                      2.148)) ^ x₂
29          5.678e-02  2.461e-02  y = log(((sigmoid(((x₃ + cube(x₀)) + -0.21807) * ((x₄ + 0....
                                      24383) * -1.6438)) ^ (square(((x₀ * -1.006) / 0.54027) ^ x...
                                      ₄) ^ 1.4216)) ^ x₂) + 1.2992)
───────────────────────────────────────────────────────────────────────────────────────────────────
Equations saved to: equations_file = PosixPath('tests/_temp/fit_plots/equations.txt')
Best PySR Equation: model.get_best()['equation'] = '(sigmoid(cube(x3 * (sigmoid((x4 + x4) + -0.70577157) * -1.0527287))) ^ (((x0 * 1.9297646) ^ x4) ^ 1.2272323)) ^ x2'
predict_fn =PySRFunction(X=>((1/(exp(0.140414285413667*x3**3/(0.493727481942107 + exp(-2*x4))**3) + 1))**(((x0*1.9297646)**x4)**1.2272323))**x2)
Saving plot to tests/_temp/fit_plots/ep_any.svg
Saving plot to tests/_temp/fit_plots/ep_deadends.svg
Saving plot to tests/_temp/fit_plots/ep_deadends_unique.svg
  - outputs/20250408_231019_6fG9HE/hall_of_fame.csv
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

interactive plots for figuring out maze_dataset.math.soft_step()¶

In [8]:
# Run the interactive visualization if in a Jupyter notebook
if "__vsc_ipynb_file__" in globals():
	from maze_dataset.benchmark.sweep_fit import create_interactive_plot

	create_interactive_plot(True)
VBox(children=(VBox(children=(Label(value='Adjust parameters:'), HBox(children=(FloatSlider(value=0.5, descrip…
In [9]:
cfg = MazeDatasetConfig(
	name="test",
	seed=3,
	grid_n=5,
	n_mazes=10,
	maze_ctor=LatticeMazeGenerators.gen_dfs_percolation,
	maze_ctor_kwargs=dict(p=0.7),
	endpoint_kwargs=dict(
		deadend_start=True,
		# deadend_end=True,
		endpoints_not_equal=True,
		except_on_no_valid_endpoint=False,
	),
)

print(f"{cfg.success_fraction_estimate() = }")
cfg_new = cfg.success_fraction_compensate()
print(f"{cfg_new.n_mazes = }")
/home/miv/projects/mazes/maze-dataset/maze_dataset/dataset/dataset.py:94: UserWarning: in GPTDatasetConfig self.name='test', self.seed=3 is trying to override GLOBAL_SEED = 42
  warnings.warn(
cfg.success_fraction_estimate() = np.float64(0.45037493871086454)
cfg_new.n_mazes = 27
In [10]:
len(MazeDataset.from_config(cfg_new))
Out[10]:
12