docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.tokenization.modular.fst

to check if a tokenizer is one of our "approved" ones, we store this in a fst set using rust_fst

this file handles the creation of this fst file, which we ship to the user

this file relies on importing get_all_tokenizers and thus MazeTokenizerModular. as such, loading this file for validating a tokenizer is the separate maze_dataset.tokenization.modular.fst_load module, since we need to be able to import that from maze_dataset.tokenization.modular.maze_tokenizer_modular and we cannot circularly import


  1"""to check if a tokenizer is one of our "approved" ones, we store this in a fst set using `rust_fst`
  2
  3this file handles the creation of this fst file, which we ship to the user
  4
  5this file relies on importing `get_all_tokenizers` and thus `MazeTokenizerModular`.
  6as such, loading this file for validating a tokenizer is the separate `maze_dataset.tokenization.modular.fst_load`
  7module, since we need to be able to import that from `maze_dataset.tokenization.modular.maze_tokenizer_modular` and
  8we cannot circularly import
  9
 10"""
 11
 12import functools
 13import random
 14
 15import tqdm
 16from muutils.misc.numerical import shorten_numerical_to_str
 17from muutils.parallel import run_maybe_parallel
 18from muutils.spinner import NoOpContextManager, SpinnerContext
 19from rust_fst import Set as FstSet  # type: ignore[import-untyped]
 20
 21from maze_dataset.tokenization.modular.all_tokenizers import get_all_tokenizers
 22from maze_dataset.tokenization.modular.fst_load import (
 23	MMT_FST_PATH,
 24	check_tokenizer_in_fst,
 25	get_tokenizers_fst,
 26)
 27
 28
 29def _get_tokenizer_name(tokenizer) -> str:  # noqa: ANN001
 30	return tokenizer.name
 31
 32
 33def save_all_tokenizers_fst(
 34	verbose: bool = True, parallel: bool | int = False
 35) -> FstSet:
 36	"""get all the tokenizers, save an fst file at `MMT_FST_PATH` and return the set"""
 37	# TYPING: add a protocol or abc for both of these which is a context manager that takes the args we care about
 38	# probably do this in muutils
 39	sp: type[SpinnerContext | NoOpContextManager] = (
 40		SpinnerContext if verbose else NoOpContextManager
 41	)
 42
 43	with sp(message="getting all tokenizers"):
 44		all_tokenizers: list = get_all_tokenizers()
 45
 46	n_tokenizers: int = len(all_tokenizers)
 47
 48	all_tokenizers_names: list[str] = run_maybe_parallel(
 49		func=_get_tokenizer_name,
 50		iterable=all_tokenizers,
 51		parallel=parallel,
 52		pbar=tqdm.tqdm,
 53		pbar_kwargs=dict(
 54			total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose
 55		),
 56	)
 57
 58	assert n_tokenizers == len(all_tokenizers_names)
 59	print(
 60		f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names"
 61	)
 62
 63	with sp(message="sorting tokenizer names"):
 64		all_tokenizers_names_sorted: list[str] = sorted(all_tokenizers_names)
 65
 66	# construct an fst set and save it
 67	# we expect it to be 1.6kb or so
 68	with sp(message="constructing and saving tokenizers fst set"):
 69		tok_set: FstSet = FstSet.from_iter(
 70			all_tokenizers_names_sorted,
 71			path=MMT_FST_PATH.as_posix(),
 72		)
 73
 74	print(
 75		f"# tokenizers fst set saved to {MMT_FST_PATH}, size: {MMT_FST_PATH.stat().st_size} bytes"
 76	)
 77
 78	return tok_set
 79
 80
 81def check_tokenizers_fst(
 82	verbose: bool = True,
 83	parallel: bool | int = False,
 84	n_check: int | None = None,
 85) -> FstSet:
 86	"regen all tokenizers, check they are in the pre-existing fst set"
 87	sp: type[SpinnerContext | NoOpContextManager] = (
 88		SpinnerContext if verbose else NoOpContextManager
 89	)
 90
 91	with sp(message="getting all tokenizers from scratch"):
 92		all_tokenizers: list = get_all_tokenizers()
 93
 94	with sp(message="load the pre-existing tokenizers fst set"):
 95		get_tokenizers_fst()
 96
 97	n_tokenizers: int = len(all_tokenizers)
 98
 99	selected_tokenizers: list
100	if n_check is not None:
101		selected_tokenizers = random.sample(all_tokenizers, n_check)
102	else:
103		selected_tokenizers = all_tokenizers
104
105	tokenizers_names: list[str] = run_maybe_parallel(
106		func=_get_tokenizer_name,
107		iterable=selected_tokenizers,
108		parallel=parallel,
109		pbar=tqdm.tqdm,
110		pbar_kwargs=dict(
111			total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose
112		),
113	)
114
115	if n_check is None:
116		assert n_tokenizers == len(tokenizers_names)
117		print(
118			f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names"
119		)
120	else:
121		assert n_check == len(tokenizers_names)
122		print(
123			f"# selected {n_check} tokenizers to check out of {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) total"
124		)
125
126	check_tokenizer_in_fst__do_except = functools.partial(
127		check_tokenizer_in_fst, do_except=True
128	)
129
130	run_maybe_parallel(
131		func=check_tokenizer_in_fst__do_except,
132		iterable=tokenizers_names,
133		parallel=parallel,
134		pbar=tqdm.tqdm,
135		pbar_kwargs=dict(
136			total=len(selected_tokenizers),
137			desc="checking tokenizers in fst",
138			disable=not verbose,
139		),
140	)
141
142	if n_check is None:
143		print("# all tokenizers are in the pre-existing fst set!")
144	else:
145		print(f"# all {n_check} selected tokenizers are in the pre-existing fst set!")
146
147
148if __name__ == "__main__":
149	import argparse
150
151	arg_parser: argparse.ArgumentParser = argparse.ArgumentParser(
152		description="save the tokenizers fst set"
153	)
154	arg_parser.add_argument(
155		"-c",
156		"--check",
157		action="store_true",
158		help="check that all tokenizers are in the pre-existing fst set",
159	)
160	arg_parser.add_argument(
161		"-q",
162		"--quiet",
163		action="store_true",
164		help="don't show spinners and progress bars",
165	)
166	arg_parser.add_argument(
167		"-p",
168		"--parallel",
169		action="store",
170		nargs="?",
171		type=int,
172		const=True,
173		default=False,
174		help="Control parallelization. will run in serial if nothing specified, use all cpus if flag passed without args, or number of cpus if int passed.",
175	)
176	arg_parser.add_argument(
177		"-n",
178		"--n-check",
179		action="store",
180		default=None,
181		help="if passed, check n random tokenizers. pass an int to check that many. pass 'none' or a -1 to check all",
182	)
183	args: argparse.Namespace = arg_parser.parse_args()
184
185	n_check: int | None = (
186		int(args.n_check)
187		if (args.n_check is not None and args.n_check.lower() != "none")
188		else None
189	)
190	if n_check is not None and n_check < 0:
191		n_check = None
192
193	if args.check:
194		check_tokenizers_fst(
195			verbose=not args.quiet,
196			parallel=args.parallel,
197			n_check=n_check,
198		)
199	else:
200		save_all_tokenizers_fst(verbose=not args.quiet, parallel=args.parallel)

def save_all_tokenizers_fst(verbose: bool = True, parallel: bool | int = False) -> rust_fst.set.Set:
34def save_all_tokenizers_fst(
35	verbose: bool = True, parallel: bool | int = False
36) -> FstSet:
37	"""get all the tokenizers, save an fst file at `MMT_FST_PATH` and return the set"""
38	# TYPING: add a protocol or abc for both of these which is a context manager that takes the args we care about
39	# probably do this in muutils
40	sp: type[SpinnerContext | NoOpContextManager] = (
41		SpinnerContext if verbose else NoOpContextManager
42	)
43
44	with sp(message="getting all tokenizers"):
45		all_tokenizers: list = get_all_tokenizers()
46
47	n_tokenizers: int = len(all_tokenizers)
48
49	all_tokenizers_names: list[str] = run_maybe_parallel(
50		func=_get_tokenizer_name,
51		iterable=all_tokenizers,
52		parallel=parallel,
53		pbar=tqdm.tqdm,
54		pbar_kwargs=dict(
55			total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose
56		),
57	)
58
59	assert n_tokenizers == len(all_tokenizers_names)
60	print(
61		f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names"
62	)
63
64	with sp(message="sorting tokenizer names"):
65		all_tokenizers_names_sorted: list[str] = sorted(all_tokenizers_names)
66
67	# construct an fst set and save it
68	# we expect it to be 1.6kb or so
69	with sp(message="constructing and saving tokenizers fst set"):
70		tok_set: FstSet = FstSet.from_iter(
71			all_tokenizers_names_sorted,
72			path=MMT_FST_PATH.as_posix(),
73		)
74
75	print(
76		f"# tokenizers fst set saved to {MMT_FST_PATH}, size: {MMT_FST_PATH.stat().st_size} bytes"
77	)
78
79	return tok_set

get all the tokenizers, save an fst file at MMT_FST_PATH and return the set

def check_tokenizers_fst( verbose: bool = True, parallel: bool | int = False, n_check: int | None = None) -> rust_fst.set.Set:
 82def check_tokenizers_fst(
 83	verbose: bool = True,
 84	parallel: bool | int = False,
 85	n_check: int | None = None,
 86) -> FstSet:
 87	"regen all tokenizers, check they are in the pre-existing fst set"
 88	sp: type[SpinnerContext | NoOpContextManager] = (
 89		SpinnerContext if verbose else NoOpContextManager
 90	)
 91
 92	with sp(message="getting all tokenizers from scratch"):
 93		all_tokenizers: list = get_all_tokenizers()
 94
 95	with sp(message="load the pre-existing tokenizers fst set"):
 96		get_tokenizers_fst()
 97
 98	n_tokenizers: int = len(all_tokenizers)
 99
100	selected_tokenizers: list
101	if n_check is not None:
102		selected_tokenizers = random.sample(all_tokenizers, n_check)
103	else:
104		selected_tokenizers = all_tokenizers
105
106	tokenizers_names: list[str] = run_maybe_parallel(
107		func=_get_tokenizer_name,
108		iterable=selected_tokenizers,
109		parallel=parallel,
110		pbar=tqdm.tqdm,
111		pbar_kwargs=dict(
112			total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose
113		),
114	)
115
116	if n_check is None:
117		assert n_tokenizers == len(tokenizers_names)
118		print(
119			f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names"
120		)
121	else:
122		assert n_check == len(tokenizers_names)
123		print(
124			f"# selected {n_check} tokenizers to check out of {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) total"
125		)
126
127	check_tokenizer_in_fst__do_except = functools.partial(
128		check_tokenizer_in_fst, do_except=True
129	)
130
131	run_maybe_parallel(
132		func=check_tokenizer_in_fst__do_except,
133		iterable=tokenizers_names,
134		parallel=parallel,
135		pbar=tqdm.tqdm,
136		pbar_kwargs=dict(
137			total=len(selected_tokenizers),
138			desc="checking tokenizers in fst",
139			disable=not verbose,
140		),
141	)
142
143	if n_check is None:
144		print("# all tokenizers are in the pre-existing fst set!")
145	else:
146		print(f"# all {n_check} selected tokenizers are in the pre-existing fst set!")

regen all tokenizers, check they are in the pre-existing fst set