docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.dataset.rasterized

a special RasterizedMazeDataset that returns 2 images, one for input and one for target, for each maze

this lets you match the input and target format of the easy_2_hard dataset

see their paper:

@misc{schwarzschild2021learn,
        title={Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks},
        author={Avi Schwarzschild and Eitan Borgnia and Arjun Gupta and Furong Huang and Uzi Vishkin and Micah Goldblum and Tom Goldstein},
        year={2021},
        eprint={2106.04537},
        archivePrefix={arXiv},
        primaryClass={cs.LG}
}

  1"""a special `RasterizedMazeDataset` that returns 2 images, one for input and one for target, for each maze
  2
  3this lets you match the input and target format of the [`easy_2_hard`](https://github.com/aks2203/easy-to-hard) dataset
  4
  5
  6see their paper:
  7
  8```bibtex
  9@misc{schwarzschild2021learn,
 10	title={Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks},
 11	author={Avi Schwarzschild and Eitan Borgnia and Arjun Gupta and Furong Huang and Uzi Vishkin and Micah Goldblum and Tom Goldstein},
 12	year={2021},
 13	eprint={2106.04537},
 14	archivePrefix={arXiv},
 15	primaryClass={cs.LG}
 16}
 17```
 18"""
 19
 20import typing
 21from pathlib import Path
 22
 23import numpy as np
 24from jaxtyping import Float, Int
 25from muutils.json_serialize import serializable_dataclass, serializable_field
 26from zanj import ZANJ
 27
 28from maze_dataset import MazeDataset, MazeDatasetConfig
 29from maze_dataset.maze import PixelColors, SolvedMaze
 30from maze_dataset.maze.lattice_maze import PixelGrid, _remove_isolated_cells
 31
 32
 33def _extend_pixels(
 34	image: Int[np.ndarray, "x y rgb"],
 35	n_mult: int = 2,
 36	n_bdry: int = 1,
 37) -> Int[np.ndarray, "n_mult*x+2*n_bdry n_mult*y+2*n_bdry rgb"]:
 38	wall_fill: int = PixelColors.WALL[0]
 39	assert all(x == wall_fill for x in PixelColors.WALL), (
 40		"PixelColors.WALL must be a single value"
 41	)
 42
 43	output: np.ndarray = np.repeat(
 44		np.repeat(
 45			image,
 46			n_mult,
 47			axis=0,
 48		),
 49		n_mult,
 50		axis=1,
 51	)
 52
 53	# pad on all sides by n_bdry
 54	return np.pad(
 55		output,
 56		pad_width=((n_bdry, n_bdry), (n_bdry, n_bdry), (0, 0)),
 57		mode="constant",
 58		constant_values=wall_fill,
 59	)
 60
 61
 62_RASTERIZED_CFG_ADDED_PARAMS: list[str] = [
 63	"remove_isolated_cells",
 64	"extend_pixels",
 65	"endpoints_as_open",
 66]
 67
 68
 69def process_maze_rasterized_input_target(
 70	maze: SolvedMaze,
 71	remove_isolated_cells: bool = True,
 72	extend_pixels: bool = True,
 73	endpoints_as_open: bool = False,
 74) -> Float[np.ndarray, "in/tgt=2 x y rgb=3"]:
 75	"""turn a single `SolvedMaze` into an array representation
 76
 77	has extra options for matching the format in https://github.com/aks2203/easy-to-hard
 78
 79	# Parameters:
 80	- `maze: SolvedMaze`
 81		the maze to process
 82	- `remove_isolated_cells: bool`
 83		whether to set isolated cells (no connections) to walls
 84		(default: `True`)
 85	- `extend_pixels: bool`
 86		whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze)
 87		(default: `True`)
 88	- `endpoints_as_open: bool`
 89		whether to set endpoints to open
 90		(default: `False`)
 91	"""
 92	# problem and solution mazes
 93	maze_pixels: PixelGrid = maze.as_pixels(show_endpoints=True, show_solution=True)
 94	problem_maze: PixelGrid = maze_pixels.copy()
 95	solution_maze: PixelGrid = maze_pixels.copy()
 96
 97	# in problem maze, set path to open
 98	problem_maze[(problem_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN
 99
100	# wherever solution maze is PixelColors.OPEN, set it to PixelColors.WALL
101	solution_maze[(solution_maze == PixelColors.OPEN).all(axis=-1)] = PixelColors.WALL
102	# wherever it is solution, set it to PixelColors.OPEN
103	solution_maze[(solution_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN
104	if endpoints_as_open:
105		for color in (PixelColors.START, PixelColors.END):
106			solution_maze[(solution_maze == color).all(axis=-1)] = PixelColors.OPEN
107
108	# postprocess to match original easy_2_hard dataset
109	if remove_isolated_cells:
110		problem_maze = _remove_isolated_cells(problem_maze)
111		solution_maze = _remove_isolated_cells(solution_maze)
112
113	if extend_pixels:
114		problem_maze = _extend_pixels(problem_maze)
115		solution_maze = _extend_pixels(solution_maze)
116
117	return np.array([problem_maze, solution_maze])
118
119
120# TYPING: error: Attributes without a default cannot follow attributes with one  [misc]
121@serializable_dataclass
122class RasterizedMazeDatasetConfig(MazeDatasetConfig):  # type: ignore[misc]
123	"""adds options which we then pass to `process_maze_rasterized_input_target`
124
125	- `remove_isolated_cells: bool` whether to set isolated cells to walls
126	- `extend_pixels: bool` whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze)
127	- `endpoints_as_open: bool` whether to set endpoints to open
128	"""
129
130	remove_isolated_cells: bool = serializable_field(default=True)
131	extend_pixels: bool = serializable_field(default=True)
132	endpoints_as_open: bool = serializable_field(default=False)
133
134
135class RasterizedMazeDataset(MazeDataset):
136	"subclass of `MazeDataset` that uses a `RasterizedMazeDatasetConfig`"
137
138	cfg: RasterizedMazeDatasetConfig
139
140	# this override here is intentional
141	def __getitem__(self, idx: int) -> Float[np.ndarray, "item in/tgt=2 x y rgb=3"]:  # type: ignore[override]
142		"""get a single maze"""
143		# get the solved maze
144		solved_maze: SolvedMaze = self.mazes[idx]
145
146		return process_maze_rasterized_input_target(
147			maze=solved_maze,
148			remove_isolated_cells=self.cfg.remove_isolated_cells,
149			extend_pixels=self.cfg.extend_pixels,
150			endpoints_as_open=self.cfg.endpoints_as_open,
151		)
152
153	def get_batch(
154		self,
155		idxs: list[int] | None,
156	) -> Float[np.ndarray, "in/tgt=2 item x y rgb=3"]:
157		"""get a batch of mazes as a tensor, from a list of indices"""
158		if idxs is None:
159			idxs = list(range(len(self)))
160
161		inputs: list[Float[np.ndarray, "x y rgb=3"]]
162		targets: list[Float[np.ndarray, "x y rgb=3"]]
163		inputs, targets = zip(*[self[i] for i in idxs], strict=False)  # type: ignore[assignment]
164
165		return np.array([inputs, targets])
166
167	# override here is intentional
168	@classmethod
169	def from_config(
170		cls,
171		cfg: RasterizedMazeDatasetConfig | MazeDatasetConfig,  # type: ignore[override]
172		do_generate: bool = True,
173		load_local: bool = True,
174		save_local: bool = True,
175		zanj: ZANJ | None = None,
176		do_download: bool = True,
177		local_base_path: Path = Path("data/maze_dataset"),
178		except_on_config_mismatch: bool = True,
179		allow_generation_metadata_filter_mismatch: bool = True,
180		verbose: bool = False,
181		**kwargs,
182	) -> "RasterizedMazeDataset":
183		"""create a rasterized maze dataset from a config
184
185		priority of loading:
186		1. load from local
187		2. download
188		3. generate
189
190		"""
191		return typing.cast(
192			RasterizedMazeDataset,
193			super().from_config(
194				cfg=cfg,
195				do_generate=do_generate,
196				load_local=load_local,
197				save_local=save_local,
198				zanj=zanj,
199				do_download=do_download,
200				local_base_path=local_base_path,
201				except_on_config_mismatch=except_on_config_mismatch,
202				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
203				verbose=verbose,
204				**kwargs,
205			),
206		)
207
208	@classmethod
209	def from_config_augmented(
210		cls,
211		cfg: RasterizedMazeDatasetConfig,
212		**kwargs,
213	) -> "RasterizedMazeDataset":
214		"""loads either a maze transformer dataset or an easy_2_hard dataset"""
215		_cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize())
216		return cls.from_base_MazeDataset(
217			cls.from_config(cfg=_cfg_temp, **kwargs),
218			added_params={
219				k: v
220				for k, v in cfg.serialize().items()
221				if k in _RASTERIZED_CFG_ADDED_PARAMS
222			},
223		)
224
225	@classmethod
226	def from_base_MazeDataset(
227		cls,
228		base_dataset: MazeDataset,
229		added_params: dict | None = None,
230	) -> "RasterizedMazeDataset":
231		"""loads either a maze transformer dataset or an easy_2_hard dataset"""
232		if added_params is None:
233			added_params = dict(
234				remove_isolated_cells=True,
235				extend_pixels=True,
236			)
237		cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load(
238			{
239				**base_dataset.cfg.serialize(),
240				**added_params,
241			},
242		)
243		output: RasterizedMazeDataset = cls(
244			cfg=cfg,
245			mazes=base_dataset.mazes,
246		)
247		return output
248
249	def plot(self, count: int | None = None, show: bool = True) -> tuple | None:
250		"""plot the first `count` mazes in the dataset"""
251		import matplotlib.pyplot as plt
252
253		print(f"{self[0][0].shape = }, {self[0][1].shape = }")
254		count = count or len(self)
255		if count == 0:
256			print("No mazes to plot for dataset")
257			return None
258		fig, axes = plt.subplots(2, count, figsize=(15, 5))
259		if count == 1:
260			axes = [axes]
261		for i in range(count):
262			axes[0, i].imshow(self[i][0])
263			axes[1, i].imshow(self[i][1])
264			# remove ticks
265			axes[0, i].set_xticks([])
266			axes[0, i].set_yticks([])
267			axes[1, i].set_xticks([])
268			axes[1, i].set_yticks([])
269
270		if show:
271			plt.show()
272
273		return fig, axes
274
275
276def make_numpy_collection(
277	base_cfg: RasterizedMazeDatasetConfig,
278	grid_sizes: list[int],
279	from_config_kwargs: dict | None = None,
280	verbose: bool = True,
281	key_fmt: str = "{size}x{size}",
282) -> dict[
283	typing.Literal["configs", "arrays"],
284	dict[str, RasterizedMazeDatasetConfig | np.ndarray],
285]:
286	"""create a collection of configs and arrays for different grid sizes, in plain tensor form
287
288	output is of structure:
289	```
290	{
291		"configs": {
292			"<n>x<n>": RasterizedMazeDatasetConfig,
293			...
294		},
295		"arrays": {
296			"<n>x<n>": np.ndarray,
297			...
298		},
299	}
300	```
301	"""
302	if from_config_kwargs is None:
303		from_config_kwargs = {}
304
305	datasets: dict[int, RasterizedMazeDataset] = {}
306
307	for size in grid_sizes:
308		if verbose:
309			print(f"Generating dataset for maze size {size}...")
310
311		cfg_temp: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load(
312			base_cfg.serialize(),
313		)
314		cfg_temp.grid_n = size
315
316		datasets[size] = RasterizedMazeDataset.from_config_augmented(
317			cfg=cfg_temp,
318			**from_config_kwargs,
319		)
320
321	return dict(
322		configs={
323			key_fmt.format(size=size): dataset.cfg for size, dataset in datasets.items()
324		},
325		arrays={
326			# get_batch(None) returns a single tensor of shape (n, 2, x, y, 3)
327			key_fmt.format(size=size): dataset.get_batch(None)
328			for size, dataset in datasets.items()
329		},
330	)

def process_maze_rasterized_input_target( maze: maze_dataset.SolvedMaze, remove_isolated_cells: bool = True, extend_pixels: bool = True, endpoints_as_open: bool = False) -> jaxtyping.Float[ndarray, 'in/tgt=2 x y rgb=3']:
 70def process_maze_rasterized_input_target(
 71	maze: SolvedMaze,
 72	remove_isolated_cells: bool = True,
 73	extend_pixels: bool = True,
 74	endpoints_as_open: bool = False,
 75) -> Float[np.ndarray, "in/tgt=2 x y rgb=3"]:
 76	"""turn a single `SolvedMaze` into an array representation
 77
 78	has extra options for matching the format in https://github.com/aks2203/easy-to-hard
 79
 80	# Parameters:
 81	- `maze: SolvedMaze`
 82		the maze to process
 83	- `remove_isolated_cells: bool`
 84		whether to set isolated cells (no connections) to walls
 85		(default: `True`)
 86	- `extend_pixels: bool`
 87		whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze)
 88		(default: `True`)
 89	- `endpoints_as_open: bool`
 90		whether to set endpoints to open
 91		(default: `False`)
 92	"""
 93	# problem and solution mazes
 94	maze_pixels: PixelGrid = maze.as_pixels(show_endpoints=True, show_solution=True)
 95	problem_maze: PixelGrid = maze_pixels.copy()
 96	solution_maze: PixelGrid = maze_pixels.copy()
 97
 98	# in problem maze, set path to open
 99	problem_maze[(problem_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN
100
101	# wherever solution maze is PixelColors.OPEN, set it to PixelColors.WALL
102	solution_maze[(solution_maze == PixelColors.OPEN).all(axis=-1)] = PixelColors.WALL
103	# wherever it is solution, set it to PixelColors.OPEN
104	solution_maze[(solution_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN
105	if endpoints_as_open:
106		for color in (PixelColors.START, PixelColors.END):
107			solution_maze[(solution_maze == color).all(axis=-1)] = PixelColors.OPEN
108
109	# postprocess to match original easy_2_hard dataset
110	if remove_isolated_cells:
111		problem_maze = _remove_isolated_cells(problem_maze)
112		solution_maze = _remove_isolated_cells(solution_maze)
113
114	if extend_pixels:
115		problem_maze = _extend_pixels(problem_maze)
116		solution_maze = _extend_pixels(solution_maze)
117
118	return np.array([problem_maze, solution_maze])

turn a single SolvedMaze into an array representation

has extra options for matching the format in https://github.com/aks2203/easy-to-hard

Parameters:

  • maze: SolvedMaze the maze to process
  • remove_isolated_cells: bool whether to set isolated cells (no connections) to walls (default: True)
  • extend_pixels: bool whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) (default: True)
  • endpoints_as_open: bool whether to set endpoints to open (default: False)
@serializable_dataclass
class RasterizedMazeDatasetConfig(maze_dataset.dataset.maze_dataset_config.MazeDatasetConfig):
122@serializable_dataclass
123class RasterizedMazeDatasetConfig(MazeDatasetConfig):  # type: ignore[misc]
124	"""adds options which we then pass to `process_maze_rasterized_input_target`
125
126	- `remove_isolated_cells: bool` whether to set isolated cells to walls
127	- `extend_pixels: bool` whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze)
128	- `endpoints_as_open: bool` whether to set endpoints to open
129	"""
130
131	remove_isolated_cells: bool = serializable_field(default=True)
132	extend_pixels: bool = serializable_field(default=True)
133	endpoints_as_open: bool = serializable_field(default=False)

adds options which we then pass to process_maze_rasterized_input_target

  • remove_isolated_cells: bool whether to set isolated cells to walls
  • extend_pixels: bool whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze)
  • endpoints_as_open: bool whether to set endpoints to open
RasterizedMazeDatasetConfig( remove_isolated_cells: bool = True, extend_pixels: bool = True, endpoints_as_open: bool = False, *, name: str, seq_len_min: int = 1, seq_len_max: int = 512, seed: int | None = 42, applied_filters: list[dict[typing.Literal['name', 'args', 'kwargs'], str | list | tuple | dict]] = <factory>, grid_n: int, n_mazes: int, maze_ctor: Callable = <function LatticeMazeGenerators.gen_dfs>, maze_ctor_kwargs: dict = <factory>, endpoint_kwargs: dict[typing.Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | None | list[tuple[int, int]]] = <factory>, _fname_loaded: str | None = None)
remove_isolated_cells: bool = True
extend_pixels: bool = True
endpoints_as_open: bool = False
def serialize(self) -> dict[str, typing.Any]:
714        def serialize(self) -> dict[str, Any]:
715            result: dict[str, Any] = {
716                _FORMAT_KEY: f"{self.__class__.__name__}(SerializableDataclass)"
717            }
718            # for each field in the class
719            for field in dataclasses.fields(self):  # type: ignore[arg-type]
720                # need it to be our special SerializableField
721                if not isinstance(field, SerializableField):
722                    raise NotSerializableFieldException(
723                        f"Field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__} is not a `SerializableField`, "
724                        f"but a {type(field)} "
725                        "this state should be inaccessible, please report this bug!"
726                    )
727
728                # try to save it
729                if field.serialize:
730                    try:
731                        # get the val
732                        value = getattr(self, field.name)
733                        # if it is a serializable dataclass, serialize it
734                        if isinstance(value, SerializableDataclass):
735                            value = value.serialize()
736                        # if the value has a serialization function, use that
737                        if hasattr(value, "serialize") and callable(value.serialize):
738                            value = value.serialize()
739                        # if the field has a serialization function, use that
740                        # it would be nice to be able to override a class's `.serialize()`, but that could lead to some inconsistencies!
741                        elif field.serialization_fn:
742                            value = field.serialization_fn(value)
743
744                        # store the value in the result
745                        result[field.name] = value
746                    except Exception as e:
747                        raise FieldSerializationError(
748                            "\n".join(
749                                [
750                                    f"Error serializing field '{field.name}' on class {self.__class__.__module__}.{self.__class__.__name__}",
751                                    f"{field = }",
752                                    f"{value = }",
753                                    f"{self = }",
754                                ]
755                            )
756                        ) from e
757
758            # store each property if we can get it
759            for prop in self._properties_to_serialize:
760                if hasattr(cls, prop):
761                    value = getattr(self, prop)
762                    result[prop] = value
763                else:
764                    raise AttributeError(
765                        f"Cannot serialize property '{prop}' on class {self.__class__.__module__}.{self.__class__.__name__}"
766                        + f"but it is in {self._properties_to_serialize = }"
767                        + f"\n{self = }"
768                    )
769
770            return result

serialize the MazeDatasetConfig with all fields and fname

@classmethod
def load(cls, data: Union[dict[str, Any], ~T]) -> Type[~T]:
777        @classmethod  # type: ignore[misc]
778        def load(cls, data: dict[str, Any] | T) -> Type[T]:
779            # HACK: this is kind of ugly, but it fixes a lot of issues for when we do recursive loading with ZANJ
780            if isinstance(data, cls):
781                return data
782
783            assert isinstance(
784                data, typing.Mapping
785            ), f"When loading {cls.__name__ = } expected a Mapping, but got {type(data) = }:\n{data = }"
786
787            cls_type_hints: dict[str, Any] = get_cls_type_hints(cls)
788
789            # initialize dict for keeping what we will pass to the constructor
790            ctor_kwargs: dict[str, Any] = dict()
791
792            # iterate over the fields of the class
793            for field in dataclasses.fields(cls):
794                # check if the field is a SerializableField
795                assert isinstance(
796                    field, SerializableField
797                ), f"Field '{field.name}' on class {cls.__name__} is not a SerializableField, but a {type(field)}. this state should be inaccessible, please report this bug!\nhttps://github.com/mivanit/muutils/issues/new"
798
799                # check if the field is in the data and if it should be initialized
800                if (field.name in data) and field.init:
801                    # get the value, we will be processing it
802                    value: Any = data[field.name]
803
804                    # get the type hint for the field
805                    field_type_hint: Any = cls_type_hints.get(field.name, None)
806
807                    # we rely on the init of `SerializableField` to check that only one of `loading_fn` and `deserialize_fn` is set
808                    if field.deserialize_fn:
809                        # if it has a deserialization function, use that
810                        value = field.deserialize_fn(value)
811                    elif field.loading_fn:
812                        # if it has a loading function, use that
813                        value = field.loading_fn(data)
814                    elif (
815                        field_type_hint is not None
816                        and hasattr(field_type_hint, "load")
817                        and callable(field_type_hint.load)
818                    ):
819                        # if no loading function but has a type hint with a load method, use that
820                        if isinstance(value, dict):
821                            value = field_type_hint.load(value)
822                        else:
823                            raise FieldLoadingError(
824                                f"Cannot load value into {field_type_hint}, expected {type(value) = } to be a dict\n{value = }"
825                            )
826                    else:
827                        # assume no loading needs to happen, keep `value` as-is
828                        pass
829
830                    # store the value in the constructor kwargs
831                    ctor_kwargs[field.name] = value
832
833            # create a new instance of the class with the constructor kwargs
834            output: cls = cls(**ctor_kwargs)
835
836            # validate the types of the fields if needed
837            if on_typecheck_mismatch != ErrorMode.IGNORE:
838                fields_valid: dict[str, bool] = (
839                    SerializableDataclass__validate_fields_types__dict(
840                        output,
841                        on_typecheck_error=on_typecheck_error,
842                    )
843                )
844
845                # if there are any fields that are not valid, raise an error
846                if not all(fields_valid.values()):
847                    msg: str = (
848                        f"Type mismatch in fields of {cls.__name__}:\n"
849                        + "\n".join(
850                            [
851                                f"{k}:\texpected {cls_type_hints[k] = }, but got value {getattr(output, k) = }, {type(getattr(output, k)) = }"
852                                for k, v in fields_valid.items()
853                                if not v
854                            ]
855                        )
856                    )
857
858                    on_typecheck_mismatch.process(
859                        msg, except_cls=FieldTypeMismatchError
860                    )
861
862            # return the new instance
863            return output

takes in an appropriately structured dict and returns an instance of the class, implemented by using @serializable_dataclass decorator

def validate_fields_types( self: muutils.json_serialize.serializable_dataclass.SerializableDataclass, on_typecheck_error: muutils.errormode.ErrorMode = ErrorMode.Except) -> bool:
283def SerializableDataclass__validate_fields_types(
284    self: SerializableDataclass,
285    on_typecheck_error: ErrorMode = _DEFAULT_ON_TYPECHECK_ERROR,
286) -> bool:
287    """validate the types of all the fields on a `SerializableDataclass`. calls `SerializableDataclass__validate_field_type` for each field"""
288    return all(
289        SerializableDataclass__validate_fields_types__dict(
290            self, on_typecheck_error=on_typecheck_error
291        ).values()
292    )

validate the types of all the fields on a SerializableDataclass. calls SerializableDataclass__validate_field_type for each field

136class RasterizedMazeDataset(MazeDataset):
137	"subclass of `MazeDataset` that uses a `RasterizedMazeDatasetConfig`"
138
139	cfg: RasterizedMazeDatasetConfig
140
141	# this override here is intentional
142	def __getitem__(self, idx: int) -> Float[np.ndarray, "item in/tgt=2 x y rgb=3"]:  # type: ignore[override]
143		"""get a single maze"""
144		# get the solved maze
145		solved_maze: SolvedMaze = self.mazes[idx]
146
147		return process_maze_rasterized_input_target(
148			maze=solved_maze,
149			remove_isolated_cells=self.cfg.remove_isolated_cells,
150			extend_pixels=self.cfg.extend_pixels,
151			endpoints_as_open=self.cfg.endpoints_as_open,
152		)
153
154	def get_batch(
155		self,
156		idxs: list[int] | None,
157	) -> Float[np.ndarray, "in/tgt=2 item x y rgb=3"]:
158		"""get a batch of mazes as a tensor, from a list of indices"""
159		if idxs is None:
160			idxs = list(range(len(self)))
161
162		inputs: list[Float[np.ndarray, "x y rgb=3"]]
163		targets: list[Float[np.ndarray, "x y rgb=3"]]
164		inputs, targets = zip(*[self[i] for i in idxs], strict=False)  # type: ignore[assignment]
165
166		return np.array([inputs, targets])
167
168	# override here is intentional
169	@classmethod
170	def from_config(
171		cls,
172		cfg: RasterizedMazeDatasetConfig | MazeDatasetConfig,  # type: ignore[override]
173		do_generate: bool = True,
174		load_local: bool = True,
175		save_local: bool = True,
176		zanj: ZANJ | None = None,
177		do_download: bool = True,
178		local_base_path: Path = Path("data/maze_dataset"),
179		except_on_config_mismatch: bool = True,
180		allow_generation_metadata_filter_mismatch: bool = True,
181		verbose: bool = False,
182		**kwargs,
183	) -> "RasterizedMazeDataset":
184		"""create a rasterized maze dataset from a config
185
186		priority of loading:
187		1. load from local
188		2. download
189		3. generate
190
191		"""
192		return typing.cast(
193			RasterizedMazeDataset,
194			super().from_config(
195				cfg=cfg,
196				do_generate=do_generate,
197				load_local=load_local,
198				save_local=save_local,
199				zanj=zanj,
200				do_download=do_download,
201				local_base_path=local_base_path,
202				except_on_config_mismatch=except_on_config_mismatch,
203				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
204				verbose=verbose,
205				**kwargs,
206			),
207		)
208
209	@classmethod
210	def from_config_augmented(
211		cls,
212		cfg: RasterizedMazeDatasetConfig,
213		**kwargs,
214	) -> "RasterizedMazeDataset":
215		"""loads either a maze transformer dataset or an easy_2_hard dataset"""
216		_cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize())
217		return cls.from_base_MazeDataset(
218			cls.from_config(cfg=_cfg_temp, **kwargs),
219			added_params={
220				k: v
221				for k, v in cfg.serialize().items()
222				if k in _RASTERIZED_CFG_ADDED_PARAMS
223			},
224		)
225
226	@classmethod
227	def from_base_MazeDataset(
228		cls,
229		base_dataset: MazeDataset,
230		added_params: dict | None = None,
231	) -> "RasterizedMazeDataset":
232		"""loads either a maze transformer dataset or an easy_2_hard dataset"""
233		if added_params is None:
234			added_params = dict(
235				remove_isolated_cells=True,
236				extend_pixels=True,
237			)
238		cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load(
239			{
240				**base_dataset.cfg.serialize(),
241				**added_params,
242			},
243		)
244		output: RasterizedMazeDataset = cls(
245			cfg=cfg,
246			mazes=base_dataset.mazes,
247		)
248		return output
249
250	def plot(self, count: int | None = None, show: bool = True) -> tuple | None:
251		"""plot the first `count` mazes in the dataset"""
252		import matplotlib.pyplot as plt
253
254		print(f"{self[0][0].shape = }, {self[0][1].shape = }")
255		count = count or len(self)
256		if count == 0:
257			print("No mazes to plot for dataset")
258			return None
259		fig, axes = plt.subplots(2, count, figsize=(15, 5))
260		if count == 1:
261			axes = [axes]
262		for i in range(count):
263			axes[0, i].imshow(self[i][0])
264			axes[1, i].imshow(self[i][1])
265			# remove ticks
266			axes[0, i].set_xticks([])
267			axes[0, i].set_yticks([])
268			axes[1, i].set_xticks([])
269			axes[1, i].set_yticks([])
270
271		if show:
272			plt.show()
273
274		return fig, axes

subclass of MazeDataset that uses a RasterizedMazeDatasetConfig

def get_batch( self, idxs: list[int] | None) -> jaxtyping.Float[ndarray, 'in/tgt=2 item x y rgb=3']:
154	def get_batch(
155		self,
156		idxs: list[int] | None,
157	) -> Float[np.ndarray, "in/tgt=2 item x y rgb=3"]:
158		"""get a batch of mazes as a tensor, from a list of indices"""
159		if idxs is None:
160			idxs = list(range(len(self)))
161
162		inputs: list[Float[np.ndarray, "x y rgb=3"]]
163		targets: list[Float[np.ndarray, "x y rgb=3"]]
164		inputs, targets = zip(*[self[i] for i in idxs], strict=False)  # type: ignore[assignment]
165
166		return np.array([inputs, targets])

get a batch of mazes as a tensor, from a list of indices

@classmethod
def from_config( cls, cfg: RasterizedMazeDatasetConfig | maze_dataset.MazeDatasetConfig, do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: zanj.zanj.ZANJ | None = None, do_download: bool = True, local_base_path: pathlib.Path = PosixPath('data/maze_dataset'), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs) -> RasterizedMazeDataset:
169	@classmethod
170	def from_config(
171		cls,
172		cfg: RasterizedMazeDatasetConfig | MazeDatasetConfig,  # type: ignore[override]
173		do_generate: bool = True,
174		load_local: bool = True,
175		save_local: bool = True,
176		zanj: ZANJ | None = None,
177		do_download: bool = True,
178		local_base_path: Path = Path("data/maze_dataset"),
179		except_on_config_mismatch: bool = True,
180		allow_generation_metadata_filter_mismatch: bool = True,
181		verbose: bool = False,
182		**kwargs,
183	) -> "RasterizedMazeDataset":
184		"""create a rasterized maze dataset from a config
185
186		priority of loading:
187		1. load from local
188		2. download
189		3. generate
190
191		"""
192		return typing.cast(
193			RasterizedMazeDataset,
194			super().from_config(
195				cfg=cfg,
196				do_generate=do_generate,
197				load_local=load_local,
198				save_local=save_local,
199				zanj=zanj,
200				do_download=do_download,
201				local_base_path=local_base_path,
202				except_on_config_mismatch=except_on_config_mismatch,
203				allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
204				verbose=verbose,
205				**kwargs,
206			),
207		)

create a rasterized maze dataset from a config

priority of loading:

  1. load from local
  2. download
  3. generate
@classmethod
def from_config_augmented( cls, cfg: RasterizedMazeDatasetConfig, **kwargs) -> RasterizedMazeDataset:
209	@classmethod
210	def from_config_augmented(
211		cls,
212		cfg: RasterizedMazeDatasetConfig,
213		**kwargs,
214	) -> "RasterizedMazeDataset":
215		"""loads either a maze transformer dataset or an easy_2_hard dataset"""
216		_cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize())
217		return cls.from_base_MazeDataset(
218			cls.from_config(cfg=_cfg_temp, **kwargs),
219			added_params={
220				k: v
221				for k, v in cfg.serialize().items()
222				if k in _RASTERIZED_CFG_ADDED_PARAMS
223			},
224		)

loads either a maze transformer dataset or an easy_2_hard dataset

@classmethod
def from_base_MazeDataset( cls, base_dataset: maze_dataset.MazeDataset, added_params: dict | None = None) -> RasterizedMazeDataset:
226	@classmethod
227	def from_base_MazeDataset(
228		cls,
229		base_dataset: MazeDataset,
230		added_params: dict | None = None,
231	) -> "RasterizedMazeDataset":
232		"""loads either a maze transformer dataset or an easy_2_hard dataset"""
233		if added_params is None:
234			added_params = dict(
235				remove_isolated_cells=True,
236				extend_pixels=True,
237			)
238		cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load(
239			{
240				**base_dataset.cfg.serialize(),
241				**added_params,
242			},
243		)
244		output: RasterizedMazeDataset = cls(
245			cfg=cfg,
246			mazes=base_dataset.mazes,
247		)
248		return output

loads either a maze transformer dataset or an easy_2_hard dataset

def plot(self, count: int | None = None, show: bool = True) -> tuple | None:
250	def plot(self, count: int | None = None, show: bool = True) -> tuple | None:
251		"""plot the first `count` mazes in the dataset"""
252		import matplotlib.pyplot as plt
253
254		print(f"{self[0][0].shape = }, {self[0][1].shape = }")
255		count = count or len(self)
256		if count == 0:
257			print("No mazes to plot for dataset")
258			return None
259		fig, axes = plt.subplots(2, count, figsize=(15, 5))
260		if count == 1:
261			axes = [axes]
262		for i in range(count):
263			axes[0, i].imshow(self[i][0])
264			axes[1, i].imshow(self[i][1])
265			# remove ticks
266			axes[0, i].set_xticks([])
267			axes[0, i].set_yticks([])
268			axes[1, i].set_xticks([])
269			axes[1, i].set_yticks([])
270
271		if show:
272			plt.show()
273
274		return fig, axes

plot the first count mazes in the dataset

def make_numpy_collection( base_cfg: RasterizedMazeDatasetConfig, grid_sizes: list[int], from_config_kwargs: dict | None = None, verbose: bool = True, key_fmt: str = '{size}x{size}') -> dict[typing.Literal['configs', 'arrays'], dict[str, RasterizedMazeDatasetConfig | numpy.ndarray]]:
277def make_numpy_collection(
278	base_cfg: RasterizedMazeDatasetConfig,
279	grid_sizes: list[int],
280	from_config_kwargs: dict | None = None,
281	verbose: bool = True,
282	key_fmt: str = "{size}x{size}",
283) -> dict[
284	typing.Literal["configs", "arrays"],
285	dict[str, RasterizedMazeDatasetConfig | np.ndarray],
286]:
287	"""create a collection of configs and arrays for different grid sizes, in plain tensor form
288
289	output is of structure:
290	```
291	{
292		"configs": {
293			"<n>x<n>": RasterizedMazeDatasetConfig,
294			...
295		},
296		"arrays": {
297			"<n>x<n>": np.ndarray,
298			...
299		},
300	}
301	```
302	"""
303	if from_config_kwargs is None:
304		from_config_kwargs = {}
305
306	datasets: dict[int, RasterizedMazeDataset] = {}
307
308	for size in grid_sizes:
309		if verbose:
310			print(f"Generating dataset for maze size {size}...")
311
312		cfg_temp: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load(
313			base_cfg.serialize(),
314		)
315		cfg_temp.grid_n = size
316
317		datasets[size] = RasterizedMazeDataset.from_config_augmented(
318			cfg=cfg_temp,
319			**from_config_kwargs,
320		)
321
322	return dict(
323		configs={
324			key_fmt.format(size=size): dataset.cfg for size, dataset in datasets.items()
325		},
326		arrays={
327			# get_batch(None) returns a single tensor of shape (n, 2, x, y, 3)
328			key_fmt.format(size=size): dataset.get_batch(None)
329			for size, dataset in datasets.items()
330		},
331	)

create a collection of configs and arrays for different grid sizes, in plain tensor form

output is of structure:

{
        "configs": {
                "<n>x<n>": RasterizedMazeDatasetConfig,
                ...
        },
        "arrays": {
                "<n>x<n>": np.ndarray,
                ...
        },
}