maze_dataset.dataset.success_predict_math
math for getting the MazeDatasetConfig.success_fraction_estimate()
function to work
Desmos link: https://www.desmos.com/calculator/qllvhwftvy
1"""math for getting the `MazeDatasetConfig.success_fraction_estimate()` function to work 2 3Desmos link: https://www.desmos.com/calculator/qllvhwftvy 4""" 5 6import numpy as np 7from jaxtyping import Float 8 9 10def sigmoid(x: float) -> float: 11 r"$\sigma(x) = \frac{1}{1 + e^{-x}}$" 12 return 1 / (1 + np.exp(-x)) 13 14 15# sigmoid_shifted = lambda x: 1 / (1 + np.exp(-1000 * (x - 0.5))) 16# r"sigmoid(x)= 1 / (1 + e^{-b(x-0.5)})" 17 18# g_poly = lambda q, a: 1 - np.abs(2 * q - 1) ** a 19# r"g(q,a) = 1 - (|2q-1|)^{a}" 20 21# f_poly = lambda q, a: q * g_poly(q, a) 22# r"f(q,a) = q * g(q,a)" 23 24# h_func = lambda q, a: f_poly(q, a) * (1 - sigmoid_shifted(q)) + (1 - f_poly(1 - q, a)) * sigmoid_shifted(q) 25# r"h(q,a,b) = f(q,a) * (1-s(q,b)) + (1-f(1-q,a)) * s(q,b)" 26 27# A_scaling = lambda q, a, w: w * g_poly(q, a) 28# r"A(q) = b * g(q, a)" 29 30 31def sigmoid_shifted(x: float) -> float: 32 r"\sigma_s(x)= \frac{1}{1 + e^{-10^3 \cdot (x-0.5)}}" 33 return 1 / (1 + np.exp(-1000 * (x - 0.5))) 34 35 36def g_poly(q: float, a: float) -> float: 37 r"$g(q,a) = 1 - (|2q-1|)^{a}$" 38 return 1 - np.abs(2 * q - 1) ** a 39 40 41def f_poly(q: float, a: float) -> float: 42 r"$f(q,a) = q \cdot g(q,a)$" 43 return q * g_poly(q, a) 44 45 46def h_func(q: float, a: float) -> float: 47 r"""$h(q,a,b) = f(q,a) \cdot (1-\sigma_s(q)) + (1-f(1-q,a)) \cdot \sigma_s(q)$""" 48 return f_poly(q, a) * (1 - sigmoid_shifted(q)) + ( 49 1 - f_poly(1 - q, a) 50 ) * sigmoid_shifted(q) 51 52 53def A_scaling(q: float, a: float, w: float) -> float: 54 r"$A(q) = w \cdot g(q, a)$" 55 return w * g_poly(q, a) 56 57 58def soft_step( 59 x: float | np.floating | Float[np.ndarray, " n"], 60 p: float | np.floating, 61 alpha: float | np.floating = 5, 62 w: float | np.floating = 50, 63) -> float: 64 """when p is close to 0.5 acts like the identity wrt x, but when p is close to 0 or 1, pushes x to 0 or 1 (whichever is closest) 65 66 https://www.desmos.com/calculator/qllvhwftvy 67 """ 68 # TYPING: this is messed up, some of these args can be arrays but i dont remember which? 69 return h_func( 70 x, # type: ignore[arg-type] 71 A_scaling(p, alpha, w), # type: ignore[arg-type] 72 ) 73 74 75# `cfg: MazeDatasetConfig` but we can't import that because it would create a circular import 76def cfg_success_predict_fn(cfg) -> float: # noqa: ANN001 77 "learned by pysr, see `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmark.config_fit`" 78 x = cfg._to_ps_array() 79 raw_val: float = sigmoid( 80 ( 81 ( 82 ((sigmoid((x[1] - x[3]) ** 3) * -4.721228) - (x[3] * 1.4636494)) 83 * ( 84 x[2] 85 * ( 86 x[4] 87 + (((x[0] + 0.048765484) ** 9.746339) + (0.8998194 ** x[1])) 88 ) 89 ) 90 ) 91 + (2.4524326 ** (2.9501643 - x[0])) 92 ) 93 * ( 94 ( 95 (((0.9077277 - x[0]) * ((x[4] * 1.0520288) ** x[1])) + x[0]) 96 * sigmoid(x[1]) ** 3 97 ) 98 + -0.18268494 99 ), 100 ) 101 return soft_step( 102 x=raw_val, 103 p=x[0], 104 alpha=5, # manually tuned 105 w=10, # manually tuned 106 )
def
sigmoid(x: float) -> float:
11def sigmoid(x: float) -> float: 12 r"$\sigma(x) = \frac{1}{1 + e^{-x}}$" 13 return 1 / (1 + np.exp(-x))
$\sigma(x) = \frac{1}{1 + e^{-x}}$
def
sigmoid_shifted(x: float) -> float:
32def sigmoid_shifted(x: float) -> float: 33 r"\sigma_s(x)= \frac{1}{1 + e^{-10^3 \cdot (x-0.5)}}" 34 return 1 / (1 + np.exp(-1000 * (x - 0.5)))
\sigma_s(x)= \frac{1}{1 + e^{-10^3 \cdot (x-0.5)}}
def
g_poly(q: float, a: float) -> float:
37def g_poly(q: float, a: float) -> float: 38 r"$g(q,a) = 1 - (|2q-1|)^{a}$" 39 return 1 - np.abs(2 * q - 1) ** a
$g(q,a) = 1 - (|2q-1|)^{a}$
def
f_poly(q: float, a: float) -> float:
42def f_poly(q: float, a: float) -> float: 43 r"$f(q,a) = q \cdot g(q,a)$" 44 return q * g_poly(q, a)
$f(q,a) = q \cdot g(q,a)$
def
h_func(q: float, a: float) -> float:
47def h_func(q: float, a: float) -> float: 48 r"""$h(q,a,b) = f(q,a) \cdot (1-\sigma_s(q)) + (1-f(1-q,a)) \cdot \sigma_s(q)$""" 49 return f_poly(q, a) * (1 - sigmoid_shifted(q)) + ( 50 1 - f_poly(1 - q, a) 51 ) * sigmoid_shifted(q)
$h(q,a,b) = f(q,a) \cdot (1-\sigma_s(q)) + (1-f(1-q,a)) \cdot \sigma_s(q)$
def
A_scaling(q: float, a: float, w: float) -> float:
54def A_scaling(q: float, a: float, w: float) -> float: 55 r"$A(q) = w \cdot g(q, a)$" 56 return w * g_poly(q, a)
$A(q) = w \cdot g(q, a)$
def
soft_step( x: float | numpy.floating | jaxtyping.Float[ndarray, 'n'], p: float | numpy.floating, alpha: float | numpy.floating = 5, w: float | numpy.floating = 50) -> float:
59def soft_step( 60 x: float | np.floating | Float[np.ndarray, " n"], 61 p: float | np.floating, 62 alpha: float | np.floating = 5, 63 w: float | np.floating = 50, 64) -> float: 65 """when p is close to 0.5 acts like the identity wrt x, but when p is close to 0 or 1, pushes x to 0 or 1 (whichever is closest) 66 67 https://www.desmos.com/calculator/qllvhwftvy 68 """ 69 # TYPING: this is messed up, some of these args can be arrays but i dont remember which? 70 return h_func( 71 x, # type: ignore[arg-type] 72 A_scaling(p, alpha, w), # type: ignore[arg-type] 73 )
when p is close to 0.5 acts like the identity wrt x, but when p is close to 0 or 1, pushes x to 0 or 1 (whichever is closest)
def
cfg_success_predict_fn(cfg) -> float:
77def cfg_success_predict_fn(cfg) -> float: # noqa: ANN001 78 "learned by pysr, see `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmark.config_fit`" 79 x = cfg._to_ps_array() 80 raw_val: float = sigmoid( 81 ( 82 ( 83 ((sigmoid((x[1] - x[3]) ** 3) * -4.721228) - (x[3] * 1.4636494)) 84 * ( 85 x[2] 86 * ( 87 x[4] 88 + (((x[0] + 0.048765484) ** 9.746339) + (0.8998194 ** x[1])) 89 ) 90 ) 91 ) 92 + (2.4524326 ** (2.9501643 - x[0])) 93 ) 94 * ( 95 ( 96 (((0.9077277 - x[0]) * ((x[4] * 1.0520288) ** x[1])) + x[0]) 97 * sigmoid(x[1]) ** 3 98 ) 99 + -0.18268494 100 ), 101 ) 102 return soft_step( 103 x=raw_val, 104 p=x[0], 105 alpha=5, # manually tuned 106 w=10, # manually tuned 107 )
learned by pysr, see estimate_dataset_fractions.ipynb
and maze_dataset.benchmark.config_fit