Coverage for maze_dataset/dataset/collected_dataset.py: 44%
87 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
1"""collecting different maze datasets into a single dataset, for greater variety in a training or validation set
3> [!CAUTION]
4> `MazeDatasetCollection` is not thoroughly tested and is not guaranteed to work.
6"""
8import itertools
9import json
10import typing
11from functools import cached_property
13import numpy as np
14from jaxtyping import Int
15from muutils.json_serialize import (
16 json_serialize,
17 serializable_dataclass,
18 serializable_field,
19)
20from muutils.json_serialize.util import _FORMAT_KEY, JSONdict
21from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash
22from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler
24from maze_dataset.constants import Coord, CoordTup
25from maze_dataset.dataset.dataset import GPTDataset, GPTDatasetConfig
26from maze_dataset.dataset.maze_dataset import MazeDataset, MazeDatasetConfig
27from maze_dataset.maze import LatticeMaze
30@serializable_dataclass(kw_only=True)
31class MazeDatasetCollectionConfig(GPTDatasetConfig):
32 """maze dataset collection configuration, including tokenizers and shuffle"""
34 # Attributes without a default cannot follow attributes with one [misc]
35 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( # type: ignore[misc]
36 serialization_fn=lambda configs: [config.serialize() for config in configs],
37 loading_fn=lambda data: [
38 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"]
39 ],
40 )
42 def summary(self) -> dict:
43 """return a summary of the config"""
44 return dict(
45 n_mazes=self.n_mazes,
46 max_grid_n=self.max_grid_n,
47 max_grid_shape=self.max_grid_shape,
48 fname=self.to_fname(),
49 cfg_summaries=[c.summary() for c in self.maze_dataset_configs],
50 )
52 @property
53 def n_mazes(self) -> int:
54 """return the total number of mazes in the collection across all dataset"""
55 return sum(config.n_mazes for config in self.maze_dataset_configs)
57 @property
58 def max_grid_n(self) -> int:
59 """return the maximum grid size of the mazes in the collection"""
60 return max(config.grid_n for config in self.maze_dataset_configs)
62 @property
63 def max_grid_shape(self) -> CoordTup:
64 """return the maximum grid shape of the mazes in the collection"""
65 return (self.max_grid_n, self.max_grid_n)
67 @property
68 def max_grid_shape_np(self) -> Coord:
69 """return the maximum grid shape of the mazes in the collection as a numpy array"""
70 return np.array(self.max_grid_shape, dtype=np.int32)
72 def stable_hash_cfg(self) -> int:
73 """return a stable hash of the config"""
74 return stable_hash(json.dumps(self.serialize()))
76 def to_fname(self) -> str:
77 """convert config to a filename"""
78 return sanitize_fname(
79 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}",
80 )
83class MazeDatasetCollection(GPTDataset):
84 """a collection of maze datasets"""
86 def __init__(
87 self,
88 cfg: MazeDatasetCollectionConfig,
89 maze_datasets: list[MazeDataset],
90 generation_metadata_collected: dict | None = None,
91 ) -> None:
92 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s"
93 super().__init__()
94 self.cfg: MazeDatasetCollectionConfig = cfg
95 self.maze_datasets: list[MazeDataset] = list(maze_datasets)
96 for c, ds in zip(
97 self.cfg.maze_dataset_configs,
98 self.maze_datasets,
99 strict=False,
100 ):
101 assert c.name == ds.cfg.name
102 assert c == ds.cfg
104 self.generation_metadata_collected: dict | None = generation_metadata_collected
106 @property
107 def dataset_lengths(self) -> list[int]:
108 """return the lengths of each dataset in the collection"""
109 return [len(dataset) for dataset in self.maze_datasets]
111 @property
112 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
113 """return the cumulative lengths of each dataset in the collection"""
114 return np.array(list(itertools.accumulate(self.dataset_lengths)))
116 @cached_property
117 def mazes(self) -> list[LatticeMaze]:
118 "single list of all mazes in the collection"
119 return list(
120 itertools.chain.from_iterable(
121 dataset.mazes for dataset in self.maze_datasets
122 ),
123 )
125 def __len__(self) -> int:
126 """return the total number of mazes in the collection"""
127 return sum(len(dataset) for dataset in self.maze_datasets)
129 def __getitem__(self, index: int) -> LatticeMaze:
130 "get a maze by index"
131 # find which dataset the index belongs to
132 # we add 1, since np.searchsorted returns the
133 # index of the last element that is strictly less than the target
134 # while we want the index of the last element less than or equal to the target
135 dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1))
136 index_adjusted: int = index
137 if dataset_idx > 0:
138 # if the index is 0, `dataset_idx - 1` will be -1.
139 # We just want to use the base index
140 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1]
141 return self.maze_datasets[dataset_idx][index_adjusted]
143 @classmethod
144 def generate(
145 cls,
146 cfg: MazeDatasetCollectionConfig,
147 **kwargs,
148 ) -> "MazeDatasetCollection":
149 """generate a dataset collection from a config"""
150 datasets = [
151 MazeDataset.generate(config, **kwargs)
152 for config in cfg.maze_dataset_configs
153 ]
154 return cls(cfg, datasets)
156 @classmethod
157 def download(
158 cls,
159 cfg: MazeDatasetCollectionConfig,
160 **kwargs,
161 ) -> "MazeDatasetCollection":
162 "(not implemented!) download a dataset collection from a config"
163 datasets = [
164 MazeDataset.download(config, **kwargs)
165 for config in cfg.maze_dataset_configs
166 ]
167 return cls(cfg, datasets)
169 def serialize(self) -> JSONdict:
170 """serialize the dataset collection"""
171 return {
172 _FORMAT_KEY: "MazeDatasetCollection",
173 "cfg": self.cfg.serialize(),
174 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
175 "generation_metadata_collected": json_serialize(
176 self.generation_metadata_collected,
177 ),
178 }
180 @classmethod
181 def load(cls, data: JSONdict) -> "MazeDatasetCollection":
182 """load the dataset collection from the representation created by `serialize`"""
183 assert data[_FORMAT_KEY] == "MazeDatasetCollection"
184 return cls(
185 **{
186 key: load_item_recursive(data[key], tuple())
187 for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
188 },
189 )
191 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow?
192 def as_tokens(
193 self,
194 # TODO: MazeTokenizer
195 maze_tokenizer, # noqa: ANN001
196 limit: int | None = None,
197 join_tokens_individual_maze: bool = False,
198 ) -> list[list[str]] | list[str]:
199 """return the dataset as tokens
201 if join_tokens_individual_maze is True, then the tokens of each maze are
202 joined with a space, and the result is a list of strings.
203 i.e.:
204 >>> dataset.as_tokens(join_tokens_individual_maze=False)
205 [["a", "b", "c"], ["d", "e", "f"]]
206 >>> dataset.as_tokens(join_tokens_individual_maze=True)
207 ["a b c", "d e f"]
208 """
209 output: list[list[str]] = [
210 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
211 ]
212 if join_tokens_individual_maze:
213 return [" ".join(tokens) for tokens in output]
214 else:
215 return output
217 def update_self_config(self) -> None:
218 "update the config to match the number of mazes, and update the underlying configs of each dataset"
219 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
220 self.cfg.__dict__["n_mazes"] = len(self)
221 for dataset in self.maze_datasets:
222 dataset.update_self_config()
224 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
227MazeDatasetCollectionConfig._dataset_class = MazeDatasetCollection # type: ignore[method-assign, assignment]
229register_loader_handler(
230 LoaderHandler(
231 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] # noqa: ARG005
232 isinstance(json_item, typing.Mapping)
233 and _FORMAT_KEY in json_item
234 and json_item[_FORMAT_KEY].startswith("MazeDatasetCollection")
235 ),
236 load=lambda json_item, path=None, z=None: MazeDatasetCollection.load(json_item), # type: ignore[misc] # noqa: ARG005
237 uid="MazeDatasetCollection",
238 source_pckg="maze_dataset.generation.maze_dataset_collection",
239 desc="MazeDatasetCollection",
240 ),
241)