Coverage for maze_dataset/dataset/rasterized.py: 77%
92 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-20 17:51 -0600
1"""a special `RasterizedMazeDataset` that returns 2 images, one for input and one for target, for each maze
3this lets you match the input and target format of the [`easy_2_hard`](https://github.com/aks2203/easy-to-hard) dataset
6see their paper:
8```bibtex
9@misc{schwarzschild2021learn,
10 title={Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks},
11 author={Avi Schwarzschild and Eitan Borgnia and Arjun Gupta and Furong Huang and Uzi Vishkin and Micah Goldblum and Tom Goldstein},
12 year={2021},
13 eprint={2106.04537},
14 archivePrefix={arXiv},
15 primaryClass={cs.LG}
16}
17```
18"""
20import typing
21from pathlib import Path
23import numpy as np
24from jaxtyping import Float, Int
25from muutils.json_serialize import serializable_dataclass, serializable_field
26from zanj import ZANJ
28from maze_dataset import MazeDataset, MazeDatasetConfig
29from maze_dataset.maze import PixelColors, SolvedMaze
30from maze_dataset.maze.lattice_maze import PixelGrid, _remove_isolated_cells
33def _extend_pixels(
34 image: Int[np.ndarray, "x y rgb"],
35 n_mult: int = 2,
36 n_bdry: int = 1,
37) -> Int[np.ndarray, "n_mult*x+2*n_bdry n_mult*y+2*n_bdry rgb"]:
38 wall_fill: int = PixelColors.WALL[0]
39 assert all(x == wall_fill for x in PixelColors.WALL), (
40 "PixelColors.WALL must be a single value"
41 )
43 output: np.ndarray = np.repeat(
44 np.repeat(
45 image,
46 n_mult,
47 axis=0,
48 ),
49 n_mult,
50 axis=1,
51 )
53 # pad on all sides by n_bdry
54 return np.pad(
55 output,
56 pad_width=((n_bdry, n_bdry), (n_bdry, n_bdry), (0, 0)),
57 mode="constant",
58 constant_values=wall_fill,
59 )
62_RASTERIZED_CFG_ADDED_PARAMS: list[str] = [
63 "remove_isolated_cells",
64 "extend_pixels",
65 "endpoints_as_open",
66]
69def process_maze_rasterized_input_target(
70 maze: SolvedMaze,
71 remove_isolated_cells: bool = True,
72 extend_pixels: bool = True,
73 endpoints_as_open: bool = False,
74) -> Float[np.ndarray, "in/tgt=2 x y rgb=3"]:
75 """turn a single `SolvedMaze` into an array representation
77 has extra options for matching the format in https://github.com/aks2203/easy-to-hard
79 # Parameters:
80 - `maze: SolvedMaze`
81 the maze to process
82 - `remove_isolated_cells: bool`
83 whether to set isolated cells (no connections) to walls
84 (default: `True`)
85 - `extend_pixels: bool`
86 whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze)
87 (default: `True`)
88 - `endpoints_as_open: bool`
89 whether to set endpoints to open
90 (default: `False`)
91 """
92 # problem and solution mazes
93 maze_pixels: PixelGrid = maze.as_pixels(show_endpoints=True, show_solution=True)
94 problem_maze: PixelGrid = maze_pixels.copy()
95 solution_maze: PixelGrid = maze_pixels.copy()
97 # in problem maze, set path to open
98 problem_maze[(problem_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN
100 # wherever solution maze is PixelColors.OPEN, set it to PixelColors.WALL
101 solution_maze[(solution_maze == PixelColors.OPEN).all(axis=-1)] = PixelColors.WALL
102 # wherever it is solution, set it to PixelColors.OPEN
103 solution_maze[(solution_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN
104 if endpoints_as_open:
105 for color in (PixelColors.START, PixelColors.END):
106 solution_maze[(solution_maze == color).all(axis=-1)] = PixelColors.OPEN
108 # postprocess to match original easy_2_hard dataset
109 if remove_isolated_cells:
110 problem_maze = _remove_isolated_cells(problem_maze)
111 solution_maze = _remove_isolated_cells(solution_maze)
113 if extend_pixels:
114 problem_maze = _extend_pixels(problem_maze)
115 solution_maze = _extend_pixels(solution_maze)
117 return np.array([problem_maze, solution_maze])
120# TYPING: error: Attributes without a default cannot follow attributes with one [misc]
121@serializable_dataclass
122class RasterizedMazeDatasetConfig(MazeDatasetConfig): # type: ignore[misc]
123 """adds options which we then pass to `process_maze_rasterized_input_target`
125 - `remove_isolated_cells: bool` whether to set isolated cells to walls
126 - `extend_pixels: bool` whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze)
127 - `endpoints_as_open: bool` whether to set endpoints to open
128 """
130 remove_isolated_cells: bool = serializable_field(default=True)
131 extend_pixels: bool = serializable_field(default=True)
132 endpoints_as_open: bool = serializable_field(default=False)
135class RasterizedMazeDataset(MazeDataset):
136 "subclass of `MazeDataset` that uses a `RasterizedMazeDatasetConfig`"
138 cfg: RasterizedMazeDatasetConfig
140 # this override here is intentional
141 def __getitem__(self, idx: int) -> Float[np.ndarray, "item in/tgt=2 x y rgb=3"]: # type: ignore[override]
142 """get a single maze"""
143 # get the solved maze
144 solved_maze: SolvedMaze = self.mazes[idx]
146 return process_maze_rasterized_input_target(
147 maze=solved_maze,
148 remove_isolated_cells=self.cfg.remove_isolated_cells,
149 extend_pixels=self.cfg.extend_pixels,
150 endpoints_as_open=self.cfg.endpoints_as_open,
151 )
153 def get_batch(
154 self,
155 idxs: list[int] | None,
156 ) -> Float[np.ndarray, "in/tgt=2 item x y rgb=3"]:
157 """get a batch of mazes as a tensor, from a list of indices"""
158 if idxs is None:
159 idxs = list(range(len(self)))
161 inputs: list[Float[np.ndarray, "x y rgb=3"]]
162 targets: list[Float[np.ndarray, "x y rgb=3"]]
163 inputs, targets = zip(*[self[i] for i in idxs], strict=False) # type: ignore[assignment]
165 return np.array([inputs, targets])
167 # override here is intentional
168 @classmethod
169 def from_config(
170 cls,
171 cfg: RasterizedMazeDatasetConfig | MazeDatasetConfig, # type: ignore[override]
172 do_generate: bool = True,
173 load_local: bool = True,
174 save_local: bool = True,
175 zanj: ZANJ | None = None,
176 do_download: bool = True,
177 local_base_path: Path = Path("data/maze_dataset"),
178 except_on_config_mismatch: bool = True,
179 allow_generation_metadata_filter_mismatch: bool = True,
180 verbose: bool = False,
181 **kwargs,
182 ) -> "RasterizedMazeDataset":
183 """create a rasterized maze dataset from a config
185 priority of loading:
186 1. load from local
187 2. download
188 3. generate
190 """
191 return typing.cast(
192 RasterizedMazeDataset,
193 super().from_config(
194 cfg=cfg,
195 do_generate=do_generate,
196 load_local=load_local,
197 save_local=save_local,
198 zanj=zanj,
199 do_download=do_download,
200 local_base_path=local_base_path,
201 except_on_config_mismatch=except_on_config_mismatch,
202 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch,
203 verbose=verbose,
204 **kwargs,
205 ),
206 )
208 @classmethod
209 def from_config_augmented(
210 cls,
211 cfg: RasterizedMazeDatasetConfig,
212 **kwargs,
213 ) -> "RasterizedMazeDataset":
214 """loads either a maze transformer dataset or an easy_2_hard dataset"""
215 _cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize())
216 return cls.from_base_MazeDataset(
217 cls.from_config(cfg=_cfg_temp, **kwargs),
218 added_params={
219 k: v
220 for k, v in cfg.serialize().items()
221 if k in _RASTERIZED_CFG_ADDED_PARAMS
222 },
223 )
225 @classmethod
226 def from_base_MazeDataset(
227 cls,
228 base_dataset: MazeDataset,
229 added_params: dict | None = None,
230 ) -> "RasterizedMazeDataset":
231 """loads either a maze transformer dataset or an easy_2_hard dataset"""
232 if added_params is None:
233 added_params = dict(
234 remove_isolated_cells=True,
235 extend_pixels=True,
236 )
237 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load(
238 {
239 **base_dataset.cfg.serialize(),
240 **added_params,
241 },
242 )
243 output: RasterizedMazeDataset = cls(
244 cfg=cfg,
245 mazes=base_dataset.mazes,
246 )
247 return output
249 def plot(self, count: int | None = None, show: bool = True) -> tuple | None:
250 """plot the first `count` mazes in the dataset"""
251 import matplotlib.pyplot as plt
253 print(f"{self[0][0].shape = }, {self[0][1].shape = }")
254 count = count or len(self)
255 if count == 0:
256 print("No mazes to plot for dataset")
257 return None
258 fig, axes = plt.subplots(2, count, figsize=(15, 5))
259 if count == 1:
260 axes = [axes]
261 for i in range(count):
262 axes[0, i].imshow(self[i][0])
263 axes[1, i].imshow(self[i][1])
264 # remove ticks
265 axes[0, i].set_xticks([])
266 axes[0, i].set_yticks([])
267 axes[1, i].set_xticks([])
268 axes[1, i].set_yticks([])
270 if show:
271 plt.show()
273 return fig, axes
276def make_numpy_collection(
277 base_cfg: RasterizedMazeDatasetConfig,
278 grid_sizes: list[int],
279 from_config_kwargs: dict | None = None,
280 verbose: bool = True,
281 key_fmt: str = "{size}x{size}",
282) -> dict[
283 typing.Literal["configs", "arrays"],
284 dict[str, RasterizedMazeDatasetConfig | np.ndarray],
285]:
286 """create a collection of configs and arrays for different grid sizes, in plain tensor form
288 output is of structure:
289 ```
290 {
291 "configs": {
292 "<n>x<n>": RasterizedMazeDatasetConfig,
293 ...
294 },
295 "arrays": {
296 "<n>x<n>": np.ndarray,
297 ...
298 },
299 }
300 ```
301 """
302 if from_config_kwargs is None:
303 from_config_kwargs = {}
305 datasets: dict[int, RasterizedMazeDataset] = {}
307 for size in grid_sizes:
308 if verbose:
309 print(f"Generating dataset for maze size {size}...")
311 cfg_temp: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load(
312 base_cfg.serialize(),
313 )
314 cfg_temp.grid_n = size
316 datasets[size] = RasterizedMazeDataset.from_config_augmented(
317 cfg=cfg_temp,
318 **from_config_kwargs,
319 )
321 return dict(
322 configs={
323 key_fmt.format(size=size): dataset.cfg for size, dataset in datasets.items()
324 },
325 arrays={
326 # get_batch(None) returns a single tensor of shape (n, 2, x, y, 3)
327 key_fmt.format(size=size): dataset.get_batch(None)
328 for size, dataset in datasets.items()
329 },
330 )