docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.dataset.collected_dataset

collecting different maze datasets into a single dataset, for greater variety in a training or validation set

Caution

MazeDatasetCollection is not thoroughly tested and is not guaranteed to work.


  1"""collecting different maze datasets into a single dataset, for greater variety in a training or validation set
  2
  3> [!CAUTION]
  4> `MazeDatasetCollection` is not thoroughly tested and is not guaranteed to work.
  5
  6"""
  7
  8import itertools
  9import json
 10import typing
 11from functools import cached_property
 12
 13import numpy as np
 14from jaxtyping import Int
 15from muutils.json_serialize import (
 16	json_serialize,
 17	serializable_dataclass,
 18	serializable_field,
 19)
 20from muutils.json_serialize.util import _FORMAT_KEY, JSONdict
 21from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash
 22from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler
 23
 24from maze_dataset.constants import Coord, CoordTup
 25from maze_dataset.dataset.dataset import GPTDataset, GPTDatasetConfig
 26from maze_dataset.dataset.maze_dataset import MazeDataset, MazeDatasetConfig
 27from maze_dataset.maze import LatticeMaze
 28
 29
 30@serializable_dataclass(kw_only=True)
 31class MazeDatasetCollectionConfig(GPTDatasetConfig):
 32	"""maze dataset collection configuration, including tokenizers and shuffle"""
 33
 34	# Attributes without a default cannot follow attributes with one  [misc]
 35	maze_dataset_configs: list[MazeDatasetConfig] = serializable_field(  # type: ignore[misc]
 36		serialization_fn=lambda configs: [config.serialize() for config in configs],
 37		loading_fn=lambda data: [
 38			MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"]
 39		],
 40	)
 41
 42	def summary(self) -> dict:
 43		"""return a summary of the config"""
 44		return dict(
 45			n_mazes=self.n_mazes,
 46			max_grid_n=self.max_grid_n,
 47			max_grid_shape=self.max_grid_shape,
 48			fname=self.to_fname(),
 49			cfg_summaries=[c.summary() for c in self.maze_dataset_configs],
 50		)
 51
 52	@property
 53	def n_mazes(self) -> int:
 54		"""return the total number of mazes in the collection across all dataset"""
 55		return sum(config.n_mazes for config in self.maze_dataset_configs)
 56
 57	@property
 58	def max_grid_n(self) -> int:
 59		"""return the maximum grid size of the mazes in the collection"""
 60		return max(config.grid_n for config in self.maze_dataset_configs)
 61
 62	@property
 63	def max_grid_shape(self) -> CoordTup:
 64		"""return the maximum grid shape of the mazes in the collection"""
 65		return (self.max_grid_n, self.max_grid_n)
 66
 67	@property
 68	def max_grid_shape_np(self) -> Coord:
 69		"""return the maximum grid shape of the mazes in the collection as a numpy array"""
 70		return np.array(self.max_grid_shape, dtype=np.int32)
 71
 72	def stable_hash_cfg(self) -> int:
 73		"""return a stable hash of the config"""
 74		return stable_hash(json.dumps(self.serialize()))
 75
 76	def to_fname(self) -> str:
 77		"""convert config to a filename"""
 78		return sanitize_fname(
 79			f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}",
 80		)
 81
 82
 83class MazeDatasetCollection(GPTDataset):
 84	"""a collection of maze datasets"""
 85
 86	def __init__(
 87		self,
 88		cfg: MazeDatasetCollectionConfig,
 89		maze_datasets: list[MazeDataset],
 90		generation_metadata_collected: dict | None = None,
 91	) -> None:
 92		"initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s"
 93		super().__init__()
 94		self.cfg: MazeDatasetCollectionConfig = cfg
 95		self.maze_datasets: list[MazeDataset] = list(maze_datasets)
 96		for c, ds in zip(
 97			self.cfg.maze_dataset_configs,
 98			self.maze_datasets,
 99			strict=False,
100		):
101			assert c.name == ds.cfg.name
102			assert c == ds.cfg
103
104		self.generation_metadata_collected: dict | None = generation_metadata_collected
105
106	@property
107	def dataset_lengths(self) -> list[int]:
108		"""return the lengths of each dataset in the collection"""
109		return [len(dataset) for dataset in self.maze_datasets]
110
111	@property
112	def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
113		"""return the cumulative lengths of each dataset in the collection"""
114		return np.array(list(itertools.accumulate(self.dataset_lengths)))
115
116	@cached_property
117	def mazes(self) -> list[LatticeMaze]:
118		"single list of all mazes in the collection"
119		return list(
120			itertools.chain.from_iterable(
121				dataset.mazes for dataset in self.maze_datasets
122			),
123		)
124
125	def __len__(self) -> int:
126		"""return the total number of mazes in the collection"""
127		return sum(len(dataset) for dataset in self.maze_datasets)
128
129	def __getitem__(self, index: int) -> LatticeMaze:
130		"get a maze by index"
131		# find which dataset the index belongs to
132		# we add 1, since np.searchsorted returns the
133		# index of the last element that is strictly less than the target
134		# while we want the index of the last element less than or equal to the target
135		dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1))
136		index_adjusted: int = index
137		if dataset_idx > 0:
138			# if the index is 0, `dataset_idx - 1` will be -1.
139			# We just want to use the base index
140			index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1]
141		return self.maze_datasets[dataset_idx][index_adjusted]
142
143	@classmethod
144	def generate(
145		cls,
146		cfg: MazeDatasetCollectionConfig,
147		**kwargs,
148	) -> "MazeDatasetCollection":
149		"""generate a dataset collection from a config"""
150		datasets = [
151			MazeDataset.generate(config, **kwargs)
152			for config in cfg.maze_dataset_configs
153		]
154		return cls(cfg, datasets)
155
156	@classmethod
157	def download(
158		cls,
159		cfg: MazeDatasetCollectionConfig,
160		**kwargs,
161	) -> "MazeDatasetCollection":
162		"(not implemented!) download a dataset collection from a config"
163		datasets = [
164			MazeDataset.download(config, **kwargs)
165			for config in cfg.maze_dataset_configs
166		]
167		return cls(cfg, datasets)
168
169	def serialize(self) -> JSONdict:
170		"""serialize the dataset collection"""
171		return {
172			_FORMAT_KEY: "MazeDatasetCollection",
173			"cfg": self.cfg.serialize(),
174			"maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
175			"generation_metadata_collected": json_serialize(
176				self.generation_metadata_collected,
177			),
178		}
179
180	@classmethod
181	def load(cls, data: JSONdict) -> "MazeDatasetCollection":
182		"""load the dataset collection from the representation created by `serialize`"""
183		assert data[_FORMAT_KEY] == "MazeDatasetCollection"
184		return cls(
185			**{
186				key: load_item_recursive(data[key], tuple())
187				for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
188			},
189		)
190
191	# TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow?
192	def as_tokens(
193		self,
194		# TODO: MazeTokenizer
195		maze_tokenizer,  # noqa: ANN001
196		limit: int | None = None,
197		join_tokens_individual_maze: bool = False,
198	) -> list[list[str]] | list[str]:
199		"""return the dataset as tokens
200
201		if join_tokens_individual_maze is True, then the tokens of each maze are
202		joined with a space, and the result is a list of strings.
203		i.e.:
204		>>> dataset.as_tokens(join_tokens_individual_maze=False)
205		[["a", "b", "c"], ["d", "e", "f"]]
206		>>> dataset.as_tokens(join_tokens_individual_maze=True)
207		["a b c", "d e f"]
208		"""
209		output: list[list[str]] = [
210			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
211		]
212		if join_tokens_individual_maze:
213			return [" ".join(tokens) for tokens in output]
214		else:
215			return output
216
217	def update_self_config(self) -> None:
218		"update the config to match the number of mazes, and update the underlying configs of each dataset"
219		# TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
220		self.cfg.__dict__["n_mazes"] = len(self)
221		for dataset in self.maze_datasets:
222			dataset.update_self_config()
223
224		self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
225
226
227MazeDatasetCollectionConfig._dataset_class = MazeDatasetCollection  # type: ignore[method-assign, assignment]
228
229register_loader_handler(
230	LoaderHandler(
231		check=lambda json_item, path=None, z=None: (  # type: ignore[misc] # noqa: ARG005
232			isinstance(json_item, typing.Mapping)
233			and _FORMAT_KEY in json_item
234			and json_item[_FORMAT_KEY].startswith("MazeDatasetCollection")
235		),
236		load=lambda json_item, path=None, z=None: MazeDatasetCollection.load(json_item),  # type: ignore[misc] # noqa: ARG005
237		uid="MazeDatasetCollection",
238		source_pckg="maze_dataset.generation.maze_dataset_collection",
239		desc="MazeDatasetCollection",
240	),
241)

@serializable_dataclass(kw_only=True)
class MazeDatasetCollectionConfig(maze_dataset.dataset.dataset.GPTDatasetConfig):
31@serializable_dataclass(kw_only=True)
32class MazeDatasetCollectionConfig(GPTDatasetConfig):
33	"""maze dataset collection configuration, including tokenizers and shuffle"""
34
35	# Attributes without a default cannot follow attributes with one  [misc]
36	maze_dataset_configs: list[MazeDatasetConfig] = serializable_field(  # type: ignore[misc]
37		serialization_fn=lambda configs: [config.serialize() for config in configs],
38		loading_fn=lambda data: [
39			MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"]
40		],
41	)
42
43	def summary(self) -> dict:
44		"""return a summary of the config"""
45		return dict(
46			n_mazes=self.n_mazes,
47			max_grid_n=self.max_grid_n,
48			max_grid_shape=self.max_grid_shape,
49			fname=self.to_fname(),
50			cfg_summaries=[c.summary() for c in self.maze_dataset_configs],
51		)
52
53	@property
54	def n_mazes(self) -> int:
55		"""return the total number of mazes in the collection across all dataset"""
56		return sum(config.n_mazes for config in self.maze_dataset_configs)
57
58	@property
59	def max_grid_n(self) -> int:
60		"""return the maximum grid size of the mazes in the collection"""
61		return max(config.grid_n for config in self.maze_dataset_configs)
62
63	@property
64	def max_grid_shape(self) -> CoordTup:
65		"""return the maximum grid shape of the mazes in the collection"""
66		return (self.max_grid_n, self.max_grid_n)
67
68	@property
69	def max_grid_shape_np(self) -> Coord:
70		"""return the maximum grid shape of the mazes in the collection as a numpy array"""
71		return np.array(self.max_grid_shape, dtype=np.int32)
72
73	def stable_hash_cfg(self) -> int:
74		"""return a stable hash of the config"""
75		return stable_hash(json.dumps(self.serialize()))
76
77	def to_fname(self) -> str:
78		"""convert config to a filename"""
79		return sanitize_fname(
80			f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}",
81		)

maze dataset collection configuration, including tokenizers and shuffle

MazeDatasetCollectionConfig( *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, maze_dataset_configs: list[maze_dataset.MazeDatasetConfig])
maze_dataset_configs: list[maze_dataset.MazeDatasetConfig]
def summary(self) -> dict:
43	def summary(self) -> dict:
44		"""return a summary of the config"""
45		return dict(
46			n_mazes=self.n_mazes,
47			max_grid_n=self.max_grid_n,
48			max_grid_shape=self.max_grid_shape,
49			fname=self.to_fname(),
50			cfg_summaries=[c.summary() for c in self.maze_dataset_configs],
51		)

return a summary of the config

n_mazes: int
53	@property
54	def n_mazes(self) -> int:
55		"""return the total number of mazes in the collection across all dataset"""
56		return sum(config.n_mazes for config in self.maze_dataset_configs)

return the total number of mazes in the collection across all dataset

max_grid_n: int
58	@property
59	def max_grid_n(self) -> int:
60		"""return the maximum grid size of the mazes in the collection"""
61		return max(config.grid_n for config in self.maze_dataset_configs)

return the maximum grid size of the mazes in the collection

max_grid_shape: tuple[int, int]
63	@property
64	def max_grid_shape(self) -> CoordTup:
65		"""return the maximum grid shape of the mazes in the collection"""
66		return (self.max_grid_n, self.max_grid_n)

return the maximum grid shape of the mazes in the collection

max_grid_shape_np: jaxtyping.Int8[ndarray, 'row_col=2']
68	@property
69	def max_grid_shape_np(self) -> Coord:
70		"""return the maximum grid shape of the mazes in the collection as a numpy array"""
71		return np.array(self.max_grid_shape, dtype=np.int32)

return the maximum grid shape of the mazes in the collection as a numpy array

def stable_hash_cfg(self) -> int:
73	def stable_hash_cfg(self) -> int:
74		"""return a stable hash of the config"""
75		return stable_hash(json.dumps(self.serialize()))

return a stable hash of the config

def to_fname(self) -> str:
77	def to_fname(self) -> str:
78		"""convert config to a filename"""
79		return sanitize_fname(
80			f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}",
81		)

convert config to a filename

def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

returns the class as a dict, implemented by using @serializable_dataclass decorator

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

Inherited Members
maze_dataset.dataset.dataset.GPTDatasetConfig
name
seq_len_min
seq_len_max
seed
applied_filters
muutils.json_serialize.serializable_dataclass.SerializableDataclass
validate_field_type
diff
update_from_nested_dict
class MazeDatasetCollection(typing.Generic[~T_DatasetConfig]):
 84class MazeDatasetCollection(GPTDataset):
 85	"""a collection of maze datasets"""
 86
 87	def __init__(
 88		self,
 89		cfg: MazeDatasetCollectionConfig,
 90		maze_datasets: list[MazeDataset],
 91		generation_metadata_collected: dict | None = None,
 92	) -> None:
 93		"initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s"
 94		super().__init__()
 95		self.cfg: MazeDatasetCollectionConfig = cfg
 96		self.maze_datasets: list[MazeDataset] = list(maze_datasets)
 97		for c, ds in zip(
 98			self.cfg.maze_dataset_configs,
 99			self.maze_datasets,
100			strict=False,
101		):
102			assert c.name == ds.cfg.name
103			assert c == ds.cfg
104
105		self.generation_metadata_collected: dict | None = generation_metadata_collected
106
107	@property
108	def dataset_lengths(self) -> list[int]:
109		"""return the lengths of each dataset in the collection"""
110		return [len(dataset) for dataset in self.maze_datasets]
111
112	@property
113	def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
114		"""return the cumulative lengths of each dataset in the collection"""
115		return np.array(list(itertools.accumulate(self.dataset_lengths)))
116
117	@cached_property
118	def mazes(self) -> list[LatticeMaze]:
119		"single list of all mazes in the collection"
120		return list(
121			itertools.chain.from_iterable(
122				dataset.mazes for dataset in self.maze_datasets
123			),
124		)
125
126	def __len__(self) -> int:
127		"""return the total number of mazes in the collection"""
128		return sum(len(dataset) for dataset in self.maze_datasets)
129
130	def __getitem__(self, index: int) -> LatticeMaze:
131		"get a maze by index"
132		# find which dataset the index belongs to
133		# we add 1, since np.searchsorted returns the
134		# index of the last element that is strictly less than the target
135		# while we want the index of the last element less than or equal to the target
136		dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1))
137		index_adjusted: int = index
138		if dataset_idx > 0:
139			# if the index is 0, `dataset_idx - 1` will be -1.
140			# We just want to use the base index
141			index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1]
142		return self.maze_datasets[dataset_idx][index_adjusted]
143
144	@classmethod
145	def generate(
146		cls,
147		cfg: MazeDatasetCollectionConfig,
148		**kwargs,
149	) -> "MazeDatasetCollection":
150		"""generate a dataset collection from a config"""
151		datasets = [
152			MazeDataset.generate(config, **kwargs)
153			for config in cfg.maze_dataset_configs
154		]
155		return cls(cfg, datasets)
156
157	@classmethod
158	def download(
159		cls,
160		cfg: MazeDatasetCollectionConfig,
161		**kwargs,
162	) -> "MazeDatasetCollection":
163		"(not implemented!) download a dataset collection from a config"
164		datasets = [
165			MazeDataset.download(config, **kwargs)
166			for config in cfg.maze_dataset_configs
167		]
168		return cls(cfg, datasets)
169
170	def serialize(self) -> JSONdict:
171		"""serialize the dataset collection"""
172		return {
173			_FORMAT_KEY: "MazeDatasetCollection",
174			"cfg": self.cfg.serialize(),
175			"maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
176			"generation_metadata_collected": json_serialize(
177				self.generation_metadata_collected,
178			),
179		}
180
181	@classmethod
182	def load(cls, data: JSONdict) -> "MazeDatasetCollection":
183		"""load the dataset collection from the representation created by `serialize`"""
184		assert data[_FORMAT_KEY] == "MazeDatasetCollection"
185		return cls(
186			**{
187				key: load_item_recursive(data[key], tuple())
188				for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
189			},
190		)
191
192	# TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow?
193	def as_tokens(
194		self,
195		# TODO: MazeTokenizer
196		maze_tokenizer,  # noqa: ANN001
197		limit: int | None = None,
198		join_tokens_individual_maze: bool = False,
199	) -> list[list[str]] | list[str]:
200		"""return the dataset as tokens
201
202		if join_tokens_individual_maze is True, then the tokens of each maze are
203		joined with a space, and the result is a list of strings.
204		i.e.:
205		>>> dataset.as_tokens(join_tokens_individual_maze=False)
206		[["a", "b", "c"], ["d", "e", "f"]]
207		>>> dataset.as_tokens(join_tokens_individual_maze=True)
208		["a b c", "d e f"]
209		"""
210		output: list[list[str]] = [
211			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
212		]
213		if join_tokens_individual_maze:
214			return [" ".join(tokens) for tokens in output]
215		else:
216			return output
217
218	def update_self_config(self) -> None:
219		"update the config to match the number of mazes, and update the underlying configs of each dataset"
220		# TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
221		self.cfg.__dict__["n_mazes"] = len(self)
222		for dataset in self.maze_datasets:
223			dataset.update_self_config()
224
225		self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]

a collection of maze datasets

MazeDatasetCollection( cfg: MazeDatasetCollectionConfig, maze_datasets: list[maze_dataset.MazeDataset], generation_metadata_collected: dict | None = None)
 87	def __init__(
 88		self,
 89		cfg: MazeDatasetCollectionConfig,
 90		maze_datasets: list[MazeDataset],
 91		generation_metadata_collected: dict | None = None,
 92	) -> None:
 93		"initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s"
 94		super().__init__()
 95		self.cfg: MazeDatasetCollectionConfig = cfg
 96		self.maze_datasets: list[MazeDataset] = list(maze_datasets)
 97		for c, ds in zip(
 98			self.cfg.maze_dataset_configs,
 99			self.maze_datasets,
100			strict=False,
101		):
102			assert c.name == ds.cfg.name
103			assert c == ds.cfg
104
105		self.generation_metadata_collected: dict | None = generation_metadata_collected

initialize the dataset collection from a MazeDatasetCollectionConfig and a list of MazeDatasets

maze_datasets: list[maze_dataset.MazeDataset]
generation_metadata_collected: dict | None
dataset_lengths: list[int]
107	@property
108	def dataset_lengths(self) -> list[int]:
109		"""return the lengths of each dataset in the collection"""
110		return [len(dataset) for dataset in self.maze_datasets]

return the lengths of each dataset in the collection

dataset_cum_lengths: jaxtyping.Int[ndarray, 'indices']
112	@property
113	def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]:
114		"""return the cumulative lengths of each dataset in the collection"""
115		return np.array(list(itertools.accumulate(self.dataset_lengths)))

return the cumulative lengths of each dataset in the collection

117	@cached_property
118	def mazes(self) -> list[LatticeMaze]:
119		"single list of all mazes in the collection"
120		return list(
121			itertools.chain.from_iterable(
122				dataset.mazes for dataset in self.maze_datasets
123			),
124		)

single list of all mazes in the collection

@classmethod
def generate( cls, cfg: MazeDatasetCollectionConfig, **kwargs) -> MazeDatasetCollection:
144	@classmethod
145	def generate(
146		cls,
147		cfg: MazeDatasetCollectionConfig,
148		**kwargs,
149	) -> "MazeDatasetCollection":
150		"""generate a dataset collection from a config"""
151		datasets = [
152			MazeDataset.generate(config, **kwargs)
153			for config in cfg.maze_dataset_configs
154		]
155		return cls(cfg, datasets)

generate a dataset collection from a config

@classmethod
def download( cls, cfg: MazeDatasetCollectionConfig, **kwargs) -> MazeDatasetCollection:
157	@classmethod
158	def download(
159		cls,
160		cfg: MazeDatasetCollectionConfig,
161		**kwargs,
162	) -> "MazeDatasetCollection":
163		"(not implemented!) download a dataset collection from a config"
164		datasets = [
165			MazeDataset.download(config, **kwargs)
166			for config in cfg.maze_dataset_configs
167		]
168		return cls(cfg, datasets)

(not implemented!) download a dataset collection from a config

def serialize( self) -> Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]:
170	def serialize(self) -> JSONdict:
171		"""serialize the dataset collection"""
172		return {
173			_FORMAT_KEY: "MazeDatasetCollection",
174			"cfg": self.cfg.serialize(),
175			"maze_datasets": [dataset.serialize() for dataset in self.maze_datasets],
176			"generation_metadata_collected": json_serialize(
177				self.generation_metadata_collected,
178			),
179		}

serialize the dataset collection

@classmethod
def load( cls, data: Dict[str, Union[bool, int, float, str, NoneType, List[Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]], Dict[str, Union[bool, int, float, str, NoneType, List[Any], Dict[str, Any]]]]]) -> MazeDatasetCollection:
181	@classmethod
182	def load(cls, data: JSONdict) -> "MazeDatasetCollection":
183		"""load the dataset collection from the representation created by `serialize`"""
184		assert data[_FORMAT_KEY] == "MazeDatasetCollection"
185		return cls(
186			**{
187				key: load_item_recursive(data[key], tuple())
188				for key in ["cfg", "maze_datasets", "generation_metadata_collected"]
189			},
190		)

load the dataset collection from the representation created by serialize

def as_tokens( self, maze_tokenizer, limit: int | None = None, join_tokens_individual_maze: bool = False) -> list[list[str]] | list[str]:
193	def as_tokens(
194		self,
195		# TODO: MazeTokenizer
196		maze_tokenizer,  # noqa: ANN001
197		limit: int | None = None,
198		join_tokens_individual_maze: bool = False,
199	) -> list[list[str]] | list[str]:
200		"""return the dataset as tokens
201
202		if join_tokens_individual_maze is True, then the tokens of each maze are
203		joined with a space, and the result is a list of strings.
204		i.e.:
205		>>> dataset.as_tokens(join_tokens_individual_maze=False)
206		[["a", "b", "c"], ["d", "e", "f"]]
207		>>> dataset.as_tokens(join_tokens_individual_maze=True)
208		["a b c", "d e f"]
209		"""
210		output: list[list[str]] = [
211			maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit]
212		]
213		if join_tokens_individual_maze:
214			return [" ".join(tokens) for tokens in output]
215		else:
216			return output

return the dataset as tokens

if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:

>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
def update_self_config(self) -> None:
218	def update_self_config(self) -> None:
219		"update the config to match the number of mazes, and update the underlying configs of each dataset"
220		# TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset
221		self.cfg.__dict__["n_mazes"] = len(self)
222		for dataset in self.maze_datasets:
223			dataset.update_self_config()
224
225		self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]

update the config to match the number of mazes, and update the underlying configs of each dataset