Coverage for maze_dataset/token_utils.py: 64%
202 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
1"""a whole bunch of utilities for tokenization"""
3import re
4import typing
5import warnings
6from collections import Counter
7from typing import Callable, Literal, overload
9import numpy as np
10from jaxtyping import Bool, Float, Int, Int8
11from muutils.errormode import ErrorMode
12from muutils.misc import list_join
13from muutils.misc.sequence import WhenMissing
15from maze_dataset.constants import (
16 CARDINAL_MAP,
17 SPECIAL_TOKENS,
18 VOCAB,
19 ConnectionArray,
20 ConnectionList,
21 CoordTup,
22)
24# filtering things from a prompt or generated text
25# ==================================================
28def remove_padding_from_token_str(token_str: str) -> str:
29 """remove padding tokens from a joined token string"""
30 token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING} ", "")
31 token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING}", "")
32 return token_str # noqa: RET504
35def tokens_between(
36 tokens: list[str],
37 start_value: str,
38 end_value: str,
39 include_start: bool = False,
40 include_end: bool = False,
41 except_when_tokens_not_unique: bool = False,
42) -> list[str]:
43 """given a list `tokens`, get the tokens between `start_value` and `end_value`
45 _extended_summary_
47 # Parameters:
48 - `tokens : list[str]`
49 - `start_value : str`
50 - `end_value : str`
51 - `include_start : bool`
52 (defaults to `False`)
53 - `include_end : bool`
54 (defaults to `False`)
55 - `except_when_tokens_not_unique : bool`
56 when `True`, raise an error if `start_value` or `end_value` are not unique in the input tokens
57 (defaults to `False`)
59 # Returns:
60 - `list[str]`
62 # Raises:
63 - `ValueError` : if `start_value` and `end_value` are the same
64 - `ValueError` : if `except_when_tokens_not_unique` is `True` and `start_value` or `end_value` are not unique in the input tokens
65 - `ValueError` : if `start_value` or `end_value` are not present in the input tokens
66 """
67 if start_value == end_value:
68 err_msg: str = f"start_value and end_value cannot be the same: {start_value = } {end_value = }"
69 raise ValueError(
70 err_msg,
71 )
72 if except_when_tokens_not_unique:
73 if (tokens.count(start_value) != 1) or (tokens.count(end_value) != 1):
74 err_msg: str = (
75 "start_value or end_value is not unique in the input tokens:"
76 f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }"
77 f"\n{start_value = } {end_value = }"
78 f"\n{tokens = }"
79 )
80 raise ValueError(err_msg)
81 else:
82 if (tokens.count(start_value) < 1) or (tokens.count(end_value) < 1):
83 err_msg: str = (
84 "start_value or end_value is not present in the input tokens:"
85 f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }"
86 f"\n{start_value = } {end_value = }"
87 f"\n{tokens = }"
88 )
89 raise ValueError(err_msg)
91 start_idx: int = tokens.index(start_value) + int(not include_start)
92 end_idx: int = tokens.index(end_value) + int(include_end)
94 assert start_idx < end_idx, "Start must come before end"
96 return tokens[start_idx:end_idx]
99def get_adj_list_tokens(tokens: list[str]) -> list[str]:
100 "get tokens between ADJLIST_START and ADJLIST_END, without the special tokens themselves"
101 return tokens_between(
102 tokens,
103 SPECIAL_TOKENS.ADJLIST_START,
104 SPECIAL_TOKENS.ADJLIST_END,
105 )
108def get_path_tokens(tokens: list[str], trim_end: bool = False) -> list[str]:
109 """The path is considered everything from the first path coord to the path_end token, if it exists."""
110 if SPECIAL_TOKENS.PATH_START not in tokens:
111 err_msg: str = f"Path start token {SPECIAL_TOKENS.PATH_START} not found in tokens:\n{tokens}"
112 raise ValueError(
113 err_msg,
114 )
115 start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + int(trim_end)
116 end_idx: int | None = None
117 if trim_end and (SPECIAL_TOKENS.PATH_END in tokens):
118 end_idx = tokens.index(SPECIAL_TOKENS.PATH_END)
119 return tokens[start_idx:end_idx]
122def get_context_tokens(tokens: list[str]) -> list[str]:
123 "get tokens between ADJLIST_START and PATH_START"
124 return tokens_between(
125 tokens,
126 SPECIAL_TOKENS.ADJLIST_START,
127 SPECIAL_TOKENS.PATH_START,
128 include_start=True,
129 include_end=True,
130 )
133def get_origin_tokens(tokens: list[str]) -> list[str]:
134 "get tokens_between ORIGIN_START and ORIGIN_END"
135 return tokens_between(
136 tokens,
137 SPECIAL_TOKENS.ORIGIN_START,
138 SPECIAL_TOKENS.ORIGIN_END,
139 include_start=False,
140 include_end=False,
141 )
144def get_target_tokens(tokens: list[str]) -> list[str]:
145 "get tokens_between TARGET_START and TARGET_END"
146 return tokens_between(
147 tokens,
148 SPECIAL_TOKENS.TARGET_START,
149 SPECIAL_TOKENS.TARGET_END,
150 include_start=False,
151 include_end=False,
152 )
155def get_cardinal_direction(coords: Int[np.ndarray, "start_end=2 row_col=2"]) -> str:
156 """Returns the cardinal direction token corresponding to traveling from `coords[0]` to `coords[1]`."""
157 return CARDINAL_MAP[tuple(coords[1] - coords[0])]
160def get_relative_direction(coords: Int[np.ndarray, "prev_cur_next=3 row_col=2"]) -> str:
161 """Returns the relative first-person direction token corresponding to traveling from `coords[1]` to `coords[2]`.
163 # Parameters
164 - `coords`: Contains 3 Coords, each of which must neighbor the previous Coord.
165 - `coords[0]`: The previous location, used to determine the current absolute direction that the "agent" is facing.
166 - `coords[1]`: The current location
167 - `coords[2]`: The next location. May be equal to the current location.
168 """
169 if coords.shape != (3, 2):
170 err_msg: str = f"`coords` must have shape (3,2). Got {coords.shape} instead."
171 raise ValueError(err_msg)
172 directions = coords[1:] - coords[:-1]
173 if not np.all(np.linalg.norm(directions, axis=1) <= np.array([1.1, 1.1])):
174 # Use floats as constant since `np.linalg.norm` returns float array
175 err_msg: str = f"Adjacent `coords` must be neighboring or equivalent. Got {coords} instead."
176 raise ValueError(
177 err_msg,
178 )
179 if np.array_equal(coords[1], coords[2]):
180 return VOCAB.PATH_STAY
181 if np.array_equal(coords[0], coords[2]):
182 return VOCAB.PATH_BACKWARD
183 if np.array_equal(coords[0], coords[1]):
184 err_msg: str = f"Previous first-person direction indeterminate from {coords=}."
185 raise ValueError(
186 err_msg,
187 )
188 if np.array_equal(directions[0], directions[1]):
189 return VOCAB.PATH_FORWARD
190 directions = np.append(
191 directions,
192 [[0], [0]],
193 axis=1,
194 ) # Augment to represent unit basis vectors in 3D
195 match np.cross(directions[0], directions[1])[-1]:
196 case 1:
197 return VOCAB.PATH_LEFT
198 case -1:
199 return VOCAB.PATH_RIGHT
202class TokenizerPendingDeprecationWarning(PendingDeprecationWarning):
203 """Pending deprecation warnings related to the `MazeTokenizerModular` upgrade."""
205 pass
208def str_is_coord(coord_str: str, allow_whitespace: bool = True) -> bool:
209 """return True if the string represents a coordinate, False otherwise"""
210 warnings.warn(
211 "`util.str_is_coord` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.",
212 TokenizerPendingDeprecationWarning,
213 )
214 strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x # noqa: E731
216 coord_str = strip_func(coord_str)
218 return all(
219 [
220 coord_str.startswith("("),
221 coord_str.endswith(")"),
222 "," in coord_str,
223 all(
224 strip_func(x).isdigit()
225 for x in strip_func(coord_str.lstrip("(").rstrip(")")).split(",")
226 ),
227 ],
228 )
231class TokenizerDeprecationWarning(DeprecationWarning):
232 """Deprecation warnings related to the `MazeTokenizerModular` upgrade."""
234 pass
237# coordinate to strings
238# ==================================================
241def _coord_to_strings_UT(coord: typing.Sequence[int]) -> list[str]:
242 """convert a coordinate to a string: `(i,j)`->"(i,j)"
244 always returns a list of length 1
245 """
246 return [f"({','.join(str(c) for c in coord)})"]
249def _coord_to_strings_indexed(coord: typing.Sequence[int]) -> list[str]:
250 """convert a coordinate to a list of indexed strings: `(i,j)`->"(", "i", ",", "j", ")"
252 always returns a list of length 5
253 """
254 return [
255 "(",
256 *list_join([str(c) for c in coord], lambda: ","),
257 ")",
258 ]
261def coord_str_to_tuple(
262 coord_str: str,
263 allow_whitespace: bool = True,
264) -> tuple[int, ...]:
265 """convert a coordinate string to a tuple"""
266 strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x # noqa: E731
267 coord_str = strip_func(coord_str)
268 stripped: str = strip_func(coord_str.lstrip("(").rstrip(")"))
269 return tuple(int(strip_func(x)) for x in stripped.split(","))
272def coord_str_to_coord_np(coord_str: str, allow_whitespace: bool = True) -> np.ndarray:
273 """convert a coordinate string to a numpy array"""
274 return np.array(coord_str_to_tuple(coord_str, allow_whitespace=allow_whitespace))
277def coord_str_to_tuple_noneable(coord_str: str) -> CoordTup | None:
278 """convert a coordinate string to a tuple, or None if the string is not a coordinate string"""
279 if not str_is_coord(coord_str):
280 return None
281 return coord_str_to_tuple(coord_str)
284def coords_string_split_UT(coords: str) -> list[str]:
285 """Splits a string of tokens into a list containing the UT tokens for each coordinate.
287 Not capable of producing indexed tokens ("(", "1", ",", "2", ")"), only unique tokens ("(1,2)").
288 Non-whitespace portions of the input string not matched are preserved in the same list:
289 "(1,2) <SPECIAL_TOKEN> (5,6)" -> ["(1,2)", "<SPECIAL_TOKEN>", "(5,6)"]
290 """
291 # ty gpt4
292 return re.findall(r"\([^)]*\)|\S+", coords)
295# back and forth in wrapped form
296# ==================================================
297@overload
298def strings_to_coords(
299 text: str | list[str],
300 when_noncoord: Literal["skip"] = "skip",
301) -> list[CoordTup]: ...
302@overload
303def strings_to_coords(
304 text: str | list[str],
305 when_noncoord: Literal["error"] = "error",
306) -> list[CoordTup]: ...
307@overload
308def strings_to_coords(
309 text: str | list[str],
310 when_noncoord: Literal["include"] = "include",
311) -> list[str | CoordTup]: ...
312def strings_to_coords(
313 text: str | list[str],
314 when_noncoord: WhenMissing = "skip",
315) -> list[str | CoordTup]:
316 """converts a list of tokens to a list of coordinates
318 returns list[CoordTup] if `when_noncoord` is "skip" or "error"
319 returns list[str | CoordTup] if `when_noncoord` is "include"
320 """
321 warnings.warn(
322 "`util.strings_to_coords` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.",
323 TokenizerPendingDeprecationWarning,
324 )
325 tokens_joined: str = text if isinstance(text, str) else " ".join(text)
326 tokens_processed: list[str] = coords_string_split_UT(tokens_joined)
327 result: list[str] = list()
328 for token in tokens_processed:
329 coord: CoordTup | None = coord_str_to_tuple_noneable(token)
330 if coord is None:
331 if when_noncoord == "skip":
332 continue
333 if when_noncoord == "error":
334 err_msg: str = (
335 f"Invalid non-coordinate token '{token}' in text: '{text}'"
336 )
337 raise ValueError(
338 err_msg,
339 )
340 if when_noncoord == "include":
341 result.append(token)
342 else:
343 err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'"
344 raise ValueError(err_msg)
345 else:
346 result.append(coord)
347 return result
350@overload
351def coords_to_strings(
352 coords: list[str | CoordTup],
353 coord_to_strings_func: Callable[[CoordTup], list[str]],
354 when_noncoord: Literal["include", "skip"] = "skip",
355) -> list[str]: ...
356@overload
357def coords_to_strings(
358 coords: list[CoordTup],
359 coord_to_strings_func: Callable[[CoordTup], list[str]],
360 when_noncoord: Literal["error"] = "error",
361) -> list[str]: ...
362def coords_to_strings(
363 coords: list[str | CoordTup],
364 coord_to_strings_func: Callable[[CoordTup], list[str]],
365 when_noncoord: WhenMissing = "skip",
366) -> list[str]:
367 """converts a list of coordinates to a list of strings (tokens)
369 expects list[CoordTup] if `when_noncoord` is "error"
370 expects list[str | CoordTup] if `when_noncoord` is "include" or "skip"
371 """
372 result: list[str] = list()
373 for coord in coords:
374 if isinstance(coord, str):
375 if when_noncoord == "skip":
376 continue
377 if when_noncoord == "error":
378 err_msg: str = (
379 f"Invalid non-coordinate '{coord}' in list of coords: '{coords}'"
380 )
381 raise ValueError(
382 err_msg,
383 )
384 if when_noncoord == "include":
385 result.append(coord)
386 else:
387 err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'"
388 raise ValueError(err_msg)
389 else:
390 result.extend(coord_to_strings_func(coord))
391 return result
394def get_token_regions(toks: list[str]) -> tuple[list[str], list[str]]:
395 """Splits a list of tokens into adjacency list tokens and non-adjacency list tokens."""
396 adj_list_start, adj_list_end = (
397 toks.index("<ADJLIST_START>") + 1,
398 toks.index("<ADJLIST_END>"),
399 )
400 adj_list = toks[adj_list_start:adj_list_end]
401 non_adj_list = toks[:adj_list_start] + toks[adj_list_end:]
402 return adj_list, non_adj_list
405def equal_except_adj_list_sequence( # noqa: C901
406 rollout1: list[str],
407 rollout2: list[str],
408 do_except: bool = False,
409 when_counter_mismatch: ErrorMode = ErrorMode.EXCEPT,
410 when_len_mismatch: ErrorMode = ErrorMode.EXCEPT,
411) -> bool:
412 """Returns if the rollout strings are equal, allowing for differently sequenced adjacency lists.
414 <ADJLIST_START> and <ADJLIST_END> tokens must be in the rollouts.
415 Intended ONLY for determining if two tokenization schemes are the same for rollouts generated from the same maze.
416 This function should NOT be used to determine if two rollouts encode the same `LatticeMaze` object.
418 # Warning: CTT False Positives
419 This function is not robustly correct for some corner cases using `CoordTokenizers.CTT`.
420 If rollouts are passed for identical tokenizers processing two slightly different mazes, a false positive is possible.
421 More specifically, some cases of zero-sum adding and removing of connections in a maze within square regions along the diagonal will produce a false positive.
422 """
423 if len(rollout1) != len(rollout2):
424 if do_except:
425 when_len_mismatch.process(
426 f"Rollouts are not the same length: {len(rollout1)} != {len(rollout2)}",
427 )
428 return False
429 if ("<ADJLIST_START>" in rollout1) ^ ("<ADJLIST_START>" in rollout2):
430 if do_except:
431 err_msg: str = f"Rollouts do not have the same <ADJLIST_START> token: `{'<ADJLIST_START>' in rollout1 = }` != `{'<ADJLIST_START>' in rollout2 = }`"
432 raise ValueError(
433 err_msg,
434 )
435 return False
436 if ("<ADJLIST_END>" in rollout1) ^ ("<ADJLIST_END>" in rollout2):
437 if do_except:
438 err_msg: str = f"Rollouts do not have the same <ADJLIST_END> token: `{'<ADJLIST_END>' in rollout1 = }` != `{'<ADJLIST_END>' in rollout2 = }`"
439 raise ValueError(
440 err_msg,
441 )
442 return False
444 adj_list1, non_adj_list1 = get_token_regions(rollout1)
445 adj_list2, non_adj_list2 = get_token_regions(rollout2)
446 if non_adj_list1 != non_adj_list2:
447 if do_except:
448 when_len_mismatch.process(
449 f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}",
450 )
451 err_msg: str = f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}"
452 raise ValueError(
453 err_msg,
454 )
455 return False
456 counter1: Counter = Counter(adj_list1)
457 counter2: Counter = Counter(adj_list2)
458 counters_eq: bool = counter1 == counter2
459 if not counters_eq:
460 if do_except:
461 when_counter_mismatch.process(
462 f"Adjacency list counters are not the same:\n{counter1}\n!=\n{counter2}\n{counter1 - counter2 = }",
463 )
464 return False
466 return True
469def connection_list_to_adj_list(
470 conn_list: ConnectionList,
471 shuffle_d0: bool = True,
472 shuffle_d1: bool = True,
473) -> Int8[np.ndarray, "conn start_end=2 coord=2"]:
474 """converts a `ConnectionList` (special lattice format) to a shuffled adjacency list
476 # Parameters:
477 - `conn_list: ConnectionList`
478 special internal format for graphs which are subgraphs of a lattice
479 - `shuffle_d0: bool`
480 shuffle the adjacency list along the 0th axis (order of pairs)
481 - `shuffle_d1: bool`
482 shuffle the adjacency list along the 1st axis (order of coordinates in each pair).
483 If `False`, all pairs have the smaller coord first.
486 # Returns:
487 - `Int8[np.ndarray, "conn start_end=2 coord=2"]`
488 adjacency list in the shape `(n_connections, 2, 2)`
489 """
490 n_connections: int = conn_list.sum()
491 adj_list: Int8[np.ndarray, "conn start_end=2 coord=2"] = np.full(
492 (n_connections, 2, 2),
493 -1,
494 dtype=np.int8,
495 )
497 if shuffle_d1:
498 flip_d1: Float[np.ndarray, " conn"] = np.random.rand(n_connections)
500 # loop over all nonzero elements of the connection list
501 i: int = 0
502 for d, x, y in np.ndindex(conn_list.shape):
503 if conn_list[d, x, y]:
504 c_start: CoordTup = (x, y)
505 c_end: CoordTup = (
506 x + (1 if d == 0 else 0),
507 y + (1 if d == 1 else 0),
508 )
509 adj_list[i, 0] = np.array(c_start, dtype=np.int8)
510 adj_list[i, 1] = np.array(c_end, dtype=np.int8)
512 # flip if shuffling
513 # magic value is fine here
514 if shuffle_d1 and (flip_d1[i] > 0.5): # noqa: PLR2004
515 c_s, c_e = adj_list[i, 0].copy(), adj_list[i, 1].copy()
516 adj_list[i, 0] = c_e
517 adj_list[i, 1] = c_s
519 i += 1
521 if shuffle_d0:
522 np.random.shuffle(adj_list)
524 return adj_list
527def is_connection(
528 edges: ConnectionArray,
529 connection_list: ConnectionList,
530) -> Bool[np.ndarray, "is_connection=edges"]:
531 """Returns if each edge in `edges` is a connection (`True`) or wall (`False`) in `connection_list`."""
532 sorted_edges = np.sort(edges, axis=1)
533 edge_direction = (
534 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0
535 ).astype(np.int8)
536 return connection_list[edge_direction, sorted_edges[:, 0, 0], sorted_edges[:, 0, 1]]
539# string to coordinate representation
540# ==================================================