docs for maze-dataset v1.4.0
View Source on GitHub

maze_dataset.dataset.dataset

GPTDatasetConfig and GPTDataset are base classes for datasets

they implement some basic functionality, saving/loading, the from_config pipeline, and filtering

Note

these should probably be moved into a different package, so don't rely on them being here


  1"""`GPTDatasetConfig` and `GPTDataset` are base classes for datasets
  2
  3they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering
  4
  5> [!NOTE]
  6> these should probably be moved into a different package, so don't rely on them being here
  7"""
  8
  9import functools
 10import json
 11import random
 12import typing
 13import warnings
 14from pathlib import Path
 15from typing import Callable, Type, TypeVar
 16
 17import numpy as np
 18from muutils.json_serialize import (
 19	JSONitem,
 20	SerializableDataclass,
 21	serializable_dataclass,
 22	serializable_field,
 23)
 24from muutils.json_serialize.util import (
 25	JSONdict,
 26)
 27from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash
 28from typing_extensions import Self
 29from zanj import ZANJ
 30
 31from maze_dataset.generation.seed import GLOBAL_SEED
 32
 33
 34def set_reproducibility(seed: int) -> None:
 35	"set reproducibility in stdlib random and numpy (but not torch)"
 36	random.seed(seed)
 37	np.random.seed(seed)
 38
 39
 40class FilterInfoMismatchError(ValueError):
 41	"""raised when the filter info in a dataset config does not match the filter info in the dataset"""
 42
 43	pass
 44
 45
 46def _load_applied_filters(
 47	filters: list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]],
 48) -> list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]]:
 49	try:
 50		return [
 51			dict(
 52				name=filter_info["name"],
 53				args=tuple(
 54					filter_info["args"],
 55				),  # muutils/zanj save tuples as lists, and this causes problems
 56				kwargs=dict(filter_info["kwargs"]),  # type: ignore[arg-type]
 57			)
 58			for filter_info in filters
 59		]
 60	except Exception as e:
 61		err_msg: str = f"failed to load applied filters:\n{filters}"
 62		raise ValueError(err_msg) from e
 63
 64
 65@serializable_dataclass(kw_only=True)
 66class GPTDatasetConfig(SerializableDataclass):
 67	"""base GPTDatasetConfig class"""
 68
 69	name: str
 70
 71	# TODO: get rid of all these things as part of migration to tokenizer-free dataset config
 72	# --------------------------------------------------
 73	seq_len_min: int = serializable_field(default=1)
 74	seq_len_max: int = serializable_field(default=512)
 75	# --------------------------------------------------
 76
 77	seed: int | None = serializable_field(default=GLOBAL_SEED)
 78	applied_filters: list[
 79		dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict]
 80	] = serializable_field(
 81		default_factory=list,
 82		deserialize_fn=_load_applied_filters,
 83		assert_type=False,  # TODO: check the type here once muutils supports checking Callable signatures
 84	)
 85
 86	def __post_init__(self) -> None:
 87		"post init, where we set a random seed if none is set"
 88		assert self.seq_len_min <= self.seq_len_max
 89		# if seed set to None, then generate a new random seed
 90		if self.seed is None:
 91			self.seed = np.random.randint(2**31)
 92
 93		# TODO: something here is broken
 94		if self.seed != GLOBAL_SEED:
 95			warnings.warn(
 96				f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }",
 97			)
 98
 99		set_reproducibility(self.seed)
100
101	def summary(self) -> dict:
102		"""return a summary of the config"""
103		# do we run this to make sure it doesn't error?
104		self_ser: dict = self.serialize()
105		assert self_ser
106		return dict(
107			name=self.name,
108			seq_len_min=self.seq_len_min,
109			seq_len_max=self.seq_len_max,
110			seed=self.seed,
111			applied_filters=self.applied_filters,
112		)
113
114	@property
115	def _dataset_class(self) -> type:
116		raise NotImplementedError("this should be implemented by subclasses!")
117
118	def to_fname(self) -> str:
119		"""convert config to a filename"""
120		self_json_str: str = json.dumps(self.serialize())
121		self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10)
122		warnings.warn(
123			f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!",
124		)
125		return sanitize_fname(
126			# TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized"  [arg-type]
127			f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}",  # type: ignore[arg-type]
128		)
129
130
131def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig":
132	err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}"
133	raise NotImplementedError(
134		err_msg,
135	)
136
137
138# abstract function, hence we dont care that `self` is unused
139def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem:  # noqa: ANN001, ARG001
140	err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}"
141	raise NotImplementedError(
142		err_msg,
143	)
144
145
146GPTDatasetConfig.load = _dataset_config_load  # type: ignore[method-assign]
147GPTDatasetConfig.serialize = _dataset_config_serialize  # type: ignore[method-assign,assignment]
148T_DatasetConfig = TypeVar("T_DatasetConfig", bound=GPTDatasetConfig)
149
150
151class GPTDataset(typing.Generic[T_DatasetConfig]):
152	"""wrapper for torch dataset with some extra functionality
153
154	(meaning the functionality should be inherited in downstream classes)
155
156	> [!NOTE]
157	> `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config
158
159	# Requires:
160	the following methods should be implemented in subclasses:
161	- `__init__(self, cfg: GPTDatasetConfig, **kwargs)`
162		initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)
163	- `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
164		generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)
165	- `serialize(self) -> JSONitem`
166		serialize the dataset to a ZANJ-serializable object, including:
167		- config
168		- data in formats specified by `self.save_formats`
169	- `load(cls, data: JSONitem) -> GPTDataset`
170		load the dataset from a ZANJ-serializable object
171	- `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
172		given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)
173	- `__len__(self) -> int`
174		return the length of the dataset, required to match interface of `torch.utils.data.Dataset`
175	- `__getitem__(self, i: int) -> list[str]`
176		return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset`
177	-  `update_self_config(self) -> None`
178		update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation
179	-  decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters
180
181	# Parameters:
182	- `cfg : GPTDatasetConfig`
183		config for the dataset, used to generate the dataset
184	- `do_generate : bool`
185		whether to generate the dataset if it isn't found
186		(defaults to `True`)
187	- `load_local : bool`
188		whether to try finding the dataset locally
189		(defaults to `True`)
190	- `save_local : bool`
191		whether to save the dataset locally if it is generated or downloaded
192		(defaults to `True`)
193	- `do_download : bool`
194		whether to try downloading the dataset
195		(defaults to `True`)
196	- `local_base_path : Path`
197		where to save the dataset
198		(defaults to `Path("data/maze_dataset")`)
199
200	# Returns:
201	- `GPTDataset`
202		the dataset, as you wanted it
203
204	# Implements:
205	- `save(self, file_path: str) -> None`
206		save the dataset to a file, using ZANJ
207	- `read(cls, file_path: str) -> GPTDataset`
208		read the dataset from a file, using ZANJ
209		get all items in the dataset, in the specified format
210	- `filter_by(self)`
211		returns a namespace class
212	-  `_filter_namespace(self) -> Class`
213		returns a namespace class for filtering the dataset, checking that method
214	- `_apply_filters_from_config(self) -> None`
215		apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating
216
217	"""
218
219	_FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`"  # type: ignore
220
221	cfg: "T_DatasetConfig"
222
223	@classmethod
224	def from_config(  # noqa: C901, PLR0912
225		cls,
226		cfg: "T_DatasetConfig",
227		do_generate: bool = True,
228		load_local: bool = True,
229		save_local: bool = True,
230		zanj: ZANJ | None = None,
231		do_download: bool = True,
232		local_base_path: Path = Path("data/maze_dataset"),
233		except_on_config_mismatch: bool = True,
234		allow_generation_metadata_filter_mismatch: bool = True,
235		verbose: bool = False,
236		**kwargs,
237	) -> "Self":
238		"""base class for gpt datasets
239
240		priority of loading:
241		1. load from local
242		2. download
243		3. generate
244
245		"""
246		print_log: Callable = print if verbose else lambda *_a, **_kw: None
247
248		local_base_path = Path(local_base_path)
249		fname: Path = Path(f"{cfg.to_fname()}.zanj")
250		output: Self | None = None
251		did_load_local: bool = False
252		if zanj is None:
253			zanj = ZANJ()
254
255		print_log(f"trying to get the dataset '{cfg.to_fname()}'")
256
257		if not (load_local or do_download or do_generate):
258			raise ValueError(
259				"no way to load dataset! you said not to load local, not to download, and not to generate",
260			)
261
262		dataset_path: Path = local_base_path / fname
263
264		# try loading
265		if load_local:  # noqa: SIM102
266			if dataset_path.exists():
267				print_log(f"loading dataset from {dataset_path.as_posix()}")
268				try:
269					output = cls.read(dataset_path, zanj=zanj)
270					did_load_local = True
271					print_log("load successful!")
272				except Exception as e:  # noqa: BLE001
273					print_log(f"failed to load dataset: {e}")
274
275		if do_download and output is None:
276			print_log("seeing if we can download the dataset...")
277			try:
278				output = cls.download(cfg, **kwargs)
279				print_log("download successful!")
280			except NotImplementedError:
281				print_log("no download found, or download failed")
282
283		if do_generate and output is None:
284			print_log("generating dataset...")
285			output = cls.generate(cfg, verbose=verbose, **kwargs)
286			# only if we generated it, apply filters
287			output = output._apply_filters_from_config()
288
289		# check and save
290		if output is None:
291			raise ValueError("failed to load dataset!")
292
293		cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True)
294		if cfg_diff:
295			if except_on_config_mismatch:
296				if allow_generation_metadata_filter_mismatch and (
297					cfg_diff
298					== {
299						"applied_filters": {
300							"self": [],
301							"other": [
302								{
303									"name": "collect_generation_meta",
304									"args": (),
305									"kwargs": {},
306								},
307							],
308						},
309					}
310				):
311					pass
312				else:
313					err_msg: str = f"config mismatch: {cfg_diff = }"
314					raise ValueError(err_msg)
315			else:
316				warnings.warn(f"config mismatch: {cfg_diff = }")
317
318		if save_local and not did_load_local:
319			print_log(f"saving dataset to {dataset_path}")
320			output.save(dataset_path, zanj=zanj)
321
322		print_log(
323			f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }",
324		)
325		return output
326
327	def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None:
328		"save dataset to a file with zanj"
329		if zanj is None:
330			zanj = ZANJ()
331		zanj.save(self.serialize(), file_path)
332
333	# serialization & loading
334	@classmethod
335	def read(cls, file_path: str | Path, zanj: ZANJ | None = None) -> "Self":
336		"read dataset from a file with zanj"
337		if zanj is None:
338			zanj = ZANJ()
339		return zanj.read(file_path)
340
341	def serialize(self) -> JSONdict:
342		"(implement in subclass!) serialize to something we can save with zanj"
343		raise NotImplementedError
344
345	def data_hash(self) -> int:
346		"(implement in subclass!) return a hash of the data"
347		raise NotImplementedError
348
349	@classmethod
350	def load(cls, data: JSONdict) -> "Self":
351		"(implement in subclass!) load a dataset from what we made with `.serialize()`"
352		raise NotImplementedError
353
354	# generating & downloading
355	@classmethod
356	def generate(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self":
357		"(implement in subclass!) generative given the config"
358		raise NotImplementedError
359
360	@classmethod
361	def download(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self":
362		"(implement in subclass!) download the dataset given the config"
363		raise NotImplementedError
364
365	# filtering
366	def update_self_config(self) -> None:
367		"""(implement in subclass!) update the config of the dataset to match the actual data, if needed
368
369		for example, adjust number of mazes after filtering
370		"""
371		pass
372
373	def __len__(self) -> int:
374		"return the length of the dataset"
375		raise NotImplementedError("implement in subclass!")
376
377	class FilterBy:
378		"""thanks GPT-4"""
379
380		def __init__(self, dataset: "T_Dataset") -> None:
381			"mock class so we can call `my_dataset.filter_by.some_registered_filter()`"
382			self.dataset: T_Dataset = dataset
383
384		def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]:
385			"override getattr so we can call `my_dataset.filter_by.some_registered_filter()`"
386			filter_func: DatasetFilterFunc = getattr(
387				self.dataset._FILTER_NAMESPACE,
388				name,
389			)
390
391			def wrapped_filter_func(*args, **kwargs):  # noqa: ANN202
392				return filter_func(self.dataset, *args, **kwargs)
393
394			return wrapped_filter_func
395
396	@property
397	def filter_by(self) -> "FilterBy":
398		"can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset"
399		return self.FilterBy(self)
400
401	def _apply_filters_from_config(self) -> "Self":
402		"""apply filters to the dataset, as specified in the config. used in `from_config()`"""
403		output: Self = self
404		# copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters
405		applied_filters_old: list[
406			dict[typing.Literal["name", "args", "kwargs"], typing.Any]
407		] = self.cfg.applied_filters
408		output.cfg.applied_filters = list()
409		# apply the filters
410		for filter_info in applied_filters_old:
411			filter_name: str = filter_info["name"]
412			if filter_name not in output._FILTER_NAMESPACE.__dict__:
413				if filter_name.startswith("__custom__:"):
414					err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!"
415					raise ValueError(
416						err_msg,
417					)
418				err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'"
419				raise ValueError(
420					err_msg,
421				)
422			filter_args: list = filter_info.get("args", list())
423			filter_kwargs: dict = filter_info.get("kwargs", dict())
424			output = getattr(output.filter_by, filter_name)(
425				*filter_args,
426				**filter_kwargs,
427			)
428
429		# update the config, perform checks
430		# TODO: some funny business with manually specified filters here?
431		output.update_self_config()
432		_check_filter_equality(
433			filters_old=applied_filters_old,
434			filters_new=output.cfg.applied_filters,  # type: ignore[arg-type]
435		)
436		return output
437
438
439def _check_filter_equality(
440	filters_old: list[
441		dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
442	],
443	filters_new: list[
444		dict[typing.Literal["name", "args", "kwargs"], str | list | dict]
445	],
446) -> None:
447	try:
448		assert len(filters_old) == len(filters_new)
449
450		for filterinfo_new, filterinfo_old in zip(
451			filters_old,
452			filters_new,
453			strict=False,
454		):
455			# basic checks
456			assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict"
457			assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict"
458			assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), (
459				"missing keys in filterinfo_new"
460			)
461			assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), (
462				"missing keys in filterinfo_old"
463			)
464
465			# name
466			assert filterinfo_new["name"] == filterinfo_old["name"], (
467				"filter names don't match"
468			)
469
470			# args
471			assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), (
472				"filter args of different lengths"
473			)
474			for arg_new, arg_old in zip(
475				filterinfo_new["args"],
476				filterinfo_old["args"],
477				strict=False,
478			):
479				assert arg_new == arg_old, "filter args don't match"
480
481			# kwargs
482			assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), (
483				"filter kwargs of different lengths"
484			)
485			for key in filterinfo_old["kwargs"]:
486				assert key in filterinfo_new["kwargs"], (
487					f"filter kwargs don't match: missing key '{key}'"
488				)
489				assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], (  # type: ignore[index]
490					f"filter kwargs don't match: values for key '{key}' don't match"
491				)
492
493	except AssertionError as e:
494		err_msg: str = (
495			f"config mismatch in applied filters: {filters_new} != {filters_old}"
496		)
497		raise FilterInfoMismatchError(
498			err_msg,
499		) from e
500
501
502def register_filter_namespace_for_dataset(
503	dataset_cls: Type[GPTDataset],
504) -> Callable[[Type], Type]:
505	"""register the namespace class with the given dataset class"""
506
507	def decorator(filter_namespace_cls: Type) -> Type:
508		dataset_cls._FILTER_NAMESPACE = filter_namespace_cls
509		filter_namespace_cls._BASE_DATASET = dataset_cls
510
511		return filter_namespace_cls
512
513	return decorator
514
515
516T_Dataset = TypeVar("T_Dataset", bound=GPTDataset)
517P_FilterKwargs = typing.ParamSpec("P_FilterKwargs")
518DatasetFilterFunc = Callable[typing.Concatenate[T_Dataset, P_FilterKwargs], T_Dataset]
519
520
521def register_dataset_filter(
522	method: DatasetFilterFunc,
523) -> DatasetFilterFunc:
524	"""register a dataset filter, copying the underlying dataset and updating the config
525
526	be sure to return a COPY, not the original?
527	# TODO: what the heck do we mean by the above? why the question mark? it should be a copy right?
528
529	method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
530	"""
531
532	@functools.wraps(method)
533	def wrapper(
534		# TYPING: error: ParamSpec "P_FilterKwargs" is unbound  [valid-type]
535		dataset: T_Dataset,
536		*args: P_FilterKwargs.args,  # type: ignore[valid-type]
537		**kwargs: P_FilterKwargs.kwargs,  # type: ignore[valid-type]
538	) -> T_Dataset:
539		new_dataset = method(dataset, *args, **kwargs)
540		# update the config
541		new_dataset.cfg.applied_filters.append(
542			dict(name=method.__name__, args=args, kwargs=kwargs),  # type: ignore[attr-defined]
543		)
544		new_dataset.update_self_config()
545		return new_dataset
546
547	# TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]")  [return-value]
548	return wrapper  # type: ignore[return-value]

def set_reproducibility(seed: int) -> None:
35def set_reproducibility(seed: int) -> None:
36	"set reproducibility in stdlib random and numpy (but not torch)"
37	random.seed(seed)
38	np.random.seed(seed)

set reproducibility in stdlib random and numpy (but not torch)

class FilterInfoMismatchError(builtins.ValueError):
41class FilterInfoMismatchError(ValueError):
42	"""raised when the filter info in a dataset config does not match the filter info in the dataset"""
43
44	pass

raised when the filter info in a dataset config does not match the filter info in the dataset

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
@serializable_dataclass(kw_only=True)
class GPTDatasetConfig(muutils.json_serialize.serializable_dataclass.SerializableDataclass):
 66@serializable_dataclass(kw_only=True)
 67class GPTDatasetConfig(SerializableDataclass):
 68	"""base GPTDatasetConfig class"""
 69
 70	name: str
 71
 72	# TODO: get rid of all these things as part of migration to tokenizer-free dataset config
 73	# --------------------------------------------------
 74	seq_len_min: int = serializable_field(default=1)
 75	seq_len_max: int = serializable_field(default=512)
 76	# --------------------------------------------------
 77
 78	seed: int | None = serializable_field(default=GLOBAL_SEED)
 79	applied_filters: list[
 80		dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict]
 81	] = serializable_field(
 82		default_factory=list,
 83		deserialize_fn=_load_applied_filters,
 84		assert_type=False,  # TODO: check the type here once muutils supports checking Callable signatures
 85	)
 86
 87	def __post_init__(self) -> None:
 88		"post init, where we set a random seed if none is set"
 89		assert self.seq_len_min <= self.seq_len_max
 90		# if seed set to None, then generate a new random seed
 91		if self.seed is None:
 92			self.seed = np.random.randint(2**31)
 93
 94		# TODO: something here is broken
 95		if self.seed != GLOBAL_SEED:
 96			warnings.warn(
 97				f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }",
 98			)
 99
100		set_reproducibility(self.seed)
101
102	def summary(self) -> dict:
103		"""return a summary of the config"""
104		# do we run this to make sure it doesn't error?
105		self_ser: dict = self.serialize()
106		assert self_ser
107		return dict(
108			name=self.name,
109			seq_len_min=self.seq_len_min,
110			seq_len_max=self.seq_len_max,
111			seed=self.seed,
112			applied_filters=self.applied_filters,
113		)
114
115	@property
116	def _dataset_class(self) -> type:
117		raise NotImplementedError("this should be implemented by subclasses!")
118
119	def to_fname(self) -> str:
120		"""convert config to a filename"""
121		self_json_str: str = json.dumps(self.serialize())
122		self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10)
123		warnings.warn(
124			f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!",
125		)
126		return sanitize_fname(
127			# TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized"  [arg-type]
128			f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}",  # type: ignore[arg-type]
129		)

base GPTDatasetConfig class

GPTDatasetConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>)
name: str
seq_len_min: int = 1
seq_len_max: int = 512
seed: int | None = 42
applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]]
def summary(self) -> dict:
102	def summary(self) -> dict:
103		"""return a summary of the config"""
104		# do we run this to make sure it doesn't error?
105		self_ser: dict = self.serialize()
106		assert self_ser
107		return dict(
108			name=self.name,
109			seq_len_min=self.seq_len_min,
110			seq_len_max=self.seq_len_max,
111			seed=self.seed,
112			applied_filters=self.applied_filters,
113		)

return a summary of the config

def to_fname(self) -> str:
119	def to_fname(self) -> str:
120		"""convert config to a filename"""
121		self_json_str: str = json.dumps(self.serialize())
122		self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10)
123		warnings.warn(
124			f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!",
125		)
126		return sanitize_fname(
127			# TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized"  [arg-type]
128			f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}",  # type: ignore[arg-type]
129		)

convert config to a filename

def serialize( self, *args, **kwargs) -> Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]:
140def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem:  # noqa: ANN001, ARG001
141	err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}"
142	raise NotImplementedError(
143		err_msg,
144	)

The type of the None singleton.

def load(*args, **kwargs) -> GPTDatasetConfig:
132def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig":
133	err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}"
134	raise NotImplementedError(
135		err_msg,
136	)

The type of the None singleton.

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
class GPTDataset(typing.Generic[~T_DatasetConfig]):
152class GPTDataset(typing.Generic[T_DatasetConfig]):
153	"""wrapper for torch dataset with some extra functionality
154
155	(meaning the functionality should be inherited in downstream classes)
156
157	> [!NOTE]
158	> `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config
159
160	# Requires:
161	the following methods should be implemented in subclasses:
162	- `__init__(self, cfg: GPTDatasetConfig, **kwargs)`
163		initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)
164	- `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
165		generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)
166	- `serialize(self) -> JSONitem`
167		serialize the dataset to a ZANJ-serializable object, including:
168		- config
169		- data in formats specified by `self.save_formats`
170	- `load(cls, data: JSONitem) -> GPTDataset`
171		load the dataset from a ZANJ-serializable object
172	- `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset`
173		given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)
174	- `__len__(self) -> int`
175		return the length of the dataset, required to match interface of `torch.utils.data.Dataset`
176	- `__getitem__(self, i: int) -> list[str]`
177		return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset`
178	-  `update_self_config(self) -> None`
179		update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation
180	-  decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters
181
182	# Parameters:
183	- `cfg : GPTDatasetConfig`
184		config for the dataset, used to generate the dataset
185	- `do_generate : bool`
186		whether to generate the dataset if it isn't found
187		(defaults to `True`)
188	- `load_local : bool`
189		whether to try finding the dataset locally
190		(defaults to `True`)
191	- `save_local : bool`
192		whether to save the dataset locally if it is generated or downloaded
193		(defaults to `True`)
194	- `do_download : bool`
195		whether to try downloading the dataset
196		(defaults to `True`)
197	- `local_base_path : Path`
198		where to save the dataset
199		(defaults to `Path("data/maze_dataset")`)
200
201	# Returns:
202	- `GPTDataset`
203		the dataset, as you wanted it
204
205	# Implements:
206	- `save(self, file_path: str) -> None`
207		save the dataset to a file, using ZANJ
208	- `read(cls, file_path: str) -> GPTDataset`
209		read the dataset from a file, using ZANJ
210		get all items in the dataset, in the specified format
211	- `filter_by(self)`
212		returns a namespace class
213	-  `_filter_namespace(self) -> Class`
214		returns a namespace class for filtering the dataset, checking that method
215	- `_apply_filters_from_config(self) -> None`
216		apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating
217
218	"""
219
220	_FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`"  # type: ignore
221
222	cfg: "T_DatasetConfig"
223
224	@classmethod
225	def from_config(  # noqa: C901, PLR0912
226		cls,
227		cfg: "T_DatasetConfig",
228		do_generate: bool = True,
229		load_local: bool = True,
230		save_local: bool = True,
231		zanj: ZANJ | None = None,
232		do_download: bool = True,
233		local_base_path: Path = Path("data/maze_dataset"),
234		except_on_config_mismatch: bool = True,
235		allow_generation_metadata_filter_mismatch: bool = True,
236		verbose: bool = False,
237		**kwargs,
238	) -> "Self":
239		"""base class for gpt datasets
240
241		priority of loading:
242		1. load from local
243		2. download
244		3. generate
245
246		"""
247		print_log: Callable = print if verbose else lambda *_a, **_kw: None
248
249		local_base_path = Path(local_base_path)
250		fname: Path = Path(f"{cfg.to_fname()}.zanj")
251		output: Self | None = None
252		did_load_local: bool = False
253		if zanj is None:
254			zanj = ZANJ()
255
256		print_log(f"trying to get the dataset '{cfg.to_fname()}'")
257
258		if not (load_local or do_download or do_generate):
259			raise ValueError(
260				"no way to load dataset! you said not to load local, not to download, and not to generate",
261			)
262
263		dataset_path: Path = local_base_path / fname
264
265		# try loading
266		if load_local:  # noqa: SIM102
267			if dataset_path.exists():
268				print_log(f"loading dataset from {dataset_path.as_posix()}")
269				try:
270					output = cls.read(dataset_path, zanj=zanj)
271					did_load_local = True
272					print_log("load successful!")
273				except Exception as e:  # noqa: BLE001
274					print_log(f"failed to load dataset: {e}")
275
276		if do_download and output is None:
277			print_log("seeing if we can download the dataset...")
278			try:
279				output = cls.download(cfg, **kwargs)
280				print_log("download successful!")
281			except NotImplementedError:
282				print_log("no download found, or download failed")
283
284		if do_generate and output is None:
285			print_log("generating dataset...")
286			output = cls.generate(cfg, verbose=verbose, **kwargs)
287			# only if we generated it, apply filters
288			output = output._apply_filters_from_config()
289
290		# check and save
291		if output is None:
292			raise ValueError("failed to load dataset!")
293
294		cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True)
295		if cfg_diff:
296			if except_on_config_mismatch:
297				if allow_generation_metadata_filter_mismatch and (
298					cfg_diff
299					== {
300						"applied_filters": {
301							"self": [],
302							"other": [
303								{
304									"name": "collect_generation_meta",
305									"args": (),
306									"kwargs": {},
307								},
308							],
309						},
310					}
311				):
312					pass
313				else:
314					err_msg: str = f"config mismatch: {cfg_diff = }"
315					raise ValueError(err_msg)
316			else:
317				warnings.warn(f"config mismatch: {cfg_diff = }")
318
319		if save_local and not did_load_local:
320			print_log(f"saving dataset to {dataset_path}")
321			output.save(dataset_path, zanj=zanj)
322
323		print_log(
324			f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }",
325		)
326		return output
327
328	def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None:
329		"save dataset to a file with zanj"
330		if zanj is None:
331			zanj = ZANJ()
332		zanj.save(self.serialize(), file_path)
333
334	# serialization & loading
335	@classmethod
336	def read(cls, file_path: str | Path, zanj: ZANJ | None = None) -> "Self":
337		"read dataset from a file with zanj"
338		if zanj is None:
339			zanj = ZANJ()
340		return zanj.read(file_path)
341
342	def serialize(self) -> JSONdict:
343		"(implement in subclass!) serialize to something we can save with zanj"
344		raise NotImplementedError
345
346	def data_hash(self) -> int:
347		"(implement in subclass!) return a hash of the data"
348		raise NotImplementedError
349
350	@classmethod
351	def load(cls, data: JSONdict) -> "Self":
352		"(implement in subclass!) load a dataset from what we made with `.serialize()`"
353		raise NotImplementedError
354
355	# generating & downloading
356	@classmethod
357	def generate(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self":
358		"(implement in subclass!) generative given the config"
359		raise NotImplementedError
360
361	@classmethod
362	def download(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self":
363		"(implement in subclass!) download the dataset given the config"
364		raise NotImplementedError
365
366	# filtering
367	def update_self_config(self) -> None:
368		"""(implement in subclass!) update the config of the dataset to match the actual data, if needed
369
370		for example, adjust number of mazes after filtering
371		"""
372		pass
373
374	def __len__(self) -> int:
375		"return the length of the dataset"
376		raise NotImplementedError("implement in subclass!")
377
378	class FilterBy:
379		"""thanks GPT-4"""
380
381		def __init__(self, dataset: "T_Dataset") -> None:
382			"mock class so we can call `my_dataset.filter_by.some_registered_filter()`"
383			self.dataset: T_Dataset = dataset
384
385		def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]:
386			"override getattr so we can call `my_dataset.filter_by.some_registered_filter()`"
387			filter_func: DatasetFilterFunc = getattr(
388				self.dataset._FILTER_NAMESPACE,
389				name,
390			)
391
392			def wrapped_filter_func(*args, **kwargs):  # noqa: ANN202
393				return filter_func(self.dataset, *args, **kwargs)
394
395			return wrapped_filter_func
396
397	@property
398	def filter_by(self) -> "FilterBy":
399		"can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset"
400		return self.FilterBy(self)
401
402	def _apply_filters_from_config(self) -> "Self":
403		"""apply filters to the dataset, as specified in the config. used in `from_config()`"""
404		output: Self = self
405		# copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters
406		applied_filters_old: list[
407			dict[typing.Literal["name", "args", "kwargs"], typing.Any]
408		] = self.cfg.applied_filters
409		output.cfg.applied_filters = list()
410		# apply the filters
411		for filter_info in applied_filters_old:
412			filter_name: str = filter_info["name"]
413			if filter_name not in output._FILTER_NAMESPACE.__dict__:
414				if filter_name.startswith("__custom__:"):
415					err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!"
416					raise ValueError(
417						err_msg,
418					)
419				err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'"
420				raise ValueError(
421					err_msg,
422				)
423			filter_args: list = filter_info.get("args", list())
424			filter_kwargs: dict = filter_info.get("kwargs", dict())
425			output = getattr(output.filter_by, filter_name)(
426				*filter_args,
427				**filter_kwargs,
428			)
429
430		# update the config, perform checks
431		# TODO: some funny business with manually specified filters here?
432		output.update_self_config()
433		_check_filter_equality(
434			filters_old=applied_filters_old,
435			filters_new=output.cfg.applied_filters,  # type: ignore[arg-type]
436		)
437		return output

wrapper for torch dataset with some extra functionality

(meaning the functionality should be inherited in downstream classes)

Note

GPTDatasetConfig should implement a to_fname method that returns a unique filename for the config

Requires:

the following methods should be implemented in subclasses:

  • __init__(self, cfg: GPTDatasetConfig, **kwargs) initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably)
  • generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset generate the dataset from a given config. kwargs are passed through from from_config, and should only contain things that dont belong in the config (i.e. how many threads to use for generation)
  • serialize(self) -> JSONitem serialize the dataset to a ZANJ-serializable object, including:
    • config
    • data in formats specified by self.save_formats
  • load(cls, data: JSONitem) -> GPTDataset load the dataset from a ZANJ-serializable object
  • download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset given a config, try to download a dataset from some source. kwargs are passed through from from_config, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url)
  • __len__(self) -> int return the length of the dataset, required to match interface of torch.utils.data.Dataset
  • __getitem__(self, i: int) -> list[str] return the ith item in the dataset, required to match interface of torch.utils.data.Dataset
  • update_self_config(self) -> None update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation
  • decorating the appropriate filter namespace with register_filter_namespace_for_dataset(your_dataset_class) if you want to use filters

Parameters:

  • cfg : GPTDatasetConfig config for the dataset, used to generate the dataset
  • do_generate : bool whether to generate the dataset if it isn't found (defaults to True)
  • load_local : bool whether to try finding the dataset locally (defaults to True)
  • save_local : bool whether to save the dataset locally if it is generated or downloaded (defaults to True)
  • do_download : bool whether to try downloading the dataset (defaults to True)
  • local_base_path : Path where to save the dataset (defaults to Path("data/maze_dataset"))

Returns:

Implements:

  • save(self, file_path: str) -> None save the dataset to a file, using ZANJ
  • read(cls, file_path: str) -> GPTDataset read the dataset from a file, using ZANJ get all items in the dataset, in the specified format
  • filter_by(self) returns a namespace class
  • _filter_namespace(self) -> Class returns a namespace class for filtering the dataset, checking that method
  • _apply_filters_from_config(self) -> None apply filters to the dataset, as specified in the config. used in from_config() but only when generating
cfg: ~T_DatasetConfig
@classmethod
def from_config( cls, cfg: ~T_DatasetConfig, do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: zanj.zanj.ZANJ | None = None, do_download: bool = True, local_base_path: pathlib._local.Path = PosixPath('data/maze_dataset'), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs) -> Self:
224	@classmethod
225	def from_config(  # noqa: C901, PLR0912
226		cls,
227		cfg: "T_DatasetConfig",
228		do_generate: bool = True,
229		load_local: bool = True,
230		save_local: bool = True,
231		zanj: ZANJ | None = None,
232		do_download: bool = True,
233		local_base_path: Path = Path("data/maze_dataset"),
234		except_on_config_mismatch: bool = True,
235		allow_generation_metadata_filter_mismatch: bool = True,
236		verbose: bool = False,
237		**kwargs,
238	) -> "Self":
239		"""base class for gpt datasets
240
241		priority of loading:
242		1. load from local
243		2. download
244		3. generate
245
246		"""
247		print_log: Callable = print if verbose else lambda *_a, **_kw: None
248
249		local_base_path = Path(local_base_path)
250		fname: Path = Path(f"{cfg.to_fname()}.zanj")
251		output: Self | None = None
252		did_load_local: bool = False
253		if zanj is None:
254			zanj = ZANJ()
255
256		print_log(f"trying to get the dataset '{cfg.to_fname()}'")
257
258		if not (load_local or do_download or do_generate):
259			raise ValueError(
260				"no way to load dataset! you said not to load local, not to download, and not to generate",
261			)
262
263		dataset_path: Path = local_base_path / fname
264
265		# try loading
266		if load_local:  # noqa: SIM102
267			if dataset_path.exists():
268				print_log(f"loading dataset from {dataset_path.as_posix()}")
269				try:
270					output = cls.read(dataset_path, zanj=zanj)
271					did_load_local = True
272					print_log("load successful!")
273				except Exception as e:  # noqa: BLE001
274					print_log(f"failed to load dataset: {e}")
275
276		if do_download and output is None:
277			print_log("seeing if we can download the dataset...")
278			try:
279				output = cls.download(cfg, **kwargs)
280				print_log("download successful!")
281			except NotImplementedError:
282				print_log("no download found, or download failed")
283
284		if do_generate and output is None:
285			print_log("generating dataset...")
286			output = cls.generate(cfg, verbose=verbose, **kwargs)
287			# only if we generated it, apply filters
288			output = output._apply_filters_from_config()
289
290		# check and save
291		if output is None:
292			raise ValueError("failed to load dataset!")
293
294		cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True)
295		if cfg_diff:
296			if except_on_config_mismatch:
297				if allow_generation_metadata_filter_mismatch and (
298					cfg_diff
299					== {
300						"applied_filters": {
301							"self": [],
302							"other": [
303								{
304									"name": "collect_generation_meta",
305									"args": (),
306									"kwargs": {},
307								},
308							],
309						},
310					}
311				):
312					pass
313				else:
314					err_msg: str = f"config mismatch: {cfg_diff = }"
315					raise ValueError(err_msg)
316			else:
317				warnings.warn(f"config mismatch: {cfg_diff = }")
318
319		if save_local and not did_load_local:
320			print_log(f"saving dataset to {dataset_path}")
321			output.save(dataset_path, zanj=zanj)
322
323		print_log(
324			f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }",
325		)
326		return output

base class for gpt datasets

priority of loading:

  1. load from local
  2. download
  3. generate
def save( self, file_path: pathlib._local.Path | str, zanj: zanj.zanj.ZANJ | None = None) -> None:
328	def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None:
329		"save dataset to a file with zanj"
330		if zanj is None:
331			zanj = ZANJ()
332		zanj.save(self.serialize(), file_path)

save dataset to a file with zanj

@classmethod
def read( cls, file_path: str | pathlib._local.Path, zanj: zanj.zanj.ZANJ | None = None) -> Self:
335	@classmethod
336	def read(cls, file_path: str | Path, zanj: ZANJ | None = None) -> "Self":
337		"read dataset from a file with zanj"
338		if zanj is None:
339			zanj = ZANJ()
340		return zanj.read(file_path)

read dataset from a file with zanj

def serialize( self) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
342	def serialize(self) -> JSONdict:
343		"(implement in subclass!) serialize to something we can save with zanj"
344		raise NotImplementedError

(implement in subclass!) serialize to something we can save with zanj

def data_hash(self) -> int:
346	def data_hash(self) -> int:
347		"(implement in subclass!) return a hash of the data"
348		raise NotImplementedError

(implement in subclass!) return a hash of the data

@classmethod
def load( cls, data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> Self:
350	@classmethod
351	def load(cls, data: JSONdict) -> "Self":
352		"(implement in subclass!) load a dataset from what we made with `.serialize()`"
353		raise NotImplementedError

(implement in subclass!) load a dataset from what we made with .serialize()

@classmethod
def generate(cls, cfg: ~T_DatasetConfig, **kwargs) -> Self:
356	@classmethod
357	def generate(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self":
358		"(implement in subclass!) generative given the config"
359		raise NotImplementedError

(implement in subclass!) generative given the config

@classmethod
def download(cls, cfg: ~T_DatasetConfig, **kwargs) -> Self:
361	@classmethod
362	def download(cls, cfg: "T_DatasetConfig", **kwargs) -> "Self":
363		"(implement in subclass!) download the dataset given the config"
364		raise NotImplementedError

(implement in subclass!) download the dataset given the config

def update_self_config(self) -> None:
367	def update_self_config(self) -> None:
368		"""(implement in subclass!) update the config of the dataset to match the actual data, if needed
369
370		for example, adjust number of mazes after filtering
371		"""
372		pass

(implement in subclass!) update the config of the dataset to match the actual data, if needed

for example, adjust number of mazes after filtering

filter_by: GPTDataset.FilterBy
397	@property
398	def filter_by(self) -> "FilterBy":
399		"can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset"
400		return self.FilterBy(self)

can call my_dataset.filter_by.some_registered_filter() to filter the dataset

class GPTDataset.FilterBy:
378	class FilterBy:
379		"""thanks GPT-4"""
380
381		def __init__(self, dataset: "T_Dataset") -> None:
382			"mock class so we can call `my_dataset.filter_by.some_registered_filter()`"
383			self.dataset: T_Dataset = dataset
384
385		def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]:
386			"override getattr so we can call `my_dataset.filter_by.some_registered_filter()`"
387			filter_func: DatasetFilterFunc = getattr(
388				self.dataset._FILTER_NAMESPACE,
389				name,
390			)
391
392			def wrapped_filter_func(*args, **kwargs):  # noqa: ANN202
393				return filter_func(self.dataset, *args, **kwargs)
394
395			return wrapped_filter_func

thanks GPT-4

GPTDataset.FilterBy(dataset: ~T_Dataset)
381		def __init__(self, dataset: "T_Dataset") -> None:
382			"mock class so we can call `my_dataset.filter_by.some_registered_filter()`"
383			self.dataset: T_Dataset = dataset

mock class so we can call my_dataset.filter_by.some_registered_filter()

dataset: ~T_Dataset
def register_filter_namespace_for_dataset( dataset_cls: Type[GPTDataset]) -> Callable[[Type], Type]:
503def register_filter_namespace_for_dataset(
504	dataset_cls: Type[GPTDataset],
505) -> Callable[[Type], Type]:
506	"""register the namespace class with the given dataset class"""
507
508	def decorator(filter_namespace_cls: Type) -> Type:
509		dataset_cls._FILTER_NAMESPACE = filter_namespace_cls
510		filter_namespace_cls._BASE_DATASET = dataset_cls
511
512		return filter_namespace_cls
513
514	return decorator

register the namespace class with the given dataset class

P_FilterKwargs = ~P_FilterKwargs
DatasetFilterFunc = typing.Callable[typing.Concatenate[~T_Dataset, ~P_FilterKwargs], ~T_Dataset]
def register_dataset_filter( method: Callable[Concatenate[~T_Dataset, ~P_FilterKwargs], ~T_Dataset]) -> Callable[Concatenate[~T_Dataset, ~P_FilterKwargs], ~T_Dataset]:
522def register_dataset_filter(
523	method: DatasetFilterFunc,
524) -> DatasetFilterFunc:
525	"""register a dataset filter, copying the underlying dataset and updating the config
526
527	be sure to return a COPY, not the original?
528	# TODO: what the heck do we mean by the above? why the question mark? it should be a copy right?
529
530	method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
531	"""
532
533	@functools.wraps(method)
534	def wrapper(
535		# TYPING: error: ParamSpec "P_FilterKwargs" is unbound  [valid-type]
536		dataset: T_Dataset,
537		*args: P_FilterKwargs.args,  # type: ignore[valid-type]
538		**kwargs: P_FilterKwargs.kwargs,  # type: ignore[valid-type]
539	) -> T_Dataset:
540		new_dataset = method(dataset, *args, **kwargs)
541		# update the config
542		new_dataset.cfg.applied_filters.append(
543			dict(name=method.__name__, args=args, kwargs=kwargs),  # type: ignore[attr-defined]
544		)
545		new_dataset.update_self_config()
546		return new_dataset
547
548	# TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]")  [return-value]
549	return wrapper  # type: ignore[return-value]

register a dataset filter, copying the underlying dataset and updating the config

be sure to return a COPY, not the original?

TODO: what the heck do we mean by the above? why the question mark? it should be a copy right?

method should be a staticmethod of a namespace class registered with register_filter_namespace_for_dataset