Coverage for maze_dataset/plotting/plot_maze.py: 81%
222 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 14:35 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 14:35 -0600
1"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more"""
3from __future__ import annotations # for type hinting self as return value
5import warnings
6from copy import deepcopy
7from dataclasses import dataclass
8from typing import Sequence
10import matplotlib as mpl
11import matplotlib.pyplot as plt
12import numpy as np
13from jaxtyping import Bool, Float
15from maze_dataset.constants import Coord, CoordArray, CoordList
16from maze_dataset.maze import (
17 LatticeMaze,
18 SolvedMaze,
19 TargetedLatticeMaze,
20)
22LARGE_NEGATIVE_NUMBER: float = -1e10
25@dataclass(kw_only=True)
26class PathFormat:
27 """formatting options for path plot"""
29 label: str | None = None
30 fmt: str = "o"
31 color: str | None = None
32 cmap: str | None = None
33 line_width: float | None = None
34 quiver_kwargs: dict | None = None
36 def combine(self, other: PathFormat) -> PathFormat:
37 """combine with other PathFormat object, overwriting attributes with non-None values.
39 returns a modified copy of self.
40 """
41 output: PathFormat = deepcopy(self)
42 for key, value in other.__dict__.items():
43 if key == "path":
44 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }"
45 raise ValueError(
46 err_msg,
47 )
48 if value is not None:
49 setattr(output, key, value)
51 return output
54# styled path
55@dataclass
56class StyledPath(PathFormat):
57 "a `StyledPath` is a `PathFormat` with a specific path"
59 path: CoordArray
62DEFAULT_FORMATS: dict[str, PathFormat] = {
63 "true": PathFormat(
64 label="true path",
65 fmt="--",
66 color="red",
67 line_width=2.5,
68 quiver_kwargs=None,
69 ),
70 "predicted": PathFormat(
71 label=None,
72 fmt=":",
73 color=None,
74 line_width=2,
75 quiver_kwargs={"width": 0.015},
76 ),
77}
80def process_path_input(
81 path: CoordList | CoordArray | StyledPath,
82 _default_key: str,
83 path_fmt: PathFormat | None = None,
84 **kwargs,
85) -> StyledPath:
86 "convert a path, which might be a list or array of coords, into a `StyledPath`"
87 styled_path: StyledPath
88 if isinstance(path, StyledPath):
89 styled_path = path
90 elif isinstance(path, np.ndarray):
91 styled_path = StyledPath(path=path)
92 # add default formatting
93 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
94 elif isinstance(path, list):
95 styled_path = StyledPath(path=np.array(path))
96 # add default formatting
97 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
98 else:
99 err_msg: str = (
100 f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}"
101 )
102 raise TypeError(
103 err_msg,
104 )
106 # add formatting from path_fmt
107 if path_fmt is not None:
108 styled_path = styled_path.combine(path_fmt)
110 # add formatting from kwargs
111 for key, value in kwargs.items():
112 setattr(styled_path, key, value)
114 return styled_path
117DEFAULT_PREDICTED_PATH_COLORS: list[str] = [
118 "tab:orange",
119 "tab:olive",
120 "sienna",
121 "mediumseagreen",
122 "tab:purple",
123 "slategrey",
124]
127class MazePlot:
128 """Class for displaying mazes and paths"""
130 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
131 """UNIT_LENGTH: Set ratio between node size and wall thickness in image.
133 Wall thickness is fixed to 1px
134 A "unit" consists of a single node and the right and lower connection/wall.
135 Example: ul = 14 yields 13:1 ratio between node size and wall thickness
136 """
137 self.unit_length: int = unit_length
138 self.maze: LatticeMaze = maze
139 self.true_path: StyledPath | None = None
140 self.predicted_paths: list[StyledPath] = []
141 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
142 self.custom_node_value_flag: bool = False
143 self.node_color_map: str = "Blues"
144 self.target_token_coord: Coord = None
145 self.preceding_tokens_coords: CoordArray = None
146 self.colormap_center: float | None = None
147 self.cbar_ax = None
148 self.marked_coords: list[tuple[Coord, dict]] = list()
150 self.marker_kwargs_current: dict = dict(
151 marker="s",
152 color="green",
153 ms=12,
154 )
155 self.marker_kwargs_next: dict = dict(
156 marker="P",
157 color="green",
158 ms=12,
159 )
161 if isinstance(maze, SolvedMaze):
162 self.add_true_path(maze.solution)
163 else:
164 if isinstance(maze, TargetedLatticeMaze):
165 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
167 @property
168 def solved_maze(self) -> SolvedMaze:
169 "get the underlying `SolvedMaze` object"
170 if self.true_path is None:
171 raise ValueError(
172 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.",
173 )
174 return SolvedMaze.from_lattice_maze(
175 lattice_maze=self.maze,
176 solution=self.true_path.path,
177 )
179 def add_true_path(
180 self,
181 path: CoordList | CoordArray | StyledPath,
182 path_fmt: PathFormat | None = None,
183 **kwargs,
184 ) -> MazePlot:
185 "add a true path to the maze with optional formatting"
186 self.true_path = process_path_input(
187 path=path,
188 _default_key="true",
189 path_fmt=path_fmt,
190 **kwargs,
191 )
193 return self
195 def add_predicted_path(
196 self,
197 path: CoordList | CoordArray | StyledPath,
198 path_fmt: PathFormat | None = None,
199 **kwargs,
200 ) -> MazePlot:
201 """Recieve predicted path and formatting preferences from input and save in predicted_path list.
203 Default formatting depends on nuber of paths already saved in predicted path list.
204 """
205 styled_path: StyledPath = process_path_input(
206 path=path,
207 _default_key="predicted",
208 path_fmt=path_fmt,
209 **kwargs,
210 )
212 # set default label and color if not specified
213 if styled_path.label is None:
214 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
216 if styled_path.color is None:
217 color_num: int = len(self.predicted_paths) % len(
218 DEFAULT_PREDICTED_PATH_COLORS,
219 )
220 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num]
222 self.predicted_paths.append(styled_path)
223 return self
225 def add_multiple_paths(
226 self,
227 path_list: Sequence[CoordList | CoordArray | StyledPath],
228 ) -> MazePlot:
229 """Function for adding multiple paths to MazePlot at once.
231 > DOCS: what are the two ways?
232 This can be done in two ways:
233 1. Passing a list of
234 """
235 for path in path_list:
236 self.add_predicted_path(path)
237 return self
239 def add_node_values(
240 self,
241 node_values: Float[np.ndarray, "grid_n grid_n"],
242 color_map: str = "Blues",
243 target_token_coord: Coord | None = None,
244 preceeding_tokens_coords: CoordArray = None,
245 colormap_center: float | None = None,
246 colormap_max: float | None = None,
247 hide_colorbar: bool = False,
248 ) -> MazePlot:
249 """add node values to the maze for visualization as a heatmap
251 > DOCS: what are these arguments?
253 # Parameters:
254 - `node_values : Float[np.ndarray, "grid_n grid_n"]`
255 - `color_map : str`
256 (defaults to `"Blues"`)
257 - `target_token_coord : Coord | None`
258 (defaults to `None`)
259 - `preceeding_tokens_coords : CoordArray`
260 (defaults to `None`)
261 - `colormap_center : float | None`
262 (defaults to `None`)
263 - `colormap_max : float | None`
264 (defaults to `None`)
265 - `hide_colorbar : bool`
266 (defaults to `False`)
268 # Returns:
269 - `MazePlot`
270 """
271 assert node_values.shape == self.maze.grid_shape, (
272 "Please pass node values of the same sape as LatticeMaze.grid_shape"
273 )
274 # assert np.min(node_values) >= 0, "Please pass non-negative node values only."
276 self.node_values = node_values
277 # Set flag for choosing cmap while plotting maze
278 self.custom_node_value_flag = True
279 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero
280 self.node_color_map = color_map
281 self.colormap_center = colormap_center
282 self.colormap_max = colormap_max
283 self.hide_colorbar = hide_colorbar
285 if target_token_coord is not None:
286 self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
287 if preceeding_tokens_coords is not None:
288 for coord in preceeding_tokens_coords:
289 self.marked_coords.append((coord, self.marker_kwargs_current))
290 return self
292 def plot(
293 self,
294 dpi: int = 100,
295 title: str = "",
296 fig_ax: tuple | None = None,
297 plain: bool = False,
298 ) -> MazePlot:
299 """Plot the maze and paths."""
300 # set up figure
301 if fig_ax is None:
302 self.fig = plt.figure(dpi=dpi)
303 self.ax = self.fig.add_subplot(1, 1, 1)
304 else:
305 self.fig, self.ax = fig_ax
307 # plot maze
308 self._plot_maze()
310 # Plot labels
311 if not plain:
312 tick_arr = np.arange(self.maze.grid_shape[0])
313 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
314 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
315 self.ax.set_xlabel("col")
316 self.ax.set_ylabel("row")
317 self.ax.set_title(title)
318 else:
319 self.ax.set_xticks([])
320 self.ax.set_yticks([])
321 self.ax.set_xlabel("")
322 self.ax.set_ylabel("")
323 self.ax.axis("off")
325 # plot paths
326 if self.true_path is not None:
327 self._plot_path(self.true_path)
328 for path in self.predicted_paths:
329 self._plot_path(path)
331 # plot markers
332 for coord, kwargs in self.marked_coords:
333 self._place_marked_coords([coord], **kwargs)
335 return self
337 def _rowcol_to_coord(self, point: Coord) -> np.ndarray:
338 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis."""
339 point = np.array([point[1], point[0]])
340 return self.unit_length * (point + 0.5)
342 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
343 """Mark coordinates on the maze with a marker.
345 default marker is a blue "+":
346 `dict(marker="+", color="blue")`
347 """
348 kwargs = {
349 **dict(marker="+", color="blue"),
350 **kwargs,
351 }
352 for coord in coords:
353 self.marked_coords.append((coord, kwargs))
355 return self
357 def _place_marked_coords(
358 self,
359 coords: CoordArray | list[Coord],
360 **kwargs,
361 ) -> MazePlot:
362 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords])
363 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs)
365 return self
367 def _plot_maze(self) -> None: # noqa: C901, PLR0912
368 """Define Colormap and plot maze.
370 Colormap: x is -inf: black
371 else: use colormap
372 """
373 img = self._lattice_maze_to_img()
375 # if no node_values have been passed (no colormap)
376 if self.custom_node_value_flag is False:
377 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1)
379 else:
380 assert self.node_values is not None, "Please pass node values."
381 assert not np.isnan(self.node_values).any(), (
382 "Please pass node values, they cannot be nan."
383 )
385 vals_min: float = np.nanmin(self.node_values)
386 vals_max: float = np.nanmax(self.node_values)
387 # if both are negative or both are positive, set max/min to 0
388 if vals_max < 0.0:
389 vals_max = 0.0
390 elif vals_min > 0.0:
391 vals_min = 0.0
393 # adjust vals_max, in case you need consistent colorbar across multiple plots
394 vals_max = self.colormap_max or vals_max
396 # create colormap
397 cmap = mpl.colormaps[self.node_color_map]
398 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black
399 cmap.set_bad(color="black")
401 if self.colormap_center is not None:
402 if not (vals_min < self.colormap_center < vals_max):
403 if vals_min == self.colormap_center:
404 vals_min -= 1e-10
405 elif vals_max == self.colormap_center:
406 vals_max += 1e-10
407 else:
408 err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}"
409 raise ValueError(
410 err_msg,
411 )
413 norm = mpl.colors.TwoSlopeNorm(
414 vmin=vals_min,
415 vcenter=self.colormap_center,
416 vmax=vals_max,
417 )
418 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm)
419 else:
420 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max)
422 # Add colorbar based on the condition of self.hide_colorbar
423 if not self.hide_colorbar:
424 ticks = np.linspace(vals_min, vals_max, 5)
426 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks):
427 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0)
429 if (
430 self.colormap_center is not None
431 and self.colormap_center not in ticks
432 and vals_min < self.colormap_center < vals_max
433 ):
434 ticks = np.insert(
435 ticks,
436 np.searchsorted(ticks, self.colormap_center),
437 self.colormap_center,
438 )
440 cbar = plt.colorbar(
441 _plotted,
442 ticks=ticks,
443 ax=self.ax,
444 cax=self.cbar_ax,
445 )
446 self.cbar_ax = cbar.ax
448 # make the boundaries of the image thicker (walls look weird without this)
449 for axis in ["top", "bottom", "left", "right"]:
450 self.ax.spines[axis].set_linewidth(2)
452 def _lattice_maze_to_img(
453 self,
454 connection_val_scale: float = 0.93,
455 ) -> Bool[np.ndarray, "row col"]:
456 """Build an image to visualise the maze.
458 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul.
459 - Nodes have area: (ul-1) * (ul-1) and value 1 by default
460 - take node_value if passed via .add_node_values()
461 - Walls have area: 1 * (ul-1) and value -1
462 - Connections have area: 1 * (ul-1); color and value 0.93 by default
463 - take node_value if passed via .add_node_values()
465 Axes definition:
466 (0,0) col
467 ----|----------->
468 |
469 row |
470 |
471 v
473 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes.
474 """
475 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored
476 node_bdry_hack: int
477 connection_list_processed: Float[np.ndarray, "dim row col"]
478 # Set node and connection values
479 if self.node_values is None:
480 scaled_node_values = np.ones(self.maze.grid_shape)
481 connection_values = scaled_node_values * connection_val_scale
482 node_bdry_hack = 0
483 # TODO: hack
484 # invert connection list
485 connection_list_processed = np.logical_not(self.maze.connection_list)
486 else:
487 # TODO: hack
488 scaled_node_values = self.node_values
489 # connection_values = scaled_node_values
490 connection_values = np.full_like(scaled_node_values, np.nan)
491 node_bdry_hack = 1
492 connection_list_processed = self.maze.connection_list
494 # Create background image (all pixels set to -1, walls everywhere)
495 img: Float[np.ndarray, "row col"] = -np.ones(
496 (
497 self.maze.grid_shape[0] * self.unit_length + 1,
498 self.maze.grid_shape[1] * self.unit_length + 1,
499 ),
500 dtype=float,
501 )
503 # Draw nodes and connections by iterating through lattice
504 for row in range(self.maze.grid_shape[0]):
505 for col in range(self.maze.grid_shape[1]):
506 # Draw node
507 img[
508 row * self.unit_length + 1 : (row + 1) * self.unit_length
509 + node_bdry_hack,
510 col * self.unit_length + 1 : (col + 1) * self.unit_length
511 + node_bdry_hack,
512 ] = scaled_node_values[row, col]
514 # Down connection
515 if not connection_list_processed[0, row, col]:
516 img[
517 (row + 1) * self.unit_length,
518 col * self.unit_length + 1 : (col + 1) * self.unit_length,
519 ] = connection_values[row, col]
521 # Right connection
522 if not connection_list_processed[1, row, col]:
523 img[
524 row * self.unit_length + 1 : (row + 1) * self.unit_length,
525 (col + 1) * self.unit_length,
526 ] = connection_values[row, col]
528 return img
530 def _plot_path(self, path_format: PathFormat) -> None:
531 if len(path_format.path) == 0:
532 warnings.warn(f"Empty path, skipping plotting\n{path_format = }")
533 return
534 p_transformed = np.array(
535 [self._rowcol_to_coord(coord) for coord in path_format.path],
536 )
537 if path_format.quiver_kwargs is not None:
538 try:
539 x: np.ndarray = p_transformed[:, 0]
540 y: np.ndarray = p_transformed[:, 1]
541 except Exception as e:
542 err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}"
543 raise ValueError(
544 err_msg,
545 ) from e
547 # Generate colors from the colormap
548 if path_format.cmap is not None:
549 n = len(x) - 1 # Number of arrows
550 cmap = plt.get_cmap(path_format.cmap)
551 colors = [cmap(i / n) for i in range(n)]
552 else:
553 colors = path_format.color
555 self.ax.quiver(
556 x[:-1],
557 y[:-1],
558 x[1:] - x[:-1],
559 y[1:] - y[:-1],
560 scale_units="xy",
561 angles="xy",
562 scale=1,
563 color=colors,
564 **path_format.quiver_kwargs,
565 )
566 else:
567 self.ax.plot(
568 p_transformed[:, 0],
569 p_transformed[:, 1],
570 path_format.fmt,
571 lw=path_format.line_width,
572 color=path_format.color,
573 label=path_format.label,
574 )
575 # mark endpoints
576 self.ax.plot(
577 [p_transformed[0][0]],
578 [p_transformed[0][1]],
579 "o",
580 color=path_format.color,
581 ms=10,
582 )
583 self.ax.plot(
584 [p_transformed[-1][0]],
585 [p_transformed[-1][1]],
586 "x",
587 color=path_format.color,
588 ms=10,
589 )
591 def to_ascii(
592 self,
593 show_endpoints: bool = True,
594 show_solution: bool = True,
595 ) -> str:
596 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`"
597 if self.true_path:
598 return self.solved_maze.as_ascii(
599 show_endpoints=show_endpoints,
600 show_solution=show_solution,
601 )
602 else:
603 return self.maze.as_ascii(show_endpoints=show_endpoints)