Coverage for maze_dataset/plotting/plot_maze.py: 81%

222 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 14:35 -0600

1"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more""" 

2 

3from __future__ import annotations # for type hinting self as return value 

4 

5import warnings 

6from copy import deepcopy 

7from dataclasses import dataclass 

8from typing import Sequence 

9 

10import matplotlib as mpl 

11import matplotlib.pyplot as plt 

12import numpy as np 

13from jaxtyping import Bool, Float 

14 

15from maze_dataset.constants import Coord, CoordArray, CoordList 

16from maze_dataset.maze import ( 

17 LatticeMaze, 

18 SolvedMaze, 

19 TargetedLatticeMaze, 

20) 

21 

22LARGE_NEGATIVE_NUMBER: float = -1e10 

23 

24 

25@dataclass(kw_only=True) 

26class PathFormat: 

27 """formatting options for path plot""" 

28 

29 label: str | None = None 

30 fmt: str = "o" 

31 color: str | None = None 

32 cmap: str | None = None 

33 line_width: float | None = None 

34 quiver_kwargs: dict | None = None 

35 

36 def combine(self, other: PathFormat) -> PathFormat: 

37 """combine with other PathFormat object, overwriting attributes with non-None values. 

38 

39 returns a modified copy of self. 

40 """ 

41 output: PathFormat = deepcopy(self) 

42 for key, value in other.__dict__.items(): 

43 if key == "path": 

44 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 

45 raise ValueError( 

46 err_msg, 

47 ) 

48 if value is not None: 

49 setattr(output, key, value) 

50 

51 return output 

52 

53 

54# styled path 

55@dataclass 

56class StyledPath(PathFormat): 

57 "a `StyledPath` is a `PathFormat` with a specific path" 

58 

59 path: CoordArray 

60 

61 

62DEFAULT_FORMATS: dict[str, PathFormat] = { 

63 "true": PathFormat( 

64 label="true path", 

65 fmt="--", 

66 color="red", 

67 line_width=2.5, 

68 quiver_kwargs=None, 

69 ), 

70 "predicted": PathFormat( 

71 label=None, 

72 fmt=":", 

73 color=None, 

74 line_width=2, 

75 quiver_kwargs={"width": 0.015}, 

76 ), 

77} 

78 

79 

80def process_path_input( 

81 path: CoordList | CoordArray | StyledPath, 

82 _default_key: str, 

83 path_fmt: PathFormat | None = None, 

84 **kwargs, 

85) -> StyledPath: 

86 "convert a path, which might be a list or array of coords, into a `StyledPath`" 

87 styled_path: StyledPath 

88 if isinstance(path, StyledPath): 

89 styled_path = path 

90 elif isinstance(path, np.ndarray): 

91 styled_path = StyledPath(path=path) 

92 # add default formatting 

93 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 

94 elif isinstance(path, list): 

95 styled_path = StyledPath(path=np.array(path)) 

96 # add default formatting 

97 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 

98 else: 

99 err_msg: str = ( 

100 f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}" 

101 ) 

102 raise TypeError( 

103 err_msg, 

104 ) 

105 

106 # add formatting from path_fmt 

107 if path_fmt is not None: 

108 styled_path = styled_path.combine(path_fmt) 

109 

110 # add formatting from kwargs 

111 for key, value in kwargs.items(): 

112 setattr(styled_path, key, value) 

113 

114 return styled_path 

115 

116 

117DEFAULT_PREDICTED_PATH_COLORS: list[str] = [ 

118 "tab:orange", 

119 "tab:olive", 

120 "sienna", 

121 "mediumseagreen", 

122 "tab:purple", 

123 "slategrey", 

124] 

125 

126 

127class MazePlot: 

128 """Class for displaying mazes and paths""" 

129 

130 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 

131 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 

132 

133 Wall thickness is fixed to 1px 

134 A "unit" consists of a single node and the right and lower connection/wall. 

135 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 

136 """ 

137 self.unit_length: int = unit_length 

138 self.maze: LatticeMaze = maze 

139 self.true_path: StyledPath | None = None 

140 self.predicted_paths: list[StyledPath] = [] 

141 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 

142 self.custom_node_value_flag: bool = False 

143 self.node_color_map: str = "Blues" 

144 self.target_token_coord: Coord = None 

145 self.preceding_tokens_coords: CoordArray = None 

146 self.colormap_center: float | None = None 

147 self.cbar_ax = None 

148 self.marked_coords: list[tuple[Coord, dict]] = list() 

149 

150 self.marker_kwargs_current: dict = dict( 

151 marker="s", 

152 color="green", 

153 ms=12, 

154 ) 

155 self.marker_kwargs_next: dict = dict( 

156 marker="P", 

157 color="green", 

158 ms=12, 

159 ) 

160 

161 if isinstance(maze, SolvedMaze): 

162 self.add_true_path(maze.solution) 

163 else: 

164 if isinstance(maze, TargetedLatticeMaze): 

165 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution) 

166 

167 @property 

168 def solved_maze(self) -> SolvedMaze: 

169 "get the underlying `SolvedMaze` object" 

170 if self.true_path is None: 

171 raise ValueError( 

172 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 

173 ) 

174 return SolvedMaze.from_lattice_maze( 

175 lattice_maze=self.maze, 

176 solution=self.true_path.path, 

177 ) 

178 

179 def add_true_path( 

180 self, 

181 path: CoordList | CoordArray | StyledPath, 

182 path_fmt: PathFormat | None = None, 

183 **kwargs, 

184 ) -> MazePlot: 

185 "add a true path to the maze with optional formatting" 

186 self.true_path = process_path_input( 

187 path=path, 

188 _default_key="true", 

189 path_fmt=path_fmt, 

190 **kwargs, 

191 ) 

192 

193 return self 

194 

195 def add_predicted_path( 

196 self, 

197 path: CoordList | CoordArray | StyledPath, 

198 path_fmt: PathFormat | None = None, 

199 **kwargs, 

200 ) -> MazePlot: 

201 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 

202 

203 Default formatting depends on nuber of paths already saved in predicted path list. 

204 """ 

205 styled_path: StyledPath = process_path_input( 

206 path=path, 

207 _default_key="predicted", 

208 path_fmt=path_fmt, 

209 **kwargs, 

210 ) 

211 

212 # set default label and color if not specified 

213 if styled_path.label is None: 

214 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 

215 

216 if styled_path.color is None: 

217 color_num: int = len(self.predicted_paths) % len( 

218 DEFAULT_PREDICTED_PATH_COLORS, 

219 ) 

220 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 

221 

222 self.predicted_paths.append(styled_path) 

223 return self 

224 

225 def add_multiple_paths( 

226 self, 

227 path_list: Sequence[CoordList | CoordArray | StyledPath], 

228 ) -> MazePlot: 

229 """Function for adding multiple paths to MazePlot at once. 

230 

231 > DOCS: what are the two ways? 

232 This can be done in two ways: 

233 1. Passing a list of 

234 """ 

235 for path in path_list: 

236 self.add_predicted_path(path) 

237 return self 

238 

239 def add_node_values( 

240 self, 

241 node_values: Float[np.ndarray, "grid_n grid_n"], 

242 color_map: str = "Blues", 

243 target_token_coord: Coord | None = None, 

244 preceeding_tokens_coords: CoordArray = None, 

245 colormap_center: float | None = None, 

246 colormap_max: float | None = None, 

247 hide_colorbar: bool = False, 

248 ) -> MazePlot: 

249 """add node values to the maze for visualization as a heatmap 

250 

251 > DOCS: what are these arguments? 

252 

253 # Parameters: 

254 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 

255 - `color_map : str` 

256 (defaults to `"Blues"`) 

257 - `target_token_coord : Coord | None` 

258 (defaults to `None`) 

259 - `preceeding_tokens_coords : CoordArray` 

260 (defaults to `None`) 

261 - `colormap_center : float | None` 

262 (defaults to `None`) 

263 - `colormap_max : float | None` 

264 (defaults to `None`) 

265 - `hide_colorbar : bool` 

266 (defaults to `False`) 

267 

268 # Returns: 

269 - `MazePlot` 

270 """ 

271 assert node_values.shape == self.maze.grid_shape, ( 

272 "Please pass node values of the same sape as LatticeMaze.grid_shape" 

273 ) 

274 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 

275 

276 self.node_values = node_values 

277 # Set flag for choosing cmap while plotting maze 

278 self.custom_node_value_flag = True 

279 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 

280 self.node_color_map = color_map 

281 self.colormap_center = colormap_center 

282 self.colormap_max = colormap_max 

283 self.hide_colorbar = hide_colorbar 

284 

285 if target_token_coord is not None: 

286 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 

287 if preceeding_tokens_coords is not None: 

288 for coord in preceeding_tokens_coords: 

289 self.marked_coords.append((coord, self.marker_kwargs_current)) 

290 return self 

291 

292 def plot( 

293 self, 

294 dpi: int = 100, 

295 title: str = "", 

296 fig_ax: tuple | None = None, 

297 plain: bool = False, 

298 ) -> MazePlot: 

299 """Plot the maze and paths.""" 

300 # set up figure 

301 if fig_ax is None: 

302 self.fig = plt.figure(dpi=dpi) 

303 self.ax = self.fig.add_subplot(1, 1, 1) 

304 else: 

305 self.fig, self.ax = fig_ax 

306 

307 # plot maze 

308 self._plot_maze() 

309 

310 # Plot labels 

311 if not plain: 

312 tick_arr = np.arange(self.maze.grid_shape[0]) 

313 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 

314 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 

315 self.ax.set_xlabel("col") 

316 self.ax.set_ylabel("row") 

317 self.ax.set_title(title) 

318 else: 

319 self.ax.set_xticks([]) 

320 self.ax.set_yticks([]) 

321 self.ax.set_xlabel("") 

322 self.ax.set_ylabel("") 

323 self.ax.axis("off") 

324 

325 # plot paths 

326 if self.true_path is not None: 

327 self._plot_path(self.true_path) 

328 for path in self.predicted_paths: 

329 self._plot_path(path) 

330 

331 # plot markers 

332 for coord, kwargs in self.marked_coords: 

333 self._place_marked_coords([coord], **kwargs) 

334 

335 return self 

336 

337 def _rowcol_to_coord(self, point: Coord) -> np.ndarray: 

338 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" 

339 point = np.array([point[1], point[0]]) 

340 return self.unit_length * (point + 0.5) 

341 

342 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 

343 """Mark coordinates on the maze with a marker. 

344 

345 default marker is a blue "+": 

346 `dict(marker="+", color="blue")` 

347 """ 

348 kwargs = { 

349 **dict(marker="+", color="blue"), 

350 **kwargs, 

351 } 

352 for coord in coords: 

353 self.marked_coords.append((coord, kwargs)) 

354 

355 return self 

356 

357 def _place_marked_coords( 

358 self, 

359 coords: CoordArray | list[Coord], 

360 **kwargs, 

361 ) -> MazePlot: 

362 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords]) 

363 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs) 

364 

365 return self 

366 

367 def _plot_maze(self) -> None: # noqa: C901, PLR0912 

368 """Define Colormap and plot maze. 

369 

370 Colormap: x is -inf: black 

371 else: use colormap 

372 """ 

373 img = self._lattice_maze_to_img() 

374 

375 # if no node_values have been passed (no colormap) 

376 if self.custom_node_value_flag is False: 

377 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1) 

378 

379 else: 

380 assert self.node_values is not None, "Please pass node values." 

381 assert not np.isnan(self.node_values).any(), ( 

382 "Please pass node values, they cannot be nan." 

383 ) 

384 

385 vals_min: float = np.nanmin(self.node_values) 

386 vals_max: float = np.nanmax(self.node_values) 

387 # if both are negative or both are positive, set max/min to 0 

388 if vals_max < 0.0: 

389 vals_max = 0.0 

390 elif vals_min > 0.0: 

391 vals_min = 0.0 

392 

393 # adjust vals_max, in case you need consistent colorbar across multiple plots 

394 vals_max = self.colormap_max or vals_max 

395 

396 # create colormap 

397 cmap = mpl.colormaps[self.node_color_map] 

398 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black 

399 cmap.set_bad(color="black") 

400 

401 if self.colormap_center is not None: 

402 if not (vals_min < self.colormap_center < vals_max): 

403 if vals_min == self.colormap_center: 

404 vals_min -= 1e-10 

405 elif vals_max == self.colormap_center: 

406 vals_max += 1e-10 

407 else: 

408 err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}" 

409 raise ValueError( 

410 err_msg, 

411 ) 

412 

413 norm = mpl.colors.TwoSlopeNorm( 

414 vmin=vals_min, 

415 vcenter=self.colormap_center, 

416 vmax=vals_max, 

417 ) 

418 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm) 

419 else: 

420 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max) 

421 

422 # Add colorbar based on the condition of self.hide_colorbar 

423 if not self.hide_colorbar: 

424 ticks = np.linspace(vals_min, vals_max, 5) 

425 

426 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks): 

427 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0) 

428 

429 if ( 

430 self.colormap_center is not None 

431 and self.colormap_center not in ticks 

432 and vals_min < self.colormap_center < vals_max 

433 ): 

434 ticks = np.insert( 

435 ticks, 

436 np.searchsorted(ticks, self.colormap_center), 

437 self.colormap_center, 

438 ) 

439 

440 cbar = plt.colorbar( 

441 _plotted, 

442 ticks=ticks, 

443 ax=self.ax, 

444 cax=self.cbar_ax, 

445 ) 

446 self.cbar_ax = cbar.ax 

447 

448 # make the boundaries of the image thicker (walls look weird without this) 

449 for axis in ["top", "bottom", "left", "right"]: 

450 self.ax.spines[axis].set_linewidth(2) 

451 

452 def _lattice_maze_to_img( 

453 self, 

454 connection_val_scale: float = 0.93, 

455 ) -> Bool[np.ndarray, "row col"]: 

456 """Build an image to visualise the maze. 

457 

458 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. 

459 - Nodes have area: (ul-1) * (ul-1) and value 1 by default 

460 - take node_value if passed via .add_node_values() 

461 - Walls have area: 1 * (ul-1) and value -1 

462 - Connections have area: 1 * (ul-1); color and value 0.93 by default 

463 - take node_value if passed via .add_node_values() 

464 

465 Axes definition: 

466 (0,0) col 

467 ----|-----------> 

468 | 

469 row | 

470 | 

471 v 

472 

473 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes. 

474 """ 

475 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored 

476 node_bdry_hack: int 

477 connection_list_processed: Float[np.ndarray, "dim row col"] 

478 # Set node and connection values 

479 if self.node_values is None: 

480 scaled_node_values = np.ones(self.maze.grid_shape) 

481 connection_values = scaled_node_values * connection_val_scale 

482 node_bdry_hack = 0 

483 # TODO: hack 

484 # invert connection list 

485 connection_list_processed = np.logical_not(self.maze.connection_list) 

486 else: 

487 # TODO: hack 

488 scaled_node_values = self.node_values 

489 # connection_values = scaled_node_values 

490 connection_values = np.full_like(scaled_node_values, np.nan) 

491 node_bdry_hack = 1 

492 connection_list_processed = self.maze.connection_list 

493 

494 # Create background image (all pixels set to -1, walls everywhere) 

495 img: Float[np.ndarray, "row col"] = -np.ones( 

496 ( 

497 self.maze.grid_shape[0] * self.unit_length + 1, 

498 self.maze.grid_shape[1] * self.unit_length + 1, 

499 ), 

500 dtype=float, 

501 ) 

502 

503 # Draw nodes and connections by iterating through lattice 

504 for row in range(self.maze.grid_shape[0]): 

505 for col in range(self.maze.grid_shape[1]): 

506 # Draw node 

507 img[ 

508 row * self.unit_length + 1 : (row + 1) * self.unit_length 

509 + node_bdry_hack, 

510 col * self.unit_length + 1 : (col + 1) * self.unit_length 

511 + node_bdry_hack, 

512 ] = scaled_node_values[row, col] 

513 

514 # Down connection 

515 if not connection_list_processed[0, row, col]: 

516 img[ 

517 (row + 1) * self.unit_length, 

518 col * self.unit_length + 1 : (col + 1) * self.unit_length, 

519 ] = connection_values[row, col] 

520 

521 # Right connection 

522 if not connection_list_processed[1, row, col]: 

523 img[ 

524 row * self.unit_length + 1 : (row + 1) * self.unit_length, 

525 (col + 1) * self.unit_length, 

526 ] = connection_values[row, col] 

527 

528 return img 

529 

530 def _plot_path(self, path_format: PathFormat) -> None: 

531 if len(path_format.path) == 0: 

532 warnings.warn(f"Empty path, skipping plotting\n{path_format = }") 

533 return 

534 p_transformed = np.array( 

535 [self._rowcol_to_coord(coord) for coord in path_format.path], 

536 ) 

537 if path_format.quiver_kwargs is not None: 

538 try: 

539 x: np.ndarray = p_transformed[:, 0] 

540 y: np.ndarray = p_transformed[:, 1] 

541 except Exception as e: 

542 err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}" 

543 raise ValueError( 

544 err_msg, 

545 ) from e 

546 

547 # Generate colors from the colormap 

548 if path_format.cmap is not None: 

549 n = len(x) - 1 # Number of arrows 

550 cmap = plt.get_cmap(path_format.cmap) 

551 colors = [cmap(i / n) for i in range(n)] 

552 else: 

553 colors = path_format.color 

554 

555 self.ax.quiver( 

556 x[:-1], 

557 y[:-1], 

558 x[1:] - x[:-1], 

559 y[1:] - y[:-1], 

560 scale_units="xy", 

561 angles="xy", 

562 scale=1, 

563 color=colors, 

564 **path_format.quiver_kwargs, 

565 ) 

566 else: 

567 self.ax.plot( 

568 p_transformed[:, 0], 

569 p_transformed[:, 1], 

570 path_format.fmt, 

571 lw=path_format.line_width, 

572 color=path_format.color, 

573 label=path_format.label, 

574 ) 

575 # mark endpoints 

576 self.ax.plot( 

577 [p_transformed[0][0]], 

578 [p_transformed[0][1]], 

579 "o", 

580 color=path_format.color, 

581 ms=10, 

582 ) 

583 self.ax.plot( 

584 [p_transformed[-1][0]], 

585 [p_transformed[-1][1]], 

586 "x", 

587 color=path_format.color, 

588 ms=10, 

589 ) 

590 

591 def to_ascii( 

592 self, 

593 show_endpoints: bool = True, 

594 show_solution: bool = True, 

595 ) -> str: 

596 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 

597 if self.true_path: 

598 return self.solved_maze.as_ascii( 

599 show_endpoints=show_endpoints, 

600 show_solution=show_solution, 

601 ) 

602 else: 

603 return self.maze.as_ascii(show_endpoints=show_endpoints)