docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.dataset.dataset

GPTDatasetConfig and GPTDataset are base classes for datasets

they implement some basic functionality, saving/loading, the from_config pipeline, and filtering

Note

these should probably be moved into a different package, so don't rely on them being here


  1"""`GPTDatasetConfig` and `GPTDataset` are base classes for datasets
  2
  3they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering
  4
  5> [!NOTE]
  6> these should probably be moved into a different package, so don't rely on them being here
  7"""
  8
  9import functools
 10import json
 11import random
 12import typing
 13import warnings
 14from pathlib import Path
 15from typing import Callable, Type, TypeVar
 16
 17import numpy as np
 18from muutils.json_serialize import (
 19	JSONitem,
 20	SerializableDataclass,
 21	serializable_dataclass,
 22	serializable_field,
 23)
 24from muutils.json_serialize.util import (
 25	JSONdict,
 26)
 27from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash
 28from zanj import ZANJ
 29
 30from maze_dataset.generation.seed import GLOBAL_SEED
 31
 32
 33def set_reproducibility(seed: int) -> None:
 34	"set reproducibility in stdlib random and numpy (but not torch)"
 35	random.seed(seed)
 36	np.random.seed(seed)
 37
 38
 39class FilterInfoMismatchError(ValueError):
 40	"""raised when the filter info in a dataset config does not match the filter info in the dataset"""
 41
 42	pass
 43
 44
 45def _load_applied_filters(
 46	filters: list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]],
 47) -> list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]]:
 48	try:
 49		return [
 50			dict(
 51				name=filter_info["name"],
 52				args=tuple(
 53					filter_info["args"],
 54				),  # muutils/zanj save tuples as lists, and this causes problems
 55				kwargs=dict(filter_info["kwargs"]),  # type: ignore[arg-type]
 56			)
 57			for filter_info in filters
 58		]
 59	except Exception as e:
 60		err_msg: str = f"failed to load applied filters:\n{filters}"
 61		raise ValueError(err_msg) from e
 62
 63
 64@serializable_dataclass(kw_only=True)
 65class GPTDatasetConfig(SerializableDataclass):
 66	"""base GPTDatasetConfig class"""
 67
 68	name: str
 69
 70	# TODO: get rid of all these things as part of migration to tokenizer-free dataset config
 71	# --------------------------------------------------
 72	seq_len_min: int = serializable_field(default=1)
 73	seq_len_max: int = serializable_field(default=512)
 74	# --------------------------------------------------
 75
 76	seed: int | None = serializable_field(default=GLOBAL_SEED)
 77	applied_filters: list[
 78		dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict]
 79	] = serializable_field(
 80		default_factory=list,
 81		deserialize_fn=_load_applied_filters,
 82		assert_type=False,  # TODO: check the type here once muutils supports checking Callable signatures
 83	)
 84
 85	def __post_init__(self) -> None:
 86		"post init, where we set a random seed if none is set"
 87		assert self.seq_len_min <= self.seq_len_max
 88		# if seed set to None, then generate a new random seed
 89		if self.seed is None:
 90			self.seed = np.random.randint(2**31)
 91
 92		# TODO: something here is broken
 93		if self.seed != GLOBAL_SEED:
 94			warnings.warn(
 95				f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }",
 96			)
 97
 98		set_reproducibility(self.seed)
 99
100	def summary(self) -> dict:
101		"""return a summary of the config"""
102		# do we run this to make sure it doesn't error?
103		self_ser: dict = self.serialize()
104		assert self_ser
105		return dict(
106			name=self.name,
107			seq_len_min=self.seq_len_min,
108			seq_len_max=self.seq_len_max,
109			seed=self.seed,
110			applied_filters=self.applied_filters,
111		)
112
113	@property
114	def _dataset_class(self) -> type:
115		raise NotImplementedError("this should be implemented by subclasses!")
116
117	def to_fname(self) -> str:
118		"""convert config to a filename"""
119		self_json_str: str = json.dumps(self.serialize())
120		self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10)
121		warnings.warn(
122			f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!",
123		)
124		return sanitize_fname(
125			# TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized"  [arg-type]
126			f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}",  # type: ignore[arg-type]
127		)
128
129
130def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig":
131	err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}"
132	raise NotImplementedError(
133		err_msg,
134	)
135
136
137# abstract function, hence we dont care that `self` is unused
138def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem:  # noqa: ANN001, ARG001
139	err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}"
140	raise NotImplementedError(
141		err_msg,
142	)
143
144
145GPTDatasetConfig.load = _dataset_config_load  # type: ignore[method-assign]
146GPTDatasetConfig.serialize = _dataset_config_serialize  # type: ignore[method-assign,assignment]
147T_DatasetConfig = TypeVar("T_DatasetConfig", bound=GPTDatasetConfig)
148
149
150class GPTDataset(typing.Generic[T_DatasetConfig]):
151	"""wrapper for torch dataset with some extra functionality
152
153	(meaning the functionality should be inherited in downstream classes)
154
155	> [!NOTE]
156	> `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config
157
158	# Requires:
159	the following methods should be implemented in subclasses:
160	- `__init__(self, cfg: GPTDatasetConfig, **kwargs)`
161		initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)
162	- `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
163		generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)
164	- `serialize(self) -> JSONitem`
165		serialize the dataset to a ZANJ-serializable object, including:
166		- config
167		- data in formats specified by `self.save_formats`
168	- `load(cls, data: JSONitem) -> GPTDataset`
169		load the dataset from a ZANJ-serializable object
170	- `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
171		given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)
172	- `__len__(self) -> int`
173		return the length of the dataset, required to match interface of `torch.utils.data.Dataset`
174	- `__getitem__(self, i: int) -> list[str]`
175		return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset`
176	-  `update_self_config(self) -> None`
177		update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation
178	-  decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters
179
180	# Parameters:
181	- `cfg : GPTDatasetConfig`
182		config for the dataset, used to generate the dataset
183	- `do_generate : bool`
184		whether to generate the dataset if it isn't found
185		(defaults to `True`)
186	- `load_local : bool`
187		whether to try finding the dataset locally
188		(defaults to `True`)
189	- `save_local : bool`
190		whether to save the dataset locally if it is generated or downloaded
191		(defaults to `True`)
192	- `do_download : bool`
193		whether to try downloading the dataset
194		(defaults to `True`)
195	- `local_base_path : Path`
196		where to save the dataset
197		(defaults to `Path("data/maze_dataset")`)
198
199	# Returns:
200	- `GPTDataset`
201		the dataset, as you wanted it
202
203	# Implements:
204	- `save(self, file_path: str) -> None`
205		save the dataset to a file, using ZANJ
206	- `read(cls, file_path: str) -> GPTDataset`
207		read the dataset from a file, using ZANJ
208		get all items in the dataset, in the specified format
209	- `filter_by(self)`
210		returns a namespace class
211	-  `_filter_namespace(self) -> Class`
212		returns a namespace class for filtering the dataset, checking that method
213	- `_apply_filters_from_config(self) -> None`
214		apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating
215
216	"""
217
218	_FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`"  # type: ignore
219
220	cfg: "T_DatasetConfig"
221
222	@classmethod
223	def from_config(  # noqa: C901, PLR0912
224		cls: "type[T_Dataset]",
225		cfg: "T_DatasetConfig",
226		do_generate: bool = True,
227		load_local: bool = True,
228		save_local: bool = True,
229		zanj: ZANJ | None = None,
230		do_download: bool = True,
231		local_base_path: Path = Path("data/maze_dataset"),
232		except_on_config_mismatch: bool = True,
233		allow_generation_metadata_filter_mismatch: bool = True,
234		verbose: bool = False,
235		**kwargs,
236	) -> "T_Dataset":
237		"""base class for gpt datasets
238
239		priority of loading:
240		1. load from local
241		2. download
242		3. generate
243
244		"""
245		print_log: Callable = print if verbose else lambda *_a, **_kw: None
246
247		local_base_path = Path(local_base_path)
248		fname: Path = Path(f"{cfg.to_fname()}.zanj")
249		output: T_Dataset | None = None
250		did_load_local: bool = False
251		if zanj is None:
252			zanj = ZANJ()
253
254		print_log(f"trying to get the dataset '{cfg.to_fname()}'")
255
256		if not (load_local or do_download or do_generate):
257			raise ValueError(
258				"no way to load dataset! you said not to load local, not to download, and not to generate",
259			)
260
261		dataset_path: Path = local_base_path / fname
262
263		# try loading
264		if load_local:  # noqa: SIM102
265			if dataset_path.exists():
266				print_log(f"loading dataset from {dataset_path.as_posix()}")
267				try:
268					output = cls.read(dataset_path, zanj=zanj)
269					did_load_local = True
270					print_log("load successful!")
271				except Exception as e:  # noqa: BLE001
272					print_log(f"failed to load dataset: {e}")
273
274		if do_download and output is None:
275			print_log("seeing if we can download the dataset...")
276			try:
277				output = cls.download(cfg, **kwargs)
278				print_log("download successful!")
279			except NotImplementedError:
280				print_log("no download found, or download failed")
281
282		if do_generate and output is None:
283			print_log("generating dataset...")
284			output = cls.generate(cfg, verbose=verbose, **kwargs)
285			# only if we generated it, apply filters
286			output = output._apply_filters_from_config()
287
288		# check and save
289		if output is None:
290			raise ValueError("failed to load dataset!")
291
292		cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True)
293		if cfg_diff:
294			if except_on_config_mismatch:
295				if allow_generation_metadata_filter_mismatch and (
296					cfg_diff
297					== {
298						"applied_filters": {
299							"self": [],
300							"other": [
301								{
302									"name": "collect_generation_meta",
303									"args": (),
304									"kwargs": {},
305								},
306							],
307						},
308					}
309				):
310					pass
311				else:
312					err_msg: str = f"config mismatch: {cfg_diff = }"
313					raise ValueError(err_msg)
314			else:
315				warnings.warn(f"config mismatch: {cfg_diff = }")
316
317		if save_local and not did_load_local:
318			print_log(f"saving dataset to {dataset_path}")
319			output.save(dataset_path, zanj=zanj)
320
321		print_log(
322			f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }",
323		)
324		return output
325
326	def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None:
327		"save dataset to a file with zanj"
328		if zanj is None:
329			zanj = ZANJ()
330		zanj.save(self.serialize(), file_path)
331
332	# serialization & loading
333	@classmethod
334	def read(
335		cls: "type[T_Dataset]", file_path: str | Path, zanj: ZANJ | None = None
336	) -> "T_Dataset":
337		"read dataset from a file with zanj"
338		if zanj is None:
339			zanj = ZANJ()
340		return zanj.read(file_path)
341
342	def serialize(self: "T_Dataset") -> JSONdict:
343		"(implement in subclass!) serialize to something we can save with zanj"
344		raise NotImplementedError
345
346	def data_hash(self: "T_Dataset") -> int:
347		"(implement in subclass!) return a hash of the data"
348		raise NotImplementedError
349
350	@classmethod
351	def load(cls: "type[T_Dataset]", data: JSONdict) -> "T_Dataset":
352		"(implement in subclass!) load a dataset from what we made with `.serialize()`"
353		raise NotImplementedError
354
355	# generating & downloading
356	@classmethod
357	def generate(
358		cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs
359	) -> "T_Dataset":
360		"(implement in subclass!) generative given the config"
361		raise NotImplementedError
362
363	@classmethod
364	def download(
365		cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs
366	) -> "T_Dataset":
367		"(implement in subclass!) download the dataset given the config"
368		raise NotImplementedError
369
370	# filtering
371	def update_self_config(self) -> None:
372		"""(implement in subclass!) update the config of the dataset to match the actual data, if needed
373
374		for example, adjust number of mazes after filtering
375		"""
376		pass
377
378	def __len__(self) -> int:
379		"return the length of the dataset"
380		raise NotImplementedError("implement in subclass!")
381
382	class FilterBy:
383		"""thanks GPT-4"""
384
385		def __init__(self, dataset: "T_Dataset") -> None:
386			"mock class so we can call `my_dataset.filter_by.some_registered_filter()`"
387			self.dataset: T_Dataset = dataset
388
389		def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]:
390			"override getattr so we can call `my_dataset.filter_by.some_registered_filter()`"
391			filter_func: DatasetFilterFunc = getattr(
392				self.dataset._FILTER_NAMESPACE,
393				name,
394			)
395
396			def wrapped_filter_func(*args, **kwargs):  # noqa: ANN202
397				return filter_func(self.dataset, *args, **kwargs)
398
399			return wrapped_filter_func
400
401	@property
402	def filter_by(self) -> "FilterBy":
403		"can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset"
404		return self.FilterBy(self)
405
406	def _apply_filters_from_config(self: "T_Dataset") -> "T_Dataset":
407		"""apply filters to the dataset, as specified in the config. used in `from_config()`"""
408		output: T_Dataset = self
409		# copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters
410		applied_filters_old: list[
411			dict[typing.Literal["name", "args", "kwargs"], typing.Any]
412		] = self.cfg.applied_filters
413		output.cfg.applied_filters = list()
414		# apply the filters
415		for filter_info in applied_filters_old:
416			filter_name: str = filter_info["name"]
417			if filter_name not in output._FILTER_NAMESPACE.__dict__:
418				if filter_name.startswith("__custom__:"):
419					err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!"
420					raise ValueError(
421						err_msg,
422					)
423				err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'"
424				raise ValueError(
425					err_msg,
426				)
427			filter_args: list = filter_info.get("args", list())
428			filter_kwargs: dict = filter_info.get("kwargs", dict())
429			output = getattr(output.filter_by, filter_name)(
430				*filter_args,
431				**filter_kwargs,
432			)
433
434		# update the config, perform checks
435		# TODO: some funny business with manually specified filters here?
436		output.update_self_config()
437		_check_filter_equality(
438			filters_old=applied_filters_old,
439			filters_new=output.cfg.applied_filters,  # type: ignore[arg-type]
440		)
441		return output
442
443
444def _check_filter_equality(
445	filters_old: list[
446		dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
447	],
448	filters_new: list[
449		dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
450	],
451) -> None:
452	try:
453		assert len(filters_old) == len(filters_new)
454
455		for filterinfo_new, filterinfo_old in zip(
456			filters_old,
457			filters_new,
458			strict=False,
459		):
460			# basic checks
461			assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict"
462			assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict"
463			assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), (
464				"missing keys in filterinfo_new"
465			)
466			assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), (
467				"missing keys in filterinfo_old"
468			)
469
470			# name
471			assert filterinfo_new["name"] == filterinfo_old["name"], (
472				"filter names don't match"
473			)
474
475			# args
476			assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), (
477				"filter args of different lengths"
478			)
479			for arg_new, arg_old in zip(
480				filterinfo_new["args"],
481				filterinfo_old["args"],
482				strict=False,
483			):
484				assert arg_new == arg_old, "filter args don't match"
485
486			# kwargs
487			assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), (
488				"filter kwargs of different lengths"
489			)
490			for key in filterinfo_old["kwargs"]:
491				assert key in filterinfo_new["kwargs"], (
492					f"filter kwargs don't match: missing key '{key}'"
493				)
494				assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], (  # type: ignore[index]
495					f"filter kwargs don't match: values for key '{key}' don't match"
496				)
497
498	except AssertionError as e:
499		err_msg: str = (
500			f"config mismatch in applied filters: {filters_new} != {filters_old}"
501		)
502		raise FilterInfoMismatchError(
503			err_msg,
504		) from e
505
506
507def register_filter_namespace_for_dataset(
508	dataset_cls: Type[GPTDataset],
509) -> Callable[[Type], Type]:
510	"""register the namespace class with the given dataset class"""
511
512	def decorator(filter_namespace_cls: Type) -> Type:
513		dataset_cls._FILTER_NAMESPACE = filter_namespace_cls
514		filter_namespace_cls._BASE_DATASET = dataset_cls
515
516		return filter_namespace_cls
517
518	return decorator
519
520
521T_Dataset = TypeVar("T_Dataset", bound=GPTDataset)
522P_FilterKwargs = typing.ParamSpec("P_FilterKwargs")
523DatasetFilterFunc = Callable[typing.Concatenate[T_Dataset, P_FilterKwargs], T_Dataset]
524
525
526def register_dataset_filter(
527	method: DatasetFilterFunc,
528) -> DatasetFilterFunc:
529	"""register a dataset filter, copying the underlying dataset and updating the config
530
531	be sure to return a COPY, not the original?
532	# TODO: what the heck do we mean by the above? why the question mark? it should be a copy right?
533
534	method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
535	"""
536
537	@functools.wraps(method)
538	def wrapper(
539		# TYPING: error: ParamSpec "P_FilterKwargs" is unbound  [valid-type]
540		dataset: T_Dataset,
541		*args: P_FilterKwargs.args,  # type: ignore[valid-type]
542		**kwargs: P_FilterKwargs.kwargs,  # type: ignore[valid-type]
543	) -> T_Dataset:
544		new_dataset = method(dataset, *args, **kwargs)
545		# update the config
546		new_dataset.cfg.applied_filters.append(
547			dict(name=method.__name__, args=args, kwargs=kwargs),  # type: ignore[attr-defined]
548		)
549		new_dataset.update_self_config()
550		return new_dataset
551
552	# TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]")  [return-value]
553	return wrapper  # type: ignore[return-value]

def set_reproducibility(seed: int) -> None:
34def set_reproducibility(seed: int) -> None:
35	"set reproducibility in stdlib random and numpy (but not torch)"
36	random.seed(seed)
37	np.random.seed(seed)

set reproducibility in stdlib random and numpy (but not torch)

class FilterInfoMismatchError(builtins.ValueError):
40class FilterInfoMismatchError(ValueError):
41	"""raised when the filter info in a dataset config does not match the filter info in the dataset"""
42
43	pass

raised when the filter info in a dataset config does not match the filter info in the dataset

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
@serializable_dataclass(kw_only=True)
class GPTDatasetConfig(muutils.json_serialize.serializable_dataclass.SerializableDataclass):
 65@serializable_dataclass(kw_only=True)
 66class GPTDatasetConfig(SerializableDataclass):
 67	"""base GPTDatasetConfig class"""
 68
 69	name: str
 70
 71	# TODO: get rid of all these things as part of migration to tokenizer-free dataset config
 72	# --------------------------------------------------
 73	seq_len_min: int = serializable_field(default=1)
 74	seq_len_max: int = serializable_field(default=512)
 75	# --------------------------------------------------
 76
 77	seed: int | None = serializable_field(default=GLOBAL_SEED)
 78	applied_filters: list[
 79		dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict]
 80	] = serializable_field(
 81		default_factory=list,
 82		deserialize_fn=_load_applied_filters,
 83		assert_type=False,  # TODO: check the type here once muutils supports checking Callable signatures
 84	)
 85
 86	def __post_init__(self) -> None:
 87		"post init, where we set a random seed if none is set"
 88		assert self.seq_len_min <= self.seq_len_max
 89		# if seed set to None, then generate a new random seed
 90		if self.seed is None:
 91			self.seed = np.random.randint(2**31)
 92
 93		# TODO: something here is broken
 94		if self.seed != GLOBAL_SEED:
 95			warnings.warn(
 96				f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }",
 97			)
 98
 99		set_reproducibility(self.seed)
100
101	def summary(self) -> dict:
102		"""return a summary of the config"""
103		# do we run this to make sure it doesn't error?
104		self_ser: dict = self.serialize()
105		assert self_ser
106		return dict(
107			name=self.name,
108			seq_len_min=self.seq_len_min,
109			seq_len_max=self.seq_len_max,
110			seed=self.seed,
111			applied_filters=self.applied_filters,
112		)
113
114	@property
115	def _dataset_class(self) -> type:
116		raise NotImplementedError("this should be implemented by subclasses!")
117
118	def to_fname(self) -> str:
119		"""convert config to a filename"""
120		self_json_str: str = json.dumps(self.serialize())
121		self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10)
122		warnings.warn(
123			f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!",
124		)
125		return sanitize_fname(
126			# TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized"  [arg-type]
127			f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}",  # type: ignore[arg-type]
128		)

base GPTDatasetConfig class

GPTDatasetConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>)
name: str
seq_len_min: int = 1
seq_len_max: int = 512
seed: int | None = 42
applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]]
def summary(self) -> dict:
101	def summary(self) -> dict:
102		"""return a summary of the config"""
103		# do we run this to make sure it doesn't error?
104		self_ser: dict = self.serialize()
105		assert self_ser
106		return dict(
107			name=self.name,
108			seq_len_min=self.seq_len_min,
109			seq_len_max=self.seq_len_max,
110			seed=self.seed,
111			applied_filters=self.applied_filters,
112		)

return a summary of the config

def to_fname(self) -> str:
118	def to_fname(self) -> str:
119		"""convert config to a filename"""
120		self_json_str: str = json.dumps(self.serialize())
121		self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10)
122		warnings.warn(
123			f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!",
124		)
125		return sanitize_fname(
126			# TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized"  [arg-type]
127			f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}",  # type: ignore[arg-type]
128		)

convert config to a filename

def serialize( self, *args, **kwargs) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]:
139def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem:  # noqa: ANN001, ARG001
140	err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}"
141	raise NotImplementedError(
142		err_msg,
143	)

returns the class as a dict, implemented by using @serializable_dataclass decorator

def load(*args, **kwargs) -> GPTDatasetConfig:
131def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig":
132	err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}"
133	raise NotImplementedError(
134		err_msg,
135	)

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
class GPTDataset(typing.Generic[~T_DatasetConfig]):
151class GPTDataset(typing.Generic[T_DatasetConfig]):
152	"""wrapper for torch dataset with some extra functionality
153
154	(meaning the functionality should be inherited in downstream classes)
155
156	> [!NOTE]
157	> `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config
158
159	# Requires:
160	the following methods should be implemented in subclasses:
161	- `__init__(self, cfg: GPTDatasetConfig, **kwargs)`
162		initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)
163	- `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
164		generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)
165	- `serialize(self) -> JSONitem`
166		serialize the dataset to a ZANJ-serializable object, including:
167		- config
168		- data in formats specified by `self.save_formats`
169	- `load(cls, data: JSONitem) -> GPTDataset`
170		load the dataset from a ZANJ-serializable object
171	- `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
172		given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)
173	- `__len__(self) -> int`
174		return the length of the dataset, required to match interface of `torch.utils.data.Dataset`
175	- `__getitem__(self, i: int) -> list[str]`
176		return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset`
177	-  `update_self_config(self) -> None`
178		update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation
179	-  decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters
180
181	# Parameters:
182	- `cfg : GPTDatasetConfig`
183		config for the dataset, used to generate the dataset
184	- `do_generate : bool`
185		whether to generate the dataset if it isn't found
186		(defaults to `True`)
187	- `load_local : bool`
188		whether to try finding the dataset locally
189		(defaults to `True`)
190	- `save_local : bool`
191		whether to save the dataset locally if it is generated or downloaded
192		(defaults to `True`)
193	- `do_download : bool`
194		whether to try downloading the dataset
195		(defaults to `True`)
196	- `local_base_path : Path`
197		where to save the dataset
198		(defaults to `Path("data/maze_dataset")`)
199
200	# Returns:
201	- `GPTDataset`
202		the dataset, as you wanted it
203
204	# Implements:
205	- `save(self, file_path: str) -> None`
206		save the dataset to a file, using ZANJ
207	- `read(cls, file_path: str) -> GPTDataset`
208		read the dataset from a file, using ZANJ
209		get all items in the dataset, in the specified format
210	- `filter_by(self)`
211		returns a namespace class
212	-  `_filter_namespace(self) -> Class`
213		returns a namespace class for filtering the dataset, checking that method
214	- `_apply_filters_from_config(self) -> None`
215		apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating
216
217	"""
218
219	_FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`"  # type: ignore
220
221	cfg: "T_DatasetConfig"
222
223	@classmethod
224	def from_config(  # noqa: C901, PLR0912
225		cls: "type[T_Dataset]",
226		cfg: "T_DatasetConfig",
227		do_generate: bool = True,
228		load_local: bool = True,
229		save_local: bool = True,
230		zanj: ZANJ | None = None,
231		do_download: bool = True,
232		local_base_path: Path = Path("data/maze_dataset"),
233		except_on_config_mismatch: bool = True,
234		allow_generation_metadata_filter_mismatch: bool = True,
235		verbose: bool = False,
236		**kwargs,
237	) -> "T_Dataset":
238		"""base class for gpt datasets
239
240		priority of loading:
241		1. load from local
242		2. download
243		3. generate
244
245		"""
246		print_log: Callable = print if verbose else lambda *_a, **_kw: None
247
248		local_base_path = Path(local_base_path)
249		fname: Path = Path(f"{cfg.to_fname()}.zanj")
250		output: T_Dataset | None = None
251		did_load_local: bool = False
252		if zanj is None:
253			zanj = ZANJ()
254
255		print_log(f"trying to get the dataset '{cfg.to_fname()}'")
256
257		if not (load_local or do_download or do_generate):
258			raise ValueError(
259				"no way to load dataset! you said not to load local, not to download, and not to generate",
260			)
261
262		dataset_path: Path = local_base_path / fname
263
264		# try loading
265		if load_local:  # noqa: SIM102
266			if dataset_path.exists():
267				print_log(f"loading dataset from {dataset_path.as_posix()}")
268				try:
269					output = cls.read(dataset_path, zanj=zanj)
270					did_load_local = True
271					print_log("load successful!")
272				except Exception as e:  # noqa: BLE001
273					print_log(f"failed to load dataset: {e}")
274
275		if do_download and output is None:
276			print_log("seeing if we can download the dataset...")
277			try:
278				output = cls.download(cfg, **kwargs)
279				print_log("download successful!")
280			except NotImplementedError:
281				print_log("no download found, or download failed")
282
283		if do_generate and output is None:
284			print_log("generating dataset...")
285			output = cls.generate(cfg, verbose=verbose, **kwargs)
286			# only if we generated it, apply filters
287			output = output._apply_filters_from_config()
288
289		# check and save
290		if output is None:
291			raise ValueError("failed to load dataset!")
292
293		cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True)
294		if cfg_diff:
295			if except_on_config_mismatch:
296				if allow_generation_metadata_filter_mismatch and (
297					cfg_diff
298					== {
299						"applied_filters": {
300							"self": [],
301							"other": [
302								{
303									"name": "collect_generation_meta",
304									"args": (),
305									"kwargs": {},
306								},
307							],
308						},
309					}
310				):
311					pass
312				else:
313					err_msg: str = f"config mismatch: {cfg_diff = }"
314					raise ValueError(err_msg)
315			else:
316				warnings.warn(f"config mismatch: {cfg_diff = }")
317
318		if save_local and not did_load_local:
319			print_log(f"saving dataset to {dataset_path}")
320			output.save(dataset_path, zanj=zanj)
321
322		print_log(
323			f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }",
324		)
325		return output
326
327	def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None:
328		"save dataset to a file with zanj"
329		if zanj is None:
330			zanj = ZANJ()
331		zanj.save(self.serialize(), file_path)
332
333	# serialization & loading
334	@classmethod
335	def read(
336		cls: "type[T_Dataset]", file_path: str | Path, zanj: ZANJ | None = None
337	) -> "T_Dataset":
338		"read dataset from a file with zanj"
339		if zanj is None:
340			zanj = ZANJ()
341		return zanj.read(file_path)
342
343	def serialize(self: "T_Dataset") -> JSONdict:
344		"(implement in subclass!) serialize to something we can save with zanj"
345		raise NotImplementedError
346
347	def data_hash(self: "T_Dataset") -> int:
348		"(implement in subclass!) return a hash of the data"
349		raise NotImplementedError
350
351	@classmethod
352	def load(cls: "type[T_Dataset]", data: JSONdict) -> "T_Dataset":
353		"(implement in subclass!) load a dataset from what we made with `.serialize()`"
354		raise NotImplementedError
355
356	# generating & downloading
357	@classmethod
358	def generate(
359		cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs
360	) -> "T_Dataset":
361		"(implement in subclass!) generative given the config"
362		raise NotImplementedError
363
364	@classmethod
365	def download(
366		cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs
367	) -> "T_Dataset":
368		"(implement in subclass!) download the dataset given the config"
369		raise NotImplementedError
370
371	# filtering
372	def update_self_config(self) -> None:
373		"""(implement in subclass!) update the config of the dataset to match the actual data, if needed
374
375		for example, adjust number of mazes after filtering
376		"""
377		pass
378
379	def __len__(self) -> int:
380		"return the length of the dataset"
381		raise NotImplementedError("implement in subclass!")
382
383	class FilterBy:
384		"""thanks GPT-4"""
385
386		def __init__(self, dataset: "T_Dataset") -> None:
387			"mock class so we can call `my_dataset.filter_by.some_registered_filter()`"
388			self.dataset: T_Dataset = dataset
389
390		def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]:
391			"override getattr so we can call `my_dataset.filter_by.some_registered_filter()`"
392			filter_func: DatasetFilterFunc = getattr(
393				self.dataset._FILTER_NAMESPACE,
394				name,
395			)
396
397			def wrapped_filter_func(*args, **kwargs):  # noqa: ANN202
398				return filter_func(self.dataset, *args, **kwargs)
399
400			return wrapped_filter_func
401
402	@property
403	def filter_by(self) -> "FilterBy":
404		"can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset"
405		return self.FilterBy(self)
406
407	def _apply_filters_from_config(self: "T_Dataset") -> "T_Dataset":
408		"""apply filters to the dataset, as specified in the config. used in `from_config()`"""
409		output: T_Dataset = self
410		# copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters
411		applied_filters_old: list[
412			dict[typing.Literal["name", "args", "kwargs"], typing.Any]
413		] = self.cfg.applied_filters
414		output.cfg.applied_filters = list()
415		# apply the filters
416		for filter_info in applied_filters_old:
417			filter_name: str = filter_info["name"]
418			if filter_name not in output._FILTER_NAMESPACE.__dict__:
419				if filter_name.startswith("__custom__:"):
420					err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!"
421					raise ValueError(
422						err_msg,
423					)
424				err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'"
425				raise ValueError(
426					err_msg,
427				)
428			filter_args: list = filter_info.get("args", list())
429			filter_kwargs: dict = filter_info.get("kwargs", dict())
430			output = getattr(output.filter_by, filter_name)(
431				*filter_args,
432				**filter_kwargs,
433			)
434
435		# update the config, perform checks
436		# TODO: some funny business with manually specified filters here?
437		output.update_self_config()
438		_check_filter_equality(
439			filters_old=applied_filters_old,
440			filters_new=output.cfg.applied_filters,  # type: ignore[arg-type]
441		)
442		return output

wrapper for torch dataset with some extra functionality

(meaning the functionality should be inherited in downstream classes)

Note

GPTDatasetConfig should implement a to_fname method that returns a unique filename for the config

Requires:

the following methods should be implemented in subclasses:

  • __init__(self, cfg: GPTDatasetConfig, **kwargs) initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)
  • generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset generate the dataset from a given config. kwargs are passed through from from_config, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)
  • serialize(self) -> JSONitem serialize the dataset to a ZANJ-serializable object, including:
    • config
    • data in formats specified by self.save_formats
  • load(cls, data: JSONitem) -> GPTDataset load the dataset from a ZANJ-serializable object
  • download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset given a config, try to download a dataset from some source. kwargs are passed through from from_config, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)
  • __len__(self) -> int return the length of the dataset, required to match interface of torch.utils.data.Dataset
  • __getitem__(self, i: int) -> list[str] return the ith item in the dataset, required to match interface of torch.utils.data.Dataset
  • update_self_config(self) -> None update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation
  • decorating the appropriate filter namespace with register_filter_namespace_for_dataset(your_dataset_class) if you want to use filters

Parameters:

  • cfg : GPTDatasetConfig config for the dataset, used to generate the dataset
  • do_generate : bool whether to generate the dataset if it isn't found (defaults to True)
  • load_local : bool whether to try finding the dataset locally (defaults to True)
  • save_local : bool whether to save the dataset locally if it is generated or downloaded (defaults to True)
  • do_download : bool whether to try downloading the dataset (defaults to True)
  • local_base_path : Path where to save the dataset (defaults to Path("data/maze_dataset"))

Returns:

Implements:

  • save(self, file_path: str) -> None save the dataset to a file, using ZANJ
  • read(cls, file_path: str) -> GPTDataset read the dataset from a file, using ZANJ get all items in the dataset, in the specified format
  • filter_by(self) returns a namespace class
  • _filter_namespace(self) -> Class returns a namespace class for filtering the dataset, checking that method
  • _apply_filters_from_config(self) -> None apply filters to the dataset, as specified in the config. used in from_config() but only when generating
cfg: ~T_DatasetConfig
@classmethod
def from_config( cls: type[~T_Dataset], cfg: ~T_DatasetConfig, do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: zanj.zanj.ZANJ | None = None, do_download: bool = True, local_base_path: pathlib.Path = PosixPath('data/maze_dataset'), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs) -> ~T_Dataset:
223	@classmethod
224	def from_config(  # noqa: C901, PLR0912
225		cls: "type[T_Dataset]",
226		cfg: "T_DatasetConfig",
227		do_generate: bool = True,
228		load_local: bool = True,
229		save_local: bool = True,
230		zanj: ZANJ | None = None,
231		do_download: bool = True,
232		local_base_path: Path = Path("data/maze_dataset"),
233		except_on_config_mismatch: bool = True,
234		allow_generation_metadata_filter_mismatch: bool = True,
235		verbose: bool = False,
236		**kwargs,
237	) -> "T_Dataset":
238		"""base class for gpt datasets
239
240		priority of loading:
241		1. load from local
242		2. download
243		3. generate
244
245		"""
246		print_log: Callable = print if verbose else lambda *_a, **_kw: None
247
248		local_base_path = Path(local_base_path)
249		fname: Path = Path(f"{cfg.to_fname()}.zanj")
250		output: T_Dataset | None = None
251		did_load_local: bool = False
252		if zanj is None:
253			zanj = ZANJ()
254
255		print_log(f"trying to get the dataset '{cfg.to_fname()}'")
256
257		if not (load_local or do_download or do_generate):
258			raise ValueError(
259				"no way to load dataset! you said not to load local, not to download, and not to generate",
260			)
261
262		dataset_path: Path = local_base_path / fname
263
264		# try loading
265		if load_local:  # noqa: SIM102
266			if dataset_path.exists():
267				print_log(f"loading dataset from {dataset_path.as_posix()}")
268				try:
269					output = cls.read(dataset_path, zanj=zanj)
270					did_load_local = True
271					print_log("load successful!")
272				except Exception as e:  # noqa: BLE001
273					print_log(f"failed to load dataset: {e}")
274
275		if do_download and output is None:
276			print_log("seeing if we can download the dataset...")
277			try:
278				output = cls.download(cfg, **kwargs)
279				print_log("download successful!")
280			except NotImplementedError:
281				print_log("no download found, or download failed")
282
283		if do_generate and output is None:
284			print_log("generating dataset...")
285			output = cls.generate(cfg, verbose=verbose, **kwargs)
286			# only if we generated it, apply filters
287			output = output._apply_filters_from_config()
288
289		# check and save
290		if output is None:
291			raise ValueError("failed to load dataset!")
292
293		cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True)
294		if cfg_diff:
295			if except_on_config_mismatch:
296				if allow_generation_metadata_filter_mismatch and (
297					cfg_diff
298					== {
299						"applied_filters": {
300							"self": [],
301							"other": [
302								{
303									"name": "collect_generation_meta",
304									"args": (),
305									"kwargs": {},
306								},
307							],
308						},
309					}
310				):
311					pass
312				else:
313					err_msg: str = f"config mismatch: {cfg_diff = }"
314					raise ValueError(err_msg)
315			else:
316				warnings.warn(f"config mismatch: {cfg_diff = }")
317
318		if save_local and not did_load_local:
319			print_log(f"saving dataset to {dataset_path}")
320			output.save(dataset_path, zanj=zanj)
321
322		print_log(
323			f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }",
324		)
325		return output

base class for gpt datasets

priority of loading:

  1. load from local
  2. download
  3. generate
def save( self, file_path: pathlib.Path | str, zanj: zanj.zanj.ZANJ | None = None) -> None:
327	def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None:
328		"save dataset to a file with zanj"
329		if zanj is None:
330			zanj = ZANJ()
331		zanj.save(self.serialize(), file_path)

save dataset to a file with zanj

@classmethod
def read( cls: type[~T_Dataset], file_path: str | pathlib.Path, zanj: zanj.zanj.ZANJ | None = None) -> ~T_Dataset:
334	@classmethod
335	def read(
336		cls: "type[T_Dataset]", file_path: str | Path, zanj: ZANJ | None = None
337	) -> "T_Dataset":
338		"read dataset from a file with zanj"
339		if zanj is None:
340			zanj = ZANJ()
341		return zanj.read(file_path)

read dataset from a file with zanj

def serialize( self: ~T_Dataset) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
343	def serialize(self: "T_Dataset") -> JSONdict:
344		"(implement in subclass!) serialize to something we can save with zanj"
345		raise NotImplementedError

(implement in subclass!) serialize to something we can save with zanj

def data_hash(self: ~T_Dataset) -> int:
347	def data_hash(self: "T_Dataset") -> int:
348		"(implement in subclass!) return a hash of the data"
349		raise NotImplementedError

(implement in subclass!) return a hash of the data

@classmethod
def load( cls: type[~T_Dataset], data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> ~T_Dataset:
351	@classmethod
352	def load(cls: "type[T_Dataset]", data: JSONdict) -> "T_Dataset":
353		"(implement in subclass!) load a dataset from what we made with `.serialize()`"
354		raise NotImplementedError

(implement in subclass!) load a dataset from what we made with .serialize()

@classmethod
def generate(cls: type[~T_Dataset], cfg: ~T_DatasetConfig, **kwargs) -> ~T_Dataset:
357	@classmethod
358	def generate(
359		cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs
360	) -> "T_Dataset":
361		"(implement in subclass!) generative given the config"
362		raise NotImplementedError

(implement in subclass!) generative given the config

@classmethod
def download(cls: type[~T_Dataset], cfg: ~T_DatasetConfig, **kwargs) -> ~T_Dataset:
364	@classmethod
365	def download(
366		cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs
367	) -> "T_Dataset":
368		"(implement in subclass!) download the dataset given the config"
369		raise NotImplementedError

(implement in subclass!) download the dataset given the config

def update_self_config(self) -> None:
372	def update_self_config(self) -> None:
373		"""(implement in subclass!) update the config of the dataset to match the actual data, if needed
374
375		for example, adjust number of mazes after filtering
376		"""
377		pass

(implement in subclass!) update the config of the dataset to match the actual data, if needed

for example, adjust number of mazes after filtering

filter_by: GPTDataset.FilterBy
402	@property
403	def filter_by(self) -> "FilterBy":
404		"can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset"
405		return self.FilterBy(self)

can call my_dataset.filter_by.some_registered_filter() to filter the dataset

class GPTDataset.FilterBy:
383	class FilterBy:
384		"""thanks GPT-4"""
385
386		def __init__(self, dataset: "T_Dataset") -> None:
387			"mock class so we can call `my_dataset.filter_by.some_registered_filter()`"
388			self.dataset: T_Dataset = dataset
389
390		def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]:
391			"override getattr so we can call `my_dataset.filter_by.some_registered_filter()`"
392			filter_func: DatasetFilterFunc = getattr(
393				self.dataset._FILTER_NAMESPACE,
394				name,
395			)
396
397			def wrapped_filter_func(*args, **kwargs):  # noqa: ANN202
398				return filter_func(self.dataset, *args, **kwargs)
399
400			return wrapped_filter_func

thanks GPT-4

GPTDataset.FilterBy(dataset: ~T_Dataset)
386		def __init__(self, dataset: "T_Dataset") -> None:
387			"mock class so we can call `my_dataset.filter_by.some_registered_filter()`"
388			self.dataset: T_Dataset = dataset

mock class so we can call my_dataset.filter_by.some_registered_filter()

dataset: ~T_Dataset
def register_filter_namespace_for_dataset( dataset_cls: Type[GPTDataset]) -> Callable[[Type], Type]:
508def register_filter_namespace_for_dataset(
509	dataset_cls: Type[GPTDataset],
510) -> Callable[[Type], Type]:
511	"""register the namespace class with the given dataset class"""
512
513	def decorator(filter_namespace_cls: Type) -> Type:
514		dataset_cls._FILTER_NAMESPACE = filter_namespace_cls
515		filter_namespace_cls._BASE_DATASET = dataset_cls
516
517		return filter_namespace_cls
518
519	return decorator

register the namespace class with the given dataset class

P_FilterKwargs = ~P_FilterKwargs
DatasetFilterFunc = typing.Callable[typing.Concatenate[~T_Dataset, ~P_FilterKwargs], ~T_Dataset]
def register_dataset_filter( method: Callable[Concatenate[~T_Dataset, ~P_FilterKwargs], ~T_Dataset]) -> Callable[Concatenate[~T_Dataset, ~P_FilterKwargs], ~T_Dataset]:
527def register_dataset_filter(
528	method: DatasetFilterFunc,
529) -> DatasetFilterFunc:
530	"""register a dataset filter, copying the underlying dataset and updating the config
531
532	be sure to return a COPY, not the original?
533	# TODO: what the heck do we mean by the above? why the question mark? it should be a copy right?
534
535	method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
536	"""
537
538	@functools.wraps(method)
539	def wrapper(
540		# TYPING: error: ParamSpec "P_FilterKwargs" is unbound  [valid-type]
541		dataset: T_Dataset,
542		*args: P_FilterKwargs.args,  # type: ignore[valid-type]
543		**kwargs: P_FilterKwargs.kwargs,  # type: ignore[valid-type]
544	) -> T_Dataset:
545		new_dataset = method(dataset, *args, **kwargs)
546		# update the config
547		new_dataset.cfg.applied_filters.append(
548			dict(name=method.__name__, args=args, kwargs=kwargs),  # type: ignore[attr-defined]
549		)
550		new_dataset.update_self_config()
551		return new_dataset
552
553	# TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]")  [return-value]
554	return wrapper  # type: ignore[return-value]

register a dataset filter, copying the underlying dataset and updating the config

be sure to return a COPY, not the original?

TODO: what the heck do we mean by the above? why the question mark? it should be a copy right?

method should be a staticmethod of a namespace class registered with register_filter_namespace_for_dataset