maze_dataset.dataset.dataset
GPTDatasetConfig
and GPTDataset
are base classes for datasets
they implement some basic functionality, saving/loading, the from_config
pipeline, and filtering
these should probably be moved into a different package, so don't rely on them being here
1"""`GPTDatasetConfig` and `GPTDataset` are base classes for datasets 2 3they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering 4 5> [!NOTE] 6> these should probably be moved into a different package, so don't rely on them being here 7""" 8 9import functools 10import json 11import random 12import typing 13import warnings 14from pathlib import Path 15from typing import Callable, Type, TypeVar 16 17import numpy as np 18from muutils.json_serialize import ( 19 JSONitem, 20 SerializableDataclass, 21 serializable_dataclass, 22 serializable_field, 23) 24from muutils.json_serialize.util import ( 25 JSONdict, 26) 27from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash 28from zanj import ZANJ 29 30from maze_dataset.generation.seed import GLOBAL_SEED 31 32 33def set_reproducibility(seed: int) -> None: 34 "set reproducibility in stdlib random and numpy (but not torch)" 35 random.seed(seed) 36 np.random.seed(seed) 37 38 39class FilterInfoMismatchError(ValueError): 40 """raised when the filter info in a dataset config does not match the filter info in the dataset""" 41 42 pass 43 44 45def _load_applied_filters( 46 filters: list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]], 47) -> list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]]: 48 try: 49 return [ 50 dict( 51 name=filter_info["name"], 52 args=tuple( 53 filter_info["args"], 54 ), # muutils/zanj save tuples as lists, and this causes problems 55 kwargs=dict(filter_info["kwargs"]), # type: ignore[arg-type] 56 ) 57 for filter_info in filters 58 ] 59 except Exception as e: 60 err_msg: str = f"failed to load applied filters:\n{filters}" 61 raise ValueError(err_msg) from e 62 63 64@serializable_dataclass(kw_only=True) 65class GPTDatasetConfig(SerializableDataclass): 66 """base GPTDatasetConfig class""" 67 68 name: str 69 70 # TODO: get rid of all these things as part of migration to tokenizer-free dataset config 71 # -------------------------------------------------- 72 seq_len_min: int = serializable_field(default=1) 73 seq_len_max: int = serializable_field(default=512) 74 # -------------------------------------------------- 75 76 seed: int | None = serializable_field(default=GLOBAL_SEED) 77 applied_filters: list[ 78 dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict] 79 ] = serializable_field( 80 default_factory=list, 81 deserialize_fn=_load_applied_filters, 82 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 83 ) 84 85 def __post_init__(self) -> None: 86 "post init, where we set a random seed if none is set" 87 assert self.seq_len_min <= self.seq_len_max 88 # if seed set to None, then generate a new random seed 89 if self.seed is None: 90 self.seed = np.random.randint(2**31) 91 92 # TODO: something here is broken 93 if self.seed != GLOBAL_SEED: 94 warnings.warn( 95 f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }", 96 ) 97 98 set_reproducibility(self.seed) 99 100 def summary(self) -> dict: 101 """return a summary of the config""" 102 # do we run this to make sure it doesn't error? 103 self_ser: dict = self.serialize() 104 assert self_ser 105 return dict( 106 name=self.name, 107 seq_len_min=self.seq_len_min, 108 seq_len_max=self.seq_len_max, 109 seed=self.seed, 110 applied_filters=self.applied_filters, 111 ) 112 113 @property 114 def _dataset_class(self) -> type: 115 raise NotImplementedError("this should be implemented by subclasses!") 116 117 def to_fname(self) -> str: 118 """convert config to a filename""" 119 self_json_str: str = json.dumps(self.serialize()) 120 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10) 121 warnings.warn( 122 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!", 123 ) 124 return sanitize_fname( 125 # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type] 126 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type] 127 ) 128 129 130def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig": 131 err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}" 132 raise NotImplementedError( 133 err_msg, 134 ) 135 136 137# abstract function, hence we dont care that `self` is unused 138def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem: # noqa: ANN001, ARG001 139 err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}" 140 raise NotImplementedError( 141 err_msg, 142 ) 143 144 145GPTDatasetConfig.load = _dataset_config_load # type: ignore[method-assign] 146GPTDatasetConfig.serialize = _dataset_config_serialize # type: ignore[method-assign,assignment] 147T_DatasetConfig = TypeVar("T_DatasetConfig", bound=GPTDatasetConfig) 148 149 150class GPTDataset(typing.Generic[T_DatasetConfig]): 151 """wrapper for torch dataset with some extra functionality 152 153 (meaning the functionality should be inherited in downstream classes) 154 155 > [!NOTE] 156 > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config 157 158 # Requires: 159 the following methods should be implemented in subclasses: 160 - `__init__(self, cfg: GPTDatasetConfig, **kwargs)` 161 initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably) 162 - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 163 generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation) 164 - `serialize(self) -> JSONitem` 165 serialize the dataset to a ZANJ-serializable object, including: 166 - config 167 - data in formats specified by `self.save_formats` 168 - `load(cls, data: JSONitem) -> GPTDataset` 169 load the dataset from a ZANJ-serializable object 170 - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 171 given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url) 172 - `__len__(self) -> int` 173 return the length of the dataset, required to match interface of `torch.utils.data.Dataset` 174 - `__getitem__(self, i: int) -> list[str]` 175 return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset` 176 - `update_self_config(self) -> None` 177 update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation 178 - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters 179 180 # Parameters: 181 - `cfg : GPTDatasetConfig` 182 config for the dataset, used to generate the dataset 183 - `do_generate : bool` 184 whether to generate the dataset if it isn't found 185 (defaults to `True`) 186 - `load_local : bool` 187 whether to try finding the dataset locally 188 (defaults to `True`) 189 - `save_local : bool` 190 whether to save the dataset locally if it is generated or downloaded 191 (defaults to `True`) 192 - `do_download : bool` 193 whether to try downloading the dataset 194 (defaults to `True`) 195 - `local_base_path : Path` 196 where to save the dataset 197 (defaults to `Path("data/maze_dataset")`) 198 199 # Returns: 200 - `GPTDataset` 201 the dataset, as you wanted it 202 203 # Implements: 204 - `save(self, file_path: str) -> None` 205 save the dataset to a file, using ZANJ 206 - `read(cls, file_path: str) -> GPTDataset` 207 read the dataset from a file, using ZANJ 208 get all items in the dataset, in the specified format 209 - `filter_by(self)` 210 returns a namespace class 211 - `_filter_namespace(self) -> Class` 212 returns a namespace class for filtering the dataset, checking that method 213 - `_apply_filters_from_config(self) -> None` 214 apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating 215 216 """ 217 218 _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore 219 220 cfg: "T_DatasetConfig" 221 222 @classmethod 223 def from_config( # noqa: C901, PLR0912 224 cls: "type[T_Dataset]", 225 cfg: "T_DatasetConfig", 226 do_generate: bool = True, 227 load_local: bool = True, 228 save_local: bool = True, 229 zanj: ZANJ | None = None, 230 do_download: bool = True, 231 local_base_path: Path = Path("data/maze_dataset"), 232 except_on_config_mismatch: bool = True, 233 allow_generation_metadata_filter_mismatch: bool = True, 234 verbose: bool = False, 235 **kwargs, 236 ) -> "T_Dataset": 237 """base class for gpt datasets 238 239 priority of loading: 240 1. load from local 241 2. download 242 3. generate 243 244 """ 245 print_log: Callable = print if verbose else lambda *_a, **_kw: None 246 247 local_base_path = Path(local_base_path) 248 fname: Path = Path(f"{cfg.to_fname()}.zanj") 249 output: T_Dataset | None = None 250 did_load_local: bool = False 251 if zanj is None: 252 zanj = ZANJ() 253 254 print_log(f"trying to get the dataset '{cfg.to_fname()}'") 255 256 if not (load_local or do_download or do_generate): 257 raise ValueError( 258 "no way to load dataset! you said not to load local, not to download, and not to generate", 259 ) 260 261 dataset_path: Path = local_base_path / fname 262 263 # try loading 264 if load_local: # noqa: SIM102 265 if dataset_path.exists(): 266 print_log(f"loading dataset from {dataset_path.as_posix()}") 267 try: 268 output = cls.read(dataset_path, zanj=zanj) 269 did_load_local = True 270 print_log("load successful!") 271 except Exception as e: # noqa: BLE001 272 print_log(f"failed to load dataset: {e}") 273 274 if do_download and output is None: 275 print_log("seeing if we can download the dataset...") 276 try: 277 output = cls.download(cfg, **kwargs) 278 print_log("download successful!") 279 except NotImplementedError: 280 print_log("no download found, or download failed") 281 282 if do_generate and output is None: 283 print_log("generating dataset...") 284 output = cls.generate(cfg, verbose=verbose, **kwargs) 285 # only if we generated it, apply filters 286 output = output._apply_filters_from_config() 287 288 # check and save 289 if output is None: 290 raise ValueError("failed to load dataset!") 291 292 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True) 293 if cfg_diff: 294 if except_on_config_mismatch: 295 if allow_generation_metadata_filter_mismatch and ( 296 cfg_diff 297 == { 298 "applied_filters": { 299 "self": [], 300 "other": [ 301 { 302 "name": "collect_generation_meta", 303 "args": (), 304 "kwargs": {}, 305 }, 306 ], 307 }, 308 } 309 ): 310 pass 311 else: 312 err_msg: str = f"config mismatch: {cfg_diff = }" 313 raise ValueError(err_msg) 314 else: 315 warnings.warn(f"config mismatch: {cfg_diff = }") 316 317 if save_local and not did_load_local: 318 print_log(f"saving dataset to {dataset_path}") 319 output.save(dataset_path, zanj=zanj) 320 321 print_log( 322 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }", 323 ) 324 return output 325 326 def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None: 327 "save dataset to a file with zanj" 328 if zanj is None: 329 zanj = ZANJ() 330 zanj.save(self.serialize(), file_path) 331 332 # serialization & loading 333 @classmethod 334 def read( 335 cls: "type[T_Dataset]", file_path: str | Path, zanj: ZANJ | None = None 336 ) -> "T_Dataset": 337 "read dataset from a file with zanj" 338 if zanj is None: 339 zanj = ZANJ() 340 return zanj.read(file_path) 341 342 def serialize(self: "T_Dataset") -> JSONdict: 343 "(implement in subclass!) serialize to something we can save with zanj" 344 raise NotImplementedError 345 346 def data_hash(self: "T_Dataset") -> int: 347 "(implement in subclass!) return a hash of the data" 348 raise NotImplementedError 349 350 @classmethod 351 def load(cls: "type[T_Dataset]", data: JSONdict) -> "T_Dataset": 352 "(implement in subclass!) load a dataset from what we made with `.serialize()`" 353 raise NotImplementedError 354 355 # generating & downloading 356 @classmethod 357 def generate( 358 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs 359 ) -> "T_Dataset": 360 "(implement in subclass!) generative given the config" 361 raise NotImplementedError 362 363 @classmethod 364 def download( 365 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs 366 ) -> "T_Dataset": 367 "(implement in subclass!) download the dataset given the config" 368 raise NotImplementedError 369 370 # filtering 371 def update_self_config(self) -> None: 372 """(implement in subclass!) update the config of the dataset to match the actual data, if needed 373 374 for example, adjust number of mazes after filtering 375 """ 376 pass 377 378 def __len__(self) -> int: 379 "return the length of the dataset" 380 raise NotImplementedError("implement in subclass!") 381 382 class FilterBy: 383 """thanks GPT-4""" 384 385 def __init__(self, dataset: "T_Dataset") -> None: 386 "mock class so we can call `my_dataset.filter_by.some_registered_filter()`" 387 self.dataset: T_Dataset = dataset 388 389 def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]: 390 "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`" 391 filter_func: DatasetFilterFunc = getattr( 392 self.dataset._FILTER_NAMESPACE, 393 name, 394 ) 395 396 def wrapped_filter_func(*args, **kwargs): # noqa: ANN202 397 return filter_func(self.dataset, *args, **kwargs) 398 399 return wrapped_filter_func 400 401 @property 402 def filter_by(self) -> "FilterBy": 403 "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset" 404 return self.FilterBy(self) 405 406 def _apply_filters_from_config(self: "T_Dataset") -> "T_Dataset": 407 """apply filters to the dataset, as specified in the config. used in `from_config()`""" 408 output: T_Dataset = self 409 # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters 410 applied_filters_old: list[ 411 dict[typing.Literal["name", "args", "kwargs"], typing.Any] 412 ] = self.cfg.applied_filters 413 output.cfg.applied_filters = list() 414 # apply the filters 415 for filter_info in applied_filters_old: 416 filter_name: str = filter_info["name"] 417 if filter_name not in output._FILTER_NAMESPACE.__dict__: 418 if filter_name.startswith("__custom__:"): 419 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!" 420 raise ValueError( 421 err_msg, 422 ) 423 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'" 424 raise ValueError( 425 err_msg, 426 ) 427 filter_args: list = filter_info.get("args", list()) 428 filter_kwargs: dict = filter_info.get("kwargs", dict()) 429 output = getattr(output.filter_by, filter_name)( 430 *filter_args, 431 **filter_kwargs, 432 ) 433 434 # update the config, perform checks 435 # TODO: some funny business with manually specified filters here? 436 output.update_self_config() 437 _check_filter_equality( 438 filters_old=applied_filters_old, 439 filters_new=output.cfg.applied_filters, # type: ignore[arg-type] 440 ) 441 return output 442 443 444def _check_filter_equality( 445 filters_old: list[ 446 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 447 ], 448 filters_new: list[ 449 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 450 ], 451) -> None: 452 try: 453 assert len(filters_old) == len(filters_new) 454 455 for filterinfo_new, filterinfo_old in zip( 456 filters_old, 457 filters_new, 458 strict=False, 459 ): 460 # basic checks 461 assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict" 462 assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict" 463 assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), ( 464 "missing keys in filterinfo_new" 465 ) 466 assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), ( 467 "missing keys in filterinfo_old" 468 ) 469 470 # name 471 assert filterinfo_new["name"] == filterinfo_old["name"], ( 472 "filter names don't match" 473 ) 474 475 # args 476 assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), ( 477 "filter args of different lengths" 478 ) 479 for arg_new, arg_old in zip( 480 filterinfo_new["args"], 481 filterinfo_old["args"], 482 strict=False, 483 ): 484 assert arg_new == arg_old, "filter args don't match" 485 486 # kwargs 487 assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), ( 488 "filter kwargs of different lengths" 489 ) 490 for key in filterinfo_old["kwargs"]: 491 assert key in filterinfo_new["kwargs"], ( 492 f"filter kwargs don't match: missing key '{key}'" 493 ) 494 assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], ( # type: ignore[index] 495 f"filter kwargs don't match: values for key '{key}' don't match" 496 ) 497 498 except AssertionError as e: 499 err_msg: str = ( 500 f"config mismatch in applied filters: {filters_new} != {filters_old}" 501 ) 502 raise FilterInfoMismatchError( 503 err_msg, 504 ) from e 505 506 507def register_filter_namespace_for_dataset( 508 dataset_cls: Type[GPTDataset], 509) -> Callable[[Type], Type]: 510 """register the namespace class with the given dataset class""" 511 512 def decorator(filter_namespace_cls: Type) -> Type: 513 dataset_cls._FILTER_NAMESPACE = filter_namespace_cls 514 filter_namespace_cls._BASE_DATASET = dataset_cls 515 516 return filter_namespace_cls 517 518 return decorator 519 520 521T_Dataset = TypeVar("T_Dataset", bound=GPTDataset) 522P_FilterKwargs = typing.ParamSpec("P_FilterKwargs") 523DatasetFilterFunc = Callable[typing.Concatenate[T_Dataset, P_FilterKwargs], T_Dataset] 524 525 526def register_dataset_filter( 527 method: DatasetFilterFunc, 528) -> DatasetFilterFunc: 529 """register a dataset filter, copying the underlying dataset and updating the config 530 531 be sure to return a COPY, not the original? 532 # TODO: what the heck do we mean by the above? why the question mark? it should be a copy right? 533 534 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 535 """ 536 537 @functools.wraps(method) 538 def wrapper( 539 # TYPING: error: ParamSpec "P_FilterKwargs" is unbound [valid-type] 540 dataset: T_Dataset, 541 *args: P_FilterKwargs.args, # type: ignore[valid-type] 542 **kwargs: P_FilterKwargs.kwargs, # type: ignore[valid-type] 543 ) -> T_Dataset: 544 new_dataset = method(dataset, *args, **kwargs) 545 # update the config 546 new_dataset.cfg.applied_filters.append( 547 dict(name=method.__name__, args=args, kwargs=kwargs), # type: ignore[attr-defined] 548 ) 549 new_dataset.update_self_config() 550 return new_dataset 551 552 # TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]") [return-value] 553 return wrapper # type: ignore[return-value]
34def set_reproducibility(seed: int) -> None: 35 "set reproducibility in stdlib random and numpy (but not torch)" 36 random.seed(seed) 37 np.random.seed(seed)
set reproducibility in stdlib random and numpy (but not torch)
40class FilterInfoMismatchError(ValueError): 41 """raised when the filter info in a dataset config does not match the filter info in the dataset""" 42 43 pass
raised when the filter info in a dataset config does not match the filter info in the dataset
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
65@serializable_dataclass(kw_only=True) 66class GPTDatasetConfig(SerializableDataclass): 67 """base GPTDatasetConfig class""" 68 69 name: str 70 71 # TODO: get rid of all these things as part of migration to tokenizer-free dataset config 72 # -------------------------------------------------- 73 seq_len_min: int = serializable_field(default=1) 74 seq_len_max: int = serializable_field(default=512) 75 # -------------------------------------------------- 76 77 seed: int | None = serializable_field(default=GLOBAL_SEED) 78 applied_filters: list[ 79 dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict] 80 ] = serializable_field( 81 default_factory=list, 82 deserialize_fn=_load_applied_filters, 83 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 84 ) 85 86 def __post_init__(self) -> None: 87 "post init, where we set a random seed if none is set" 88 assert self.seq_len_min <= self.seq_len_max 89 # if seed set to None, then generate a new random seed 90 if self.seed is None: 91 self.seed = np.random.randint(2**31) 92 93 # TODO: something here is broken 94 if self.seed != GLOBAL_SEED: 95 warnings.warn( 96 f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }", 97 ) 98 99 set_reproducibility(self.seed) 100 101 def summary(self) -> dict: 102 """return a summary of the config""" 103 # do we run this to make sure it doesn't error? 104 self_ser: dict = self.serialize() 105 assert self_ser 106 return dict( 107 name=self.name, 108 seq_len_min=self.seq_len_min, 109 seq_len_max=self.seq_len_max, 110 seed=self.seed, 111 applied_filters=self.applied_filters, 112 ) 113 114 @property 115 def _dataset_class(self) -> type: 116 raise NotImplementedError("this should be implemented by subclasses!") 117 118 def to_fname(self) -> str: 119 """convert config to a filename""" 120 self_json_str: str = json.dumps(self.serialize()) 121 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10) 122 warnings.warn( 123 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!", 124 ) 125 return sanitize_fname( 126 # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type] 127 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type] 128 )
base GPTDatasetConfig class
101 def summary(self) -> dict: 102 """return a summary of the config""" 103 # do we run this to make sure it doesn't error? 104 self_ser: dict = self.serialize() 105 assert self_ser 106 return dict( 107 name=self.name, 108 seq_len_min=self.seq_len_min, 109 seq_len_max=self.seq_len_max, 110 seed=self.seed, 111 applied_filters=self.applied_filters, 112 )
return a summary of the config
118 def to_fname(self) -> str: 119 """convert config to a filename""" 120 self_json_str: str = json.dumps(self.serialize()) 121 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10) 122 warnings.warn( 123 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!", 124 ) 125 return sanitize_fname( 126 # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type] 127 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type] 128 )
convert config to a filename
139def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem: # noqa: ANN001, ARG001 140 err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}" 141 raise NotImplementedError( 142 err_msg, 143 )
returns the class as a dict, implemented by using @serializable_dataclass
decorator
131def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig": 132 err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}" 133 raise NotImplementedError( 134 err_msg, 135 )
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
151class GPTDataset(typing.Generic[T_DatasetConfig]): 152 """wrapper for torch dataset with some extra functionality 153 154 (meaning the functionality should be inherited in downstream classes) 155 156 > [!NOTE] 157 > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config 158 159 # Requires: 160 the following methods should be implemented in subclasses: 161 - `__init__(self, cfg: GPTDatasetConfig, **kwargs)` 162 initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably) 163 - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 164 generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation) 165 - `serialize(self) -> JSONitem` 166 serialize the dataset to a ZANJ-serializable object, including: 167 - config 168 - data in formats specified by `self.save_formats` 169 - `load(cls, data: JSONitem) -> GPTDataset` 170 load the dataset from a ZANJ-serializable object 171 - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 172 given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url) 173 - `__len__(self) -> int` 174 return the length of the dataset, required to match interface of `torch.utils.data.Dataset` 175 - `__getitem__(self, i: int) -> list[str]` 176 return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset` 177 - `update_self_config(self) -> None` 178 update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation 179 - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters 180 181 # Parameters: 182 - `cfg : GPTDatasetConfig` 183 config for the dataset, used to generate the dataset 184 - `do_generate : bool` 185 whether to generate the dataset if it isn't found 186 (defaults to `True`) 187 - `load_local : bool` 188 whether to try finding the dataset locally 189 (defaults to `True`) 190 - `save_local : bool` 191 whether to save the dataset locally if it is generated or downloaded 192 (defaults to `True`) 193 - `do_download : bool` 194 whether to try downloading the dataset 195 (defaults to `True`) 196 - `local_base_path : Path` 197 where to save the dataset 198 (defaults to `Path("data/maze_dataset")`) 199 200 # Returns: 201 - `GPTDataset` 202 the dataset, as you wanted it 203 204 # Implements: 205 - `save(self, file_path: str) -> None` 206 save the dataset to a file, using ZANJ 207 - `read(cls, file_path: str) -> GPTDataset` 208 read the dataset from a file, using ZANJ 209 get all items in the dataset, in the specified format 210 - `filter_by(self)` 211 returns a namespace class 212 - `_filter_namespace(self) -> Class` 213 returns a namespace class for filtering the dataset, checking that method 214 - `_apply_filters_from_config(self) -> None` 215 apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating 216 217 """ 218 219 _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore 220 221 cfg: "T_DatasetConfig" 222 223 @classmethod 224 def from_config( # noqa: C901, PLR0912 225 cls: "type[T_Dataset]", 226 cfg: "T_DatasetConfig", 227 do_generate: bool = True, 228 load_local: bool = True, 229 save_local: bool = True, 230 zanj: ZANJ | None = None, 231 do_download: bool = True, 232 local_base_path: Path = Path("data/maze_dataset"), 233 except_on_config_mismatch: bool = True, 234 allow_generation_metadata_filter_mismatch: bool = True, 235 verbose: bool = False, 236 **kwargs, 237 ) -> "T_Dataset": 238 """base class for gpt datasets 239 240 priority of loading: 241 1. load from local 242 2. download 243 3. generate 244 245 """ 246 print_log: Callable = print if verbose else lambda *_a, **_kw: None 247 248 local_base_path = Path(local_base_path) 249 fname: Path = Path(f"{cfg.to_fname()}.zanj") 250 output: T_Dataset | None = None 251 did_load_local: bool = False 252 if zanj is None: 253 zanj = ZANJ() 254 255 print_log(f"trying to get the dataset '{cfg.to_fname()}'") 256 257 if not (load_local or do_download or do_generate): 258 raise ValueError( 259 "no way to load dataset! you said not to load local, not to download, and not to generate", 260 ) 261 262 dataset_path: Path = local_base_path / fname 263 264 # try loading 265 if load_local: # noqa: SIM102 266 if dataset_path.exists(): 267 print_log(f"loading dataset from {dataset_path.as_posix()}") 268 try: 269 output = cls.read(dataset_path, zanj=zanj) 270 did_load_local = True 271 print_log("load successful!") 272 except Exception as e: # noqa: BLE001 273 print_log(f"failed to load dataset: {e}") 274 275 if do_download and output is None: 276 print_log("seeing if we can download the dataset...") 277 try: 278 output = cls.download(cfg, **kwargs) 279 print_log("download successful!") 280 except NotImplementedError: 281 print_log("no download found, or download failed") 282 283 if do_generate and output is None: 284 print_log("generating dataset...") 285 output = cls.generate(cfg, verbose=verbose, **kwargs) 286 # only if we generated it, apply filters 287 output = output._apply_filters_from_config() 288 289 # check and save 290 if output is None: 291 raise ValueError("failed to load dataset!") 292 293 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True) 294 if cfg_diff: 295 if except_on_config_mismatch: 296 if allow_generation_metadata_filter_mismatch and ( 297 cfg_diff 298 == { 299 "applied_filters": { 300 "self": [], 301 "other": [ 302 { 303 "name": "collect_generation_meta", 304 "args": (), 305 "kwargs": {}, 306 }, 307 ], 308 }, 309 } 310 ): 311 pass 312 else: 313 err_msg: str = f"config mismatch: {cfg_diff = }" 314 raise ValueError(err_msg) 315 else: 316 warnings.warn(f"config mismatch: {cfg_diff = }") 317 318 if save_local and not did_load_local: 319 print_log(f"saving dataset to {dataset_path}") 320 output.save(dataset_path, zanj=zanj) 321 322 print_log( 323 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }", 324 ) 325 return output 326 327 def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None: 328 "save dataset to a file with zanj" 329 if zanj is None: 330 zanj = ZANJ() 331 zanj.save(self.serialize(), file_path) 332 333 # serialization & loading 334 @classmethod 335 def read( 336 cls: "type[T_Dataset]", file_path: str | Path, zanj: ZANJ | None = None 337 ) -> "T_Dataset": 338 "read dataset from a file with zanj" 339 if zanj is None: 340 zanj = ZANJ() 341 return zanj.read(file_path) 342 343 def serialize(self: "T_Dataset") -> JSONdict: 344 "(implement in subclass!) serialize to something we can save with zanj" 345 raise NotImplementedError 346 347 def data_hash(self: "T_Dataset") -> int: 348 "(implement in subclass!) return a hash of the data" 349 raise NotImplementedError 350 351 @classmethod 352 def load(cls: "type[T_Dataset]", data: JSONdict) -> "T_Dataset": 353 "(implement in subclass!) load a dataset from what we made with `.serialize()`" 354 raise NotImplementedError 355 356 # generating & downloading 357 @classmethod 358 def generate( 359 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs 360 ) -> "T_Dataset": 361 "(implement in subclass!) generative given the config" 362 raise NotImplementedError 363 364 @classmethod 365 def download( 366 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs 367 ) -> "T_Dataset": 368 "(implement in subclass!) download the dataset given the config" 369 raise NotImplementedError 370 371 # filtering 372 def update_self_config(self) -> None: 373 """(implement in subclass!) update the config of the dataset to match the actual data, if needed 374 375 for example, adjust number of mazes after filtering 376 """ 377 pass 378 379 def __len__(self) -> int: 380 "return the length of the dataset" 381 raise NotImplementedError("implement in subclass!") 382 383 class FilterBy: 384 """thanks GPT-4""" 385 386 def __init__(self, dataset: "T_Dataset") -> None: 387 "mock class so we can call `my_dataset.filter_by.some_registered_filter()`" 388 self.dataset: T_Dataset = dataset 389 390 def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]: 391 "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`" 392 filter_func: DatasetFilterFunc = getattr( 393 self.dataset._FILTER_NAMESPACE, 394 name, 395 ) 396 397 def wrapped_filter_func(*args, **kwargs): # noqa: ANN202 398 return filter_func(self.dataset, *args, **kwargs) 399 400 return wrapped_filter_func 401 402 @property 403 def filter_by(self) -> "FilterBy": 404 "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset" 405 return self.FilterBy(self) 406 407 def _apply_filters_from_config(self: "T_Dataset") -> "T_Dataset": 408 """apply filters to the dataset, as specified in the config. used in `from_config()`""" 409 output: T_Dataset = self 410 # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters 411 applied_filters_old: list[ 412 dict[typing.Literal["name", "args", "kwargs"], typing.Any] 413 ] = self.cfg.applied_filters 414 output.cfg.applied_filters = list() 415 # apply the filters 416 for filter_info in applied_filters_old: 417 filter_name: str = filter_info["name"] 418 if filter_name not in output._FILTER_NAMESPACE.__dict__: 419 if filter_name.startswith("__custom__:"): 420 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!" 421 raise ValueError( 422 err_msg, 423 ) 424 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'" 425 raise ValueError( 426 err_msg, 427 ) 428 filter_args: list = filter_info.get("args", list()) 429 filter_kwargs: dict = filter_info.get("kwargs", dict()) 430 output = getattr(output.filter_by, filter_name)( 431 *filter_args, 432 **filter_kwargs, 433 ) 434 435 # update the config, perform checks 436 # TODO: some funny business with manually specified filters here? 437 output.update_self_config() 438 _check_filter_equality( 439 filters_old=applied_filters_old, 440 filters_new=output.cfg.applied_filters, # type: ignore[arg-type] 441 ) 442 return output
wrapper for torch dataset with some extra functionality
(meaning the functionality should be inherited in downstream classes)
GPTDatasetConfig
should implement a to_fname
method that returns a unique filename for the config
Requires:
the following methods should be implemented in subclasses:
__init__(self, cfg: GPTDatasetConfig, **kwargs)
initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset
generate the dataset from a given config. kwargs are passed through fromfrom_config
, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)serialize(self) -> JSONitem
serialize the dataset to a ZANJ-serializable object, including:- config
- data in formats specified by
self.save_formats
load(cls, data: JSONitem) -> GPTDataset
load the dataset from a ZANJ-serializable objectdownload(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset
given a config, try to download a dataset from some source. kwargs are passed through fromfrom_config
, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)__len__(self) -> int
return the length of the dataset, required to match interface oftorch.utils.data.Dataset
__getitem__(self, i: int) -> list[str]
return the ith item in the dataset, required to match interface oftorch.utils.data.Dataset
update_self_config(self) -> None
update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation- decorating the appropriate filter namespace with
register_filter_namespace_for_dataset(your_dataset_class)
if you want to use filters
Parameters:
cfg : GPTDatasetConfig
config for the dataset, used to generate the datasetdo_generate : bool
whether to generate the dataset if it isn't found (defaults toTrue
)load_local : bool
whether to try finding the dataset locally (defaults toTrue
)save_local : bool
whether to save the dataset locally if it is generated or downloaded (defaults toTrue
)do_download : bool
whether to try downloading the dataset (defaults toTrue
)local_base_path : Path
where to save the dataset (defaults toPath("data/maze_dataset")
)
Returns:
GPTDataset
the dataset, as you wanted it
Implements:
save(self, file_path: str) -> None
save the dataset to a file, using ZANJread(cls, file_path: str) -> GPTDataset
read the dataset from a file, using ZANJ get all items in the dataset, in the specified formatfilter_by(self)
returns a namespace class_filter_namespace(self) -> Class
returns a namespace class for filtering the dataset, checking that method_apply_filters_from_config(self) -> None
apply filters to the dataset, as specified in the config. used infrom_config()
but only when generating
223 @classmethod 224 def from_config( # noqa: C901, PLR0912 225 cls: "type[T_Dataset]", 226 cfg: "T_DatasetConfig", 227 do_generate: bool = True, 228 load_local: bool = True, 229 save_local: bool = True, 230 zanj: ZANJ | None = None, 231 do_download: bool = True, 232 local_base_path: Path = Path("data/maze_dataset"), 233 except_on_config_mismatch: bool = True, 234 allow_generation_metadata_filter_mismatch: bool = True, 235 verbose: bool = False, 236 **kwargs, 237 ) -> "T_Dataset": 238 """base class for gpt datasets 239 240 priority of loading: 241 1. load from local 242 2. download 243 3. generate 244 245 """ 246 print_log: Callable = print if verbose else lambda *_a, **_kw: None 247 248 local_base_path = Path(local_base_path) 249 fname: Path = Path(f"{cfg.to_fname()}.zanj") 250 output: T_Dataset | None = None 251 did_load_local: bool = False 252 if zanj is None: 253 zanj = ZANJ() 254 255 print_log(f"trying to get the dataset '{cfg.to_fname()}'") 256 257 if not (load_local or do_download or do_generate): 258 raise ValueError( 259 "no way to load dataset! you said not to load local, not to download, and not to generate", 260 ) 261 262 dataset_path: Path = local_base_path / fname 263 264 # try loading 265 if load_local: # noqa: SIM102 266 if dataset_path.exists(): 267 print_log(f"loading dataset from {dataset_path.as_posix()}") 268 try: 269 output = cls.read(dataset_path, zanj=zanj) 270 did_load_local = True 271 print_log("load successful!") 272 except Exception as e: # noqa: BLE001 273 print_log(f"failed to load dataset: {e}") 274 275 if do_download and output is None: 276 print_log("seeing if we can download the dataset...") 277 try: 278 output = cls.download(cfg, **kwargs) 279 print_log("download successful!") 280 except NotImplementedError: 281 print_log("no download found, or download failed") 282 283 if do_generate and output is None: 284 print_log("generating dataset...") 285 output = cls.generate(cfg, verbose=verbose, **kwargs) 286 # only if we generated it, apply filters 287 output = output._apply_filters_from_config() 288 289 # check and save 290 if output is None: 291 raise ValueError("failed to load dataset!") 292 293 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True) 294 if cfg_diff: 295 if except_on_config_mismatch: 296 if allow_generation_metadata_filter_mismatch and ( 297 cfg_diff 298 == { 299 "applied_filters": { 300 "self": [], 301 "other": [ 302 { 303 "name": "collect_generation_meta", 304 "args": (), 305 "kwargs": {}, 306 }, 307 ], 308 }, 309 } 310 ): 311 pass 312 else: 313 err_msg: str = f"config mismatch: {cfg_diff = }" 314 raise ValueError(err_msg) 315 else: 316 warnings.warn(f"config mismatch: {cfg_diff = }") 317 318 if save_local and not did_load_local: 319 print_log(f"saving dataset to {dataset_path}") 320 output.save(dataset_path, zanj=zanj) 321 322 print_log( 323 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }", 324 ) 325 return output
base class for gpt datasets
priority of loading:
- load from local
- download
- generate
327 def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None: 328 "save dataset to a file with zanj" 329 if zanj is None: 330 zanj = ZANJ() 331 zanj.save(self.serialize(), file_path)
save dataset to a file with zanj
334 @classmethod 335 def read( 336 cls: "type[T_Dataset]", file_path: str | Path, zanj: ZANJ | None = None 337 ) -> "T_Dataset": 338 "read dataset from a file with zanj" 339 if zanj is None: 340 zanj = ZANJ() 341 return zanj.read(file_path)
read dataset from a file with zanj
343 def serialize(self: "T_Dataset") -> JSONdict: 344 "(implement in subclass!) serialize to something we can save with zanj" 345 raise NotImplementedError
(implement in subclass!) serialize to something we can save with zanj
347 def data_hash(self: "T_Dataset") -> int: 348 "(implement in subclass!) return a hash of the data" 349 raise NotImplementedError
(implement in subclass!) return a hash of the data
351 @classmethod 352 def load(cls: "type[T_Dataset]", data: JSONdict) -> "T_Dataset": 353 "(implement in subclass!) load a dataset from what we made with `.serialize()`" 354 raise NotImplementedError
(implement in subclass!) load a dataset from what we made with .serialize()
357 @classmethod 358 def generate( 359 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs 360 ) -> "T_Dataset": 361 "(implement in subclass!) generative given the config" 362 raise NotImplementedError
(implement in subclass!) generative given the config
364 @classmethod 365 def download( 366 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs 367 ) -> "T_Dataset": 368 "(implement in subclass!) download the dataset given the config" 369 raise NotImplementedError
(implement in subclass!) download the dataset given the config
372 def update_self_config(self) -> None: 373 """(implement in subclass!) update the config of the dataset to match the actual data, if needed 374 375 for example, adjust number of mazes after filtering 376 """ 377 pass
(implement in subclass!) update the config of the dataset to match the actual data, if needed
for example, adjust number of mazes after filtering
402 @property 403 def filter_by(self) -> "FilterBy": 404 "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset" 405 return self.FilterBy(self)
can call my_dataset.filter_by.some_registered_filter()
to filter the dataset
383 class FilterBy: 384 """thanks GPT-4""" 385 386 def __init__(self, dataset: "T_Dataset") -> None: 387 "mock class so we can call `my_dataset.filter_by.some_registered_filter()`" 388 self.dataset: T_Dataset = dataset 389 390 def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]: 391 "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`" 392 filter_func: DatasetFilterFunc = getattr( 393 self.dataset._FILTER_NAMESPACE, 394 name, 395 ) 396 397 def wrapped_filter_func(*args, **kwargs): # noqa: ANN202 398 return filter_func(self.dataset, *args, **kwargs) 399 400 return wrapped_filter_func
thanks GPT-4
508def register_filter_namespace_for_dataset( 509 dataset_cls: Type[GPTDataset], 510) -> Callable[[Type], Type]: 511 """register the namespace class with the given dataset class""" 512 513 def decorator(filter_namespace_cls: Type) -> Type: 514 dataset_cls._FILTER_NAMESPACE = filter_namespace_cls 515 filter_namespace_cls._BASE_DATASET = dataset_cls 516 517 return filter_namespace_cls 518 519 return decorator
register the namespace class with the given dataset class
527def register_dataset_filter( 528 method: DatasetFilterFunc, 529) -> DatasetFilterFunc: 530 """register a dataset filter, copying the underlying dataset and updating the config 531 532 be sure to return a COPY, not the original? 533 # TODO: what the heck do we mean by the above? why the question mark? it should be a copy right? 534 535 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 536 """ 537 538 @functools.wraps(method) 539 def wrapper( 540 # TYPING: error: ParamSpec "P_FilterKwargs" is unbound [valid-type] 541 dataset: T_Dataset, 542 *args: P_FilterKwargs.args, # type: ignore[valid-type] 543 **kwargs: P_FilterKwargs.kwargs, # type: ignore[valid-type] 544 ) -> T_Dataset: 545 new_dataset = method(dataset, *args, **kwargs) 546 # update the config 547 new_dataset.cfg.applied_filters.append( 548 dict(name=method.__name__, args=args, kwargs=kwargs), # type: ignore[attr-defined] 549 ) 550 new_dataset.update_self_config() 551 return new_dataset 552 553 # TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]") [return-value] 554 return wrapper # type: ignore[return-value]
register a dataset filter, copying the underlying dataset and updating the config
be sure to return a COPY, not the original?
TODO: what the heck do we mean by the above? why the question mark? it should be a copy right?
method should be a staticmethod of a namespace class registered with register_filter_namespace_for_dataset