Coverage for maze_dataset/dataset/rasterized.py: 77%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-20 17:51 -0600

1"""a special `RasterizedMazeDataset` that returns 2 images, one for input and one for target, for each maze 

2 

3this lets you match the input and target format of the [`easy_2_hard`](https://github.com/aks2203/easy-to-hard) dataset 

4 

5 

6see their paper: 

7 

8```bibtex 

9@misc{schwarzschild2021learn, 

10 title={Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks}, 

11 author={Avi Schwarzschild and Eitan Borgnia and Arjun Gupta and Furong Huang and Uzi Vishkin and Micah Goldblum and Tom Goldstein}, 

12 year={2021}, 

13 eprint={2106.04537}, 

14 archivePrefix={arXiv}, 

15 primaryClass={cs.LG} 

16} 

17``` 

18""" 

19 

20import typing 

21from pathlib import Path 

22 

23import numpy as np 

24from jaxtyping import Float, Int 

25from muutils.json_serialize import serializable_dataclass, serializable_field 

26from zanj import ZANJ 

27 

28from maze_dataset import MazeDataset, MazeDatasetConfig 

29from maze_dataset.maze import PixelColors, SolvedMaze 

30from maze_dataset.maze.lattice_maze import PixelGrid, _remove_isolated_cells 

31 

32 

33def _extend_pixels( 

34 image: Int[np.ndarray, "x y rgb"], 

35 n_mult: int = 2, 

36 n_bdry: int = 1, 

37) -> Int[np.ndarray, "n_mult*x+2*n_bdry n_mult*y+2*n_bdry rgb"]: 

38 wall_fill: int = PixelColors.WALL[0] 

39 assert all(x == wall_fill for x in PixelColors.WALL), ( 

40 "PixelColors.WALL must be a single value" 

41 ) 

42 

43 output: np.ndarray = np.repeat( 

44 np.repeat( 

45 image, 

46 n_mult, 

47 axis=0, 

48 ), 

49 n_mult, 

50 axis=1, 

51 ) 

52 

53 # pad on all sides by n_bdry 

54 return np.pad( 

55 output, 

56 pad_width=((n_bdry, n_bdry), (n_bdry, n_bdry), (0, 0)), 

57 mode="constant", 

58 constant_values=wall_fill, 

59 ) 

60 

61 

62_RASTERIZED_CFG_ADDED_PARAMS: list[str] = [ 

63 "remove_isolated_cells", 

64 "extend_pixels", 

65 "endpoints_as_open", 

66] 

67 

68 

69def process_maze_rasterized_input_target( 

70 maze: SolvedMaze, 

71 remove_isolated_cells: bool = True, 

72 extend_pixels: bool = True, 

73 endpoints_as_open: bool = False, 

74) -> Float[np.ndarray, "in/tgt=2 x y rgb=3"]: 

75 """turn a single `SolvedMaze` into an array representation 

76 

77 has extra options for matching the format in https://github.com/aks2203/easy-to-hard 

78 

79 # Parameters: 

80 - `maze: SolvedMaze` 

81 the maze to process 

82 - `remove_isolated_cells: bool` 

83 whether to set isolated cells (no connections) to walls 

84 (default: `True`) 

85 - `extend_pixels: bool` 

86 whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) 

87 (default: `True`) 

88 - `endpoints_as_open: bool` 

89 whether to set endpoints to open 

90 (default: `False`) 

91 """ 

92 # problem and solution mazes 

93 maze_pixels: PixelGrid = maze.as_pixels(show_endpoints=True, show_solution=True) 

94 problem_maze: PixelGrid = maze_pixels.copy() 

95 solution_maze: PixelGrid = maze_pixels.copy() 

96 

97 # in problem maze, set path to open 

98 problem_maze[(problem_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN 

99 

100 # wherever solution maze is PixelColors.OPEN, set it to PixelColors.WALL 

101 solution_maze[(solution_maze == PixelColors.OPEN).all(axis=-1)] = PixelColors.WALL 

102 # wherever it is solution, set it to PixelColors.OPEN 

103 solution_maze[(solution_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN 

104 if endpoints_as_open: 

105 for color in (PixelColors.START, PixelColors.END): 

106 solution_maze[(solution_maze == color).all(axis=-1)] = PixelColors.OPEN 

107 

108 # postprocess to match original easy_2_hard dataset 

109 if remove_isolated_cells: 

110 problem_maze = _remove_isolated_cells(problem_maze) 

111 solution_maze = _remove_isolated_cells(solution_maze) 

112 

113 if extend_pixels: 

114 problem_maze = _extend_pixels(problem_maze) 

115 solution_maze = _extend_pixels(solution_maze) 

116 

117 return np.array([problem_maze, solution_maze]) 

118 

119 

120# TYPING: error: Attributes without a default cannot follow attributes with one [misc] 

121@serializable_dataclass 

122class RasterizedMazeDatasetConfig(MazeDatasetConfig): # type: ignore[misc] 

123 """adds options which we then pass to `process_maze_rasterized_input_target` 

124 

125 - `remove_isolated_cells: bool` whether to set isolated cells to walls 

126 - `extend_pixels: bool` whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) 

127 - `endpoints_as_open: bool` whether to set endpoints to open 

128 """ 

129 

130 remove_isolated_cells: bool = serializable_field(default=True) 

131 extend_pixels: bool = serializable_field(default=True) 

132 endpoints_as_open: bool = serializable_field(default=False) 

133 

134 

135class RasterizedMazeDataset(MazeDataset): 

136 "subclass of `MazeDataset` that uses a `RasterizedMazeDatasetConfig`" 

137 

138 cfg: RasterizedMazeDatasetConfig 

139 

140 # this override here is intentional 

141 def __getitem__(self, idx: int) -> Float[np.ndarray, "item in/tgt=2 x y rgb=3"]: # type: ignore[override] 

142 """get a single maze""" 

143 # get the solved maze 

144 solved_maze: SolvedMaze = self.mazes[idx] 

145 

146 return process_maze_rasterized_input_target( 

147 maze=solved_maze, 

148 remove_isolated_cells=self.cfg.remove_isolated_cells, 

149 extend_pixels=self.cfg.extend_pixels, 

150 endpoints_as_open=self.cfg.endpoints_as_open, 

151 ) 

152 

153 def get_batch( 

154 self, 

155 idxs: list[int] | None, 

156 ) -> Float[np.ndarray, "in/tgt=2 item x y rgb=3"]: 

157 """get a batch of mazes as a tensor, from a list of indices""" 

158 if idxs is None: 

159 idxs = list(range(len(self))) 

160 

161 inputs: list[Float[np.ndarray, "x y rgb=3"]] 

162 targets: list[Float[np.ndarray, "x y rgb=3"]] 

163 inputs, targets = zip(*[self[i] for i in idxs], strict=False) # type: ignore[assignment] 

164 

165 return np.array([inputs, targets]) 

166 

167 # override here is intentional 

168 @classmethod 

169 def from_config( 

170 cls, 

171 cfg: RasterizedMazeDatasetConfig | MazeDatasetConfig, # type: ignore[override] 

172 do_generate: bool = True, 

173 load_local: bool = True, 

174 save_local: bool = True, 

175 zanj: ZANJ | None = None, 

176 do_download: bool = True, 

177 local_base_path: Path = Path("data/maze_dataset"), 

178 except_on_config_mismatch: bool = True, 

179 allow_generation_metadata_filter_mismatch: bool = True, 

180 verbose: bool = False, 

181 **kwargs, 

182 ) -> "RasterizedMazeDataset": 

183 """create a rasterized maze dataset from a config 

184 

185 priority of loading: 

186 1. load from local 

187 2. download 

188 3. generate 

189 

190 """ 

191 return typing.cast( 

192 RasterizedMazeDataset, 

193 super().from_config( 

194 cfg=cfg, 

195 do_generate=do_generate, 

196 load_local=load_local, 

197 save_local=save_local, 

198 zanj=zanj, 

199 do_download=do_download, 

200 local_base_path=local_base_path, 

201 except_on_config_mismatch=except_on_config_mismatch, 

202 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 

203 verbose=verbose, 

204 **kwargs, 

205 ), 

206 ) 

207 

208 @classmethod 

209 def from_config_augmented( 

210 cls, 

211 cfg: RasterizedMazeDatasetConfig, 

212 **kwargs, 

213 ) -> "RasterizedMazeDataset": 

214 """loads either a maze transformer dataset or an easy_2_hard dataset""" 

215 _cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize()) 

216 return cls.from_base_MazeDataset( 

217 cls.from_config(cfg=_cfg_temp, **kwargs), 

218 added_params={ 

219 k: v 

220 for k, v in cfg.serialize().items() 

221 if k in _RASTERIZED_CFG_ADDED_PARAMS 

222 }, 

223 ) 

224 

225 @classmethod 

226 def from_base_MazeDataset( 

227 cls, 

228 base_dataset: MazeDataset, 

229 added_params: dict | None = None, 

230 ) -> "RasterizedMazeDataset": 

231 """loads either a maze transformer dataset or an easy_2_hard dataset""" 

232 if added_params is None: 

233 added_params = dict( 

234 remove_isolated_cells=True, 

235 extend_pixels=True, 

236 ) 

237 cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( 

238 { 

239 **base_dataset.cfg.serialize(), 

240 **added_params, 

241 }, 

242 ) 

243 output: RasterizedMazeDataset = cls( 

244 cfg=cfg, 

245 mazes=base_dataset.mazes, 

246 ) 

247 return output 

248 

249 def plot(self, count: int | None = None, show: bool = True) -> tuple | None: 

250 """plot the first `count` mazes in the dataset""" 

251 import matplotlib.pyplot as plt 

252 

253 print(f"{self[0][0].shape = }, {self[0][1].shape = }") 

254 count = count or len(self) 

255 if count == 0: 

256 print("No mazes to plot for dataset") 

257 return None 

258 fig, axes = plt.subplots(2, count, figsize=(15, 5)) 

259 if count == 1: 

260 axes = [axes] 

261 for i in range(count): 

262 axes[0, i].imshow(self[i][0]) 

263 axes[1, i].imshow(self[i][1]) 

264 # remove ticks 

265 axes[0, i].set_xticks([]) 

266 axes[0, i].set_yticks([]) 

267 axes[1, i].set_xticks([]) 

268 axes[1, i].set_yticks([]) 

269 

270 if show: 

271 plt.show() 

272 

273 return fig, axes 

274 

275 

276def make_numpy_collection( 

277 base_cfg: RasterizedMazeDatasetConfig, 

278 grid_sizes: list[int], 

279 from_config_kwargs: dict | None = None, 

280 verbose: bool = True, 

281 key_fmt: str = "{size}x{size}", 

282) -> dict[ 

283 typing.Literal["configs", "arrays"], 

284 dict[str, RasterizedMazeDatasetConfig | np.ndarray], 

285]: 

286 """create a collection of configs and arrays for different grid sizes, in plain tensor form 

287 

288 output is of structure: 

289 ``` 

290 { 

291 "configs": { 

292 "<n>x<n>": RasterizedMazeDatasetConfig, 

293 ... 

294 }, 

295 "arrays": { 

296 "<n>x<n>": np.ndarray, 

297 ... 

298 }, 

299 } 

300 ``` 

301 """ 

302 if from_config_kwargs is None: 

303 from_config_kwargs = {} 

304 

305 datasets: dict[int, RasterizedMazeDataset] = {} 

306 

307 for size in grid_sizes: 

308 if verbose: 

309 print(f"Generating dataset for maze size {size}...") 

310 

311 cfg_temp: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( 

312 base_cfg.serialize(), 

313 ) 

314 cfg_temp.grid_n = size 

315 

316 datasets[size] = RasterizedMazeDataset.from_config_augmented( 

317 cfg=cfg_temp, 

318 **from_config_kwargs, 

319 ) 

320 

321 return dict( 

322 configs={ 

323 key_fmt.format(size=size): dataset.cfg for size, dataset in datasets.items() 

324 }, 

325 arrays={ 

326 # get_batch(None) returns a single tensor of shape (n, 2, x, y, 3) 

327 key_fmt.format(size=size): dataset.get_batch(None) 

328 for size, dataset in datasets.items() 

329 }, 

330 )