Coverage for tests/unit/dataset/test_rasterized.py: 100%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1from itertools import product 

2 

3import numpy as np 

4import pytest 

5 

6from maze_dataset import LatticeMazeGenerators, MazeDatasetConfig 

7from maze_dataset.dataset.maze_dataset import MazeDataset 

8from maze_dataset.dataset.rasterized import ( 

9 RasterizedMazeDataset, 

10 RasterizedMazeDatasetConfig, 

11 make_numpy_collection, 

12) 

13 

14_PARAMTETRIZATION = ( 

15 "remove_isolated_cells, extend_pixels, endpoints_as_open", 

16 list(product([True, False], repeat=3)), 

17) 

18 

19 

20@pytest.mark.parametrize(*_PARAMTETRIZATION) 

21def test_rasterized_new(remove_isolated_cells, extend_pixels, endpoints_as_open): 

22 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig( 

23 name="test", 

24 grid_n=5, 

25 n_mazes=2, 

26 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells 

27 maze_ctor_kwargs=dict(p=0.4), 

28 remove_isolated_cells=remove_isolated_cells, 

29 extend_pixels=extend_pixels, 

30 endpoints_as_open=endpoints_as_open, 

31 ) 

32 dataset: RasterizedMazeDataset = RasterizedMazeDataset.from_config_augmented( 

33 cfg, 

34 load_local=False, 

35 ) 

36 

37 print(f"{dataset[0][0].shape = }, {dataset[0][1].shape = }") 

38 print(f"{dataset[0][1] = }\n{dataset[1][1] = }") 

39 

40 

41@pytest.mark.parametrize(*_PARAMTETRIZATION) 

42def test_rasterized_from_mazedataset( 

43 remove_isolated_cells, 

44 extend_pixels, 

45 endpoints_as_open, 

46): 

47 cfg: MazeDatasetConfig = MazeDatasetConfig( 

48 name="test", 

49 grid_n=5, 

50 n_mazes=2, 

51 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells 

52 maze_ctor_kwargs=dict(p=0.4), 

53 ) 

54 dataset_m: MazeDataset = MazeDataset.from_config(cfg, load_local=False) 

55 dataset_r: RasterizedMazeDataset = RasterizedMazeDataset.from_base_MazeDataset( 

56 dataset_m, 

57 added_params=dict( 

58 remove_isolated_cells=remove_isolated_cells, 

59 extend_pixels=extend_pixels, 

60 endpoints_as_open=endpoints_as_open, 

61 ), 

62 ) 

63 

64 assert dataset_r 

65 

66 

67@pytest.mark.parametrize(*_PARAMTETRIZATION) 

68def test_make_numpy_collection(remove_isolated_cells, extend_pixels, endpoints_as_open): 

69 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig( 

70 name="test", 

71 grid_n=5, 

72 n_mazes=2, 

73 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells 

74 maze_ctor_kwargs=dict(p=0.4), 

75 remove_isolated_cells=remove_isolated_cells, 

76 extend_pixels=extend_pixels, 

77 endpoints_as_open=endpoints_as_open, 

78 ) 

79 

80 output = make_numpy_collection( 

81 base_cfg=cfg, 

82 grid_sizes=[2, 3], 

83 from_config_kwargs=dict(load_local=False), 

84 verbose=True, 

85 ) 

86 

87 assert isinstance(output, dict) 

88 assert isinstance(output["configs"], dict) 

89 assert isinstance(output["arrays"], dict) 

90 

91 assert len(output["configs"]) == 2 

92 assert len(output["arrays"]) == 2 

93 

94 for k, v in output["configs"].items(): 

95 assert isinstance(k, str) 

96 assert isinstance(v, RasterizedMazeDatasetConfig) 

97 

98 for k, v in output["arrays"].items(): 

99 assert isinstance(k, str) 

100 assert isinstance(v, np.ndarray)