Coverage for maze_dataset/tokenization/modular/fst.py: 0%
60 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
1"""to check if a tokenizer is one of our "approved" ones, we store this in a fst set using `rust_fst`
3this file handles the creation of this fst file, which we ship to the user
5this file relies on importing `get_all_tokenizers` and thus `MazeTokenizerModular`.
6as such, loading this file for validating a tokenizer is the separate `maze_dataset.tokenization.modular.fst_load`
7module, since we need to be able to import that from `maze_dataset.tokenization.modular.maze_tokenizer_modular` and
8we cannot circularly import
10"""
12import functools
13import random
15import tqdm
16from muutils.misc.numerical import shorten_numerical_to_str
17from muutils.parallel import run_maybe_parallel
18from muutils.spinner import NoOpContextManager, SpinnerContext
19from rust_fst import Set as FstSet # type: ignore[import-untyped]
21from maze_dataset.tokenization.modular.all_tokenizers import get_all_tokenizers
22from maze_dataset.tokenization.modular.fst_load import (
23 MMT_FST_PATH,
24 check_tokenizer_in_fst,
25 get_tokenizers_fst,
26)
29def _get_tokenizer_name(tokenizer) -> str: # noqa: ANN001
30 return tokenizer.name
33def save_all_tokenizers_fst(
34 verbose: bool = True, parallel: bool | int = False
35) -> FstSet:
36 """get all the tokenizers, save an fst file at `MMT_FST_PATH` and return the set"""
37 # TYPING: add a protocol or abc for both of these which is a context manager that takes the args we care about
38 # probably do this in muutils
39 sp: type[SpinnerContext | NoOpContextManager] = (
40 SpinnerContext if verbose else NoOpContextManager
41 )
43 with sp(message="getting all tokenizers"):
44 all_tokenizers: list = get_all_tokenizers()
46 n_tokenizers: int = len(all_tokenizers)
48 all_tokenizers_names: list[str] = run_maybe_parallel(
49 func=_get_tokenizer_name,
50 iterable=all_tokenizers,
51 parallel=parallel,
52 pbar=tqdm.tqdm,
53 pbar_kwargs=dict(
54 total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose
55 ),
56 )
58 assert n_tokenizers == len(all_tokenizers_names)
59 print(
60 f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names"
61 )
63 with sp(message="sorting tokenizer names"):
64 all_tokenizers_names_sorted: list[str] = sorted(all_tokenizers_names)
66 # construct an fst set and save it
67 # we expect it to be 1.6kb or so
68 with sp(message="constructing and saving tokenizers fst set"):
69 tok_set: FstSet = FstSet.from_iter(
70 all_tokenizers_names_sorted,
71 path=MMT_FST_PATH.as_posix(),
72 )
74 print(
75 f"# tokenizers fst set saved to {MMT_FST_PATH}, size: {MMT_FST_PATH.stat().st_size} bytes"
76 )
78 return tok_set
81def check_tokenizers_fst(
82 verbose: bool = True,
83 parallel: bool | int = False,
84 n_check: int | None = None,
85) -> FstSet:
86 "regen all tokenizers, check they are in the pre-existing fst set"
87 sp: type[SpinnerContext | NoOpContextManager] = (
88 SpinnerContext if verbose else NoOpContextManager
89 )
91 with sp(message="getting all tokenizers from scratch"):
92 all_tokenizers: list = get_all_tokenizers()
94 with sp(message="load the pre-existing tokenizers fst set"):
95 get_tokenizers_fst()
97 n_tokenizers: int = len(all_tokenizers)
99 selected_tokenizers: list
100 if n_check is not None:
101 selected_tokenizers = random.sample(all_tokenizers, n_check)
102 else:
103 selected_tokenizers = all_tokenizers
105 tokenizers_names: list[str] = run_maybe_parallel(
106 func=_get_tokenizer_name,
107 iterable=selected_tokenizers,
108 parallel=parallel,
109 pbar=tqdm.tqdm,
110 pbar_kwargs=dict(
111 total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose
112 ),
113 )
115 if n_check is None:
116 assert n_tokenizers == len(tokenizers_names)
117 print(
118 f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names"
119 )
120 else:
121 assert n_check == len(tokenizers_names)
122 print(
123 f"# selected {n_check} tokenizers to check out of {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) total"
124 )
126 check_tokenizer_in_fst__do_except = functools.partial(
127 check_tokenizer_in_fst, do_except=True
128 )
130 run_maybe_parallel(
131 func=check_tokenizer_in_fst__do_except,
132 iterable=tokenizers_names,
133 parallel=parallel,
134 pbar=tqdm.tqdm,
135 pbar_kwargs=dict(
136 total=len(selected_tokenizers),
137 desc="checking tokenizers in fst",
138 disable=not verbose,
139 ),
140 )
142 if n_check is None:
143 print("# all tokenizers are in the pre-existing fst set!")
144 else:
145 print(f"# all {n_check} selected tokenizers are in the pre-existing fst set!")
148if __name__ == "__main__":
149 import argparse
151 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser(
152 description="save the tokenizers fst set"
153 )
154 arg_parser.add_argument(
155 "-c",
156 "--check",
157 action="store_true",
158 help="check that all tokenizers are in the pre-existing fst set",
159 )
160 arg_parser.add_argument(
161 "-q",
162 "--quiet",
163 action="store_true",
164 help="don't show spinners and progress bars",
165 )
166 arg_parser.add_argument(
167 "-p",
168 "--parallel",
169 action="store",
170 nargs="?",
171 type=int,
172 const=True,
173 default=False,
174 help="Control parallelization. will run in serial if nothing specified, use all cpus if flag passed without args, or number of cpus if int passed.",
175 )
176 arg_parser.add_argument(
177 "-n",
178 "--n-check",
179 action="store",
180 default=None,
181 help="if passed, check n random tokenizers. pass an int to check that many. pass 'none' or a -1 to check all",
182 )
183 args: argparse.Namespace = arg_parser.parse_args()
185 n_check: int | None = (
186 int(args.n_check)
187 if (args.n_check is not None and args.n_check.lower() != "none")
188 else None
189 )
190 if n_check is not None and n_check < 0:
191 n_check = None
193 if args.check:
194 check_tokenizers_fst(
195 verbose=not args.quiet,
196 parallel=args.parallel,
197 n_check=n_check,
198 )
199 else:
200 save_all_tokenizers_fst(verbose=not args.quiet, parallel=args.parallel)