Coverage for tests/all_tokenizers/test_all_tokenizers.py: 99%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 12:48 -0600

1import itertools 

2import os 

3from collections import Counter 

4from typing import Callable, Iterable 

5 

6import pytest 

7from zanj import ZANJ 

8 

9from maze_dataset import VOCAB, VOCAB_LIST, LatticeMaze 

10from maze_dataset.maze.lattice_maze import SolvedMaze 

11from maze_dataset.testing_utils import MIXED_MAZES 

12from maze_dataset.token_utils import equal_except_adj_list_sequence 

13from maze_dataset.tokenization import ( 

14 AdjListTokenizers, 

15 CoordTokenizers, 

16 EdgeGroupings, 

17 EdgePermuters, 

18 MazeTokenizerModular, 

19 PathTokenizers, 

20 PromptSequencers, 

21 StepSizes, 

22 StepTokenizers, 

23 _TokenizerElement, 

24) 

25from maze_dataset.tokenization.modular.all_instances import all_instances 

26from maze_dataset.tokenization.modular.all_tokenizers import ( 

27 EVERY_TEST_TOKENIZERS, 

28 MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, 

29 sample_tokenizers_for_test, 

30 save_hashes, 

31) 

32 

33# Size of the sample from `all_tokenizers.get_all_tokenizers()` to test 

34# get from env, or set to default value of 100 

35_os_env_num_tokenizers: str = os.getenv("NUM_TOKENIZERS_TO_TEST", "100") 

36NUM_TOKENIZERS_TO_TEST: int | None = ( 

37 int(_os_env_num_tokenizers) if _os_env_num_tokenizers.isdigit() else None 

38) 

39print(f"{NUM_TOKENIZERS_TO_TEST = }") 

40 

41SAMPLED_TOKENIZERS: list[MazeTokenizerModular] = sample_tokenizers_for_test( 

42 NUM_TOKENIZERS_TO_TEST, 

43) 

44 

45SAMPLED_MAZES: list[SolvedMaze] = MIXED_MAZES[:6] 

46 

47 

48@pytest.fixture(scope="session") 

49def save_tokenizer_hashes(): 

50 save_hashes() 

51 

52 

53@pytest.mark.parametrize( 

54 "class_", 

55 [pytest.param(c, id=c.__name__) for c in _TokenizerElement.__subclasses__()], 

56) 

57def test_all_instances_tokenizerelement(class_: type): 

58 all_vals = list( 

59 all_instances( 

60 class_, 

61 validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, 

62 ), 

63 ) 

64 assert len({hash(elem) for elem in all_vals}) == len(all_vals) 

65 

66 

67SAMPLE_MIN: int = len(EVERY_TEST_TOKENIZERS) 

68 

69 

70@pytest.mark.parametrize( 

71 ("n", "result"), 

72 [ 

73 pytest.param(i, result) 

74 for i, result in [ 

75 (SAMPLE_MIN - 1, ValueError), 

76 (SAMPLE_MIN, None), 

77 (SAMPLE_MIN + 5, None), 

78 (SAMPLE_MIN + 200, None), 

79 ] 

80 ], 

81) 

82def test_sample_tokenizers_for_test(n: int, result: type[Exception] | None): 

83 if isinstance(result, type) and issubclass(result, Exception): 

84 with pytest.raises(result): 

85 sample_tokenizers_for_test(n) 

86 return 

87 mts: list[MazeTokenizerModular] = sample_tokenizers_for_test(n) 

88 mts_set: set[MazeTokenizerModular] = set(mts) 

89 assert len(mts) == len(mts_set) 

90 assert set(EVERY_TEST_TOKENIZERS).issubset(mts_set) 

91 if n > SAMPLE_MIN + 1: 

92 mts2: list[MazeTokenizerModular] = sample_tokenizers_for_test(n) 

93 assert set(mts2) != mts_set # Check that succesive samples are different 

94 

95 

96@pytest.mark.parametrize( 

97 "tokenizer", 

98 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS], 

99) 

100def test_token_region_delimiters(tokenizer: MazeTokenizerModular): 

101 """<PATH_START> and similar token region delimiters should appear at most 1 time, regardless of tokenizer.""" 

102 for maze in SAMPLED_MAZES: 

103 counts: Counter = Counter(maze.as_tokens(tokenizer)) 

104 assert all([counts[tok] < 2 for tok in VOCAB_LIST[:8]]) 

105 

106 

107@pytest.mark.parametrize( 

108 "tokenizer", 

109 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS], 

110) 

111def test_token_stability(tokenizer: MazeTokenizerModular): 

112 """Tests consistency of tokenizations over multiple method calls.""" 

113 for maze in SAMPLED_MAZES: 

114 tokens1: list[str] = maze.as_tokens(tokenizer) 

115 tokens2: list[str] = maze.as_tokens(tokenizer) 

116 if tokenizer.has_element( 

117 EdgeGroupings.ByLeadingCoord, 

118 EdgePermuters.RandomCoords, 

119 ) or tokenizer.has_element( 

120 AdjListTokenizers.AdjListCardinal, 

121 EdgePermuters.RandomCoords, 

122 ): 

123 # In this case, the adjlist is expected to have different token counts over multiple calls 

124 # Exclude that region from the test 

125 non_adjlist1 = tokens1[: tokens1.index(VOCAB.ADJLIST_START)] 

126 non_adjlist1.extend(tokens1[tokens1.index(VOCAB.ADJLIST_END) :]) 

127 non_adjlist2 = tokens2[: tokens2.index(VOCAB.ADJLIST_START)] 

128 non_adjlist2.extend(tokens2[tokens2.index(VOCAB.ADJLIST_END) :]) 

129 assert non_adjlist1 == non_adjlist2 

130 else: 

131 assert equal_except_adj_list_sequence(tokens1, tokens2) 

132 

133 

134@pytest.mark.parametrize( 

135 "tokenizer", 

136 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS], 

137) 

138def test_tokenizer_properties(tokenizer: MazeTokenizerModular): 

139 # Just make sure the call doesn't raise exception 

140 assert len(tokenizer.name) > 5 

141 

142 assert tokenizer.vocab_size == 4096 

143 assert isinstance(tokenizer.token_arr, Iterable) 

144 assert all(isinstance(token, str) for token in tokenizer.token_arr) 

145 assert tokenizer.token_arr[tokenizer.padding_token_index] == VOCAB.PADDING 

146 

147 # Just make sure the call doesn't raise exception 

148 print(tokenizer.summary()) 

149 

150 

151@pytest.mark.parametrize( 

152 "tokenizer", 

153 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS], 

154) 

155def test_encode_decode(tokenizer: MazeTokenizerModular): 

156 for maze in SAMPLED_MAZES: 

157 maze_tok: list[str] = maze.as_tokens(maze_tokenizer=tokenizer) 

158 maze_encoded: list[int] = tokenizer.encode(maze_tok) 

159 maze_decoded: LatticeMaze = tokenizer.decode(maze_encoded) 

160 assert maze_tok == maze_decoded 

161 

162 

163@pytest.mark.parametrize( 

164 "tokenizer", 

165 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS], 

166) 

167def test_zanj_save_read(tokenizer: MazeTokenizerModular): 

168 path = os.path.abspath( 

169 os.path.join( 

170 os.path.curdir, 

171 "data", 

172 f"mmt.{tokenizer.hash_b64()}.zanj", 

173 ), 

174 ) 

175 zanj = ZANJ() 

176 zanj.save(tokenizer, path) 

177 assert zanj.read(path) == tokenizer 

178 

179 

180@pytest.mark.parametrize( 

181 "tokenizer", 

182 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS], 

183) 

184def test_is_AOTP(tokenizer: MazeTokenizerModular): 

185 if isinstance(tokenizer.prompt_sequencer, PromptSequencers.AOTP): 

186 assert tokenizer.is_AOTP() 

187 else: 

188 assert not tokenizer.is_AOTP() 

189 

190 

191@pytest.mark.parametrize( 

192 "tokenizer", 

193 [pytest.param(tokenizer, id=tokenizer.name) for tokenizer in SAMPLED_TOKENIZERS], 

194) 

195def test_is_UT(tokenizer: MazeTokenizerModular): 

196 if isinstance(tokenizer.prompt_sequencer.coord_tokenizer, CoordTokenizers.UT): 

197 assert tokenizer.is_UT() 

198 else: 

199 assert not tokenizer.is_UT() 

200 

201 

202_has_elems_type = ( 

203 type[_TokenizerElement] 

204 | _TokenizerElement 

205 | Iterable[type[_TokenizerElement] | _TokenizerElement] 

206) 

207 

208 

209@pytest.mark.parametrize( 

210 ("tokenizer", "elems", "result_func"), 

211 [ 

212 pytest.param( 

213 tokenizer, 

214 elems_tuple[0], 

215 elems_tuple[1], 

216 id=f"{tokenizer.name}-{elems_tuple[0]}", 

217 ) 

218 for tokenizer, elems_tuple in itertools.product( 

219 SAMPLED_TOKENIZERS, 

220 [ 

221 ( 

222 [PromptSequencers.AOTP()], 

223 lambda mt, els: mt.prompt_sequencer == els[0], 

224 ), 

225 (PromptSequencers.AOTP(), lambda mt, els: mt.prompt_sequencer == els), 

226 ( 

227 [CoordTokenizers.CTT()], 

228 lambda mt, els: mt.prompt_sequencer.coord_tokenizer == els[0], 

229 ), 

230 ( 

231 CoordTokenizers.CTT(intra=False), 

232 lambda mt, els: mt.prompt_sequencer.coord_tokenizer == els, 

233 ), 

234 ( 

235 [CoordTokenizers.CTT], 

236 lambda mt, els: isinstance( 

237 mt.prompt_sequencer.coord_tokenizer, 

238 els[0], 

239 ), 

240 ), 

241 ( 

242 CoordTokenizers._CoordTokenizer, 

243 lambda mt, els: isinstance( 

244 mt.prompt_sequencer.coord_tokenizer, 

245 els, 

246 ), 

247 ), 

248 ( 

249 StepSizes.Singles, 

250 lambda mt, els: isinstance( 

251 mt.prompt_sequencer.path_tokenizer.step_size, 

252 els, 

253 ), 

254 ), 

255 ( 

256 StepTokenizers.Coord, 

257 lambda mt, els: any( 

258 isinstance(step_tok, els) 

259 for step_tok in mt.prompt_sequencer.path_tokenizer.step_tokenizers 

260 ), 

261 ), 

262 ( 

263 [CoordTokenizers.CTT()], 

264 lambda mt, els: mt.prompt_sequencer.coord_tokenizer == els[0], 

265 ), 

266 ( 

267 [CoordTokenizers.CTT, PathTokenizers.StepSequence], 

268 lambda mt, els: isinstance( 

269 mt.prompt_sequencer.coord_tokenizer, 

270 els[0], 

271 ) 

272 and isinstance(mt.prompt_sequencer.path_tokenizer, els[1]), 

273 ), 

274 # ((a for a in [CoordTokenizers.CTT, PathTokenizers.Coords]), 

275 # lambda mt, els: isinstance(mt.coord_tokenizer, list(els)[0]) and isinstance(mt.path_tokenizer, list(els)[1]) 

276 # ), 

277 ( 

278 [CoordTokenizers.CTT, PathTokenizers.StepSequence(post=False)], 

279 lambda mt, els: isinstance( 

280 mt.prompt_sequencer.coord_tokenizer, 

281 els[0], 

282 ) 

283 and mt.prompt_sequencer.path_tokenizer == els[1], 

284 ), 

285 ( 

286 [ 

287 CoordTokenizers.CTT, 

288 PathTokenizers.StepSequence, 

289 PromptSequencers.AOP(), 

290 ], 

291 lambda mt, els: isinstance( 

292 mt.prompt_sequencer.coord_tokenizer, 

293 els[0], 

294 ) 

295 and isinstance(mt.prompt_sequencer.path_tokenizer, els[1]) 

296 and mt.prompt_sequencer == els[2], 

297 ), 

298 ], 

299 ) 

300 ], 

301) 

302def test_has_element( 

303 tokenizer: MazeTokenizerModular, 

304 elems: _has_elems_type, 

305 result_func: Callable[[MazeTokenizerModular, _has_elems_type], bool], 

306): 

307 assert tokenizer.has_element(elems) == result_func(tokenizer, elems)