Coverage for maze_dataset/dataset/success_predict_math.py: 0%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 01:43 -0600

1"""math for getting the `MazeDatasetConfig.success_fraction_estimate()` function to work 

2 

3Desmos link: https://www.desmos.com/calculator/qllvhwftvy 

4""" 

5 

6import numpy as np 

7from jaxtyping import Float 

8 

9 

10def sigmoid(x: float) -> float: 

11 r"$\sigma(x) = \frac{1}{1 + e^{-x}}$" 

12 return 1 / (1 + np.exp(-x)) 

13 

14 

15# sigmoid_shifted = lambda x: 1 / (1 + np.exp(-1000 * (x - 0.5))) 

16# r"sigmoid(x)= 1 / (1 + e^{-b(x-0.5)})" 

17 

18# g_poly = lambda q, a: 1 - np.abs(2 * q - 1) ** a 

19# r"g(q,a) = 1 - (|2q-1|)^{a}" 

20 

21# f_poly = lambda q, a: q * g_poly(q, a) 

22# r"f(q,a) = q * g(q,a)" 

23 

24# h_func = lambda q, a: f_poly(q, a) * (1 - sigmoid_shifted(q)) + (1 - f_poly(1 - q, a)) * sigmoid_shifted(q) 

25# r"h(q,a,b) = f(q,a) * (1-s(q,b)) + (1-f(1-q,a)) * s(q,b)" 

26 

27# A_scaling = lambda q, a, w: w * g_poly(q, a) 

28# r"A(q) = b * g(q, a)" 

29 

30 

31def sigmoid_shifted(x: float) -> float: 

32 r"\sigma_s(x)= \frac{1}{1 + e^{-10^3 \cdot (x-0.5)}}" 

33 return 1 / (1 + np.exp(-1000 * (x - 0.5))) 

34 

35 

36def g_poly(q: float, a: float) -> float: 

37 r"$g(q,a) = 1 - (|2q-1|)^{a}$" 

38 return 1 - np.abs(2 * q - 1) ** a 

39 

40 

41def f_poly(q: float, a: float) -> float: 

42 r"$f(q,a) = q \cdot g(q,a)$" 

43 return q * g_poly(q, a) 

44 

45 

46def h_func(q: float, a: float) -> float: 

47 r"""$h(q,a,b) = f(q,a) \cdot (1-\sigma_s(q)) + (1-f(1-q,a)) \cdot \sigma_s(q)$""" 

48 return f_poly(q, a) * (1 - sigmoid_shifted(q)) + ( 

49 1 - f_poly(1 - q, a) 

50 ) * sigmoid_shifted(q) 

51 

52 

53def A_scaling(q: float, a: float, w: float) -> float: 

54 r"$A(q) = w \cdot g(q, a)$" 

55 return w * g_poly(q, a) 

56 

57 

58def soft_step( 

59 x: float | np.floating | Float[np.ndarray, " n"], 

60 p: float | np.floating, 

61 alpha: float | np.floating = 5, 

62 w: float | np.floating = 50, 

63) -> float: 

64 """when p is close to 0.5 acts like the identity wrt x, but when p is close to 0 or 1, pushes x to 0 or 1 (whichever is closest) 

65 

66 https://www.desmos.com/calculator/qllvhwftvy 

67 """ 

68 # TYPING: this is messed up, some of these args can be arrays but i dont remember which? 

69 return h_func( 

70 x, # type: ignore[arg-type] 

71 A_scaling(p, alpha, w), # type: ignore[arg-type] 

72 ) 

73 

74 

75# `cfg: MazeDatasetConfig` but we can't import that because it would create a circular import 

76def cfg_success_predict_fn(cfg) -> float: # noqa: ANN001 

77 "learned by pysr, see `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmark.config_fit`" 

78 x = cfg._to_ps_array() 

79 raw_val: float = sigmoid( 

80 ( 

81 ( 

82 ((sigmoid((x[1] - x[3]) ** 3) * -4.721228) - (x[3] * 1.4636494)) 

83 * ( 

84 x[2] 

85 * ( 

86 x[4] 

87 + (((x[0] + 0.048765484) ** 9.746339) + (0.8998194 ** x[1])) 

88 ) 

89 ) 

90 ) 

91 + (2.4524326 ** (2.9501643 - x[0])) 

92 ) 

93 * ( 

94 ( 

95 (((0.9077277 - x[0]) * ((x[4] * 1.0520288) ** x[1])) + x[0]) 

96 * sigmoid(x[1]) ** 3 

97 ) 

98 + -0.18268494 

99 ), 

100 ) 

101 return soft_step( 

102 x=raw_val, 

103 p=x[0], 

104 alpha=5, # manually tuned 

105 w=10, # manually tuned 

106 )