Coverage for maze_dataset/dataset/collected_dataset.py: 44%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1"""collecting different maze datasets into a single dataset, for greater variety in a training or validation set 

2 

3> [!CAUTION] 

4> `MazeDatasetCollection` is not thoroughly tested and is not guaranteed to work. 

5 

6""" 

7 

8import itertools 

9import json 

10import typing 

11from functools import cached_property 

12 

13import numpy as np 

14from jaxtyping import Int 

15from muutils.json_serialize import ( 

16 json_serialize, 

17 serializable_dataclass, 

18 serializable_field, 

19) 

20from muutils.json_serialize.util import _FORMAT_KEY, JSONdict 

21from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash 

22from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler 

23 

24from maze_dataset.constants import Coord, CoordTup 

25from maze_dataset.dataset.dataset import GPTDataset, GPTDatasetConfig 

26from maze_dataset.dataset.maze_dataset import MazeDataset, MazeDatasetConfig 

27from maze_dataset.maze import LatticeMaze 

28 

29 

30@serializable_dataclass(kw_only=True) 

31class MazeDatasetCollectionConfig(GPTDatasetConfig): 

32 """maze dataset collection configuration, including tokenizers and shuffle""" 

33 

34 # Attributes without a default cannot follow attributes with one [misc] 

35 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( # type: ignore[misc] 

36 serialization_fn=lambda configs: [config.serialize() for config in configs], 

37 loading_fn=lambda data: [ 

38 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"] 

39 ], 

40 ) 

41 

42 def summary(self) -> dict: 

43 """return a summary of the config""" 

44 return dict( 

45 n_mazes=self.n_mazes, 

46 max_grid_n=self.max_grid_n, 

47 max_grid_shape=self.max_grid_shape, 

48 fname=self.to_fname(), 

49 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 

50 ) 

51 

52 @property 

53 def n_mazes(self) -> int: 

54 """return the total number of mazes in the collection across all dataset""" 

55 return sum(config.n_mazes for config in self.maze_dataset_configs) 

56 

57 @property 

58 def max_grid_n(self) -> int: 

59 """return the maximum grid size of the mazes in the collection""" 

60 return max(config.grid_n for config in self.maze_dataset_configs) 

61 

62 @property 

63 def max_grid_shape(self) -> CoordTup: 

64 """return the maximum grid shape of the mazes in the collection""" 

65 return (self.max_grid_n, self.max_grid_n) 

66 

67 @property 

68 def max_grid_shape_np(self) -> Coord: 

69 """return the maximum grid shape of the mazes in the collection as a numpy array""" 

70 return np.array(self.max_grid_shape, dtype=np.int32) 

71 

72 def stable_hash_cfg(self) -> int: 

73 """return a stable hash of the config""" 

74 return stable_hash(json.dumps(self.serialize())) 

75 

76 def to_fname(self) -> str: 

77 """convert config to a filename""" 

78 return sanitize_fname( 

79 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 

80 ) 

81 

82 

83class MazeDatasetCollection(GPTDataset): 

84 """a collection of maze datasets""" 

85 

86 def __init__( 

87 self, 

88 cfg: MazeDatasetCollectionConfig, 

89 maze_datasets: list[MazeDataset], 

90 generation_metadata_collected: dict | None = None, 

91 ) -> None: 

92 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 

93 super().__init__() 

94 self.cfg: MazeDatasetCollectionConfig = cfg 

95 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 

96 for c, ds in zip( 

97 self.cfg.maze_dataset_configs, 

98 self.maze_datasets, 

99 strict=False, 

100 ): 

101 assert c.name == ds.cfg.name 

102 assert c == ds.cfg 

103 

104 self.generation_metadata_collected: dict | None = generation_metadata_collected 

105 

106 @property 

107 def dataset_lengths(self) -> list[int]: 

108 """return the lengths of each dataset in the collection""" 

109 return [len(dataset) for dataset in self.maze_datasets] 

110 

111 @property 

112 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 

113 """return the cumulative lengths of each dataset in the collection""" 

114 return np.array(list(itertools.accumulate(self.dataset_lengths))) 

115 

116 @cached_property 

117 def mazes(self) -> list[LatticeMaze]: 

118 "single list of all mazes in the collection" 

119 return list( 

120 itertools.chain.from_iterable( 

121 dataset.mazes for dataset in self.maze_datasets 

122 ), 

123 ) 

124 

125 def __len__(self) -> int: 

126 """return the total number of mazes in the collection""" 

127 return sum(len(dataset) for dataset in self.maze_datasets) 

128 

129 def __getitem__(self, index: int) -> LatticeMaze: 

130 "get a maze by index" 

131 # find which dataset the index belongs to 

132 # we add 1, since np.searchsorted returns the 

133 # index of the last element that is strictly less than the target 

134 # while we want the index of the last element less than or equal to the target 

135 dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1)) 

136 index_adjusted: int = index 

137 if dataset_idx > 0: 

138 # if the index is 0, `dataset_idx - 1` will be -1. 

139 # We just want to use the base index 

140 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1] 

141 return self.maze_datasets[dataset_idx][index_adjusted] 

142 

143 @classmethod 

144 def generate( 

145 cls, 

146 cfg: MazeDatasetCollectionConfig, 

147 **kwargs, 

148 ) -> "MazeDatasetCollection": 

149 """generate a dataset collection from a config""" 

150 datasets = [ 

151 MazeDataset.generate(config, **kwargs) 

152 for config in cfg.maze_dataset_configs 

153 ] 

154 return cls(cfg, datasets) 

155 

156 @classmethod 

157 def download( 

158 cls, 

159 cfg: MazeDatasetCollectionConfig, 

160 **kwargs, 

161 ) -> "MazeDatasetCollection": 

162 "(not implemented!) download a dataset collection from a config" 

163 datasets = [ 

164 MazeDataset.download(config, **kwargs) 

165 for config in cfg.maze_dataset_configs 

166 ] 

167 return cls(cfg, datasets) 

168 

169 def serialize(self) -> JSONdict: 

170 """serialize the dataset collection""" 

171 return { 

172 _FORMAT_KEY: "MazeDatasetCollection", 

173 "cfg": self.cfg.serialize(), 

174 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 

175 "generation_metadata_collected": json_serialize( 

176 self.generation_metadata_collected, 

177 ), 

178 } 

179 

180 @classmethod 

181 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 

182 """load the dataset collection from the representation created by `serialize`""" 

183 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 

184 return cls( 

185 **{ 

186 key: load_item_recursive(data[key], tuple()) 

187 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 

188 }, 

189 ) 

190 

191 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow? 

192 def as_tokens( 

193 self, 

194 # TODO: MazeTokenizer 

195 maze_tokenizer, # noqa: ANN001 

196 limit: int | None = None, 

197 join_tokens_individual_maze: bool = False, 

198 ) -> list[list[str]] | list[str]: 

199 """return the dataset as tokens 

200 

201 if join_tokens_individual_maze is True, then the tokens of each maze are 

202 joined with a space, and the result is a list of strings. 

203 i.e.: 

204 >>> dataset.as_tokens(join_tokens_individual_maze=False) 

205 [["a", "b", "c"], ["d", "e", "f"]] 

206 >>> dataset.as_tokens(join_tokens_individual_maze=True) 

207 ["a b c", "d e f"] 

208 """ 

209 output: list[list[str]] = [ 

210 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 

211 ] 

212 if join_tokens_individual_maze: 

213 return [" ".join(tokens) for tokens in output] 

214 else: 

215 return output 

216 

217 def update_self_config(self) -> None: 

218 "update the config to match the number of mazes, and update the underlying configs of each dataset" 

219 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 

220 self.cfg.__dict__["n_mazes"] = len(self) 

221 for dataset in self.maze_datasets: 

222 dataset.update_self_config() 

223 

224 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets] 

225 

226 

227MazeDatasetCollectionConfig._dataset_class = MazeDatasetCollection # type: ignore[method-assign, assignment] 

228 

229register_loader_handler( 

230 LoaderHandler( 

231 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] # noqa: ARG005 

232 isinstance(json_item, typing.Mapping) 

233 and _FORMAT_KEY in json_item 

234 and json_item[_FORMAT_KEY].startswith("MazeDatasetCollection") 

235 ), 

236 load=lambda json_item, path=None, z=None: MazeDatasetCollection.load(json_item), # type: ignore[misc] # noqa: ARG005 

237 uid="MazeDatasetCollection", 

238 source_pckg="maze_dataset.generation.maze_dataset_collection", 

239 desc="MazeDatasetCollection", 

240 ), 

241)