maze_dataset.dataset.collected_dataset
collecting different maze datasets into a single dataset, for greater variety in a training or validation set
MazeDatasetCollection
is not thoroughly tested and is not guaranteed to work.
1"""collecting different maze datasets into a single dataset, for greater variety in a training or validation set 2 3> [!CAUTION] 4> `MazeDatasetCollection` is not thoroughly tested and is not guaranteed to work. 5 6""" 7 8import itertools 9import json 10import typing 11from functools import cached_property 12 13import numpy as np 14from jaxtyping import Int 15from muutils.json_serialize import ( 16 json_serialize, 17 serializable_dataclass, 18 serializable_field, 19) 20from muutils.json_serialize.util import _FORMAT_KEY, JSONdict 21from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash 22from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler 23 24from maze_dataset.constants import Coord, CoordTup 25from maze_dataset.dataset.dataset import GPTDataset, GPTDatasetConfig 26from maze_dataset.dataset.maze_dataset import MazeDataset, MazeDatasetConfig 27from maze_dataset.maze import LatticeMaze 28 29 30@serializable_dataclass(kw_only=True) 31class MazeDatasetCollectionConfig(GPTDatasetConfig): 32 """maze dataset collection configuration, including tokenizers and shuffle""" 33 34 # Attributes without a default cannot follow attributes with one [misc] 35 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( # type: ignore[misc] 36 serialization_fn=lambda configs: [config.serialize() for config in configs], 37 loading_fn=lambda data: [ 38 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"] 39 ], 40 ) 41 42 def summary(self) -> dict: 43 """return a summary of the config""" 44 return dict( 45 n_mazes=self.n_mazes, 46 max_grid_n=self.max_grid_n, 47 max_grid_shape=self.max_grid_shape, 48 fname=self.to_fname(), 49 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 50 ) 51 52 @property 53 def n_mazes(self) -> int: 54 """return the total number of mazes in the collection across all dataset""" 55 return sum(config.n_mazes for config in self.maze_dataset_configs) 56 57 @property 58 def max_grid_n(self) -> int: 59 """return the maximum grid size of the mazes in the collection""" 60 return max(config.grid_n for config in self.maze_dataset_configs) 61 62 @property 63 def max_grid_shape(self) -> CoordTup: 64 """return the maximum grid shape of the mazes in the collection""" 65 return (self.max_grid_n, self.max_grid_n) 66 67 @property 68 def max_grid_shape_np(self) -> Coord: 69 """return the maximum grid shape of the mazes in the collection as a numpy array""" 70 return np.array(self.max_grid_shape, dtype=np.int32) 71 72 def stable_hash_cfg(self) -> int: 73 """return a stable hash of the config""" 74 return stable_hash(json.dumps(self.serialize())) 75 76 def to_fname(self) -> str: 77 """convert config to a filename""" 78 return sanitize_fname( 79 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 80 ) 81 82 83class MazeDatasetCollection(GPTDataset): 84 """a collection of maze datasets""" 85 86 def __init__( 87 self, 88 cfg: MazeDatasetCollectionConfig, 89 maze_datasets: list[MazeDataset], 90 generation_metadata_collected: dict | None = None, 91 ) -> None: 92 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 93 super().__init__() 94 self.cfg: MazeDatasetCollectionConfig = cfg 95 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 96 for c, ds in zip( 97 self.cfg.maze_dataset_configs, 98 self.maze_datasets, 99 strict=False, 100 ): 101 assert c.name == ds.cfg.name 102 assert c == ds.cfg 103 104 self.generation_metadata_collected: dict | None = generation_metadata_collected 105 106 @property 107 def dataset_lengths(self) -> list[int]: 108 """return the lengths of each dataset in the collection""" 109 return [len(dataset) for dataset in self.maze_datasets] 110 111 @property 112 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 113 """return the cumulative lengths of each dataset in the collection""" 114 return np.array(list(itertools.accumulate(self.dataset_lengths))) 115 116 @cached_property 117 def mazes(self) -> list[LatticeMaze]: 118 "single list of all mazes in the collection" 119 return list( 120 itertools.chain.from_iterable( 121 dataset.mazes for dataset in self.maze_datasets 122 ), 123 ) 124 125 def __len__(self) -> int: 126 """return the total number of mazes in the collection""" 127 return sum(len(dataset) for dataset in self.maze_datasets) 128 129 def __getitem__(self, index: int) -> LatticeMaze: 130 "get a maze by index" 131 # find which dataset the index belongs to 132 # we add 1, since np.searchsorted returns the 133 # index of the last element that is strictly less than the target 134 # while we want the index of the last element less than or equal to the target 135 dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1)) 136 index_adjusted: int = index 137 if dataset_idx > 0: 138 # if the index is 0, `dataset_idx - 1` will be -1. 139 # We just want to use the base index 140 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1] 141 return self.maze_datasets[dataset_idx][index_adjusted] 142 143 @classmethod 144 def generate( 145 cls, 146 cfg: MazeDatasetCollectionConfig, 147 **kwargs, 148 ) -> "MazeDatasetCollection": 149 """generate a dataset collection from a config""" 150 datasets = [ 151 MazeDataset.generate(config, **kwargs) 152 for config in cfg.maze_dataset_configs 153 ] 154 return cls(cfg, datasets) 155 156 @classmethod 157 def download( 158 cls, 159 cfg: MazeDatasetCollectionConfig, 160 **kwargs, 161 ) -> "MazeDatasetCollection": 162 "(not implemented!) download a dataset collection from a config" 163 datasets = [ 164 MazeDataset.download(config, **kwargs) 165 for config in cfg.maze_dataset_configs 166 ] 167 return cls(cfg, datasets) 168 169 def serialize(self) -> JSONdict: 170 """serialize the dataset collection""" 171 return { 172 _FORMAT_KEY: "MazeDatasetCollection", 173 "cfg": self.cfg.serialize(), 174 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 175 "generation_metadata_collected": json_serialize( 176 self.generation_metadata_collected, 177 ), 178 } 179 180 @classmethod 181 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 182 """load the dataset collection from the representation created by `serialize`""" 183 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 184 return cls( 185 **{ 186 key: load_item_recursive(data[key], tuple()) 187 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 188 }, 189 ) 190 191 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow? 192 def as_tokens( 193 self, 194 # TODO: MazeTokenizer 195 maze_tokenizer, # noqa: ANN001 196 limit: int | None = None, 197 join_tokens_individual_maze: bool = False, 198 ) -> list[list[str]] | list[str]: 199 """return the dataset as tokens 200 201 if join_tokens_individual_maze is True, then the tokens of each maze are 202 joined with a space, and the result is a list of strings. 203 i.e.: 204 >>> dataset.as_tokens(join_tokens_individual_maze=False) 205 [["a", "b", "c"], ["d", "e", "f"]] 206 >>> dataset.as_tokens(join_tokens_individual_maze=True) 207 ["a b c", "d e f"] 208 """ 209 output: list[list[str]] = [ 210 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 211 ] 212 if join_tokens_individual_maze: 213 return [" ".join(tokens) for tokens in output] 214 else: 215 return output 216 217 def update_self_config(self) -> None: 218 "update the config to match the number of mazes, and update the underlying configs of each dataset" 219 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 220 self.cfg.__dict__["n_mazes"] = len(self) 221 for dataset in self.maze_datasets: 222 dataset.update_self_config() 223 224 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets] 225 226 227MazeDatasetCollectionConfig._dataset_class = MazeDatasetCollection # type: ignore[method-assign, assignment] 228 229register_loader_handler( 230 LoaderHandler( 231 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] # noqa: ARG005 232 isinstance(json_item, typing.Mapping) 233 and _FORMAT_KEY in json_item 234 and json_item[_FORMAT_KEY].startswith("MazeDatasetCollection") 235 ), 236 load=lambda json_item, path=None, z=None: MazeDatasetCollection.load(json_item), # type: ignore[misc] # noqa: ARG005 237 uid="MazeDatasetCollection", 238 source_pckg="maze_dataset.generation.maze_dataset_collection", 239 desc="MazeDatasetCollection", 240 ), 241)
31@serializable_dataclass(kw_only=True) 32class MazeDatasetCollectionConfig(GPTDatasetConfig): 33 """maze dataset collection configuration, including tokenizers and shuffle""" 34 35 # Attributes without a default cannot follow attributes with one [misc] 36 maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( # type: ignore[misc] 37 serialization_fn=lambda configs: [config.serialize() for config in configs], 38 loading_fn=lambda data: [ 39 MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"] 40 ], 41 ) 42 43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 ) 52 53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs) 57 58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs) 62 63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n) 67 68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32) 72 73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize())) 76 77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
maze dataset collection configuration, including tokenizers and shuffle
43 def summary(self) -> dict: 44 """return a summary of the config""" 45 return dict( 46 n_mazes=self.n_mazes, 47 max_grid_n=self.max_grid_n, 48 max_grid_shape=self.max_grid_shape, 49 fname=self.to_fname(), 50 cfg_summaries=[c.summary() for c in self.maze_dataset_configs], 51 )
return a summary of the config
53 @property 54 def n_mazes(self) -> int: 55 """return the total number of mazes in the collection across all dataset""" 56 return sum(config.n_mazes for config in self.maze_dataset_configs)
return the total number of mazes in the collection across all dataset
58 @property 59 def max_grid_n(self) -> int: 60 """return the maximum grid size of the mazes in the collection""" 61 return max(config.grid_n for config in self.maze_dataset_configs)
return the maximum grid size of the mazes in the collection
63 @property 64 def max_grid_shape(self) -> CoordTup: 65 """return the maximum grid shape of the mazes in the collection""" 66 return (self.max_grid_n, self.max_grid_n)
return the maximum grid shape of the mazes in the collection
68 @property 69 def max_grid_shape_np(self) -> Coord: 70 """return the maximum grid shape of the mazes in the collection as a numpy array""" 71 return np.array(self.max_grid_shape, dtype=np.int32)
return the maximum grid shape of the mazes in the collection as a numpy array
73 def stable_hash_cfg(self) -> int: 74 """return a stable hash of the config""" 75 return stable_hash(json.dumps(self.serialize()))
return a stable hash of the config
77 def to_fname(self) -> str: 78 """convert config to a filename""" 79 return sanitize_fname( 80 f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", 81 )
convert config to a filename
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
returns the class as a dict, implemented by using @serializable_dataclass
decorator
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
84class MazeDatasetCollection(GPTDataset): 85 """a collection of maze datasets""" 86 87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected 106 107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets] 111 112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths))) 116 117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 ) 125 126 def __len__(self) -> int: 127 """return the total number of mazes in the collection""" 128 return sum(len(dataset) for dataset in self.maze_datasets) 129 130 def __getitem__(self, index: int) -> LatticeMaze: 131 "get a maze by index" 132 # find which dataset the index belongs to 133 # we add 1, since np.searchsorted returns the 134 # index of the last element that is strictly less than the target 135 # while we want the index of the last element less than or equal to the target 136 dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1)) 137 index_adjusted: int = index 138 if dataset_idx > 0: 139 # if the index is 0, `dataset_idx - 1` will be -1. 140 # We just want to use the base index 141 index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1] 142 return self.maze_datasets[dataset_idx][index_adjusted] 143 144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets) 156 157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets) 169 170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 } 180 181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 ) 191 192 # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow? 193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output 217 218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
a collection of maze datasets
87 def __init__( 88 self, 89 cfg: MazeDatasetCollectionConfig, 90 maze_datasets: list[MazeDataset], 91 generation_metadata_collected: dict | None = None, 92 ) -> None: 93 "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" 94 super().__init__() 95 self.cfg: MazeDatasetCollectionConfig = cfg 96 self.maze_datasets: list[MazeDataset] = list(maze_datasets) 97 for c, ds in zip( 98 self.cfg.maze_dataset_configs, 99 self.maze_datasets, 100 strict=False, 101 ): 102 assert c.name == ds.cfg.name 103 assert c == ds.cfg 104 105 self.generation_metadata_collected: dict | None = generation_metadata_collected
initialize the dataset collection from a MazeDatasetCollectionConfig
and a list of MazeDataset
s
107 @property 108 def dataset_lengths(self) -> list[int]: 109 """return the lengths of each dataset in the collection""" 110 return [len(dataset) for dataset in self.maze_datasets]
return the lengths of each dataset in the collection
112 @property 113 def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: 114 """return the cumulative lengths of each dataset in the collection""" 115 return np.array(list(itertools.accumulate(self.dataset_lengths)))
return the cumulative lengths of each dataset in the collection
117 @cached_property 118 def mazes(self) -> list[LatticeMaze]: 119 "single list of all mazes in the collection" 120 return list( 121 itertools.chain.from_iterable( 122 dataset.mazes for dataset in self.maze_datasets 123 ), 124 )
single list of all mazes in the collection
144 @classmethod 145 def generate( 146 cls, 147 cfg: MazeDatasetCollectionConfig, 148 **kwargs, 149 ) -> "MazeDatasetCollection": 150 """generate a dataset collection from a config""" 151 datasets = [ 152 MazeDataset.generate(config, **kwargs) 153 for config in cfg.maze_dataset_configs 154 ] 155 return cls(cfg, datasets)
generate a dataset collection from a config
157 @classmethod 158 def download( 159 cls, 160 cfg: MazeDatasetCollectionConfig, 161 **kwargs, 162 ) -> "MazeDatasetCollection": 163 "(not implemented!) download a dataset collection from a config" 164 datasets = [ 165 MazeDataset.download(config, **kwargs) 166 for config in cfg.maze_dataset_configs 167 ] 168 return cls(cfg, datasets)
(not implemented!) download a dataset collection from a config
170 def serialize(self) -> JSONdict: 171 """serialize the dataset collection""" 172 return { 173 _FORMAT_KEY: "MazeDatasetCollection", 174 "cfg": self.cfg.serialize(), 175 "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], 176 "generation_metadata_collected": json_serialize( 177 self.generation_metadata_collected, 178 ), 179 }
serialize the dataset collection
181 @classmethod 182 def load(cls, data: JSONdict) -> "MazeDatasetCollection": 183 """load the dataset collection from the representation created by `serialize`""" 184 assert data[_FORMAT_KEY] == "MazeDatasetCollection" 185 return cls( 186 **{ 187 key: load_item_recursive(data[key], tuple()) 188 for key in ["cfg", "maze_datasets", "generation_metadata_collected"] 189 }, 190 )
load the dataset collection from the representation created by serialize
193 def as_tokens( 194 self, 195 # TODO: MazeTokenizer 196 maze_tokenizer, # noqa: ANN001 197 limit: int | None = None, 198 join_tokens_individual_maze: bool = False, 199 ) -> list[list[str]] | list[str]: 200 """return the dataset as tokens 201 202 if join_tokens_individual_maze is True, then the tokens of each maze are 203 joined with a space, and the result is a list of strings. 204 i.e.: 205 >>> dataset.as_tokens(join_tokens_individual_maze=False) 206 [["a", "b", "c"], ["d", "e", "f"]] 207 >>> dataset.as_tokens(join_tokens_individual_maze=True) 208 ["a b c", "d e f"] 209 """ 210 output: list[list[str]] = [ 211 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 212 ] 213 if join_tokens_individual_maze: 214 return [" ".join(tokens) for tokens in output] 215 else: 216 return output
return the dataset as tokens
if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.:
>>> dataset.as_tokens(join_tokens_individual_maze=False)
[["a", "b", "c"], ["d", "e", "f"]]
>>> dataset.as_tokens(join_tokens_individual_maze=True)
["a b c", "d e f"]
218 def update_self_config(self) -> None: 219 "update the config to match the number of mazes, and update the underlying configs of each dataset" 220 # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset 221 self.cfg.__dict__["n_mazes"] = len(self) 222 for dataset in self.maze_datasets: 223 dataset.update_self_config() 224 225 self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets]
update the config to match the number of mazes, and update the underlying configs of each dataset