docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.tokenization.modular.all_tokenizers

Contains get_all_tokenizers() and supporting limited-use functions.

get_all_tokenizers()

returns a comprehensive collection of all valid MazeTokenizerModular objects. This is an overwhelming majority subset of the set of all possible MazeTokenizerModular objects. Other tokenizers not contained in get_all_tokenizers() may be possible to construct, but they are untested and not guaranteed to work. This collection is in a separate module since it is expensive to compute and will grow more expensive as features are added to MazeTokenizerModular.

Use Cases

In general, uses for this module are limited to development of the library and specific research studying many tokenization behaviors.

  • Unit testing:
  • Large-scale tokenizer research:
    • Specific research training models on many tokenization behaviors can use get_all_tokenizers() as the maximally inclusive collection
    • get_all_tokenizers() may be subsequently filtered using MazeTokenizerModular.has_element For other uses, it's likely that the computational expense can be avoided by using
  • maze_tokenizer.get_all_tokenizer_hashes() for membership checks
  • utils.all_instances for generating smaller subsets of MazeTokenizerModular or _TokenizerElement objects

EVERY_TEST_TOKENIZERS

A collection of the tokenizers which should always be included in unit tests when test fuzzing is used. This collection should be expanded as specific tokenizers become canonical or popular.


  1"""Contains `get_all_tokenizers()` and supporting limited-use functions.
  2
  3# `get_all_tokenizers()`
  4returns a comprehensive collection of all valid `MazeTokenizerModular` objects.
  5This is an overwhelming majority subset of the set of all possible `MazeTokenizerModular` objects.
  6Other tokenizers not contained in `get_all_tokenizers()` may be possible to construct, but they are untested and not guaranteed to work.
  7This collection is in a separate module since it is expensive to compute and will grow more expensive as features are added to `MazeTokenizerModular`.
  8
  9## Use Cases
 10In general, uses for this module are limited to development of the library and specific research studying many tokenization behaviors.
 11- Unit testing:
 12  - Tokenizers to use in unit tests are sampled from `get_all_tokenizers()`
 13- Large-scale tokenizer research:
 14  - Specific research training models on many tokenization behaviors can use `get_all_tokenizers()` as the maximally inclusive collection
 15  - `get_all_tokenizers()` may be subsequently filtered using `MazeTokenizerModular.has_element`
 16For other uses, it's likely that the computational expense can be avoided by using
 17- `maze_tokenizer.get_all_tokenizer_hashes()` for membership checks
 18- `utils.all_instances` for generating smaller subsets of `MazeTokenizerModular` or `_TokenizerElement` objects
 19
 20# `EVERY_TEST_TOKENIZERS`
 21A collection of the tokenizers which should always be included in unit tests when test fuzzing is used.
 22This collection should be expanded as specific tokenizers become canonical or popular.
 23"""
 24
 25import functools
 26import multiprocessing
 27import random
 28from functools import cache
 29from pathlib import Path
 30from typing import Callable
 31
 32import frozendict
 33import numpy as np
 34from muutils.spinner import NoOpContextManager, SpinnerContext
 35from tqdm import tqdm
 36
 37from maze_dataset.tokenization import (
 38	CoordTokenizers,
 39	MazeTokenizerModular,
 40	PromptSequencers,
 41	StepTokenizers,
 42	_TokenizerElement,
 43)
 44from maze_dataset.tokenization.modular.all_instances import FiniteValued, all_instances
 45from maze_dataset.tokenization.modular.hashing import (
 46	AllTokenizersHashBitLength,
 47	AllTokenizersHashDtype,
 48	AllTokenizersHashesArray,
 49)
 50
 51# Always include this as the first item in the dict `validation_funcs` whenever using `all_instances` with `MazeTokenizerModular`
 52# TYPING: error: Type variable "maze_dataset.utils.FiniteValued" is unbound  [valid-type]
 53#   note: (Hint: Use "Generic[FiniteValued]" or "Protocol[FiniteValued]" base class to bind "FiniteValued" inside a class)
 54#   note: (Hint: Use "FiniteValued" in function signature to bind "FiniteValued" inside a function)
 55MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS: frozendict.frozendict[
 56	type[FiniteValued],
 57	Callable[[FiniteValued], bool],
 58] = frozendict.frozendict(
 59	{
 60		# TYPING: Item "bool" of the upper bound "bool | IsDataclass | Enum" of type variable "FiniteValued" has no attribute "is_valid"  [union-attr]
 61		_TokenizerElement: lambda x: x.is_valid(),
 62		# Currently no need for `MazeTokenizerModular.is_valid` since that method contains no special cases not already covered by `_TokenizerElement.is_valid`
 63		# MazeTokenizerModular: lambda x: x.is_valid(),
 64		# TYPING: error: No overload variant of "set" matches argument type "FiniteValued"  [call-overload]
 65		#   note: Possible overload variants:
 66		#   note:     def [_T] set(self) -> set[_T]
 67		#   note:     def [_T] set(self, Iterable[_T], /) -> set[_T]
 68		# TYPING: error: Argument 1 to "len" has incompatible type "FiniteValued"; expected "Sized"  [arg-type]
 69		StepTokenizers.StepTokenizerPermutation: lambda x: len(set(x)) == len(x)
 70		and x != (StepTokenizers.Distance(),),
 71	},
 72)
 73
 74DOWNLOAD_URL: str = "https://raw.githubusercontent.com/understanding-search/maze-dataset/main/maze_dataset/tokenization/MazeTokenizerModular_hashes.npz"
 75
 76
 77@cache
 78def get_all_tokenizers() -> list[MazeTokenizerModular]:
 79	"""Computes a complete list of all valid tokenizers.
 80
 81	Warning: This is an expensive function.
 82	"""
 83	return list(
 84		all_instances(
 85			MazeTokenizerModular,
 86			validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,
 87		),
 88	)
 89
 90
 91@cache
 92def get_all_tokenizers_names() -> list[str]:
 93	"""computes the sorted list of names of all tokenizers"""
 94	return sorted([tokenizer.name for tokenizer in get_all_tokenizers()])
 95
 96
 97EVERY_TEST_TOKENIZERS: list[MazeTokenizerModular] = [
 98	MazeTokenizerModular(),
 99	MazeTokenizerModular(
100		prompt_sequencer=PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.CTT()),
101	),
102	# TODO: add more here as specific tokenizers become canonical and frequently used
103]
104
105
106@cache
107def all_tokenizers_set() -> set[MazeTokenizerModular]:
108	"""Casts `get_all_tokenizers()` to a set."""
109	return set(get_all_tokenizers())
110
111
112@cache
113def _all_tokenizers_except_every_test_tokenizers() -> list[MazeTokenizerModular]:
114	"""Returns"""
115	return list(all_tokenizers_set().difference(EVERY_TEST_TOKENIZERS))
116
117
118def sample_all_tokenizers(n: int) -> list[MazeTokenizerModular]:
119	"""Samples `n` tokenizers from `get_all_tokenizers()`."""
120	return random.sample(get_all_tokenizers(), n)
121
122
123def sample_tokenizers_for_test(n: int | None) -> list[MazeTokenizerModular]:
124	"""Returns a sample of size `n` of unique elements from `get_all_tokenizers()`,
125
126	always including every element in `EVERY_TEST_TOKENIZERS`.
127	"""
128	if n is None:
129		return get_all_tokenizers()
130
131	if n < len(EVERY_TEST_TOKENIZERS):
132		err_msg: str = f"`n` must be at least {len(EVERY_TEST_TOKENIZERS) = } such that the sample can contain `EVERY_TEST_TOKENIZERS`."
133		raise ValueError(
134			err_msg,
135		)
136	sample: list[MazeTokenizerModular] = random.sample(
137		_all_tokenizers_except_every_test_tokenizers(),
138		n - len(EVERY_TEST_TOKENIZERS),
139	)
140	sample.extend(EVERY_TEST_TOKENIZERS)
141	return sample
142
143
144def save_hashes(
145	path: Path | None = None,
146	verbose: bool = False,
147	parallelize: bool | int = False,
148) -> AllTokenizersHashesArray:
149	"""Computes, sorts, and saves the hashes of every member of `get_all_tokenizers()`."""
150	spinner = (
151		functools.partial(SpinnerContext, spinner_chars="square_dot")
152		if verbose
153		else NoOpContextManager
154	)
155
156	# get all tokenizers
157	with spinner(initial_value="getting all tokenizers...", update_interval=2.0):
158		all_tokenizers = get_all_tokenizers()
159
160	# compute hashes
161	hashes_array_np64: AllTokenizersHashesArray
162	if parallelize:
163		n_cpus: int = (
164			parallelize if int(parallelize) > 1 else multiprocessing.cpu_count()
165		)
166		with spinner(  # noqa: SIM117
167			initial_value=f"using {n_cpus} processes to compute {len(all_tokenizers)} tokenizer hashes...",
168			update_interval=2.0,
169		):
170			with multiprocessing.Pool(processes=n_cpus) as pool:
171				hashes_list: list[int] = list(pool.map(hash, all_tokenizers))
172
173		with spinner(initial_value="converting hashes to numpy array..."):
174			hashes_array_np64 = np.array(hashes_list, dtype=np.int64)
175	else:
176		with spinner(
177			initial_value=f"computing {len(all_tokenizers)} tokenizer hashes...",
178		):
179			hashes_array_np64 = np.array(
180				[
181					hash(obj)  # uses stable hash
182					for obj in tqdm(all_tokenizers, disable=not verbose)
183				],
184				dtype=np.int64,
185			)
186
187	# convert to correct dtype
188	hashes_array: AllTokenizersHashesArray = (
189		hashes_array_np64 % (1 << AllTokenizersHashBitLength)
190		if AllTokenizersHashBitLength < 64  # noqa: PLR2004
191		else hashes_array_np64
192	).astype(AllTokenizersHashDtype)
193
194	# make sure there are no dupes
195	with spinner(initial_value="sorting and checking for hash collisions..."):
196		sorted_hashes, counts = np.unique(hashes_array, return_counts=True)
197		if sorted_hashes.shape[0] != hashes_array.shape[0]:
198			collisions: np.array = sorted_hashes[counts > 1]
199			n_collisions: int = hashes_array.shape[0] - sorted_hashes.shape[0]
200			err_msg: str = (
201				f"{n_collisions} tokenizer hash collisions: {collisions}\n"
202				"Report error to the developer to increase the hash size or otherwise update the tokenizer hashing size:\n"
203				f"https://github.com/understanding-search/maze-dataset/issues/new?labels=bug,tokenization&title=Tokenizer+hash+collision+error&body={n_collisions}+collisions+out+of+{hashes_array.shape[0]}+total+hashes",
204			)
205
206			raise ValueError(
207				err_msg,
208			)
209
210	# save and return
211	with spinner(initial_value="saving hashes...", update_interval=0.5):
212		if path is None:
213			path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
214		np.savez_compressed(
215			path,
216			hashes=sorted_hashes,
217		)
218
219	return sorted_hashes

MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS: frozendict.frozendict[type[~FiniteValued], typing.Callable[[~FiniteValued], bool]] = frozendict.frozendict({<class 'maze_dataset.tokenization.modular.element_base._TokenizerElement'>: <function <lambda>>, tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer] | tuple[maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer, maze_dataset.tokenization.modular.elements.StepTokenizers._StepTokenizer]: <function <lambda>>})
DOWNLOAD_URL: str = 'https://raw.githubusercontent.com/understanding-search/maze-dataset/main/maze_dataset/tokenization/MazeTokenizerModular_hashes.npz'
@cache
def get_all_tokenizers() -> list[maze_dataset.tokenization.MazeTokenizerModular]:
78@cache
79def get_all_tokenizers() -> list[MazeTokenizerModular]:
80	"""Computes a complete list of all valid tokenizers.
81
82	Warning: This is an expensive function.
83	"""
84	return list(
85		all_instances(
86			MazeTokenizerModular,
87			validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS,
88		),
89	)

Computes a complete list of all valid tokenizers.

Warning: This is an expensive function.

@cache
def get_all_tokenizers_names() -> list[str]:
92@cache
93def get_all_tokenizers_names() -> list[str]:
94	"""computes the sorted list of names of all tokenizers"""
95	return sorted([tokenizer.name for tokenizer in get_all_tokenizers()])

computes the sorted list of names of all tokenizers

EVERY_TEST_TOKENIZERS: list[maze_dataset.tokenization.MazeTokenizerModular] = [MazeTokenizerModular(prompt_sequencer=PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.UT(), adj_list_tokenizer=AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), target_tokenizer=TargetTokenizers.Unlabeled(post=False), path_tokenizer=PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False))), MazeTokenizerModular(prompt_sequencer=PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.CTT(pre=True, intra=True, post=True), adj_list_tokenizer=AdjListTokenizers.AdjListCoord(pre=False, post=True, shuffle_d0=True, edge_grouping=EdgeGroupings.Ungrouped(connection_token_ordinal=1), edge_subset=EdgeSubsets.ConnectionEdges(walls=False), edge_permuter=EdgePermuters.RandomCoords()), target_tokenizer=TargetTokenizers.Unlabeled(post=False), path_tokenizer=PathTokenizers.StepSequence(step_size=StepSizes.Singles(), step_tokenizers=(StepTokenizers.Coord(),), pre=False, intra=False, post=False)))]
@cache
def all_tokenizers_set() -> set[maze_dataset.tokenization.MazeTokenizerModular]:
107@cache
108def all_tokenizers_set() -> set[MazeTokenizerModular]:
109	"""Casts `get_all_tokenizers()` to a set."""
110	return set(get_all_tokenizers())

Casts get_all_tokenizers() to a set.

def sample_all_tokenizers( n: int) -> list[maze_dataset.tokenization.MazeTokenizerModular]:
119def sample_all_tokenizers(n: int) -> list[MazeTokenizerModular]:
120	"""Samples `n` tokenizers from `get_all_tokenizers()`."""
121	return random.sample(get_all_tokenizers(), n)

Samples n tokenizers from get_all_tokenizers().

def sample_tokenizers_for_test( n: int | None) -> list[maze_dataset.tokenization.MazeTokenizerModular]:
124def sample_tokenizers_for_test(n: int | None) -> list[MazeTokenizerModular]:
125	"""Returns a sample of size `n` of unique elements from `get_all_tokenizers()`,
126
127	always including every element in `EVERY_TEST_TOKENIZERS`.
128	"""
129	if n is None:
130		return get_all_tokenizers()
131
132	if n < len(EVERY_TEST_TOKENIZERS):
133		err_msg: str = f"`n` must be at least {len(EVERY_TEST_TOKENIZERS) = } such that the sample can contain `EVERY_TEST_TOKENIZERS`."
134		raise ValueError(
135			err_msg,
136		)
137	sample: list[MazeTokenizerModular] = random.sample(
138		_all_tokenizers_except_every_test_tokenizers(),
139		n - len(EVERY_TEST_TOKENIZERS),
140	)
141	sample.extend(EVERY_TEST_TOKENIZERS)
142	return sample

Returns a sample of size n of unique elements from get_all_tokenizers(),

always including every element in EVERY_TEST_TOKENIZERS.

def save_hashes( path: pathlib.Path | None = None, verbose: bool = False, parallelize: bool | int = False) -> jaxtyping.UInt32[ndarray, 'n_tokens']:
145def save_hashes(
146	path: Path | None = None,
147	verbose: bool = False,
148	parallelize: bool | int = False,
149) -> AllTokenizersHashesArray:
150	"""Computes, sorts, and saves the hashes of every member of `get_all_tokenizers()`."""
151	spinner = (
152		functools.partial(SpinnerContext, spinner_chars="square_dot")
153		if verbose
154		else NoOpContextManager
155	)
156
157	# get all tokenizers
158	with spinner(initial_value="getting all tokenizers...", update_interval=2.0):
159		all_tokenizers = get_all_tokenizers()
160
161	# compute hashes
162	hashes_array_np64: AllTokenizersHashesArray
163	if parallelize:
164		n_cpus: int = (
165			parallelize if int(parallelize) > 1 else multiprocessing.cpu_count()
166		)
167		with spinner(  # noqa: SIM117
168			initial_value=f"using {n_cpus} processes to compute {len(all_tokenizers)} tokenizer hashes...",
169			update_interval=2.0,
170		):
171			with multiprocessing.Pool(processes=n_cpus) as pool:
172				hashes_list: list[int] = list(pool.map(hash, all_tokenizers))
173
174		with spinner(initial_value="converting hashes to numpy array..."):
175			hashes_array_np64 = np.array(hashes_list, dtype=np.int64)
176	else:
177		with spinner(
178			initial_value=f"computing {len(all_tokenizers)} tokenizer hashes...",
179		):
180			hashes_array_np64 = np.array(
181				[
182					hash(obj)  # uses stable hash
183					for obj in tqdm(all_tokenizers, disable=not verbose)
184				],
185				dtype=np.int64,
186			)
187
188	# convert to correct dtype
189	hashes_array: AllTokenizersHashesArray = (
190		hashes_array_np64 % (1 << AllTokenizersHashBitLength)
191		if AllTokenizersHashBitLength < 64  # noqa: PLR2004
192		else hashes_array_np64
193	).astype(AllTokenizersHashDtype)
194
195	# make sure there are no dupes
196	with spinner(initial_value="sorting and checking for hash collisions..."):
197		sorted_hashes, counts = np.unique(hashes_array, return_counts=True)
198		if sorted_hashes.shape[0] != hashes_array.shape[0]:
199			collisions: np.array = sorted_hashes[counts > 1]
200			n_collisions: int = hashes_array.shape[0] - sorted_hashes.shape[0]
201			err_msg: str = (
202				f"{n_collisions} tokenizer hash collisions: {collisions}\n"
203				"Report error to the developer to increase the hash size or otherwise update the tokenizer hashing size:\n"
204				f"https://github.com/understanding-search/maze-dataset/issues/new?labels=bug,tokenization&title=Tokenizer+hash+collision+error&body={n_collisions}+collisions+out+of+{hashes_array.shape[0]}+total+hashes",
205			)
206
207			raise ValueError(
208				err_msg,
209			)
210
211	# save and return
212	with spinner(initial_value="saving hashes...", update_interval=0.5):
213		if path is None:
214			path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
215		np.savez_compressed(
216			path,
217			hashes=sorted_hashes,
218		)
219
220	return sorted_hashes

Computes, sorts, and saves the hashes of every member of get_all_tokenizers().