Coverage for maze_dataset/tokenization/modular/save_hashes.py: 0%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-20 17:51 -0600

1"""generate and save the hashes of all supported tokenizers 

2 

3> [!CAUTION] 

4> using hashes to validate validate a `MazeTokenizerModular` is deprecated in favor of using fst 

5 

6calls `maze_dataset.tokenization.all_tokenizers.save_hashes()` 

7 

8Usage: 

9 

10To save to the default location (inside package, `maze_dataset/tokenization/MazeTokenizerModular_hashes.npy`): 

11```bash 

12python -m maze_dataset.tokenization.save_hashes 

13``` 

14 

15to save to a custom location: 

16```bash 

17python -m maze_dataset.tokenization.save_hashes /path/to/save/to.npy 

18``` 

19 

20to check hashes shipped with the package: 

21```bash 

22python -m maze_dataset.tokenization.save_hashes --check 

23``` 

24 

25""" 

26 

27from pathlib import Path 

28 

29import numpy as np 

30from muutils.spinner import SpinnerContext 

31 

32from maze_dataset.tokenization.modular import all_tokenizers 

33from maze_dataset.tokenization.modular.hashing import ( 

34 _load_tokenizer_hashes, 

35 get_all_tokenizer_hashes, 

36) 

37 

38if __name__ == "__main__": 

39 # parse args 

40 # ================================================== 

41 import argparse 

42 

43 parser: argparse.ArgumentParser = argparse.ArgumentParser( 

44 description="generate and save (or download) the hashes of all supported tokenizers", 

45 ) 

46 

47 parser.add_argument("path", type=str, nargs="?", help="path to save the hashes to") 

48 parser.add_argument( 

49 "--quiet", 

50 "-q", 

51 action="store_true", 

52 help="disable progress bar and spinner", 

53 ) 

54 parser.add_argument( 

55 "--parallelize", 

56 "-p", 

57 action="store_true", 

58 help="parallelize the computation", 

59 ) 

60 parser.add_argument( 

61 "--check", 

62 "-c", 

63 action="store_true", 

64 help="save to temp location, then compare to existing", 

65 ) 

66 parser.add_argument( 

67 "--download", 

68 "-d", 

69 action="store_true", 

70 help=f"download the hashes from github: {all_tokenizers.DOWNLOAD_URL}", 

71 ) 

72 

73 args: argparse.Namespace = parser.parse_args() 

74 

75 if not args.check: 

76 # write new hashes 

77 # ================================================== 

78 all_tokenizers.save_hashes( 

79 path=args.path, 

80 verbose=not args.quiet, 

81 parallelize=args.parallelize, 

82 ) 

83 

84 else: 

85 # check hashes only 

86 # ================================================== 

87 

88 # set up path 

89 if args.path is not None: 

90 raise ValueError("cannot use --check with a custom path") 

91 temp_path: Path = Path("tests/_temp/tok_hashes.npz") 

92 temp_path.parent.mkdir(parents=True, exist_ok=True) 

93 

94 # generate and save to temp location 

95 returned_hashes: np.ndarray = all_tokenizers.save_hashes( 

96 path=temp_path, 

97 verbose=not args.quiet, 

98 parallelize=args.parallelize, 

99 ) 

100 

101 # load saved hashes 

102 with SpinnerContext( 

103 spinner_chars="square_dot", 

104 update_interval=0.5, 

105 message="loading saved hashes...", 

106 ): 

107 read_hashes: np.ndarray = np.load(temp_path)["hashes"] 

108 read_hashes_pkg: np.ndarray = _load_tokenizer_hashes() 

109 read_hashes_wrapped: np.ndarray = get_all_tokenizer_hashes() 

110 

111 # compare 

112 with SpinnerContext( 

113 spinner_chars="square_dot", 

114 update_interval=0.01, 

115 message="checking hashes: ", 

116 format_string="\r{spinner} ({elapsed_time:.2f}s) {message}{value} ", 

117 format_string_when_updated=True, 

118 ) as sp: 

119 sp.update_value("returned vs read") 

120 assert np.array_equal(returned_hashes, read_hashes) 

121 sp.update_value("returned vs _load_tokenizer_hashes") 

122 assert np.array_equal(returned_hashes, read_hashes_pkg) 

123 sp.update_value("returned vs get_all_tokenizer_hashes()") 

124 assert np.array_equal(read_hashes, read_hashes_wrapped)