Coverage for maze_dataset/tokenization/modular/hashing.py: 45%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 12:48 -0600

1"""legacy system for checking a `ModularMazeTokenizer` is valid -- compare its hash to a table of known hashes 

2 

3this has been superseded by the fst system 

4 

5""" 

6 

7import hashlib 

8from pathlib import Path 

9 

10import numpy as np 

11from jaxtyping import UInt32 

12 

13# NOTE: these all need to match! 

14 

15AllTokenizersHashBitLength = 32 

16"bit length of the hashes of all tokenizers, must match `AllTokenizersHashDtype` and `AllTokenizersHashesArray`" 

17 

18AllTokenizersHashDtype = np.uint32 

19"numpy data type of the hashes of all tokenizers, must match `AllTokenizersHashBitLength` and `AllTokenizersHashesArray`" 

20 

21AllTokenizersHashesArray = UInt32[np.ndarray, " n_tokens"] 

22"jaxtyping type of the hashes of all tokenizers, must match `AllTokenizersHashBitLength` and `AllTokenizersHashDtype`" 

23 

24 

25def _hash_tokenizer_name(s: str) -> int: 

26 h64: int = int.from_bytes( 

27 hashlib.shake_256(s.encode("utf-8")).digest(64), 

28 byteorder="big", 

29 ) 

30 return (h64 >> 32) ^ (h64 & 0xFFFFFFFF) 

31 

32 

33_ALL_TOKENIZER_HASHES: AllTokenizersHashesArray 

34"private array of all tokenizer hashes" 

35_TOKENIZER_HASHES_PATH: Path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz" 

36"path to where we expect the hashes file -- in the same dir as this file, by default. change with `set_tokenizer_hashes_path`" 

37 

38 

39def set_tokenizer_hashes_path(path: Path) -> None: 

40 """set path to tokenizer hashes, and reload the hashes if needed 

41 

42 the hashes are expected to be stored in and read from `_TOKENIZER_HASHES_PATH`, 

43 which by default is `Path(__file__).parent / "MazeTokenizerModular_hashes.npz"` or in this file's directory. 

44 

45 However, this might not always work, so we provide a way to change this. 

46 """ 

47 global _TOKENIZER_HASHES_PATH, _ALL_TOKENIZER_HASHES # noqa: PLW0603 

48 

49 path = Path(path) 

50 if path.is_dir(): 

51 path = path / "MazeTokenizerModular_hashes.npz" 

52 

53 if not path.is_file(): 

54 err_msg: str = f"could not find maze tokenizer hashes file at: {path}" 

55 raise FileNotFoundError(err_msg) 

56 

57 if _TOKENIZER_HASHES_PATH.absolute() != path.absolute(): 

58 # reload if they aren't equal 

59 _TOKENIZER_HASHES_PATH = path 

60 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 

61 else: 

62 # always set to new path 

63 _TOKENIZER_HASHES_PATH = path 

64 

65 

66def _load_tokenizer_hashes() -> AllTokenizersHashesArray: 

67 """Loads the sorted list of `all_tokenizers.get_all_tokenizers()` hashes from disk.""" 

68 global _TOKENIZER_HASHES_PATH # noqa: PLW0602 

69 try: 

70 path: Path = _TOKENIZER_HASHES_PATH 

71 return np.load(path)["hashes"] 

72 except FileNotFoundError as e: 

73 err_msg: str = ( 

74 "Tokenizers hashes cannot be loaded. To fix this, run" 

75 "\n`python -m maze-dataset.tokenization.save_hashes` which will save the hashes to" 

76 "\n`data/MazeTokenizerModular_hashes.npz`" 

77 "\nrelative to the current working directory -- this is where the code looks for them." 

78 ) 

79 raise FileNotFoundError(err_msg) from e 

80 

81 

82def get_all_tokenizer_hashes() -> AllTokenizersHashesArray: 

83 """returns all the tokenizer hashes in an `AllTokenizersHashesDtype` array, setting global variable if needed""" 

84 # naughty use of globals 

85 global _ALL_TOKENIZER_HASHES # noqa: PLW0603 

86 try: 

87 got_tokenizers: bool = len(_ALL_TOKENIZER_HASHES) > 0 

88 if got_tokenizers: 

89 return _ALL_TOKENIZER_HASHES 

90 else: 

91 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 

92 except NameError: 

93 _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() 

94 

95 return _ALL_TOKENIZER_HASHES