docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.utils

misc utilities for the maze_dataset package


  1"misc utilities for the `maze_dataset` package"
  2
  3import math
  4from typing import (
  5	overload,
  6)
  7
  8import numpy as np
  9from jaxtyping import Bool, Int, Int8
 10
 11
 12def bool_array_from_string(
 13	string: str,
 14	shape: list[int],
 15	true_symbol: str = "T",
 16) -> Bool[np.ndarray, "*shape"]:
 17	"""Transform a string into an ndarray of bools.
 18
 19	Parameters
 20	----------
 21	string: str
 22		The string representation of the array
 23	shape: list[int]
 24		The shape of the resulting array
 25	true_symbol:
 26		The character to parse as True. Whitespace will be removed. All other characters will be parsed as False.
 27
 28	Returns
 29	-------
 30	np.ndarray
 31		A ndarray with dtype bool of shape `shape`
 32
 33	Examples
 34	--------
 35	>>> bool_array_from_string(
 36	...	 "TT TF", shape=[2,2]
 37	... )
 38	array([[ True,  True],
 39		[ True, False]])
 40
 41	"""
 42	stripped = "".join(string.split())
 43
 44	expected_symbol_count = math.prod(shape)
 45	symbol_count = len(stripped)
 46	if len(stripped) != expected_symbol_count:
 47		err_msg: str = f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}."
 48		raise ValueError(err_msg)
 49
 50	bools = [(symbol == true_symbol) for symbol in stripped]
 51	return np.array(bools).reshape(*shape)
 52
 53
 54def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]:
 55	"""returns an array of indices, sorted by distance from the corner
 56
 57	this gives the property that `np.ndindex((n,n))` is equal to
 58	the first n^2 elements of `np.ndindex((n+1, n+1))`
 59
 60	```
 61	>>> corner_first_ndindex(1)
 62	[(0, 0)]
 63	>>> corner_first_ndindex(2)
 64	[(0, 0), (0, 1), (1, 0), (1, 1)]
 65	>>> corner_first_ndindex(3)
 66	[(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)]
 67	```
 68	"""
 69	unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)])))
 70	return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1]))
 71
 72
 73# alternate numpy version from GPT-4:
 74"""
 75# Create all index combinations
 76indices = np.indices([n]*ndim).reshape(ndim, -1).T
 77# Find the max value for each index
 78max_indices = np.max(indices, axis=1)
 79# Identify the odd max values
 80odd_mask = max_indices % 2 != 0
 81# Make a copy of indices to avoid changing the original one
 82indices_copy = indices.copy()
 83# Reverse the order of the coordinates for indices with odd max value
 84indices_copy[odd_mask] = indices_copy[odd_mask, ::-1]
 85# Sort by max index value, then by coordinates
 86sorted_order = np.lexsort((*indices_copy.T, max_indices))
 87return indices[sorted_order]
 88"""
 89
 90
 91@overload
 92def manhattan_distance(
 93	edges: Int[np.ndarray, "edges coord=2 row_col=2"],
 94) -> Int8[np.ndarray, " edges"]: ...
 95# TYPING: error: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader  [overload-cannot-match]
 96# this is because mypy doesn't play nice with jaxtyping
 97@overload
 98def manhattan_distance(  # type: ignore[overload-cannot-match]
 99	edges: Int[np.ndarray, "coord=2 row_col=2"],
100) -> int: ...
101def manhattan_distance(
102	edges: (
103		Int[np.ndarray, "edges coord=2 row_col=2"]
104		| Int[np.ndarray, "coord=2 row_col=2"]
105	),
106) -> Int8[np.ndarray, " edges"] | int:
107	"""Returns the Manhattan distance between two coords."""
108	# magic values for dims fine here
109	if len(edges.shape) == 3:  # noqa: PLR2004
110		return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype(
111			np.int8,
112		)
113	elif len(edges.shape) == 2:  # noqa: PLR2004
114		return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8))
115	else:
116		err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints."
117		raise ValueError(err_msg)
118
119
120def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]:
121	"""Returns an array with the maximum possible degree for each coord."""
122	out = np.full((n, n), 2)
123	out[1:-1, :] += 1
124	out[:, 1:-1] += 1
125	return out
126
127
128def lattice_connection_array(
129	n: int,
130) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]:
131	"""Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n.
132
133	Thanks Claude.
134
135	# Parameters
136	- `n`: The size of the square lattice.
137
138	# Returns
139	np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice.
140	In each pair, the coord with the smaller sum always comes first.
141	"""
142	row_coords, col_coords = np.meshgrid(
143		np.arange(n, dtype=np.int8),
144		np.arange(n, dtype=np.int8),
145		indexing="ij",
146	)
147
148	# Horizontal edges
149	horiz_edges = np.column_stack(
150		(
151			row_coords[:, :-1].ravel(),
152			col_coords[:, :-1].ravel(),
153			row_coords[:, 1:].ravel(),
154			col_coords[:, 1:].ravel(),
155		),
156	)
157
158	# Vertical edges
159	vert_edges = np.column_stack(
160		(
161			row_coords[:-1, :].ravel(),
162			col_coords[:-1, :].ravel(),
163			row_coords[1:, :].ravel(),
164			col_coords[1:, :].ravel(),
165		),
166	)
167
168	return np.concatenate(
169		(horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)),
170		axis=0,
171	)
172
173
174def adj_list_to_nested_set(adj_list: list) -> set:
175	"""Used for comparison of adj_lists
176
177	Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...]
178	We don't care about order of coordinate pairs within
179	the adj_list or coordinates within each coordinate pair.
180	"""
181	return {
182		frozenset([tuple(start_coord), tuple(end_coord)])
183		for start_coord, end_coord in adj_list
184	}

def bool_array_from_string( string: str, shape: list[int], true_symbol: str = 'T') -> jaxtyping.Bool[ndarray, '*shape']:
13def bool_array_from_string(
14	string: str,
15	shape: list[int],
16	true_symbol: str = "T",
17) -> Bool[np.ndarray, "*shape"]:
18	"""Transform a string into an ndarray of bools.
19
20	Parameters
21	----------
22	string: str
23		The string representation of the array
24	shape: list[int]
25		The shape of the resulting array
26	true_symbol:
27		The character to parse as True. Whitespace will be removed. All other characters will be parsed as False.
28
29	Returns
30	-------
31	np.ndarray
32		A ndarray with dtype bool of shape `shape`
33
34	Examples
35	--------
36	>>> bool_array_from_string(
37	...	 "TT TF", shape=[2,2]
38	... )
39	array([[ True,  True],
40		[ True, False]])
41
42	"""
43	stripped = "".join(string.split())
44
45	expected_symbol_count = math.prod(shape)
46	symbol_count = len(stripped)
47	if len(stripped) != expected_symbol_count:
48		err_msg: str = f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}."
49		raise ValueError(err_msg)
50
51	bools = [(symbol == true_symbol) for symbol in stripped]
52	return np.array(bools).reshape(*shape)

Transform a string into an ndarray of bools.

Parameters

string: str The string representation of the array shape: list[int] The shape of the resulting array true_symbol: The character to parse as True. Whitespace will be removed. All other characters will be parsed as False.

Returns

np.ndarray A ndarray with dtype bool of shape shape

Examples

>>> bool_array_from_string(
...      "TT TF", shape=[2,2]
... )
array([[ True,  True],
        [ True, False]])
def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]:
55def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]:
56	"""returns an array of indices, sorted by distance from the corner
57
58	this gives the property that `np.ndindex((n,n))` is equal to
59	the first n^2 elements of `np.ndindex((n+1, n+1))`
60
61	```
62	>>> corner_first_ndindex(1)
63	[(0, 0)]
64	>>> corner_first_ndindex(2)
65	[(0, 0), (0, 1), (1, 0), (1, 1)]
66	>>> corner_first_ndindex(3)
67	[(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)]
68	```
69	"""
70	unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)])))
71	return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1]))

returns an array of indices, sorted by distance from the corner

this gives the property that np.ndindex((n,n)) is equal to the first n^2 elements of np.ndindex((n+1, n+1))

>>> corner_first_ndindex(1)
[(0, 0)]
>>> corner_first_ndindex(2)
[(0, 0), (0, 1), (1, 0), (1, 1)]
>>> corner_first_ndindex(3)
[(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)]
def manhattan_distance( edges: jaxtyping.Int[ndarray, 'edges coord=2 row_col=2'] | jaxtyping.Int[ndarray, 'coord=2 row_col=2']) -> jaxtyping.Int8[ndarray, 'edges'] | int:
102def manhattan_distance(
103	edges: (
104		Int[np.ndarray, "edges coord=2 row_col=2"]
105		| Int[np.ndarray, "coord=2 row_col=2"]
106	),
107) -> Int8[np.ndarray, " edges"] | int:
108	"""Returns the Manhattan distance between two coords."""
109	# magic values for dims fine here
110	if len(edges.shape) == 3:  # noqa: PLR2004
111		return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype(
112			np.int8,
113		)
114	elif len(edges.shape) == 2:  # noqa: PLR2004
115		return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8))
116	else:
117		err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints."
118		raise ValueError(err_msg)

Returns the Manhattan distance between two coords.

def lattice_max_degrees(n: int) -> jaxtyping.Int8[ndarray, 'row col']:
121def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]:
122	"""Returns an array with the maximum possible degree for each coord."""
123	out = np.full((n, n), 2)
124	out[1:-1, :] += 1
125	out[:, 1:-1] += 1
126	return out

Returns an array with the maximum possible degree for each coord.

def lattice_connection_array( n: int) -> jaxtyping.Int8[ndarray, 'edges=2*n*(n-1) leading_trailing_coord=2 row_col=2']:
129def lattice_connection_array(
130	n: int,
131) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]:
132	"""Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n.
133
134	Thanks Claude.
135
136	# Parameters
137	- `n`: The size of the square lattice.
138
139	# Returns
140	np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice.
141	In each pair, the coord with the smaller sum always comes first.
142	"""
143	row_coords, col_coords = np.meshgrid(
144		np.arange(n, dtype=np.int8),
145		np.arange(n, dtype=np.int8),
146		indexing="ij",
147	)
148
149	# Horizontal edges
150	horiz_edges = np.column_stack(
151		(
152			row_coords[:, :-1].ravel(),
153			col_coords[:, :-1].ravel(),
154			row_coords[:, 1:].ravel(),
155			col_coords[:, 1:].ravel(),
156		),
157	)
158
159	# Vertical edges
160	vert_edges = np.column_stack(
161		(
162			row_coords[:-1, :].ravel(),
163			col_coords[:-1, :].ravel(),
164			row_coords[1:, :].ravel(),
165			col_coords[1:, :].ravel(),
166		),
167	)
168
169	return np.concatenate(
170		(horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)),
171		axis=0,
172	)

Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n.

Thanks Claude.

Parameters

  • n: The size of the square lattice.

Returns

np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. In each pair, the coord with the smaller sum always comes first.

def adj_list_to_nested_set(adj_list: list) -> set:
175def adj_list_to_nested_set(adj_list: list) -> set:
176	"""Used for comparison of adj_lists
177
178	Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...]
179	We don't care about order of coordinate pairs within
180	the adj_list or coordinates within each coordinate pair.
181	"""
182	return {
183		frozenset([tuple(start_coord), tuple(end_coord)])
184		for start_coord, end_coord in adj_list
185	}

Used for comparison of adj_lists

Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] We don't care about order of coordinate pairs within the adj_list or coordinates within each coordinate pair.