maze_dataset.dataset.dataset
GPTDatasetConfig
and GPTDataset
are base classes for datasets
they implement some basic functionality, saving/loading, the from_config
pipeline, and filtering
these should probably be moved into a different package, so don't rely on them being here
1"""`GPTDatasetConfig` and `GPTDataset` are base classes for datasets 2 3they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering 4 5> [!NOTE] 6> these should probably be moved into a different package, so don't rely on them being here 7""" 8 9import functools 10import json 11import random 12import typing 13import warnings 14from pathlib import Path 15from typing import Callable, Type, TypeVar 16 17import numpy as np 18from muutils.json_serialize import ( 19 JSONitem, 20 SerializableDataclass, 21 serializable_dataclass, 22 serializable_field, 23) 24from muutils.json_serialize.util import ( 25 JSONdict, 26) 27from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash 28from typing_extensions import Self 29from zanj import ZANJ 30 31from maze_dataset.generation.seed import GLOBAL_SEED 32 33 34def set_reproducibility(seed: int) -> None: 35 "set reproducibility in stdlib random and numpy (but not torch)" 36 random.seed(seed) 37 np.random.seed(seed) 38 39 40class FilterInfoMismatchError(ValueError): 41 """raised when the filter info in a dataset config does not match the filter info in the dataset""" 42 43 pass 44 45 46def _load_applied_filters( 47 filters: list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]], 48) -> list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]]: 49 try: 50 return [ 51 dict( 52 name=filter_info["name"], 53 args=tuple( 54 filter_info["args"], 55 ), # muutils/zanj save tuples as lists, and this causes problems 56 kwargs=dict(filter_info["kwargs"]), # type: ignore[arg-type] 57 ) 58 for filter_info in filters 59 ] 60 except Exception as e: 61 err_msg: str = f"failed to load applied filters:\n{filters}" 62 raise ValueError(err_msg) from e 63 64 65@serializable_dataclass(kw_only=True) 66class GPTDatasetConfig(SerializableDataclass): 67 """base GPTDatasetConfig class""" 68 69 name: str 70 71 # TODO: get rid of all these things as part of migration to tokenizer-free dataset config 72 # -------------------------------------------------- 73 seq_len_min: int = serializable_field(default=1) 74 seq_len_max: int = serializable_field(default=512) 75 # -------------------------------------------------- 76 77 seed: int | None = serializable_field(default=GLOBAL_SEED) 78 applied_filters: list[ 79 dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict] 80 ] = serializable_field( 81 default_factory=list, 82 deserialize_fn=_load_applied_filters, 83 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 84 ) 85 86 def __post_init__(self) -> None: 87 "post init, where we set a random seed if none is set" 88 assert self.seq_len_min <= self.seq_len_max 89 # if seed set to None, then generate a new random seed 90 if self.seed is None: 91 self.seed = np.random.randint(2**31) 92 93 # TODO: something here is broken 94 if self.seed != GLOBAL_SEED: 95 warnings.warn( 96 f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }", 97 ) 98 99 set_reproducibility(self.seed) 100 101 def summary(self) -> dict: 102 """return a summary of the config""" 103 # do we run this to make sure it doesn't error? 104 self_ser: dict = self.serialize() 105 assert self_ser 106 return dict( 107 name=self.name, 108 seq_len_min=self.seq_len_min, 109 seq_len_max=self.seq_len_max, 110 seed=self.seed, 111 applied_filters=self.applied_filters, 112 ) 113 114 @property 115 def _dataset_class(self) -> type: 116 raise NotImplementedError("this should be implemented by subclasses!") 117 118 def to_fname(self) -> str: 119 """convert config to a filename""" 120 self_json_str: str = json.dumps(self.serialize()) 121 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10) 122 warnings.warn( 123 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!", 124 ) 125 return sanitize_fname( 126 # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type] 127 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type] 128 ) 129 130 131def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig": 132 err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}" 133 raise NotImplementedError( 134 err_msg, 135 ) 136 137 138# abstract function, hence we dont care that `self` is unused 139def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem: # noqa: ANN001, ARG001 140 err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}" 141 raise NotImplementedError( 142 err_msg, 143 ) 144 145 146GPTDatasetConfig.load = _dataset_config_load # type: ignore[method-assign] 147GPTDatasetConfig.serialize = _dataset_config_serialize # type: ignore[method-assign,assignment] 148T_DatasetConfig = TypeVar("T_DatasetConfig", bound=GPTDatasetConfig) 149 150 151class GPTDataset(typing.Generic[T_DatasetConfig]): 152 """wrapper for torch dataset with some extra functionality 153 154 (meaning the functionality should be inherited in downstream classes) 155 156 > [!NOTE] 157 > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config 158 159 # Requires: 160 the following methods should be implemented in subclasses: 161 - `__init__(self, cfg: GPTDatasetConfig, **kwargs)` 162 initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably) 163 - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 164 generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation) 165 - `serialize(self) -> JSONitem` 166 serialize the dataset to a ZANJ-serializable object, including: 167 - config 168 - data in formats specified by `self.save_formats` 169 - `load(cls, data: JSONitem) -> GPTDataset` 170 load the dataset from a ZANJ-serializable object 171 - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 172 given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url) 173 - `__len__(self) -> int` 174 return the length of the dataset, required to match interface of `torch.utils.data.Dataset` 175 - `__getitem__(self, i: int) -> list[str]` 176 return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset` 177 - `update_self_config(self) -> None` 178 update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation 179 - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters 180 181 # Parameters: 182 - `cfg : GPTDatasetConfig` 183 config for the dataset, used to generate the dataset 184 - `do_generate : bool` 185 whether to generate the dataset if it isn't found 186 (defaults to `True`) 187 - `load_local : bool` 188 whether to try finding the dataset locally 189 (defaults to `True`) 190 - `save_local : bool` 191 whether to save the dataset locally if it is generated or downloaded 192 (defaults to `True`) 193 - `do_download : bool` 194 whether to try downloading the dataset 195 (defaults to `True`) 196 - `local_base_path : Path` 197 where to save the dataset 198 (defaults to `Path("data/maze_dataset")`) 199 200 # Returns: 201 - `GPTDataset` 202 the dataset, as you wanted it 203 204 # Implements: 205 - `save(self, file_path: str) -> None` 206 save the dataset to a file, using ZANJ 207 - `read(cls, file_path: str) -> GPTDataset` 208 read the dataset from a file, using ZANJ 209 get all items in the dataset, in the specified format 210 - `filter_by(self)` 211 returns a namespace class 212 - `_filter_namespace(self) -> Class` 213 returns a namespace class for filtering the dataset, checking that method 214 - `_apply_filters_from_config(self) -> None` 215 apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating 216 217 """ 218 219 _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore 220 221 cfg: "T_DatasetConfig" 222 223 @classmethod 224 def from_config( # noqa: C901, PLR0912 225 cls, 226 cfg: "T_DatasetConfig", 227 do_generate: bool = True, 228 load_local: bool = True, 229 save_local: bool = True, 230 zanj: ZANJ | None = None, 231 do_download: bool = True, 232 local_base_path: Path = Path("data/maze_dataset"), 233 except_on_config_mismatch: bool = True, 234 allow_generation_metadata_filter_mismatch: bool = True, 235 verbose: bool = False, 236 **kwargs, 237 ) -> "Self": 238 """base class for gpt datasets 239 240 priority of loading: 241 1. load from local 242 2. download 243 3. generate 244 245 """ 246 print_log: Callable = print if verbose else lambda *_a, **_kw: None 247 248 local_base_path = Path(local_base_path) 249 fname: Path = Path(f"{cfg.to_fname()}.zanj") 250 output: Self | None = None 251 did_load_local: bool = False 252 if zanj is None: 253 zanj = ZANJ() 254 255 print_log(f"trying to get the dataset '{cfg.to_fname()}'") 256 257 if not (load_local or do_download or do_generate): 258 raise ValueError( 259 "no way to load dataset! you said not to load local, not to download, and not to generate", 260 ) 261 262 dataset_path: Path = local_base_path / fname 263 264 # try loading 265 if load_local: # noqa: SIM102 266 if dataset_path.exists(): 267 print_log(f"loading dataset from {dataset_path.as_posix()}") 268 try: 269 output = cls.read(dataset_path, zanj=zanj) 270 did_load_local = True 271 print_log("load successful!") 272 except Exception as e: # noqa: BLE001 273 print_log(f"failed to load dataset: {e}") 274 275 if do_download and output is None: 276 print_log("seeing if we can download the dataset...") 277 try: 278 output = cls.download(cfg, **kwargs) 279 print_log("download successful!") 280 except NotImplementedError: 281 print_log("no download found, or download failed") 282 283 if do_generate and output is None: 284 print_log("generating dataset...") 285 output = cls.generate(cfg, verbose=verbose, **kwargs) 286 # only if we generated it, apply filters 287 output = output._apply_filters_from_config() 288 289 # check and save 290 if output is None: 291 raise ValueError("failed to load dataset!") 292 293 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True) 294 if cfg_diff: 295 if except_on_config_mismatch: 296 if allow_generation_metadata_filter_mismatch and ( 297 cfg_diff 298 == { 299 "applied_filters": { 300 "self": [], 301 "other": [ 302 { 303 "name": "collect_generation_meta", 304 "args": (), 305 "kwargs": {}, 306 }, 307 ], 308 }, 309 } 310 ): 311 pass 312 else: 313 err_msg: str = f"config mismatch: {cfg_diff = }" 314 raise ValueError(err_msg) 315 else: 316 warnings.warn(f"config mismatch: {cfg_diff = }") 317 318 if save_local and not did_load_local: 319 print_log(f"saving dataset to {dataset_path}") 320 output.save(dataset_path, zanj=zanj) 321 322 print_log( 323 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }", 324 ) 325 return output 326 327 def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None: 328 "save dataset to a file with zanj" 329 if zanj is None: 330 zanj = ZANJ() 331 zanj.save(self.serialize(), file_path) 332 333 # serialization & loading 334 @classmethod 335 def read(cls, file_path: str | Path, zanj: ZANJ | None = None) -> "Self": 336 "read dataset from a file with zanj" 337 if zanj is None: 338 zanj = ZANJ() 339 return zanj.read(file_path) 340 341 def serialize(self) -> JSONdict: 342 "(implement in subclass!) serialize to something we can save with zanj" 343 raise NotImplementedError 344 345 def data_hash(self) -> int: 346 "(implement in subclass!) return a hash of the data" 347 raise NotImplementedError 348 349 @classmethod 350 def load(cls, data: JSONdict) -> "Self": 351 "(implement in subclass!) load a dataset from what we made with `.serialize()`" 352 raise NotImplementedError 353 354 # generating & downloading 355 @classmethod 356 def generate(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self": 357 "(implement in subclass!) generative given the config" 358 raise NotImplementedError 359 360 @classmethod 361 def download(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self": 362 "(implement in subclass!) download the dataset given the config" 363 raise NotImplementedError 364 365 # filtering 366 def update_self_config(self) -> None: 367 """(implement in subclass!) update the config of the dataset to match the actual data, if needed 368 369 for example, adjust number of mazes after filtering 370 """ 371 pass 372 373 def __len__(self) -> int: 374 "return the length of the dataset" 375 raise NotImplementedError("implement in subclass!") 376 377 class FilterBy: 378 """thanks GPT-4""" 379 380 def __init__(self, dataset: "T_Dataset") -> None: 381 "mock class so we can call `my_dataset.filter_by.some_registered_filter()`" 382 self.dataset: T_Dataset = dataset 383 384 def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]: 385 "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`" 386 filter_func: DatasetFilterFunc = getattr( 387 self.dataset._FILTER_NAMESPACE, 388 name, 389 ) 390 391 def wrapped_filter_func(*args, **kwargs): # noqa: ANN202 392 return filter_func(self.dataset, *args, **kwargs) 393 394 return wrapped_filter_func 395 396 @property 397 def filter_by(self) -> "FilterBy": 398 "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset" 399 return self.FilterBy(self) 400 401 def _apply_filters_from_config(self) -> "Self": 402 """apply filters to the dataset, as specified in the config. used in `from_config()`""" 403 output: Self = self 404 # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters 405 applied_filters_old: list[ 406 dict[typing.Literal["name", "args", "kwargs"], typing.Any] 407 ] = self.cfg.applied_filters 408 output.cfg.applied_filters = list() 409 # apply the filters 410 for filter_info in applied_filters_old: 411 filter_name: str = filter_info["name"] 412 if filter_name not in output._FILTER_NAMESPACE.__dict__: 413 if filter_name.startswith("__custom__:"): 414 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!" 415 raise ValueError( 416 err_msg, 417 ) 418 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'" 419 raise ValueError( 420 err_msg, 421 ) 422 filter_args: list = filter_info.get("args", list()) 423 filter_kwargs: dict = filter_info.get("kwargs", dict()) 424 output = getattr(output.filter_by, filter_name)( 425 *filter_args, 426 **filter_kwargs, 427 ) 428 429 # update the config, perform checks 430 # TODO: some funny business with manually specified filters here? 431 output.update_self_config() 432 _check_filter_equality( 433 filters_old=applied_filters_old, 434 filters_new=output.cfg.applied_filters, # type: ignore[arg-type] 435 ) 436 return output 437 438 439def _check_filter_equality( 440 filters_old: list[ 441 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 442 ], 443 filters_new: list[ 444 dict[typing.Literal["name", "args", "kwargs"], str | list | dict] 445 ], 446) -> None: 447 try: 448 assert len(filters_old) == len(filters_new) 449 450 for filterinfo_new, filterinfo_old in zip( 451 filters_old, 452 filters_new, 453 strict=False, 454 ): 455 # basic checks 456 assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict" 457 assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict" 458 assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), ( 459 "missing keys in filterinfo_new" 460 ) 461 assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), ( 462 "missing keys in filterinfo_old" 463 ) 464 465 # name 466 assert filterinfo_new["name"] == filterinfo_old["name"], ( 467 "filter names don't match" 468 ) 469 470 # args 471 assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), ( 472 "filter args of different lengths" 473 ) 474 for arg_new, arg_old in zip( 475 filterinfo_new["args"], 476 filterinfo_old["args"], 477 strict=False, 478 ): 479 assert arg_new == arg_old, "filter args don't match" 480 481 # kwargs 482 assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), ( 483 "filter kwargs of different lengths" 484 ) 485 for key in filterinfo_old["kwargs"]: 486 assert key in filterinfo_new["kwargs"], ( 487 f"filter kwargs don't match: missing key '{key}'" 488 ) 489 assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], ( # type: ignore[index] 490 f"filter kwargs don't match: values for key '{key}' don't match" 491 ) 492 493 except AssertionError as e: 494 err_msg: str = ( 495 f"config mismatch in applied filters: {filters_new} != {filters_old}" 496 ) 497 raise FilterInfoMismatchError( 498 err_msg, 499 ) from e 500 501 502def register_filter_namespace_for_dataset( 503 dataset_cls: Type[GPTDataset], 504) -> Callable[[Type], Type]: 505 """register the namespace class with the given dataset class""" 506 507 def decorator(filter_namespace_cls: Type) -> Type: 508 dataset_cls._FILTER_NAMESPACE = filter_namespace_cls 509 filter_namespace_cls._BASE_DATASET = dataset_cls 510 511 return filter_namespace_cls 512 513 return decorator 514 515 516T_Dataset = TypeVar("T_Dataset", bound=GPTDataset) 517P_FilterKwargs = typing.ParamSpec("P_FilterKwargs") 518DatasetFilterFunc = Callable[typing.Concatenate[T_Dataset, P_FilterKwargs], T_Dataset] 519 520 521def register_dataset_filter( 522 method: DatasetFilterFunc, 523) -> DatasetFilterFunc: 524 """register a dataset filter, copying the underlying dataset and updating the config 525 526 be sure to return a COPY, not the original? 527 # TODO: what the heck do we mean by the above? why the question mark? it should be a copy right? 528 529 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 530 """ 531 532 @functools.wraps(method) 533 def wrapper( 534 # TYPING: error: ParamSpec "P_FilterKwargs" is unbound [valid-type] 535 dataset: T_Dataset, 536 *args: P_FilterKwargs.args, # type: ignore[valid-type] 537 **kwargs: P_FilterKwargs.kwargs, # type: ignore[valid-type] 538 ) -> T_Dataset: 539 new_dataset = method(dataset, *args, **kwargs) 540 # update the config 541 new_dataset.cfg.applied_filters.append( 542 dict(name=method.__name__, args=args, kwargs=kwargs), # type: ignore[attr-defined] 543 ) 544 new_dataset.update_self_config() 545 return new_dataset 546 547 # TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]") [return-value] 548 return wrapper # type: ignore[return-value]
35def set_reproducibility(seed: int) -> None: 36 "set reproducibility in stdlib random and numpy (but not torch)" 37 random.seed(seed) 38 np.random.seed(seed)
set reproducibility in stdlib random and numpy (but not torch)
41class FilterInfoMismatchError(ValueError): 42 """raised when the filter info in a dataset config does not match the filter info in the dataset""" 43 44 pass
raised when the filter info in a dataset config does not match the filter info in the dataset
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
66@serializable_dataclass(kw_only=True) 67class GPTDatasetConfig(SerializableDataclass): 68 """base GPTDatasetConfig class""" 69 70 name: str 71 72 # TODO: get rid of all these things as part of migration to tokenizer-free dataset config 73 # -------------------------------------------------- 74 seq_len_min: int = serializable_field(default=1) 75 seq_len_max: int = serializable_field(default=512) 76 # -------------------------------------------------- 77 78 seed: int | None = serializable_field(default=GLOBAL_SEED) 79 applied_filters: list[ 80 dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict] 81 ] = serializable_field( 82 default_factory=list, 83 deserialize_fn=_load_applied_filters, 84 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 85 ) 86 87 def __post_init__(self) -> None: 88 "post init, where we set a random seed if none is set" 89 assert self.seq_len_min <= self.seq_len_max 90 # if seed set to None, then generate a new random seed 91 if self.seed is None: 92 self.seed = np.random.randint(2**31) 93 94 # TODO: something here is broken 95 if self.seed != GLOBAL_SEED: 96 warnings.warn( 97 f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }", 98 ) 99 100 set_reproducibility(self.seed) 101 102 def summary(self) -> dict: 103 """return a summary of the config""" 104 # do we run this to make sure it doesn't error? 105 self_ser: dict = self.serialize() 106 assert self_ser 107 return dict( 108 name=self.name, 109 seq_len_min=self.seq_len_min, 110 seq_len_max=self.seq_len_max, 111 seed=self.seed, 112 applied_filters=self.applied_filters, 113 ) 114 115 @property 116 def _dataset_class(self) -> type: 117 raise NotImplementedError("this should be implemented by subclasses!") 118 119 def to_fname(self) -> str: 120 """convert config to a filename""" 121 self_json_str: str = json.dumps(self.serialize()) 122 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10) 123 warnings.warn( 124 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!", 125 ) 126 return sanitize_fname( 127 # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type] 128 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type] 129 )
base GPTDatasetConfig class
102 def summary(self) -> dict: 103 """return a summary of the config""" 104 # do we run this to make sure it doesn't error? 105 self_ser: dict = self.serialize() 106 assert self_ser 107 return dict( 108 name=self.name, 109 seq_len_min=self.seq_len_min, 110 seq_len_max=self.seq_len_max, 111 seed=self.seed, 112 applied_filters=self.applied_filters, 113 )
return a summary of the config
119 def to_fname(self) -> str: 120 """convert config to a filename""" 121 self_json_str: str = json.dumps(self.serialize()) 122 self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10) 123 warnings.warn( 124 f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!", 125 ) 126 return sanitize_fname( 127 # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type] 128 f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type] 129 )
convert config to a filename
140def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem: # noqa: ANN001, ARG001 141 err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}" 142 raise NotImplementedError( 143 err_msg, 144 )
The type of the None singleton.
132def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig": 133 err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}" 134 raise NotImplementedError( 135 err_msg, 136 )
The type of the None singleton.
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
152class GPTDataset(typing.Generic[T_DatasetConfig]): 153 """wrapper for torch dataset with some extra functionality 154 155 (meaning the functionality should be inherited in downstream classes) 156 157 > [!NOTE] 158 > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config 159 160 # Requires: 161 the following methods should be implemented in subclasses: 162 - `__init__(self, cfg: GPTDatasetConfig, **kwargs)` 163 initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably) 164 - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 165 generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation) 166 - `serialize(self) -> JSONitem` 167 serialize the dataset to a ZANJ-serializable object, including: 168 - config 169 - data in formats specified by `self.save_formats` 170 - `load(cls, data: JSONitem) -> GPTDataset` 171 load the dataset from a ZANJ-serializable object 172 - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` 173 given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url) 174 - `__len__(self) -> int` 175 return the length of the dataset, required to match interface of `torch.utils.data.Dataset` 176 - `__getitem__(self, i: int) -> list[str]` 177 return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset` 178 - `update_self_config(self) -> None` 179 update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation 180 - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters 181 182 # Parameters: 183 - `cfg : GPTDatasetConfig` 184 config for the dataset, used to generate the dataset 185 - `do_generate : bool` 186 whether to generate the dataset if it isn't found 187 (defaults to `True`) 188 - `load_local : bool` 189 whether to try finding the dataset locally 190 (defaults to `True`) 191 - `save_local : bool` 192 whether to save the dataset locally if it is generated or downloaded 193 (defaults to `True`) 194 - `do_download : bool` 195 whether to try downloading the dataset 196 (defaults to `True`) 197 - `local_base_path : Path` 198 where to save the dataset 199 (defaults to `Path("data/maze_dataset")`) 200 201 # Returns: 202 - `GPTDataset` 203 the dataset, as you wanted it 204 205 # Implements: 206 - `save(self, file_path: str) -> None` 207 save the dataset to a file, using ZANJ 208 - `read(cls, file_path: str) -> GPTDataset` 209 read the dataset from a file, using ZANJ 210 get all items in the dataset, in the specified format 211 - `filter_by(self)` 212 returns a namespace class 213 - `_filter_namespace(self) -> Class` 214 returns a namespace class for filtering the dataset, checking that method 215 - `_apply_filters_from_config(self) -> None` 216 apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating 217 218 """ 219 220 _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore 221 222 cfg: "T_DatasetConfig" 223 224 @classmethod 225 def from_config( # noqa: C901, PLR0912 226 cls, 227 cfg: "T_DatasetConfig", 228 do_generate: bool = True, 229 load_local: bool = True, 230 save_local: bool = True, 231 zanj: ZANJ | None = None, 232 do_download: bool = True, 233 local_base_path: Path = Path("data/maze_dataset"), 234 except_on_config_mismatch: bool = True, 235 allow_generation_metadata_filter_mismatch: bool = True, 236 verbose: bool = False, 237 **kwargs, 238 ) -> "Self": 239 """base class for gpt datasets 240 241 priority of loading: 242 1. load from local 243 2. download 244 3. generate 245 246 """ 247 print_log: Callable = print if verbose else lambda *_a, **_kw: None 248 249 local_base_path = Path(local_base_path) 250 fname: Path = Path(f"{cfg.to_fname()}.zanj") 251 output: Self | None = None 252 did_load_local: bool = False 253 if zanj is None: 254 zanj = ZANJ() 255 256 print_log(f"trying to get the dataset '{cfg.to_fname()}'") 257 258 if not (load_local or do_download or do_generate): 259 raise ValueError( 260 "no way to load dataset! you said not to load local, not to download, and not to generate", 261 ) 262 263 dataset_path: Path = local_base_path / fname 264 265 # try loading 266 if load_local: # noqa: SIM102 267 if dataset_path.exists(): 268 print_log(f"loading dataset from {dataset_path.as_posix()}") 269 try: 270 output = cls.read(dataset_path, zanj=zanj) 271 did_load_local = True 272 print_log("load successful!") 273 except Exception as e: # noqa: BLE001 274 print_log(f"failed to load dataset: {e}") 275 276 if do_download and output is None: 277 print_log("seeing if we can download the dataset...") 278 try: 279 output = cls.download(cfg, **kwargs) 280 print_log("download successful!") 281 except NotImplementedError: 282 print_log("no download found, or download failed") 283 284 if do_generate and output is None: 285 print_log("generating dataset...") 286 output = cls.generate(cfg, verbose=verbose, **kwargs) 287 # only if we generated it, apply filters 288 output = output._apply_filters_from_config() 289 290 # check and save 291 if output is None: 292 raise ValueError("failed to load dataset!") 293 294 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True) 295 if cfg_diff: 296 if except_on_config_mismatch: 297 if allow_generation_metadata_filter_mismatch and ( 298 cfg_diff 299 == { 300 "applied_filters": { 301 "self": [], 302 "other": [ 303 { 304 "name": "collect_generation_meta", 305 "args": (), 306 "kwargs": {}, 307 }, 308 ], 309 }, 310 } 311 ): 312 pass 313 else: 314 err_msg: str = f"config mismatch: {cfg_diff = }" 315 raise ValueError(err_msg) 316 else: 317 warnings.warn(f"config mismatch: {cfg_diff = }") 318 319 if save_local and not did_load_local: 320 print_log(f"saving dataset to {dataset_path}") 321 output.save(dataset_path, zanj=zanj) 322 323 print_log( 324 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }", 325 ) 326 return output 327 328 def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None: 329 "save dataset to a file with zanj" 330 if zanj is None: 331 zanj = ZANJ() 332 zanj.save(self.serialize(), file_path) 333 334 # serialization & loading 335 @classmethod 336 def read(cls, file_path: str | Path, zanj: ZANJ | None = None) -> "Self": 337 "read dataset from a file with zanj" 338 if zanj is None: 339 zanj = ZANJ() 340 return zanj.read(file_path) 341 342 def serialize(self) -> JSONdict: 343 "(implement in subclass!) serialize to something we can save with zanj" 344 raise NotImplementedError 345 346 def data_hash(self) -> int: 347 "(implement in subclass!) return a hash of the data" 348 raise NotImplementedError 349 350 @classmethod 351 def load(cls, data: JSONdict) -> "Self": 352 "(implement in subclass!) load a dataset from what we made with `.serialize()`" 353 raise NotImplementedError 354 355 # generating & downloading 356 @classmethod 357 def generate(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self": 358 "(implement in subclass!) generative given the config" 359 raise NotImplementedError 360 361 @classmethod 362 def download(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self": 363 "(implement in subclass!) download the dataset given the config" 364 raise NotImplementedError 365 366 # filtering 367 def update_self_config(self) -> None: 368 """(implement in subclass!) update the config of the dataset to match the actual data, if needed 369 370 for example, adjust number of mazes after filtering 371 """ 372 pass 373 374 def __len__(self) -> int: 375 "return the length of the dataset" 376 raise NotImplementedError("implement in subclass!") 377 378 class FilterBy: 379 """thanks GPT-4""" 380 381 def __init__(self, dataset: "T_Dataset") -> None: 382 "mock class so we can call `my_dataset.filter_by.some_registered_filter()`" 383 self.dataset: T_Dataset = dataset 384 385 def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]: 386 "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`" 387 filter_func: DatasetFilterFunc = getattr( 388 self.dataset._FILTER_NAMESPACE, 389 name, 390 ) 391 392 def wrapped_filter_func(*args, **kwargs): # noqa: ANN202 393 return filter_func(self.dataset, *args, **kwargs) 394 395 return wrapped_filter_func 396 397 @property 398 def filter_by(self) -> "FilterBy": 399 "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset" 400 return self.FilterBy(self) 401 402 def _apply_filters_from_config(self) -> "Self": 403 """apply filters to the dataset, as specified in the config. used in `from_config()`""" 404 output: Self = self 405 # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters 406 applied_filters_old: list[ 407 dict[typing.Literal["name", "args", "kwargs"], typing.Any] 408 ] = self.cfg.applied_filters 409 output.cfg.applied_filters = list() 410 # apply the filters 411 for filter_info in applied_filters_old: 412 filter_name: str = filter_info["name"] 413 if filter_name not in output._FILTER_NAMESPACE.__dict__: 414 if filter_name.startswith("__custom__:"): 415 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!" 416 raise ValueError( 417 err_msg, 418 ) 419 err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'" 420 raise ValueError( 421 err_msg, 422 ) 423 filter_args: list = filter_info.get("args", list()) 424 filter_kwargs: dict = filter_info.get("kwargs", dict()) 425 output = getattr(output.filter_by, filter_name)( 426 *filter_args, 427 **filter_kwargs, 428 ) 429 430 # update the config, perform checks 431 # TODO: some funny business with manually specified filters here? 432 output.update_self_config() 433 _check_filter_equality( 434 filters_old=applied_filters_old, 435 filters_new=output.cfg.applied_filters, # type: ignore[arg-type] 436 ) 437 return output
wrapper for torch dataset with some extra functionality
(meaning the functionality should be inherited in downstream classes)
GPTDatasetConfig
should implement a to_fname
method that returns a unique filename for the config
Requires:
the following methods should be implemented in subclasses:
__init__(self, cfg: GPTDatasetConfig, **kwargs)
initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset
generate the dataset from a given config. kwargs are passed through fromfrom_config
, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)serialize(self) -> JSONitem
serialize the dataset to a ZANJ-serializable object, including:- config
- data in formats specified by
self.save_formats
load(cls, data: JSONitem) -> GPTDataset
load the dataset from a ZANJ-serializable objectdownload(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset
given a config, try to download a dataset from some source. kwargs are passed through fromfrom_config
, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)__len__(self) -> int
return the length of the dataset, required to match interface oftorch.utils.data.Dataset
__getitem__(self, i: int) -> list[str]
return the ith item in the dataset, required to match interface oftorch.utils.data.Dataset
update_self_config(self) -> None
update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation- decorating the appropriate filter namespace with
register_filter_namespace_for_dataset(your_dataset_class)
if you want to use filters
Parameters:
cfg : GPTDatasetConfig
config for the dataset, used to generate the datasetdo_generate : bool
whether to generate the dataset if it isn't found (defaults toTrue
)load_local : bool
whether to try finding the dataset locally (defaults toTrue
)save_local : bool
whether to save the dataset locally if it is generated or downloaded (defaults toTrue
)do_download : bool
whether to try downloading the dataset (defaults toTrue
)local_base_path : Path
where to save the dataset (defaults toPath("data/maze_dataset")
)
Returns:
GPTDataset
the dataset, as you wanted it
Implements:
save(self, file_path: str) -> None
save the dataset to a file, using ZANJread(cls, file_path: str) -> GPTDataset
read the dataset from a file, using ZANJ get all items in the dataset, in the specified formatfilter_by(self)
returns a namespace class_filter_namespace(self) -> Class
returns a namespace class for filtering the dataset, checking that method_apply_filters_from_config(self) -> None
apply filters to the dataset, as specified in the config. used infrom_config()
but only when generating
224 @classmethod 225 def from_config( # noqa: C901, PLR0912 226 cls, 227 cfg: "T_DatasetConfig", 228 do_generate: bool = True, 229 load_local: bool = True, 230 save_local: bool = True, 231 zanj: ZANJ | None = None, 232 do_download: bool = True, 233 local_base_path: Path = Path("data/maze_dataset"), 234 except_on_config_mismatch: bool = True, 235 allow_generation_metadata_filter_mismatch: bool = True, 236 verbose: bool = False, 237 **kwargs, 238 ) -> "Self": 239 """base class for gpt datasets 240 241 priority of loading: 242 1. load from local 243 2. download 244 3. generate 245 246 """ 247 print_log: Callable = print if verbose else lambda *_a, **_kw: None 248 249 local_base_path = Path(local_base_path) 250 fname: Path = Path(f"{cfg.to_fname()}.zanj") 251 output: Self | None = None 252 did_load_local: bool = False 253 if zanj is None: 254 zanj = ZANJ() 255 256 print_log(f"trying to get the dataset '{cfg.to_fname()}'") 257 258 if not (load_local or do_download or do_generate): 259 raise ValueError( 260 "no way to load dataset! you said not to load local, not to download, and not to generate", 261 ) 262 263 dataset_path: Path = local_base_path / fname 264 265 # try loading 266 if load_local: # noqa: SIM102 267 if dataset_path.exists(): 268 print_log(f"loading dataset from {dataset_path.as_posix()}") 269 try: 270 output = cls.read(dataset_path, zanj=zanj) 271 did_load_local = True 272 print_log("load successful!") 273 except Exception as e: # noqa: BLE001 274 print_log(f"failed to load dataset: {e}") 275 276 if do_download and output is None: 277 print_log("seeing if we can download the dataset...") 278 try: 279 output = cls.download(cfg, **kwargs) 280 print_log("download successful!") 281 except NotImplementedError: 282 print_log("no download found, or download failed") 283 284 if do_generate and output is None: 285 print_log("generating dataset...") 286 output = cls.generate(cfg, verbose=verbose, **kwargs) 287 # only if we generated it, apply filters 288 output = output._apply_filters_from_config() 289 290 # check and save 291 if output is None: 292 raise ValueError("failed to load dataset!") 293 294 cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True) 295 if cfg_diff: 296 if except_on_config_mismatch: 297 if allow_generation_metadata_filter_mismatch and ( 298 cfg_diff 299 == { 300 "applied_filters": { 301 "self": [], 302 "other": [ 303 { 304 "name": "collect_generation_meta", 305 "args": (), 306 "kwargs": {}, 307 }, 308 ], 309 }, 310 } 311 ): 312 pass 313 else: 314 err_msg: str = f"config mismatch: {cfg_diff = }" 315 raise ValueError(err_msg) 316 else: 317 warnings.warn(f"config mismatch: {cfg_diff = }") 318 319 if save_local and not did_load_local: 320 print_log(f"saving dataset to {dataset_path}") 321 output.save(dataset_path, zanj=zanj) 322 323 print_log( 324 f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }", 325 ) 326 return output
base class for gpt datasets
priority of loading:
- load from local
- download
- generate
328 def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None: 329 "save dataset to a file with zanj" 330 if zanj is None: 331 zanj = ZANJ() 332 zanj.save(self.serialize(), file_path)
save dataset to a file with zanj
335 @classmethod 336 def read(cls, file_path: str | Path, zanj: ZANJ | None = None) -> "Self": 337 "read dataset from a file with zanj" 338 if zanj is None: 339 zanj = ZANJ() 340 return zanj.read(file_path)
read dataset from a file with zanj
342 def serialize(self) -> JSONdict: 343 "(implement in subclass!) serialize to something we can save with zanj" 344 raise NotImplementedError
(implement in subclass!) serialize to something we can save with zanj
346 def data_hash(self) -> int: 347 "(implement in subclass!) return a hash of the data" 348 raise NotImplementedError
(implement in subclass!) return a hash of the data
350 @classmethod 351 def load(cls, data: JSONdict) -> "Self": 352 "(implement in subclass!) load a dataset from what we made with `.serialize()`" 353 raise NotImplementedError
(implement in subclass!) load a dataset from what we made with .serialize()
356 @classmethod 357 def generate(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self": 358 "(implement in subclass!) generative given the config" 359 raise NotImplementedError
(implement in subclass!) generative given the config
361 @classmethod 362 def download(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self": 363 "(implement in subclass!) download the dataset given the config" 364 raise NotImplementedError
(implement in subclass!) download the dataset given the config
367 def update_self_config(self) -> None: 368 """(implement in subclass!) update the config of the dataset to match the actual data, if needed 369 370 for example, adjust number of mazes after filtering 371 """ 372 pass
(implement in subclass!) update the config of the dataset to match the actual data, if needed
for example, adjust number of mazes after filtering
397 @property 398 def filter_by(self) -> "FilterBy": 399 "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset" 400 return self.FilterBy(self)
can call my_dataset.filter_by.some_registered_filter()
to filter the dataset
378 class FilterBy: 379 """thanks GPT-4""" 380 381 def __init__(self, dataset: "T_Dataset") -> None: 382 "mock class so we can call `my_dataset.filter_by.some_registered_filter()`" 383 self.dataset: T_Dataset = dataset 384 385 def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]: 386 "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`" 387 filter_func: DatasetFilterFunc = getattr( 388 self.dataset._FILTER_NAMESPACE, 389 name, 390 ) 391 392 def wrapped_filter_func(*args, **kwargs): # noqa: ANN202 393 return filter_func(self.dataset, *args, **kwargs) 394 395 return wrapped_filter_func
thanks GPT-4
503def register_filter_namespace_for_dataset( 504 dataset_cls: Type[GPTDataset], 505) -> Callable[[Type], Type]: 506 """register the namespace class with the given dataset class""" 507 508 def decorator(filter_namespace_cls: Type) -> Type: 509 dataset_cls._FILTER_NAMESPACE = filter_namespace_cls 510 filter_namespace_cls._BASE_DATASET = dataset_cls 511 512 return filter_namespace_cls 513 514 return decorator
register the namespace class with the given dataset class
522def register_dataset_filter( 523 method: DatasetFilterFunc, 524) -> DatasetFilterFunc: 525 """register a dataset filter, copying the underlying dataset and updating the config 526 527 be sure to return a COPY, not the original? 528 # TODO: what the heck do we mean by the above? why the question mark? it should be a copy right? 529 530 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 531 """ 532 533 @functools.wraps(method) 534 def wrapper( 535 # TYPING: error: ParamSpec "P_FilterKwargs" is unbound [valid-type] 536 dataset: T_Dataset, 537 *args: P_FilterKwargs.args, # type: ignore[valid-type] 538 **kwargs: P_FilterKwargs.kwargs, # type: ignore[valid-type] 539 ) -> T_Dataset: 540 new_dataset = method(dataset, *args, **kwargs) 541 # update the config 542 new_dataset.cfg.applied_filters.append( 543 dict(name=method.__name__, args=args, kwargs=kwargs), # type: ignore[attr-defined] 544 ) 545 new_dataset.update_self_config() 546 return new_dataset 547 548 # TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]") [return-value] 549 return wrapper # type: ignore[return-value]
register a dataset filter, copying the underlying dataset and updating the config
be sure to return a COPY, not the original?
TODO: what the heck do we mean by the above? why the question mark? it should be a copy right?
method should be a staticmethod of a namespace class registered with register_filter_namespace_for_dataset