Coverage for maze_dataset/dataset/dataset.py: 40%
202 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 21:38 -0700
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 21:38 -0700
1"""`GPTDatasetConfig` and `GPTDataset` are base classes for datasets
3they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering
5> [!NOTE]
6> these should probably be moved into a different package, so don't rely on them being here
7"""
9import functools
10import json
11import random
12import typing
13import warnings
14from pathlib import Path
15from typing import Callable, Type, TypeVar
17import numpy as np
18from muutils.json_serialize import (
19 JSONitem,
20 SerializableDataclass,
21 serializable_dataclass,
22 serializable_field,
23)
24from muutils.json_serialize.util import (
25 JSONdict,
26)
27from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash
28from typing_extensions import Self
29from zanj import ZANJ
31from maze_dataset.generation.seed import GLOBAL_SEED
34def set_reproducibility(seed: int) -> None:
35 "set reproducibility in stdlib random and numpy (but not torch)"
36 random.seed(seed)
37 np.random.seed(seed)
40class FilterInfoMismatchError(ValueError):
41 """raised when the filter info in a dataset config does not match the filter info in the dataset"""
43 pass
46def _load_applied_filters(
47 filters: list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]],
48) -> list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]]:
49 try:
50 return [
51 dict(
52 name=filter_info["name"],
53 args=tuple(
54 filter_info["args"],
55 ), # muutils/zanj save tuples as lists, and this causes problems
56 kwargs=dict(filter_info["kwargs"]), # type: ignore[arg-type]
57 )
58 for filter_info in filters
59 ]
60 except Exception as e:
61 err_msg: str = f"failed to load applied filters:\n{filters}"
62 raise ValueError(err_msg) from e
65@serializable_dataclass(kw_only=True)
66class GPTDatasetConfig(SerializableDataclass):
67 """base GPTDatasetConfig class"""
69 name: str
71 # TODO: get rid of all these things as part of migration to tokenizer-free dataset config
72 # --------------------------------------------------
73 seq_len_min: int = serializable_field(default=1)
74 seq_len_max: int = serializable_field(default=512)
75 # --------------------------------------------------
77 seed: int | None = serializable_field(default=GLOBAL_SEED)
78 applied_filters: list[
79 dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict]
80 ] = serializable_field(
81 default_factory=list,
82 deserialize_fn=_load_applied_filters,
83 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures
84 )
86 def __post_init__(self) -> None:
87 "post init, where we set a random seed if none is set"
88 assert self.seq_len_min <= self.seq_len_max
89 # if seed set to None, then generate a new random seed
90 if self.seed is None:
91 self.seed = np.random.randint(2**31)
93 # TODO: something here is broken
94 if self.seed != GLOBAL_SEED:
95 warnings.warn(
96 f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }",
97 )
99 set_reproducibility(self.seed)
101 def summary(self) -> dict:
102 """return a summary of the config"""
103 # do we run this to make sure it doesn't error?
104 self_ser: dict = self.serialize()
105 assert self_ser
106 return dict(
107 name=self.name,
108 seq_len_min=self.seq_len_min,
109 seq_len_max=self.seq_len_max,
110 seed=self.seed,
111 applied_filters=self.applied_filters,
112 )
114 @property
115 def _dataset_class(self) -> type:
116 raise NotImplementedError("this should be implemented by subclasses!")
118 def to_fname(self) -> str:
119 """convert config to a filename"""
120 self_json_str: str = json.dumps(self.serialize())
121 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10)
122 warnings.warn(
123 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!",
124 )
125 return sanitize_fname(
126 # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type]
127 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type]
128 )
131def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig":
132 err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}"
133 raise NotImplementedError(
134 err_msg,
135 )
138# abstract function, hence we dont care that `self` is unused
139def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem: # noqa: ANN001, ARG001
140 err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}"
141 raise NotImplementedError(
142 err_msg,
143 )
146GPTDatasetConfig.load = _dataset_config_load # type: ignore[method-assign]
147GPTDatasetConfig.serialize = _dataset_config_serialize # type: ignore[method-assign,assignment]
148T_DatasetConfig = TypeVar("T_DatasetConfig", bound=GPTDatasetConfig)
151class GPTDataset(typing.Generic[T_DatasetConfig]):
152 """wrapper for torch dataset with some extra functionality
154 (meaning the functionality should be inherited in downstream classes)
156 > [!NOTE]
157 > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config
159 # Requires:
160 the following methods should be implemented in subclasses:
161 - `__init__(self, cfg: GPTDatasetConfig, **kwargs)`
162 initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)
163 - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
164 generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)
165 - `serialize(self) -> JSONitem`
166 serialize the dataset to a ZANJ-serializable object, including:
167 - config
168 - data in formats specified by `self.save_formats`
169 - `load(cls, data: JSONitem) -> GPTDataset`
170 load the dataset from a ZANJ-serializable object
171 - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
172 given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)
173 - `__len__(self) -> int`
174 return the length of the dataset, required to match interface of `torch.utils.data.Dataset`
175 - `__getitem__(self, i: int) -> list[str]`
176 return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset`
177 - `update_self_config(self) -> None`
178 update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation
179 - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters
181 # Parameters:
182 - `cfg : GPTDatasetConfig`
183 config for the dataset, used to generate the dataset
184 - `do_generate : bool`
185 whether to generate the dataset if it isn't found
186 (defaults to `True`)
187 - `load_local : bool`
188 whether to try finding the dataset locally
189 (defaults to `True`)
190 - `save_local : bool`
191 whether to save the dataset locally if it is generated or downloaded
192 (defaults to `True`)
193 - `do_download : bool`
194 whether to try downloading the dataset
195 (defaults to `True`)
196 - `local_base_path : Path`
197 where to save the dataset
198 (defaults to `Path("data/maze_dataset")`)
200 # Returns:
201 - `GPTDataset`
202 the dataset, as you wanted it
204 # Implements:
205 - `save(self, file_path: str) -> None`
206 save the dataset to a file, using ZANJ
207 - `read(cls, file_path: str) -> GPTDataset`
208 read the dataset from a file, using ZANJ
209 get all items in the dataset, in the specified format
210 - `filter_by(self)`
211 returns a namespace class
212 - `_filter_namespace(self) -> Class`
213 returns a namespace class for filtering the dataset, checking that method
214 - `_apply_filters_from_config(self) -> None`
215 apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating
217 """
219 _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore
221 cfg: "T_DatasetConfig"
223 @classmethod
224 def from_config( # noqa: C901, PLR0912
225 cls,
226 cfg: "T_DatasetConfig",
227 do_generate: bool = True,
228 load_local: bool = True,
229 save_local: bool = True,
230 zanj: ZANJ | None = None,
231 do_download: bool = True,
232 local_base_path: Path = Path("data/maze_dataset"),
233 except_on_config_mismatch: bool = True,
234 allow_generation_metadata_filter_mismatch: bool = True,
235 verbose: bool = False,
236 **kwargs,
237 ) -> "Self":
238 """base class for gpt datasets
240 priority of loading:
241 1. load from local
242 2. download
243 3. generate
245 """
246 print_log: Callable = print if verbose else lambda *_a, **_kw: None
248 local_base_path = Path(local_base_path)
249 fname: Path = Path(f"{cfg.to_fname()}.zanj")
250 output: Self | None = None
251 did_load_local: bool = False
252 if zanj is None:
253 zanj = ZANJ()
255 print_log(f"trying to get the dataset '{cfg.to_fname()}'")
257 if not (load_local or do_download or do_generate):
258 raise ValueError(
259 "no way to load dataset! you said not to load local, not to download, and not to generate",
260 )
262 dataset_path: Path = local_base_path / fname
264 # try loading
265 if load_local: # noqa: SIM102
266 if dataset_path.exists():
267 print_log(f"loading dataset from {dataset_path.as_posix()}")
268 try:
269 output = cls.read(dataset_path, zanj=zanj)
270 did_load_local = True
271 print_log("load successful!")
272 except Exception as e: # noqa: BLE001
273 print_log(f"failed to load dataset: {e}")
275 if do_download and output is None:
276 print_log("seeing if we can download the dataset...")
277 try:
278 output = cls.download(cfg, **kwargs)
279 print_log("download successful!")
280 except NotImplementedError:
281 print_log("no download found, or download failed")
283 if do_generate and output is None:
284 print_log("generating dataset...")
285 output = cls.generate(cfg, verbose=verbose, **kwargs)
286 # only if we generated it, apply filters
287 output = output._apply_filters_from_config()
289 # check and save
290 if output is None:
291 raise ValueError("failed to load dataset!")
293 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True)
294 if cfg_diff:
295 if except_on_config_mismatch:
296 if allow_generation_metadata_filter_mismatch and (
297 cfg_diff
298 == {
299 "applied_filters": {
300 "self": [],
301 "other": [
302 {
303 "name": "collect_generation_meta",
304 "args": (),
305 "kwargs": {},
306 },
307 ],
308 },
309 }
310 ):
311 pass
312 else:
313 err_msg: str = f"config mismatch: {cfg_diff = }"
314 raise ValueError(err_msg)
315 else:
316 warnings.warn(f"config mismatch: {cfg_diff = }")
318 if save_local and not did_load_local:
319 print_log(f"saving dataset to {dataset_path}")
320 output.save(dataset_path, zanj=zanj)
322 print_log(
323 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }",
324 )
325 return output
327 def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None:
328 "save dataset to a file with zanj"
329 if zanj is None:
330 zanj = ZANJ()
331 zanj.save(self.serialize(), file_path)
333 # serialization & loading
334 @classmethod
335 def read(cls, file_path: str | Path, zanj: ZANJ | None = None) -> "Self":
336 "read dataset from a file with zanj"
337 if zanj is None:
338 zanj = ZANJ()
339 return zanj.read(file_path)
341 def serialize(self) -> JSONdict:
342 "(implement in subclass!) serialize to something we can save with zanj"
343 raise NotImplementedError
345 def data_hash(self) -> int:
346 "(implement in subclass!) return a hash of the data"
347 raise NotImplementedError
349 @classmethod
350 def load(cls, data: JSONdict) -> "Self":
351 "(implement in subclass!) load a dataset from what we made with `.serialize()`"
352 raise NotImplementedError
354 # generating & downloading
355 @classmethod
356 def generate(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self":
357 "(implement in subclass!) generative given the config"
358 raise NotImplementedError
360 @classmethod
361 def download(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self":
362 "(implement in subclass!) download the dataset given the config"
363 raise NotImplementedError
365 # filtering
366 def update_self_config(self) -> None:
367 """(implement in subclass!) update the config of the dataset to match the actual data, if needed
369 for example, adjust number of mazes after filtering
370 """
371 pass
373 def __len__(self) -> int:
374 "return the length of the dataset"
375 raise NotImplementedError("implement in subclass!")
377 class FilterBy:
378 """thanks GPT-4"""
380 def __init__(self, dataset: "T_Dataset") -> None:
381 "mock class so we can call `my_dataset.filter_by.some_registered_filter()`"
382 self.dataset: T_Dataset = dataset
384 def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]:
385 "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`"
386 filter_func: DatasetFilterFunc = getattr(
387 self.dataset._FILTER_NAMESPACE,
388 name,
389 )
391 def wrapped_filter_func(*args, **kwargs): # noqa: ANN202
392 return filter_func(self.dataset, *args, **kwargs)
394 return wrapped_filter_func
396 @property
397 def filter_by(self) -> "FilterBy":
398 "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset"
399 return self.FilterBy(self)
401 def _apply_filters_from_config(self) -> "Self":
402 """apply filters to the dataset, as specified in the config. used in `from_config()`"""
403 output: Self = self
404 # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters
405 applied_filters_old: list[
406 dict[typing.Literal["name", "args", "kwargs"], typing.Any]
407 ] = self.cfg.applied_filters
408 output.cfg.applied_filters = list()
409 # apply the filters
410 for filter_info in applied_filters_old:
411 filter_name: str = filter_info["name"]
412 if filter_name not in output._FILTER_NAMESPACE.__dict__:
413 if filter_name.startswith("__custom__:"):
414 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!"
415 raise ValueError(
416 err_msg,
417 )
418 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'"
419 raise ValueError(
420 err_msg,
421 )
422 filter_args: list = filter_info.get("args", list())
423 filter_kwargs: dict = filter_info.get("kwargs", dict())
424 output = getattr(output.filter_by, filter_name)(
425 *filter_args,
426 **filter_kwargs,
427 )
429 # update the config, perform checks
430 # TODO: some funny business with manually specified filters here?
431 output.update_self_config()
432 _check_filter_equality(
433 filters_old=applied_filters_old,
434 filters_new=output.cfg.applied_filters, # type: ignore[arg-type]
435 )
436 return output
439def _check_filter_equality(
440 filters_old: list[
441 dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
442 ],
443 filters_new: list[
444 dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
445 ],
446) -> None:
447 try:
448 assert len(filters_old) == len(filters_new)
450 for filterinfo_new, filterinfo_old in zip(
451 filters_old,
452 filters_new,
453 strict=False,
454 ):
455 # basic checks
456 assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict"
457 assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict"
458 assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), (
459 "missing keys in filterinfo_new"
460 )
461 assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), (
462 "missing keys in filterinfo_old"
463 )
465 # name
466 assert filterinfo_new["name"] == filterinfo_old["name"], (
467 "filter names don't match"
468 )
470 # args
471 assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), (
472 "filter args of different lengths"
473 )
474 for arg_new, arg_old in zip(
475 filterinfo_new["args"],
476 filterinfo_old["args"],
477 strict=False,
478 ):
479 assert arg_new == arg_old, "filter args don't match"
481 # kwargs
482 assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), (
483 "filter kwargs of different lengths"
484 )
485 for key in filterinfo_old["kwargs"]:
486 assert key in filterinfo_new["kwargs"], (
487 f"filter kwargs don't match: missing key '{key}'"
488 )
489 assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], ( # type: ignore[index]
490 f"filter kwargs don't match: values for key '{key}' don't match"
491 )
493 except AssertionError as e:
494 err_msg: str = (
495 f"config mismatch in applied filters: {filters_new} != {filters_old}"
496 )
497 raise FilterInfoMismatchError(
498 err_msg,
499 ) from e
502def register_filter_namespace_for_dataset(
503 dataset_cls: Type[GPTDataset],
504) -> Callable[[Type], Type]:
505 """register the namespace class with the given dataset class"""
507 def decorator(filter_namespace_cls: Type) -> Type:
508 dataset_cls._FILTER_NAMESPACE = filter_namespace_cls
509 filter_namespace_cls._BASE_DATASET = dataset_cls
511 return filter_namespace_cls
513 return decorator
516T_Dataset = TypeVar("T_Dataset", bound=GPTDataset)
517P_FilterKwargs = typing.ParamSpec("P_FilterKwargs")
518DatasetFilterFunc = Callable[typing.Concatenate[T_Dataset, P_FilterKwargs], T_Dataset]
521def register_dataset_filter(
522 method: DatasetFilterFunc,
523) -> DatasetFilterFunc:
524 """register a dataset filter, copying the underlying dataset and updating the config
526 be sure to return a COPY, not the original?
527 # TODO: what the heck do we mean by the above? why the question mark? it should be a copy right?
529 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
530 """
532 @functools.wraps(method)
533 def wrapper(
534 # TYPING: error: ParamSpec "P_FilterKwargs" is unbound [valid-type]
535 dataset: T_Dataset,
536 *args: P_FilterKwargs.args, # type: ignore[valid-type]
537 **kwargs: P_FilterKwargs.kwargs, # type: ignore[valid-type]
538 ) -> T_Dataset:
539 new_dataset = method(dataset, *args, **kwargs)
540 # update the config
541 new_dataset.cfg.applied_filters.append(
542 dict(name=method.__name__, args=args, kwargs=kwargs), # type: ignore[attr-defined]
543 )
544 new_dataset.update_self_config()
545 return new_dataset
547 # TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]") [return-value]
548 return wrapper # type: ignore[return-value]