maze_dataset.token_utils
a whole bunch of utilities for tokenization
1"""a whole bunch of utilities for tokenization""" 2 3import re 4import typing 5import warnings 6from collections import Counter 7from typing import Callable, Literal, overload 8 9import numpy as np 10from jaxtyping import Bool, Float, Int, Int8 11from muutils.errormode import ErrorMode 12from muutils.misc import list_join 13from muutils.misc.sequence import WhenMissing 14 15from maze_dataset.constants import ( 16 CARDINAL_MAP, 17 SPECIAL_TOKENS, 18 VOCAB, 19 ConnectionArray, 20 ConnectionList, 21 CoordTup, 22) 23 24# filtering things from a prompt or generated text 25# ================================================== 26 27 28def remove_padding_from_token_str(token_str: str) -> str: 29 """remove padding tokens from a joined token string""" 30 token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING} ", "") 31 token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING}", "") 32 return token_str # noqa: RET504 33 34 35def tokens_between( 36 tokens: list[str], 37 start_value: str, 38 end_value: str, 39 include_start: bool = False, 40 include_end: bool = False, 41 except_when_tokens_not_unique: bool = False, 42) -> list[str]: 43 """given a list `tokens`, get the tokens between `start_value` and `end_value` 44 45 _extended_summary_ 46 47 # Parameters: 48 - `tokens : list[str]` 49 - `start_value : str` 50 - `end_value : str` 51 - `include_start : bool` 52 (defaults to `False`) 53 - `include_end : bool` 54 (defaults to `False`) 55 - `except_when_tokens_not_unique : bool` 56 when `True`, raise an error if `start_value` or `end_value` are not unique in the input tokens 57 (defaults to `False`) 58 59 # Returns: 60 - `list[str]` 61 62 # Raises: 63 - `ValueError` : if `start_value` and `end_value` are the same 64 - `ValueError` : if `except_when_tokens_not_unique` is `True` and `start_value` or `end_value` are not unique in the input tokens 65 - `ValueError` : if `start_value` or `end_value` are not present in the input tokens 66 """ 67 if start_value == end_value: 68 err_msg: str = f"start_value and end_value cannot be the same: {start_value = } {end_value = }" 69 raise ValueError( 70 err_msg, 71 ) 72 if except_when_tokens_not_unique: 73 if (tokens.count(start_value) != 1) or (tokens.count(end_value) != 1): 74 err_msg: str = ( 75 "start_value or end_value is not unique in the input tokens:" 76 f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }" 77 f"\n{start_value = } {end_value = }" 78 f"\n{tokens = }" 79 ) 80 raise ValueError(err_msg) 81 else: 82 if (tokens.count(start_value) < 1) or (tokens.count(end_value) < 1): 83 err_msg: str = ( 84 "start_value or end_value is not present in the input tokens:" 85 f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }" 86 f"\n{start_value = } {end_value = }" 87 f"\n{tokens = }" 88 ) 89 raise ValueError(err_msg) 90 91 start_idx: int = tokens.index(start_value) + int(not include_start) 92 end_idx: int = tokens.index(end_value) + int(include_end) 93 94 assert start_idx < end_idx, "Start must come before end" 95 96 return tokens[start_idx:end_idx] 97 98 99def get_adj_list_tokens(tokens: list[str]) -> list[str]: 100 "get tokens between ADJLIST_START and ADJLIST_END, without the special tokens themselves" 101 return tokens_between( 102 tokens, 103 SPECIAL_TOKENS.ADJLIST_START, 104 SPECIAL_TOKENS.ADJLIST_END, 105 ) 106 107 108def get_path_tokens(tokens: list[str], trim_end: bool = False) -> list[str]: 109 """The path is considered everything from the first path coord to the path_end token, if it exists.""" 110 if SPECIAL_TOKENS.PATH_START not in tokens: 111 err_msg: str = f"Path start token {SPECIAL_TOKENS.PATH_START} not found in tokens:\n{tokens}" 112 raise ValueError( 113 err_msg, 114 ) 115 start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + int(trim_end) 116 end_idx: int | None = None 117 if trim_end and (SPECIAL_TOKENS.PATH_END in tokens): 118 end_idx = tokens.index(SPECIAL_TOKENS.PATH_END) 119 return tokens[start_idx:end_idx] 120 121 122def get_context_tokens(tokens: list[str]) -> list[str]: 123 "get tokens between ADJLIST_START and PATH_START" 124 return tokens_between( 125 tokens, 126 SPECIAL_TOKENS.ADJLIST_START, 127 SPECIAL_TOKENS.PATH_START, 128 include_start=True, 129 include_end=True, 130 ) 131 132 133def get_origin_tokens(tokens: list[str]) -> list[str]: 134 "get tokens_between ORIGIN_START and ORIGIN_END" 135 return tokens_between( 136 tokens, 137 SPECIAL_TOKENS.ORIGIN_START, 138 SPECIAL_TOKENS.ORIGIN_END, 139 include_start=False, 140 include_end=False, 141 ) 142 143 144def get_target_tokens(tokens: list[str]) -> list[str]: 145 "get tokens_between TARGET_START and TARGET_END" 146 return tokens_between( 147 tokens, 148 SPECIAL_TOKENS.TARGET_START, 149 SPECIAL_TOKENS.TARGET_END, 150 include_start=False, 151 include_end=False, 152 ) 153 154 155def get_cardinal_direction(coords: Int[np.ndarray, "start_end=2 row_col=2"]) -> str: 156 """Returns the cardinal direction token corresponding to traveling from `coords[0]` to `coords[1]`.""" 157 return CARDINAL_MAP[tuple(coords[1] - coords[0])] 158 159 160def get_relative_direction(coords: Int[np.ndarray, "prev_cur_next=3 row_col=2"]) -> str: 161 """Returns the relative first-person direction token corresponding to traveling from `coords[1]` to `coords[2]`. 162 163 # Parameters 164 - `coords`: Contains 3 Coords, each of which must neighbor the previous Coord. 165 - `coords[0]`: The previous location, used to determine the current absolute direction that the "agent" is facing. 166 - `coords[1]`: The current location 167 - `coords[2]`: The next location. May be equal to the current location. 168 """ 169 if coords.shape != (3, 2): 170 err_msg: str = f"`coords` must have shape (3,2). Got {coords.shape} instead." 171 raise ValueError(err_msg) 172 directions = coords[1:] - coords[:-1] 173 if not np.all(np.linalg.norm(directions, axis=1) <= np.array([1.1, 1.1])): 174 # Use floats as constant since `np.linalg.norm` returns float array 175 err_msg: str = f"Adjacent `coords` must be neighboring or equivalent. Got {coords} instead." 176 raise ValueError( 177 err_msg, 178 ) 179 if np.array_equal(coords[1], coords[2]): 180 return VOCAB.PATH_STAY 181 if np.array_equal(coords[0], coords[2]): 182 return VOCAB.PATH_BACKWARD 183 if np.array_equal(coords[0], coords[1]): 184 err_msg: str = f"Previous first-person direction indeterminate from {coords=}." 185 raise ValueError( 186 err_msg, 187 ) 188 if np.array_equal(directions[0], directions[1]): 189 return VOCAB.PATH_FORWARD 190 directions = np.append( 191 directions, 192 [[0], [0]], 193 axis=1, 194 ) # Augment to represent unit basis vectors in 3D 195 match np.cross(directions[0], directions[1])[-1]: 196 case 1: 197 return VOCAB.PATH_LEFT 198 case -1: 199 return VOCAB.PATH_RIGHT 200 201 202class TokenizerPendingDeprecationWarning(PendingDeprecationWarning): 203 """Pending deprecation warnings related to the `MazeTokenizerModular` upgrade.""" 204 205 pass 206 207 208def str_is_coord(coord_str: str, allow_whitespace: bool = True) -> bool: 209 """return True if the string represents a coordinate, False otherwise""" 210 warnings.warn( 211 "`util.str_is_coord` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.", 212 TokenizerPendingDeprecationWarning, 213 ) 214 strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x # noqa: E731 215 216 coord_str = strip_func(coord_str) 217 218 return all( 219 [ 220 coord_str.startswith("("), 221 coord_str.endswith(")"), 222 "," in coord_str, 223 all( 224 strip_func(x).isdigit() 225 for x in strip_func(coord_str.lstrip("(").rstrip(")")).split(",") 226 ), 227 ], 228 ) 229 230 231class TokenizerDeprecationWarning(DeprecationWarning): 232 """Deprecation warnings related to the `MazeTokenizerModular` upgrade.""" 233 234 pass 235 236 237# coordinate to strings 238# ================================================== 239 240 241def _coord_to_strings_UT(coord: typing.Sequence[int]) -> list[str]: 242 """convert a coordinate to a string: `(i,j)`->"(i,j)" 243 244 always returns a list of length 1 245 """ 246 return [f"({','.join(str(c) for c in coord)})"] 247 248 249def _coord_to_strings_indexed(coord: typing.Sequence[int]) -> list[str]: 250 """convert a coordinate to a list of indexed strings: `(i,j)`->"(", "i", ",", "j", ")" 251 252 always returns a list of length 5 253 """ 254 return [ 255 "(", 256 *list_join([str(c) for c in coord], lambda: ","), 257 ")", 258 ] 259 260 261def coord_str_to_tuple( 262 coord_str: str, 263 allow_whitespace: bool = True, 264) -> tuple[int, ...]: 265 """convert a coordinate string to a tuple""" 266 strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x # noqa: E731 267 coord_str = strip_func(coord_str) 268 stripped: str = strip_func(coord_str.lstrip("(").rstrip(")")) 269 return tuple(int(strip_func(x)) for x in stripped.split(",")) 270 271 272def coord_str_to_coord_np(coord_str: str, allow_whitespace: bool = True) -> np.ndarray: 273 """convert a coordinate string to a numpy array""" 274 return np.array(coord_str_to_tuple(coord_str, allow_whitespace=allow_whitespace)) 275 276 277def coord_str_to_tuple_noneable(coord_str: str) -> CoordTup | None: 278 """convert a coordinate string to a tuple, or None if the string is not a coordinate string""" 279 if not str_is_coord(coord_str): 280 return None 281 return coord_str_to_tuple(coord_str) 282 283 284def coords_string_split_UT(coords: str) -> list[str]: 285 """Splits a string of tokens into a list containing the UT tokens for each coordinate. 286 287 Not capable of producing indexed tokens ("(", "1", ",", "2", ")"), only unique tokens ("(1,2)"). 288 Non-whitespace portions of the input string not matched are preserved in the same list: 289 "(1,2) <SPECIAL_TOKEN> (5,6)" -> ["(1,2)", "<SPECIAL_TOKEN>", "(5,6)"] 290 """ 291 # ty gpt4 292 return re.findall(r"\([^)]*\)|\S+", coords) 293 294 295# back and forth in wrapped form 296# ================================================== 297@overload 298def strings_to_coords( 299 text: str | list[str], 300 when_noncoord: Literal["skip"] = "skip", 301) -> list[CoordTup]: ... 302@overload 303def strings_to_coords( 304 text: str | list[str], 305 when_noncoord: Literal["error"] = "error", 306) -> list[CoordTup]: ... 307@overload 308def strings_to_coords( 309 text: str | list[str], 310 when_noncoord: Literal["include"] = "include", 311) -> list[str | CoordTup]: ... 312def strings_to_coords( 313 text: str | list[str], 314 when_noncoord: WhenMissing = "skip", 315) -> list[str | CoordTup]: 316 """converts a list of tokens to a list of coordinates 317 318 returns list[CoordTup] if `when_noncoord` is "skip" or "error" 319 returns list[str | CoordTup] if `when_noncoord` is "include" 320 """ 321 warnings.warn( 322 "`util.strings_to_coords` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.", 323 TokenizerPendingDeprecationWarning, 324 ) 325 tokens_joined: str = text if isinstance(text, str) else " ".join(text) 326 tokens_processed: list[str] = coords_string_split_UT(tokens_joined) 327 result: list[str] = list() 328 for token in tokens_processed: 329 coord: CoordTup | None = coord_str_to_tuple_noneable(token) 330 if coord is None: 331 if when_noncoord == "skip": 332 continue 333 if when_noncoord == "error": 334 err_msg: str = ( 335 f"Invalid non-coordinate token '{token}' in text: '{text}'" 336 ) 337 raise ValueError( 338 err_msg, 339 ) 340 if when_noncoord == "include": 341 result.append(token) 342 else: 343 err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'" 344 raise ValueError(err_msg) 345 else: 346 result.append(coord) 347 return result 348 349 350@overload 351def coords_to_strings( 352 coords: list[str | CoordTup], 353 coord_to_strings_func: Callable[[CoordTup], list[str]], 354 when_noncoord: Literal["include", "skip"] = "skip", 355) -> list[str]: ... 356@overload 357def coords_to_strings( 358 coords: list[CoordTup], 359 coord_to_strings_func: Callable[[CoordTup], list[str]], 360 when_noncoord: Literal["error"] = "error", 361) -> list[str]: ... 362def coords_to_strings( 363 coords: list[str | CoordTup], 364 coord_to_strings_func: Callable[[CoordTup], list[str]], 365 when_noncoord: WhenMissing = "skip", 366) -> list[str]: 367 """converts a list of coordinates to a list of strings (tokens) 368 369 expects list[CoordTup] if `when_noncoord` is "error" 370 expects list[str | CoordTup] if `when_noncoord` is "include" or "skip" 371 """ 372 result: list[str] = list() 373 for coord in coords: 374 if isinstance(coord, str): 375 if when_noncoord == "skip": 376 continue 377 if when_noncoord == "error": 378 err_msg: str = ( 379 f"Invalid non-coordinate '{coord}' in list of coords: '{coords}'" 380 ) 381 raise ValueError( 382 err_msg, 383 ) 384 if when_noncoord == "include": 385 result.append(coord) 386 else: 387 err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'" 388 raise ValueError(err_msg) 389 else: 390 result.extend(coord_to_strings_func(coord)) 391 return result 392 393 394def get_token_regions(toks: list[str]) -> tuple[list[str], list[str]]: 395 """Splits a list of tokens into adjacency list tokens and non-adjacency list tokens.""" 396 adj_list_start, adj_list_end = ( 397 toks.index("<ADJLIST_START>") + 1, 398 toks.index("<ADJLIST_END>"), 399 ) 400 adj_list = toks[adj_list_start:adj_list_end] 401 non_adj_list = toks[:adj_list_start] + toks[adj_list_end:] 402 return adj_list, non_adj_list 403 404 405def equal_except_adj_list_sequence( # noqa: C901 406 rollout1: list[str], 407 rollout2: list[str], 408 do_except: bool = False, 409 when_counter_mismatch: ErrorMode = ErrorMode.EXCEPT, 410 when_len_mismatch: ErrorMode = ErrorMode.EXCEPT, 411) -> bool: 412 """Returns if the rollout strings are equal, allowing for differently sequenced adjacency lists. 413 414 <ADJLIST_START> and <ADJLIST_END> tokens must be in the rollouts. 415 Intended ONLY for determining if two tokenization schemes are the same for rollouts generated from the same maze. 416 This function should NOT be used to determine if two rollouts encode the same `LatticeMaze` object. 417 418 # Warning: CTT False Positives 419 This function is not robustly correct for some corner cases using `CoordTokenizers.CTT`. 420 If rollouts are passed for identical tokenizers processing two slightly different mazes, a false positive is possible. 421 More specifically, some cases of zero-sum adding and removing of connections in a maze within square regions along the diagonal will produce a false positive. 422 """ 423 if len(rollout1) != len(rollout2): 424 if do_except: 425 when_len_mismatch.process( 426 f"Rollouts are not the same length: {len(rollout1)} != {len(rollout2)}", 427 ) 428 return False 429 if ("<ADJLIST_START>" in rollout1) ^ ("<ADJLIST_START>" in rollout2): 430 if do_except: 431 err_msg: str = f"Rollouts do not have the same <ADJLIST_START> token: `{'<ADJLIST_START>' in rollout1 = }` != `{'<ADJLIST_START>' in rollout2 = }`" 432 raise ValueError( 433 err_msg, 434 ) 435 return False 436 if ("<ADJLIST_END>" in rollout1) ^ ("<ADJLIST_END>" in rollout2): 437 if do_except: 438 err_msg: str = f"Rollouts do not have the same <ADJLIST_END> token: `{'<ADJLIST_END>' in rollout1 = }` != `{'<ADJLIST_END>' in rollout2 = }`" 439 raise ValueError( 440 err_msg, 441 ) 442 return False 443 444 adj_list1, non_adj_list1 = get_token_regions(rollout1) 445 adj_list2, non_adj_list2 = get_token_regions(rollout2) 446 if non_adj_list1 != non_adj_list2: 447 if do_except: 448 when_len_mismatch.process( 449 f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}", 450 ) 451 err_msg: str = f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}" 452 raise ValueError( 453 err_msg, 454 ) 455 return False 456 counter1: Counter = Counter(adj_list1) 457 counter2: Counter = Counter(adj_list2) 458 counters_eq: bool = counter1 == counter2 459 if not counters_eq: 460 if do_except: 461 when_counter_mismatch.process( 462 f"Adjacency list counters are not the same:\n{counter1}\n!=\n{counter2}\n{counter1 - counter2 = }", 463 ) 464 return False 465 466 return True 467 468 469def connection_list_to_adj_list( 470 conn_list: ConnectionList, 471 shuffle_d0: bool = True, 472 shuffle_d1: bool = True, 473) -> Int8[np.ndarray, "conn start_end=2 coord=2"]: 474 """converts a `ConnectionList` (special lattice format) to a shuffled adjacency list 475 476 # Parameters: 477 - `conn_list: ConnectionList` 478 special internal format for graphs which are subgraphs of a lattice 479 - `shuffle_d0: bool` 480 shuffle the adjacency list along the 0th axis (order of pairs) 481 - `shuffle_d1: bool` 482 shuffle the adjacency list along the 1st axis (order of coordinates in each pair). 483 If `False`, all pairs have the smaller coord first. 484 485 486 # Returns: 487 - `Int8[np.ndarray, "conn start_end=2 coord=2"]` 488 adjacency list in the shape `(n_connections, 2, 2)` 489 """ 490 n_connections: int = conn_list.sum() 491 adj_list: Int8[np.ndarray, "conn start_end=2 coord=2"] = np.full( 492 (n_connections, 2, 2), 493 -1, 494 dtype=np.int8, 495 ) 496 497 if shuffle_d1: 498 flip_d1: Float[np.ndarray, " conn"] = np.random.rand(n_connections) 499 500 # loop over all nonzero elements of the connection list 501 i: int = 0 502 for d, x, y in np.ndindex(conn_list.shape): 503 if conn_list[d, x, y]: 504 c_start: CoordTup = (x, y) 505 c_end: CoordTup = ( 506 x + (1 if d == 0 else 0), 507 y + (1 if d == 1 else 0), 508 ) 509 adj_list[i, 0] = np.array(c_start, dtype=np.int8) 510 adj_list[i, 1] = np.array(c_end, dtype=np.int8) 511 512 # flip if shuffling 513 # magic value is fine here 514 if shuffle_d1 and (flip_d1[i] > 0.5): # noqa: PLR2004 515 c_s, c_e = adj_list[i, 0].copy(), adj_list[i, 1].copy() 516 adj_list[i, 0] = c_e 517 adj_list[i, 1] = c_s 518 519 i += 1 520 521 if shuffle_d0: 522 np.random.shuffle(adj_list) 523 524 return adj_list 525 526 527def is_connection( 528 edges: ConnectionArray, 529 connection_list: ConnectionList, 530) -> Bool[np.ndarray, "is_connection=edges"]: 531 """Returns if each edge in `edges` is a connection (`True`) or wall (`False`) in `connection_list`.""" 532 sorted_edges = np.sort(edges, axis=1) 533 edge_direction = ( 534 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0 535 ).astype(np.int8) 536 return connection_list[edge_direction, sorted_edges[:, 0, 0], sorted_edges[:, 0, 1]] 537 538 539# string to coordinate representation 540# ==================================================
29def remove_padding_from_token_str(token_str: str) -> str: 30 """remove padding tokens from a joined token string""" 31 token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING} ", "") 32 token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING}", "") 33 return token_str # noqa: RET504
remove padding tokens from a joined token string
36def tokens_between( 37 tokens: list[str], 38 start_value: str, 39 end_value: str, 40 include_start: bool = False, 41 include_end: bool = False, 42 except_when_tokens_not_unique: bool = False, 43) -> list[str]: 44 """given a list `tokens`, get the tokens between `start_value` and `end_value` 45 46 _extended_summary_ 47 48 # Parameters: 49 - `tokens : list[str]` 50 - `start_value : str` 51 - `end_value : str` 52 - `include_start : bool` 53 (defaults to `False`) 54 - `include_end : bool` 55 (defaults to `False`) 56 - `except_when_tokens_not_unique : bool` 57 when `True`, raise an error if `start_value` or `end_value` are not unique in the input tokens 58 (defaults to `False`) 59 60 # Returns: 61 - `list[str]` 62 63 # Raises: 64 - `ValueError` : if `start_value` and `end_value` are the same 65 - `ValueError` : if `except_when_tokens_not_unique` is `True` and `start_value` or `end_value` are not unique in the input tokens 66 - `ValueError` : if `start_value` or `end_value` are not present in the input tokens 67 """ 68 if start_value == end_value: 69 err_msg: str = f"start_value and end_value cannot be the same: {start_value = } {end_value = }" 70 raise ValueError( 71 err_msg, 72 ) 73 if except_when_tokens_not_unique: 74 if (tokens.count(start_value) != 1) or (tokens.count(end_value) != 1): 75 err_msg: str = ( 76 "start_value or end_value is not unique in the input tokens:" 77 f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }" 78 f"\n{start_value = } {end_value = }" 79 f"\n{tokens = }" 80 ) 81 raise ValueError(err_msg) 82 else: 83 if (tokens.count(start_value) < 1) or (tokens.count(end_value) < 1): 84 err_msg: str = ( 85 "start_value or end_value is not present in the input tokens:" 86 f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }" 87 f"\n{start_value = } {end_value = }" 88 f"\n{tokens = }" 89 ) 90 raise ValueError(err_msg) 91 92 start_idx: int = tokens.index(start_value) + int(not include_start) 93 end_idx: int = tokens.index(end_value) + int(include_end) 94 95 assert start_idx < end_idx, "Start must come before end" 96 97 return tokens[start_idx:end_idx]
given a list tokens
, get the tokens between start_value
and end_value
_extended_summary_
Parameters:
tokens : list[str]
start_value : str
end_value : str
include_start : bool
(defaults toFalse
)include_end : bool
(defaults toFalse
)except_when_tokens_not_unique : bool
whenTrue
, raise an error ifstart_value
orend_value
are not unique in the input tokens (defaults toFalse
)
Returns:
list[str]
Raises:
ValueError
: ifstart_value
andend_value
are the sameValueError
: ifexcept_when_tokens_not_unique
isTrue
andstart_value
orend_value
are not unique in the input tokensValueError
: ifstart_value
orend_value
are not present in the input tokens
100def get_adj_list_tokens(tokens: list[str]) -> list[str]: 101 "get tokens between ADJLIST_START and ADJLIST_END, without the special tokens themselves" 102 return tokens_between( 103 tokens, 104 SPECIAL_TOKENS.ADJLIST_START, 105 SPECIAL_TOKENS.ADJLIST_END, 106 )
get tokens between ADJLIST_START and ADJLIST_END, without the special tokens themselves
109def get_path_tokens(tokens: list[str], trim_end: bool = False) -> list[str]: 110 """The path is considered everything from the first path coord to the path_end token, if it exists.""" 111 if SPECIAL_TOKENS.PATH_START not in tokens: 112 err_msg: str = f"Path start token {SPECIAL_TOKENS.PATH_START} not found in tokens:\n{tokens}" 113 raise ValueError( 114 err_msg, 115 ) 116 start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + int(trim_end) 117 end_idx: int | None = None 118 if trim_end and (SPECIAL_TOKENS.PATH_END in tokens): 119 end_idx = tokens.index(SPECIAL_TOKENS.PATH_END) 120 return tokens[start_idx:end_idx]
The path is considered everything from the first path coord to the path_end token, if it exists.
123def get_context_tokens(tokens: list[str]) -> list[str]: 124 "get tokens between ADJLIST_START and PATH_START" 125 return tokens_between( 126 tokens, 127 SPECIAL_TOKENS.ADJLIST_START, 128 SPECIAL_TOKENS.PATH_START, 129 include_start=True, 130 include_end=True, 131 )
get tokens between ADJLIST_START and PATH_START
134def get_origin_tokens(tokens: list[str]) -> list[str]: 135 "get tokens_between ORIGIN_START and ORIGIN_END" 136 return tokens_between( 137 tokens, 138 SPECIAL_TOKENS.ORIGIN_START, 139 SPECIAL_TOKENS.ORIGIN_END, 140 include_start=False, 141 include_end=False, 142 )
get tokens_between ORIGIN_START and ORIGIN_END
145def get_target_tokens(tokens: list[str]) -> list[str]: 146 "get tokens_between TARGET_START and TARGET_END" 147 return tokens_between( 148 tokens, 149 SPECIAL_TOKENS.TARGET_START, 150 SPECIAL_TOKENS.TARGET_END, 151 include_start=False, 152 include_end=False, 153 )
get tokens_between TARGET_START and TARGET_END
156def get_cardinal_direction(coords: Int[np.ndarray, "start_end=2 row_col=2"]) -> str: 157 """Returns the cardinal direction token corresponding to traveling from `coords[0]` to `coords[1]`.""" 158 return CARDINAL_MAP[tuple(coords[1] - coords[0])]
Returns the cardinal direction token corresponding to traveling from coords[0]
to coords[1]
.
161def get_relative_direction(coords: Int[np.ndarray, "prev_cur_next=3 row_col=2"]) -> str: 162 """Returns the relative first-person direction token corresponding to traveling from `coords[1]` to `coords[2]`. 163 164 # Parameters 165 - `coords`: Contains 3 Coords, each of which must neighbor the previous Coord. 166 - `coords[0]`: The previous location, used to determine the current absolute direction that the "agent" is facing. 167 - `coords[1]`: The current location 168 - `coords[2]`: The next location. May be equal to the current location. 169 """ 170 if coords.shape != (3, 2): 171 err_msg: str = f"`coords` must have shape (3,2). Got {coords.shape} instead." 172 raise ValueError(err_msg) 173 directions = coords[1:] - coords[:-1] 174 if not np.all(np.linalg.norm(directions, axis=1) <= np.array([1.1, 1.1])): 175 # Use floats as constant since `np.linalg.norm` returns float array 176 err_msg: str = f"Adjacent `coords` must be neighboring or equivalent. Got {coords} instead." 177 raise ValueError( 178 err_msg, 179 ) 180 if np.array_equal(coords[1], coords[2]): 181 return VOCAB.PATH_STAY 182 if np.array_equal(coords[0], coords[2]): 183 return VOCAB.PATH_BACKWARD 184 if np.array_equal(coords[0], coords[1]): 185 err_msg: str = f"Previous first-person direction indeterminate from {coords=}." 186 raise ValueError( 187 err_msg, 188 ) 189 if np.array_equal(directions[0], directions[1]): 190 return VOCAB.PATH_FORWARD 191 directions = np.append( 192 directions, 193 [[0], [0]], 194 axis=1, 195 ) # Augment to represent unit basis vectors in 3D 196 match np.cross(directions[0], directions[1])[-1]: 197 case 1: 198 return VOCAB.PATH_LEFT 199 case -1: 200 return VOCAB.PATH_RIGHT
Returns the relative first-person direction token corresponding to traveling from coords[1]
to coords[2]
.
Parameters
coords
: Contains 3 Coords, each of which must neighbor the previous Coord.coords[0]
: The previous location, used to determine the current absolute direction that the "agent" is facing.coords[1]
: The current locationcoords[2]
: The next location. May be equal to the current location.
203class TokenizerPendingDeprecationWarning(PendingDeprecationWarning): 204 """Pending deprecation warnings related to the `MazeTokenizerModular` upgrade.""" 205 206 pass
Pending deprecation warnings related to the MazeTokenizerModular
upgrade.
Inherited Members
- builtins.PendingDeprecationWarning
- PendingDeprecationWarning
- builtins.BaseException
- with_traceback
- add_note
- args
209def str_is_coord(coord_str: str, allow_whitespace: bool = True) -> bool: 210 """return True if the string represents a coordinate, False otherwise""" 211 warnings.warn( 212 "`util.str_is_coord` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.", 213 TokenizerPendingDeprecationWarning, 214 ) 215 strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x # noqa: E731 216 217 coord_str = strip_func(coord_str) 218 219 return all( 220 [ 221 coord_str.startswith("("), 222 coord_str.endswith(")"), 223 "," in coord_str, 224 all( 225 strip_func(x).isdigit() 226 for x in strip_func(coord_str.lstrip("(").rstrip(")")).split(",") 227 ), 228 ], 229 )
return True if the string represents a coordinate, False otherwise
232class TokenizerDeprecationWarning(DeprecationWarning): 233 """Deprecation warnings related to the `MazeTokenizerModular` upgrade.""" 234 235 pass
Deprecation warnings related to the MazeTokenizerModular
upgrade.
Inherited Members
- builtins.DeprecationWarning
- DeprecationWarning
- builtins.BaseException
- with_traceback
- add_note
- args
262def coord_str_to_tuple( 263 coord_str: str, 264 allow_whitespace: bool = True, 265) -> tuple[int, ...]: 266 """convert a coordinate string to a tuple""" 267 strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x # noqa: E731 268 coord_str = strip_func(coord_str) 269 stripped: str = strip_func(coord_str.lstrip("(").rstrip(")")) 270 return tuple(int(strip_func(x)) for x in stripped.split(","))
convert a coordinate string to a tuple
273def coord_str_to_coord_np(coord_str: str, allow_whitespace: bool = True) -> np.ndarray: 274 """convert a coordinate string to a numpy array""" 275 return np.array(coord_str_to_tuple(coord_str, allow_whitespace=allow_whitespace))
convert a coordinate string to a numpy array
278def coord_str_to_tuple_noneable(coord_str: str) -> CoordTup | None: 279 """convert a coordinate string to a tuple, or None if the string is not a coordinate string""" 280 if not str_is_coord(coord_str): 281 return None 282 return coord_str_to_tuple(coord_str)
convert a coordinate string to a tuple, or None if the string is not a coordinate string
285def coords_string_split_UT(coords: str) -> list[str]: 286 """Splits a string of tokens into a list containing the UT tokens for each coordinate. 287 288 Not capable of producing indexed tokens ("(", "1", ",", "2", ")"), only unique tokens ("(1,2)"). 289 Non-whitespace portions of the input string not matched are preserved in the same list: 290 "(1,2) <SPECIAL_TOKEN> (5,6)" -> ["(1,2)", "<SPECIAL_TOKEN>", "(5,6)"] 291 """ 292 # ty gpt4 293 return re.findall(r"\([^)]*\)|\S+", coords)
Splits a string of tokens into a list containing the UT tokens for each coordinate.
Not capable of producing indexed tokens ("(", "1", ",", "2", ")"), only unique tokens ("(1,2)").
Non-whitespace portions of the input string not matched are preserved in the same list:
"(1,2)
313def strings_to_coords( 314 text: str | list[str], 315 when_noncoord: WhenMissing = "skip", 316) -> list[str | CoordTup]: 317 """converts a list of tokens to a list of coordinates 318 319 returns list[CoordTup] if `when_noncoord` is "skip" or "error" 320 returns list[str | CoordTup] if `when_noncoord` is "include" 321 """ 322 warnings.warn( 323 "`util.strings_to_coords` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.", 324 TokenizerPendingDeprecationWarning, 325 ) 326 tokens_joined: str = text if isinstance(text, str) else " ".join(text) 327 tokens_processed: list[str] = coords_string_split_UT(tokens_joined) 328 result: list[str] = list() 329 for token in tokens_processed: 330 coord: CoordTup | None = coord_str_to_tuple_noneable(token) 331 if coord is None: 332 if when_noncoord == "skip": 333 continue 334 if when_noncoord == "error": 335 err_msg: str = ( 336 f"Invalid non-coordinate token '{token}' in text: '{text}'" 337 ) 338 raise ValueError( 339 err_msg, 340 ) 341 if when_noncoord == "include": 342 result.append(token) 343 else: 344 err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'" 345 raise ValueError(err_msg) 346 else: 347 result.append(coord) 348 return result
converts a list of tokens to a list of coordinates
returns list[CoordTup] if when_noncoord
is "skip" or "error"
returns list[str | CoordTup] if when_noncoord
is "include"
363def coords_to_strings( 364 coords: list[str | CoordTup], 365 coord_to_strings_func: Callable[[CoordTup], list[str]], 366 when_noncoord: WhenMissing = "skip", 367) -> list[str]: 368 """converts a list of coordinates to a list of strings (tokens) 369 370 expects list[CoordTup] if `when_noncoord` is "error" 371 expects list[str | CoordTup] if `when_noncoord` is "include" or "skip" 372 """ 373 result: list[str] = list() 374 for coord in coords: 375 if isinstance(coord, str): 376 if when_noncoord == "skip": 377 continue 378 if when_noncoord == "error": 379 err_msg: str = ( 380 f"Invalid non-coordinate '{coord}' in list of coords: '{coords}'" 381 ) 382 raise ValueError( 383 err_msg, 384 ) 385 if when_noncoord == "include": 386 result.append(coord) 387 else: 388 err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'" 389 raise ValueError(err_msg) 390 else: 391 result.extend(coord_to_strings_func(coord)) 392 return result
converts a list of coordinates to a list of strings (tokens)
expects list[CoordTup] if when_noncoord
is "error"
expects list[str | CoordTup] if when_noncoord
is "include" or "skip"
395def get_token_regions(toks: list[str]) -> tuple[list[str], list[str]]: 396 """Splits a list of tokens into adjacency list tokens and non-adjacency list tokens.""" 397 adj_list_start, adj_list_end = ( 398 toks.index("<ADJLIST_START>") + 1, 399 toks.index("<ADJLIST_END>"), 400 ) 401 adj_list = toks[adj_list_start:adj_list_end] 402 non_adj_list = toks[:adj_list_start] + toks[adj_list_end:] 403 return adj_list, non_adj_list
Splits a list of tokens into adjacency list tokens and non-adjacency list tokens.
406def equal_except_adj_list_sequence( # noqa: C901 407 rollout1: list[str], 408 rollout2: list[str], 409 do_except: bool = False, 410 when_counter_mismatch: ErrorMode = ErrorMode.EXCEPT, 411 when_len_mismatch: ErrorMode = ErrorMode.EXCEPT, 412) -> bool: 413 """Returns if the rollout strings are equal, allowing for differently sequenced adjacency lists. 414 415 <ADJLIST_START> and <ADJLIST_END> tokens must be in the rollouts. 416 Intended ONLY for determining if two tokenization schemes are the same for rollouts generated from the same maze. 417 This function should NOT be used to determine if two rollouts encode the same `LatticeMaze` object. 418 419 # Warning: CTT False Positives 420 This function is not robustly correct for some corner cases using `CoordTokenizers.CTT`. 421 If rollouts are passed for identical tokenizers processing two slightly different mazes, a false positive is possible. 422 More specifically, some cases of zero-sum adding and removing of connections in a maze within square regions along the diagonal will produce a false positive. 423 """ 424 if len(rollout1) != len(rollout2): 425 if do_except: 426 when_len_mismatch.process( 427 f"Rollouts are not the same length: {len(rollout1)} != {len(rollout2)}", 428 ) 429 return False 430 if ("<ADJLIST_START>" in rollout1) ^ ("<ADJLIST_START>" in rollout2): 431 if do_except: 432 err_msg: str = f"Rollouts do not have the same <ADJLIST_START> token: `{'<ADJLIST_START>' in rollout1 = }` != `{'<ADJLIST_START>' in rollout2 = }`" 433 raise ValueError( 434 err_msg, 435 ) 436 return False 437 if ("<ADJLIST_END>" in rollout1) ^ ("<ADJLIST_END>" in rollout2): 438 if do_except: 439 err_msg: str = f"Rollouts do not have the same <ADJLIST_END> token: `{'<ADJLIST_END>' in rollout1 = }` != `{'<ADJLIST_END>' in rollout2 = }`" 440 raise ValueError( 441 err_msg, 442 ) 443 return False 444 445 adj_list1, non_adj_list1 = get_token_regions(rollout1) 446 adj_list2, non_adj_list2 = get_token_regions(rollout2) 447 if non_adj_list1 != non_adj_list2: 448 if do_except: 449 when_len_mismatch.process( 450 f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}", 451 ) 452 err_msg: str = f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}" 453 raise ValueError( 454 err_msg, 455 ) 456 return False 457 counter1: Counter = Counter(adj_list1) 458 counter2: Counter = Counter(adj_list2) 459 counters_eq: bool = counter1 == counter2 460 if not counters_eq: 461 if do_except: 462 when_counter_mismatch.process( 463 f"Adjacency list counters are not the same:\n{counter1}\n!=\n{counter2}\n{counter1 - counter2 = }", 464 ) 465 return False 466 467 return True
Returns if the rollout strings are equal, allowing for differently sequenced adjacency lists.
LatticeMaze
object.
Warning: CTT False Positives
This function is not robustly correct for some corner cases using CoordTokenizers.CTT
.
If rollouts are passed for identical tokenizers processing two slightly different mazes, a false positive is possible.
More specifically, some cases of zero-sum adding and removing of connections in a maze within square regions along the diagonal will produce a false positive.
470def connection_list_to_adj_list( 471 conn_list: ConnectionList, 472 shuffle_d0: bool = True, 473 shuffle_d1: bool = True, 474) -> Int8[np.ndarray, "conn start_end=2 coord=2"]: 475 """converts a `ConnectionList` (special lattice format) to a shuffled adjacency list 476 477 # Parameters: 478 - `conn_list: ConnectionList` 479 special internal format for graphs which are subgraphs of a lattice 480 - `shuffle_d0: bool` 481 shuffle the adjacency list along the 0th axis (order of pairs) 482 - `shuffle_d1: bool` 483 shuffle the adjacency list along the 1st axis (order of coordinates in each pair). 484 If `False`, all pairs have the smaller coord first. 485 486 487 # Returns: 488 - `Int8[np.ndarray, "conn start_end=2 coord=2"]` 489 adjacency list in the shape `(n_connections, 2, 2)` 490 """ 491 n_connections: int = conn_list.sum() 492 adj_list: Int8[np.ndarray, "conn start_end=2 coord=2"] = np.full( 493 (n_connections, 2, 2), 494 -1, 495 dtype=np.int8, 496 ) 497 498 if shuffle_d1: 499 flip_d1: Float[np.ndarray, " conn"] = np.random.rand(n_connections) 500 501 # loop over all nonzero elements of the connection list 502 i: int = 0 503 for d, x, y in np.ndindex(conn_list.shape): 504 if conn_list[d, x, y]: 505 c_start: CoordTup = (x, y) 506 c_end: CoordTup = ( 507 x + (1 if d == 0 else 0), 508 y + (1 if d == 1 else 0), 509 ) 510 adj_list[i, 0] = np.array(c_start, dtype=np.int8) 511 adj_list[i, 1] = np.array(c_end, dtype=np.int8) 512 513 # flip if shuffling 514 # magic value is fine here 515 if shuffle_d1 and (flip_d1[i] > 0.5): # noqa: PLR2004 516 c_s, c_e = adj_list[i, 0].copy(), adj_list[i, 1].copy() 517 adj_list[i, 0] = c_e 518 adj_list[i, 1] = c_s 519 520 i += 1 521 522 if shuffle_d0: 523 np.random.shuffle(adj_list) 524 525 return adj_list
converts a ConnectionList
(special lattice format) to a shuffled adjacency list
Parameters:
conn_list: ConnectionList
special internal format for graphs which are subgraphs of a latticeshuffle_d0: bool
shuffle the adjacency list along the 0th axis (order of pairs)shuffle_d1: bool
shuffle the adjacency list along the 1st axis (order of coordinates in each pair). IfFalse
, all pairs have the smaller coord first.
Returns:
Int8[np.ndarray, "conn start_end=2 coord=2"]
adjacency list in the shape(n_connections, 2, 2)
528def is_connection( 529 edges: ConnectionArray, 530 connection_list: ConnectionList, 531) -> Bool[np.ndarray, "is_connection=edges"]: 532 """Returns if each edge in `edges` is a connection (`True`) or wall (`False`) in `connection_list`.""" 533 sorted_edges = np.sort(edges, axis=1) 534 edge_direction = ( 535 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0 536 ).astype(np.int8) 537 return connection_list[edge_direction, sorted_edges[:, 0, 0], sorted_edges[:, 0, 1]]
Returns if each edge in edges
is a connection (True
) or wall (False
) in connection_list
.