Coverage for maze_dataset/tokenization/modular/fst.py: 0%

60 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-20 17:51 -0600

1"""to check if a tokenizer is one of our "approved" ones, we store this in a fst set using `rust_fst` 

2 

3this file handles the creation of this fst file, which we ship to the user 

4 

5this file relies on importing `get_all_tokenizers` and thus `MazeTokenizerModular`. 

6as such, loading this file for validating a tokenizer is the separate `maze_dataset.tokenization.modular.fst_load` 

7module, since we need to be able to import that from `maze_dataset.tokenization.modular.maze_tokenizer_modular` and 

8we cannot circularly import 

9 

10""" 

11 

12import functools 

13import random 

14 

15import tqdm 

16from muutils.misc.numerical import shorten_numerical_to_str 

17from muutils.parallel import run_maybe_parallel 

18from muutils.spinner import NoOpContextManager, SpinnerContext 

19from rust_fst import Set as FstSet # type: ignore[import-untyped] 

20 

21from maze_dataset.tokenization.modular.all_tokenizers import get_all_tokenizers 

22from maze_dataset.tokenization.modular.fst_load import ( 

23 MMT_FST_PATH, 

24 check_tokenizer_in_fst, 

25 get_tokenizers_fst, 

26) 

27 

28 

29def _get_tokenizer_name(tokenizer) -> str: # noqa: ANN001 

30 return tokenizer.name 

31 

32 

33def save_all_tokenizers_fst( 

34 verbose: bool = True, parallel: bool | int = False 

35) -> FstSet: 

36 """get all the tokenizers, save an fst file at `MMT_FST_PATH` and return the set""" 

37 # TYPING: add a protocol or abc for both of these which is a context manager that takes the args we care about 

38 # probably do this in muutils 

39 sp: type[SpinnerContext | NoOpContextManager] = ( 

40 SpinnerContext if verbose else NoOpContextManager 

41 ) 

42 

43 with sp(message="getting all tokenizers"): 

44 all_tokenizers: list = get_all_tokenizers() 

45 

46 n_tokenizers: int = len(all_tokenizers) 

47 

48 all_tokenizers_names: list[str] = run_maybe_parallel( 

49 func=_get_tokenizer_name, 

50 iterable=all_tokenizers, 

51 parallel=parallel, 

52 pbar=tqdm.tqdm, 

53 pbar_kwargs=dict( 

54 total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose 

55 ), 

56 ) 

57 

58 assert n_tokenizers == len(all_tokenizers_names) 

59 print( 

60 f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names" 

61 ) 

62 

63 with sp(message="sorting tokenizer names"): 

64 all_tokenizers_names_sorted: list[str] = sorted(all_tokenizers_names) 

65 

66 # construct an fst set and save it 

67 # we expect it to be 1.6kb or so 

68 with sp(message="constructing and saving tokenizers fst set"): 

69 tok_set: FstSet = FstSet.from_iter( 

70 all_tokenizers_names_sorted, 

71 path=MMT_FST_PATH.as_posix(), 

72 ) 

73 

74 print( 

75 f"# tokenizers fst set saved to {MMT_FST_PATH}, size: {MMT_FST_PATH.stat().st_size} bytes" 

76 ) 

77 

78 return tok_set 

79 

80 

81def check_tokenizers_fst( 

82 verbose: bool = True, 

83 parallel: bool | int = False, 

84 n_check: int | None = None, 

85) -> FstSet: 

86 "regen all tokenizers, check they are in the pre-existing fst set" 

87 sp: type[SpinnerContext | NoOpContextManager] = ( 

88 SpinnerContext if verbose else NoOpContextManager 

89 ) 

90 

91 with sp(message="getting all tokenizers from scratch"): 

92 all_tokenizers: list = get_all_tokenizers() 

93 

94 with sp(message="load the pre-existing tokenizers fst set"): 

95 get_tokenizers_fst() 

96 

97 n_tokenizers: int = len(all_tokenizers) 

98 

99 selected_tokenizers: list 

100 if n_check is not None: 

101 selected_tokenizers = random.sample(all_tokenizers, n_check) 

102 else: 

103 selected_tokenizers = all_tokenizers 

104 

105 tokenizers_names: list[str] = run_maybe_parallel( 

106 func=_get_tokenizer_name, 

107 iterable=selected_tokenizers, 

108 parallel=parallel, 

109 pbar=tqdm.tqdm, 

110 pbar_kwargs=dict( 

111 total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose 

112 ), 

113 ) 

114 

115 if n_check is None: 

116 assert n_tokenizers == len(tokenizers_names) 

117 print( 

118 f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names" 

119 ) 

120 else: 

121 assert n_check == len(tokenizers_names) 

122 print( 

123 f"# selected {n_check} tokenizers to check out of {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) total" 

124 ) 

125 

126 check_tokenizer_in_fst__do_except = functools.partial( 

127 check_tokenizer_in_fst, do_except=True 

128 ) 

129 

130 run_maybe_parallel( 

131 func=check_tokenizer_in_fst__do_except, 

132 iterable=tokenizers_names, 

133 parallel=parallel, 

134 pbar=tqdm.tqdm, 

135 pbar_kwargs=dict( 

136 total=len(selected_tokenizers), 

137 desc="checking tokenizers in fst", 

138 disable=not verbose, 

139 ), 

140 ) 

141 

142 if n_check is None: 

143 print("# all tokenizers are in the pre-existing fst set!") 

144 else: 

145 print(f"# all {n_check} selected tokenizers are in the pre-existing fst set!") 

146 

147 

148if __name__ == "__main__": 

149 import argparse 

150 

151 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser( 

152 description="save the tokenizers fst set" 

153 ) 

154 arg_parser.add_argument( 

155 "-c", 

156 "--check", 

157 action="store_true", 

158 help="check that all tokenizers are in the pre-existing fst set", 

159 ) 

160 arg_parser.add_argument( 

161 "-q", 

162 "--quiet", 

163 action="store_true", 

164 help="don't show spinners and progress bars", 

165 ) 

166 arg_parser.add_argument( 

167 "-p", 

168 "--parallel", 

169 action="store", 

170 nargs="?", 

171 type=int, 

172 const=True, 

173 default=False, 

174 help="Control parallelization. will run in serial if nothing specified, use all cpus if flag passed without args, or number of cpus if int passed.", 

175 ) 

176 arg_parser.add_argument( 

177 "-n", 

178 "--n-check", 

179 action="store", 

180 default=None, 

181 help="if passed, check n random tokenizers. pass an int to check that many. pass 'none' or a -1 to check all", 

182 ) 

183 args: argparse.Namespace = arg_parser.parse_args() 

184 

185 n_check: int | None = ( 

186 int(args.n_check) 

187 if (args.n_check is not None and args.n_check.lower() != "none") 

188 else None 

189 ) 

190 if n_check is not None and n_check < 0: 

191 n_check = None 

192 

193 if args.check: 

194 check_tokenizers_fst( 

195 verbose=not args.quiet, 

196 parallel=args.parallel, 

197 n_check=n_check, 

198 ) 

199 else: 

200 save_all_tokenizers_fst(verbose=not args.quiet, parallel=args.parallel)