Coverage for tests/unit/dataset/test_collected_dataset_2.py: 100%
111 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
1from pathlib import Path
3import numpy as np
4import pytest
5from muutils.json_serialize.util import _FORMAT_KEY
7from maze_dataset import (
8 MazeDataset,
9 MazeDatasetCollection,
10 MazeDatasetCollectionConfig,
11 MazeDatasetConfig,
12)
13from maze_dataset.generation import LatticeMazeGenerators
14from maze_dataset.maze import SolvedMaze
16# Define a temp path for file operations
17TEMP_PATH: Path = Path("tests/_temp/maze_dataset_collection/")
20@pytest.fixture(scope="module", autouse=True)
21def setup_temp_dir():
22 """Create temporary directory for tests."""
23 TEMP_PATH.mkdir(parents=True, exist_ok=True)
24 # No cleanup as requested
27@pytest.fixture
28def small_configs():
29 """Create a list of small MazeDatasetConfig objects for testing."""
30 return [
31 MazeDatasetConfig(
32 name=f"test_{i}",
33 grid_n=3,
34 n_mazes=2,
35 maze_ctor=LatticeMazeGenerators.gen_dfs,
36 )
37 for i in range(2)
38 ]
41@pytest.fixture
42def small_datasets(small_configs):
43 """Create a list of small MazeDataset objects for testing."""
44 return [
45 MazeDataset.from_config(
46 cfg, do_download=False, load_local=False, save_local=False
47 )
48 for cfg in small_configs
49 ]
52@pytest.fixture
53def collection_config(small_configs):
54 """Create a MazeDatasetCollectionConfig for testing."""
55 return MazeDatasetCollectionConfig(
56 name="test_collection",
57 maze_dataset_configs=small_configs,
58 )
61@pytest.fixture
62def collection(small_datasets, collection_config):
63 """Create a MazeDatasetCollection for testing."""
64 return MazeDatasetCollection(
65 cfg=collection_config,
66 maze_datasets=small_datasets,
67 )
70def test_dataset_lengths(collection, small_datasets):
71 """Test that dataset_lengths returns the correct length for each dataset."""
72 expected_lengths = [len(ds) for ds in small_datasets]
73 assert collection.dataset_lengths == expected_lengths
76def test_dataset_cum_lengths(collection):
77 """Test that dataset_cum_lengths returns the correct cumulative lengths."""
78 expected_cum_lengths = np.array([2, 4]) # [2, 2+2]
79 assert np.array_equal(collection.dataset_cum_lengths, expected_cum_lengths)
82def test_mazes_cached_property(collection, small_datasets):
83 """Test that the mazes cached_property correctly flattens all mazes."""
84 expected_mazes = []
85 for ds in small_datasets:
86 expected_mazes.extend(ds.mazes)
88 # Access property
89 assert hasattr(collection, "mazes")
90 mazes = collection.mazes
92 # Check results
93 assert len(mazes) == len(expected_mazes)
94 assert mazes == expected_mazes
97def test_getitem_across_datasets(collection, small_datasets):
98 """Test that __getitem__ correctly accesses mazes across dataset boundaries."""
99 # First dataset
100 assert collection[0] == small_datasets[0][0]
101 assert collection[1] == small_datasets[0][1]
103 # Second dataset
104 assert collection[2] == small_datasets[1][0]
105 assert collection[3] == small_datasets[1][1]
108def test_iteration(collection):
109 """Test that the collection is iterable and returns all mazes."""
110 mazes = list(collection)
111 assert len(mazes) == 4
112 assert all(isinstance(maze, SolvedMaze) for maze in mazes)
115def test_generate_classmethod(collection_config):
116 """Test the generate class method creates a collection from config."""
117 collection = MazeDatasetCollection.generate(
118 collection_config, do_download=False, load_local=False, save_local=False
119 )
121 assert isinstance(collection, MazeDatasetCollection)
122 assert len(collection) == 4
123 assert collection.cfg == collection_config
126def test_serialization_deserialization(collection):
127 """Test serialization and deserialization of the collection."""
128 # Serialize
129 serialized = collection.serialize()
131 # Check keys
132 assert _FORMAT_KEY in serialized
133 assert serialized[_FORMAT_KEY] == "MazeDatasetCollection"
134 assert "cfg" in serialized
135 assert "maze_datasets" in serialized
137 # Deserialize
138 deserialized = MazeDatasetCollection.load(serialized)
140 # Check properties
141 assert deserialized.cfg.name == collection.cfg.name
142 assert len(deserialized) == len(collection)
145def test_save_and_read(collection):
146 """Test saving and reading a collection to/from a file."""
147 file_path = TEMP_PATH / "test_collection.zanj"
149 # Save
150 collection.save(file_path)
151 assert file_path.exists()
153 # Read
154 loaded = MazeDatasetCollection.read(file_path)
155 assert len(loaded) == len(collection)
156 assert loaded.cfg.name == collection.cfg.name
159def test_as_tokens(collection):
160 """Test as_tokens method with different parameters."""
161 # Create a simple tokenizer for testing
162 from maze_dataset.tokenization import MazeTokenizerModular
164 tokenizer = MazeTokenizerModular()
166 # Test with join_tokens_individual_maze=False
167 tokens = collection.as_tokens(tokenizer, limit=2, join_tokens_individual_maze=False)
168 assert len(tokens) == 2
169 assert all(isinstance(t, list) for t in tokens)
171 # Test with join_tokens_individual_maze=True
172 tokens_joined = collection.as_tokens(
173 tokenizer, limit=2, join_tokens_individual_maze=True
174 )
175 assert len(tokens_joined) == 2
176 assert all(isinstance(t, str) for t in tokens_joined)
177 assert all(" " in t for t in tokens_joined)
180def test_update_self_config(collection):
181 """Test that update_self_config correctly updates the config."""
182 original_n_mazes = collection.cfg.n_mazes
184 # Change the dataset size by removing a maze
185 collection.maze_datasets[0].mazes.pop()
187 # Update config
188 collection.update_self_config()
190 # Check the config is updated
191 assert collection.cfg.n_mazes == original_n_mazes - 1
194def test_max_grid_properties(collection_config):
195 """Test max_grid properties are calculated correctly."""
196 assert collection_config.max_grid_n == 3
197 assert collection_config.max_grid_shape == (3, 3)
198 assert np.array_equal(collection_config.max_grid_shape_np, np.array([3, 3]))
201def test_config_serialization(collection_config):
202 """Test that the collection config serializes and deserializes correctly."""
203 serialized = collection_config.serialize()
204 deserialized = MazeDatasetCollectionConfig.load(serialized)
206 assert deserialized.name == collection_config.name
207 assert len(deserialized.maze_dataset_configs) == len(
208 collection_config.maze_dataset_configs
209 )
211 # Test summary method
212 summary = collection_config.summary()
213 assert "n_mazes" in summary
214 assert "max_grid_n" in summary
215 assert summary["n_mazes"] == 4
218def test_mixed_grid_sizes():
219 """Test a collection with different grid sizes."""
220 configs = [
221 MazeDatasetConfig(
222 name=f"test_grid_{i}",
223 grid_n=i + 3, # 3, 4
224 n_mazes=2,
225 maze_ctor=LatticeMazeGenerators.gen_dfs,
226 )
227 for i in range(2)
228 ]
230 datasets = [
231 MazeDataset.from_config(
232 cfg, do_download=False, load_local=False, save_local=False
233 )
234 for cfg in configs
235 ]
237 collection_config = MazeDatasetCollectionConfig(
238 name="mixed_grid_collection",
239 maze_dataset_configs=configs,
240 )
242 collection = MazeDatasetCollection(
243 cfg=collection_config,
244 maze_datasets=datasets,
245 )
247 # The max grid size should be the largest one
248 assert collection.cfg.max_grid_n == 4
249 assert collection.cfg.max_grid_shape == (4, 4)
252def test_different_generation_methods():
253 """Test a collection with different generation methods."""
254 configs = [
255 MazeDatasetConfig(
256 name="dfs_test",
257 grid_n=3,
258 n_mazes=2,
259 maze_ctor=LatticeMazeGenerators.gen_dfs,
260 ),
261 MazeDatasetConfig(
262 name="percolation_test",
263 grid_n=3,
264 n_mazes=2,
265 maze_ctor=LatticeMazeGenerators.gen_percolation,
266 maze_ctor_kwargs={"p": 0.7},
267 ),
268 ]
270 datasets = [
271 MazeDataset.from_config(
272 cfg, do_download=False, load_local=False, save_local=False
273 )
274 for cfg in configs
275 ]
277 collection_config = MazeDatasetCollectionConfig(
278 name="mixed_gen_collection",
279 maze_dataset_configs=configs,
280 )
282 collection = MazeDatasetCollection(
283 cfg=collection_config,
284 maze_datasets=datasets,
285 )
287 # Check that the collection has all mazes
288 assert len(collection) == 4
290 # Check that the mazes are of different types based on their generation metadata
291 # type ignore here since it might be None, but if its None that will cause an error anyways
292 # For DFS
293 assert collection[0].generation_meta.get("func_name") == "gen_dfs" # type: ignore[union-attr]
294 # For percolation
295 assert collection[2].generation_meta.get("func_name") == "gen_percolation" # type: ignore[union-attr]
296 assert collection[2].generation_meta.get("percolation_p") == 0.7 # type: ignore[union-attr]