Coverage for maze_dataset/token_utils.py: 64%

202 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 00:49 -0600

1"""a whole bunch of utilities for tokenization""" 

2 

3import re 

4import typing 

5import warnings 

6from collections import Counter 

7from typing import Callable, Literal, overload 

8 

9import numpy as np 

10from jaxtyping import Bool, Float, Int, Int8 

11from muutils.errormode import ErrorMode 

12from muutils.misc import list_join 

13from muutils.misc.sequence import WhenMissing 

14 

15from maze_dataset.constants import ( 

16 CARDINAL_MAP, 

17 SPECIAL_TOKENS, 

18 VOCAB, 

19 ConnectionArray, 

20 ConnectionList, 

21 CoordTup, 

22) 

23 

24# filtering things from a prompt or generated text 

25# ================================================== 

26 

27 

28def remove_padding_from_token_str(token_str: str) -> str: 

29 """remove padding tokens from a joined token string""" 

30 token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING} ", "") 

31 token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING}", "") 

32 return token_str # noqa: RET504 

33 

34 

35def tokens_between( 

36 tokens: list[str], 

37 start_value: str, 

38 end_value: str, 

39 include_start: bool = False, 

40 include_end: bool = False, 

41 except_when_tokens_not_unique: bool = False, 

42) -> list[str]: 

43 """given a list `tokens`, get the tokens between `start_value` and `end_value` 

44 

45 _extended_summary_ 

46 

47 # Parameters: 

48 - `tokens : list[str]` 

49 - `start_value : str` 

50 - `end_value : str` 

51 - `include_start : bool` 

52 (defaults to `False`) 

53 - `include_end : bool` 

54 (defaults to `False`) 

55 - `except_when_tokens_not_unique : bool` 

56 when `True`, raise an error if `start_value` or `end_value` are not unique in the input tokens 

57 (defaults to `False`) 

58 

59 # Returns: 

60 - `list[str]` 

61 

62 # Raises: 

63 - `ValueError` : if `start_value` and `end_value` are the same 

64 - `ValueError` : if `except_when_tokens_not_unique` is `True` and `start_value` or `end_value` are not unique in the input tokens 

65 - `ValueError` : if `start_value` or `end_value` are not present in the input tokens 

66 """ 

67 if start_value == end_value: 

68 err_msg: str = f"start_value and end_value cannot be the same: {start_value = } {end_value = }" 

69 raise ValueError( 

70 err_msg, 

71 ) 

72 if except_when_tokens_not_unique: 

73 if (tokens.count(start_value) != 1) or (tokens.count(end_value) != 1): 

74 err_msg: str = ( 

75 "start_value or end_value is not unique in the input tokens:" 

76 f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }" 

77 f"\n{start_value = } {end_value = }" 

78 f"\n{tokens = }" 

79 ) 

80 raise ValueError(err_msg) 

81 else: 

82 if (tokens.count(start_value) < 1) or (tokens.count(end_value) < 1): 

83 err_msg: str = ( 

84 "start_value or end_value is not present in the input tokens:" 

85 f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }" 

86 f"\n{start_value = } {end_value = }" 

87 f"\n{tokens = }" 

88 ) 

89 raise ValueError(err_msg) 

90 

91 start_idx: int = tokens.index(start_value) + int(not include_start) 

92 end_idx: int = tokens.index(end_value) + int(include_end) 

93 

94 assert start_idx < end_idx, "Start must come before end" 

95 

96 return tokens[start_idx:end_idx] 

97 

98 

99def get_adj_list_tokens(tokens: list[str]) -> list[str]: 

100 "get tokens between ADJLIST_START and ADJLIST_END, without the special tokens themselves" 

101 return tokens_between( 

102 tokens, 

103 SPECIAL_TOKENS.ADJLIST_START, 

104 SPECIAL_TOKENS.ADJLIST_END, 

105 ) 

106 

107 

108def get_path_tokens(tokens: list[str], trim_end: bool = False) -> list[str]: 

109 """The path is considered everything from the first path coord to the path_end token, if it exists.""" 

110 if SPECIAL_TOKENS.PATH_START not in tokens: 

111 err_msg: str = f"Path start token {SPECIAL_TOKENS.PATH_START} not found in tokens:\n{tokens}" 

112 raise ValueError( 

113 err_msg, 

114 ) 

115 start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + int(trim_end) 

116 end_idx: int | None = None 

117 if trim_end and (SPECIAL_TOKENS.PATH_END in tokens): 

118 end_idx = tokens.index(SPECIAL_TOKENS.PATH_END) 

119 return tokens[start_idx:end_idx] 

120 

121 

122def get_context_tokens(tokens: list[str]) -> list[str]: 

123 "get tokens between ADJLIST_START and PATH_START" 

124 return tokens_between( 

125 tokens, 

126 SPECIAL_TOKENS.ADJLIST_START, 

127 SPECIAL_TOKENS.PATH_START, 

128 include_start=True, 

129 include_end=True, 

130 ) 

131 

132 

133def get_origin_tokens(tokens: list[str]) -> list[str]: 

134 "get tokens_between ORIGIN_START and ORIGIN_END" 

135 return tokens_between( 

136 tokens, 

137 SPECIAL_TOKENS.ORIGIN_START, 

138 SPECIAL_TOKENS.ORIGIN_END, 

139 include_start=False, 

140 include_end=False, 

141 ) 

142 

143 

144def get_target_tokens(tokens: list[str]) -> list[str]: 

145 "get tokens_between TARGET_START and TARGET_END" 

146 return tokens_between( 

147 tokens, 

148 SPECIAL_TOKENS.TARGET_START, 

149 SPECIAL_TOKENS.TARGET_END, 

150 include_start=False, 

151 include_end=False, 

152 ) 

153 

154 

155def get_cardinal_direction(coords: Int[np.ndarray, "start_end=2 row_col=2"]) -> str: 

156 """Returns the cardinal direction token corresponding to traveling from `coords[0]` to `coords[1]`.""" 

157 return CARDINAL_MAP[tuple(coords[1] - coords[0])] 

158 

159 

160def get_relative_direction(coords: Int[np.ndarray, "prev_cur_next=3 row_col=2"]) -> str: 

161 """Returns the relative first-person direction token corresponding to traveling from `coords[1]` to `coords[2]`. 

162 

163 # Parameters 

164 - `coords`: Contains 3 Coords, each of which must neighbor the previous Coord. 

165 - `coords[0]`: The previous location, used to determine the current absolute direction that the "agent" is facing. 

166 - `coords[1]`: The current location 

167 - `coords[2]`: The next location. May be equal to the current location. 

168 """ 

169 if coords.shape != (3, 2): 

170 err_msg: str = f"`coords` must have shape (3,2). Got {coords.shape} instead." 

171 raise ValueError(err_msg) 

172 directions = coords[1:] - coords[:-1] 

173 if not np.all(np.linalg.norm(directions, axis=1) <= np.array([1.1, 1.1])): 

174 # Use floats as constant since `np.linalg.norm` returns float array 

175 err_msg: str = f"Adjacent `coords` must be neighboring or equivalent. Got {coords} instead." 

176 raise ValueError( 

177 err_msg, 

178 ) 

179 if np.array_equal(coords[1], coords[2]): 

180 return VOCAB.PATH_STAY 

181 if np.array_equal(coords[0], coords[2]): 

182 return VOCAB.PATH_BACKWARD 

183 if np.array_equal(coords[0], coords[1]): 

184 err_msg: str = f"Previous first-person direction indeterminate from {coords=}." 

185 raise ValueError( 

186 err_msg, 

187 ) 

188 if np.array_equal(directions[0], directions[1]): 

189 return VOCAB.PATH_FORWARD 

190 directions = np.append( 

191 directions, 

192 [[0], [0]], 

193 axis=1, 

194 ) # Augment to represent unit basis vectors in 3D 

195 match np.cross(directions[0], directions[1])[-1]: 

196 case 1: 

197 return VOCAB.PATH_LEFT 

198 case -1: 

199 return VOCAB.PATH_RIGHT 

200 

201 

202class TokenizerPendingDeprecationWarning(PendingDeprecationWarning): 

203 """Pending deprecation warnings related to the `MazeTokenizerModular` upgrade.""" 

204 

205 pass 

206 

207 

208def str_is_coord(coord_str: str, allow_whitespace: bool = True) -> bool: 

209 """return True if the string represents a coordinate, False otherwise""" 

210 warnings.warn( 

211 "`util.str_is_coord` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.", 

212 TokenizerPendingDeprecationWarning, 

213 ) 

214 strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x # noqa: E731 

215 

216 coord_str = strip_func(coord_str) 

217 

218 return all( 

219 [ 

220 coord_str.startswith("("), 

221 coord_str.endswith(")"), 

222 "," in coord_str, 

223 all( 

224 strip_func(x).isdigit() 

225 for x in strip_func(coord_str.lstrip("(").rstrip(")")).split(",") 

226 ), 

227 ], 

228 ) 

229 

230 

231class TokenizerDeprecationWarning(DeprecationWarning): 

232 """Deprecation warnings related to the `MazeTokenizerModular` upgrade.""" 

233 

234 pass 

235 

236 

237# coordinate to strings 

238# ================================================== 

239 

240 

241def _coord_to_strings_UT(coord: typing.Sequence[int]) -> list[str]: 

242 """convert a coordinate to a string: `(i,j)`->"(i,j)" 

243 

244 always returns a list of length 1 

245 """ 

246 return [f"({','.join(str(c) for c in coord)})"] 

247 

248 

249def _coord_to_strings_indexed(coord: typing.Sequence[int]) -> list[str]: 

250 """convert a coordinate to a list of indexed strings: `(i,j)`->"(", "i", ",", "j", ")" 

251 

252 always returns a list of length 5 

253 """ 

254 return [ 

255 "(", 

256 *list_join([str(c) for c in coord], lambda: ","), 

257 ")", 

258 ] 

259 

260 

261def coord_str_to_tuple( 

262 coord_str: str, 

263 allow_whitespace: bool = True, 

264) -> tuple[int, ...]: 

265 """convert a coordinate string to a tuple""" 

266 strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x # noqa: E731 

267 coord_str = strip_func(coord_str) 

268 stripped: str = strip_func(coord_str.lstrip("(").rstrip(")")) 

269 return tuple(int(strip_func(x)) for x in stripped.split(",")) 

270 

271 

272def coord_str_to_coord_np(coord_str: str, allow_whitespace: bool = True) -> np.ndarray: 

273 """convert a coordinate string to a numpy array""" 

274 return np.array(coord_str_to_tuple(coord_str, allow_whitespace=allow_whitespace)) 

275 

276 

277def coord_str_to_tuple_noneable(coord_str: str) -> CoordTup | None: 

278 """convert a coordinate string to a tuple, or None if the string is not a coordinate string""" 

279 if not str_is_coord(coord_str): 

280 return None 

281 return coord_str_to_tuple(coord_str) 

282 

283 

284def coords_string_split_UT(coords: str) -> list[str]: 

285 """Splits a string of tokens into a list containing the UT tokens for each coordinate. 

286 

287 Not capable of producing indexed tokens ("(", "1", ",", "2", ")"), only unique tokens ("(1,2)"). 

288 Non-whitespace portions of the input string not matched are preserved in the same list: 

289 "(1,2) <SPECIAL_TOKEN> (5,6)" -> ["(1,2)", "<SPECIAL_TOKEN>", "(5,6)"] 

290 """ 

291 # ty gpt4 

292 return re.findall(r"\([^)]*\)|\S+", coords) 

293 

294 

295# back and forth in wrapped form 

296# ================================================== 

297@overload 

298def strings_to_coords( 

299 text: str | list[str], 

300 when_noncoord: Literal["skip"] = "skip", 

301) -> list[CoordTup]: ... 

302@overload 

303def strings_to_coords( 

304 text: str | list[str], 

305 when_noncoord: Literal["error"] = "error", 

306) -> list[CoordTup]: ... 

307@overload 

308def strings_to_coords( 

309 text: str | list[str], 

310 when_noncoord: Literal["include"] = "include", 

311) -> list[str | CoordTup]: ... 

312def strings_to_coords( 

313 text: str | list[str], 

314 when_noncoord: WhenMissing = "skip", 

315) -> list[str | CoordTup]: 

316 """converts a list of tokens to a list of coordinates 

317 

318 returns list[CoordTup] if `when_noncoord` is "skip" or "error" 

319 returns list[str | CoordTup] if `when_noncoord` is "include" 

320 """ 

321 warnings.warn( 

322 "`util.strings_to_coords` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.", 

323 TokenizerPendingDeprecationWarning, 

324 ) 

325 tokens_joined: str = text if isinstance(text, str) else " ".join(text) 

326 tokens_processed: list[str] = coords_string_split_UT(tokens_joined) 

327 result: list[str] = list() 

328 for token in tokens_processed: 

329 coord: CoordTup | None = coord_str_to_tuple_noneable(token) 

330 if coord is None: 

331 if when_noncoord == "skip": 

332 continue 

333 if when_noncoord == "error": 

334 err_msg: str = ( 

335 f"Invalid non-coordinate token '{token}' in text: '{text}'" 

336 ) 

337 raise ValueError( 

338 err_msg, 

339 ) 

340 if when_noncoord == "include": 

341 result.append(token) 

342 else: 

343 err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'" 

344 raise ValueError(err_msg) 

345 else: 

346 result.append(coord) 

347 return result 

348 

349 

350@overload 

351def coords_to_strings( 

352 coords: list[str | CoordTup], 

353 coord_to_strings_func: Callable[[CoordTup], list[str]], 

354 when_noncoord: Literal["include", "skip"] = "skip", 

355) -> list[str]: ... 

356@overload 

357def coords_to_strings( 

358 coords: list[CoordTup], 

359 coord_to_strings_func: Callable[[CoordTup], list[str]], 

360 when_noncoord: Literal["error"] = "error", 

361) -> list[str]: ... 

362def coords_to_strings( 

363 coords: list[str | CoordTup], 

364 coord_to_strings_func: Callable[[CoordTup], list[str]], 

365 when_noncoord: WhenMissing = "skip", 

366) -> list[str]: 

367 """converts a list of coordinates to a list of strings (tokens) 

368 

369 expects list[CoordTup] if `when_noncoord` is "error" 

370 expects list[str | CoordTup] if `when_noncoord` is "include" or "skip" 

371 """ 

372 result: list[str] = list() 

373 for coord in coords: 

374 if isinstance(coord, str): 

375 if when_noncoord == "skip": 

376 continue 

377 if when_noncoord == "error": 

378 err_msg: str = ( 

379 f"Invalid non-coordinate '{coord}' in list of coords: '{coords}'" 

380 ) 

381 raise ValueError( 

382 err_msg, 

383 ) 

384 if when_noncoord == "include": 

385 result.append(coord) 

386 else: 

387 err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'" 

388 raise ValueError(err_msg) 

389 else: 

390 result.extend(coord_to_strings_func(coord)) 

391 return result 

392 

393 

394def get_token_regions(toks: list[str]) -> tuple[list[str], list[str]]: 

395 """Splits a list of tokens into adjacency list tokens and non-adjacency list tokens.""" 

396 adj_list_start, adj_list_end = ( 

397 toks.index("<ADJLIST_START>") + 1, 

398 toks.index("<ADJLIST_END>"), 

399 ) 

400 adj_list = toks[adj_list_start:adj_list_end] 

401 non_adj_list = toks[:adj_list_start] + toks[adj_list_end:] 

402 return adj_list, non_adj_list 

403 

404 

405def equal_except_adj_list_sequence( # noqa: C901 

406 rollout1: list[str], 

407 rollout2: list[str], 

408 do_except: bool = False, 

409 when_counter_mismatch: ErrorMode = ErrorMode.EXCEPT, 

410 when_len_mismatch: ErrorMode = ErrorMode.EXCEPT, 

411) -> bool: 

412 """Returns if the rollout strings are equal, allowing for differently sequenced adjacency lists. 

413 

414 <ADJLIST_START> and <ADJLIST_END> tokens must be in the rollouts. 

415 Intended ONLY for determining if two tokenization schemes are the same for rollouts generated from the same maze. 

416 This function should NOT be used to determine if two rollouts encode the same `LatticeMaze` object. 

417 

418 # Warning: CTT False Positives 

419 This function is not robustly correct for some corner cases using `CoordTokenizers.CTT`. 

420 If rollouts are passed for identical tokenizers processing two slightly different mazes, a false positive is possible. 

421 More specifically, some cases of zero-sum adding and removing of connections in a maze within square regions along the diagonal will produce a false positive. 

422 """ 

423 if len(rollout1) != len(rollout2): 

424 if do_except: 

425 when_len_mismatch.process( 

426 f"Rollouts are not the same length: {len(rollout1)} != {len(rollout2)}", 

427 ) 

428 return False 

429 if ("<ADJLIST_START>" in rollout1) ^ ("<ADJLIST_START>" in rollout2): 

430 if do_except: 

431 err_msg: str = f"Rollouts do not have the same <ADJLIST_START> token: `{'<ADJLIST_START>' in rollout1 = }` != `{'<ADJLIST_START>' in rollout2 = }`" 

432 raise ValueError( 

433 err_msg, 

434 ) 

435 return False 

436 if ("<ADJLIST_END>" in rollout1) ^ ("<ADJLIST_END>" in rollout2): 

437 if do_except: 

438 err_msg: str = f"Rollouts do not have the same <ADJLIST_END> token: `{'<ADJLIST_END>' in rollout1 = }` != `{'<ADJLIST_END>' in rollout2 = }`" 

439 raise ValueError( 

440 err_msg, 

441 ) 

442 return False 

443 

444 adj_list1, non_adj_list1 = get_token_regions(rollout1) 

445 adj_list2, non_adj_list2 = get_token_regions(rollout2) 

446 if non_adj_list1 != non_adj_list2: 

447 if do_except: 

448 when_len_mismatch.process( 

449 f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}", 

450 ) 

451 err_msg: str = f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}" 

452 raise ValueError( 

453 err_msg, 

454 ) 

455 return False 

456 counter1: Counter = Counter(adj_list1) 

457 counter2: Counter = Counter(adj_list2) 

458 counters_eq: bool = counter1 == counter2 

459 if not counters_eq: 

460 if do_except: 

461 when_counter_mismatch.process( 

462 f"Adjacency list counters are not the same:\n{counter1}\n!=\n{counter2}\n{counter1 - counter2 = }", 

463 ) 

464 return False 

465 

466 return True 

467 

468 

469def connection_list_to_adj_list( 

470 conn_list: ConnectionList, 

471 shuffle_d0: bool = True, 

472 shuffle_d1: bool = True, 

473) -> Int8[np.ndarray, "conn start_end=2 coord=2"]: 

474 """converts a `ConnectionList` (special lattice format) to a shuffled adjacency list 

475 

476 # Parameters: 

477 - `conn_list: ConnectionList` 

478 special internal format for graphs which are subgraphs of a lattice 

479 - `shuffle_d0: bool` 

480 shuffle the adjacency list along the 0th axis (order of pairs) 

481 - `shuffle_d1: bool` 

482 shuffle the adjacency list along the 1st axis (order of coordinates in each pair). 

483 If `False`, all pairs have the smaller coord first. 

484 

485 

486 # Returns: 

487 - `Int8[np.ndarray, "conn start_end=2 coord=2"]` 

488 adjacency list in the shape `(n_connections, 2, 2)` 

489 """ 

490 n_connections: int = conn_list.sum() 

491 adj_list: Int8[np.ndarray, "conn start_end=2 coord=2"] = np.full( 

492 (n_connections, 2, 2), 

493 -1, 

494 dtype=np.int8, 

495 ) 

496 

497 if shuffle_d1: 

498 flip_d1: Float[np.ndarray, " conn"] = np.random.rand(n_connections) 

499 

500 # loop over all nonzero elements of the connection list 

501 i: int = 0 

502 for d, x, y in np.ndindex(conn_list.shape): 

503 if conn_list[d, x, y]: 

504 c_start: CoordTup = (x, y) 

505 c_end: CoordTup = ( 

506 x + (1 if d == 0 else 0), 

507 y + (1 if d == 1 else 0), 

508 ) 

509 adj_list[i, 0] = np.array(c_start, dtype=np.int8) 

510 adj_list[i, 1] = np.array(c_end, dtype=np.int8) 

511 

512 # flip if shuffling 

513 # magic value is fine here 

514 if shuffle_d1 and (flip_d1[i] > 0.5): # noqa: PLR2004 

515 c_s, c_e = adj_list[i, 0].copy(), adj_list[i, 1].copy() 

516 adj_list[i, 0] = c_e 

517 adj_list[i, 1] = c_s 

518 

519 i += 1 

520 

521 if shuffle_d0: 

522 np.random.shuffle(adj_list) 

523 

524 return adj_list 

525 

526 

527def is_connection( 

528 edges: ConnectionArray, 

529 connection_list: ConnectionList, 

530) -> Bool[np.ndarray, "is_connection=edges"]: 

531 """Returns if each edge in `edges` is a connection (`True`) or wall (`False`) in `connection_list`.""" 

532 sorted_edges = np.sort(edges, axis=1) 

533 edge_direction = ( 

534 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0 

535 ).astype(np.int8) 

536 return connection_list[edge_direction, sorted_edges[:, 0, 0], sorted_edges[:, 0, 1]] 

537 

538 

539# string to coordinate representation 

540# ==================================================