Coverage for maze_dataset/tokenization/modular/hashing.py: 45%
44 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
1"""legacy system for checking a `ModularMazeTokenizer` is valid -- compare its hash to a table of known hashes
3this has been superseded by the fst system
5"""
7import hashlib
8from pathlib import Path
10import numpy as np
11from jaxtyping import UInt32
13# NOTE: these all need to match!
15AllTokenizersHashBitLength = 32
16"bit length of the hashes of all tokenizers, must match `AllTokenizersHashDtype` and `AllTokenizersHashesArray`"
18AllTokenizersHashDtype = np.uint32
19"numpy data type of the hashes of all tokenizers, must match `AllTokenizersHashBitLength` and `AllTokenizersHashesArray`"
21AllTokenizersHashesArray = UInt32[np.ndarray, " n_tokens"]
22"jaxtyping type of the hashes of all tokenizers, must match `AllTokenizersHashBitLength` and `AllTokenizersHashDtype`"
25def _hash_tokenizer_name(s: str) -> int:
26 h64: int = int.from_bytes(
27 hashlib.shake_256(s.encode("utf-8")).digest(64),
28 byteorder="big",
29 )
30 return (h64 >> 32) ^ (h64 & 0xFFFFFFFF)
33_ALL_TOKENIZER_HASHES: AllTokenizersHashesArray
34"private array of all tokenizer hashes"
35_TOKENIZER_HASHES_PATH: Path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
36"path to where we expect the hashes file -- in the same dir as this file, by default. change with `set_tokenizer_hashes_path`"
39def set_tokenizer_hashes_path(path: Path) -> None:
40 """set path to tokenizer hashes, and reload the hashes if needed
42 the hashes are expected to be stored in and read from `_TOKENIZER_HASHES_PATH`,
43 which by default is `Path(__file__).parent / "MazeTokenizerModular_hashes.npz"` or in this file's directory.
45 However, this might not always work, so we provide a way to change this.
46 """
47 global _TOKENIZER_HASHES_PATH, _ALL_TOKENIZER_HASHES # noqa: PLW0603
49 path = Path(path)
50 if path.is_dir():
51 path = path / "MazeTokenizerModular_hashes.npz"
53 if not path.is_file():
54 err_msg: str = f"could not find maze tokenizer hashes file at: {path}"
55 raise FileNotFoundError(err_msg)
57 if _TOKENIZER_HASHES_PATH.absolute() != path.absolute():
58 # reload if they aren't equal
59 _TOKENIZER_HASHES_PATH = path
60 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
61 else:
62 # always set to new path
63 _TOKENIZER_HASHES_PATH = path
66def _load_tokenizer_hashes() -> AllTokenizersHashesArray:
67 """Loads the sorted list of `all_tokenizers.get_all_tokenizers()` hashes from disk."""
68 global _TOKENIZER_HASHES_PATH # noqa: PLW0602
69 try:
70 path: Path = _TOKENIZER_HASHES_PATH
71 return np.load(path)["hashes"]
72 except FileNotFoundError as e:
73 err_msg: str = (
74 "Tokenizers hashes cannot be loaded. To fix this, run"
75 "\n`python -m maze-dataset.tokenization.save_hashes` which will save the hashes to"
76 "\n`data/MazeTokenizerModular_hashes.npz`"
77 "\nrelative to the current working directory -- this is where the code looks for them."
78 )
79 raise FileNotFoundError(err_msg) from e
82def get_all_tokenizer_hashes() -> AllTokenizersHashesArray:
83 """returns all the tokenizer hashes in an `AllTokenizersHashesDtype` array, setting global variable if needed"""
84 # naughty use of globals
85 global _ALL_TOKENIZER_HASHES # noqa: PLW0603
86 try:
87 got_tokenizers: bool = len(_ALL_TOKENIZER_HASHES) > 0
88 if got_tokenizers:
89 return _ALL_TOKENIZER_HASHES
90 else:
91 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
92 except NameError:
93 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
95 return _ALL_TOKENIZER_HASHES