maze_dataset.utils
misc utilities for the maze_dataset
package
1"misc utilities for the `maze_dataset` package" 2 3import math 4from typing import ( 5 overload, 6) 7 8import numpy as np 9from jaxtyping import Bool, Int, Int8 10 11 12def bool_array_from_string( 13 string: str, 14 shape: list[int], 15 true_symbol: str = "T", 16) -> Bool[np.ndarray, "*shape"]: 17 """Transform a string into an ndarray of bools. 18 19 Parameters 20 ---------- 21 string: str 22 The string representation of the array 23 shape: list[int] 24 The shape of the resulting array 25 true_symbol: 26 The character to parse as True. Whitespace will be removed. All other characters will be parsed as False. 27 28 Returns 29 ------- 30 np.ndarray 31 A ndarray with dtype bool of shape `shape` 32 33 Examples 34 -------- 35 >>> bool_array_from_string( 36 ... "TT TF", shape=[2,2] 37 ... ) 38 array([[ True, True], 39 [ True, False]]) 40 41 """ 42 stripped = "".join(string.split()) 43 44 expected_symbol_count = math.prod(shape) 45 symbol_count = len(stripped) 46 if len(stripped) != expected_symbol_count: 47 err_msg: str = f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}." 48 raise ValueError(err_msg) 49 50 bools = [(symbol == true_symbol) for symbol in stripped] 51 return np.array(bools).reshape(*shape) 52 53 54def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]: 55 """returns an array of indices, sorted by distance from the corner 56 57 this gives the property that `np.ndindex((n,n))` is equal to 58 the first n^2 elements of `np.ndindex((n+1, n+1))` 59 60 ``` 61 >>> corner_first_ndindex(1) 62 [(0, 0)] 63 >>> corner_first_ndindex(2) 64 [(0, 0), (0, 1), (1, 0), (1, 1)] 65 >>> corner_first_ndindex(3) 66 [(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)] 67 ``` 68 """ 69 unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)]))) 70 return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1])) 71 72 73# alternate numpy version from GPT-4: 74""" 75# Create all index combinations 76indices = np.indices([n]*ndim).reshape(ndim, -1).T 77# Find the max value for each index 78max_indices = np.max(indices, axis=1) 79# Identify the odd max values 80odd_mask = max_indices % 2 != 0 81# Make a copy of indices to avoid changing the original one 82indices_copy = indices.copy() 83# Reverse the order of the coordinates for indices with odd max value 84indices_copy[odd_mask] = indices_copy[odd_mask, ::-1] 85# Sort by max index value, then by coordinates 86sorted_order = np.lexsort((*indices_copy.T, max_indices)) 87return indices[sorted_order] 88""" 89 90 91@overload 92def manhattan_distance( 93 edges: Int[np.ndarray, "edges coord=2 row_col=2"], 94) -> Int8[np.ndarray, " edges"]: ... 95# TYPING: error: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader [overload-cannot-match] 96# this is because mypy doesn't play nice with jaxtyping 97@overload 98def manhattan_distance( # type: ignore[overload-cannot-match] 99 edges: Int[np.ndarray, "coord=2 row_col=2"], 100) -> int: ... 101def manhattan_distance( 102 edges: ( 103 Int[np.ndarray, "edges coord=2 row_col=2"] 104 | Int[np.ndarray, "coord=2 row_col=2"] 105 ), 106) -> Int8[np.ndarray, " edges"] | int: 107 """Returns the Manhattan distance between two coords.""" 108 # magic values for dims fine here 109 if len(edges.shape) == 3: # noqa: PLR2004 110 return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype( 111 np.int8, 112 ) 113 elif len(edges.shape) == 2: # noqa: PLR2004 114 return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8)) 115 else: 116 err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints." 117 raise ValueError(err_msg) 118 119 120def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]: 121 """Returns an array with the maximum possible degree for each coord.""" 122 out = np.full((n, n), 2) 123 out[1:-1, :] += 1 124 out[:, 1:-1] += 1 125 return out 126 127 128def lattice_connection_array( 129 n: int, 130) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]: 131 """Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n. 132 133 Thanks Claude. 134 135 # Parameters 136 - `n`: The size of the square lattice. 137 138 # Returns 139 np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. 140 In each pair, the coord with the smaller sum always comes first. 141 """ 142 row_coords, col_coords = np.meshgrid( 143 np.arange(n, dtype=np.int8), 144 np.arange(n, dtype=np.int8), 145 indexing="ij", 146 ) 147 148 # Horizontal edges 149 horiz_edges = np.column_stack( 150 ( 151 row_coords[:, :-1].ravel(), 152 col_coords[:, :-1].ravel(), 153 row_coords[:, 1:].ravel(), 154 col_coords[:, 1:].ravel(), 155 ), 156 ) 157 158 # Vertical edges 159 vert_edges = np.column_stack( 160 ( 161 row_coords[:-1, :].ravel(), 162 col_coords[:-1, :].ravel(), 163 row_coords[1:, :].ravel(), 164 col_coords[1:, :].ravel(), 165 ), 166 ) 167 168 return np.concatenate( 169 (horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)), 170 axis=0, 171 ) 172 173 174def adj_list_to_nested_set(adj_list: list) -> set: 175 """Used for comparison of adj_lists 176 177 Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] 178 We don't care about order of coordinate pairs within 179 the adj_list or coordinates within each coordinate pair. 180 """ 181 return { 182 frozenset([tuple(start_coord), tuple(end_coord)]) 183 for start_coord, end_coord in adj_list 184 }
13def bool_array_from_string( 14 string: str, 15 shape: list[int], 16 true_symbol: str = "T", 17) -> Bool[np.ndarray, "*shape"]: 18 """Transform a string into an ndarray of bools. 19 20 Parameters 21 ---------- 22 string: str 23 The string representation of the array 24 shape: list[int] 25 The shape of the resulting array 26 true_symbol: 27 The character to parse as True. Whitespace will be removed. All other characters will be parsed as False. 28 29 Returns 30 ------- 31 np.ndarray 32 A ndarray with dtype bool of shape `shape` 33 34 Examples 35 -------- 36 >>> bool_array_from_string( 37 ... "TT TF", shape=[2,2] 38 ... ) 39 array([[ True, True], 40 [ True, False]]) 41 42 """ 43 stripped = "".join(string.split()) 44 45 expected_symbol_count = math.prod(shape) 46 symbol_count = len(stripped) 47 if len(stripped) != expected_symbol_count: 48 err_msg: str = f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}." 49 raise ValueError(err_msg) 50 51 bools = [(symbol == true_symbol) for symbol in stripped] 52 return np.array(bools).reshape(*shape)
Transform a string into an ndarray of bools.
Parameters
string: str The string representation of the array shape: list[int] The shape of the resulting array true_symbol: The character to parse as True. Whitespace will be removed. All other characters will be parsed as False.
Returns
np.ndarray
A ndarray with dtype bool of shape shape
Examples
>>> bool_array_from_string(
... "TT TF", shape=[2,2]
... )
array([[ True, True],
[ True, False]])
55def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]: 56 """returns an array of indices, sorted by distance from the corner 57 58 this gives the property that `np.ndindex((n,n))` is equal to 59 the first n^2 elements of `np.ndindex((n+1, n+1))` 60 61 ``` 62 >>> corner_first_ndindex(1) 63 [(0, 0)] 64 >>> corner_first_ndindex(2) 65 [(0, 0), (0, 1), (1, 0), (1, 1)] 66 >>> corner_first_ndindex(3) 67 [(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)] 68 ``` 69 """ 70 unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)]))) 71 return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1]))
returns an array of indices, sorted by distance from the corner
this gives the property that np.ndindex((n,n))
is equal to
the first n^2 elements of np.ndindex((n+1, n+1))
>>> corner_first_ndindex(1)
[(0, 0)]
>>> corner_first_ndindex(2)
[(0, 0), (0, 1), (1, 0), (1, 1)]
>>> corner_first_ndindex(3)
[(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)]
102def manhattan_distance( 103 edges: ( 104 Int[np.ndarray, "edges coord=2 row_col=2"] 105 | Int[np.ndarray, "coord=2 row_col=2"] 106 ), 107) -> Int8[np.ndarray, " edges"] | int: 108 """Returns the Manhattan distance between two coords.""" 109 # magic values for dims fine here 110 if len(edges.shape) == 3: # noqa: PLR2004 111 return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype( 112 np.int8, 113 ) 114 elif len(edges.shape) == 2: # noqa: PLR2004 115 return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8)) 116 else: 117 err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints." 118 raise ValueError(err_msg)
Returns the Manhattan distance between two coords.
121def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]: 122 """Returns an array with the maximum possible degree for each coord.""" 123 out = np.full((n, n), 2) 124 out[1:-1, :] += 1 125 out[:, 1:-1] += 1 126 return out
Returns an array with the maximum possible degree for each coord.
129def lattice_connection_array( 130 n: int, 131) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]: 132 """Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n. 133 134 Thanks Claude. 135 136 # Parameters 137 - `n`: The size of the square lattice. 138 139 # Returns 140 np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. 141 In each pair, the coord with the smaller sum always comes first. 142 """ 143 row_coords, col_coords = np.meshgrid( 144 np.arange(n, dtype=np.int8), 145 np.arange(n, dtype=np.int8), 146 indexing="ij", 147 ) 148 149 # Horizontal edges 150 horiz_edges = np.column_stack( 151 ( 152 row_coords[:, :-1].ravel(), 153 col_coords[:, :-1].ravel(), 154 row_coords[:, 1:].ravel(), 155 col_coords[:, 1:].ravel(), 156 ), 157 ) 158 159 # Vertical edges 160 vert_edges = np.column_stack( 161 ( 162 row_coords[:-1, :].ravel(), 163 col_coords[:-1, :].ravel(), 164 row_coords[1:, :].ravel(), 165 col_coords[1:, :].ravel(), 166 ), 167 ) 168 169 return np.concatenate( 170 (horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)), 171 axis=0, 172 )
Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n.
Thanks Claude.
Parameters
n
: The size of the square lattice.
Returns
np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. In each pair, the coord with the smaller sum always comes first.
175def adj_list_to_nested_set(adj_list: list) -> set: 176 """Used for comparison of adj_lists 177 178 Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] 179 We don't care about order of coordinate pairs within 180 the adj_list or coordinates within each coordinate pair. 181 """ 182 return { 183 frozenset([tuple(start_coord), tuple(end_coord)]) 184 for start_coord, end_coord in adj_list 185 }
Used for comparison of adj_lists
Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] We don't care about order of coordinate pairs within the adj_list or coordinates within each coordinate pair.