Coverage for tests/unit/dataset/test_collected_dataset_2.py: 100%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1from pathlib import Path 

2 

3import numpy as np 

4import pytest 

5from muutils.json_serialize.util import _FORMAT_KEY 

6 

7from maze_dataset import ( 

8 MazeDataset, 

9 MazeDatasetCollection, 

10 MazeDatasetCollectionConfig, 

11 MazeDatasetConfig, 

12) 

13from maze_dataset.generation import LatticeMazeGenerators 

14from maze_dataset.maze import SolvedMaze 

15 

16# Define a temp path for file operations 

17TEMP_PATH: Path = Path("tests/_temp/maze_dataset_collection/") 

18 

19 

20@pytest.fixture(scope="module", autouse=True) 

21def setup_temp_dir(): 

22 """Create temporary directory for tests.""" 

23 TEMP_PATH.mkdir(parents=True, exist_ok=True) 

24 # No cleanup as requested 

25 

26 

27@pytest.fixture 

28def small_configs(): 

29 """Create a list of small MazeDatasetConfig objects for testing.""" 

30 return [ 

31 MazeDatasetConfig( 

32 name=f"test_{i}", 

33 grid_n=3, 

34 n_mazes=2, 

35 maze_ctor=LatticeMazeGenerators.gen_dfs, 

36 ) 

37 for i in range(2) 

38 ] 

39 

40 

41@pytest.fixture 

42def small_datasets(small_configs): 

43 """Create a list of small MazeDataset objects for testing.""" 

44 return [ 

45 MazeDataset.from_config( 

46 cfg, do_download=False, load_local=False, save_local=False 

47 ) 

48 for cfg in small_configs 

49 ] 

50 

51 

52@pytest.fixture 

53def collection_config(small_configs): 

54 """Create a MazeDatasetCollectionConfig for testing.""" 

55 return MazeDatasetCollectionConfig( 

56 name="test_collection", 

57 maze_dataset_configs=small_configs, 

58 ) 

59 

60 

61@pytest.fixture 

62def collection(small_datasets, collection_config): 

63 """Create a MazeDatasetCollection for testing.""" 

64 return MazeDatasetCollection( 

65 cfg=collection_config, 

66 maze_datasets=small_datasets, 

67 ) 

68 

69 

70def test_dataset_lengths(collection, small_datasets): 

71 """Test that dataset_lengths returns the correct length for each dataset.""" 

72 expected_lengths = [len(ds) for ds in small_datasets] 

73 assert collection.dataset_lengths == expected_lengths 

74 

75 

76def test_dataset_cum_lengths(collection): 

77 """Test that dataset_cum_lengths returns the correct cumulative lengths.""" 

78 expected_cum_lengths = np.array([2, 4]) # [2, 2+2] 

79 assert np.array_equal(collection.dataset_cum_lengths, expected_cum_lengths) 

80 

81 

82def test_mazes_cached_property(collection, small_datasets): 

83 """Test that the mazes cached_property correctly flattens all mazes.""" 

84 expected_mazes = [] 

85 for ds in small_datasets: 

86 expected_mazes.extend(ds.mazes) 

87 

88 # Access property 

89 assert hasattr(collection, "mazes") 

90 mazes = collection.mazes 

91 

92 # Check results 

93 assert len(mazes) == len(expected_mazes) 

94 assert mazes == expected_mazes 

95 

96 

97def test_getitem_across_datasets(collection, small_datasets): 

98 """Test that __getitem__ correctly accesses mazes across dataset boundaries.""" 

99 # First dataset 

100 assert collection[0] == small_datasets[0][0] 

101 assert collection[1] == small_datasets[0][1] 

102 

103 # Second dataset 

104 assert collection[2] == small_datasets[1][0] 

105 assert collection[3] == small_datasets[1][1] 

106 

107 

108def test_iteration(collection): 

109 """Test that the collection is iterable and returns all mazes.""" 

110 mazes = list(collection) 

111 assert len(mazes) == 4 

112 assert all(isinstance(maze, SolvedMaze) for maze in mazes) 

113 

114 

115def test_generate_classmethod(collection_config): 

116 """Test the generate class method creates a collection from config.""" 

117 collection = MazeDatasetCollection.generate( 

118 collection_config, do_download=False, load_local=False, save_local=False 

119 ) 

120 

121 assert isinstance(collection, MazeDatasetCollection) 

122 assert len(collection) == 4 

123 assert collection.cfg == collection_config 

124 

125 

126def test_serialization_deserialization(collection): 

127 """Test serialization and deserialization of the collection.""" 

128 # Serialize 

129 serialized = collection.serialize() 

130 

131 # Check keys 

132 assert _FORMAT_KEY in serialized 

133 assert serialized[_FORMAT_KEY] == "MazeDatasetCollection" 

134 assert "cfg" in serialized 

135 assert "maze_datasets" in serialized 

136 

137 # Deserialize 

138 deserialized = MazeDatasetCollection.load(serialized) 

139 

140 # Check properties 

141 assert deserialized.cfg.name == collection.cfg.name 

142 assert len(deserialized) == len(collection) 

143 

144 

145def test_save_and_read(collection): 

146 """Test saving and reading a collection to/from a file.""" 

147 file_path = TEMP_PATH / "test_collection.zanj" 

148 

149 # Save 

150 collection.save(file_path) 

151 assert file_path.exists() 

152 

153 # Read 

154 loaded = MazeDatasetCollection.read(file_path) 

155 assert len(loaded) == len(collection) 

156 assert loaded.cfg.name == collection.cfg.name 

157 

158 

159def test_as_tokens(collection): 

160 """Test as_tokens method with different parameters.""" 

161 # Create a simple tokenizer for testing 

162 from maze_dataset.tokenization import MazeTokenizerModular 

163 

164 tokenizer = MazeTokenizerModular() 

165 

166 # Test with join_tokens_individual_maze=False 

167 tokens = collection.as_tokens(tokenizer, limit=2, join_tokens_individual_maze=False) 

168 assert len(tokens) == 2 

169 assert all(isinstance(t, list) for t in tokens) 

170 

171 # Test with join_tokens_individual_maze=True 

172 tokens_joined = collection.as_tokens( 

173 tokenizer, limit=2, join_tokens_individual_maze=True 

174 ) 

175 assert len(tokens_joined) == 2 

176 assert all(isinstance(t, str) for t in tokens_joined) 

177 assert all(" " in t for t in tokens_joined) 

178 

179 

180def test_update_self_config(collection): 

181 """Test that update_self_config correctly updates the config.""" 

182 original_n_mazes = collection.cfg.n_mazes 

183 

184 # Change the dataset size by removing a maze 

185 collection.maze_datasets[0].mazes.pop() 

186 

187 # Update config 

188 collection.update_self_config() 

189 

190 # Check the config is updated 

191 assert collection.cfg.n_mazes == original_n_mazes - 1 

192 

193 

194def test_max_grid_properties(collection_config): 

195 """Test max_grid properties are calculated correctly.""" 

196 assert collection_config.max_grid_n == 3 

197 assert collection_config.max_grid_shape == (3, 3) 

198 assert np.array_equal(collection_config.max_grid_shape_np, np.array([3, 3])) 

199 

200 

201def test_config_serialization(collection_config): 

202 """Test that the collection config serializes and deserializes correctly.""" 

203 serialized = collection_config.serialize() 

204 deserialized = MazeDatasetCollectionConfig.load(serialized) 

205 

206 assert deserialized.name == collection_config.name 

207 assert len(deserialized.maze_dataset_configs) == len( 

208 collection_config.maze_dataset_configs 

209 ) 

210 

211 # Test summary method 

212 summary = collection_config.summary() 

213 assert "n_mazes" in summary 

214 assert "max_grid_n" in summary 

215 assert summary["n_mazes"] == 4 

216 

217 

218def test_mixed_grid_sizes(): 

219 """Test a collection with different grid sizes.""" 

220 configs = [ 

221 MazeDatasetConfig( 

222 name=f"test_grid_{i}", 

223 grid_n=i + 3, # 3, 4 

224 n_mazes=2, 

225 maze_ctor=LatticeMazeGenerators.gen_dfs, 

226 ) 

227 for i in range(2) 

228 ] 

229 

230 datasets = [ 

231 MazeDataset.from_config( 

232 cfg, do_download=False, load_local=False, save_local=False 

233 ) 

234 for cfg in configs 

235 ] 

236 

237 collection_config = MazeDatasetCollectionConfig( 

238 name="mixed_grid_collection", 

239 maze_dataset_configs=configs, 

240 ) 

241 

242 collection = MazeDatasetCollection( 

243 cfg=collection_config, 

244 maze_datasets=datasets, 

245 ) 

246 

247 # The max grid size should be the largest one 

248 assert collection.cfg.max_grid_n == 4 

249 assert collection.cfg.max_grid_shape == (4, 4) 

250 

251 

252def test_different_generation_methods(): 

253 """Test a collection with different generation methods.""" 

254 configs = [ 

255 MazeDatasetConfig( 

256 name="dfs_test", 

257 grid_n=3, 

258 n_mazes=2, 

259 maze_ctor=LatticeMazeGenerators.gen_dfs, 

260 ), 

261 MazeDatasetConfig( 

262 name="percolation_test", 

263 grid_n=3, 

264 n_mazes=2, 

265 maze_ctor=LatticeMazeGenerators.gen_percolation, 

266 maze_ctor_kwargs={"p": 0.7}, 

267 ), 

268 ] 

269 

270 datasets = [ 

271 MazeDataset.from_config( 

272 cfg, do_download=False, load_local=False, save_local=False 

273 ) 

274 for cfg in configs 

275 ] 

276 

277 collection_config = MazeDatasetCollectionConfig( 

278 name="mixed_gen_collection", 

279 maze_dataset_configs=configs, 

280 ) 

281 

282 collection = MazeDatasetCollection( 

283 cfg=collection_config, 

284 maze_datasets=datasets, 

285 ) 

286 

287 # Check that the collection has all mazes 

288 assert len(collection) == 4 

289 

290 # Check that the mazes are of different types based on their generation metadata 

291 # type ignore here since it might be None, but if its None that will cause an error anyways 

292 # For DFS 

293 assert collection[0].generation_meta.get("func_name") == "gen_dfs" # type: ignore[union-attr] 

294 # For percolation 

295 assert collection[2].generation_meta.get("func_name") == "gen_percolation" # type: ignore[union-attr] 

296 assert collection[2].generation_meta.get("percolation_p") == 0.7 # type: ignore[union-attr]