Coverage for maze_dataset/utils.py: 58%

40 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-27 23:43 -0600

1"misc utilities for the `maze_dataset` package" 

2 

3import math 

4from typing import ( 

5 overload, 

6) 

7 

8import numpy as np 

9from jaxtyping import Bool, Int, Int8 

10 

11 

12def bool_array_from_string( 

13 string: str, 

14 shape: list[int], 

15 true_symbol: str = "T", 

16) -> Bool[np.ndarray, "*shape"]: 

17 """Transform a string into an ndarray of bools. 

18 

19 Parameters 

20 ---------- 

21 string: str 

22 The string representation of the array 

23 shape: list[int] 

24 The shape of the resulting array 

25 true_symbol: 

26 The character to parse as True. Whitespace will be removed. All other characters will be parsed as False. 

27 

28 Returns 

29 ------- 

30 np.ndarray 

31 A ndarray with dtype bool of shape `shape` 

32 

33 Examples 

34 -------- 

35 >>> bool_array_from_string( 

36 ... "TT TF", shape=[2,2] 

37 ... ) 

38 array([[ True, True], 

39 [ True, False]]) 

40 

41 """ 

42 stripped = "".join(string.split()) 

43 

44 expected_symbol_count = math.prod(shape) 

45 symbol_count = len(stripped) 

46 if len(stripped) != expected_symbol_count: 

47 err_msg: str = f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}." 

48 raise ValueError(err_msg) 

49 

50 bools = [(symbol == true_symbol) for symbol in stripped] 

51 return np.array(bools).reshape(*shape) 

52 

53 

54def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]: 

55 """returns an array of indices, sorted by distance from the corner 

56 

57 this gives the property that `np.ndindex((n,n))` is equal to 

58 the first n^2 elements of `np.ndindex((n+1, n+1))` 

59 

60 ``` 

61 >>> corner_first_ndindex(1) 

62 [(0, 0)] 

63 >>> corner_first_ndindex(2) 

64 [(0, 0), (0, 1), (1, 0), (1, 1)] 

65 >>> corner_first_ndindex(3) 

66 [(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)] 

67 ``` 

68 """ 

69 unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)]))) 

70 return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1])) 

71 

72 

73# alternate numpy version from GPT-4: 

74""" 

75# Create all index combinations 

76indices = np.indices([n]*ndim).reshape(ndim, -1).T 

77# Find the max value for each index 

78max_indices = np.max(indices, axis=1) 

79# Identify the odd max values 

80odd_mask = max_indices % 2 != 0 

81# Make a copy of indices to avoid changing the original one 

82indices_copy = indices.copy() 

83# Reverse the order of the coordinates for indices with odd max value 

84indices_copy[odd_mask] = indices_copy[odd_mask, ::-1] 

85# Sort by max index value, then by coordinates 

86sorted_order = np.lexsort((*indices_copy.T, max_indices)) 

87return indices[sorted_order] 

88""" 

89 

90 

91@overload 

92def manhattan_distance( 

93 edges: Int[np.ndarray, "edges coord=2 row_col=2"], 

94) -> Int8[np.ndarray, " edges"]: ... 

95# TYPING: error: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader [overload-cannot-match] 

96# this is because mypy doesn't play nice with jaxtyping 

97@overload 

98def manhattan_distance( # type: ignore[overload-cannot-match] 

99 edges: Int[np.ndarray, "coord=2 row_col=2"], 

100) -> int: ... 

101def manhattan_distance( 

102 edges: ( 

103 Int[np.ndarray, "edges coord=2 row_col=2"] 

104 | Int[np.ndarray, "coord=2 row_col=2"] 

105 ), 

106) -> Int8[np.ndarray, " edges"] | int: 

107 """Returns the Manhattan distance between two coords.""" 

108 # magic values for dims fine here 

109 if len(edges.shape) == 3: # noqa: PLR2004 

110 return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype( 

111 np.int8, 

112 ) 

113 elif len(edges.shape) == 2: # noqa: PLR2004 

114 return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8)) 

115 else: 

116 err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints." 

117 raise ValueError(err_msg) 

118 

119 

120def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]: 

121 """Returns an array with the maximum possible degree for each coord.""" 

122 out = np.full((n, n), 2) 

123 out[1:-1, :] += 1 

124 out[:, 1:-1] += 1 

125 return out 

126 

127 

128def lattice_connection_array( 

129 n: int, 

130) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]: 

131 """Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n. 

132 

133 Thanks Claude. 

134 

135 # Parameters 

136 - `n`: The size of the square lattice. 

137 

138 # Returns 

139 np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. 

140 In each pair, the coord with the smaller sum always comes first. 

141 """ 

142 row_coords, col_coords = np.meshgrid( 

143 np.arange(n, dtype=np.int8), 

144 np.arange(n, dtype=np.int8), 

145 indexing="ij", 

146 ) 

147 

148 # Horizontal edges 

149 horiz_edges = np.column_stack( 

150 ( 

151 row_coords[:, :-1].ravel(), 

152 col_coords[:, :-1].ravel(), 

153 row_coords[:, 1:].ravel(), 

154 col_coords[:, 1:].ravel(), 

155 ), 

156 ) 

157 

158 # Vertical edges 

159 vert_edges = np.column_stack( 

160 ( 

161 row_coords[:-1, :].ravel(), 

162 col_coords[:-1, :].ravel(), 

163 row_coords[1:, :].ravel(), 

164 col_coords[1:, :].ravel(), 

165 ), 

166 ) 

167 

168 return np.concatenate( 

169 (horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)), 

170 axis=0, 

171 ) 

172 

173 

174def adj_list_to_nested_set(adj_list: list) -> set: 

175 """Used for comparison of adj_lists 

176 

177 Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] 

178 We don't care about order of coordinate pairs within 

179 the adj_list or coordinates within each coordinate pair. 

180 """ 

181 return { 

182 frozenset([tuple(start_coord), tuple(end_coord)]) 

183 for start_coord, end_coord in adj_list 

184 }