maze_dataset.tokenization.modular.fst
to check if a tokenizer is one of our "approved" ones, we store this in a fst set using rust_fst
this file handles the creation of this fst file, which we ship to the user
this file relies on importing get_all_tokenizers
and thus MazeTokenizerModular
.
as such, loading this file for validating a tokenizer is the separate maze_dataset.tokenization.modular.fst_load
module, since we need to be able to import that from maze_dataset.tokenization.modular.maze_tokenizer_modular
and
we cannot circularly import
1"""to check if a tokenizer is one of our "approved" ones, we store this in a fst set using `rust_fst` 2 3this file handles the creation of this fst file, which we ship to the user 4 5this file relies on importing `get_all_tokenizers` and thus `MazeTokenizerModular`. 6as such, loading this file for validating a tokenizer is the separate `maze_dataset.tokenization.modular.fst_load` 7module, since we need to be able to import that from `maze_dataset.tokenization.modular.maze_tokenizer_modular` and 8we cannot circularly import 9 10""" 11 12import functools 13import random 14 15import tqdm 16from muutils.misc.numerical import shorten_numerical_to_str 17from muutils.parallel import run_maybe_parallel 18from muutils.spinner import NoOpContextManager, SpinnerContext 19from rust_fst import Set as FstSet # type: ignore[import-untyped] 20 21from maze_dataset.tokenization.modular.all_tokenizers import get_all_tokenizers 22from maze_dataset.tokenization.modular.fst_load import ( 23 MMT_FST_PATH, 24 check_tokenizer_in_fst, 25 get_tokenizers_fst, 26) 27 28 29def _get_tokenizer_name(tokenizer) -> str: # noqa: ANN001 30 return tokenizer.name 31 32 33def save_all_tokenizers_fst( 34 verbose: bool = True, parallel: bool | int = False 35) -> FstSet: 36 """get all the tokenizers, save an fst file at `MMT_FST_PATH` and return the set""" 37 # TYPING: add a protocol or abc for both of these which is a context manager that takes the args we care about 38 # probably do this in muutils 39 sp: type[SpinnerContext | NoOpContextManager] = ( 40 SpinnerContext if verbose else NoOpContextManager 41 ) 42 43 with sp(message="getting all tokenizers"): 44 all_tokenizers: list = get_all_tokenizers() 45 46 n_tokenizers: int = len(all_tokenizers) 47 48 all_tokenizers_names: list[str] = run_maybe_parallel( 49 func=_get_tokenizer_name, 50 iterable=all_tokenizers, 51 parallel=parallel, 52 pbar=tqdm.tqdm, 53 pbar_kwargs=dict( 54 total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose 55 ), 56 ) 57 58 assert n_tokenizers == len(all_tokenizers_names) 59 print( 60 f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names" 61 ) 62 63 with sp(message="sorting tokenizer names"): 64 all_tokenizers_names_sorted: list[str] = sorted(all_tokenizers_names) 65 66 # construct an fst set and save it 67 # we expect it to be 1.6kb or so 68 with sp(message="constructing and saving tokenizers fst set"): 69 tok_set: FstSet = FstSet.from_iter( 70 all_tokenizers_names_sorted, 71 path=MMT_FST_PATH.as_posix(), 72 ) 73 74 print( 75 f"# tokenizers fst set saved to {MMT_FST_PATH}, size: {MMT_FST_PATH.stat().st_size} bytes" 76 ) 77 78 return tok_set 79 80 81def check_tokenizers_fst( 82 verbose: bool = True, 83 parallel: bool | int = False, 84 n_check: int | None = None, 85) -> FstSet: 86 "regen all tokenizers, check they are in the pre-existing fst set" 87 sp: type[SpinnerContext | NoOpContextManager] = ( 88 SpinnerContext if verbose else NoOpContextManager 89 ) 90 91 with sp(message="getting all tokenizers from scratch"): 92 all_tokenizers: list = get_all_tokenizers() 93 94 with sp(message="load the pre-existing tokenizers fst set"): 95 get_tokenizers_fst() 96 97 n_tokenizers: int = len(all_tokenizers) 98 99 selected_tokenizers: list 100 if n_check is not None: 101 selected_tokenizers = random.sample(all_tokenizers, n_check) 102 else: 103 selected_tokenizers = all_tokenizers 104 105 tokenizers_names: list[str] = run_maybe_parallel( 106 func=_get_tokenizer_name, 107 iterable=selected_tokenizers, 108 parallel=parallel, 109 pbar=tqdm.tqdm, 110 pbar_kwargs=dict( 111 total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose 112 ), 113 ) 114 115 if n_check is None: 116 assert n_tokenizers == len(tokenizers_names) 117 print( 118 f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names" 119 ) 120 else: 121 assert n_check == len(tokenizers_names) 122 print( 123 f"# selected {n_check} tokenizers to check out of {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) total" 124 ) 125 126 check_tokenizer_in_fst__do_except = functools.partial( 127 check_tokenizer_in_fst, do_except=True 128 ) 129 130 run_maybe_parallel( 131 func=check_tokenizer_in_fst__do_except, 132 iterable=tokenizers_names, 133 parallel=parallel, 134 pbar=tqdm.tqdm, 135 pbar_kwargs=dict( 136 total=len(selected_tokenizers), 137 desc="checking tokenizers in fst", 138 disable=not verbose, 139 ), 140 ) 141 142 if n_check is None: 143 print("# all tokenizers are in the pre-existing fst set!") 144 else: 145 print(f"# all {n_check} selected tokenizers are in the pre-existing fst set!") 146 147 148if __name__ == "__main__": 149 import argparse 150 151 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser( 152 description="save the tokenizers fst set" 153 ) 154 arg_parser.add_argument( 155 "-c", 156 "--check", 157 action="store_true", 158 help="check that all tokenizers are in the pre-existing fst set", 159 ) 160 arg_parser.add_argument( 161 "-q", 162 "--quiet", 163 action="store_true", 164 help="don't show spinners and progress bars", 165 ) 166 arg_parser.add_argument( 167 "-p", 168 "--parallel", 169 action="store", 170 nargs="?", 171 type=int, 172 const=True, 173 default=False, 174 help="Control parallelization. will run in serial if nothing specified, use all cpus if flag passed without args, or number of cpus if int passed.", 175 ) 176 arg_parser.add_argument( 177 "-n", 178 "--n-check", 179 action="store", 180 default=None, 181 help="if passed, check n random tokenizers. pass an int to check that many. pass 'none' or a -1 to check all", 182 ) 183 args: argparse.Namespace = arg_parser.parse_args() 184 185 n_check: int | None = ( 186 int(args.n_check) 187 if (args.n_check is not None and args.n_check.lower() != "none") 188 else None 189 ) 190 if n_check is not None and n_check < 0: 191 n_check = None 192 193 if args.check: 194 check_tokenizers_fst( 195 verbose=not args.quiet, 196 parallel=args.parallel, 197 n_check=n_check, 198 ) 199 else: 200 save_all_tokenizers_fst(verbose=not args.quiet, parallel=args.parallel)
def
save_all_tokenizers_fst(verbose: bool = True, parallel: bool | int = False) -> rust_fst.set.Set:
34def save_all_tokenizers_fst( 35 verbose: bool = True, parallel: bool | int = False 36) -> FstSet: 37 """get all the tokenizers, save an fst file at `MMT_FST_PATH` and return the set""" 38 # TYPING: add a protocol or abc for both of these which is a context manager that takes the args we care about 39 # probably do this in muutils 40 sp: type[SpinnerContext | NoOpContextManager] = ( 41 SpinnerContext if verbose else NoOpContextManager 42 ) 43 44 with sp(message="getting all tokenizers"): 45 all_tokenizers: list = get_all_tokenizers() 46 47 n_tokenizers: int = len(all_tokenizers) 48 49 all_tokenizers_names: list[str] = run_maybe_parallel( 50 func=_get_tokenizer_name, 51 iterable=all_tokenizers, 52 parallel=parallel, 53 pbar=tqdm.tqdm, 54 pbar_kwargs=dict( 55 total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose 56 ), 57 ) 58 59 assert n_tokenizers == len(all_tokenizers_names) 60 print( 61 f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names" 62 ) 63 64 with sp(message="sorting tokenizer names"): 65 all_tokenizers_names_sorted: list[str] = sorted(all_tokenizers_names) 66 67 # construct an fst set and save it 68 # we expect it to be 1.6kb or so 69 with sp(message="constructing and saving tokenizers fst set"): 70 tok_set: FstSet = FstSet.from_iter( 71 all_tokenizers_names_sorted, 72 path=MMT_FST_PATH.as_posix(), 73 ) 74 75 print( 76 f"# tokenizers fst set saved to {MMT_FST_PATH}, size: {MMT_FST_PATH.stat().st_size} bytes" 77 ) 78 79 return tok_set
get all the tokenizers, save an fst file at MMT_FST_PATH
and return the set
def
check_tokenizers_fst( verbose: bool = True, parallel: bool | int = False, n_check: int | None = None) -> rust_fst.set.Set:
82def check_tokenizers_fst( 83 verbose: bool = True, 84 parallel: bool | int = False, 85 n_check: int | None = None, 86) -> FstSet: 87 "regen all tokenizers, check they are in the pre-existing fst set" 88 sp: type[SpinnerContext | NoOpContextManager] = ( 89 SpinnerContext if verbose else NoOpContextManager 90 ) 91 92 with sp(message="getting all tokenizers from scratch"): 93 all_tokenizers: list = get_all_tokenizers() 94 95 with sp(message="load the pre-existing tokenizers fst set"): 96 get_tokenizers_fst() 97 98 n_tokenizers: int = len(all_tokenizers) 99 100 selected_tokenizers: list 101 if n_check is not None: 102 selected_tokenizers = random.sample(all_tokenizers, n_check) 103 else: 104 selected_tokenizers = all_tokenizers 105 106 tokenizers_names: list[str] = run_maybe_parallel( 107 func=_get_tokenizer_name, 108 iterable=selected_tokenizers, 109 parallel=parallel, 110 pbar=tqdm.tqdm, 111 pbar_kwargs=dict( 112 total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose 113 ), 114 ) 115 116 if n_check is None: 117 assert n_tokenizers == len(tokenizers_names) 118 print( 119 f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names" 120 ) 121 else: 122 assert n_check == len(tokenizers_names) 123 print( 124 f"# selected {n_check} tokenizers to check out of {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) total" 125 ) 126 127 check_tokenizer_in_fst__do_except = functools.partial( 128 check_tokenizer_in_fst, do_except=True 129 ) 130 131 run_maybe_parallel( 132 func=check_tokenizer_in_fst__do_except, 133 iterable=tokenizers_names, 134 parallel=parallel, 135 pbar=tqdm.tqdm, 136 pbar_kwargs=dict( 137 total=len(selected_tokenizers), 138 desc="checking tokenizers in fst", 139 disable=not verbose, 140 ), 141 ) 142 143 if n_check is None: 144 print("# all tokenizers are in the pre-existing fst set!") 145 else: 146 print(f"# all {n_check} selected tokenizers are in the pre-existing fst set!")
regen all tokenizers, check they are in the pre-existing fst set