Coverage for tests/unit/tokenization/test_token_utils.py: 97%

175 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-08-03 21:38 -0700

1import itertools 

2from typing import Callable 

3 

4import frozendict 

5import numpy as np 

6import pytest 

7from jaxtyping import Int 

8 

9from maze_dataset import LatticeMaze 

10from maze_dataset.constants import VOCAB, Connection, ConnectionArray 

11from maze_dataset.generation import _NUMPY_RNG 

12from maze_dataset.testing_utils import GRID_N, MAZE_DATASET 

13from maze_dataset.token_utils import ( 

14 _coord_to_strings_UT, 

15 coords_to_strings, 

16 equal_except_adj_list_sequence, 

17 get_adj_list_tokens, 

18 get_origin_tokens, 

19 get_path_tokens, 

20 get_relative_direction, 

21 get_target_tokens, 

22 is_connection, 

23 strings_to_coords, 

24 tokens_between, 

25) 

26from maze_dataset.tokenization import ( 

27 PathTokenizers, 

28 StepTokenizers, 

29 get_tokens_up_to_path_start, 

30) 

31from maze_dataset.tokenization.modular.all_instances import FiniteValued, all_instances 

32from maze_dataset.utils import ( 

33 lattice_connection_array, 

34 manhattan_distance, 

35) 

36 

37MAZE_TOKENS: tuple[list[str], str] = ( 

38 [ 

39 "<ADJLIST_START>", 

40 "(0,1)", 

41 "<-->", 

42 "(1,1)", 

43 ";", 

44 "(1,0)", 

45 "<-->", 

46 "(1,1)", 

47 ";", 

48 "(0,1)", 

49 "<-->", 

50 "(0,0)", 

51 ";", 

52 "<ADJLIST_END>", 

53 "<ORIGIN_START>", 

54 "(1,0)", 

55 "<ORIGIN_END>", 

56 "<TARGET_START>", 

57 "(1,1)", 

58 "<TARGET_END>", 

59 "<PATH_START>", 

60 "(1,0)", 

61 "(1,1)", 

62 "<PATH_END>", 

63 ], 

64 "AOTP_UT", 

65) 

66MAZE_TOKENS_AOTP_CTT_indexed: tuple[list[str], str] = ( 

67 [ 

68 "<ADJLIST_START>", 

69 "(", 

70 "0", 

71 ",", 

72 "1", 

73 ")", 

74 "<-->", 

75 "(", 

76 "1", 

77 ",", 

78 "1", 

79 ")", 

80 ";", 

81 "(", 

82 "1", 

83 ",", 

84 "0", 

85 ")", 

86 "<-->", 

87 "(", 

88 "1", 

89 ",", 

90 "1", 

91 ")", 

92 ";", 

93 "(", 

94 "0", 

95 ",", 

96 "1", 

97 ")", 

98 "<-->", 

99 "(", 

100 "0", 

101 ",", 

102 "0", 

103 ")", 

104 ";", 

105 "<ADJLIST_END>", 

106 "<ORIGIN_START>", 

107 "(", 

108 "1", 

109 ",", 

110 "0", 

111 ")", 

112 "<ORIGIN_END>", 

113 "<TARGET_START>", 

114 "(", 

115 "1", 

116 ",", 

117 "1", 

118 ")", 

119 "<TARGET_END>", 

120 "<PATH_START>", 

121 "(", 

122 "1", 

123 ",", 

124 "0", 

125 ")", 

126 "(", 

127 "1", 

128 ",", 

129 "1", 

130 ")", 

131 "<PATH_END>", 

132 ], 

133 "AOTP_CTT_indexed", 

134) 

135TEST_TOKEN_LISTS: list[tuple[list[str], str]] = [ 

136 MAZE_TOKENS, 

137 MAZE_TOKENS_AOTP_CTT_indexed, 

138] 

139 

140 

141@pytest.mark.parametrize( 

142 ("toks", "tokenizer_name"), 

143 [ 

144 pytest.param( 

145 token_list[0], 

146 token_list[1], 

147 id=f"{token_list[1]}", 

148 ) 

149 for token_list in TEST_TOKEN_LISTS 

150 ], 

151) 

152def test_tokens_between(toks: list[str], tokenizer_name: str): 

153 result = tokens_between(toks, "<PATH_START>", "<PATH_END>") 

154 match tokenizer_name: 

155 case "AOTP_UT": 

156 assert result == ["(1,0)", "(1,1)"] 

157 case "AOTP_CTT_indexed": 

158 assert result == ["(", "1", ",", "0", ")", "(", "1", ",", "1", ")"] 

159 

160 # Normal case 

161 tokens = ["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"] 

162 start_value = "quick" 

163 end_value = "over" 

164 assert tokens_between(tokens, start_value, end_value) == ["brown", "fox", "jumps"] 

165 

166 # Including start and end values 

167 assert tokens_between(tokens, start_value, end_value, True, True) == [ 

168 "quick", 

169 "brown", 

170 "fox", 

171 "jumps", 

172 "over", 

173 ] 

174 

175 # When start_value or end_value is not unique and except_when_tokens_not_unique is True 

176 with pytest.raises(ValueError): # noqa: PT011 

177 tokens_between(tokens, "the", "dog", False, False, True) 

178 

179 # When start_value or end_value is not unique and except_when_tokens_not_unique is False 

180 assert tokens_between(tokens, "the", "dog", False, False, False) == [ 

181 "quick", 

182 "brown", 

183 "fox", 

184 "jumps", 

185 "over", 

186 "the", 

187 "lazy", 

188 ] 

189 

190 # Empty tokens list 

191 with pytest.raises(ValueError): # noqa: PT011 

192 tokens_between([], "start", "end") 

193 

194 # start_value and end_value are the same 

195 with pytest.raises(ValueError): # noqa: PT011 

196 tokens_between(tokens, "fox", "fox") 

197 

198 # start_value or end_value not in the tokens list 

199 with pytest.raises(ValueError): # noqa: PT011 

200 tokens_between(tokens, "start", "end") 

201 

202 # start_value comes after end_value in the tokens list 

203 with pytest.raises(AssertionError): 

204 tokens_between(tokens, "over", "quick") 

205 

206 # start_value and end_value are at the beginning and end of the tokens list, respectively 

207 assert tokens_between(tokens, "the", "dog", True, True) == tokens 

208 

209 # Single element in the tokens list, which is the same as start_value and end_value 

210 with pytest.raises(ValueError): # noqa: PT011 

211 tokens_between(["fox"], "fox", "fox", True, True) 

212 

213 

214@pytest.mark.parametrize( 

215 ("toks", "tokenizer_name"), 

216 [ 

217 pytest.param( 

218 token_list[0], 

219 token_list[1], 

220 id=f"{token_list[1]}", 

221 ) 

222 for token_list in TEST_TOKEN_LISTS 

223 ], 

224) 

225def test_tokens_between_out_of_order(toks: list[str], tokenizer_name: str): 

226 assert tokenizer_name 

227 with pytest.raises(AssertionError): 

228 tokens_between(toks, "<PATH_END>", "<PATH_START>") 

229 

230 

231@pytest.mark.parametrize( 

232 ("toks", "tokenizer_name"), 

233 [ 

234 pytest.param( 

235 token_list[0], 

236 token_list[1], 

237 id=f"{token_list[1]}", 

238 ) 

239 for token_list in TEST_TOKEN_LISTS 

240 ], 

241) 

242def test_get_adj_list_tokens(toks: list[str], tokenizer_name: str): 

243 result = get_adj_list_tokens(toks) 

244 match tokenizer_name: 

245 case "AOTP_UT": 

246 expected = [ 

247 "(0,1)", 

248 "<-->", 

249 "(1,1)", 

250 ";", 

251 "(1,0)", 

252 "<-->", 

253 "(1,1)", 

254 ";", 

255 "(0,1)", 

256 "<-->", 

257 "(0,0)", 

258 ";", 

259 ] 

260 case "AOTP_CTT_indexed": 

261 expected = [ 

262 "(", 

263 "0", 

264 ",", 

265 "1", 

266 ")", 

267 "<-->", 

268 "(", 

269 "1", 

270 ",", 

271 "1", 

272 ")", 

273 ";", 

274 "(", 

275 "1", 

276 ",", 

277 "0", 

278 ")", 

279 "<-->", 

280 "(", 

281 "1", 

282 ",", 

283 "1", 

284 ")", 

285 ";", 

286 "(", 

287 "0", 

288 ",", 

289 "1", 

290 ")", 

291 "<-->", 

292 "(", 

293 "0", 

294 ",", 

295 "0", 

296 ")", 

297 ";", 

298 ] 

299 assert result == expected 

300 

301 

302@pytest.mark.parametrize( 

303 ("toks", "tokenizer_name"), 

304 [ 

305 pytest.param( 

306 token_list[0], 

307 token_list[1], 

308 id=f"{token_list[1]}", 

309 ) 

310 for token_list in TEST_TOKEN_LISTS 

311 ], 

312) 

313def test_get_path_tokens(toks: list[str], tokenizer_name: str): 

314 result_notrim = get_path_tokens(toks) 

315 result_trim = get_path_tokens(toks, trim_end=True) 

316 match tokenizer_name: 

317 case "AOTP_UT": 

318 assert result_notrim == ["<PATH_START>", "(1,0)", "(1,1)", "<PATH_END>"] 

319 assert result_trim == ["(1,0)", "(1,1)"] 

320 case "AOTP_CTT_indexed": 

321 assert result_notrim == [ 

322 "<PATH_START>", 

323 "(", 

324 "1", 

325 ",", 

326 "0", 

327 ")", 

328 "(", 

329 "1", 

330 ",", 

331 "1", 

332 ")", 

333 "<PATH_END>", 

334 ] 

335 assert result_trim == ["(", "1", ",", "0", ")", "(", "1", ",", "1", ")"] 

336 

337 

338@pytest.mark.parametrize( 

339 ("toks", "tokenizer_name"), 

340 [ 

341 pytest.param( 

342 token_list[0], 

343 token_list[1], 

344 id=f"{token_list[1]}", 

345 ) 

346 for token_list in TEST_TOKEN_LISTS 

347 ], 

348) 

349def test_get_origin_tokens(toks: list[str], tokenizer_name: str): 

350 result = get_origin_tokens(toks) 

351 match tokenizer_name: 

352 case "AOTP_UT": 

353 assert result == ["(1,0)"] 

354 case "AOTP_CTT_indexed": 

355 assert result == ["(", "1", ",", "0", ")"] 

356 

357 

358@pytest.mark.parametrize( 

359 ("toks", "tokenizer_name"), 

360 [ 

361 pytest.param( 

362 token_list[0], 

363 token_list[1], 

364 id=f"{token_list[1]}", 

365 ) 

366 for token_list in TEST_TOKEN_LISTS 

367 ], 

368) 

369def test_get_target_tokens(toks: list[str], tokenizer_name: str): 

370 result = get_target_tokens(toks) 

371 match tokenizer_name: 

372 case "AOTP_UT": 

373 assert result == ["(1,1)"] 

374 case "AOTP_CTT_indexed": 

375 assert result == ["(", "1", ",", "1", ")"] 

376 

377 

378@pytest.mark.parametrize( 

379 ("toks", "tokenizer_name"), 

380 [ 

381 pytest.param( 

382 token_list[0], 

383 token_list[1], 

384 id=f"{token_list[1]}", 

385 ) 

386 for token_list in [MAZE_TOKENS] 

387 ], 

388) 

389def test_get_tokens_up_to_path_start_including_start( 

390 toks: list[str], 

391 tokenizer_name: str, 

392): 

393 # Dont test on `MAZE_TOKENS_AOTP_CTT_indexed` because this function doesn't support `AOTP_CTT_indexed` when `include_start_coord=True`. 

394 result = get_tokens_up_to_path_start(toks, include_start_coord=True) 

395 match tokenizer_name: 

396 case "AOTP_UT": 

397 expected = [ 

398 "<ADJLIST_START>", 

399 "(0,1)", 

400 "<-->", 

401 "(1,1)", 

402 ";", 

403 "(1,0)", 

404 "<-->", 

405 "(1,1)", 

406 ";", 

407 "(0,1)", 

408 "<-->", 

409 "(0,0)", 

410 ";", 

411 "<ADJLIST_END>", 

412 "<ORIGIN_START>", 

413 "(1,0)", 

414 "<ORIGIN_END>", 

415 "<TARGET_START>", 

416 "(1,1)", 

417 "<TARGET_END>", 

418 "<PATH_START>", 

419 "(1,0)", 

420 ] 

421 case "AOTP_CTT_indexed": 

422 expected = [ 

423 "<ADJLIST_START>", 

424 "(", 

425 "0", 

426 ",", 

427 "1", 

428 ")", 

429 "<-->", 

430 "(", 

431 "1", 

432 ",", 

433 "1", 

434 ")", 

435 ";", 

436 "(", 

437 "1", 

438 ",", 

439 "0", 

440 ")", 

441 "<-->", 

442 "(", 

443 "1", 

444 ",", 

445 "1", 

446 ")", 

447 ";", 

448 "(", 

449 "0", 

450 ",", 

451 "1", 

452 ")", 

453 "<-->", 

454 "(", 

455 "0", 

456 ",", 

457 "0", 

458 ")", 

459 ";", 

460 "<ADJLIST_END>", 

461 "<ORIGIN_START>", 

462 "(", 

463 "1", 

464 ",", 

465 "0", 

466 ")", 

467 "<ORIGIN_END>", 

468 "<TARGET_START>", 

469 "(", 

470 "1", 

471 ",", 

472 "1", 

473 ")", 

474 "<TARGET_END>", 

475 "<PATH_START>", 

476 "(", 

477 "1", 

478 ",", 

479 "0", 

480 ")", 

481 ] 

482 assert result == expected 

483 

484 

485@pytest.mark.parametrize( 

486 ("toks", "tokenizer_name"), 

487 [ 

488 pytest.param( 

489 token_list[0], 

490 token_list[1], 

491 id=f"{token_list[1]}", 

492 ) 

493 for token_list in TEST_TOKEN_LISTS 

494 ], 

495) 

496def test_get_tokens_up_to_path_start_excluding_start( 

497 toks: list[str], 

498 tokenizer_name: str, 

499): 

500 result = get_tokens_up_to_path_start(toks, include_start_coord=False) 

501 match tokenizer_name: 

502 case "AOTP_UT": 

503 expected = [ 

504 "<ADJLIST_START>", 

505 "(0,1)", 

506 "<-->", 

507 "(1,1)", 

508 ";", 

509 "(1,0)", 

510 "<-->", 

511 "(1,1)", 

512 ";", 

513 "(0,1)", 

514 "<-->", 

515 "(0,0)", 

516 ";", 

517 "<ADJLIST_END>", 

518 "<ORIGIN_START>", 

519 "(1,0)", 

520 "<ORIGIN_END>", 

521 "<TARGET_START>", 

522 "(1,1)", 

523 "<TARGET_END>", 

524 "<PATH_START>", 

525 ] 

526 case "AOTP_CTT_indexed": 

527 expected = [ 

528 "<ADJLIST_START>", 

529 "(", 

530 "0", 

531 ",", 

532 "1", 

533 ")", 

534 "<-->", 

535 "(", 

536 "1", 

537 ",", 

538 "1", 

539 ")", 

540 ";", 

541 "(", 

542 "1", 

543 ",", 

544 "0", 

545 ")", 

546 "<-->", 

547 "(", 

548 "1", 

549 ",", 

550 "1", 

551 ")", 

552 ";", 

553 "(", 

554 "0", 

555 ",", 

556 "1", 

557 ")", 

558 "<-->", 

559 "(", 

560 "0", 

561 ",", 

562 "0", 

563 ")", 

564 ";", 

565 "<ADJLIST_END>", 

566 "<ORIGIN_START>", 

567 "(", 

568 "1", 

569 ",", 

570 "0", 

571 ")", 

572 "<ORIGIN_END>", 

573 "<TARGET_START>", 

574 "(", 

575 "1", 

576 ",", 

577 "1", 

578 ")", 

579 "<TARGET_END>", 

580 "<PATH_START>", 

581 ] 

582 assert result == expected 

583 

584 

585@pytest.mark.parametrize( 

586 ("toks", "tokenizer_name"), 

587 [ 

588 pytest.param( 

589 token_list[0], 

590 token_list[1], 

591 id=f"{token_list[1]}", 

592 ) 

593 for token_list in TEST_TOKEN_LISTS 

594 ], 

595) 

596def test_strings_to_coords(toks: list[str], tokenizer_name: str): 

597 assert tokenizer_name 

598 adj_list = get_adj_list_tokens(toks) 

599 skipped = strings_to_coords(adj_list, when_noncoord="skip") 

600 included = strings_to_coords(adj_list, when_noncoord="include") 

601 

602 assert skipped == [ 

603 (0, 1), 

604 (1, 1), 

605 (1, 0), 

606 (1, 1), 

607 (0, 1), 

608 (0, 0), 

609 ] 

610 

611 assert included == [ 

612 (0, 1), 

613 "<-->", 

614 (1, 1), 

615 ";", 

616 (1, 0), 

617 "<-->", 

618 (1, 1), 

619 ";", 

620 (0, 1), 

621 "<-->", 

622 (0, 0), 

623 ";", 

624 ] 

625 

626 with pytest.raises(ValueError): # noqa: PT011 

627 strings_to_coords(adj_list, when_noncoord="error") 

628 

629 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)") == [(1, 2), (5, 6)] 

630 assert strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="skip") == [ 

631 (1, 2), 

632 (5, 6), 

633 ] 

634 assert strings_to_coords( 

635 "(1,2) <ADJLIST_START> (5,6)", 

636 when_noncoord="include", 

637 ) == [(1, 2), "<ADJLIST_START>", (5, 6)] 

638 with pytest.raises(ValueError): # noqa: PT011 

639 strings_to_coords("(1,2) <ADJLIST_START> (5,6)", when_noncoord="error") 

640 

641 

642@pytest.mark.parametrize( 

643 ("toks", "tokenizer_name"), 

644 [ 

645 pytest.param( 

646 token_list[0], 

647 token_list[1], 

648 id=f"{token_list[1]}", 

649 ) 

650 for token_list in TEST_TOKEN_LISTS 

651 ], 

652) 

653def test_coords_to_strings(toks: list[str], tokenizer_name: str): 

654 assert tokenizer_name 

655 adj_list = get_adj_list_tokens(toks) 

656 # config = MazeDatasetConfig(name="test", grid_n=2, n_mazes=1) 

657 coords = strings_to_coords(adj_list, when_noncoord="include") 

658 

659 skipped = coords_to_strings( 

660 coords, 

661 coord_to_strings_func=_coord_to_strings_UT, 

662 when_noncoord="skip", 

663 ) 

664 included = coords_to_strings( 

665 coords, 

666 coord_to_strings_func=_coord_to_strings_UT, 

667 when_noncoord="include", 

668 ) 

669 

670 assert skipped == [ 

671 "(0,1)", 

672 "(1,1)", 

673 "(1,0)", 

674 "(1,1)", 

675 "(0,1)", 

676 "(0,0)", 

677 ] 

678 

679 assert included == [ 

680 "(0,1)", 

681 "<-->", 

682 "(1,1)", 

683 ";", 

684 "(1,0)", 

685 "<-->", 

686 "(1,1)", 

687 ";", 

688 "(0,1)", 

689 "<-->", 

690 "(0,0)", 

691 ";", 

692 ] 

693 

694 with pytest.raises(ValueError): # noqa: PT011 

695 coords_to_strings( 

696 coords, 

697 coord_to_strings_func=_coord_to_strings_UT, 

698 when_noncoord="error", 

699 ) 

700 

701 

702def test_equal_except_adj_list_sequence(): 

703 assert equal_except_adj_list_sequence(MAZE_TOKENS[0], MAZE_TOKENS[0]) 

704 assert not equal_except_adj_list_sequence( 

705 MAZE_TOKENS[0], 

706 MAZE_TOKENS_AOTP_CTT_indexed[0], 

707 ) 

708 assert equal_except_adj_list_sequence( 

709 [ 

710 "<ADJLIST_START>", 

711 "(0,1)", 

712 "<-->", 

713 "(1,1)", 

714 ";", 

715 "(1,0)", 

716 "<-->", 

717 "(1,1)", 

718 ";", 

719 "(0,1)", 

720 "<-->", 

721 "(0,0)", 

722 ";", 

723 "<ADJLIST_END>", 

724 "<ORIGIN_START>", 

725 "(1,0)", 

726 "<ORIGIN_END>", 

727 "<TARGET_START>", 

728 "(1,1)", 

729 "<TARGET_END>", 

730 "<PATH_START>", 

731 "(1,0)", 

732 "(1,1)", 

733 "<PATH_END>", 

734 ], 

735 [ 

736 "<ADJLIST_START>", 

737 "(0,1)", 

738 "<-->", 

739 "(1,1)", 

740 ";", 

741 "(1,0)", 

742 "<-->", 

743 "(1,1)", 

744 ";", 

745 "(0,1)", 

746 "<-->", 

747 "(0,0)", 

748 ";", 

749 "<ADJLIST_END>", 

750 "<ORIGIN_START>", 

751 "(1,0)", 

752 "<ORIGIN_END>", 

753 "<TARGET_START>", 

754 "(1,1)", 

755 "<TARGET_END>", 

756 "<PATH_START>", 

757 "(1,0)", 

758 "(1,1)", 

759 "<PATH_END>", 

760 ], 

761 ) 

762 assert equal_except_adj_list_sequence( 

763 [ 

764 "<ADJLIST_START>", 

765 "(0,1)", 

766 "<-->", 

767 "(1,1)", 

768 ";", 

769 "(1,0)", 

770 "<-->", 

771 "(1,1)", 

772 ";", 

773 "(0,1)", 

774 "<-->", 

775 "(0,0)", 

776 ";", 

777 "<ADJLIST_END>", 

778 "<ORIGIN_START>", 

779 "(1,0)", 

780 "<ORIGIN_END>", 

781 "<TARGET_START>", 

782 "(1,1)", 

783 "<TARGET_END>", 

784 "<PATH_START>", 

785 "(1,0)", 

786 "(1,1)", 

787 "<PATH_END>", 

788 ], 

789 [ 

790 "<ADJLIST_START>", 

791 "(1,0)", 

792 "<-->", 

793 "(1,1)", 

794 ";", 

795 "(0,1)", 

796 "<-->", 

797 "(0,0)", 

798 ";", 

799 "(0,1)", 

800 "<-->", 

801 "(1,1)", 

802 ";", 

803 "<ADJLIST_END>", 

804 "<ORIGIN_START>", 

805 "(1,0)", 

806 "<ORIGIN_END>", 

807 "<TARGET_START>", 

808 "(1,1)", 

809 "<TARGET_END>", 

810 "<PATH_START>", 

811 "(1,0)", 

812 "(1,1)", 

813 "<PATH_END>", 

814 ], 

815 ) 

816 assert equal_except_adj_list_sequence( 

817 [ 

818 "<ADJLIST_START>", 

819 "(0,1)", 

820 "<-->", 

821 "(1,1)", 

822 ";", 

823 "(1,0)", 

824 "<-->", 

825 "(1,1)", 

826 ";", 

827 "(0,1)", 

828 "<-->", 

829 "(0,0)", 

830 ";", 

831 "<ADJLIST_END>", 

832 "<ORIGIN_START>", 

833 "(1,0)", 

834 "<ORIGIN_END>", 

835 "<TARGET_START>", 

836 "(1,1)", 

837 "<TARGET_END>", 

838 "<PATH_START>", 

839 "(1,0)", 

840 "(1,1)", 

841 "<PATH_END>", 

842 ], 

843 [ 

844 "<ADJLIST_START>", 

845 "(1,1)", 

846 "<-->", 

847 "(0,1)", 

848 ";", 

849 "(1,0)", 

850 "<-->", 

851 "(1,1)", 

852 ";", 

853 "(0,1)", 

854 "<-->", 

855 "(0,0)", 

856 ";", 

857 "<ADJLIST_END>", 

858 "<ORIGIN_START>", 

859 "(1,0)", 

860 "<ORIGIN_END>", 

861 "<TARGET_START>", 

862 "(1,1)", 

863 "<TARGET_END>", 

864 "<PATH_START>", 

865 "(1,0)", 

866 "(1,1)", 

867 "<PATH_END>", 

868 ], 

869 ) 

870 assert not equal_except_adj_list_sequence( 

871 [ 

872 "<ADJLIST_START>", 

873 "(0,1)", 

874 "<-->", 

875 "(1,1)", 

876 ";", 

877 "(1,0)", 

878 "<-->", 

879 "(1,1)", 

880 ";", 

881 "(0,1)", 

882 "<-->", 

883 "(0,0)", 

884 ";", 

885 "<ADJLIST_END>", 

886 "<ORIGIN_START>", 

887 "(1,0)", 

888 "<ORIGIN_END>", 

889 "<TARGET_START>", 

890 "(1,1)", 

891 "<TARGET_END>", 

892 "<PATH_START>", 

893 "(1,0)", 

894 "(1,1)", 

895 "<PATH_END>", 

896 ], 

897 [ 

898 "<ADJLIST_START>", 

899 "(1,0)", 

900 "<-->", 

901 "(1,1)", 

902 ";", 

903 "(0,1)", 

904 "<-->", 

905 "(0,0)", 

906 ";", 

907 "(0,1)", 

908 "<-->", 

909 "(1,1)", 

910 ";", 

911 "<ADJLIST_END>", 

912 "<ORIGIN_START>", 

913 "(1,0)", 

914 "<ORIGIN_END>", 

915 "<TARGET_START>", 

916 "(1,1)", 

917 "<TARGET_END>", 

918 "<PATH_START>", 

919 "(1,1)", 

920 "(1,0)", 

921 "<PATH_END>", 

922 ], 

923 ) 

924 assert not equal_except_adj_list_sequence( 

925 [ 

926 "<ADJLIST_START>", 

927 "(0,1)", 

928 "<-->", 

929 "(1,1)", 

930 ";", 

931 "(1,0)", 

932 "<-->", 

933 "(1,1)", 

934 ";", 

935 "(0,1)", 

936 "<-->", 

937 "(0,0)", 

938 ";", 

939 "<ADJLIST_END>", 

940 "<ORIGIN_START>", 

941 "(1,0)", 

942 "<ORIGIN_END>", 

943 "<TARGET_START>", 

944 "(1,1)", 

945 "<TARGET_END>", 

946 "<PATH_START>", 

947 "(1,0)", 

948 "(1,1)", 

949 "<PATH_END>", 

950 ], 

951 [ 

952 "<ADJLIST_START>", 

953 "(0,1)", 

954 "<-->", 

955 "(1,1)", 

956 ";", 

957 "(1,0)", 

958 "<-->", 

959 "(1,1)", 

960 ";", 

961 "(0,1)", 

962 "<-->", 

963 "(0,0)", 

964 ";", 

965 "<ADJLIST_END>", 

966 "<ORIGIN_START>", 

967 "(1,0)", 

968 "<ORIGIN_END>", 

969 "<TARGET_START>", 

970 "(1,1)", 

971 "<TARGET_END>", 

972 "<PATH_START>", 

973 "(1,0)", 

974 "(1,1)", 

975 "<PATH_END>", 

976 "<PATH_END>", 

977 ], 

978 ) 

979 assert not equal_except_adj_list_sequence( 

980 [ 

981 "<ADJLIST_START>", 

982 "(0,1)", 

983 "<-->", 

984 "(1,1)", 

985 ";", 

986 "(1,0)", 

987 "<-->", 

988 "(1,1)", 

989 ";", 

990 "(0,1)", 

991 "<-->", 

992 "(0,0)", 

993 ";", 

994 "<ADJLIST_END>", 

995 "(1,0)", 

996 "<ORIGIN_END>", 

997 "<TARGET_START>", 

998 "(1,1)", 

999 "<TARGET_END>", 

1000 "<PATH_START>", 

1001 "(1,0)", 

1002 "(1,1)", 

1003 "<PATH_END>", 

1004 ], 

1005 [ 

1006 "<ADJLIST_START>", 

1007 "(0,1)", 

1008 "<-->", 

1009 "(1,1)", 

1010 ";", 

1011 "(1,0)", 

1012 "<-->", 

1013 "(1,1)", 

1014 ";", 

1015 "(0,1)", 

1016 "<-->", 

1017 "(0,0)", 

1018 ";", 

1019 "<ADJLIST_END>", 

1020 "<ORIGIN_START>", 

1021 "(1,0)", 

1022 "<ORIGIN_END>", 

1023 "<TARGET_START>", 

1024 "(1,1)", 

1025 "<TARGET_END>", 

1026 "<PATH_START>", 

1027 "(1,0)", 

1028 "(1,1)", 

1029 "<PATH_END>", 

1030 ], 

1031 ) 

1032 assert not equal_except_adj_list_sequence( 

1033 [ 

1034 "<ADJLIST_START>", 

1035 "(0,1)", 

1036 "<-->", 

1037 "(1,1)", 

1038 ";", 

1039 "(1,0)", 

1040 "<-->", 

1041 "(1,1)", 

1042 ";", 

1043 "(0,1)", 

1044 "<-->", 

1045 "(0,0)", 

1046 ";", 

1047 "<ADJLIST_END>", 

1048 "<ORIGIN_START>", 

1049 "(1,0)", 

1050 "<ORIGIN_END>", 

1051 "<TARGET_START>", 

1052 "(1,1)", 

1053 "<TARGET_END>", 

1054 "<PATH_START>", 

1055 "(1,0)", 

1056 "(1,1)", 

1057 "<PATH_END>", 

1058 ], 

1059 [ 

1060 "(0,1)", 

1061 "<-->", 

1062 "(1,1)", 

1063 ";", 

1064 "(1,0)", 

1065 "<-->", 

1066 "(1,1)", 

1067 ";", 

1068 "(0,1)", 

1069 "<-->", 

1070 "(0,0)", 

1071 ";", 

1072 "<ADJLIST_END>", 

1073 "<ORIGIN_START>", 

1074 "(1,0)", 

1075 "<ORIGIN_END>", 

1076 "<TARGET_START>", 

1077 "(1,1)", 

1078 "<TARGET_END>", 

1079 "<PATH_START>", 

1080 "(1,0)", 

1081 "(1,1)", 

1082 "<PATH_END>", 

1083 ], 

1084 ) 

1085 with pytest.raises(ValueError): # noqa: PT011 

1086 equal_except_adj_list_sequence( 

1087 [ 

1088 "(0,1)", 

1089 "<-->", 

1090 "(1,1)", 

1091 ";", 

1092 "(1,0)", 

1093 "<-->", 

1094 "(1,1)", 

1095 ";", 

1096 "(0,1)", 

1097 "<-->", 

1098 "(0,0)", 

1099 ";", 

1100 "<ADJLIST_END>", 

1101 "<ORIGIN_START>", 

1102 "(1,0)", 

1103 "<ORIGIN_END>", 

1104 "<TARGET_START>", 

1105 "(1,1)", 

1106 "<TARGET_END>", 

1107 "<PATH_START>", 

1108 "(1,0)", 

1109 "(1,1)", 

1110 "<PATH_END>", 

1111 ], 

1112 [ 

1113 "(0,1)", 

1114 "<-->", 

1115 "(1,1)", 

1116 ";", 

1117 "(1,0)", 

1118 "<-->", 

1119 "(1,1)", 

1120 ";", 

1121 "(0,1)", 

1122 "<-->", 

1123 "(0,0)", 

1124 ";", 

1125 "<ADJLIST_END>", 

1126 "<ORIGIN_START>", 

1127 "(1,0)", 

1128 "<ORIGIN_END>", 

1129 "<TARGET_START>", 

1130 "(1,1)", 

1131 "<TARGET_END>", 

1132 "<PATH_START>", 

1133 "(1,0)", 

1134 "(1,1)", 

1135 "<PATH_END>", 

1136 ], 

1137 ) 

1138 with pytest.raises(ValueError): # noqa: PT011 

1139 equal_except_adj_list_sequence( 

1140 [ 

1141 "<ADJLIST_START>", 

1142 "(0,1)", 

1143 "<-->", 

1144 "(1,1)", 

1145 ";", 

1146 "(1,0)", 

1147 "<-->", 

1148 "(1,1)", 

1149 ";", 

1150 "(0,1)", 

1151 "<-->", 

1152 "(0,0)", 

1153 ";", 

1154 "<ORIGIN_START>", 

1155 "(1,0)", 

1156 "<ORIGIN_END>", 

1157 "<TARGET_START>", 

1158 "(1,1)", 

1159 "<TARGET_END>", 

1160 "<PATH_START>", 

1161 "(1,0)", 

1162 "(1,1)", 

1163 "<PATH_END>", 

1164 ], 

1165 [ 

1166 "<ADJLIST_START>", 

1167 "(0,1)", 

1168 "<-->", 

1169 "(1,1)", 

1170 ";", 

1171 "(1,0)", 

1172 "<-->", 

1173 "(1,1)", 

1174 ";", 

1175 "(0,1)", 

1176 "<-->", 

1177 "(0,0)", 

1178 ";", 

1179 "<ORIGIN_START>", 

1180 "(1,0)", 

1181 "<ORIGIN_END>", 

1182 "<TARGET_START>", 

1183 "(1,1)", 

1184 "<TARGET_END>", 

1185 "<PATH_START>", 

1186 "(1,0)", 

1187 "(1,1)", 

1188 "<PATH_END>", 

1189 ], 

1190 ) 

1191 assert not equal_except_adj_list_sequence( 

1192 [ 

1193 "<ADJLIST_START>", 

1194 "(0,1)", 

1195 "<-->", 

1196 "(1,1)", 

1197 ";", 

1198 "(1,0)", 

1199 "<-->", 

1200 "(1,1)", 

1201 ";", 

1202 "(0,1)", 

1203 "<-->", 

1204 "(0,0)", 

1205 ";", 

1206 "<ADJLIST_END>", 

1207 "<ORIGIN_START>", 

1208 "(1,0)", 

1209 "<ORIGIN_END>", 

1210 "<TARGET_START>", 

1211 "(1,1)", 

1212 "<TARGET_END>", 

1213 "<PATH_START>", 

1214 "(1,0)", 

1215 "(1,1)", 

1216 "<PATH_END>", 

1217 ], 

1218 [ 

1219 "<ADJLIST_START>", 

1220 "(0,1)", 

1221 "<-->", 

1222 "(1,1)", 

1223 ";", 

1224 "(1,0)", 

1225 "<-->", 

1226 "(1,1)", 

1227 ";", 

1228 "(0,1)", 

1229 "<-->", 

1230 "(0,0)", 

1231 ";", 

1232 "<ORIGIN_START>", 

1233 "(1,0)", 

1234 "<ORIGIN_END>", 

1235 "<TARGET_START>", 

1236 "(1,1)", 

1237 "<TARGET_END>", 

1238 "<PATH_START>", 

1239 "(1,0)", 

1240 "(1,1)", 

1241 "<PATH_END>", 

1242 ], 

1243 ) 

1244 

1245 # CTT 

1246 assert equal_except_adj_list_sequence( 

1247 [ 

1248 "<ADJLIST_START>", 

1249 "(", 

1250 "0", 

1251 ",", 

1252 "1", 

1253 ")", 

1254 "<-->", 

1255 "(", 

1256 "1", 

1257 ",", 

1258 "1", 

1259 ")", 

1260 ";", 

1261 "(", 

1262 "1", 

1263 ",", 

1264 "0", 

1265 ")", 

1266 "<-->", 

1267 "(", 

1268 "1", 

1269 ",", 

1270 "1", 

1271 ")", 

1272 ";", 

1273 "(", 

1274 "0", 

1275 ",", 

1276 "1", 

1277 ")", 

1278 "<-->", 

1279 "(", 

1280 "0", 

1281 ",", 

1282 "0", 

1283 ")", 

1284 ";", 

1285 "<ADJLIST_END>", 

1286 "<ORIGIN_START>", 

1287 "(", 

1288 "1", 

1289 ",", 

1290 "0", 

1291 ")", 

1292 "<ORIGIN_END>", 

1293 "<TARGET_START>", 

1294 "(", 

1295 "1", 

1296 ",", 

1297 "1", 

1298 ")", 

1299 "<TARGET_END>", 

1300 "<PATH_START>", 

1301 "(", 

1302 "1", 

1303 ",", 

1304 "0", 

1305 ")", 

1306 "(", 

1307 "1", 

1308 ",", 

1309 "1", 

1310 ")", 

1311 "<PATH_END>", 

1312 ], 

1313 [ 

1314 "<ADJLIST_START>", 

1315 "(", 

1316 "0", 

1317 ",", 

1318 "1", 

1319 ")", 

1320 "<-->", 

1321 "(", 

1322 "1", 

1323 ",", 

1324 "1", 

1325 ")", 

1326 ";", 

1327 "(", 

1328 "1", 

1329 ",", 

1330 "0", 

1331 ")", 

1332 "<-->", 

1333 "(", 

1334 "1", 

1335 ",", 

1336 "1", 

1337 ")", 

1338 ";", 

1339 "(", 

1340 "0", 

1341 ",", 

1342 "1", 

1343 ")", 

1344 "<-->", 

1345 "(", 

1346 "0", 

1347 ",", 

1348 "0", 

1349 ")", 

1350 ";", 

1351 "<ADJLIST_END>", 

1352 "<ORIGIN_START>", 

1353 "(", 

1354 "1", 

1355 ",", 

1356 "0", 

1357 ")", 

1358 "<ORIGIN_END>", 

1359 "<TARGET_START>", 

1360 "(", 

1361 "1", 

1362 ",", 

1363 "1", 

1364 ")", 

1365 "<TARGET_END>", 

1366 "<PATH_START>", 

1367 "(", 

1368 "1", 

1369 ",", 

1370 "0", 

1371 ")", 

1372 "(", 

1373 "1", 

1374 ",", 

1375 "1", 

1376 ")", 

1377 "<PATH_END>", 

1378 ], 

1379 ) 

1380 assert equal_except_adj_list_sequence( 

1381 [ 

1382 "<ADJLIST_START>", 

1383 "(", 

1384 "0", 

1385 ",", 

1386 "1", 

1387 ")", 

1388 "<-->", 

1389 "(", 

1390 "1", 

1391 ",", 

1392 "1", 

1393 ")", 

1394 ";", 

1395 "(", 

1396 "1", 

1397 ",", 

1398 "0", 

1399 ")", 

1400 "<-->", 

1401 "(", 

1402 "1", 

1403 ",", 

1404 "1", 

1405 ")", 

1406 ";", 

1407 "(", 

1408 "0", 

1409 ",", 

1410 "1", 

1411 ")", 

1412 "<-->", 

1413 "(", 

1414 "0", 

1415 ",", 

1416 "0", 

1417 ")", 

1418 ";", 

1419 "<ADJLIST_END>", 

1420 "<ORIGIN_START>", 

1421 "(", 

1422 "1", 

1423 ",", 

1424 "0", 

1425 ")", 

1426 "<ORIGIN_END>", 

1427 "<TARGET_START>", 

1428 "(", 

1429 "1", 

1430 ",", 

1431 "1", 

1432 ")", 

1433 "<TARGET_END>", 

1434 "<PATH_START>", 

1435 "(", 

1436 "1", 

1437 ",", 

1438 "0", 

1439 ")", 

1440 "(", 

1441 "1", 

1442 ",", 

1443 "1", 

1444 ")", 

1445 "<PATH_END>", 

1446 ], 

1447 [ 

1448 "<ADJLIST_START>", 

1449 "(", 

1450 "1", 

1451 ",", 

1452 "1", 

1453 ")", 

1454 "<-->", 

1455 "(", 

1456 "0", 

1457 ",", 

1458 "1", 

1459 ")", 

1460 ";", 

1461 "(", 

1462 "1", 

1463 ",", 

1464 "0", 

1465 ")", 

1466 "<-->", 

1467 "(", 

1468 "1", 

1469 ",", 

1470 "1", 

1471 ")", 

1472 ";", 

1473 "(", 

1474 "0", 

1475 ",", 

1476 "1", 

1477 ")", 

1478 "<-->", 

1479 "(", 

1480 "0", 

1481 ",", 

1482 "0", 

1483 ")", 

1484 ";", 

1485 "<ADJLIST_END>", 

1486 "<ORIGIN_START>", 

1487 "(", 

1488 "1", 

1489 ",", 

1490 "0", 

1491 ")", 

1492 "<ORIGIN_END>", 

1493 "<TARGET_START>", 

1494 "(", 

1495 "1", 

1496 ",", 

1497 "1", 

1498 ")", 

1499 "<TARGET_END>", 

1500 "<PATH_START>", 

1501 "(", 

1502 "1", 

1503 ",", 

1504 "0", 

1505 ")", 

1506 "(", 

1507 "1", 

1508 ",", 

1509 "1", 

1510 ")", 

1511 "<PATH_END>", 

1512 ], 

1513 ) 

1514 # This inactive test demonstrates the lack of robustness of the function for comparing source `LatticeMaze` objects. 

1515 # See function documentation for details. 

1516 # assert not equal_except_adj_list_sequence( 

1517 # "<ADJLIST_START> ( 0 , 1 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split(), 

1518 # "<ADJLIST_START> ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 1 , 0 ) <--> ( 1 , 1 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 0 ) <ORIGIN_END> <TARGET_START> ( 1 , 1 ) <TARGET_END> <PATH_START> ( 1 , 0 ) ( 1 , 1 ) <PATH_END>".split() 

1519 # ) 

1520 

1521 

1522# @mivanit: this was really difficult to understand 

1523@pytest.mark.parametrize( 

1524 ("type_", "validation_funcs", "assertion"), 

1525 [ 

1526 pytest.param( 

1527 type_, 

1528 vfs, 

1529 assertion, 

1530 id=f"{i}-{type_.__name__}", 

1531 ) 

1532 for i, (type_, vfs, assertion) in enumerate( 

1533 [ 

1534 ( 

1535 # type 

1536 PathTokenizers._PathTokenizer, 

1537 # validation_funcs 

1538 dict(), 

1539 # assertion 

1540 lambda x: PathTokenizers.StepSequence( 

1541 step_tokenizers=(StepTokenizers.Distance(),), 

1542 ) 

1543 in x, 

1544 ), 

1545 ( 

1546 # type 

1547 PathTokenizers._PathTokenizer, 

1548 # validation_funcs 

1549 {PathTokenizers._PathTokenizer: lambda x: x.is_valid()}, 

1550 # assertion 

1551 lambda x: PathTokenizers.StepSequence( 

1552 step_tokenizers=(StepTokenizers.Distance(),), 

1553 ) 

1554 not in x 

1555 and PathTokenizers.StepSequence( 

1556 step_tokenizers=( 

1557 StepTokenizers.Coord(), 

1558 StepTokenizers.Coord(), 

1559 ), 

1560 ) 

1561 not in x, 

1562 ), 

1563 ], 

1564 ) 

1565 ], 

1566) 

1567def test_all_instances2( 

1568 type_: FiniteValued, 

1569 validation_funcs: frozendict.frozendict[ 

1570 FiniteValued, 

1571 Callable[[FiniteValued], bool], 

1572 ], 

1573 assertion: Callable[[list[FiniteValued]], bool], 

1574): 

1575 assert assertion(all_instances(type_, validation_funcs)) 

1576 

1577 

1578@pytest.mark.parametrize( 

1579 ("coords", "result"), 

1580 [ 

1581 pytest.param( 

1582 np.array(coords), 

1583 res, 

1584 id=f"{coords}", 

1585 ) 

1586 for coords, res in ( 

1587 [ 

1588 ([[0, 0], [0, 1], [1, 1]], VOCAB.PATH_RIGHT), 

1589 ([[0, 0], [1, 0], [1, 1]], VOCAB.PATH_LEFT), 

1590 ([[0, 0], [0, 1], [0, 2]], VOCAB.PATH_FORWARD), 

1591 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD), 

1592 ([[0, 0], [0, 1], [0, 1]], VOCAB.PATH_STAY), 

1593 ([[1, 1], [0, 1], [0, 0]], VOCAB.PATH_LEFT), 

1594 ([[1, 1], [1, 0], [0, 0]], VOCAB.PATH_RIGHT), 

1595 ([[0, 2], [0, 1], [0, 0]], VOCAB.PATH_FORWARD), 

1596 ([[0, 0], [0, 1], [0, 0]], VOCAB.PATH_BACKWARD), 

1597 ([[0, 1], [0, 1], [0, 0]], ValueError), 

1598 ([[0, 1], [1, 1], [0, 0]], ValueError), 

1599 ([[1, 0], [1, 1], [0, 0]], ValueError), 

1600 ([[0, 1], [0, 2], [0, 0]], ValueError), 

1601 ([[0, 1], [0, 0], [0, 0]], VOCAB.PATH_STAY), 

1602 ([[1, 1], [0, 0], [0, 1]], ValueError), 

1603 ([[1, 1], [0, 0], [1, 0]], ValueError), 

1604 ([[0, 2], [0, 0], [0, 1]], ValueError), 

1605 ([[0, 0], [0, 0], [0, 1]], ValueError), 

1606 ([[0, 1], [0, 0], [0, 1]], VOCAB.PATH_BACKWARD), 

1607 ([[-1, 0], [0, 0], [1, 0]], VOCAB.PATH_FORWARD), 

1608 ([[-1, 0], [0, 0], [0, 1]], VOCAB.PATH_LEFT), 

1609 ([[-1, 0], [0, 0], [-1, 0]], VOCAB.PATH_BACKWARD), 

1610 ([[-1, 0], [0, 0], [0, -1]], VOCAB.PATH_RIGHT), 

1611 ([[-1, 0], [0, 0], [1, 0], [2, 0]], ValueError), 

1612 ([[-1, 0], [0, 0]], ValueError), 

1613 ([[-1, 0, 0], [0, 0, 0]], ValueError), 

1614 ] 

1615 ) 

1616 ], 

1617) 

1618def test_get_relative_direction( 

1619 coords: Int[np.ndarray, "prev_cur_next=3 axis=2"], 

1620 result: str | type[Exception], 

1621): 

1622 if isinstance(result, type) and issubclass(result, Exception): 

1623 with pytest.raises(result): 

1624 get_relative_direction(coords) 

1625 return 

1626 assert get_relative_direction(coords) == result 

1627 

1628 

1629@pytest.mark.parametrize( 

1630 ("edges", "result"), 

1631 [ 

1632 pytest.param( 

1633 edges, 

1634 res, 

1635 id=f"{edges}", 

1636 ) 

1637 for edges, res in ( 

1638 [ 

1639 (np.array([[0, 0], [0, 1]]), 1), 

1640 (np.array([[1, 0], [0, 1]]), 2), 

1641 (np.array([[-1, 0], [0, 1]]), 2), 

1642 (np.array([[0, 0], [5, 3]]), 8), 

1643 ( 

1644 np.array( 

1645 [ 

1646 [[0, 0], [0, 1]], 

1647 [[1, 0], [0, 1]], 

1648 [[-1, 0], [0, 1]], 

1649 [[0, 0], [5, 3]], 

1650 ], 

1651 ), 

1652 [1, 2, 2, 8], 

1653 ), 

1654 (np.array([[[0, 0], [5, 3]]]), [8]), 

1655 ] 

1656 ) 

1657 ], 

1658) 

1659def test_manhattan_distance( 

1660 edges: ConnectionArray | Connection, 

1661 result: Int[np.ndarray, " edges"] | Int[np.ndarray, ""] | type[Exception], 

1662): 

1663 if isinstance(result, type) and issubclass(result, Exception): 

1664 with pytest.raises(result): 

1665 manhattan_distance(edges) 

1666 return 

1667 assert np.array_equal(manhattan_distance(edges), np.array(result, dtype=np.int8)) 

1668 

1669 

1670@pytest.mark.parametrize( 

1671 "n", 

1672 [pytest.param(n) for n in [2, 3, 5, 20]], 

1673) 

1674def test_lattice_connection_arrray(n): 

1675 edges = lattice_connection_array(n) 

1676 assert tuple(edges.shape) == (2 * n * (n - 1), 2, 2) 

1677 assert np.all(np.sum(edges[:, 1], axis=1) > np.sum(edges[:, 0], axis=1)) 

1678 assert tuple(np.unique(edges, axis=0).shape) == (2 * n * (n - 1), 2, 2) 

1679 

1680 

1681@pytest.mark.parametrize( 

1682 ("edges", "maze"), 

1683 [ 

1684 pytest.param( 

1685 edges(), 

1686 maze, 

1687 id=f"edges[{i}]; maze[{j}]", 

1688 ) 

1689 for (i, edges), (j, maze) in itertools.product( 

1690 enumerate( 

1691 [ 

1692 lambda: lattice_connection_array(GRID_N), 

1693 lambda: np.flip(lattice_connection_array(GRID_N), axis=1), 

1694 lambda: lattice_connection_array(GRID_N - 1), 

1695 lambda: _NUMPY_RNG.choice( 

1696 lattice_connection_array(GRID_N), 

1697 2 * GRID_N, 

1698 axis=0, 

1699 ), 

1700 lambda: _NUMPY_RNG.choice( 

1701 lattice_connection_array(GRID_N), 

1702 1, 

1703 axis=0, 

1704 ), 

1705 ], 

1706 ), 

1707 enumerate(MAZE_DATASET.mazes), 

1708 ) 

1709 ], 

1710) 

1711def test_is_connection(edges: ConnectionArray, maze: LatticeMaze): 

1712 output = is_connection(edges, maze.connection_list) 

1713 sorted_edges = np.sort(edges, axis=1) 

1714 edge_direction = ( 

1715 (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0 

1716 ).astype(np.int8) 

1717 assert np.array_equal( 

1718 output, 

1719 maze.connection_list[ 

1720 edge_direction, 

1721 sorted_edges[:, 0, 0], 

1722 sorted_edges[:, 0, 1], 

1723 ], 

1724 )