maze_dataset.tokenization.modular.save_hashes
generate and save the hashes of all supported tokenizers
Caution
using hashes to validate validate a MazeTokenizerModular
is deprecated in favor of using fst
calls maze_dataset.tokenization.all_tokenizers.save_hashes()
Usage:
To save to the default location (inside package, maze_dataset/tokenization/MazeTokenizerModular_hashes.npy
):
python -m maze_dataset.tokenization.save_hashes
to save to a custom location:
python -m maze_dataset.tokenization.save_hashes /path/to/save/to.npy
to check hashes shipped with the package:
python -m maze_dataset.tokenization.save_hashes --check
1"""generate and save the hashes of all supported tokenizers 2 3> [!CAUTION] 4> using hashes to validate validate a `MazeTokenizerModular` is deprecated in favor of using fst 5 6calls `maze_dataset.tokenization.all_tokenizers.save_hashes()` 7 8Usage: 9 10To save to the default location (inside package, `maze_dataset/tokenization/MazeTokenizerModular_hashes.npy`): 11```bash 12python -m maze_dataset.tokenization.save_hashes 13``` 14 15to save to a custom location: 16```bash 17python -m maze_dataset.tokenization.save_hashes /path/to/save/to.npy 18``` 19 20to check hashes shipped with the package: 21```bash 22python -m maze_dataset.tokenization.save_hashes --check 23``` 24 25""" 26 27from pathlib import Path 28 29import numpy as np 30from muutils.spinner import SpinnerContext 31 32from maze_dataset.tokenization.modular import all_tokenizers 33from maze_dataset.tokenization.modular.hashing import ( 34 _load_tokenizer_hashes, 35 get_all_tokenizer_hashes, 36) 37 38if __name__ == "__main__": 39 # parse args 40 # ================================================== 41 import argparse 42 43 parser: argparse.ArgumentParser = argparse.ArgumentParser( 44 description="generate and save (or download) the hashes of all supported tokenizers", 45 ) 46 47 parser.add_argument("path", type=str, nargs="?", help="path to save the hashes to") 48 parser.add_argument( 49 "--quiet", 50 "-q", 51 action="store_true", 52 help="disable progress bar and spinner", 53 ) 54 parser.add_argument( 55 "--parallelize", 56 "-p", 57 action="store_true", 58 help="parallelize the computation", 59 ) 60 parser.add_argument( 61 "--check", 62 "-c", 63 action="store_true", 64 help="save to temp location, then compare to existing", 65 ) 66 parser.add_argument( 67 "--download", 68 "-d", 69 action="store_true", 70 help=f"download the hashes from github: {all_tokenizers.DOWNLOAD_URL}", 71 ) 72 73 args: argparse.Namespace = parser.parse_args() 74 75 if not args.check: 76 # write new hashes 77 # ================================================== 78 all_tokenizers.save_hashes( 79 path=args.path, 80 verbose=not args.quiet, 81 parallelize=args.parallelize, 82 ) 83 84 else: 85 # check hashes only 86 # ================================================== 87 88 # set up path 89 if args.path is not None: 90 raise ValueError("cannot use --check with a custom path") 91 temp_path: Path = Path("tests/_temp/tok_hashes.npz") 92 temp_path.parent.mkdir(parents=True, exist_ok=True) 93 94 # generate and save to temp location 95 returned_hashes: np.ndarray = all_tokenizers.save_hashes( 96 path=temp_path, 97 verbose=not args.quiet, 98 parallelize=args.parallelize, 99 ) 100 101 # load saved hashes 102 with SpinnerContext( 103 spinner_chars="square_dot", 104 update_interval=0.5, 105 message="loading saved hashes...", 106 ): 107 read_hashes: np.ndarray = np.load(temp_path)["hashes"] 108 read_hashes_pkg: np.ndarray = _load_tokenizer_hashes() 109 read_hashes_wrapped: np.ndarray = get_all_tokenizer_hashes() 110 111 # compare 112 with SpinnerContext( 113 spinner_chars="square_dot", 114 update_interval=0.01, 115 message="checking hashes: ", 116 format_string="\r{spinner} ({elapsed_time:.2f}s) {message}{value} ", 117 format_string_when_updated=True, 118 ) as sp: 119 sp.update_value("returned vs read") 120 assert np.array_equal(returned_hashes, read_hashes) 121 sp.update_value("returned vs _load_tokenizer_hashes") 122 assert np.array_equal(returned_hashes, read_hashes_pkg) 123 sp.update_value("returned vs get_all_tokenizer_hashes()") 124 assert np.array_equal(read_hashes, read_hashes_wrapped)