Coverage for maze_dataset/utils.py: 58%
40 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-27 23:43 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-27 23:43 -0600
1"misc utilities for the `maze_dataset` package"
3import math
4from typing import (
5 overload,
6)
8import numpy as np
9from jaxtyping import Bool, Int, Int8
12def bool_array_from_string(
13 string: str,
14 shape: list[int],
15 true_symbol: str = "T",
16) -> Bool[np.ndarray, "*shape"]:
17 """Transform a string into an ndarray of bools.
19 Parameters
20 ----------
21 string: str
22 The string representation of the array
23 shape: list[int]
24 The shape of the resulting array
25 true_symbol:
26 The character to parse as True. Whitespace will be removed. All other characters will be parsed as False.
28 Returns
29 -------
30 np.ndarray
31 A ndarray with dtype bool of shape `shape`
33 Examples
34 --------
35 >>> bool_array_from_string(
36 ... "TT TF", shape=[2,2]
37 ... )
38 array([[ True, True],
39 [ True, False]])
41 """
42 stripped = "".join(string.split())
44 expected_symbol_count = math.prod(shape)
45 symbol_count = len(stripped)
46 if len(stripped) != expected_symbol_count:
47 err_msg: str = f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}."
48 raise ValueError(err_msg)
50 bools = [(symbol == true_symbol) for symbol in stripped]
51 return np.array(bools).reshape(*shape)
54def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]:
55 """returns an array of indices, sorted by distance from the corner
57 this gives the property that `np.ndindex((n,n))` is equal to
58 the first n^2 elements of `np.ndindex((n+1, n+1))`
60 ```
61 >>> corner_first_ndindex(1)
62 [(0, 0)]
63 >>> corner_first_ndindex(2)
64 [(0, 0), (0, 1), (1, 0), (1, 1)]
65 >>> corner_first_ndindex(3)
66 [(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)]
67 ```
68 """
69 unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)])))
70 return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1]))
73# alternate numpy version from GPT-4:
74"""
75# Create all index combinations
76indices = np.indices([n]*ndim).reshape(ndim, -1).T
77# Find the max value for each index
78max_indices = np.max(indices, axis=1)
79# Identify the odd max values
80odd_mask = max_indices % 2 != 0
81# Make a copy of indices to avoid changing the original one
82indices_copy = indices.copy()
83# Reverse the order of the coordinates for indices with odd max value
84indices_copy[odd_mask] = indices_copy[odd_mask, ::-1]
85# Sort by max index value, then by coordinates
86sorted_order = np.lexsort((*indices_copy.T, max_indices))
87return indices[sorted_order]
88"""
91@overload
92def manhattan_distance(
93 edges: Int[np.ndarray, "edges coord=2 row_col=2"],
94) -> Int8[np.ndarray, " edges"]: ...
95# TYPING: error: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader [overload-cannot-match]
96# this is because mypy doesn't play nice with jaxtyping
97@overload
98def manhattan_distance( # type: ignore[overload-cannot-match]
99 edges: Int[np.ndarray, "coord=2 row_col=2"],
100) -> int: ...
101def manhattan_distance(
102 edges: (
103 Int[np.ndarray, "edges coord=2 row_col=2"]
104 | Int[np.ndarray, "coord=2 row_col=2"]
105 ),
106) -> Int8[np.ndarray, " edges"] | int:
107 """Returns the Manhattan distance between two coords."""
108 # magic values for dims fine here
109 if len(edges.shape) == 3: # noqa: PLR2004
110 return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype(
111 np.int8,
112 )
113 elif len(edges.shape) == 2: # noqa: PLR2004
114 return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8))
115 else:
116 err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints."
117 raise ValueError(err_msg)
120def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]:
121 """Returns an array with the maximum possible degree for each coord."""
122 out = np.full((n, n), 2)
123 out[1:-1, :] += 1
124 out[:, 1:-1] += 1
125 return out
128def lattice_connection_array(
129 n: int,
130) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]:
131 """Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n.
133 Thanks Claude.
135 # Parameters
136 - `n`: The size of the square lattice.
138 # Returns
139 np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice.
140 In each pair, the coord with the smaller sum always comes first.
141 """
142 row_coords, col_coords = np.meshgrid(
143 np.arange(n, dtype=np.int8),
144 np.arange(n, dtype=np.int8),
145 indexing="ij",
146 )
148 # Horizontal edges
149 horiz_edges = np.column_stack(
150 (
151 row_coords[:, :-1].ravel(),
152 col_coords[:, :-1].ravel(),
153 row_coords[:, 1:].ravel(),
154 col_coords[:, 1:].ravel(),
155 ),
156 )
158 # Vertical edges
159 vert_edges = np.column_stack(
160 (
161 row_coords[:-1, :].ravel(),
162 col_coords[:-1, :].ravel(),
163 row_coords[1:, :].ravel(),
164 col_coords[1:, :].ravel(),
165 ),
166 )
168 return np.concatenate(
169 (horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)),
170 axis=0,
171 )
174def adj_list_to_nested_set(adj_list: list) -> set:
175 """Used for comparison of adj_lists
177 Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...]
178 We don't care about order of coordinate pairs within
179 the adj_list or coordinates within each coordinate pair.
180 """
181 return {
182 frozenset([tuple(start_coord), tuple(end_coord)])
183 for start_coord, end_coord in adj_list
184 }