Coverage for tests/unit/dataset/test_rasterized.py: 100%
34 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
1from itertools import product
3import numpy as np
4import pytest
6from maze_dataset import LatticeMazeGenerators, MazeDatasetConfig
7from maze_dataset.dataset.maze_dataset import MazeDataset
8from maze_dataset.dataset.rasterized import (
9 RasterizedMazeDataset,
10 RasterizedMazeDatasetConfig,
11 make_numpy_collection,
12)
14_PARAMTETRIZATION = (
15 "remove_isolated_cells, extend_pixels, endpoints_as_open",
16 list(product([True, False], repeat=3)),
17)
20@pytest.mark.parametrize(*_PARAMTETRIZATION)
21def test_rasterized_new(remove_isolated_cells, extend_pixels, endpoints_as_open):
22 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig(
23 name="test",
24 grid_n=5,
25 n_mazes=2,
26 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells
27 maze_ctor_kwargs=dict(p=0.4),
28 remove_isolated_cells=remove_isolated_cells,
29 extend_pixels=extend_pixels,
30 endpoints_as_open=endpoints_as_open,
31 )
32 dataset: RasterizedMazeDataset = RasterizedMazeDataset.from_config_augmented(
33 cfg,
34 load_local=False,
35 )
37 print(f"{dataset[0][0].shape = }, {dataset[0][1].shape = }")
38 print(f"{dataset[0][1] = }\n{dataset[1][1] = }")
41@pytest.mark.parametrize(*_PARAMTETRIZATION)
42def test_rasterized_from_mazedataset(
43 remove_isolated_cells,
44 extend_pixels,
45 endpoints_as_open,
46):
47 cfg: MazeDatasetConfig = MazeDatasetConfig(
48 name="test",
49 grid_n=5,
50 n_mazes=2,
51 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells
52 maze_ctor_kwargs=dict(p=0.4),
53 )
54 dataset_m: MazeDataset = MazeDataset.from_config(cfg, load_local=False)
55 dataset_r: RasterizedMazeDataset = RasterizedMazeDataset.from_base_MazeDataset(
56 dataset_m,
57 added_params=dict(
58 remove_isolated_cells=remove_isolated_cells,
59 extend_pixels=extend_pixels,
60 endpoints_as_open=endpoints_as_open,
61 ),
62 )
64 assert dataset_r
67@pytest.mark.parametrize(*_PARAMTETRIZATION)
68def test_make_numpy_collection(remove_isolated_cells, extend_pixels, endpoints_as_open):
69 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig(
70 name="test",
71 grid_n=5,
72 n_mazes=2,
73 maze_ctor=LatticeMazeGenerators.gen_percolation, # use percolation here to get some isolated cells
74 maze_ctor_kwargs=dict(p=0.4),
75 remove_isolated_cells=remove_isolated_cells,
76 extend_pixels=extend_pixels,
77 endpoints_as_open=endpoints_as_open,
78 )
80 output = make_numpy_collection(
81 base_cfg=cfg,
82 grid_sizes=[2, 3],
83 from_config_kwargs=dict(load_local=False),
84 verbose=True,
85 )
87 assert isinstance(output, dict)
88 assert isinstance(output["configs"], dict)
89 assert isinstance(output["arrays"], dict)
91 assert len(output["configs"]) == 2
92 assert len(output["arrays"]) == 2
94 for k, v in output["configs"].items():
95 assert isinstance(k, str)
96 assert isinstance(v, RasterizedMazeDatasetConfig)
98 for k, v in output["arrays"].items():
99 assert isinstance(k, str)
100 assert isinstance(v, np.ndarray)