maze_dataset.dataset.rasterized
a special RasterizedMazeDataset
that returns 2 images, one for input and one for target, for each maze
this lets you match the input and target format of the easy_2_hard
dataset
see their paper:
@misc{schwarzschild2021learn,
title={Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks},
author={Avi Schwarzschild and Eitan Borgnia and Arjun Gupta and Furong Huang and Uzi Vishkin and Micah Goldblum and Tom Goldstein},
year={2021},
eprint={2106.04537},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
1"""a special `RasterizedMazeDataset` that returns 2 images, one for input and one for target, for each maze 2 3this lets you match the input and target format of the [`easy_2_hard`](https://github.com/aks2203/easy-to-hard) dataset 4 5 6see their paper: 7 8```bibtex 9@misc{schwarzschild2021learn, 10 title={Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks}, 11 author={Avi Schwarzschild and Eitan Borgnia and Arjun Gupta and Furong Huang and Uzi Vishkin and Micah Goldblum and Tom Goldstein}, 12 year={2021}, 13 eprint={2106.04537}, 14 archivePrefix={arXiv}, 15 primaryClass={cs.LG} 16} 17``` 18""" 19 20import typing 21from pathlib import Path 22 23import numpy as np 24from jaxtyping import Float, Int 25from muutils.json_serialize import serializable_dataclass, serializable_field 26from zanj import ZANJ 27 28from maze_dataset import MazeDataset, MazeDatasetConfig 29from maze_dataset.maze import PixelColors, SolvedMaze 30from maze_dataset.maze.lattice_maze import PixelGrid, _remove_isolated_cells 31 32 33def _extend_pixels( 34 image: Int[np.ndarray, "x y rgb"], 35 n_mult: int = 2, 36 n_bdry: int = 1, 37) -> Int[np.ndarray, "n_mult*x+2*n_bdry n_mult*y+2*n_bdry rgb"]: 38 wall_fill: int = PixelColors.WALL[0] 39 assert all(x == wall_fill for x in PixelColors.WALL), ( 40 "PixelColors.WALL must be a single value" 41 ) 42 43 output: np.ndarray = np.repeat( 44 np.repeat( 45 image, 46 n_mult, 47 axis=0, 48 ), 49 n_mult, 50 axis=1, 51 ) 52 53 # pad on all sides by n_bdry 54 return np.pad( 55 output, 56 pad_width=((n_bdry, n_bdry), (n_bdry, n_bdry), (0, 0)), 57 mode="constant", 58 constant_values=wall_fill, 59 ) 60 61 62_RASTERIZED_CFG_ADDED_PARAMS: list[str] = [ 63 "remove_isolated_cells", 64 "extend_pixels", 65 "endpoints_as_open", 66] 67 68 69def process_maze_rasterized_input_target( 70 maze: SolvedMaze, 71 remove_isolated_cells: bool = True, 72 extend_pixels: bool = True, 73 endpoints_as_open: bool = False, 74) -> Float[np.ndarray, "in/tgt=2 x y rgb=3"]: 75 """turn a single `SolvedMaze` into an array representation 76 77 has extra options for matching the format in https://github.com/aks2203/easy-to-hard 78 79 # Parameters: 80 - `maze: SolvedMaze` 81 the maze to process 82 - `remove_isolated_cells: bool` 83 whether to set isolated cells (no connections) to walls 84 (default: `True`) 85 - `extend_pixels: bool` 86 whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) 87 (default: `True`) 88 - `endpoints_as_open: bool` 89 whether to set endpoints to open 90 (default: `False`) 91 """ 92 # problem and solution mazes 93 maze_pixels: PixelGrid = maze.as_pixels(show_endpoints=True, show_solution=True) 94 problem_maze: PixelGrid = maze_pixels.copy() 95 solution_maze: PixelGrid = maze_pixels.copy() 96 97 # in problem maze, set path to open 98 problem_maze[(problem_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN 99 100 # wherever solution maze is PixelColors.OPEN, set it to PixelColors.WALL 101 solution_maze[(solution_maze == PixelColors.OPEN).all(axis=-1)] = PixelColors.WALL 102 # wherever it is solution, set it to PixelColors.OPEN 103 solution_maze[(solution_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN 104 if endpoints_as_open: 105 for color in (PixelColors.START, PixelColors.END): 106 solution_maze[(solution_maze == color).all(axis=-1)] = PixelColors.OPEN 107 108 # postprocess to match original easy_2_hard dataset 109 if remove_isolated_cells: 110 problem_maze = _remove_isolated_cells(problem_maze) 111 solution_maze = _remove_isolated_cells(solution_maze) 112 113 if extend_pixels: 114 problem_maze = _extend_pixels(problem_maze) 115 solution_maze = _extend_pixels(solution_maze) 116 117 return np.array([problem_maze, solution_maze]) 118 119 120# TYPING: error: Attributes without a default cannot follow attributes with one [misc] 121@serializable_dataclass 122class RasterizedMazeDatasetConfig(MazeDatasetConfig): # type: ignore[misc] 123 """adds options which we then pass to `process_maze_rasterized_input_target` 124 125 - `remove_isolated_cells: bool` whether to set isolated cells to walls 126 - `extend_pixels: bool` whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) 127 - `endpoints_as_open: bool` whether to set endpoints to open 128 """ 129 130 remove_isolated_cells: bool = serializable_field(default=True) 131 extend_pixels: bool = serializable_field(default=True) 132 endpoints_as_open: bool = serializable_field(default=False) 133 134 135class RasterizedMazeDataset(MazeDataset): 136 "subclass of `MazeDataset` that uses a `RasterizedMazeDatasetConfig`" 137 138 cfg: RasterizedMazeDatasetConfig 139 140 # this override here is intentional 141 def __getitem__(self, idx: int) -> Float[np.ndarray, "item in/tgt=2 x y rgb=3"]: # type: ignore[override] 142 """get a single maze""" 143 # get the solved maze 144 solved_maze: SolvedMaze = self.mazes[idx] 145 146 return process_maze_rasterized_input_target( 147 maze=solved_maze, 148 remove_isolated_cells=self.cfg.remove_isolated_cells, 149 extend_pixels=self.cfg.extend_pixels, 150 endpoints_as_open=self.cfg.endpoints_as_open, 151 ) 152 153 def get_batch( 154 self, 155 idxs: list[int] | None, 156 ) -> Float[np.ndarray, "in/tgt=2 item x y rgb=3"]: 157 """get a batch of mazes as a tensor, from a list of indices""" 158 if idxs is None: 159 idxs = list(range(len(self))) 160 161 inputs: list[Float[np.ndarray, "x y rgb=3"]] 162 targets: list[Float[np.ndarray, "x y rgb=3"]] 163 inputs, targets = zip(*[self[i] for i in idxs], strict=False) # type: ignore[assignment] 164 165 return np.array([inputs, targets]) 166 167 # override here is intentional 168 @classmethod 169 def from_config( 170 cls, 171 cfg: RasterizedMazeDatasetConfig | MazeDatasetConfig, # type: ignore[override] 172 do_generate: bool = True, 173 load_local: bool = True, 174 save_local: bool = True, 175 zanj: ZANJ | None = None, 176 do_download: bool = True, 177 local_base_path: Path = Path("data/maze_dataset"), 178 except_on_config_mismatch: bool = True, 179 allow_generation_metadata_filter_mismatch: bool = True, 180 verbose: bool = False, 181 **kwargs, 182 ) -> "RasterizedMazeDataset": 183 """create a rasterized maze dataset from a config 184 185 priority of loading: 186 1. load from local 187 2. download 188 3. generate 189 190 """ 191 return typing.cast( 192 RasterizedMazeDataset, 193 super().from_config( 194 cfg=cfg, 195 do_generate=do_generate, 196 load_local=load_local, 197 save_local=save_local, 198 zanj=zanj, 199 do_download=do_download, 200 local_base_path=local_base_path, 201 except_on_config_mismatch=except_on_config_mismatch, 202 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 203 verbose=verbose, 204 **kwargs, 205 ), 206 ) 207 208 @classmethod 209 def from_config_augmented( 210 cls, 211 cfg: RasterizedMazeDatasetConfig, 212 **kwargs, 213 ) -> "RasterizedMazeDataset": 214 """loads either a maze transformer dataset or an easy_2_hard dataset""" 215 _cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize()) 216 return cls.from_base_MazeDataset( 217 cls.from_config(cfg=_cfg_temp, **kwargs), 218 added_params={ 219 k: v 220 for k, v in cfg.serialize().items() 221 if k in _RASTERIZED_CFG_ADDED_PARAMS 222 }, 223 ) 224 225 @classmethod 226 def from_base_MazeDataset( 227 cls, 228 base_dataset: MazeDataset, 229 added_params: dict | None = None, 230 ) -> "RasterizedMazeDataset": 231 """loads either a maze transformer dataset or an easy_2_hard dataset""" 232 if added_params is None: 233 added_params = dict( 234 remove_isolated_cells=True, 235 extend_pixels=True, 236 ) 237 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( 238 { 239 **base_dataset.cfg.serialize(), 240 **added_params, 241 }, 242 ) 243 output: RasterizedMazeDataset = cls( 244 cfg=cfg, 245 mazes=base_dataset.mazes, 246 ) 247 return output 248 249 def plot(self, count: int | None = None, show: bool = True) -> tuple | None: 250 """plot the first `count` mazes in the dataset""" 251 import matplotlib.pyplot as plt 252 253 print(f"{self[0][0].shape = }, {self[0][1].shape = }") 254 count = count or len(self) 255 if count == 0: 256 print("No mazes to plot for dataset") 257 return None 258 fig, axes = plt.subplots(2, count, figsize=(15, 5)) 259 if count == 1: 260 axes = [axes] 261 for i in range(count): 262 axes[0, i].imshow(self[i][0]) 263 axes[1, i].imshow(self[i][1]) 264 # remove ticks 265 axes[0, i].set_xticks([]) 266 axes[0, i].set_yticks([]) 267 axes[1, i].set_xticks([]) 268 axes[1, i].set_yticks([]) 269 270 if show: 271 plt.show() 272 273 return fig, axes 274 275 276def make_numpy_collection( 277 base_cfg: RasterizedMazeDatasetConfig, 278 grid_sizes: list[int], 279 from_config_kwargs: dict | None = None, 280 verbose: bool = True, 281 key_fmt: str = "{size}x{size}", 282) -> dict[ 283 typing.Literal["configs", "arrays"], 284 dict[str, RasterizedMazeDatasetConfig | np.ndarray], 285]: 286 """create a collection of configs and arrays for different grid sizes, in plain tensor form 287 288 output is of structure: 289 ``` 290 { 291 "configs": { 292 "<n>x<n>": RasterizedMazeDatasetConfig, 293 ... 294 }, 295 "arrays": { 296 "<n>x<n>": np.ndarray, 297 ... 298 }, 299 } 300 ``` 301 """ 302 if from_config_kwargs is None: 303 from_config_kwargs = {} 304 305 datasets: dict[int, RasterizedMazeDataset] = {} 306 307 for size in grid_sizes: 308 if verbose: 309 print(f"Generating dataset for maze size {size}...") 310 311 cfg_temp: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( 312 base_cfg.serialize(), 313 ) 314 cfg_temp.grid_n = size 315 316 datasets[size] = RasterizedMazeDataset.from_config_augmented( 317 cfg=cfg_temp, 318 **from_config_kwargs, 319 ) 320 321 return dict( 322 configs={ 323 key_fmt.format(size=size): dataset.cfg for size, dataset in datasets.items() 324 }, 325 arrays={ 326 # get_batch(None) returns a single tensor of shape (n, 2, x, y, 3) 327 key_fmt.format(size=size): dataset.get_batch(None) 328 for size, dataset in datasets.items() 329 }, 330 )
70def process_maze_rasterized_input_target( 71 maze: SolvedMaze, 72 remove_isolated_cells: bool = True, 73 extend_pixels: bool = True, 74 endpoints_as_open: bool = False, 75) -> Float[np.ndarray, "in/tgt=2 x y rgb=3"]: 76 """turn a single `SolvedMaze` into an array representation 77 78 has extra options for matching the format in https://github.com/aks2203/easy-to-hard 79 80 # Parameters: 81 - `maze: SolvedMaze` 82 the maze to process 83 - `remove_isolated_cells: bool` 84 whether to set isolated cells (no connections) to walls 85 (default: `True`) 86 - `extend_pixels: bool` 87 whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) 88 (default: `True`) 89 - `endpoints_as_open: bool` 90 whether to set endpoints to open 91 (default: `False`) 92 """ 93 # problem and solution mazes 94 maze_pixels: PixelGrid = maze.as_pixels(show_endpoints=True, show_solution=True) 95 problem_maze: PixelGrid = maze_pixels.copy() 96 solution_maze: PixelGrid = maze_pixels.copy() 97 98 # in problem maze, set path to open 99 problem_maze[(problem_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN 100 101 # wherever solution maze is PixelColors.OPEN, set it to PixelColors.WALL 102 solution_maze[(solution_maze == PixelColors.OPEN).all(axis=-1)] = PixelColors.WALL 103 # wherever it is solution, set it to PixelColors.OPEN 104 solution_maze[(solution_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN 105 if endpoints_as_open: 106 for color in (PixelColors.START, PixelColors.END): 107 solution_maze[(solution_maze == color).all(axis=-1)] = PixelColors.OPEN 108 109 # postprocess to match original easy_2_hard dataset 110 if remove_isolated_cells: 111 problem_maze = _remove_isolated_cells(problem_maze) 112 solution_maze = _remove_isolated_cells(solution_maze) 113 114 if extend_pixels: 115 problem_maze = _extend_pixels(problem_maze) 116 solution_maze = _extend_pixels(solution_maze) 117 118 return np.array([problem_maze, solution_maze])
turn a single SolvedMaze
into an array representation
has extra options for matching the format in https://github.com/aks2203/easy-to-hard
Parameters:
maze: SolvedMaze
the maze to processremove_isolated_cells: bool
whether to set isolated cells (no connections) to walls (default:True
)extend_pixels: bool
whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) (default:True
)endpoints_as_open: bool
whether to set endpoints to open (default:False
)
122@serializable_dataclass 123class RasterizedMazeDatasetConfig(MazeDatasetConfig): # type: ignore[misc] 124 """adds options which we then pass to `process_maze_rasterized_input_target` 125 126 - `remove_isolated_cells: bool` whether to set isolated cells to walls 127 - `extend_pixels: bool` whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) 128 - `endpoints_as_open: bool` whether to set endpoints to open 129 """ 130 131 remove_isolated_cells: bool = serializable_field(default=True) 132 extend_pixels: bool = serializable_field(default=True) 133 endpoints_as_open: bool = serializable_field(default=False)
adds options which we then pass to process_maze_rasterized_input_target
remove_isolated_cells: bool
whether to set isolated cells to wallsextend_pixels: bool
whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze)endpoints_as_open: bool
whether to set endpoints to open
714 def serialize(self) -> dict[str, Any]: 715 result: dict[str, Any] = { 716 _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)" 717 } 718 # for each field in the class 719 for field in dataclasses.fields(self): # type: ignore[arg-type] 720 # need it to be our special SerializableField 721 if not isinstance(field, SerializableField): 722 raise NotSerializableFieldException( 723 f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, " 724 f"but a {type(field)} " 725 "this state should be inaccessible, please report this bug!" 726 ) 727 728 # try to save it 729 if field.serialize: 730 try: 731 # get the val 732 value = getattr(self, field.name) 733 # if it is a serializable dataclass, serialize it 734 if isinstance(value, SerializableDataclass): 735 value = value.serialize() 736 # if the value has a serialization function, use that 737 if hasattr(value, "serialize") and callable(value.serialize): 738 value = value.serialize() 739 # if the field has a serialization function, use that 740 # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies! 741 elif field.serialization_fn: 742 value = field.serialization_fn(value) 743 744 # store the value in the result 745 result[field.name] = value 746 except Exception as e: 747 raise FieldSerializationError( 748 "\n".join( 749 [ 750 f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}", 751 f"{field = }", 752 f"{value = }", 753 f"{self = }", 754 ] 755 ) 756 ) from e 757 758 # store each property if we can get it 759 for prop in self._properties_to_serialize: 760 if hasattr(cls, prop): 761 value = getattr(self, prop) 762 result[prop] = value 763 else: 764 raise AttributeError( 765 f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}" 766 + f"but it is in {self._properties_to_serialize = }" 767 + f"\n{self = }" 768 ) 769 770 return result
serialize the MazeDatasetConfig with all fields and fname
777 @classmethod # type: ignore[misc] 778 def load(cls, data: dict[str, Any] | T) -> Type[T]: 779 # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ 780 if isinstance(data, cls): 781 return data 782 783 assert isinstance( 784 data, typing.Mapping 785 ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }" 786 787 cls_type_hints: dict[str, Any] = get_cls_type_hints(cls) 788 789 # initialize dict for keeping what we will pass to the constructor 790 ctor_kwargs: dict[str, Any] = dict() 791 792 # iterate over the fields of the class 793 for field in dataclasses.fields(cls): 794 # check if the field is a SerializableField 795 assert isinstance( 796 field, SerializableField 797 ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new" 798 799 # check if the field is in the data and if it should be initialized 800 if (field.name in data) and field.init: 801 # get the value, we will be processing it 802 value: Any = data[field.name] 803 804 # get the type hint for the field 805 field_type_hint: Any = cls_type_hints.get(field.name, None) 806 807 # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set 808 if field.deserialize_fn: 809 # if it has a deserialization function, use that 810 value = field.deserialize_fn(value) 811 elif field.loading_fn: 812 # if it has a loading function, use that 813 value = field.loading_fn(data) 814 elif ( 815 field_type_hint is not None 816 and hasattr(field_type_hint, "load") 817 and callable(field_type_hint.load) 818 ): 819 # if no loading function but has a type hint with a load method, use that 820 if isinstance(value, dict): 821 value = field_type_hint.load(value) 822 else: 823 raise FieldLoadingError( 824 f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }" 825 ) 826 else: 827 # assume no loading needs to happen, keep `value` as-is 828 pass 829 830 # store the value in the constructor kwargs 831 ctor_kwargs[field.name] = value 832 833 # create a new instance of the class with the constructor kwargs 834 output: cls = cls(**ctor_kwargs) 835 836 # validate the types of the fields if needed 837 if on_typecheck_mismatch != ErrorMode.IGNORE: 838 fields_valid: dict[str, bool] = ( 839 SerializableDataclass__validate_fields_types__dict( 840 output, 841 on_typecheck_error=on_typecheck_error, 842 ) 843 ) 844 845 # if there are any fields that are not valid, raise an error 846 if not all(fields_valid.values()): 847 msg: str = ( 848 f"Type mismatch in fields of {cls.__name__}:\n" 849 + "\n".join( 850 [ 851 f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }" 852 for k, v in fields_valid.items() 853 if not v 854 ] 855 ) 856 ) 857 858 on_typecheck_mismatch.process( 859 msg, except_cls=FieldTypeMismatchError 860 ) 861 862 # return the new instance 863 return output
takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass
decorator
283def SerializableDataclass__validate_fields_types( 284 self: SerializableDataclass, 285 on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR, 286) -> bool: 287 """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field""" 288 return all( 289 SerializableDataclass__validate_fields_types__dict( 290 self, on_typecheck_error=on_typecheck_error 291 ).values() 292 )
validate the types of all the fields on a SerializableDataclass
. calls SerializableDataclass__validate_field_type
for each field
Inherited Members
- maze_dataset.dataset.maze_dataset_config.MazeDatasetConfig
- config_version
- versions
- summary
- success_fraction_estimate
- success_fraction_compensate
- maze_dataset.dataset.maze_dataset_config.MazeDatasetConfig_base
- grid_n
- n_mazes
- maze_ctor
- maze_ctor_kwargs
- endpoint_kwargs
- grid_shape
- grid_shape_np
- max_grid_n
- stable_hash_cfg
- to_fname
- muutils.json_serialize.serializable_dataclass.SerializableDataclass
- validate_field_type
- diff
- update_from_nested_dict
136class RasterizedMazeDataset(MazeDataset): 137 "subclass of `MazeDataset` that uses a `RasterizedMazeDatasetConfig`" 138 139 cfg: RasterizedMazeDatasetConfig 140 141 # this override here is intentional 142 def __getitem__(self, idx: int) -> Float[np.ndarray, "item in/tgt=2 x y rgb=3"]: # type: ignore[override] 143 """get a single maze""" 144 # get the solved maze 145 solved_maze: SolvedMaze = self.mazes[idx] 146 147 return process_maze_rasterized_input_target( 148 maze=solved_maze, 149 remove_isolated_cells=self.cfg.remove_isolated_cells, 150 extend_pixels=self.cfg.extend_pixels, 151 endpoints_as_open=self.cfg.endpoints_as_open, 152 ) 153 154 def get_batch( 155 self, 156 idxs: list[int] | None, 157 ) -> Float[np.ndarray, "in/tgt=2 item x y rgb=3"]: 158 """get a batch of mazes as a tensor, from a list of indices""" 159 if idxs is None: 160 idxs = list(range(len(self))) 161 162 inputs: list[Float[np.ndarray, "x y rgb=3"]] 163 targets: list[Float[np.ndarray, "x y rgb=3"]] 164 inputs, targets = zip(*[self[i] for i in idxs], strict=False) # type: ignore[assignment] 165 166 return np.array([inputs, targets]) 167 168 # override here is intentional 169 @classmethod 170 def from_config( 171 cls, 172 cfg: RasterizedMazeDatasetConfig | MazeDatasetConfig, # type: ignore[override] 173 do_generate: bool = True, 174 load_local: bool = True, 175 save_local: bool = True, 176 zanj: ZANJ | None = None, 177 do_download: bool = True, 178 local_base_path: Path = Path("data/maze_dataset"), 179 except_on_config_mismatch: bool = True, 180 allow_generation_metadata_filter_mismatch: bool = True, 181 verbose: bool = False, 182 **kwargs, 183 ) -> "RasterizedMazeDataset": 184 """create a rasterized maze dataset from a config 185 186 priority of loading: 187 1. load from local 188 2. download 189 3. generate 190 191 """ 192 return typing.cast( 193 RasterizedMazeDataset, 194 super().from_config( 195 cfg=cfg, 196 do_generate=do_generate, 197 load_local=load_local, 198 save_local=save_local, 199 zanj=zanj, 200 do_download=do_download, 201 local_base_path=local_base_path, 202 except_on_config_mismatch=except_on_config_mismatch, 203 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 204 verbose=verbose, 205 **kwargs, 206 ), 207 ) 208 209 @classmethod 210 def from_config_augmented( 211 cls, 212 cfg: RasterizedMazeDatasetConfig, 213 **kwargs, 214 ) -> "RasterizedMazeDataset": 215 """loads either a maze transformer dataset or an easy_2_hard dataset""" 216 _cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize()) 217 return cls.from_base_MazeDataset( 218 cls.from_config(cfg=_cfg_temp, **kwargs), 219 added_params={ 220 k: v 221 for k, v in cfg.serialize().items() 222 if k in _RASTERIZED_CFG_ADDED_PARAMS 223 }, 224 ) 225 226 @classmethod 227 def from_base_MazeDataset( 228 cls, 229 base_dataset: MazeDataset, 230 added_params: dict | None = None, 231 ) -> "RasterizedMazeDataset": 232 """loads either a maze transformer dataset or an easy_2_hard dataset""" 233 if added_params is None: 234 added_params = dict( 235 remove_isolated_cells=True, 236 extend_pixels=True, 237 ) 238 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( 239 { 240 **base_dataset.cfg.serialize(), 241 **added_params, 242 }, 243 ) 244 output: RasterizedMazeDataset = cls( 245 cfg=cfg, 246 mazes=base_dataset.mazes, 247 ) 248 return output 249 250 def plot(self, count: int | None = None, show: bool = True) -> tuple | None: 251 """plot the first `count` mazes in the dataset""" 252 import matplotlib.pyplot as plt 253 254 print(f"{self[0][0].shape = }, {self[0][1].shape = }") 255 count = count or len(self) 256 if count == 0: 257 print("No mazes to plot for dataset") 258 return None 259 fig, axes = plt.subplots(2, count, figsize=(15, 5)) 260 if count == 1: 261 axes = [axes] 262 for i in range(count): 263 axes[0, i].imshow(self[i][0]) 264 axes[1, i].imshow(self[i][1]) 265 # remove ticks 266 axes[0, i].set_xticks([]) 267 axes[0, i].set_yticks([]) 268 axes[1, i].set_xticks([]) 269 axes[1, i].set_yticks([]) 270 271 if show: 272 plt.show() 273 274 return fig, axes
subclass of MazeDataset
that uses a RasterizedMazeDatasetConfig
154 def get_batch( 155 self, 156 idxs: list[int] | None, 157 ) -> Float[np.ndarray, "in/tgt=2 item x y rgb=3"]: 158 """get a batch of mazes as a tensor, from a list of indices""" 159 if idxs is None: 160 idxs = list(range(len(self))) 161 162 inputs: list[Float[np.ndarray, "x y rgb=3"]] 163 targets: list[Float[np.ndarray, "x y rgb=3"]] 164 inputs, targets = zip(*[self[i] for i in idxs], strict=False) # type: ignore[assignment] 165 166 return np.array([inputs, targets])
get a batch of mazes as a tensor, from a list of indices
169 @classmethod 170 def from_config( 171 cls, 172 cfg: RasterizedMazeDatasetConfig | MazeDatasetConfig, # type: ignore[override] 173 do_generate: bool = True, 174 load_local: bool = True, 175 save_local: bool = True, 176 zanj: ZANJ | None = None, 177 do_download: bool = True, 178 local_base_path: Path = Path("data/maze_dataset"), 179 except_on_config_mismatch: bool = True, 180 allow_generation_metadata_filter_mismatch: bool = True, 181 verbose: bool = False, 182 **kwargs, 183 ) -> "RasterizedMazeDataset": 184 """create a rasterized maze dataset from a config 185 186 priority of loading: 187 1. load from local 188 2. download 189 3. generate 190 191 """ 192 return typing.cast( 193 RasterizedMazeDataset, 194 super().from_config( 195 cfg=cfg, 196 do_generate=do_generate, 197 load_local=load_local, 198 save_local=save_local, 199 zanj=zanj, 200 do_download=do_download, 201 local_base_path=local_base_path, 202 except_on_config_mismatch=except_on_config_mismatch, 203 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 204 verbose=verbose, 205 **kwargs, 206 ), 207 )
create a rasterized maze dataset from a config
priority of loading:
- load from local
- download
- generate
209 @classmethod 210 def from_config_augmented( 211 cls, 212 cfg: RasterizedMazeDatasetConfig, 213 **kwargs, 214 ) -> "RasterizedMazeDataset": 215 """loads either a maze transformer dataset or an easy_2_hard dataset""" 216 _cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize()) 217 return cls.from_base_MazeDataset( 218 cls.from_config(cfg=_cfg_temp, **kwargs), 219 added_params={ 220 k: v 221 for k, v in cfg.serialize().items() 222 if k in _RASTERIZED_CFG_ADDED_PARAMS 223 }, 224 )
loads either a maze transformer dataset or an easy_2_hard dataset
226 @classmethod 227 def from_base_MazeDataset( 228 cls, 229 base_dataset: MazeDataset, 230 added_params: dict | None = None, 231 ) -> "RasterizedMazeDataset": 232 """loads either a maze transformer dataset or an easy_2_hard dataset""" 233 if added_params is None: 234 added_params = dict( 235 remove_isolated_cells=True, 236 extend_pixels=True, 237 ) 238 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( 239 { 240 **base_dataset.cfg.serialize(), 241 **added_params, 242 }, 243 ) 244 output: RasterizedMazeDataset = cls( 245 cfg=cfg, 246 mazes=base_dataset.mazes, 247 ) 248 return output
loads either a maze transformer dataset or an easy_2_hard dataset
250 def plot(self, count: int | None = None, show: bool = True) -> tuple | None: 251 """plot the first `count` mazes in the dataset""" 252 import matplotlib.pyplot as plt 253 254 print(f"{self[0][0].shape = }, {self[0][1].shape = }") 255 count = count or len(self) 256 if count == 0: 257 print("No mazes to plot for dataset") 258 return None 259 fig, axes = plt.subplots(2, count, figsize=(15, 5)) 260 if count == 1: 261 axes = [axes] 262 for i in range(count): 263 axes[0, i].imshow(self[i][0]) 264 axes[1, i].imshow(self[i][1]) 265 # remove ticks 266 axes[0, i].set_xticks([]) 267 axes[0, i].set_yticks([]) 268 axes[1, i].set_xticks([]) 269 axes[1, i].set_yticks([]) 270 271 if show: 272 plt.show() 273 274 return fig, axes
plot the first count
mazes in the dataset
277def make_numpy_collection( 278 base_cfg: RasterizedMazeDatasetConfig, 279 grid_sizes: list[int], 280 from_config_kwargs: dict | None = None, 281 verbose: bool = True, 282 key_fmt: str = "{size}x{size}", 283) -> dict[ 284 typing.Literal["configs", "arrays"], 285 dict[str, RasterizedMazeDatasetConfig | np.ndarray], 286]: 287 """create a collection of configs and arrays for different grid sizes, in plain tensor form 288 289 output is of structure: 290 ``` 291 { 292 "configs": { 293 "<n>x<n>": RasterizedMazeDatasetConfig, 294 ... 295 }, 296 "arrays": { 297 "<n>x<n>": np.ndarray, 298 ... 299 }, 300 } 301 ``` 302 """ 303 if from_config_kwargs is None: 304 from_config_kwargs = {} 305 306 datasets: dict[int, RasterizedMazeDataset] = {} 307 308 for size in grid_sizes: 309 if verbose: 310 print(f"Generating dataset for maze size {size}...") 311 312 cfg_temp: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( 313 base_cfg.serialize(), 314 ) 315 cfg_temp.grid_n = size 316 317 datasets[size] = RasterizedMazeDataset.from_config_augmented( 318 cfg=cfg_temp, 319 **from_config_kwargs, 320 ) 321 322 return dict( 323 configs={ 324 key_fmt.format(size=size): dataset.cfg for size, dataset in datasets.items() 325 }, 326 arrays={ 327 # get_batch(None) returns a single tensor of shape (n, 2, x, y, 3) 328 key_fmt.format(size=size): dataset.get_batch(None) 329 for size, dataset in datasets.items() 330 }, 331 )
create a collection of configs and arrays for different grid sizes, in plain tensor form
output is of structure:
{
"configs": {
"<n>x<n>": RasterizedMazeDatasetConfig,
...
},
"arrays": {
"<n>x<n>": np.ndarray,
...
},
}