docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.token_utils

a whole bunch of utilities for tokenization


  1"""a whole bunch of utilities for tokenization"""
  2
  3import re
  4import typing
  5import warnings
  6from collections import Counter
  7from typing import Callable, Literal, overload
  8
  9import numpy as np
 10from jaxtyping import Bool, Float, Int, Int8
 11from muutils.errormode import ErrorMode
 12from muutils.misc import list_join
 13from muutils.misc.sequence import WhenMissing
 14
 15from maze_dataset.constants import (
 16	CARDINAL_MAP,
 17	SPECIAL_TOKENS,
 18	VOCAB,
 19	ConnectionArray,
 20	ConnectionList,
 21	CoordTup,
 22)
 23
 24# filtering things from a prompt or generated text
 25# ==================================================
 26
 27
 28def remove_padding_from_token_str(token_str: str) -> str:
 29	"""remove padding tokens from a joined token string"""
 30	token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING} ", "")
 31	token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING}", "")
 32	return token_str  # noqa: RET504
 33
 34
 35def tokens_between(
 36	tokens: list[str],
 37	start_value: str,
 38	end_value: str,
 39	include_start: bool = False,
 40	include_end: bool = False,
 41	except_when_tokens_not_unique: bool = False,
 42) -> list[str]:
 43	"""given a list `tokens`, get the tokens between `start_value` and `end_value`
 44
 45	_extended_summary_
 46
 47	# Parameters:
 48	- `tokens : list[str]`
 49	- `start_value : str`
 50	- `end_value : str`
 51	- `include_start : bool`
 52		(defaults to `False`)
 53	- `include_end : bool`
 54		(defaults to `False`)
 55	- `except_when_tokens_not_unique : bool`
 56		when `True`, raise an error if `start_value` or `end_value` are not unique in the input tokens
 57		(defaults to `False`)
 58
 59	# Returns:
 60	- `list[str]`
 61
 62	# Raises:
 63	- `ValueError` : if `start_value` and `end_value` are the same
 64	- `ValueError` : if `except_when_tokens_not_unique` is `True` and `start_value` or `end_value` are not unique in the input tokens
 65	- `ValueError` : if `start_value` or `end_value` are not present in the input tokens
 66	"""
 67	if start_value == end_value:
 68		err_msg: str = f"start_value and end_value cannot be the same: {start_value = } {end_value = }"
 69		raise ValueError(
 70			err_msg,
 71		)
 72	if except_when_tokens_not_unique:
 73		if (tokens.count(start_value) != 1) or (tokens.count(end_value) != 1):
 74			err_msg: str = (
 75				"start_value or end_value is not unique in the input tokens:"
 76				f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }"
 77				f"\n{start_value = } {end_value = }"
 78				f"\n{tokens = }"
 79			)
 80			raise ValueError(err_msg)
 81	else:
 82		if (tokens.count(start_value) < 1) or (tokens.count(end_value) < 1):
 83			err_msg: str = (
 84				"start_value or end_value is not present in the input tokens:"
 85				f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }"
 86				f"\n{start_value = } {end_value = }"
 87				f"\n{tokens = }"
 88			)
 89			raise ValueError(err_msg)
 90
 91	start_idx: int = tokens.index(start_value) + int(not include_start)
 92	end_idx: int = tokens.index(end_value) + int(include_end)
 93
 94	assert start_idx < end_idx, "Start must come before end"
 95
 96	return tokens[start_idx:end_idx]
 97
 98
 99def get_adj_list_tokens(tokens: list[str]) -> list[str]:
100	"get tokens between ADJLIST_START and ADJLIST_END, without the special tokens themselves"
101	return tokens_between(
102		tokens,
103		SPECIAL_TOKENS.ADJLIST_START,
104		SPECIAL_TOKENS.ADJLIST_END,
105	)
106
107
108def get_path_tokens(tokens: list[str], trim_end: bool = False) -> list[str]:
109	"""The path is considered everything from the first path coord to the path_end token, if it exists."""
110	if SPECIAL_TOKENS.PATH_START not in tokens:
111		err_msg: str = f"Path start token {SPECIAL_TOKENS.PATH_START} not found in tokens:\n{tokens}"
112		raise ValueError(
113			err_msg,
114		)
115	start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + int(trim_end)
116	end_idx: int | None = None
117	if trim_end and (SPECIAL_TOKENS.PATH_END in tokens):
118		end_idx = tokens.index(SPECIAL_TOKENS.PATH_END)
119	return tokens[start_idx:end_idx]
120
121
122def get_context_tokens(tokens: list[str]) -> list[str]:
123	"get tokens between ADJLIST_START and PATH_START"
124	return tokens_between(
125		tokens,
126		SPECIAL_TOKENS.ADJLIST_START,
127		SPECIAL_TOKENS.PATH_START,
128		include_start=True,
129		include_end=True,
130	)
131
132
133def get_origin_tokens(tokens: list[str]) -> list[str]:
134	"get tokens_between ORIGIN_START and ORIGIN_END"
135	return tokens_between(
136		tokens,
137		SPECIAL_TOKENS.ORIGIN_START,
138		SPECIAL_TOKENS.ORIGIN_END,
139		include_start=False,
140		include_end=False,
141	)
142
143
144def get_target_tokens(tokens: list[str]) -> list[str]:
145	"get tokens_between TARGET_START and TARGET_END"
146	return tokens_between(
147		tokens,
148		SPECIAL_TOKENS.TARGET_START,
149		SPECIAL_TOKENS.TARGET_END,
150		include_start=False,
151		include_end=False,
152	)
153
154
155def get_cardinal_direction(coords: Int[np.ndarray, "start_end=2 row_col=2"]) -> str:
156	"""Returns the cardinal direction token corresponding to traveling from `coords[0]` to `coords[1]`."""
157	return CARDINAL_MAP[tuple(coords[1] - coords[0])]
158
159
160def get_relative_direction(coords: Int[np.ndarray, "prev_cur_next=3 row_col=2"]) -> str:
161	"""Returns the relative first-person direction token corresponding to traveling from `coords[1]` to `coords[2]`.
162
163	# Parameters
164	- `coords`: Contains 3 Coords, each of which must neighbor the previous Coord.
165		- `coords[0]`: The previous location, used to determine the current absolute direction that the "agent" is facing.
166		- `coords[1]`: The current location
167		- `coords[2]`: The next location. May be equal to the current location.
168	"""
169	if coords.shape != (3, 2):
170		err_msg: str = f"`coords` must have shape (3,2). Got {coords.shape} instead."
171		raise ValueError(err_msg)
172	directions = coords[1:] - coords[:-1]
173	if not np.all(np.linalg.norm(directions, axis=1) <= np.array([1.1, 1.1])):
174		# Use floats as constant since `np.linalg.norm` returns float array
175		err_msg: str = f"Adjacent `coords` must be neighboring or equivalent. Got {coords} instead."
176		raise ValueError(
177			err_msg,
178		)
179	if np.array_equal(coords[1], coords[2]):
180		return VOCAB.PATH_STAY
181	if np.array_equal(coords[0], coords[2]):
182		return VOCAB.PATH_BACKWARD
183	if np.array_equal(coords[0], coords[1]):
184		err_msg: str = f"Previous first-person direction indeterminate from {coords=}."
185		raise ValueError(
186			err_msg,
187		)
188	if np.array_equal(directions[0], directions[1]):
189		return VOCAB.PATH_FORWARD
190	directions = np.append(
191		directions,
192		[[0], [0]],
193		axis=1,
194	)  # Augment to represent unit basis vectors in 3D
195	match np.cross(directions[0], directions[1])[-1]:
196		case 1:
197			return VOCAB.PATH_LEFT
198		case -1:
199			return VOCAB.PATH_RIGHT
200
201
202class TokenizerPendingDeprecationWarning(PendingDeprecationWarning):
203	"""Pending deprecation warnings related to the `MazeTokenizerModular` upgrade."""
204
205	pass
206
207
208def str_is_coord(coord_str: str, allow_whitespace: bool = True) -> bool:
209	"""return True if the string represents a coordinate, False otherwise"""
210	warnings.warn(
211		"`util.str_is_coord` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.",
212		TokenizerPendingDeprecationWarning,
213	)
214	strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x  # noqa: E731
215
216	coord_str = strip_func(coord_str)
217
218	return all(
219		[
220			coord_str.startswith("("),
221			coord_str.endswith(")"),
222			"," in coord_str,
223			all(
224				strip_func(x).isdigit()
225				for x in strip_func(coord_str.lstrip("(").rstrip(")")).split(",")
226			),
227		],
228	)
229
230
231class TokenizerDeprecationWarning(DeprecationWarning):
232	"""Deprecation warnings related to the `MazeTokenizerModular` upgrade."""
233
234	pass
235
236
237# coordinate to strings
238# ==================================================
239
240
241def _coord_to_strings_UT(coord: typing.Sequence[int]) -> list[str]:
242	"""convert a coordinate to a string: `(i,j)`->"(i,j)"
243
244	always returns a list of length 1
245	"""
246	return [f"({','.join(str(c) for c in coord)})"]
247
248
249def _coord_to_strings_indexed(coord: typing.Sequence[int]) -> list[str]:
250	"""convert a coordinate to a list of indexed strings: `(i,j)`->"(", "i", ",", "j", ")"
251
252	always returns a list of length 5
253	"""
254	return [
255		"(",
256		*list_join([str(c) for c in coord], lambda: ","),
257		")",
258	]
259
260
261def coord_str_to_tuple(
262	coord_str: str,
263	allow_whitespace: bool = True,
264) -> tuple[int, ...]:
265	"""convert a coordinate string to a tuple"""
266	strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x  # noqa: E731
267	coord_str = strip_func(coord_str)
268	stripped: str = strip_func(coord_str.lstrip("(").rstrip(")"))
269	return tuple(int(strip_func(x)) for x in stripped.split(","))
270
271
272def coord_str_to_coord_np(coord_str: str, allow_whitespace: bool = True) -> np.ndarray:
273	"""convert a coordinate string to a numpy array"""
274	return np.array(coord_str_to_tuple(coord_str, allow_whitespace=allow_whitespace))
275
276
277def coord_str_to_tuple_noneable(coord_str: str) -> CoordTup | None:
278	"""convert a coordinate string to a tuple, or None if the string is not a coordinate string"""
279	if not str_is_coord(coord_str):
280		return None
281	return coord_str_to_tuple(coord_str)
282
283
284def coords_string_split_UT(coords: str) -> list[str]:
285	"""Splits a string of tokens into a list containing the UT tokens for each coordinate.
286
287	Not capable of producing indexed tokens ("(", "1", ",", "2", ")"), only unique tokens ("(1,2)").
288	Non-whitespace portions of the input string not matched are preserved in the same list:
289	"(1,2) <SPECIAL_TOKEN> (5,6)" -> ["(1,2)", "<SPECIAL_TOKEN>", "(5,6)"]
290	"""
291	# ty gpt4
292	return re.findall(r"\([^)]*\)|\S+", coords)
293
294
295# back and forth in wrapped form
296# ==================================================
297@overload
298def strings_to_coords(
299	text: str | list[str],
300	when_noncoord: Literal["skip"] = "skip",
301) -> list[CoordTup]: ...
302@overload
303def strings_to_coords(
304	text: str | list[str],
305	when_noncoord: Literal["error"] = "error",
306) -> list[CoordTup]: ...
307@overload
308def strings_to_coords(
309	text: str | list[str],
310	when_noncoord: Literal["include"] = "include",
311) -> list[str | CoordTup]: ...
312def strings_to_coords(
313	text: str | list[str],
314	when_noncoord: WhenMissing = "skip",
315) -> list[str | CoordTup]:
316	"""converts a list of tokens to a list of coordinates
317
318	returns list[CoordTup] if `when_noncoord` is "skip" or "error"
319	returns list[str | CoordTup] if `when_noncoord` is "include"
320	"""
321	warnings.warn(
322		"`util.strings_to_coords` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.",
323		TokenizerPendingDeprecationWarning,
324	)
325	tokens_joined: str = text if isinstance(text, str) else " ".join(text)
326	tokens_processed: list[str] = coords_string_split_UT(tokens_joined)
327	result: list[str] = list()
328	for token in tokens_processed:
329		coord: CoordTup | None = coord_str_to_tuple_noneable(token)
330		if coord is None:
331			if when_noncoord == "skip":
332				continue
333			if when_noncoord == "error":
334				err_msg: str = (
335					f"Invalid non-coordinate token '{token}' in text: '{text}'"
336				)
337				raise ValueError(
338					err_msg,
339				)
340			if when_noncoord == "include":
341				result.append(token)
342			else:
343				err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'"
344				raise ValueError(err_msg)
345		else:
346			result.append(coord)
347	return result
348
349
350@overload
351def coords_to_strings(
352	coords: list[str | CoordTup],
353	coord_to_strings_func: Callable[[CoordTup], list[str]],
354	when_noncoord: Literal["include", "skip"] = "skip",
355) -> list[str]: ...
356@overload
357def coords_to_strings(
358	coords: list[CoordTup],
359	coord_to_strings_func: Callable[[CoordTup], list[str]],
360	when_noncoord: Literal["error"] = "error",
361) -> list[str]: ...
362def coords_to_strings(
363	coords: list[str | CoordTup],
364	coord_to_strings_func: Callable[[CoordTup], list[str]],
365	when_noncoord: WhenMissing = "skip",
366) -> list[str]:
367	"""converts a list of coordinates to a list of strings (tokens)
368
369	expects list[CoordTup] if `when_noncoord` is "error"
370	expects list[str | CoordTup] if `when_noncoord` is "include" or "skip"
371	"""
372	result: list[str] = list()
373	for coord in coords:
374		if isinstance(coord, str):
375			if when_noncoord == "skip":
376				continue
377			if when_noncoord == "error":
378				err_msg: str = (
379					f"Invalid non-coordinate '{coord}' in list of coords: '{coords}'"
380				)
381				raise ValueError(
382					err_msg,
383				)
384			if when_noncoord == "include":
385				result.append(coord)
386			else:
387				err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'"
388				raise ValueError(err_msg)
389		else:
390			result.extend(coord_to_strings_func(coord))
391	return result
392
393
394def get_token_regions(toks: list[str]) -> tuple[list[str], list[str]]:
395	"""Splits a list of tokens into adjacency list tokens and non-adjacency list tokens."""
396	adj_list_start, adj_list_end = (
397		toks.index("<ADJLIST_START>") + 1,
398		toks.index("<ADJLIST_END>"),
399	)
400	adj_list = toks[adj_list_start:adj_list_end]
401	non_adj_list = toks[:adj_list_start] + toks[adj_list_end:]
402	return adj_list, non_adj_list
403
404
405def equal_except_adj_list_sequence(  # noqa: C901
406	rollout1: list[str],
407	rollout2: list[str],
408	do_except: bool = False,
409	when_counter_mismatch: ErrorMode = ErrorMode.EXCEPT,
410	when_len_mismatch: ErrorMode = ErrorMode.EXCEPT,
411) -> bool:
412	"""Returns if the rollout strings are equal, allowing for differently sequenced adjacency lists.
413
414	<ADJLIST_START> and <ADJLIST_END> tokens must be in the rollouts.
415	Intended ONLY for determining if two tokenization schemes are the same for rollouts generated from the same maze.
416	This function should NOT be used to determine if two rollouts encode the same `LatticeMaze` object.
417
418	# Warning: CTT False Positives
419	This function is not robustly correct for some corner cases using `CoordTokenizers.CTT`.
420	If rollouts are passed for identical tokenizers processing two slightly different mazes, a false positive is possible.
421	More specifically, some cases of zero-sum adding and removing of connections in a maze within square regions along the diagonal will produce a false positive.
422	"""
423	if len(rollout1) != len(rollout2):
424		if do_except:
425			when_len_mismatch.process(
426				f"Rollouts are not the same length: {len(rollout1)} != {len(rollout2)}",
427			)
428		return False
429	if ("<ADJLIST_START>" in rollout1) ^ ("<ADJLIST_START>" in rollout2):
430		if do_except:
431			err_msg: str = f"Rollouts do not have the same <ADJLIST_START> token: `{'<ADJLIST_START>' in rollout1 = }` != `{'<ADJLIST_START>' in rollout2 = }`"
432			raise ValueError(
433				err_msg,
434			)
435		return False
436	if ("<ADJLIST_END>" in rollout1) ^ ("<ADJLIST_END>" in rollout2):
437		if do_except:
438			err_msg: str = f"Rollouts do not have the same <ADJLIST_END> token: `{'<ADJLIST_END>' in rollout1 = }` != `{'<ADJLIST_END>' in rollout2 = }`"
439			raise ValueError(
440				err_msg,
441			)
442		return False
443
444	adj_list1, non_adj_list1 = get_token_regions(rollout1)
445	adj_list2, non_adj_list2 = get_token_regions(rollout2)
446	if non_adj_list1 != non_adj_list2:
447		if do_except:
448			when_len_mismatch.process(
449				f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}",
450			)
451			err_msg: str = f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}"
452			raise ValueError(
453				err_msg,
454			)
455		return False
456	counter1: Counter = Counter(adj_list1)
457	counter2: Counter = Counter(adj_list2)
458	counters_eq: bool = counter1 == counter2
459	if not counters_eq:
460		if do_except:
461			when_counter_mismatch.process(
462				f"Adjacency list counters are not the same:\n{counter1}\n!=\n{counter2}\n{counter1 - counter2 = }",
463			)
464		return False
465
466	return True
467
468
469def connection_list_to_adj_list(
470	conn_list: ConnectionList,
471	shuffle_d0: bool = True,
472	shuffle_d1: bool = True,
473) -> Int8[np.ndarray, "conn start_end=2 coord=2"]:
474	"""converts a `ConnectionList` (special lattice format) to a shuffled adjacency list
475
476	# Parameters:
477	- `conn_list: ConnectionList`
478		special internal format for graphs which are subgraphs of a lattice
479	- `shuffle_d0: bool`
480		shuffle the adjacency list along the 0th axis (order of pairs)
481	- `shuffle_d1: bool`
482		shuffle the adjacency list along the 1st axis (order of coordinates in each pair).
483		If `False`, all pairs have the smaller coord first.
484
485
486	# Returns:
487	- `Int8[np.ndarray, "conn start_end=2 coord=2"]`
488		adjacency list in the shape `(n_connections, 2, 2)`
489	"""
490	n_connections: int = conn_list.sum()
491	adj_list: Int8[np.ndarray, "conn start_end=2 coord=2"] = np.full(
492		(n_connections, 2, 2),
493		-1,
494		dtype=np.int8,
495	)
496
497	if shuffle_d1:
498		flip_d1: Float[np.ndarray, " conn"] = np.random.rand(n_connections)
499
500	# loop over all nonzero elements of the connection list
501	i: int = 0
502	for d, x, y in np.ndindex(conn_list.shape):
503		if conn_list[d, x, y]:
504			c_start: CoordTup = (x, y)
505			c_end: CoordTup = (
506				x + (1 if d == 0 else 0),
507				y + (1 if d == 1 else 0),
508			)
509			adj_list[i, 0] = np.array(c_start, dtype=np.int8)
510			adj_list[i, 1] = np.array(c_end, dtype=np.int8)
511
512			# flip if shuffling
513			# magic value is fine here
514			if shuffle_d1 and (flip_d1[i] > 0.5):  # noqa: PLR2004
515				c_s, c_e = adj_list[i, 0].copy(), adj_list[i, 1].copy()
516				adj_list[i, 0] = c_e
517				adj_list[i, 1] = c_s
518
519			i += 1
520
521	if shuffle_d0:
522		np.random.shuffle(adj_list)
523
524	return adj_list
525
526
527def is_connection(
528	edges: ConnectionArray,
529	connection_list: ConnectionList,
530) -> Bool[np.ndarray, "is_connection=edges"]:
531	"""Returns if each edge in `edges` is a connection (`True`) or wall (`False`) in `connection_list`."""
532	sorted_edges = np.sort(edges, axis=1)
533	edge_direction = (
534		(sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0
535	).astype(np.int8)
536	return connection_list[edge_direction, sorted_edges[:, 0, 0], sorted_edges[:, 0, 1]]
537
538
539# string to coordinate representation
540# ==================================================

def remove_padding_from_token_str(token_str: str) -> str:
29def remove_padding_from_token_str(token_str: str) -> str:
30	"""remove padding tokens from a joined token string"""
31	token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING} ", "")
32	token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING}", "")
33	return token_str  # noqa: RET504

remove padding tokens from a joined token string

def tokens_between( tokens: list[str], start_value: str, end_value: str, include_start: bool = False, include_end: bool = False, except_when_tokens_not_unique: bool = False) -> list[str]:
36def tokens_between(
37	tokens: list[str],
38	start_value: str,
39	end_value: str,
40	include_start: bool = False,
41	include_end: bool = False,
42	except_when_tokens_not_unique: bool = False,
43) -> list[str]:
44	"""given a list `tokens`, get the tokens between `start_value` and `end_value`
45
46	_extended_summary_
47
48	# Parameters:
49	- `tokens : list[str]`
50	- `start_value : str`
51	- `end_value : str`
52	- `include_start : bool`
53		(defaults to `False`)
54	- `include_end : bool`
55		(defaults to `False`)
56	- `except_when_tokens_not_unique : bool`
57		when `True`, raise an error if `start_value` or `end_value` are not unique in the input tokens
58		(defaults to `False`)
59
60	# Returns:
61	- `list[str]`
62
63	# Raises:
64	- `ValueError` : if `start_value` and `end_value` are the same
65	- `ValueError` : if `except_when_tokens_not_unique` is `True` and `start_value` or `end_value` are not unique in the input tokens
66	- `ValueError` : if `start_value` or `end_value` are not present in the input tokens
67	"""
68	if start_value == end_value:
69		err_msg: str = f"start_value and end_value cannot be the same: {start_value = } {end_value = }"
70		raise ValueError(
71			err_msg,
72		)
73	if except_when_tokens_not_unique:
74		if (tokens.count(start_value) != 1) or (tokens.count(end_value) != 1):
75			err_msg: str = (
76				"start_value or end_value is not unique in the input tokens:"
77				f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }"
78				f"\n{start_value = } {end_value = }"
79				f"\n{tokens = }"
80			)
81			raise ValueError(err_msg)
82	else:
83		if (tokens.count(start_value) < 1) or (tokens.count(end_value) < 1):
84			err_msg: str = (
85				"start_value or end_value is not present in the input tokens:"
86				f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }"
87				f"\n{start_value = } {end_value = }"
88				f"\n{tokens = }"
89			)
90			raise ValueError(err_msg)
91
92	start_idx: int = tokens.index(start_value) + int(not include_start)
93	end_idx: int = tokens.index(end_value) + int(include_end)
94
95	assert start_idx < end_idx, "Start must come before end"
96
97	return tokens[start_idx:end_idx]

given a list tokens, get the tokens between start_value and end_value

_extended_summary_

Parameters:

  • tokens : list[str]
  • start_value : str
  • end_value : str
  • include_start : bool (defaults to False)
  • include_end : bool (defaults to False)
  • except_when_tokens_not_unique : bool when True, raise an error if start_value or end_value are not unique in the input tokens (defaults to False)

Returns:

  • list[str]

Raises:

  • ValueError : if start_value and end_value are the same
  • ValueError : if except_when_tokens_not_unique is True and start_value or end_value are not unique in the input tokens
  • ValueError : if start_value or end_value are not present in the input tokens
def get_adj_list_tokens(tokens: list[str]) -> list[str]:
100def get_adj_list_tokens(tokens: list[str]) -> list[str]:
101	"get tokens between ADJLIST_START and ADJLIST_END, without the special tokens themselves"
102	return tokens_between(
103		tokens,
104		SPECIAL_TOKENS.ADJLIST_START,
105		SPECIAL_TOKENS.ADJLIST_END,
106	)

get tokens between ADJLIST_START and ADJLIST_END, without the special tokens themselves

def get_path_tokens(tokens: list[str], trim_end: bool = False) -> list[str]:
109def get_path_tokens(tokens: list[str], trim_end: bool = False) -> list[str]:
110	"""The path is considered everything from the first path coord to the path_end token, if it exists."""
111	if SPECIAL_TOKENS.PATH_START not in tokens:
112		err_msg: str = f"Path start token {SPECIAL_TOKENS.PATH_START} not found in tokens:\n{tokens}"
113		raise ValueError(
114			err_msg,
115		)
116	start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + int(trim_end)
117	end_idx: int | None = None
118	if trim_end and (SPECIAL_TOKENS.PATH_END in tokens):
119		end_idx = tokens.index(SPECIAL_TOKENS.PATH_END)
120	return tokens[start_idx:end_idx]

The path is considered everything from the first path coord to the path_end token, if it exists.

def get_context_tokens(tokens: list[str]) -> list[str]:
123def get_context_tokens(tokens: list[str]) -> list[str]:
124	"get tokens between ADJLIST_START and PATH_START"
125	return tokens_between(
126		tokens,
127		SPECIAL_TOKENS.ADJLIST_START,
128		SPECIAL_TOKENS.PATH_START,
129		include_start=True,
130		include_end=True,
131	)

get tokens between ADJLIST_START and PATH_START

def get_origin_tokens(tokens: list[str]) -> list[str]:
134def get_origin_tokens(tokens: list[str]) -> list[str]:
135	"get tokens_between ORIGIN_START and ORIGIN_END"
136	return tokens_between(
137		tokens,
138		SPECIAL_TOKENS.ORIGIN_START,
139		SPECIAL_TOKENS.ORIGIN_END,
140		include_start=False,
141		include_end=False,
142	)

get tokens_between ORIGIN_START and ORIGIN_END

def get_target_tokens(tokens: list[str]) -> list[str]:
145def get_target_tokens(tokens: list[str]) -> list[str]:
146	"get tokens_between TARGET_START and TARGET_END"
147	return tokens_between(
148		tokens,
149		SPECIAL_TOKENS.TARGET_START,
150		SPECIAL_TOKENS.TARGET_END,
151		include_start=False,
152		include_end=False,
153	)

get tokens_between TARGET_START and TARGET_END

def get_cardinal_direction(coords: jaxtyping.Int[ndarray, 'start_end=2 row_col=2']) -> str:
156def get_cardinal_direction(coords: Int[np.ndarray, "start_end=2 row_col=2"]) -> str:
157	"""Returns the cardinal direction token corresponding to traveling from `coords[0]` to `coords[1]`."""
158	return CARDINAL_MAP[tuple(coords[1] - coords[0])]

Returns the cardinal direction token corresponding to traveling from coords[0] to coords[1].

def get_relative_direction(coords: jaxtyping.Int[ndarray, 'prev_cur_next=3 row_col=2']) -> str:
161def get_relative_direction(coords: Int[np.ndarray, "prev_cur_next=3 row_col=2"]) -> str:
162	"""Returns the relative first-person direction token corresponding to traveling from `coords[1]` to `coords[2]`.
163
164	# Parameters
165	- `coords`: Contains 3 Coords, each of which must neighbor the previous Coord.
166		- `coords[0]`: The previous location, used to determine the current absolute direction that the "agent" is facing.
167		- `coords[1]`: The current location
168		- `coords[2]`: The next location. May be equal to the current location.
169	"""
170	if coords.shape != (3, 2):
171		err_msg: str = f"`coords` must have shape (3,2). Got {coords.shape} instead."
172		raise ValueError(err_msg)
173	directions = coords[1:] - coords[:-1]
174	if not np.all(np.linalg.norm(directions, axis=1) <= np.array([1.1, 1.1])):
175		# Use floats as constant since `np.linalg.norm` returns float array
176		err_msg: str = f"Adjacent `coords` must be neighboring or equivalent. Got {coords} instead."
177		raise ValueError(
178			err_msg,
179		)
180	if np.array_equal(coords[1], coords[2]):
181		return VOCAB.PATH_STAY
182	if np.array_equal(coords[0], coords[2]):
183		return VOCAB.PATH_BACKWARD
184	if np.array_equal(coords[0], coords[1]):
185		err_msg: str = f"Previous first-person direction indeterminate from {coords=}."
186		raise ValueError(
187			err_msg,
188		)
189	if np.array_equal(directions[0], directions[1]):
190		return VOCAB.PATH_FORWARD
191	directions = np.append(
192		directions,
193		[[0], [0]],
194		axis=1,
195	)  # Augment to represent unit basis vectors in 3D
196	match np.cross(directions[0], directions[1])[-1]:
197		case 1:
198			return VOCAB.PATH_LEFT
199		case -1:
200			return VOCAB.PATH_RIGHT

Returns the relative first-person direction token corresponding to traveling from coords[1] to coords[2].

Parameters

  • coords: Contains 3 Coords, each of which must neighbor the previous Coord.
    • coords[0]: The previous location, used to determine the current absolute direction that the "agent" is facing.
    • coords[1]: The current location
    • coords[2]: The next location. May be equal to the current location.
class TokenizerPendingDeprecationWarning(builtins.PendingDeprecationWarning):
203class TokenizerPendingDeprecationWarning(PendingDeprecationWarning):
204	"""Pending deprecation warnings related to the `MazeTokenizerModular` upgrade."""
205
206	pass

Pending deprecation warnings related to the MazeTokenizerModular upgrade.

Inherited Members
builtins.PendingDeprecationWarning
PendingDeprecationWarning
builtins.BaseException
with_traceback
add_note
args
def str_is_coord(coord_str: str, allow_whitespace: bool = True) -> bool:
209def str_is_coord(coord_str: str, allow_whitespace: bool = True) -> bool:
210	"""return True if the string represents a coordinate, False otherwise"""
211	warnings.warn(
212		"`util.str_is_coord` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.",
213		TokenizerPendingDeprecationWarning,
214	)
215	strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x  # noqa: E731
216
217	coord_str = strip_func(coord_str)
218
219	return all(
220		[
221			coord_str.startswith("("),
222			coord_str.endswith(")"),
223			"," in coord_str,
224			all(
225				strip_func(x).isdigit()
226				for x in strip_func(coord_str.lstrip("(").rstrip(")")).split(",")
227			),
228		],
229	)

return True if the string represents a coordinate, False otherwise

class TokenizerDeprecationWarning(builtins.DeprecationWarning):
232class TokenizerDeprecationWarning(DeprecationWarning):
233	"""Deprecation warnings related to the `MazeTokenizerModular` upgrade."""
234
235	pass

Deprecation warnings related to the MazeTokenizerModular upgrade.

Inherited Members
builtins.DeprecationWarning
DeprecationWarning
builtins.BaseException
with_traceback
add_note
args
def coord_str_to_tuple(coord_str: str, allow_whitespace: bool = True) -> tuple[int, ...]:
262def coord_str_to_tuple(
263	coord_str: str,
264	allow_whitespace: bool = True,
265) -> tuple[int, ...]:
266	"""convert a coordinate string to a tuple"""
267	strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x  # noqa: E731
268	coord_str = strip_func(coord_str)
269	stripped: str = strip_func(coord_str.lstrip("(").rstrip(")"))
270	return tuple(int(strip_func(x)) for x in stripped.split(","))

convert a coordinate string to a tuple

def coord_str_to_coord_np(coord_str: str, allow_whitespace: bool = True) -> numpy.ndarray:
273def coord_str_to_coord_np(coord_str: str, allow_whitespace: bool = True) -> np.ndarray:
274	"""convert a coordinate string to a numpy array"""
275	return np.array(coord_str_to_tuple(coord_str, allow_whitespace=allow_whitespace))

convert a coordinate string to a numpy array

def coord_str_to_tuple_noneable(coord_str: str) -> tuple[int, int] | None:
278def coord_str_to_tuple_noneable(coord_str: str) -> CoordTup | None:
279	"""convert a coordinate string to a tuple, or None if the string is not a coordinate string"""
280	if not str_is_coord(coord_str):
281		return None
282	return coord_str_to_tuple(coord_str)

convert a coordinate string to a tuple, or None if the string is not a coordinate string

def coords_string_split_UT(coords: str) -> list[str]:
285def coords_string_split_UT(coords: str) -> list[str]:
286	"""Splits a string of tokens into a list containing the UT tokens for each coordinate.
287
288	Not capable of producing indexed tokens ("(", "1", ",", "2", ")"), only unique tokens ("(1,2)").
289	Non-whitespace portions of the input string not matched are preserved in the same list:
290	"(1,2) <SPECIAL_TOKEN> (5,6)" -> ["(1,2)", "<SPECIAL_TOKEN>", "(5,6)"]
291	"""
292	# ty gpt4
293	return re.findall(r"\([^)]*\)|\S+", coords)

Splits a string of tokens into a list containing the UT tokens for each coordinate.

Not capable of producing indexed tokens ("(", "1", ",", "2", ")"), only unique tokens ("(1,2)"). Non-whitespace portions of the input string not matched are preserved in the same list: "(1,2) (5,6)" -> ["(1,2)", "", "(5,6)"]

def strings_to_coords( text: str | list[str], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str | tuple[int, int]]:
313def strings_to_coords(
314	text: str | list[str],
315	when_noncoord: WhenMissing = "skip",
316) -> list[str | CoordTup]:
317	"""converts a list of tokens to a list of coordinates
318
319	returns list[CoordTup] if `when_noncoord` is "skip" or "error"
320	returns list[str | CoordTup] if `when_noncoord` is "include"
321	"""
322	warnings.warn(
323		"`util.strings_to_coords` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.",
324		TokenizerPendingDeprecationWarning,
325	)
326	tokens_joined: str = text if isinstance(text, str) else " ".join(text)
327	tokens_processed: list[str] = coords_string_split_UT(tokens_joined)
328	result: list[str] = list()
329	for token in tokens_processed:
330		coord: CoordTup | None = coord_str_to_tuple_noneable(token)
331		if coord is None:
332			if when_noncoord == "skip":
333				continue
334			if when_noncoord == "error":
335				err_msg: str = (
336					f"Invalid non-coordinate token '{token}' in text: '{text}'"
337				)
338				raise ValueError(
339					err_msg,
340				)
341			if when_noncoord == "include":
342				result.append(token)
343			else:
344				err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'"
345				raise ValueError(err_msg)
346		else:
347			result.append(coord)
348	return result

converts a list of tokens to a list of coordinates

returns list[CoordTup] if when_noncoord is "skip" or "error" returns list[str | CoordTup] if when_noncoord is "include"

def coords_to_strings( coords: list[str | tuple[int, int]], coord_to_strings_func: Callable[[tuple[int, int]], list[str]], when_noncoord: Literal['except', 'skip', 'include'] = 'skip') -> list[str]:
363def coords_to_strings(
364	coords: list[str | CoordTup],
365	coord_to_strings_func: Callable[[CoordTup], list[str]],
366	when_noncoord: WhenMissing = "skip",
367) -> list[str]:
368	"""converts a list of coordinates to a list of strings (tokens)
369
370	expects list[CoordTup] if `when_noncoord` is "error"
371	expects list[str | CoordTup] if `when_noncoord` is "include" or "skip"
372	"""
373	result: list[str] = list()
374	for coord in coords:
375		if isinstance(coord, str):
376			if when_noncoord == "skip":
377				continue
378			if when_noncoord == "error":
379				err_msg: str = (
380					f"Invalid non-coordinate '{coord}' in list of coords: '{coords}'"
381				)
382				raise ValueError(
383					err_msg,
384				)
385			if when_noncoord == "include":
386				result.append(coord)
387			else:
388				err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'"
389				raise ValueError(err_msg)
390		else:
391			result.extend(coord_to_strings_func(coord))
392	return result

converts a list of coordinates to a list of strings (tokens)

expects list[CoordTup] if when_noncoord is "error" expects list[str | CoordTup] if when_noncoord is "include" or "skip"

def get_token_regions(toks: list[str]) -> tuple[list[str], list[str]]:
395def get_token_regions(toks: list[str]) -> tuple[list[str], list[str]]:
396	"""Splits a list of tokens into adjacency list tokens and non-adjacency list tokens."""
397	adj_list_start, adj_list_end = (
398		toks.index("<ADJLIST_START>") + 1,
399		toks.index("<ADJLIST_END>"),
400	)
401	adj_list = toks[adj_list_start:adj_list_end]
402	non_adj_list = toks[:adj_list_start] + toks[adj_list_end:]
403	return adj_list, non_adj_list

Splits a list of tokens into adjacency list tokens and non-adjacency list tokens.

def equal_except_adj_list_sequence( rollout1: list[str], rollout2: list[str], do_except: bool = False, when_counter_mismatch: muutils.errormode.ErrorMode = ErrorMode.Except, when_len_mismatch: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
406def equal_except_adj_list_sequence(  # noqa: C901
407	rollout1: list[str],
408	rollout2: list[str],
409	do_except: bool = False,
410	when_counter_mismatch: ErrorMode = ErrorMode.EXCEPT,
411	when_len_mismatch: ErrorMode = ErrorMode.EXCEPT,
412) -> bool:
413	"""Returns if the rollout strings are equal, allowing for differently sequenced adjacency lists.
414
415	<ADJLIST_START> and <ADJLIST_END> tokens must be in the rollouts.
416	Intended ONLY for determining if two tokenization schemes are the same for rollouts generated from the same maze.
417	This function should NOT be used to determine if two rollouts encode the same `LatticeMaze` object.
418
419	# Warning: CTT False Positives
420	This function is not robustly correct for some corner cases using `CoordTokenizers.CTT`.
421	If rollouts are passed for identical tokenizers processing two slightly different mazes, a false positive is possible.
422	More specifically, some cases of zero-sum adding and removing of connections in a maze within square regions along the diagonal will produce a false positive.
423	"""
424	if len(rollout1) != len(rollout2):
425		if do_except:
426			when_len_mismatch.process(
427				f"Rollouts are not the same length: {len(rollout1)} != {len(rollout2)}",
428			)
429		return False
430	if ("<ADJLIST_START>" in rollout1) ^ ("<ADJLIST_START>" in rollout2):
431		if do_except:
432			err_msg: str = f"Rollouts do not have the same <ADJLIST_START> token: `{'<ADJLIST_START>' in rollout1 = }` != `{'<ADJLIST_START>' in rollout2 = }`"
433			raise ValueError(
434				err_msg,
435			)
436		return False
437	if ("<ADJLIST_END>" in rollout1) ^ ("<ADJLIST_END>" in rollout2):
438		if do_except:
439			err_msg: str = f"Rollouts do not have the same <ADJLIST_END> token: `{'<ADJLIST_END>' in rollout1 = }` != `{'<ADJLIST_END>' in rollout2 = }`"
440			raise ValueError(
441				err_msg,
442			)
443		return False
444
445	adj_list1, non_adj_list1 = get_token_regions(rollout1)
446	adj_list2, non_adj_list2 = get_token_regions(rollout2)
447	if non_adj_list1 != non_adj_list2:
448		if do_except:
449			when_len_mismatch.process(
450				f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}",
451			)
452			err_msg: str = f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}"
453			raise ValueError(
454				err_msg,
455			)
456		return False
457	counter1: Counter = Counter(adj_list1)
458	counter2: Counter = Counter(adj_list2)
459	counters_eq: bool = counter1 == counter2
460	if not counters_eq:
461		if do_except:
462			when_counter_mismatch.process(
463				f"Adjacency list counters are not the same:\n{counter1}\n!=\n{counter2}\n{counter1 - counter2 = }",
464			)
465		return False
466
467	return True

Returns if the rollout strings are equal, allowing for differently sequenced adjacency lists.

and tokens must be in the rollouts. Intended ONLY for determining if two tokenization schemes are the same for rollouts generated from the same maze. This function should NOT be used to determine if two rollouts encode the same LatticeMaze object.

Warning: CTT False Positives

This function is not robustly correct for some corner cases using CoordTokenizers.CTT. If rollouts are passed for identical tokenizers processing two slightly different mazes, a false positive is possible. More specifically, some cases of zero-sum adding and removing of connections in a maze within square regions along the diagonal will produce a false positive.

def connection_list_to_adj_list( conn_list: jaxtyping.Bool[ndarray, 'lattice_dim=2 row col'], shuffle_d0: bool = True, shuffle_d1: bool = True) -> jaxtyping.Int8[ndarray, 'conn start_end=2 coord=2']:
470def connection_list_to_adj_list(
471	conn_list: ConnectionList,
472	shuffle_d0: bool = True,
473	shuffle_d1: bool = True,
474) -> Int8[np.ndarray, "conn start_end=2 coord=2"]:
475	"""converts a `ConnectionList` (special lattice format) to a shuffled adjacency list
476
477	# Parameters:
478	- `conn_list: ConnectionList`
479		special internal format for graphs which are subgraphs of a lattice
480	- `shuffle_d0: bool`
481		shuffle the adjacency list along the 0th axis (order of pairs)
482	- `shuffle_d1: bool`
483		shuffle the adjacency list along the 1st axis (order of coordinates in each pair).
484		If `False`, all pairs have the smaller coord first.
485
486
487	# Returns:
488	- `Int8[np.ndarray, "conn start_end=2 coord=2"]`
489		adjacency list in the shape `(n_connections, 2, 2)`
490	"""
491	n_connections: int = conn_list.sum()
492	adj_list: Int8[np.ndarray, "conn start_end=2 coord=2"] = np.full(
493		(n_connections, 2, 2),
494		-1,
495		dtype=np.int8,
496	)
497
498	if shuffle_d1:
499		flip_d1: Float[np.ndarray, " conn"] = np.random.rand(n_connections)
500
501	# loop over all nonzero elements of the connection list
502	i: int = 0
503	for d, x, y in np.ndindex(conn_list.shape):
504		if conn_list[d, x, y]:
505			c_start: CoordTup = (x, y)
506			c_end: CoordTup = (
507				x + (1 if d == 0 else 0),
508				y + (1 if d == 1 else 0),
509			)
510			adj_list[i, 0] = np.array(c_start, dtype=np.int8)
511			adj_list[i, 1] = np.array(c_end, dtype=np.int8)
512
513			# flip if shuffling
514			# magic value is fine here
515			if shuffle_d1 and (flip_d1[i] > 0.5):  # noqa: PLR2004
516				c_s, c_e = adj_list[i, 0].copy(), adj_list[i, 1].copy()
517				adj_list[i, 0] = c_e
518				adj_list[i, 1] = c_s
519
520			i += 1
521
522	if shuffle_d0:
523		np.random.shuffle(adj_list)
524
525	return adj_list

converts a ConnectionList (special lattice format) to a shuffled adjacency list

Parameters:

  • conn_list: ConnectionList special internal format for graphs which are subgraphs of a lattice
  • shuffle_d0: bool shuffle the adjacency list along the 0th axis (order of pairs)
  • shuffle_d1: bool shuffle the adjacency list along the 1st axis (order of coordinates in each pair). If False, all pairs have the smaller coord first.

Returns:

  • Int8[np.ndarray, "conn start_end=2 coord=2"] adjacency list in the shape (n_connections, 2, 2)
def is_connection( edges: jaxtyping.Int8[ndarray, 'edges leading_trailing_coord=2 row_col=2'], connection_list: jaxtyping.Bool[ndarray, 'lattice_dim=2 row col']) -> jaxtyping.Bool[ndarray, 'is_connection=edges']:
528def is_connection(
529	edges: ConnectionArray,
530	connection_list: ConnectionList,
531) -> Bool[np.ndarray, "is_connection=edges"]:
532	"""Returns if each edge in `edges` is a connection (`True`) or wall (`False`) in `connection_list`."""
533	sorted_edges = np.sort(edges, axis=1)
534	edge_direction = (
535		(sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0
536	).astype(np.int8)
537	return connection_list[edge_direction, sorted_edges[:, 0, 0], sorted_edges[:, 0, 1]]

Returns if each edge in edges is a connection (True) or wall (False) in connection_list.