docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.tokenization.modular.hashing

legacy system for checking a ModularMazeTokenizer is valid -- compare its hash to a table of known hashes

this has been superseded by the fst system


 1"""legacy system for checking a `ModularMazeTokenizer` is valid -- compare its hash to a table of known hashes
 2
 3this has been superseded by the fst system
 4
 5"""
 6
 7import hashlib
 8from pathlib import Path
 9
10import numpy as np
11from jaxtyping import UInt32
12
13# NOTE: these all need to match!
14
15AllTokenizersHashBitLength = 32
16"bit length of the hashes of all tokenizers, must match `AllTokenizersHashDtype` and `AllTokenizersHashesArray`"
17
18AllTokenizersHashDtype = np.uint32
19"numpy data type of the hashes of all tokenizers, must match `AllTokenizersHashBitLength` and `AllTokenizersHashesArray`"
20
21AllTokenizersHashesArray = UInt32[np.ndarray, " n_tokens"]
22"jaxtyping type of the hashes of all tokenizers, must match `AllTokenizersHashBitLength` and `AllTokenizersHashDtype`"
23
24
25def _hash_tokenizer_name(s: str) -> int:
26	h64: int = int.from_bytes(
27		hashlib.shake_256(s.encode("utf-8")).digest(64),
28		byteorder="big",
29	)
30	return (h64 >> 32) ^ (h64 & 0xFFFFFFFF)
31
32
33_ALL_TOKENIZER_HASHES: AllTokenizersHashesArray
34"private array of all tokenizer hashes"
35_TOKENIZER_HASHES_PATH: Path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz"
36"path to where we expect the hashes file -- in the same dir as this file, by default. change with `set_tokenizer_hashes_path`"
37
38
39def set_tokenizer_hashes_path(path: Path) -> None:
40	"""set path to tokenizer hashes, and reload the hashes if needed
41
42	the hashes are expected to be stored in and read from `_TOKENIZER_HASHES_PATH`,
43	which by default is `Path(__file__).parent / "MazeTokenizerModular_hashes.npz"` or in this file's directory.
44
45	However, this might not always work, so we provide a way to change this.
46	"""
47	global _TOKENIZER_HASHES_PATH, _ALL_TOKENIZER_HASHES  # noqa: PLW0603
48
49	path = Path(path)
50	if path.is_dir():
51		path = path / "MazeTokenizerModular_hashes.npz"
52
53	if not path.is_file():
54		err_msg: str = f"could not find maze tokenizer hashes file at: {path}"
55		raise FileNotFoundError(err_msg)
56
57	if _TOKENIZER_HASHES_PATH.absolute() != path.absolute():
58		# reload if they aren't equal
59		_TOKENIZER_HASHES_PATH = path
60		_ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
61	else:
62		# always set to new path
63		_TOKENIZER_HASHES_PATH = path
64
65
66def _load_tokenizer_hashes() -> AllTokenizersHashesArray:
67	"""Loads the sorted list of `all_tokenizers.get_all_tokenizers()` hashes from disk."""
68	global _TOKENIZER_HASHES_PATH  # noqa: PLW0602
69	try:
70		path: Path = _TOKENIZER_HASHES_PATH
71		return np.load(path)["hashes"]
72	except FileNotFoundError as e:
73		err_msg: str = (
74			"Tokenizers hashes cannot be loaded. To fix this, run"
75			"\n`python -m maze-dataset.tokenization.save_hashes` which will save the hashes to"
76			"\n`data/MazeTokenizerModular_hashes.npz`"
77			"\nrelative to the current working directory -- this is where the code looks for them."
78		)
79		raise FileNotFoundError(err_msg) from e
80
81
82def get_all_tokenizer_hashes() -> AllTokenizersHashesArray:
83	"""returns all the tokenizer hashes in an `AllTokenizersHashesDtype` array, setting global variable if needed"""
84	# naughty use of globals
85	global _ALL_TOKENIZER_HASHES  # noqa: PLW0603
86	try:
87		got_tokenizers: bool = len(_ALL_TOKENIZER_HASHES) > 0
88		if got_tokenizers:
89			return _ALL_TOKENIZER_HASHES
90		else:
91			_ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
92	except NameError:
93		_ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
94
95	return _ALL_TOKENIZER_HASHES

AllTokenizersHashBitLength = 32

bit length of the hashes of all tokenizers, must match AllTokenizersHashDtype and AllTokenizersHashesArray

AllTokenizersHashDtype = <class 'numpy.uint32'>

numpy data type of the hashes of all tokenizers, must match AllTokenizersHashBitLength and AllTokenizersHashesArray

AllTokenizersHashesArray = <class 'jaxtyping.UInt32[ndarray, 'n_tokens']'>

jaxtyping type of the hashes of all tokenizers, must match AllTokenizersHashBitLength and AllTokenizersHashDtype

def set_tokenizer_hashes_path(path: pathlib.Path) -> None:
40def set_tokenizer_hashes_path(path: Path) -> None:
41	"""set path to tokenizer hashes, and reload the hashes if needed
42
43	the hashes are expected to be stored in and read from `_TOKENIZER_HASHES_PATH`,
44	which by default is `Path(__file__).parent / "MazeTokenizerModular_hashes.npz"` or in this file's directory.
45
46	However, this might not always work, so we provide a way to change this.
47	"""
48	global _TOKENIZER_HASHES_PATH, _ALL_TOKENIZER_HASHES  # noqa: PLW0603
49
50	path = Path(path)
51	if path.is_dir():
52		path = path / "MazeTokenizerModular_hashes.npz"
53
54	if not path.is_file():
55		err_msg: str = f"could not find maze tokenizer hashes file at: {path}"
56		raise FileNotFoundError(err_msg)
57
58	if _TOKENIZER_HASHES_PATH.absolute() != path.absolute():
59		# reload if they aren't equal
60		_TOKENIZER_HASHES_PATH = path
61		_ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
62	else:
63		# always set to new path
64		_TOKENIZER_HASHES_PATH = path

set path to tokenizer hashes, and reload the hashes if needed

the hashes are expected to be stored in and read from _TOKENIZER_HASHES_PATH, which by default is Path(__file__).parent / "MazeTokenizerModular_hashes.npz" or in this file's directory.

However, this might not always work, so we provide a way to change this.

def get_all_tokenizer_hashes() -> jaxtyping.UInt32[ndarray, 'n_tokens']:
83def get_all_tokenizer_hashes() -> AllTokenizersHashesArray:
84	"""returns all the tokenizer hashes in an `AllTokenizersHashesDtype` array, setting global variable if needed"""
85	# naughty use of globals
86	global _ALL_TOKENIZER_HASHES  # noqa: PLW0603
87	try:
88		got_tokenizers: bool = len(_ALL_TOKENIZER_HASHES) > 0
89		if got_tokenizers:
90			return _ALL_TOKENIZER_HASHES
91		else:
92			_ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
93	except NameError:
94		_ALL_TOKENIZER_HASHES = _load_tokenizer_hashes()
95
96	return _ALL_TOKENIZER_HASHES

returns all the tokenizer hashes in an AllTokenizersHashesDtype array, setting global variable if needed