Coverage for maze_dataset/dataset/success_predict_math.py: 0%
20 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 01:43 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 01:43 -0600
1"""math for getting the `MazeDatasetConfig.success_fraction_estimate()` function to work
3Desmos link: https://www.desmos.com/calculator/qllvhwftvy
4"""
6import numpy as np
7from jaxtyping import Float
10def sigmoid(x: float) -> float:
11 r"$\sigma(x) = \frac{1}{1 + e^{-x}}$"
12 return 1 / (1 + np.exp(-x))
15# sigmoid_shifted = lambda x: 1 / (1 + np.exp(-1000 * (x - 0.5)))
16# r"sigmoid(x)= 1 / (1 + e^{-b(x-0.5)})"
18# g_poly = lambda q, a: 1 - np.abs(2 * q - 1) ** a
19# r"g(q,a) = 1 - (|2q-1|)^{a}"
21# f_poly = lambda q, a: q * g_poly(q, a)
22# r"f(q,a) = q * g(q,a)"
24# h_func = lambda q, a: f_poly(q, a) * (1 - sigmoid_shifted(q)) + (1 - f_poly(1 - q, a)) * sigmoid_shifted(q)
25# r"h(q,a,b) = f(q,a) * (1-s(q,b)) + (1-f(1-q,a)) * s(q,b)"
27# A_scaling = lambda q, a, w: w * g_poly(q, a)
28# r"A(q) = b * g(q, a)"
31def sigmoid_shifted(x: float) -> float:
32 r"\sigma_s(x)= \frac{1}{1 + e^{-10^3 \cdot (x-0.5)}}"
33 return 1 / (1 + np.exp(-1000 * (x - 0.5)))
36def g_poly(q: float, a: float) -> float:
37 r"$g(q,a) = 1 - (|2q-1|)^{a}$"
38 return 1 - np.abs(2 * q - 1) ** a
41def f_poly(q: float, a: float) -> float:
42 r"$f(q,a) = q \cdot g(q,a)$"
43 return q * g_poly(q, a)
46def h_func(q: float, a: float) -> float:
47 r"""$h(q,a,b) = f(q,a) \cdot (1-\sigma_s(q)) + (1-f(1-q,a)) \cdot \sigma_s(q)$"""
48 return f_poly(q, a) * (1 - sigmoid_shifted(q)) + (
49 1 - f_poly(1 - q, a)
50 ) * sigmoid_shifted(q)
53def A_scaling(q: float, a: float, w: float) -> float:
54 r"$A(q) = w \cdot g(q, a)$"
55 return w * g_poly(q, a)
58def soft_step(
59 x: float | np.floating | Float[np.ndarray, " n"],
60 p: float | np.floating,
61 alpha: float | np.floating = 5,
62 w: float | np.floating = 50,
63) -> float:
64 """when p is close to 0.5 acts like the identity wrt x, but when p is close to 0 or 1, pushes x to 0 or 1 (whichever is closest)
66 https://www.desmos.com/calculator/qllvhwftvy
67 """
68 # TYPING: this is messed up, some of these args can be arrays but i dont remember which?
69 return h_func(
70 x, # type: ignore[arg-type]
71 A_scaling(p, alpha, w), # type: ignore[arg-type]
72 )
75# `cfg: MazeDatasetConfig` but we can't import that because it would create a circular import
76def cfg_success_predict_fn(cfg) -> float: # noqa: ANN001
77 "learned by pysr, see `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmark.config_fit`"
78 x = cfg._to_ps_array()
79 raw_val: float = sigmoid(
80 (
81 (
82 ((sigmoid((x[1] - x[3]) ** 3) * -4.721228) - (x[3] * 1.4636494))
83 * (
84 x[2]
85 * (
86 x[4]
87 + (((x[0] + 0.048765484) ** 9.746339) + (0.8998194 ** x[1]))
88 )
89 )
90 )
91 + (2.4524326 ** (2.9501643 - x[0]))
92 )
93 * (
94 (
95 (((0.9077277 - x[0]) * ((x[4] * 1.0520288) ** x[1])) + x[0])
96 * sigmoid(x[1]) ** 3
97 )
98 + -0.18268494
99 ),
100 )
101 return soft_step(
102 x=raw_val,
103 p=x[0],
104 alpha=5, # manually tuned
105 w=10, # manually tuned
106 )