docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.dataset.success_predict_math

math for getting the MazeDatasetConfig.success_fraction_estimate() function to work

Desmos link: https://www.desmos.com/calculator/qllvhwftvy


  1"""math for getting the `MazeDatasetConfig.success_fraction_estimate()` function to work
  2
  3Desmos link: https://www.desmos.com/calculator/qllvhwftvy
  4"""
  5
  6import numpy as np
  7from jaxtyping import Float
  8
  9
 10def sigmoid(x: float) -> float:
 11	r"$\sigma(x) = \frac{1}{1 + e^{-x}}$"
 12	return 1 / (1 + np.exp(-x))
 13
 14
 15# sigmoid_shifted = lambda x: 1 / (1 + np.exp(-1000 * (x - 0.5)))
 16# r"sigmoid(x)= 1 / (1 + e^{-b(x-0.5)})"
 17
 18# g_poly = lambda q, a: 1 - np.abs(2 * q - 1) ** a
 19# r"g(q,a) = 1 - (|2q-1|)^{a}"
 20
 21# f_poly = lambda q, a: q * g_poly(q, a)
 22# r"f(q,a) = q * g(q,a)"
 23
 24# h_func = lambda q, a: f_poly(q, a) * (1 - sigmoid_shifted(q)) + (1 - f_poly(1 - q, a)) * sigmoid_shifted(q)
 25# r"h(q,a,b) = f(q,a) * (1-s(q,b)) + (1-f(1-q,a)) * s(q,b)"
 26
 27# A_scaling = lambda q, a, w: w * g_poly(q, a)
 28# r"A(q) = b * g(q, a)"
 29
 30
 31def sigmoid_shifted(x: float) -> float:
 32	r"\sigma_s(x)= \frac{1}{1 + e^{-10^3 \cdot (x-0.5)}}"
 33	return 1 / (1 + np.exp(-1000 * (x - 0.5)))
 34
 35
 36def g_poly(q: float, a: float) -> float:
 37	r"$g(q,a) = 1 - (|2q-1|)^{a}$"
 38	return 1 - np.abs(2 * q - 1) ** a
 39
 40
 41def f_poly(q: float, a: float) -> float:
 42	r"$f(q,a) = q \cdot g(q,a)$"
 43	return q * g_poly(q, a)
 44
 45
 46def h_func(q: float, a: float) -> float:
 47	r"""$h(q,a,b) = f(q,a) \cdot (1-\sigma_s(q)) + (1-f(1-q,a)) \cdot \sigma_s(q)$"""
 48	return f_poly(q, a) * (1 - sigmoid_shifted(q)) + (
 49		1 - f_poly(1 - q, a)
 50	) * sigmoid_shifted(q)
 51
 52
 53def A_scaling(q: float, a: float, w: float) -> float:
 54	r"$A(q) = w \cdot g(q, a)$"
 55	return w * g_poly(q, a)
 56
 57
 58def soft_step(
 59	x: float | np.floating | Float[np.ndarray, " n"],
 60	p: float | np.floating,
 61	alpha: float | np.floating = 5,
 62	w: float | np.floating = 50,
 63) -> float:
 64	"""when p is close to 0.5 acts like the identity wrt x, but when p is close to 0 or 1, pushes x to 0 or 1 (whichever is closest)
 65
 66	https://www.desmos.com/calculator/qllvhwftvy
 67	"""
 68	# TYPING: this is messed up, some of these args can be arrays but i dont remember which?
 69	return h_func(
 70		x,  # type: ignore[arg-type]
 71		A_scaling(p, alpha, w),  # type: ignore[arg-type]
 72	)
 73
 74
 75# `cfg: MazeDatasetConfig` but we can't import that because it would create a circular import
 76def cfg_success_predict_fn(cfg) -> float:  # noqa: ANN001
 77	"learned by pysr, see `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmark.config_fit`"
 78	x = cfg._to_ps_array()
 79	raw_val: float = sigmoid(
 80		(
 81			(
 82				((sigmoid((x[1] - x[3]) ** 3) * -4.721228) - (x[3] * 1.4636494))
 83				* (
 84					x[2]
 85					* (
 86						x[4]
 87						+ (((x[0] + 0.048765484) ** 9.746339) + (0.8998194 ** x[1]))
 88					)
 89				)
 90			)
 91			+ (2.4524326 ** (2.9501643 - x[0]))
 92		)
 93		* (
 94			(
 95				(((0.9077277 - x[0]) * ((x[4] * 1.0520288) ** x[1])) + x[0])
 96				* sigmoid(x[1]) ** 3
 97			)
 98			+ -0.18268494
 99		),
100	)
101	return soft_step(
102		x=raw_val,
103		p=x[0],
104		alpha=5,  # manually tuned
105		w=10,  # manually tuned
106	)

def sigmoid(x: float) -> float:
11def sigmoid(x: float) -> float:
12	r"$\sigma(x) = \frac{1}{1 + e^{-x}}$"
13	return 1 / (1 + np.exp(-x))

$\sigma(x) = \frac{1}{1 + e^{-x}}$

def sigmoid_shifted(x: float) -> float:
32def sigmoid_shifted(x: float) -> float:
33	r"\sigma_s(x)= \frac{1}{1 + e^{-10^3 \cdot (x-0.5)}}"
34	return 1 / (1 + np.exp(-1000 * (x - 0.5)))

\sigma_s(x)= \frac{1}{1 + e^{-10^3 \cdot (x-0.5)}}

def g_poly(q: float, a: float) -> float:
37def g_poly(q: float, a: float) -> float:
38	r"$g(q,a) = 1 - (|2q-1|)^{a}$"
39	return 1 - np.abs(2 * q - 1) ** a

$g(q,a) = 1 - (|2q-1|)^{a}$

def f_poly(q: float, a: float) -> float:
42def f_poly(q: float, a: float) -> float:
43	r"$f(q,a) = q \cdot g(q,a)$"
44	return q * g_poly(q, a)

$f(q,a) = q \cdot g(q,a)$

def h_func(q: float, a: float) -> float:
47def h_func(q: float, a: float) -> float:
48	r"""$h(q,a,b) = f(q,a) \cdot (1-\sigma_s(q)) + (1-f(1-q,a)) \cdot \sigma_s(q)$"""
49	return f_poly(q, a) * (1 - sigmoid_shifted(q)) + (
50		1 - f_poly(1 - q, a)
51	) * sigmoid_shifted(q)

$h(q,a,b) = f(q,a) \cdot (1-\sigma_s(q)) + (1-f(1-q,a)) \cdot \sigma_s(q)$

def A_scaling(q: float, a: float, w: float) -> float:
54def A_scaling(q: float, a: float, w: float) -> float:
55	r"$A(q) = w \cdot g(q, a)$"
56	return w * g_poly(q, a)

$A(q) = w \cdot g(q, a)$

def soft_step( x: float | numpy.floating | jaxtyping.Float[ndarray, 'n'], p: float | numpy.floating, alpha: float | numpy.floating = 5, w: float | numpy.floating = 50) -> float:
59def soft_step(
60	x: float | np.floating | Float[np.ndarray, " n"],
61	p: float | np.floating,
62	alpha: float | np.floating = 5,
63	w: float | np.floating = 50,
64) -> float:
65	"""when p is close to 0.5 acts like the identity wrt x, but when p is close to 0 or 1, pushes x to 0 or 1 (whichever is closest)
66
67	https://www.desmos.com/calculator/qllvhwftvy
68	"""
69	# TYPING: this is messed up, some of these args can be arrays but i dont remember which?
70	return h_func(
71		x,  # type: ignore[arg-type]
72		A_scaling(p, alpha, w),  # type: ignore[arg-type]
73	)

when p is close to 0.5 acts like the identity wrt x, but when p is close to 0 or 1, pushes x to 0 or 1 (whichever is closest)

https://www.desmos.com/calculator/qllvhwftvy

def cfg_success_predict_fn(cfg) -> float:
 77def cfg_success_predict_fn(cfg) -> float:  # noqa: ANN001
 78	"learned by pysr, see `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmark.config_fit`"
 79	x = cfg._to_ps_array()
 80	raw_val: float = sigmoid(
 81		(
 82			(
 83				((sigmoid((x[1] - x[3]) ** 3) * -4.721228) - (x[3] * 1.4636494))
 84				* (
 85					x[2]
 86					* (
 87						x[4]
 88						+ (((x[0] + 0.048765484) ** 9.746339) + (0.8998194 ** x[1]))
 89					)
 90				)
 91			)
 92			+ (2.4524326 ** (2.9501643 - x[0]))
 93		)
 94		* (
 95			(
 96				(((0.9077277 - x[0]) * ((x[4] * 1.0520288) ** x[1])) + x[0])
 97				* sigmoid(x[1]) ** 3
 98			)
 99			+ -0.18268494
100		),
101	)
102	return soft_step(
103		x=raw_val,
104		p=x[0],
105		alpha=5,  # manually tuned
106		w=10,  # manually tuned
107	)

learned by pysr, see estimate_dataset_fractions.ipynb and maze_dataset.benchmark.config_fit