Coverage for maze_dataset/tokenization/modular/save_hashes.py: 0%
32 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
1"""generate and save the hashes of all supported tokenizers
3> [!CAUTION]
4> using hashes to validate validate a `MazeTokenizerModular` is deprecated in favor of using fst
6calls `maze_dataset.tokenization.all_tokenizers.save_hashes()`
8Usage:
10To save to the default location (inside package, `maze_dataset/tokenization/MazeTokenizerModular_hashes.npy`):
11```bash
12python -m maze_dataset.tokenization.save_hashes
13```
15to save to a custom location:
16```bash
17python -m maze_dataset.tokenization.save_hashes /path/to/save/to.npy
18```
20to check hashes shipped with the package:
21```bash
22python -m maze_dataset.tokenization.save_hashes --check
23```
25"""
27from pathlib import Path
29import numpy as np
30from muutils.spinner import SpinnerContext
32from maze_dataset.tokenization.modular import all_tokenizers
33from maze_dataset.tokenization.modular.hashing import (
34 _load_tokenizer_hashes,
35 get_all_tokenizer_hashes,
36)
38if __name__ == "__main__":
39 # parse args
40 # ==================================================
41 import argparse
43 parser: argparse.ArgumentParser = argparse.ArgumentParser(
44 description="generate and save (or download) the hashes of all supported tokenizers",
45 )
47 parser.add_argument("path", type=str, nargs="?", help="path to save the hashes to")
48 parser.add_argument(
49 "--quiet",
50 "-q",
51 action="store_true",
52 help="disable progress bar and spinner",
53 )
54 parser.add_argument(
55 "--parallelize",
56 "-p",
57 action="store_true",
58 help="parallelize the computation",
59 )
60 parser.add_argument(
61 "--check",
62 "-c",
63 action="store_true",
64 help="save to temp location, then compare to existing",
65 )
66 parser.add_argument(
67 "--download",
68 "-d",
69 action="store_true",
70 help=f"download the hashes from github: {all_tokenizers.DOWNLOAD_URL}",
71 )
73 args: argparse.Namespace = parser.parse_args()
75 if not args.check:
76 # write new hashes
77 # ==================================================
78 all_tokenizers.save_hashes(
79 path=args.path,
80 verbose=not args.quiet,
81 parallelize=args.parallelize,
82 )
84 else:
85 # check hashes only
86 # ==================================================
88 # set up path
89 if args.path is not None:
90 raise ValueError("cannot use --check with a custom path")
91 temp_path: Path = Path("tests/_temp/tok_hashes.npz")
92 temp_path.parent.mkdir(parents=True, exist_ok=True)
94 # generate and save to temp location
95 returned_hashes: np.ndarray = all_tokenizers.save_hashes(
96 path=temp_path,
97 verbose=not args.quiet,
98 parallelize=args.parallelize,
99 )
101 # load saved hashes
102 with SpinnerContext(
103 spinner_chars="square_dot",
104 update_interval=0.5,
105 message="loading saved hashes...",
106 ):
107 read_hashes: np.ndarray = np.load(temp_path)["hashes"]
108 read_hashes_pkg: np.ndarray = _load_tokenizer_hashes()
109 read_hashes_wrapped: np.ndarray = get_all_tokenizer_hashes()
111 # compare
112 with SpinnerContext(
113 spinner_chars="square_dot",
114 update_interval=0.01,
115 message="checking hashes: ",
116 format_string="\r{spinner} ({elapsed_time:.2f}s) {message}{value} ",
117 format_string_when_updated=True,
118 ) as sp:
119 sp.update_value("returned vs read")
120 assert np.array_equal(returned_hashes, read_hashes)
121 sp.update_value("returned vs _load_tokenizer_hashes")
122 assert np.array_equal(returned_hashes, read_hashes_pkg)
123 sp.update_value("returned vs get_all_tokenizer_hashes()")
124 assert np.array_equal(read_hashes, read_hashes_wrapped)