Coverage for tests/unit/dataset/test_collected_dataset.py: 100%
46 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
1from functools import cached_property
2from pathlib import Path
4import numpy as np
6from maze_dataset.dataset.collected_dataset import (
7 MazeDatasetCollection,
8 MazeDatasetCollectionConfig,
9 MazeDatasetConfig,
10)
12DATASET_LENGTHS: list[int] = [1, 0, 3, 2, 1]
13DATASET_GRID_SIZES: list[int] = [5, 1, 3, 3, 4]
16class TestMazeDatasetCollection:
17 @cached_property
18 def test_collection(self) -> MazeDatasetCollection:
19 config = MazeDatasetCollectionConfig(
20 name="test_collection",
21 maze_dataset_configs=[
22 MazeDatasetConfig(
23 n_mazes=n_mazes,
24 grid_n=grid_n,
25 name=f"test_dataset_{n_mazes}_{grid_n}",
26 )
27 for n_mazes, grid_n in zip(
28 DATASET_LENGTHS,
29 DATASET_GRID_SIZES,
30 strict=False,
31 )
32 ],
33 )
34 return MazeDatasetCollection.from_config(
35 config,
36 do_generate=True,
37 load_local=False,
38 do_download=False,
39 save_local=True,
40 local_base_path=Path("data/"),
41 )
43 def test_dataset_lengths(self):
44 assert np.all(
45 np.array(self.test_collection.dataset_lengths) == np.array(DATASET_LENGTHS),
46 )
48 def test_dataset_cum_lengths(self):
49 assert (
50 self.test_collection.dataset_cum_lengths == np.array([1, 1, 4, 6, 7])
51 ).all()
53 def test_mazes(self):
54 assert len(self.test_collection.mazes) == 7
55 assert self.test_collection.mazes[0].connection_list.shape == (2, 5, 5)
56 assert self.test_collection.mazes[-1].connection_list.shape == (2, 4, 4)
58 def test_len(self):
59 assert len(self.test_collection) == 7
61 def test_getitem(self):
62 # print(len(self.test_collection))
63 # print(self.test_collection.mazes)
64 assert self.test_collection[0].connection_list.shape == (2, 5, 5)
65 assert self.test_collection[1].connection_list.shape == (2, 3, 3)
66 assert self.test_collection[2].connection_list.shape == (2, 3, 3)
67 assert self.test_collection[3].connection_list.shape == (2, 3, 3)
68 assert self.test_collection[4].connection_list.shape == (2, 3, 3)
69 assert self.test_collection[5].connection_list.shape == (2, 3, 3)
70 assert self.test_collection[6].connection_list.shape == (2, 4, 4)
72 for i in range(sum(DATASET_LENGTHS)):
73 assert (
74 self.test_collection[i].connection_list.shape
75 == self.test_collection.mazes[i].connection_list.shape
76 )
77 assert (
78 self.test_collection[i].connection_list
79 == self.test_collection.mazes[i].connection_list
80 ).all()
82 def test_download(self):
83 # TODO: test downloading after we implement downloading datasets
84 pass
86 def test_serialize_and_load(self):
87 serialized = self.test_collection.serialize()
88 loaded = MazeDatasetCollection.load(serialized)
89 assert loaded.mazes == self.test_collection.mazes
90 assert loaded.cfg.diff(self.test_collection.cfg) == {}
91 assert loaded.cfg == self.test_collection.cfg
93 def test_save_read(self):
94 self.test_collection.save("tests/_temp/collected_dataset_test_save_read.zanj")
95 loaded = MazeDatasetCollection.read(
96 "tests/_temp/collected_dataset_test_save_read.zanj",
97 )
98 assert loaded.mazes == self.test_collection.mazes
99 assert loaded.cfg.diff(self.test_collection.cfg) == {}
100 assert loaded.cfg == self.test_collection.cfg