Coverage for maze_dataset/dataset/dataset.py: 40%

202 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-08-03 21:38 -0700

1"""`GPTDatasetConfig` and `GPTDataset` are base classes for datasets 

2 

3they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering 

4 

5> [!NOTE] 

6> these should probably be moved into a different package, so don't rely on them being here 

7""" 

8 

9import functools 

10import json 

11import random 

12import typing 

13import warnings 

14from pathlib import Path 

15from typing import Callable, Type, TypeVar 

16 

17import numpy as np 

18from muutils.json_serialize import ( 

19 JSONitem, 

20 SerializableDataclass, 

21 serializable_dataclass, 

22 serializable_field, 

23) 

24from muutils.json_serialize.util import ( 

25 JSONdict, 

26) 

27from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash 

28from typing_extensions import Self 

29from zanj import ZANJ 

30 

31from maze_dataset.generation.seed import GLOBAL_SEED 

32 

33 

34def set_reproducibility(seed: int) -> None: 

35 "set reproducibility in stdlib random and numpy (but not torch)" 

36 random.seed(seed) 

37 np.random.seed(seed) 

38 

39 

40class FilterInfoMismatchError(ValueError): 

41 """raised when the filter info in a dataset config does not match the filter info in the dataset""" 

42 

43 pass 

44 

45 

46def _load_applied_filters( 

47 filters: list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]], 

48) -> list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]]: 

49 try: 

50 return [ 

51 dict( 

52 name=filter_info["name"], 

53 args=tuple( 

54 filter_info["args"], 

55 ), # muutils/zanj save tuples as lists, and this causes problems 

56 kwargs=dict(filter_info["kwargs"]), # type: ignore[arg-type] 

57 ) 

58 for filter_info in filters 

59 ] 

60 except Exception as e: 

61 err_msg: str = f"failed to load applied filters:\n{filters}" 

62 raise ValueError(err_msg) from e 

63 

64 

65@serializable_dataclass(kw_only=True) 

66class GPTDatasetConfig(SerializableDataclass): 

67 """base GPTDatasetConfig class""" 

68 

69 name: str 

70 

71 # TODO: get rid of all these things as part of migration to tokenizer-free dataset config 

72 # -------------------------------------------------- 

73 seq_len_min: int = serializable_field(default=1) 

74 seq_len_max: int = serializable_field(default=512) 

75 # -------------------------------------------------- 

76 

77 seed: int | None = serializable_field(default=GLOBAL_SEED) 

78 applied_filters: list[ 

79 dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict] 

80 ] = serializable_field( 

81 default_factory=list, 

82 deserialize_fn=_load_applied_filters, 

83 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 

84 ) 

85 

86 def __post_init__(self) -> None: 

87 "post init, where we set a random seed if none is set" 

88 assert self.seq_len_min <= self.seq_len_max 

89 # if seed set to None, then generate a new random seed 

90 if self.seed is None: 

91 self.seed = np.random.randint(2**31) 

92 

93 # TODO: something here is broken 

94 if self.seed != GLOBAL_SEED: 

95 warnings.warn( 

96 f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }", 

97 ) 

98 

99 set_reproducibility(self.seed) 

100 

101 def summary(self) -> dict: 

102 """return a summary of the config""" 

103 # do we run this to make sure it doesn't error? 

104 self_ser: dict = self.serialize() 

105 assert self_ser 

106 return dict( 

107 name=self.name, 

108 seq_len_min=self.seq_len_min, 

109 seq_len_max=self.seq_len_max, 

110 seed=self.seed, 

111 applied_filters=self.applied_filters, 

112 ) 

113 

114 @property 

115 def _dataset_class(self) -> type: 

116 raise NotImplementedError("this should be implemented by subclasses!") 

117 

118 def to_fname(self) -> str: 

119 """convert config to a filename""" 

120 self_json_str: str = json.dumps(self.serialize()) 

121 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10) 

122 warnings.warn( 

123 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!", 

124 ) 

125 return sanitize_fname( 

126 # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type] 

127 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type] 

128 ) 

129 

130 

131def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig": 

132 err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}" 

133 raise NotImplementedError( 

134 err_msg, 

135 ) 

136 

137 

138# abstract function, hence we dont care that `self` is unused 

139def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem: # noqa: ANN001, ARG001 

140 err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}" 

141 raise NotImplementedError( 

142 err_msg, 

143 ) 

144 

145 

146GPTDatasetConfig.load = _dataset_config_load # type: ignore[method-assign] 

147GPTDatasetConfig.serialize = _dataset_config_serialize # type: ignore[method-assign,assignment] 

148T_DatasetConfig = TypeVar("T_DatasetConfig", bound=GPTDatasetConfig) 

149 

150 

151class GPTDataset(typing.Generic[T_DatasetConfig]): 

152 """wrapper for torch dataset with some extra functionality 

153 

154 (meaning the functionality should be inherited in downstream classes) 

155 

156 > [!NOTE] 

157 > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config 

158 

159 # Requires: 

160 the following methods should be implemented in subclasses: 

161 - `__init__(self, cfg: GPTDatasetConfig, **kwargs)` 

162 initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably) 

163 - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 

164 generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation) 

165 - `serialize(self) -> JSONitem` 

166 serialize the dataset to a ZANJ-serializable object, including: 

167 - config 

168 - data in formats specified by `self.save_formats` 

169 - `load(cls, data: JSONitem) -> GPTDataset` 

170 load the dataset from a ZANJ-serializable object 

171 - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 

172 given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url) 

173 - `__len__(self) -> int` 

174 return the length of the dataset, required to match interface of `torch.utils.data.Dataset` 

175 - `__getitem__(self, i: int) -> list[str]` 

176 return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset` 

177 - `update_self_config(self) -> None` 

178 update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation 

179 - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters 

180 

181 # Parameters: 

182 - `cfg : GPTDatasetConfig` 

183 config for the dataset, used to generate the dataset 

184 - `do_generate : bool` 

185 whether to generate the dataset if it isn't found 

186 (defaults to `True`) 

187 - `load_local : bool` 

188 whether to try finding the dataset locally 

189 (defaults to `True`) 

190 - `save_local : bool` 

191 whether to save the dataset locally if it is generated or downloaded 

192 (defaults to `True`) 

193 - `do_download : bool` 

194 whether to try downloading the dataset 

195 (defaults to `True`) 

196 - `local_base_path : Path` 

197 where to save the dataset 

198 (defaults to `Path("data/maze_dataset")`) 

199 

200 # Returns: 

201 - `GPTDataset` 

202 the dataset, as you wanted it 

203 

204 # Implements: 

205 - `save(self, file_path: str) -> None` 

206 save the dataset to a file, using ZANJ 

207 - `read(cls, file_path: str) -> GPTDataset` 

208 read the dataset from a file, using ZANJ 

209 get all items in the dataset, in the specified format 

210 - `filter_by(self)` 

211 returns a namespace class 

212 - `_filter_namespace(self) -> Class` 

213 returns a namespace class for filtering the dataset, checking that method 

214 - `_apply_filters_from_config(self) -> None` 

215 apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating 

216 

217 """ 

218 

219 _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore 

220 

221 cfg: "T_DatasetConfig" 

222 

223 @classmethod 

224 def from_config( # noqa: C901, PLR0912 

225 cls, 

226 cfg: "T_DatasetConfig", 

227 do_generate: bool = True, 

228 load_local: bool = True, 

229 save_local: bool = True, 

230 zanj: ZANJ | None = None, 

231 do_download: bool = True, 

232 local_base_path: Path = Path("data/maze_dataset"), 

233 except_on_config_mismatch: bool = True, 

234 allow_generation_metadata_filter_mismatch: bool = True, 

235 verbose: bool = False, 

236 **kwargs, 

237 ) -> "Self": 

238 """base class for gpt datasets 

239 

240 priority of loading: 

241 1. load from local 

242 2. download 

243 3. generate 

244 

245 """ 

246 print_log: Callable = print if verbose else lambda *_a, **_kw: None 

247 

248 local_base_path = Path(local_base_path) 

249 fname: Path = Path(f"{cfg.to_fname()}.zanj") 

250 output: Self | None = None 

251 did_load_local: bool = False 

252 if zanj is None: 

253 zanj = ZANJ() 

254 

255 print_log(f"trying to get the dataset '{cfg.to_fname()}'") 

256 

257 if not (load_local or do_download or do_generate): 

258 raise ValueError( 

259 "no way to load dataset! you said not to load local, not to download, and not to generate", 

260 ) 

261 

262 dataset_path: Path = local_base_path / fname 

263 

264 # try loading 

265 if load_local: # noqa: SIM102 

266 if dataset_path.exists(): 

267 print_log(f"loading dataset from {dataset_path.as_posix()}") 

268 try: 

269 output = cls.read(dataset_path, zanj=zanj) 

270 did_load_local = True 

271 print_log("load successful!") 

272 except Exception as e: # noqa: BLE001 

273 print_log(f"failed to load dataset: {e}") 

274 

275 if do_download and output is None: 

276 print_log("seeing if we can download the dataset...") 

277 try: 

278 output = cls.download(cfg, **kwargs) 

279 print_log("download successful!") 

280 except NotImplementedError: 

281 print_log("no download found, or download failed") 

282 

283 if do_generate and output is None: 

284 print_log("generating dataset...") 

285 output = cls.generate(cfg, verbose=verbose, **kwargs) 

286 # only if we generated it, apply filters 

287 output = output._apply_filters_from_config() 

288 

289 # check and save 

290 if output is None: 

291 raise ValueError("failed to load dataset!") 

292 

293 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True) 

294 if cfg_diff: 

295 if except_on_config_mismatch: 

296 if allow_generation_metadata_filter_mismatch and ( 

297 cfg_diff 

298 == { 

299 "applied_filters": { 

300 "self": [], 

301 "other": [ 

302 { 

303 "name": "collect_generation_meta", 

304 "args": (), 

305 "kwargs": {}, 

306 }, 

307 ], 

308 }, 

309 } 

310 ): 

311 pass 

312 else: 

313 err_msg: str = f"config mismatch: {cfg_diff = }" 

314 raise ValueError(err_msg) 

315 else: 

316 warnings.warn(f"config mismatch: {cfg_diff = }") 

317 

318 if save_local and not did_load_local: 

319 print_log(f"saving dataset to {dataset_path}") 

320 output.save(dataset_path, zanj=zanj) 

321 

322 print_log( 

323 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }", 

324 ) 

325 return output 

326 

327 def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None: 

328 "save dataset to a file with zanj" 

329 if zanj is None: 

330 zanj = ZANJ() 

331 zanj.save(self.serialize(), file_path) 

332 

333 # serialization & loading 

334 @classmethod 

335 def read(cls, file_path: str | Path, zanj: ZANJ | None = None) -> "Self": 

336 "read dataset from a file with zanj" 

337 if zanj is None: 

338 zanj = ZANJ() 

339 return zanj.read(file_path) 

340 

341 def serialize(self) -> JSONdict: 

342 "(implement in subclass!) serialize to something we can save with zanj" 

343 raise NotImplementedError 

344 

345 def data_hash(self) -> int: 

346 "(implement in subclass!) return a hash of the data" 

347 raise NotImplementedError 

348 

349 @classmethod 

350 def load(cls, data: JSONdict) -> "Self": 

351 "(implement in subclass!) load a dataset from what we made with `.serialize()`" 

352 raise NotImplementedError 

353 

354 # generating & downloading 

355 @classmethod 

356 def generate(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self": 

357 "(implement in subclass!) generative given the config" 

358 raise NotImplementedError 

359 

360 @classmethod 

361 def download(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self": 

362 "(implement in subclass!) download the dataset given the config" 

363 raise NotImplementedError 

364 

365 # filtering 

366 def update_self_config(self) -> None: 

367 """(implement in subclass!) update the config of the dataset to match the actual data, if needed 

368 

369 for example, adjust number of mazes after filtering 

370 """ 

371 pass 

372 

373 def __len__(self) -> int: 

374 "return the length of the dataset" 

375 raise NotImplementedError("implement in subclass!") 

376 

377 class FilterBy: 

378 """thanks GPT-4""" 

379 

380 def __init__(self, dataset: "T_Dataset") -> None: 

381 "mock class so we can call `my_dataset.filter_by.some_registered_filter()`" 

382 self.dataset: T_Dataset = dataset 

383 

384 def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]: 

385 "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`" 

386 filter_func: DatasetFilterFunc = getattr( 

387 self.dataset._FILTER_NAMESPACE, 

388 name, 

389 ) 

390 

391 def wrapped_filter_func(*args, **kwargs): # noqa: ANN202 

392 return filter_func(self.dataset, *args, **kwargs) 

393 

394 return wrapped_filter_func 

395 

396 @property 

397 def filter_by(self) -> "FilterBy": 

398 "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset" 

399 return self.FilterBy(self) 

400 

401 def _apply_filters_from_config(self) -> "Self": 

402 """apply filters to the dataset, as specified in the config. used in `from_config()`""" 

403 output: Self = self 

404 # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters 

405 applied_filters_old: list[ 

406 dict[typing.Literal["name", "args", "kwargs"], typing.Any] 

407 ] = self.cfg.applied_filters 

408 output.cfg.applied_filters = list() 

409 # apply the filters 

410 for filter_info in applied_filters_old: 

411 filter_name: str = filter_info["name"] 

412 if filter_name not in output._FILTER_NAMESPACE.__dict__: 

413 if filter_name.startswith("__custom__:"): 

414 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!" 

415 raise ValueError( 

416 err_msg, 

417 ) 

418 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'" 

419 raise ValueError( 

420 err_msg, 

421 ) 

422 filter_args: list = filter_info.get("args", list()) 

423 filter_kwargs: dict = filter_info.get("kwargs", dict()) 

424 output = getattr(output.filter_by, filter_name)( 

425 *filter_args, 

426 **filter_kwargs, 

427 ) 

428 

429 # update the config, perform checks 

430 # TODO: some funny business with manually specified filters here? 

431 output.update_self_config() 

432 _check_filter_equality( 

433 filters_old=applied_filters_old, 

434 filters_new=output.cfg.applied_filters, # type: ignore[arg-type] 

435 ) 

436 return output 

437 

438 

439def _check_filter_equality( 

440 filters_old: list[ 

441 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 

442 ], 

443 filters_new: list[ 

444 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 

445 ], 

446) -> None: 

447 try: 

448 assert len(filters_old) == len(filters_new) 

449 

450 for filterinfo_new, filterinfo_old in zip( 

451 filters_old, 

452 filters_new, 

453 strict=False, 

454 ): 

455 # basic checks 

456 assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict" 

457 assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict" 

458 assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), ( 

459 "missing keys in filterinfo_new" 

460 ) 

461 assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), ( 

462 "missing keys in filterinfo_old" 

463 ) 

464 

465 # name 

466 assert filterinfo_new["name"] == filterinfo_old["name"], ( 

467 "filter names don't match" 

468 ) 

469 

470 # args 

471 assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), ( 

472 "filter args of different lengths" 

473 ) 

474 for arg_new, arg_old in zip( 

475 filterinfo_new["args"], 

476 filterinfo_old["args"], 

477 strict=False, 

478 ): 

479 assert arg_new == arg_old, "filter args don't match" 

480 

481 # kwargs 

482 assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), ( 

483 "filter kwargs of different lengths" 

484 ) 

485 for key in filterinfo_old["kwargs"]: 

486 assert key in filterinfo_new["kwargs"], ( 

487 f"filter kwargs don't match: missing key '{key}'" 

488 ) 

489 assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], ( # type: ignore[index] 

490 f"filter kwargs don't match: values for key '{key}' don't match" 

491 ) 

492 

493 except AssertionError as e: 

494 err_msg: str = ( 

495 f"config mismatch in applied filters: {filters_new} != {filters_old}" 

496 ) 

497 raise FilterInfoMismatchError( 

498 err_msg, 

499 ) from e 

500 

501 

502def register_filter_namespace_for_dataset( 

503 dataset_cls: Type[GPTDataset], 

504) -> Callable[[Type], Type]: 

505 """register the namespace class with the given dataset class""" 

506 

507 def decorator(filter_namespace_cls: Type) -> Type: 

508 dataset_cls._FILTER_NAMESPACE = filter_namespace_cls 

509 filter_namespace_cls._BASE_DATASET = dataset_cls 

510 

511 return filter_namespace_cls 

512 

513 return decorator 

514 

515 

516T_Dataset = TypeVar("T_Dataset", bound=GPTDataset) 

517P_FilterKwargs = typing.ParamSpec("P_FilterKwargs") 

518DatasetFilterFunc = Callable[typing.Concatenate[T_Dataset, P_FilterKwargs], T_Dataset] 

519 

520 

521def register_dataset_filter( 

522 method: DatasetFilterFunc, 

523) -> DatasetFilterFunc: 

524 """register a dataset filter, copying the underlying dataset and updating the config 

525 

526 be sure to return a COPY, not the original? 

527 # TODO: what the heck do we mean by the above? why the question mark? it should be a copy right? 

528 

529 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 

530 """ 

531 

532 @functools.wraps(method) 

533 def wrapper( 

534 # TYPING: error: ParamSpec "P_FilterKwargs" is unbound [valid-type] 

535 dataset: T_Dataset, 

536 *args: P_FilterKwargs.args, # type: ignore[valid-type] 

537 **kwargs: P_FilterKwargs.kwargs, # type: ignore[valid-type] 

538 ) -> T_Dataset: 

539 new_dataset = method(dataset, *args, **kwargs) 

540 # update the config 

541 new_dataset.cfg.applied_filters.append( 

542 dict(name=method.__name__, args=args, kwargs=kwargs), # type: ignore[attr-defined] 

543 ) 

544 new_dataset.update_self_config() 

545 return new_dataset 

546 

547 # TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]") [return-value] 

548 return wrapper # type: ignore[return-value]