Coverage for maze_dataset/dataset/dataset.py: 40%

201 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1"""`GPTDatasetConfig` and `GPTDataset` are base classes for datasets 

2 

3they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering 

4 

5> [!NOTE] 

6> these should probably be moved into a different package, so don't rely on them being here 

7""" 

8 

9import functools 

10import json 

11import random 

12import typing 

13import warnings 

14from pathlib import Path 

15from typing import Callable, Type, TypeVar 

16 

17import numpy as np 

18from muutils.json_serialize import ( 

19 JSONitem, 

20 SerializableDataclass, 

21 serializable_dataclass, 

22 serializable_field, 

23) 

24from muutils.json_serialize.util import ( 

25 JSONdict, 

26) 

27from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash 

28from zanj import ZANJ 

29 

30from maze_dataset.generation.seed import GLOBAL_SEED 

31 

32 

33def set_reproducibility(seed: int) -> None: 

34 "set reproducibility in stdlib random and numpy (but not torch)" 

35 random.seed(seed) 

36 np.random.seed(seed) 

37 

38 

39class FilterInfoMismatchError(ValueError): 

40 """raised when the filter info in a dataset config does not match the filter info in the dataset""" 

41 

42 pass 

43 

44 

45def _load_applied_filters( 

46 filters: list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]], 

47) -> list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]]: 

48 try: 

49 return [ 

50 dict( 

51 name=filter_info["name"], 

52 args=tuple( 

53 filter_info["args"], 

54 ), # muutils/zanj save tuples as lists, and this causes problems 

55 kwargs=dict(filter_info["kwargs"]), # type: ignore[arg-type] 

56 ) 

57 for filter_info in filters 

58 ] 

59 except Exception as e: 

60 err_msg: str = f"failed to load applied filters:\n{filters}" 

61 raise ValueError(err_msg) from e 

62 

63 

64@serializable_dataclass(kw_only=True) 

65class GPTDatasetConfig(SerializableDataclass): 

66 """base GPTDatasetConfig class""" 

67 

68 name: str 

69 

70 # TODO: get rid of all these things as part of migration to tokenizer-free dataset config 

71 # -------------------------------------------------- 

72 seq_len_min: int = serializable_field(default=1) 

73 seq_len_max: int = serializable_field(default=512) 

74 # -------------------------------------------------- 

75 

76 seed: int | None = serializable_field(default=GLOBAL_SEED) 

77 applied_filters: list[ 

78 dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict] 

79 ] = serializable_field( 

80 default_factory=list, 

81 deserialize_fn=_load_applied_filters, 

82 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 

83 ) 

84 

85 def __post_init__(self) -> None: 

86 "post init, where we set a random seed if none is set" 

87 assert self.seq_len_min <= self.seq_len_max 

88 # if seed set to None, then generate a new random seed 

89 if self.seed is None: 

90 self.seed = np.random.randint(2**31) 

91 

92 # TODO: something here is broken 

93 if self.seed != GLOBAL_SEED: 

94 warnings.warn( 

95 f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }", 

96 ) 

97 

98 set_reproducibility(self.seed) 

99 

100 def summary(self) -> dict: 

101 """return a summary of the config""" 

102 # do we run this to make sure it doesn't error? 

103 self_ser: dict = self.serialize() 

104 assert self_ser 

105 return dict( 

106 name=self.name, 

107 seq_len_min=self.seq_len_min, 

108 seq_len_max=self.seq_len_max, 

109 seed=self.seed, 

110 applied_filters=self.applied_filters, 

111 ) 

112 

113 @property 

114 def _dataset_class(self) -> type: 

115 raise NotImplementedError("this should be implemented by subclasses!") 

116 

117 def to_fname(self) -> str: 

118 """convert config to a filename""" 

119 self_json_str: str = json.dumps(self.serialize()) 

120 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10) 

121 warnings.warn( 

122 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!", 

123 ) 

124 return sanitize_fname( 

125 # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type] 

126 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type] 

127 ) 

128 

129 

130def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig": 

131 err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}" 

132 raise NotImplementedError( 

133 err_msg, 

134 ) 

135 

136 

137# abstract function, hence we dont care that `self` is unused 

138def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem: # noqa: ANN001, ARG001 

139 err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}" 

140 raise NotImplementedError( 

141 err_msg, 

142 ) 

143 

144 

145GPTDatasetConfig.load = _dataset_config_load # type: ignore[method-assign] 

146GPTDatasetConfig.serialize = _dataset_config_serialize # type: ignore[method-assign,assignment] 

147T_DatasetConfig = TypeVar("T_DatasetConfig", bound=GPTDatasetConfig) 

148 

149 

150class GPTDataset(typing.Generic[T_DatasetConfig]): 

151 """wrapper for torch dataset with some extra functionality 

152 

153 (meaning the functionality should be inherited in downstream classes) 

154 

155 > [!NOTE] 

156 > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config 

157 

158 # Requires: 

159 the following methods should be implemented in subclasses: 

160 - `__init__(self, cfg: GPTDatasetConfig, **kwargs)` 

161 initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably) 

162 - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 

163 generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation) 

164 - `serialize(self) -> JSONitem` 

165 serialize the dataset to a ZANJ-serializable object, including: 

166 - config 

167 - data in formats specified by `self.save_formats` 

168 - `load(cls, data: JSONitem) -> GPTDataset` 

169 load the dataset from a ZANJ-serializable object 

170 - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 

171 given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url) 

172 - `__len__(self) -> int` 

173 return the length of the dataset, required to match interface of `torch.utils.data.Dataset` 

174 - `__getitem__(self, i: int) -> list[str]` 

175 return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset` 

176 - `update_self_config(self) -> None` 

177 update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation 

178 - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters 

179 

180 # Parameters: 

181 - `cfg : GPTDatasetConfig` 

182 config for the dataset, used to generate the dataset 

183 - `do_generate : bool` 

184 whether to generate the dataset if it isn't found 

185 (defaults to `True`) 

186 - `load_local : bool` 

187 whether to try finding the dataset locally 

188 (defaults to `True`) 

189 - `save_local : bool` 

190 whether to save the dataset locally if it is generated or downloaded 

191 (defaults to `True`) 

192 - `do_download : bool` 

193 whether to try downloading the dataset 

194 (defaults to `True`) 

195 - `local_base_path : Path` 

196 where to save the dataset 

197 (defaults to `Path("data/maze_dataset")`) 

198 

199 # Returns: 

200 - `GPTDataset` 

201 the dataset, as you wanted it 

202 

203 # Implements: 

204 - `save(self, file_path: str) -> None` 

205 save the dataset to a file, using ZANJ 

206 - `read(cls, file_path: str) -> GPTDataset` 

207 read the dataset from a file, using ZANJ 

208 get all items in the dataset, in the specified format 

209 - `filter_by(self)` 

210 returns a namespace class 

211 - `_filter_namespace(self) -> Class` 

212 returns a namespace class for filtering the dataset, checking that method 

213 - `_apply_filters_from_config(self) -> None` 

214 apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating 

215 

216 """ 

217 

218 _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore 

219 

220 cfg: "T_DatasetConfig" 

221 

222 @classmethod 

223 def from_config( # noqa: C901, PLR0912 

224 cls: "type[T_Dataset]", 

225 cfg: "T_DatasetConfig", 

226 do_generate: bool = True, 

227 load_local: bool = True, 

228 save_local: bool = True, 

229 zanj: ZANJ | None = None, 

230 do_download: bool = True, 

231 local_base_path: Path = Path("data/maze_dataset"), 

232 except_on_config_mismatch: bool = True, 

233 allow_generation_metadata_filter_mismatch: bool = True, 

234 verbose: bool = False, 

235 **kwargs, 

236 ) -> "T_Dataset": 

237 """base class for gpt datasets 

238 

239 priority of loading: 

240 1. load from local 

241 2. download 

242 3. generate 

243 

244 """ 

245 print_log: Callable = print if verbose else lambda *_a, **_kw: None 

246 

247 local_base_path = Path(local_base_path) 

248 fname: Path = Path(f"{cfg.to_fname()}.zanj") 

249 output: T_Dataset | None = None 

250 did_load_local: bool = False 

251 if zanj is None: 

252 zanj = ZANJ() 

253 

254 print_log(f"trying to get the dataset '{cfg.to_fname()}'") 

255 

256 if not (load_local or do_download or do_generate): 

257 raise ValueError( 

258 "no way to load dataset! you said not to load local, not to download, and not to generate", 

259 ) 

260 

261 dataset_path: Path = local_base_path / fname 

262 

263 # try loading 

264 if load_local: # noqa: SIM102 

265 if dataset_path.exists(): 

266 print_log(f"loading dataset from {dataset_path.as_posix()}") 

267 try: 

268 output = cls.read(dataset_path, zanj=zanj) 

269 did_load_local = True 

270 print_log("load successful!") 

271 except Exception as e: # noqa: BLE001 

272 print_log(f"failed to load dataset: {e}") 

273 

274 if do_download and output is None: 

275 print_log("seeing if we can download the dataset...") 

276 try: 

277 output = cls.download(cfg, **kwargs) 

278 print_log("download successful!") 

279 except NotImplementedError: 

280 print_log("no download found, or download failed") 

281 

282 if do_generate and output is None: 

283 print_log("generating dataset...") 

284 output = cls.generate(cfg, verbose=verbose, **kwargs) 

285 # only if we generated it, apply filters 

286 output = output._apply_filters_from_config() 

287 

288 # check and save 

289 if output is None: 

290 raise ValueError("failed to load dataset!") 

291 

292 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True) 

293 if cfg_diff: 

294 if except_on_config_mismatch: 

295 if allow_generation_metadata_filter_mismatch and ( 

296 cfg_diff 

297 == { 

298 "applied_filters": { 

299 "self": [], 

300 "other": [ 

301 { 

302 "name": "collect_generation_meta", 

303 "args": (), 

304 "kwargs": {}, 

305 }, 

306 ], 

307 }, 

308 } 

309 ): 

310 pass 

311 else: 

312 err_msg: str = f"config mismatch: {cfg_diff = }" 

313 raise ValueError(err_msg) 

314 else: 

315 warnings.warn(f"config mismatch: {cfg_diff = }") 

316 

317 if save_local and not did_load_local: 

318 print_log(f"saving dataset to {dataset_path}") 

319 output.save(dataset_path, zanj=zanj) 

320 

321 print_log( 

322 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }", 

323 ) 

324 return output 

325 

326 def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None: 

327 "save dataset to a file with zanj" 

328 if zanj is None: 

329 zanj = ZANJ() 

330 zanj.save(self.serialize(), file_path) 

331 

332 # serialization & loading 

333 @classmethod 

334 def read( 

335 cls: "type[T_Dataset]", file_path: str | Path, zanj: ZANJ | None = None 

336 ) -> "T_Dataset": 

337 "read dataset from a file with zanj" 

338 if zanj is None: 

339 zanj = ZANJ() 

340 return zanj.read(file_path) 

341 

342 def serialize(self: "T_Dataset") -> JSONdict: 

343 "(implement in subclass!) serialize to something we can save with zanj" 

344 raise NotImplementedError 

345 

346 def data_hash(self: "T_Dataset") -> int: 

347 "(implement in subclass!) return a hash of the data" 

348 raise NotImplementedError 

349 

350 @classmethod 

351 def load(cls: "type[T_Dataset]", data: JSONdict) -> "T_Dataset": 

352 "(implement in subclass!) load a dataset from what we made with `.serialize()`" 

353 raise NotImplementedError 

354 

355 # generating & downloading 

356 @classmethod 

357 def generate( 

358 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs 

359 ) -> "T_Dataset": 

360 "(implement in subclass!) generative given the config" 

361 raise NotImplementedError 

362 

363 @classmethod 

364 def download( 

365 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs 

366 ) -> "T_Dataset": 

367 "(implement in subclass!) download the dataset given the config" 

368 raise NotImplementedError 

369 

370 # filtering 

371 def update_self_config(self) -> None: 

372 """(implement in subclass!) update the config of the dataset to match the actual data, if needed 

373 

374 for example, adjust number of mazes after filtering 

375 """ 

376 pass 

377 

378 def __len__(self) -> int: 

379 "return the length of the dataset" 

380 raise NotImplementedError("implement in subclass!") 

381 

382 class FilterBy: 

383 """thanks GPT-4""" 

384 

385 def __init__(self, dataset: "T_Dataset") -> None: 

386 "mock class so we can call `my_dataset.filter_by.some_registered_filter()`" 

387 self.dataset: T_Dataset = dataset 

388 

389 def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]: 

390 "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`" 

391 filter_func: DatasetFilterFunc = getattr( 

392 self.dataset._FILTER_NAMESPACE, 

393 name, 

394 ) 

395 

396 def wrapped_filter_func(*args, **kwargs): # noqa: ANN202 

397 return filter_func(self.dataset, *args, **kwargs) 

398 

399 return wrapped_filter_func 

400 

401 @property 

402 def filter_by(self) -> "FilterBy": 

403 "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset" 

404 return self.FilterBy(self) 

405 

406 def _apply_filters_from_config(self: "T_Dataset") -> "T_Dataset": 

407 """apply filters to the dataset, as specified in the config. used in `from_config()`""" 

408 output: T_Dataset = self 

409 # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters 

410 applied_filters_old: list[ 

411 dict[typing.Literal["name", "args", "kwargs"], typing.Any] 

412 ] = self.cfg.applied_filters 

413 output.cfg.applied_filters = list() 

414 # apply the filters 

415 for filter_info in applied_filters_old: 

416 filter_name: str = filter_info["name"] 

417 if filter_name not in output._FILTER_NAMESPACE.__dict__: 

418 if filter_name.startswith("__custom__:"): 

419 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!" 

420 raise ValueError( 

421 err_msg, 

422 ) 

423 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'" 

424 raise ValueError( 

425 err_msg, 

426 ) 

427 filter_args: list = filter_info.get("args", list()) 

428 filter_kwargs: dict = filter_info.get("kwargs", dict()) 

429 output = getattr(output.filter_by, filter_name)( 

430 *filter_args, 

431 **filter_kwargs, 

432 ) 

433 

434 # update the config, perform checks 

435 # TODO: some funny business with manually specified filters here? 

436 output.update_self_config() 

437 _check_filter_equality( 

438 filters_old=applied_filters_old, 

439 filters_new=output.cfg.applied_filters, # type: ignore[arg-type] 

440 ) 

441 return output 

442 

443 

444def _check_filter_equality( 

445 filters_old: list[ 

446 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 

447 ], 

448 filters_new: list[ 

449 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 

450 ], 

451) -> None: 

452 try: 

453 assert len(filters_old) == len(filters_new) 

454 

455 for filterinfo_new, filterinfo_old in zip( 

456 filters_old, 

457 filters_new, 

458 strict=False, 

459 ): 

460 # basic checks 

461 assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict" 

462 assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict" 

463 assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), ( 

464 "missing keys in filterinfo_new" 

465 ) 

466 assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), ( 

467 "missing keys in filterinfo_old" 

468 ) 

469 

470 # name 

471 assert filterinfo_new["name"] == filterinfo_old["name"], ( 

472 "filter names don't match" 

473 ) 

474 

475 # args 

476 assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), ( 

477 "filter args of different lengths" 

478 ) 

479 for arg_new, arg_old in zip( 

480 filterinfo_new["args"], 

481 filterinfo_old["args"], 

482 strict=False, 

483 ): 

484 assert arg_new == arg_old, "filter args don't match" 

485 

486 # kwargs 

487 assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), ( 

488 "filter kwargs of different lengths" 

489 ) 

490 for key in filterinfo_old["kwargs"]: 

491 assert key in filterinfo_new["kwargs"], ( 

492 f"filter kwargs don't match: missing key '{key}'" 

493 ) 

494 assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], ( # type: ignore[index] 

495 f"filter kwargs don't match: values for key '{key}' don't match" 

496 ) 

497 

498 except AssertionError as e: 

499 err_msg: str = ( 

500 f"config mismatch in applied filters: {filters_new} != {filters_old}" 

501 ) 

502 raise FilterInfoMismatchError( 

503 err_msg, 

504 ) from e 

505 

506 

507def register_filter_namespace_for_dataset( 

508 dataset_cls: Type[GPTDataset], 

509) -> Callable[[Type], Type]: 

510 """register the namespace class with the given dataset class""" 

511 

512 def decorator(filter_namespace_cls: Type) -> Type: 

513 dataset_cls._FILTER_NAMESPACE = filter_namespace_cls 

514 filter_namespace_cls._BASE_DATASET = dataset_cls 

515 

516 return filter_namespace_cls 

517 

518 return decorator 

519 

520 

521T_Dataset = TypeVar("T_Dataset", bound=GPTDataset) 

522P_FilterKwargs = typing.ParamSpec("P_FilterKwargs") 

523DatasetFilterFunc = Callable[typing.Concatenate[T_Dataset, P_FilterKwargs], T_Dataset] 

524 

525 

526def register_dataset_filter( 

527 method: DatasetFilterFunc, 

528) -> DatasetFilterFunc: 

529 """register a dataset filter, copying the underlying dataset and updating the config 

530 

531 be sure to return a COPY, not the original? 

532 # TODO: what the heck do we mean by the above? why the question mark? it should be a copy right? 

533 

534 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 

535 """ 

536 

537 @functools.wraps(method) 

538 def wrapper( 

539 # TYPING: error: ParamSpec "P_FilterKwargs" is unbound [valid-type] 

540 dataset: T_Dataset, 

541 *args: P_FilterKwargs.args, # type: ignore[valid-type] 

542 **kwargs: P_FilterKwargs.kwargs, # type: ignore[valid-type] 

543 ) -> T_Dataset: 

544 new_dataset = method(dataset, *args, **kwargs) 

545 # update the config 

546 new_dataset.cfg.applied_filters.append( 

547 dict(name=method.__name__, args=args, kwargs=kwargs), # type: ignore[attr-defined] 

548 ) 

549 new_dataset.update_self_config() 

550 return new_dataset 

551 

552 # TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]") [return-value] 

553 return wrapper # type: ignore[return-value]