Coverage for maze_dataset/dataset/dataset.py: 40%
201 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
1"""`GPTDatasetConfig` and `GPTDataset` are base classes for datasets
3they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering
5> [!NOTE]
6> these should probably be moved into a different package, so don't rely on them being here
7"""
9import functools
10import json
11import random
12import typing
13import warnings
14from pathlib import Path
15from typing import Callable, Type, TypeVar
17import numpy as np
18from muutils.json_serialize import (
19 JSONitem,
20 SerializableDataclass,
21 serializable_dataclass,
22 serializable_field,
23)
24from muutils.json_serialize.util import (
25 JSONdict,
26)
27from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash
28from zanj import ZANJ
30from maze_dataset.generation.seed import GLOBAL_SEED
33def set_reproducibility(seed: int) -> None:
34 "set reproducibility in stdlib random and numpy (but not torch)"
35 random.seed(seed)
36 np.random.seed(seed)
39class FilterInfoMismatchError(ValueError):
40 """raised when the filter info in a dataset config does not match the filter info in the dataset"""
42 pass
45def _load_applied_filters(
46 filters: list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]],
47) -> list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]]:
48 try:
49 return [
50 dict(
51 name=filter_info["name"],
52 args=tuple(
53 filter_info["args"],
54 ), # muutils/zanj save tuples as lists, and this causes problems
55 kwargs=dict(filter_info["kwargs"]), # type: ignore[arg-type]
56 )
57 for filter_info in filters
58 ]
59 except Exception as e:
60 err_msg: str = f"failed to load applied filters:\n{filters}"
61 raise ValueError(err_msg) from e
64@serializable_dataclass(kw_only=True)
65class GPTDatasetConfig(SerializableDataclass):
66 """base GPTDatasetConfig class"""
68 name: str
70 # TODO: get rid of all these things as part of migration to tokenizer-free dataset config
71 # --------------------------------------------------
72 seq_len_min: int = serializable_field(default=1)
73 seq_len_max: int = serializable_field(default=512)
74 # --------------------------------------------------
76 seed: int | None = serializable_field(default=GLOBAL_SEED)
77 applied_filters: list[
78 dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict]
79 ] = serializable_field(
80 default_factory=list,
81 deserialize_fn=_load_applied_filters,
82 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures
83 )
85 def __post_init__(self) -> None:
86 "post init, where we set a random seed if none is set"
87 assert self.seq_len_min <= self.seq_len_max
88 # if seed set to None, then generate a new random seed
89 if self.seed is None:
90 self.seed = np.random.randint(2**31)
92 # TODO: something here is broken
93 if self.seed != GLOBAL_SEED:
94 warnings.warn(
95 f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }",
96 )
98 set_reproducibility(self.seed)
100 def summary(self) -> dict:
101 """return a summary of the config"""
102 # do we run this to make sure it doesn't error?
103 self_ser: dict = self.serialize()
104 assert self_ser
105 return dict(
106 name=self.name,
107 seq_len_min=self.seq_len_min,
108 seq_len_max=self.seq_len_max,
109 seed=self.seed,
110 applied_filters=self.applied_filters,
111 )
113 @property
114 def _dataset_class(self) -> type:
115 raise NotImplementedError("this should be implemented by subclasses!")
117 def to_fname(self) -> str:
118 """convert config to a filename"""
119 self_json_str: str = json.dumps(self.serialize())
120 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10)
121 warnings.warn(
122 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!",
123 )
124 return sanitize_fname(
125 # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type]
126 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type]
127 )
130def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig":
131 err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}"
132 raise NotImplementedError(
133 err_msg,
134 )
137# abstract function, hence we dont care that `self` is unused
138def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem: # noqa: ANN001, ARG001
139 err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}"
140 raise NotImplementedError(
141 err_msg,
142 )
145GPTDatasetConfig.load = _dataset_config_load # type: ignore[method-assign]
146GPTDatasetConfig.serialize = _dataset_config_serialize # type: ignore[method-assign,assignment]
147T_DatasetConfig = TypeVar("T_DatasetConfig", bound=GPTDatasetConfig)
150class GPTDataset(typing.Generic[T_DatasetConfig]):
151 """wrapper for torch dataset with some extra functionality
153 (meaning the functionality should be inherited in downstream classes)
155 > [!NOTE]
156 > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config
158 # Requires:
159 the following methods should be implemented in subclasses:
160 - `__init__(self, cfg: GPTDatasetConfig, **kwargs)`
161 initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)
162 - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
163 generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)
164 - `serialize(self) -> JSONitem`
165 serialize the dataset to a ZANJ-serializable object, including:
166 - config
167 - data in formats specified by `self.save_formats`
168 - `load(cls, data: JSONitem) -> GPTDataset`
169 load the dataset from a ZANJ-serializable object
170 - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
171 given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)
172 - `__len__(self) -> int`
173 return the length of the dataset, required to match interface of `torch.utils.data.Dataset`
174 - `__getitem__(self, i: int) -> list[str]`
175 return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset`
176 - `update_self_config(self) -> None`
177 update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation
178 - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters
180 # Parameters:
181 - `cfg : GPTDatasetConfig`
182 config for the dataset, used to generate the dataset
183 - `do_generate : bool`
184 whether to generate the dataset if it isn't found
185 (defaults to `True`)
186 - `load_local : bool`
187 whether to try finding the dataset locally
188 (defaults to `True`)
189 - `save_local : bool`
190 whether to save the dataset locally if it is generated or downloaded
191 (defaults to `True`)
192 - `do_download : bool`
193 whether to try downloading the dataset
194 (defaults to `True`)
195 - `local_base_path : Path`
196 where to save the dataset
197 (defaults to `Path("data/maze_dataset")`)
199 # Returns:
200 - `GPTDataset`
201 the dataset, as you wanted it
203 # Implements:
204 - `save(self, file_path: str) -> None`
205 save the dataset to a file, using ZANJ
206 - `read(cls, file_path: str) -> GPTDataset`
207 read the dataset from a file, using ZANJ
208 get all items in the dataset, in the specified format
209 - `filter_by(self)`
210 returns a namespace class
211 - `_filter_namespace(self) -> Class`
212 returns a namespace class for filtering the dataset, checking that method
213 - `_apply_filters_from_config(self) -> None`
214 apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating
216 """
218 _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore
220 cfg: "T_DatasetConfig"
222 @classmethod
223 def from_config( # noqa: C901, PLR0912
224 cls: "type[T_Dataset]",
225 cfg: "T_DatasetConfig",
226 do_generate: bool = True,
227 load_local: bool = True,
228 save_local: bool = True,
229 zanj: ZANJ | None = None,
230 do_download: bool = True,
231 local_base_path: Path = Path("data/maze_dataset"),
232 except_on_config_mismatch: bool = True,
233 allow_generation_metadata_filter_mismatch: bool = True,
234 verbose: bool = False,
235 **kwargs,
236 ) -> "T_Dataset":
237 """base class for gpt datasets
239 priority of loading:
240 1. load from local
241 2. download
242 3. generate
244 """
245 print_log: Callable = print if verbose else lambda *_a, **_kw: None
247 local_base_path = Path(local_base_path)
248 fname: Path = Path(f"{cfg.to_fname()}.zanj")
249 output: T_Dataset | None = None
250 did_load_local: bool = False
251 if zanj is None:
252 zanj = ZANJ()
254 print_log(f"trying to get the dataset '{cfg.to_fname()}'")
256 if not (load_local or do_download or do_generate):
257 raise ValueError(
258 "no way to load dataset! you said not to load local, not to download, and not to generate",
259 )
261 dataset_path: Path = local_base_path / fname
263 # try loading
264 if load_local: # noqa: SIM102
265 if dataset_path.exists():
266 print_log(f"loading dataset from {dataset_path.as_posix()}")
267 try:
268 output = cls.read(dataset_path, zanj=zanj)
269 did_load_local = True
270 print_log("load successful!")
271 except Exception as e: # noqa: BLE001
272 print_log(f"failed to load dataset: {e}")
274 if do_download and output is None:
275 print_log("seeing if we can download the dataset...")
276 try:
277 output = cls.download(cfg, **kwargs)
278 print_log("download successful!")
279 except NotImplementedError:
280 print_log("no download found, or download failed")
282 if do_generate and output is None:
283 print_log("generating dataset...")
284 output = cls.generate(cfg, verbose=verbose, **kwargs)
285 # only if we generated it, apply filters
286 output = output._apply_filters_from_config()
288 # check and save
289 if output is None:
290 raise ValueError("failed to load dataset!")
292 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True)
293 if cfg_diff:
294 if except_on_config_mismatch:
295 if allow_generation_metadata_filter_mismatch and (
296 cfg_diff
297 == {
298 "applied_filters": {
299 "self": [],
300 "other": [
301 {
302 "name": "collect_generation_meta",
303 "args": (),
304 "kwargs": {},
305 },
306 ],
307 },
308 }
309 ):
310 pass
311 else:
312 err_msg: str = f"config mismatch: {cfg_diff = }"
313 raise ValueError(err_msg)
314 else:
315 warnings.warn(f"config mismatch: {cfg_diff = }")
317 if save_local and not did_load_local:
318 print_log(f"saving dataset to {dataset_path}")
319 output.save(dataset_path, zanj=zanj)
321 print_log(
322 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }",
323 )
324 return output
326 def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None:
327 "save dataset to a file with zanj"
328 if zanj is None:
329 zanj = ZANJ()
330 zanj.save(self.serialize(), file_path)
332 # serialization & loading
333 @classmethod
334 def read(
335 cls: "type[T_Dataset]", file_path: str | Path, zanj: ZANJ | None = None
336 ) -> "T_Dataset":
337 "read dataset from a file with zanj"
338 if zanj is None:
339 zanj = ZANJ()
340 return zanj.read(file_path)
342 def serialize(self: "T_Dataset") -> JSONdict:
343 "(implement in subclass!) serialize to something we can save with zanj"
344 raise NotImplementedError
346 def data_hash(self: "T_Dataset") -> int:
347 "(implement in subclass!) return a hash of the data"
348 raise NotImplementedError
350 @classmethod
351 def load(cls: "type[T_Dataset]", data: JSONdict) -> "T_Dataset":
352 "(implement in subclass!) load a dataset from what we made with `.serialize()`"
353 raise NotImplementedError
355 # generating & downloading
356 @classmethod
357 def generate(
358 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs
359 ) -> "T_Dataset":
360 "(implement in subclass!) generative given the config"
361 raise NotImplementedError
363 @classmethod
364 def download(
365 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs
366 ) -> "T_Dataset":
367 "(implement in subclass!) download the dataset given the config"
368 raise NotImplementedError
370 # filtering
371 def update_self_config(self) -> None:
372 """(implement in subclass!) update the config of the dataset to match the actual data, if needed
374 for example, adjust number of mazes after filtering
375 """
376 pass
378 def __len__(self) -> int:
379 "return the length of the dataset"
380 raise NotImplementedError("implement in subclass!")
382 class FilterBy:
383 """thanks GPT-4"""
385 def __init__(self, dataset: "T_Dataset") -> None:
386 "mock class so we can call `my_dataset.filter_by.some_registered_filter()`"
387 self.dataset: T_Dataset = dataset
389 def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]:
390 "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`"
391 filter_func: DatasetFilterFunc = getattr(
392 self.dataset._FILTER_NAMESPACE,
393 name,
394 )
396 def wrapped_filter_func(*args, **kwargs): # noqa: ANN202
397 return filter_func(self.dataset, *args, **kwargs)
399 return wrapped_filter_func
401 @property
402 def filter_by(self) -> "FilterBy":
403 "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset"
404 return self.FilterBy(self)
406 def _apply_filters_from_config(self: "T_Dataset") -> "T_Dataset":
407 """apply filters to the dataset, as specified in the config. used in `from_config()`"""
408 output: T_Dataset = self
409 # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters
410 applied_filters_old: list[
411 dict[typing.Literal["name", "args", "kwargs"], typing.Any]
412 ] = self.cfg.applied_filters
413 output.cfg.applied_filters = list()
414 # apply the filters
415 for filter_info in applied_filters_old:
416 filter_name: str = filter_info["name"]
417 if filter_name not in output._FILTER_NAMESPACE.__dict__:
418 if filter_name.startswith("__custom__:"):
419 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!"
420 raise ValueError(
421 err_msg,
422 )
423 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'"
424 raise ValueError(
425 err_msg,
426 )
427 filter_args: list = filter_info.get("args", list())
428 filter_kwargs: dict = filter_info.get("kwargs", dict())
429 output = getattr(output.filter_by, filter_name)(
430 *filter_args,
431 **filter_kwargs,
432 )
434 # update the config, perform checks
435 # TODO: some funny business with manually specified filters here?
436 output.update_self_config()
437 _check_filter_equality(
438 filters_old=applied_filters_old,
439 filters_new=output.cfg.applied_filters, # type: ignore[arg-type]
440 )
441 return output
444def _check_filter_equality(
445 filters_old: list[
446 dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
447 ],
448 filters_new: list[
449 dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
450 ],
451) -> None:
452 try:
453 assert len(filters_old) == len(filters_new)
455 for filterinfo_new, filterinfo_old in zip(
456 filters_old,
457 filters_new,
458 strict=False,
459 ):
460 # basic checks
461 assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict"
462 assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict"
463 assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), (
464 "missing keys in filterinfo_new"
465 )
466 assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), (
467 "missing keys in filterinfo_old"
468 )
470 # name
471 assert filterinfo_new["name"] == filterinfo_old["name"], (
472 "filter names don't match"
473 )
475 # args
476 assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), (
477 "filter args of different lengths"
478 )
479 for arg_new, arg_old in zip(
480 filterinfo_new["args"],
481 filterinfo_old["args"],
482 strict=False,
483 ):
484 assert arg_new == arg_old, "filter args don't match"
486 # kwargs
487 assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), (
488 "filter kwargs of different lengths"
489 )
490 for key in filterinfo_old["kwargs"]:
491 assert key in filterinfo_new["kwargs"], (
492 f"filter kwargs don't match: missing key '{key}'"
493 )
494 assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], ( # type: ignore[index]
495 f"filter kwargs don't match: values for key '{key}' don't match"
496 )
498 except AssertionError as e:
499 err_msg: str = (
500 f"config mismatch in applied filters: {filters_new} != {filters_old}"
501 )
502 raise FilterInfoMismatchError(
503 err_msg,
504 ) from e
507def register_filter_namespace_for_dataset(
508 dataset_cls: Type[GPTDataset],
509) -> Callable[[Type], Type]:
510 """register the namespace class with the given dataset class"""
512 def decorator(filter_namespace_cls: Type) -> Type:
513 dataset_cls._FILTER_NAMESPACE = filter_namespace_cls
514 filter_namespace_cls._BASE_DATASET = dataset_cls
516 return filter_namespace_cls
518 return decorator
521T_Dataset = TypeVar("T_Dataset", bound=GPTDataset)
522P_FilterKwargs = typing.ParamSpec("P_FilterKwargs")
523DatasetFilterFunc = Callable[typing.Concatenate[T_Dataset, P_FilterKwargs], T_Dataset]
526def register_dataset_filter(
527 method: DatasetFilterFunc,
528) -> DatasetFilterFunc:
529 """register a dataset filter, copying the underlying dataset and updating the config
531 be sure to return a COPY, not the original?
532 # TODO: what the heck do we mean by the above? why the question mark? it should be a copy right?
534 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
535 """
537 @functools.wraps(method)
538 def wrapper(
539 # TYPING: error: ParamSpec "P_FilterKwargs" is unbound [valid-type]
540 dataset: T_Dataset,
541 *args: P_FilterKwargs.args, # type: ignore[valid-type]
542 **kwargs: P_FilterKwargs.kwargs, # type: ignore[valid-type]
543 ) -> T_Dataset:
544 new_dataset = method(dataset, *args, **kwargs)
545 # update the config
546 new_dataset.cfg.applied_filters.append(
547 dict(name=method.__name__, args=args, kwargs=kwargs), # type: ignore[attr-defined]
548 )
549 new_dataset.update_self_config()
550 return new_dataset
552 # TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]") [return-value]
553 return wrapper # type: ignore[return-value]