Coverage for tests/unit/dataset/test_collected_dataset.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1from functools import cached_property 

2from pathlib import Path 

3 

4import numpy as np 

5 

6from maze_dataset.dataset.collected_dataset import ( 

7 MazeDatasetCollection, 

8 MazeDatasetCollectionConfig, 

9 MazeDatasetConfig, 

10) 

11 

12DATASET_LENGTHS: list[int] = [1, 0, 3, 2, 1] 

13DATASET_GRID_SIZES: list[int] = [5, 1, 3, 3, 4] 

14 

15 

16class TestMazeDatasetCollection: 

17 @cached_property 

18 def test_collection(self) -> MazeDatasetCollection: 

19 config = MazeDatasetCollectionConfig( 

20 name="test_collection", 

21 maze_dataset_configs=[ 

22 MazeDatasetConfig( 

23 n_mazes=n_mazes, 

24 grid_n=grid_n, 

25 name=f"test_dataset_{n_mazes}_{grid_n}", 

26 ) 

27 for n_mazes, grid_n in zip( 

28 DATASET_LENGTHS, 

29 DATASET_GRID_SIZES, 

30 strict=False, 

31 ) 

32 ], 

33 ) 

34 return MazeDatasetCollection.from_config( 

35 config, 

36 do_generate=True, 

37 load_local=False, 

38 do_download=False, 

39 save_local=True, 

40 local_base_path=Path("data/"), 

41 ) 

42 

43 def test_dataset_lengths(self): 

44 assert np.all( 

45 np.array(self.test_collection.dataset_lengths) == np.array(DATASET_LENGTHS), 

46 ) 

47 

48 def test_dataset_cum_lengths(self): 

49 assert ( 

50 self.test_collection.dataset_cum_lengths == np.array([1, 1, 4, 6, 7]) 

51 ).all() 

52 

53 def test_mazes(self): 

54 assert len(self.test_collection.mazes) == 7 

55 assert self.test_collection.mazes[0].connection_list.shape == (2, 5, 5) 

56 assert self.test_collection.mazes[-1].connection_list.shape == (2, 4, 4) 

57 

58 def test_len(self): 

59 assert len(self.test_collection) == 7 

60 

61 def test_getitem(self): 

62 # print(len(self.test_collection)) 

63 # print(self.test_collection.mazes) 

64 assert self.test_collection[0].connection_list.shape == (2, 5, 5) 

65 assert self.test_collection[1].connection_list.shape == (2, 3, 3) 

66 assert self.test_collection[2].connection_list.shape == (2, 3, 3) 

67 assert self.test_collection[3].connection_list.shape == (2, 3, 3) 

68 assert self.test_collection[4].connection_list.shape == (2, 3, 3) 

69 assert self.test_collection[5].connection_list.shape == (2, 3, 3) 

70 assert self.test_collection[6].connection_list.shape == (2, 4, 4) 

71 

72 for i in range(sum(DATASET_LENGTHS)): 

73 assert ( 

74 self.test_collection[i].connection_list.shape 

75 == self.test_collection.mazes[i].connection_list.shape 

76 ) 

77 assert ( 

78 self.test_collection[i].connection_list 

79 == self.test_collection.mazes[i].connection_list 

80 ).all() 

81 

82 def test_download(self): 

83 # TODO: test downloading after we implement downloading datasets 

84 pass 

85 

86 def test_serialize_and_load(self): 

87 serialized = self.test_collection.serialize() 

88 loaded = MazeDatasetCollection.load(serialized) 

89 assert loaded.mazes == self.test_collection.mazes 

90 assert loaded.cfg.diff(self.test_collection.cfg) == {} 

91 assert loaded.cfg == self.test_collection.cfg 

92 

93 def test_save_read(self): 

94 self.test_collection.save("tests/_temp/collected_dataset_test_save_read.zanj") 

95 loaded = MazeDatasetCollection.read( 

96 "tests/_temp/collected_dataset_test_save_read.zanj", 

97 ) 

98 assert loaded.mazes == self.test_collection.mazes 

99 assert loaded.cfg.diff(self.test_collection.cfg) == {} 

100 assert loaded.cfg == self.test_collection.cfg