Coverage for tests/unit/tokenization/test_token_utils.py: 97%

175 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-27 23:43 -0600

1import itertools 

2from typing import Callable 

3 

4import frozendict 

5import numpy as np 

6import pytest 

7from jaxtyping import Int 

8 

9from maze_dataset import LatticeMaze 

10from maze_dataset.constants import VOCAB, Connection, ConnectionArray 

11from maze_dataset.generation import numpy_rng 

12from maze_dataset.testing_utils import GRID_N, MAZE_DATASET 

13from maze_dataset.token_utils import ( 

14 _coord_to_strings_UT, 

15 coords_to_strings, 

16 equal_except_adj_list_sequence, 

17 get_adj_list_tokens, 

18 get_origin_tokens, 

19 get_path_tokens, 

20 get_relative_direction, 

21 get_target_tokens, 

22 is_connection, 

23 strings_to_coords, 

24 tokens_between, 

25) 

26from maze_dataset.tokenization import ( 

27 PathTokenizers, 

28 StepTokenizers, 

29 get_tokens_up_to_path_start, 

30) 

31from maze_dataset.tokenization.modular.all_instances import FiniteValued, all_instances 

32from maze_dataset.utils import ( 

33 lattice_connection_array, 

34 manhattan_distance, 

35) 

36 

37MAZE_TOKENS: tuple[list[str], str] = ( 

38 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

39 "AOTP_UT", 

40) 

41MAZE_TOKENS_AOTP_CTT_indexed: tuple[list[str], str] = ( 

42 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

43 "AOTP_CTT_indexed", 

44) 

45TEST_TOKEN_LISTS: list[tuple[list[str], str]] = [ 

46 MAZE_TOKENS, 

47 MAZE_TOKENS_AOTP_CTT_indexed, 

48] 

49 

50 

51@pytest.mark.parametrize( 

52 ("toks", "tokenizer_name"), 

53 [ 

54 pytest.param( 

55 token_list[0], 

56 token_list[1], 

57 id=f"{token_list[1]}", 

58 ) 

59 for token_list in TEST_TOKEN_LISTS 

60 ], 

61) 

62def test_tokens_between(toks: list[str], tokenizer_name: str): 

63 result = tokens_between(toks, "<PATH_START>", "<PATH_END>") 

64 match tokenizer_name: 

65 case "AOTP_UT": 

66 assert result == ["(1,0)", "(1,1)"] 

67 case "AOTP_CTT_indexed": 

68 assert result == ["(", "1", ",", "0", ")", "(", "1", ",", "1", ")"] 

69 

70 # Normal case 

71 tokens = ["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] 

72 start_value = "quick" 

73 end_value = "over" 

74 assert tokens_between(tokens, start_value, end_value) == ["brown", "fox", "jumps"] 

75 

76 # Including start and end values 

77 assert tokens_between(tokens, start_value, end_value, True, True) == [ 

78 "quick", 

79 "brown", 

80 "fox", 

81 "jumps", 

82 "over", 

83 ] 

84 

85 # When start_value or end_value is not unique and except_when_tokens_not_unique is True 

86 with pytest.raises(ValueError): # noqa: PT011 

87 tokens_between(tokens, "the", "dog", False, False, True) 

88 

89 # When start_value or end_value is not unique and except_when_tokens_not_unique is False 

90 assert tokens_between(tokens, "the", "dog", False, False, False) == [ 

91 "quick", 

92 "brown", 

93 "fox", 

94 "jumps", 

95 "over", 

96 "the", 

97 "lazy", 

98 ] 

99 

100 # Empty tokens list 

101 with pytest.raises(ValueError): # noqa: PT011 

102 tokens_between([], "start", "end") 

103 

104 # start_value and end_value are the same 

105 with pytest.raises(ValueError): # noqa: PT011 

106 tokens_between(tokens, "fox", "fox") 

107 

108 # start_value or end_value not in the tokens list 

109 with pytest.raises(ValueError): # noqa: PT011 

110 tokens_between(tokens, "start", "end") 

111 

112 # start_value comes after end_value in the tokens list 

113 with pytest.raises(AssertionError): 

114 tokens_between(tokens, "over", "quick") 

115 

116 # start_value and end_value are at the beginning and end of the tokens list, respectively 

117 assert tokens_between(tokens, "the", "dog", True, True) == tokens 

118 

119 # Single element in the tokens list, which is the same as start_value and end_value 

120 with pytest.raises(ValueError): # noqa: PT011 

121 tokens_between(["fox"], "fox", "fox", True, True) 

122 

123 

124@pytest.mark.parametrize( 

125 ("toks", "tokenizer_name"), 

126 [ 

127 pytest.param( 

128 token_list[0], 

129 token_list[1], 

130 id=f"{token_list[1]}", 

131 ) 

132 for token_list in TEST_TOKEN_LISTS 

133 ], 

134) 

135def test_tokens_between_out_of_order(toks: list[str], tokenizer_name: str): 

136 assert tokenizer_name 

137 with pytest.raises(AssertionError): 

138 tokens_between(toks, "<PATH_END>", "<PATH_START>") 

139 

140 

141@pytest.mark.parametrize( 

142 ("toks", "tokenizer_name"), 

143 [ 

144 pytest.param( 

145 token_list[0], 

146 token_list[1], 

147 id=f"{token_list[1]}", 

148 ) 

149 for token_list in TEST_TOKEN_LISTS 

150 ], 

151) 

152def test_get_adj_list_tokens(toks: list[str], tokenizer_name: str): 

153 result = get_adj_list_tokens(toks) 

154 match tokenizer_name: 

155 case "AOTP_UT": 

156 expected = ( 

157 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ;".split() 

158 ) 

159 case "AOTP_CTT_indexed": 

160 expected = "( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ;".split() 

161 assert result == expected 

162 

163 

164@pytest.mark.parametrize( 

165 ("toks", "tokenizer_name"), 

166 [ 

167 pytest.param( 

168 token_list[0], 

169 token_list[1], 

170 id=f"{token_list[1]}", 

171 ) 

172 for token_list in TEST_TOKEN_LISTS 

173 ], 

174) 

175def test_get_path_tokens(toks: list[str], tokenizer_name: str): 

176 result_notrim = get_path_tokens(toks) 

177 result_trim = get_path_tokens(toks, trim_end=True) 

178 match tokenizer_name: 

179 case "AOTP_UT": 

180 assert result_notrim == ["<PATH_START>", "(1,0)", "(1,1)", "<PATH_END>"] 

181 assert result_trim == ["(1,0)", "(1,1)"] 

182 case "AOTP_CTT_indexed": 

183 assert ( 

184 result_notrim == "<PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split() 

185 ) 

186 assert result_trim == "( 1 , 0 ) ( 1 , 1 )".split() 

187 

188 

189@pytest.mark.parametrize( 

190 ("toks", "tokenizer_name"), 

191 [ 

192 pytest.param( 

193 token_list[0], 

194 token_list[1], 

195 id=f"{token_list[1]}", 

196 ) 

197 for token_list in TEST_TOKEN_LISTS 

198 ], 

199) 

200def test_get_origin_tokens(toks: list[str], tokenizer_name: str): 

201 result = get_origin_tokens(toks) 

202 match tokenizer_name: 

203 case "AOTP_UT": 

204 assert result == ["(1,0)"] 

205 case "AOTP_CTT_indexed": 

206 assert result == "( 1 , 0 )".split() 

207 

208 

209@pytest.mark.parametrize( 

210 ("toks", "tokenizer_name"), 

211 [ 

212 pytest.param( 

213 token_list[0], 

214 token_list[1], 

215 id=f"{token_list[1]}", 

216 ) 

217 for token_list in TEST_TOKEN_LISTS 

218 ], 

219) 

220def test_get_target_tokens(toks: list[str], tokenizer_name: str): 

221 result = get_target_tokens(toks) 

222 match tokenizer_name: 

223 case "AOTP_UT": 

224 assert result == ["(1,1)"] 

225 case "AOTP_CTT_indexed": 

226 assert result == "( 1 , 1 )".split() 

227 

228 

229@pytest.mark.parametrize( 

230 ("toks", "tokenizer_name"), 

231 [ 

232 pytest.param( 

233 token_list[0], 

234 token_list[1], 

235 id=f"{token_list[1]}", 

236 ) 

237 for token_list in [MAZE_TOKENS] 

238 ], 

239) 

240def test_get_tokens_up_to_path_start_including_start( 

241 toks: list[str], 

242 tokenizer_name: str, 

243): 

244 # Dont test on `MAZE_TOKENS_AOTP_CTT_indexed` because this function doesn't support `AOTP_CTT_indexed` when `include_start_coord=True`. 

245 result = get_tokens_up_to_path_start(toks, include_start_coord=True) 

246 match tokenizer_name: 

247 case "AOTP_UT": 

248 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0)".split() 

249 case "AOTP_CTT_indexed": 

250 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 )".split() 

251 assert result == expected 

252 

253 

254@pytest.mark.parametrize( 

255 ("toks", "tokenizer_name"), 

256 [ 

257 pytest.param( 

258 token_list[0], 

259 token_list[1], 

260 id=f"{token_list[1]}", 

261 ) 

262 for token_list in TEST_TOKEN_LISTS 

263 ], 

264) 

265def test_get_tokens_up_to_path_start_excluding_start( 

266 toks: list[str], 

267 tokenizer_name: str, 

268): 

269 result = get_tokens_up_to_path_start(toks, include_start_coord=False) 

270 match tokenizer_name: 

271 case "AOTP_UT": 

272 expected = "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START>".split() 

273 case "AOTP_CTT_indexed": 

274 expected = "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START>".split() 

275 assert result == expected 

276 

277 

278@pytest.mark.parametrize( 

279 ("toks", "tokenizer_name"), 

280 [ 

281 pytest.param( 

282 token_list[0], 

283 token_list[1], 

284 id=f"{token_list[1]}", 

285 ) 

286 for token_list in TEST_TOKEN_LISTS 

287 ], 

288) 

289def test_strings_to_coords(toks: list[str], tokenizer_name: str): 

290 assert tokenizer_name 

291 adj_list = get_adj_list_tokens(toks) 

292 skipped = strings_to_coords(adj_list, when_noncoord="skip") 

293 included = strings_to_coords(adj_list, when_noncoord="include") 

294 

295 assert skipped == [ 

296 (0, 1), 

297 (1, 1), 

298 (1, 0), 

299 (1, 1), 

300 (0, 1), 

301 (0, 0), 

302 ] 

303 

304 assert included == [ 

305 (0, 1), 

306 "<-->", 

307 (1, 1), 

308 ";", 

309 (1, 0), 

310 "<-->", 

311 (1, 1), 

312 ";", 

313 (0, 1), 

314 "<-->", 

315 (0, 0), 

316 ";", 

317 ] 

318 

319 with pytest.raises(ValueError): # noqa: PT011 

320 strings_to_coords(adj_list, when_noncoord="error") 

321 

322 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)") == [(1, 2), (5, 6)] 

323 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="skip") == [ 

324 (1, 2), 

325 (5, 6), 

326 ] 

327 assert strings_to_coords( 

328 "(1,2) <ADJLIST_START> (5,6)", 

329 when_noncoord="include", 

330 ) == [(1, 2), "<ADJLIST_START>", (5, 6)] 

331 with pytest.raises(ValueError): # noqa: PT011 

332 strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="error") 

333 

334 

335@pytest.mark.parametrize( 

336 ("toks", "tokenizer_name"), 

337 [ 

338 pytest.param( 

339 token_list[0], 

340 token_list[1], 

341 id=f"{token_list[1]}", 

342 ) 

343 for token_list in TEST_TOKEN_LISTS 

344 ], 

345) 

346def test_coords_to_strings(toks: list[str], tokenizer_name: str): 

347 assert tokenizer_name 

348 adj_list = get_adj_list_tokens(toks) 

349 # config = MazeDatasetConfig(name="test", grid_n=2, n_mazes=1) 

350 coords = strings_to_coords(adj_list, when_noncoord="include") 

351 

352 skipped = coords_to_strings( 

353 coords, 

354 coord_to_strings_func=_coord_to_strings_UT, 

355 when_noncoord="skip", 

356 ) 

357 included = coords_to_strings( 

358 coords, 

359 coord_to_strings_func=_coord_to_strings_UT, 

360 when_noncoord="include", 

361 ) 

362 

363 assert skipped == [ 

364 "(0,1)", 

365 "(1,1)", 

366 "(1,0)", 

367 "(1,1)", 

368 "(0,1)", 

369 "(0,0)", 

370 ] 

371 

372 assert included == [ 

373 "(0,1)", 

374 "<-->", 

375 "(1,1)", 

376 ";", 

377 "(1,0)", 

378 "<-->", 

379 "(1,1)", 

380 ";", 

381 "(0,1)", 

382 "<-->", 

383 "(0,0)", 

384 ";", 

385 ] 

386 

387 with pytest.raises(ValueError): # noqa: PT011 

388 coords_to_strings( 

389 coords, 

390 coord_to_strings_func=_coord_to_strings_UT, 

391 when_noncoord="error", 

392 ) 

393 

394 

395def test_equal_except_adj_list_sequence(): 

396 assert equal_except_adj_list_sequence(MAZE_TOKENS[0], MAZE_TOKENS[0]) 

397 assert not equal_except_adj_list_sequence( 

398 MAZE_TOKENS[0], 

399 MAZE_TOKENS_AOTP_CTT_indexed[0], 

400 ) 

401 assert equal_except_adj_list_sequence( 

402 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

403 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

404 ) 

405 assert equal_except_adj_list_sequence( 

406 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

407 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

408 ) 

409 assert equal_except_adj_list_sequence( 

410 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

411 "<ADJLIST_START> (1,1) <--> (0,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

412 ) 

413 assert not equal_except_adj_list_sequence( 

414 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

415 "<ADJLIST_START> (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (0,1) <--> (1,1) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,1) (1,0) <PATH_END>".split(), 

416 ) 

417 assert not equal_except_adj_list_sequence( 

418 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

419 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END> <PATH_END>".split(), 

420 ) 

421 assert not equal_except_adj_list_sequence( 

422 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

423 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

424 ) 

425 assert not equal_except_adj_list_sequence( 

426 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

427 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

428 ) 

429 with pytest.raises(ValueError): # noqa: PT011 

430 equal_except_adj_list_sequence( 

431 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

432 "(0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

433 ) 

434 with pytest.raises(ValueError): # noqa: PT011 

435 equal_except_adj_list_sequence( 

436 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

437 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

438 ) 

439 assert not equal_except_adj_list_sequence( 

440 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ADJLIST_END> <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

441 "<ADJLIST_START> (0,1) <--> (1,1) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; <ORIGIN_START> (1,0) <ORIGIN_END> <TARGET_START> (1,1) <TARGET_END> <PATH_START> (1,0) (1,1) <PATH_END>".split(), 

442 ) 

443 

444 # CTT 

445 assert equal_except_adj_list_sequence( 

446 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

447 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

448 ) 

449 assert equal_except_adj_list_sequence( 

450 "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

451 "<ADJLIST_START> ( 1 , 1 ) <--> ( 0 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

452 ) 

453 # This inactive test demonstrates the lack of robustness of the function for comparing source `LatticeMaze` objects. 

454 # See function documentation for details. 

455 # assert not equal_except_adj_list_sequence( 

456 # "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

457 # "<ADJLIST_START> ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split() 

458 # ) 

459 

460 

461# @mivanit: this was really difficult to understand 

462@pytest.mark.parametrize( 

463 ("type_", "validation_funcs", "assertion"), 

464 [ 

465 pytest.param( 

466 type_, 

467 vfs, 

468 assertion, 

469 id=f"{i}-{type_.__name__}", 

470 ) 

471 for i, (type_, vfs, assertion) in enumerate( 

472 [ 

473 ( 

474 # type 

475 PathTokenizers._PathTokenizer, 

476 # validation_funcs 

477 dict(), 

478 # assertion 

479 lambda x: PathTokenizers.StepSequence( 

480 step_tokenizers=(StepTokenizers.Distance(),), 

481 ) 

482 in x, 

483 ), 

484 ( 

485 # type 

486 PathTokenizers._PathTokenizer, 

487 # validation_funcs 

488 {PathTokenizers._PathTokenizer: lambda x: x.is_valid()}, 

489 # assertion 

490 lambda x: PathTokenizers.StepSequence( 

491 step_tokenizers=(StepTokenizers.Distance(),), 

492 ) 

493 not in x 

494 and PathTokenizers.StepSequence( 

495 step_tokenizers=( 

496 StepTokenizers.Coord(), 

497 StepTokenizers.Coord(), 

498 ), 

499 ) 

500 not in x, 

501 ), 

502 ], 

503 ) 

504 ], 

505) 

506def test_all_instances2( 

507 type_: FiniteValued, 

508 validation_funcs: frozendict.frozendict[ 

509 FiniteValued, 

510 Callable[[FiniteValued], bool], 

511 ], 

512 assertion: Callable[[list[FiniteValued]], bool], 

513): 

514 assert assertion(all_instances(type_, validation_funcs)) 

515 

516 

517@pytest.mark.parametrize( 

518 ("coords", "result"), 

519 [ 

520 pytest.param( 

521 np.array(coords), 

522 res, 

523 id=f"{coords}", 

524 ) 

525 for coords, res in ( 

526 [ 

527 ([[0, 0], [0, 1], [1, 1]], VOCAB.PATH_RIGHT), 

528 ([[0, 0], [1, 0], [1, 1]], VOCAB.PATH_LEFT), 

529 ([[0, 0], [0, 1], [0, 2]], VOCAB.PATH_FORWARD), 

530 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD), 

531 ([[0, 0], [0, 1], [0, 1]], VOCAB.PATH_STAY), 

532 ([[1, 1], [0, 1], [0, 0]], VOCAB.PATH_LEFT), 

533 ([[1, 1], [1, 0], [0, 0]], VOCAB.PATH_RIGHT), 

534 ([[0, 2], [0, 1], [0, 0]], VOCAB.PATH_FORWARD), 

535 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD), 

536 ([[0, 1], [0, 1], [0, 0]], ValueError), 

537 ([[0, 1], [1, 1], [0, 0]], ValueError), 

538 ([[1, 0], [1, 1], [0, 0]], ValueError), 

539 ([[0, 1], [0, 2], [0, 0]], ValueError), 

540 ([[0, 1], [0, 0], [0, 0]], VOCAB.PATH_STAY), 

541 ([[1, 1], [0, 0], [0, 1]], ValueError), 

542 ([[1, 1], [0, 0], [1, 0]], ValueError), 

543 ([[0, 2], [0, 0], [0, 1]], ValueError), 

544 ([[0, 0], [0, 0], [0, 1]], ValueError), 

545 ([[0, 1], [0, 0], [0, 1]], VOCAB.PATH_BACKWARD), 

546 ([[-1, 0], [0, 0], [1, 0]], VOCAB.PATH_FORWARD), 

547 ([[-1, 0], [0, 0], [0, 1]], VOCAB.PATH_LEFT), 

548 ([[-1, 0], [0, 0], [-1, 0]], VOCAB.PATH_BACKWARD), 

549 ([[-1, 0], [0, 0], [0, -1]], VOCAB.PATH_RIGHT), 

550 ([[-1, 0], [0, 0], [1, 0], [2, 0]], ValueError), 

551 ([[-1, 0], [0, 0]], ValueError), 

552 ([[-1, 0, 0], [0, 0, 0]], ValueError), 

553 ] 

554 ) 

555 ], 

556) 

557def test_get_relative_direction( 

558 coords: Int[np.ndarray, "prev_cur_next=3 axis=2"], 

559 result: str | type[Exception], 

560): 

561 if isinstance(result, type) and issubclass(result, Exception): 

562 with pytest.raises(result): 

563 get_relative_direction(coords) 

564 return 

565 assert get_relative_direction(coords) == result 

566 

567 

568@pytest.mark.parametrize( 

569 ("edges", "result"), 

570 [ 

571 pytest.param( 

572 edges, 

573 res, 

574 id=f"{edges}", 

575 ) 

576 for edges, res in ( 

577 [ 

578 (np.array([[0, 0], [0, 1]]), 1), 

579 (np.array([[1, 0], [0, 1]]), 2), 

580 (np.array([[-1, 0], [0, 1]]), 2), 

581 (np.array([[0, 0], [5, 3]]), 8), 

582 ( 

583 np.array( 

584 [ 

585 [[0, 0], [0, 1]], 

586 [[1, 0], [0, 1]], 

587 [[-1, 0], [0, 1]], 

588 [[0, 0], [5, 3]], 

589 ], 

590 ), 

591 [1, 2, 2, 8], 

592 ), 

593 (np.array([[[0, 0], [5, 3]]]), [8]), 

594 ] 

595 ) 

596 ], 

597) 

598def test_manhattan_distance( 

599 edges: ConnectionArray | Connection, 

600 result: Int[np.ndarray, " edges"] | Int[np.ndarray, ""] | type[Exception], 

601): 

602 if isinstance(result, type) and issubclass(result, Exception): 

603 with pytest.raises(result): 

604 manhattan_distance(edges) 

605 return 

606 assert np.array_equal(manhattan_distance(edges), np.array(result, dtype=np.int8)) 

607 

608 

609@pytest.mark.parametrize( 

610 "n", 

611 [pytest.param(n) for n in [2, 3, 5, 20]], 

612) 

613def test_lattice_connection_arrray(n): 

614 edges = lattice_connection_array(n) 

615 assert tuple(edges.shape) == (2 * n * (n - 1), 2, 2) 

616 assert np.all(np.sum(edges[:, 1], axis=1) > np.sum(edges[:, 0], axis=1)) 

617 assert tuple(np.unique(edges, axis=0).shape) == (2 * n * (n - 1), 2, 2) 

618 

619 

620@pytest.mark.parametrize( 

621 ("edges", "maze"), 

622 [ 

623 pytest.param( 

624 edges(), 

625 maze, 

626 id=f"edges[{i}]; maze[{j}]", 

627 ) 

628 for (i, edges), (j, maze) in itertools.product( 

629 enumerate( 

630 [ 

631 lambda: lattice_connection_array(GRID_N), 

632 lambda: np.flip(lattice_connection_array(GRID_N), axis=1), 

633 lambda: lattice_connection_array(GRID_N - 1), 

634 lambda: numpy_rng.choice( 

635 lattice_connection_array(GRID_N), 

636 2 * GRID_N, 

637 axis=0, 

638 ), 

639 lambda: numpy_rng.choice( 

640 lattice_connection_array(GRID_N), 

641 1, 

642 axis=0, 

643 ), 

644 ], 

645 ), 

646 enumerate(MAZE_DATASET.mazes), 

647 ) 

648 ], 

649) 

650def test_is_connection(edges: ConnectionArray, maze: LatticeMaze): 

651 output = is_connection(edges, maze.connection_list) 

652 sorted_edges = np.sort(edges, axis=1) 

653 edge_direction = ( 

654 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0 

655 ).astype(np.int8) 

656 assert np.array_equal( 

657 output, 

658 maze.connection_list[ 

659 edge_direction, 

660 sorted_edges[:, 0, 0], 

661 sorted_edges[:, 0, 1], 

662 ], 

663 )