maze_dataset.tokenization.modular.hashing
legacy system for checking a ModularMazeTokenizer
is valid -- compare its hash to a table of known hashes
this has been superseded by the fst system
1"""legacy system for checking a `ModularMazeTokenizer` is valid -- compare its hash to a table of known hashes 2 3this has been superseded by the fst system 4 5""" 6 7import hashlib 8from pathlib import Path 9 10import numpy as np 11from jaxtyping import UInt32 12 13# NOTE: these all need to match! 14 15AllTokenizersHashBitLength = 32 16"bit length of the hashes of all tokenizers, must match `AllTokenizersHashDtype` and `AllTokenizersHashesArray`" 17 18AllTokenizersHashDtype = np.uint32 19"numpy data type of the hashes of all tokenizers, must match `AllTokenizersHashBitLength` and `AllTokenizersHashesArray`" 20 21AllTokenizersHashesArray = UInt32[np.ndarray, " n_tokens"] 22"jaxtyping type of the hashes of all tokenizers, must match `AllTokenizersHashBitLength` and `AllTokenizersHashDtype`" 23 24 25def _hash_tokenizer_name(s: str) -> int: 26 h64: int = int.from_bytes( 27 hashlib.shake_256(s.encode("utf-8")).digest(64), 28 byteorder="big", 29 ) 30 return (h64 >> 32) ^ (h64 & 0xFFFFFFFF) 31 32 33_ALL_TOKENIZER_HASHES: AllTokenizersHashesArray 34"private array of all tokenizer hashes" 35_TOKENIZER_HASHES_PATH: Path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz" 36"path to where we expect the hashes file -- in the same dir as this file, by default. change with `set_tokenizer_hashes_path`" 37 38 39def set_tokenizer_hashes_path(path: Path) -> None: 40 """set path to tokenizer hashes, and reload the hashes if needed 41 42 the hashes are expected to be stored in and read from `_TOKENIZER_HASHES_PATH`, 43 which by default is `Path(__file__).parent / "MazeTokenizerModular_hashes.npz"` or in this file's directory. 44 45 However, this might not always work, so we provide a way to change this. 46 """ 47 global _TOKENIZER_HASHES_PATH, _ALL_TOKENIZER_HASHES # noqa: PLW0603 48 49 path = Path(path) 50 if path.is_dir(): 51 path = path / "MazeTokenizerModular_hashes.npz" 52 53 if not path.is_file(): 54 err_msg: str = f"could not find maze tokenizer hashes file at: {path}" 55 raise FileNotFoundError(err_msg) 56 57 if _TOKENIZER_HASHES_PATH.absolute() != path.absolute(): 58 # reload if they aren't equal 59 _TOKENIZER_HASHES_PATH = path 60 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 61 else: 62 # always set to new path 63 _TOKENIZER_HASHES_PATH = path 64 65 66def _load_tokenizer_hashes() -> AllTokenizersHashesArray: 67 """Loads the sorted list of `all_tokenizers.get_all_tokenizers()` hashes from disk.""" 68 global _TOKENIZER_HASHES_PATH # noqa: PLW0602 69 try: 70 path: Path = _TOKENIZER_HASHES_PATH 71 return np.load(path)["hashes"] 72 except FileNotFoundError as e: 73 err_msg: str = ( 74 "Tokenizers hashes cannot be loaded. To fix this, run" 75 "\n`python -m maze-dataset.tokenization.save_hashes` which will save the hashes to" 76 "\n`data/MazeTokenizerModular_hashes.npz`" 77 "\nrelative to the current working directory -- this is where the code looks for them." 78 ) 79 raise FileNotFoundError(err_msg) from e 80 81 82def get_all_tokenizer_hashes() -> AllTokenizersHashesArray: 83 """returns all the tokenizer hashes in an `AllTokenizersHashesDtype` array, setting global variable if needed""" 84 # naughty use of globals 85 global _ALL_TOKENIZER_HASHES # noqa: PLW0603 86 try: 87 got_tokenizers: bool = len(_ALL_TOKENIZER_HASHES) > 0 88 if got_tokenizers: 89 return _ALL_TOKENIZER_HASHES 90 else: 91 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 92 except NameError: 93 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 94 95 return _ALL_TOKENIZER_HASHES
bit length of the hashes of all tokenizers, must match AllTokenizersHashDtype
and AllTokenizersHashesArray
numpy data type of the hashes of all tokenizers, must match AllTokenizersHashBitLength
and AllTokenizersHashesArray
jaxtyping type of the hashes of all tokenizers, must match AllTokenizersHashBitLength
and AllTokenizersHashDtype
40def set_tokenizer_hashes_path(path: Path) -> None: 41 """set path to tokenizer hashes, and reload the hashes if needed 42 43 the hashes are expected to be stored in and read from `_TOKENIZER_HASHES_PATH`, 44 which by default is `Path(__file__).parent / "MazeTokenizerModular_hashes.npz"` or in this file's directory. 45 46 However, this might not always work, so we provide a way to change this. 47 """ 48 global _TOKENIZER_HASHES_PATH, _ALL_TOKENIZER_HASHES # noqa: PLW0603 49 50 path = Path(path) 51 if path.is_dir(): 52 path = path / "MazeTokenizerModular_hashes.npz" 53 54 if not path.is_file(): 55 err_msg: str = f"could not find maze tokenizer hashes file at: {path}" 56 raise FileNotFoundError(err_msg) 57 58 if _TOKENIZER_HASHES_PATH.absolute() != path.absolute(): 59 # reload if they aren't equal 60 _TOKENIZER_HASHES_PATH = path 61 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 62 else: 63 # always set to new path 64 _TOKENIZER_HASHES_PATH = path
set path to tokenizer hashes, and reload the hashes if needed
the hashes are expected to be stored in and read from _TOKENIZER_HASHES_PATH
,
which by default is Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
or in this file's directory.
However, this might not always work, so we provide a way to change this.
83def get_all_tokenizer_hashes() -> AllTokenizersHashesArray: 84 """returns all the tokenizer hashes in an `AllTokenizersHashesDtype` array, setting global variable if needed""" 85 # naughty use of globals 86 global _ALL_TOKENIZER_HASHES # noqa: PLW0603 87 try: 88 got_tokenizers: bool = len(_ALL_TOKENIZER_HASHES) > 0 89 if got_tokenizers: 90 return _ALL_TOKENIZER_HASHES 91 else: 92 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 93 except NameError: 94 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 95 96 return _ALL_TOKENIZER_HASHES
returns all the tokenizer hashes in an AllTokenizersHashesDtype
array, setting global variable if needed