Coverage for tests/unit/generation/test_generators.py: 84%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1import warnings 

2 

3import numpy as np 

4import pytest 

5 

6from maze_dataset.generation.generators import ( 

7 GENERATORS_MAP, 

8 LatticeMazeGenerators, 

9 get_maze_with_solution, 

10) 

11from maze_dataset.maze import Coord, SolvedMaze 

12 

13 

14def test_gen_dfs_square(): 

15 three_by_three: Coord = np.array([3, 3]) 

16 maze = LatticeMazeGenerators.gen_dfs(three_by_three) 

17 

18 assert maze.connection_list.shape == (2, 3, 3) 

19 

20 

21def test_gen_dfs_oblong(): 

22 three_by_four: Coord = np.array([3, 4]) 

23 maze = LatticeMazeGenerators.gen_dfs(three_by_four) 

24 

25 assert maze.connection_list.shape == (2, 3, 4) 

26 

27 

28@pytest.mark.parametrize("gfunc_name", GENERATORS_MAP.keys()) 

29def test_get_maze_with_solution(gfunc_name): 

30 three_by_three: Coord = np.array([5, 5]) 

31 

32 try: 

33 maze: SolvedMaze = get_maze_with_solution(gfunc_name, three_by_three) 

34 except ValueError as e: 

35 if gfunc_name == "gen_percolation": 

36 warnings.warn( 

37 f"Skipping test for {gfunc_name} because percolation is stochastic, and a connected component might not be found", 

38 ) 

39 else: 

40 raise e # noqa: TRY201 

41 

42 assert maze.connection_list.shape == (2, 5, 5) 

43 assert len(maze.solution[0]) == 2 

44 assert len(maze.solution[-1]) == 2