maze_dataset.tokenization.modular.all_tokenizers
Contains get_all_tokenizers()
and supporting limited-use functions.
get_all_tokenizers()
returns a comprehensive collection of all valid MazeTokenizerModular
objects.
This is an overwhelming majority subset of the set of all possible MazeTokenizerModular
objects.
Other tokenizers not contained in get_all_tokenizers()
may be possible to construct, but they are untested and not guaranteed to work.
This collection is in a separate module since it is expensive to compute and will grow more expensive as features are added to MazeTokenizerModular
.
Use Cases
In general, uses for this module are limited to development of the library and specific research studying many tokenization behaviors.
- Unit testing:
- Tokenizers to use in unit tests are sampled from
get_all_tokenizers()
- Tokenizers to use in unit tests are sampled from
- Large-scale tokenizer research:
- Specific research training models on many tokenization behaviors can use
get_all_tokenizers()
as the maximally inclusive collection get_all_tokenizers()
may be subsequently filtered usingMazeTokenizerModular.has_element
For other uses, it's likely that the computational expense can be avoided by using
- Specific research training models on many tokenization behaviors can use
maze_tokenizer.get_all_tokenizer_hashes()
for membership checksutils.all_instances
for generating smaller subsets ofMazeTokenizerModular
or_TokenizerElement
objects
EVERY_TEST_TOKENIZERS
A collection of the tokenizers which should always be included in unit tests when test fuzzing is used. This collection should be expanded as specific tokenizers become canonical or popular.
1"""Contains `get_all_tokenizers()` and supporting limited-use functions. 2 3# `get_all_tokenizers()` 4returns a comprehensive collection of all valid `MazeTokenizerModular` objects. 5This is an overwhelming majority subset of the set of all possible `MazeTokenizerModular` objects. 6Other tokenizers not contained in `get_all_tokenizers()` may be possible to construct, but they are untested and not guaranteed to work. 7This collection is in a separate module since it is expensive to compute and will grow more expensive as features are added to `MazeTokenizerModular`. 8 9## Use Cases 10In general, uses for this module are limited to development of the library and specific research studying many tokenization behaviors. 11- Unit testing: 12 - Tokenizers to use in unit tests are sampled from `get_all_tokenizers()` 13- Large-scale tokenizer research: 14 - Specific research training models on many tokenization behaviors can use `get_all_tokenizers()` as the maximally inclusive collection 15 - `get_all_tokenizers()` may be subsequently filtered using `MazeTokenizerModular.has_element` 16For other uses, it's likely that the computational expense can be avoided by using 17- `maze_tokenizer.get_all_tokenizer_hashes()` for membership checks 18- `utils.all_instances` for generating smaller subsets of `MazeTokenizerModular` or `_TokenizerElement` objects 19 20# `EVERY_TEST_TOKENIZERS` 21A collection of the tokenizers which should always be included in unit tests when test fuzzing is used. 22This collection should be expanded as specific tokenizers become canonical or popular. 23""" 24 25import functools 26import multiprocessing 27import random 28from functools import cache 29from pathlib import Path 30from typing import Callable 31 32import frozendict 33import numpy as np 34from muutils.spinner import NoOpContextManager, SpinnerContext 35from tqdm import tqdm 36 37from maze_dataset.tokenization import ( 38 CoordTokenizers, 39 MazeTokenizerModular, 40 PromptSequencers, 41 StepTokenizers, 42 _TokenizerElement, 43) 44from maze_dataset.tokenization.modular.all_instances import FiniteValued, all_instances 45from maze_dataset.tokenization.modular.hashing import ( 46 AllTokenizersHashBitLength, 47 AllTokenizersHashDtype, 48 AllTokenizersHashesArray, 49) 50 51# Always include this as the first item in the dict `validation_funcs` whenever using `all_instances` with `MazeTokenizerModular` 52# TYPING: error: Type variable "maze_dataset.utils.FiniteValued" is unbound [valid-type] 53# note: (Hint: Use "Generic[FiniteValued]" or "Protocol[FiniteValued]" base class to bind "FiniteValued" inside a class) 54# note: (Hint: Use "FiniteValued" in function signature to bind "FiniteValued" inside a function) 55MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS: frozendict.frozendict[ 56 type[FiniteValued], 57 Callable[[FiniteValued], bool], 58] = frozendict.frozendict( 59 { 60 # TYPING: Item "bool" of the upper bound "bool | IsDataclass | Enum" of type variable "FiniteValued" has no attribute "is_valid" [union-attr] 61 _TokenizerElement: lambda x: x.is_valid(), 62 # Currently no need for `MazeTokenizerModular.is_valid` since that method contains no special cases not already covered by `_TokenizerElement.is_valid` 63 # MazeTokenizerModular: lambda x: x.is_valid(), 64 # TYPING: error: No overload variant of "set" matches argument type "FiniteValued" [call-overload] 65 # note: Possible overload variants: 66 # note: def [_T] set(self) -> set[_T] 67 # note: def [_T] set(self, Iterable[_T], /) -> set[_T] 68 # TYPING: error: Argument 1 to "len" has incompatible type "FiniteValued"; expected "Sized" [arg-type] 69 StepTokenizers.StepTokenizerPermutation: lambda x: len(set(x)) == len(x) 70 and x != (StepTokenizers.Distance(),), 71 }, 72) 73 74DOWNLOAD_URL: str = "https://raw.githubusercontent.com/understanding-search/maze-dataset/main/maze_dataset/tokenization/MazeTokenizerModular_hashes.npz" 75 76 77@cache 78def get_all_tokenizers() -> list[MazeTokenizerModular]: 79 """Computes a complete list of all valid tokenizers. 80 81 Warning: This is an expensive function. 82 """ 83 return list( 84 all_instances( 85 MazeTokenizerModular, 86 validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, 87 ), 88 ) 89 90 91@cache 92def get_all_tokenizers_names() -> list[str]: 93 """computes the sorted list of names of all tokenizers""" 94 return sorted([tokenizer.name for tokenizer in get_all_tokenizers()]) 95 96 97EVERY_TEST_TOKENIZERS: list[MazeTokenizerModular] = [ 98 MazeTokenizerModular(), 99 MazeTokenizerModular( 100 prompt_sequencer=PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.CTT()), 101 ), 102 # TODO: add more here as specific tokenizers become canonical and frequently used 103] 104 105 106@cache 107def all_tokenizers_set() -> set[MazeTokenizerModular]: 108 """Casts `get_all_tokenizers()` to a set.""" 109 return set(get_all_tokenizers()) 110 111 112@cache 113def _all_tokenizers_except_every_test_tokenizers() -> list[MazeTokenizerModular]: 114 """Returns""" 115 return list(all_tokenizers_set().difference(EVERY_TEST_TOKENIZERS)) 116 117 118def sample_all_tokenizers(n: int) -> list[MazeTokenizerModular]: 119 """Samples `n` tokenizers from `get_all_tokenizers()`.""" 120 return random.sample(get_all_tokenizers(), n) 121 122 123def sample_tokenizers_for_test(n: int | None) -> list[MazeTokenizerModular]: 124 """Returns a sample of size `n` of unique elements from `get_all_tokenizers()`, 125 126 always including every element in `EVERY_TEST_TOKENIZERS`. 127 """ 128 if n is None: 129 return get_all_tokenizers() 130 131 if n < len(EVERY_TEST_TOKENIZERS): 132 err_msg: str = f"`n` must be at least {len(EVERY_TEST_TOKENIZERS) = } such that the sample can contain `EVERY_TEST_TOKENIZERS`." 133 raise ValueError( 134 err_msg, 135 ) 136 sample: list[MazeTokenizerModular] = random.sample( 137 _all_tokenizers_except_every_test_tokenizers(), 138 n - len(EVERY_TEST_TOKENIZERS), 139 ) 140 sample.extend(EVERY_TEST_TOKENIZERS) 141 return sample 142 143 144def save_hashes( 145 path: Path | None = None, 146 verbose: bool = False, 147 parallelize: bool | int = False, 148) -> AllTokenizersHashesArray: 149 """Computes, sorts, and saves the hashes of every member of `get_all_tokenizers()`.""" 150 spinner = ( 151 functools.partial(SpinnerContext, spinner_chars="square_dot") 152 if verbose 153 else NoOpContextManager 154 ) 155 156 # get all tokenizers 157 with spinner(initial_value="getting all tokenizers...", update_interval=2.0): 158 all_tokenizers = get_all_tokenizers() 159 160 # compute hashes 161 hashes_array_np64: AllTokenizersHashesArray 162 if parallelize: 163 n_cpus: int = ( 164 parallelize if int(parallelize) > 1 else multiprocessing.cpu_count() 165 ) 166 with spinner( # noqa: SIM117 167 initial_value=f"using {n_cpus} processes to compute {len(all_tokenizers)} tokenizer hashes...", 168 update_interval=2.0, 169 ): 170 with multiprocessing.Pool(processes=n_cpus) as pool: 171 hashes_list: list[int] = list(pool.map(hash, all_tokenizers)) 172 173 with spinner(initial_value="converting hashes to numpy array..."): 174 hashes_array_np64 = np.array(hashes_list, dtype=np.int64) 175 else: 176 with spinner( 177 initial_value=f"computing {len(all_tokenizers)} tokenizer hashes...", 178 ): 179 hashes_array_np64 = np.array( 180 [ 181 hash(obj) # uses stable hash 182 for obj in tqdm(all_tokenizers, disable=not verbose) 183 ], 184 dtype=np.int64, 185 ) 186 187 # convert to correct dtype 188 hashes_array: AllTokenizersHashesArray = ( 189 hashes_array_np64 % (1 << AllTokenizersHashBitLength) 190 if AllTokenizersHashBitLength < 64 # noqa: PLR2004 191 else hashes_array_np64 192 ).astype(AllTokenizersHashDtype) 193 194 # make sure there are no dupes 195 with spinner(initial_value="sorting and checking for hash collisions..."): 196 sorted_hashes, counts = np.unique(hashes_array, return_counts=True) 197 if sorted_hashes.shape[0] != hashes_array.shape[0]: 198 collisions: np.array = sorted_hashes[counts > 1] 199 n_collisions: int = hashes_array.shape[0] - sorted_hashes.shape[0] 200 err_msg: str = ( 201 f"{n_collisions} tokenizer hash collisions: {collisions}\n" 202 "Report error to the developer to increase the hash size or otherwise update the tokenizer hashing size:\n" 203 f"https://github.com/understanding-search/maze-dataset/issues/new?labels=bug,tokenization&title=Tokenizer+hash+collision+error&body={n_collisions}+collisions+out+of+{hashes_array.shape[0]}+total+hashes", 204 ) 205 206 raise ValueError( 207 err_msg, 208 ) 209 210 # save and return 211 with spinner(initial_value="saving hashes...", update_interval=0.5): 212 if path is None: 213 path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz" 214 np.savez_compressed( 215 path, 216 hashes=sorted_hashes, 217 ) 218 219 return sorted_hashes
78@cache 79def get_all_tokenizers() -> list[MazeTokenizerModular]: 80 """Computes a complete list of all valid tokenizers. 81 82 Warning: This is an expensive function. 83 """ 84 return list( 85 all_instances( 86 MazeTokenizerModular, 87 validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, 88 ), 89 )
Computes a complete list of all valid tokenizers.
Warning: This is an expensive function.
92@cache 93def get_all_tokenizers_names() -> list[str]: 94 """computes the sorted list of names of all tokenizers""" 95 return sorted([tokenizer.name for tokenizer in get_all_tokenizers()])
computes the sorted list of names of all tokenizers
107@cache 108def all_tokenizers_set() -> set[MazeTokenizerModular]: 109 """Casts `get_all_tokenizers()` to a set.""" 110 return set(get_all_tokenizers())
Casts get_all_tokenizers()
to a set.
119def sample_all_tokenizers(n: int) -> list[MazeTokenizerModular]: 120 """Samples `n` tokenizers from `get_all_tokenizers()`.""" 121 return random.sample(get_all_tokenizers(), n)
Samples n
tokenizers from get_all_tokenizers()
.
124def sample_tokenizers_for_test(n: int | None) -> list[MazeTokenizerModular]: 125 """Returns a sample of size `n` of unique elements from `get_all_tokenizers()`, 126 127 always including every element in `EVERY_TEST_TOKENIZERS`. 128 """ 129 if n is None: 130 return get_all_tokenizers() 131 132 if n < len(EVERY_TEST_TOKENIZERS): 133 err_msg: str = f"`n` must be at least {len(EVERY_TEST_TOKENIZERS) = } such that the sample can contain `EVERY_TEST_TOKENIZERS`." 134 raise ValueError( 135 err_msg, 136 ) 137 sample: list[MazeTokenizerModular] = random.sample( 138 _all_tokenizers_except_every_test_tokenizers(), 139 n - len(EVERY_TEST_TOKENIZERS), 140 ) 141 sample.extend(EVERY_TEST_TOKENIZERS) 142 return sample
Returns a sample of size n
of unique elements from get_all_tokenizers()
,
always including every element in EVERY_TEST_TOKENIZERS
.
145def save_hashes( 146 path: Path | None = None, 147 verbose: bool = False, 148 parallelize: bool | int = False, 149) -> AllTokenizersHashesArray: 150 """Computes, sorts, and saves the hashes of every member of `get_all_tokenizers()`.""" 151 spinner = ( 152 functools.partial(SpinnerContext, spinner_chars="square_dot") 153 if verbose 154 else NoOpContextManager 155 ) 156 157 # get all tokenizers 158 with spinner(initial_value="getting all tokenizers...", update_interval=2.0): 159 all_tokenizers = get_all_tokenizers() 160 161 # compute hashes 162 hashes_array_np64: AllTokenizersHashesArray 163 if parallelize: 164 n_cpus: int = ( 165 parallelize if int(parallelize) > 1 else multiprocessing.cpu_count() 166 ) 167 with spinner( # noqa: SIM117 168 initial_value=f"using {n_cpus} processes to compute {len(all_tokenizers)} tokenizer hashes...", 169 update_interval=2.0, 170 ): 171 with multiprocessing.Pool(processes=n_cpus) as pool: 172 hashes_list: list[int] = list(pool.map(hash, all_tokenizers)) 173 174 with spinner(initial_value="converting hashes to numpy array..."): 175 hashes_array_np64 = np.array(hashes_list, dtype=np.int64) 176 else: 177 with spinner( 178 initial_value=f"computing {len(all_tokenizers)} tokenizer hashes...", 179 ): 180 hashes_array_np64 = np.array( 181 [ 182 hash(obj) # uses stable hash 183 for obj in tqdm(all_tokenizers, disable=not verbose) 184 ], 185 dtype=np.int64, 186 ) 187 188 # convert to correct dtype 189 hashes_array: AllTokenizersHashesArray = ( 190 hashes_array_np64 % (1 << AllTokenizersHashBitLength) 191 if AllTokenizersHashBitLength < 64 # noqa: PLR2004 192 else hashes_array_np64 193 ).astype(AllTokenizersHashDtype) 194 195 # make sure there are no dupes 196 with spinner(initial_value="sorting and checking for hash collisions..."): 197 sorted_hashes, counts = np.unique(hashes_array, return_counts=True) 198 if sorted_hashes.shape[0] != hashes_array.shape[0]: 199 collisions: np.array = sorted_hashes[counts > 1] 200 n_collisions: int = hashes_array.shape[0] - sorted_hashes.shape[0] 201 err_msg: str = ( 202 f"{n_collisions} tokenizer hash collisions: {collisions}\n" 203 "Report error to the developer to increase the hash size or otherwise update the tokenizer hashing size:\n" 204 f"https://github.com/understanding-search/maze-dataset/issues/new?labels=bug,tokenization&title=Tokenizer+hash+collision+error&body={n_collisions}+collisions+out+of+{hashes_array.shape[0]}+total+hashes", 205 ) 206 207 raise ValueError( 208 err_msg, 209 ) 210 211 # save and return 212 with spinner(initial_value="saving hashes...", update_interval=0.5): 213 if path is None: 214 path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz" 215 np.savez_compressed( 216 path, 217 hashes=sorted_hashes, 218 ) 219 220 return sorted_hashes
Computes, sorts, and saves the hashes of every member of get_all_tokenizers()
.