maze_dataset.plotting.plot_maze
provides MazePlot
, which has many tools for plotting mazes with multiple paths, colored nodes, and more
1"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more""" 2 3from __future__ import annotations # for type hinting self as return value 4 5import warnings 6from copy import deepcopy 7from dataclasses import dataclass 8from typing import Sequence 9 10import matplotlib as mpl 11import matplotlib.pyplot as plt 12import numpy as np 13from jaxtyping import Bool, Float 14 15from maze_dataset.constants import Coord, CoordArray, CoordList 16from maze_dataset.maze import ( 17 LatticeMaze, 18 SolvedMaze, 19 TargetedLatticeMaze, 20) 21 22LARGE_NEGATIVE_NUMBER: float = -1e10 23 24 25@dataclass(kw_only=True) 26class PathFormat: 27 """formatting options for path plot""" 28 29 label: str | None = None 30 fmt: str = "o" 31 color: str | None = None 32 cmap: str | None = None 33 line_width: float | None = None 34 quiver_kwargs: dict | None = None 35 36 def combine(self, other: PathFormat) -> PathFormat: 37 """combine with other PathFormat object, overwriting attributes with non-None values. 38 39 returns a modified copy of self. 40 """ 41 output: PathFormat = deepcopy(self) 42 for key, value in other.__dict__.items(): 43 if key == "path": 44 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 45 raise ValueError( 46 err_msg, 47 ) 48 if value is not None: 49 setattr(output, key, value) 50 51 return output 52 53 54# styled path 55@dataclass 56class StyledPath(PathFormat): 57 "a `StyledPath` is a `PathFormat` with a specific path" 58 59 path: CoordArray 60 61 62DEFAULT_FORMATS: dict[str, PathFormat] = { 63 "true": PathFormat( 64 label="true path", 65 fmt="--", 66 color="red", 67 line_width=2.5, 68 quiver_kwargs=None, 69 ), 70 "predicted": PathFormat( 71 label=None, 72 fmt=":", 73 color=None, 74 line_width=2, 75 quiver_kwargs={"width": 0.015}, 76 ), 77} 78 79 80def process_path_input( 81 path: CoordList | CoordArray | StyledPath, 82 _default_key: str, 83 path_fmt: PathFormat | None = None, 84 **kwargs, 85) -> StyledPath: 86 "convert a path, which might be a list or array of coords, into a `StyledPath`" 87 styled_path: StyledPath 88 if isinstance(path, StyledPath): 89 styled_path = path 90 elif isinstance(path, np.ndarray): 91 styled_path = StyledPath(path=path) 92 # add default formatting 93 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 94 elif isinstance(path, list): 95 styled_path = StyledPath(path=np.array(path)) 96 # add default formatting 97 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 98 else: 99 err_msg: str = ( 100 f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}" 101 ) 102 raise TypeError( 103 err_msg, 104 ) 105 106 # add formatting from path_fmt 107 if path_fmt is not None: 108 styled_path = styled_path.combine(path_fmt) 109 110 # add formatting from kwargs 111 for key, value in kwargs.items(): 112 setattr(styled_path, key, value) 113 114 return styled_path 115 116 117DEFAULT_PREDICTED_PATH_COLORS: list[str] = [ 118 "tab:orange", 119 "tab:olive", 120 "sienna", 121 "mediumseagreen", 122 "tab:purple", 123 "slategrey", 124] 125 126 127class MazePlot: 128 """Class for displaying mazes and paths""" 129 130 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 131 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 132 133 Wall thickness is fixed to 1px 134 A "unit" consists of a single node and the right and lower connection/wall. 135 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 136 """ 137 self.unit_length: int = unit_length 138 self.maze: LatticeMaze = maze 139 self.true_path: StyledPath | None = None 140 self.predicted_paths: list[StyledPath] = [] 141 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 142 self.custom_node_value_flag: bool = False 143 self.node_color_map: str = "Blues" 144 self.target_token_coord: Coord = None 145 self.preceding_tokens_coords: CoordArray = None 146 self.colormap_center: float | None = None 147 self.cbar_ax = None 148 self.marked_coords: list[tuple[Coord, dict]] = list() 149 150 self.marker_kwargs_current: dict = dict( 151 marker="s", 152 color="green", 153 ms=12, 154 ) 155 self.marker_kwargs_next: dict = dict( 156 marker="P", 157 color="green", 158 ms=12, 159 ) 160 161 if isinstance(maze, SolvedMaze): 162 self.add_true_path(maze.solution) 163 else: 164 if isinstance(maze, TargetedLatticeMaze): 165 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution) 166 167 @property 168 def solved_maze(self) -> SolvedMaze: 169 "get the underlying `SolvedMaze` object" 170 if self.true_path is None: 171 raise ValueError( 172 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 173 ) 174 return SolvedMaze.from_lattice_maze( 175 lattice_maze=self.maze, 176 solution=self.true_path.path, 177 ) 178 179 def add_true_path( 180 self, 181 path: CoordList | CoordArray | StyledPath, 182 path_fmt: PathFormat | None = None, 183 **kwargs, 184 ) -> MazePlot: 185 "add a true path to the maze with optional formatting" 186 self.true_path = process_path_input( 187 path=path, 188 _default_key="true", 189 path_fmt=path_fmt, 190 **kwargs, 191 ) 192 193 return self 194 195 def add_predicted_path( 196 self, 197 path: CoordList | CoordArray | StyledPath, 198 path_fmt: PathFormat | None = None, 199 **kwargs, 200 ) -> MazePlot: 201 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 202 203 Default formatting depends on nuber of paths already saved in predicted path list. 204 """ 205 styled_path: StyledPath = process_path_input( 206 path=path, 207 _default_key="predicted", 208 path_fmt=path_fmt, 209 **kwargs, 210 ) 211 212 # set default label and color if not specified 213 if styled_path.label is None: 214 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 215 216 if styled_path.color is None: 217 color_num: int = len(self.predicted_paths) % len( 218 DEFAULT_PREDICTED_PATH_COLORS, 219 ) 220 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 221 222 self.predicted_paths.append(styled_path) 223 return self 224 225 def add_multiple_paths( 226 self, 227 path_list: Sequence[CoordList | CoordArray | StyledPath], 228 ) -> MazePlot: 229 """Function for adding multiple paths to MazePlot at once. 230 231 > DOCS: what are the two ways? 232 This can be done in two ways: 233 1. Passing a list of 234 """ 235 for path in path_list: 236 self.add_predicted_path(path) 237 return self 238 239 def add_node_values( 240 self, 241 node_values: Float[np.ndarray, "grid_n grid_n"], 242 color_map: str = "Blues", 243 target_token_coord: Coord | None = None, 244 preceeding_tokens_coords: CoordArray = None, 245 colormap_center: float | None = None, 246 colormap_max: float | None = None, 247 hide_colorbar: bool = False, 248 ) -> MazePlot: 249 """add node values to the maze for visualization as a heatmap 250 251 > DOCS: what are these arguments? 252 253 # Parameters: 254 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 255 - `color_map : str` 256 (defaults to `"Blues"`) 257 - `target_token_coord : Coord | None` 258 (defaults to `None`) 259 - `preceeding_tokens_coords : CoordArray` 260 (defaults to `None`) 261 - `colormap_center : float | None` 262 (defaults to `None`) 263 - `colormap_max : float | None` 264 (defaults to `None`) 265 - `hide_colorbar : bool` 266 (defaults to `False`) 267 268 # Returns: 269 - `MazePlot` 270 """ 271 assert node_values.shape == self.maze.grid_shape, ( 272 "Please pass node values of the same sape as LatticeMaze.grid_shape" 273 ) 274 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 275 276 self.node_values = node_values 277 # Set flag for choosing cmap while plotting maze 278 self.custom_node_value_flag = True 279 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 280 self.node_color_map = color_map 281 self.colormap_center = colormap_center 282 self.colormap_max = colormap_max 283 self.hide_colorbar = hide_colorbar 284 285 if target_token_coord is not None: 286 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 287 if preceeding_tokens_coords is not None: 288 for coord in preceeding_tokens_coords: 289 self.marked_coords.append((coord, self.marker_kwargs_current)) 290 return self 291 292 def plot( 293 self, 294 dpi: int = 100, 295 title: str = "", 296 fig_ax: tuple | None = None, 297 plain: bool = False, 298 ) -> MazePlot: 299 """Plot the maze and paths.""" 300 # set up figure 301 if fig_ax is None: 302 self.fig = plt.figure(dpi=dpi) 303 self.ax = self.fig.add_subplot(1, 1, 1) 304 else: 305 self.fig, self.ax = fig_ax 306 307 # plot maze 308 self._plot_maze() 309 310 # Plot labels 311 if not plain: 312 tick_arr = np.arange(self.maze.grid_shape[0]) 313 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 314 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 315 self.ax.set_xlabel("col") 316 self.ax.set_ylabel("row") 317 self.ax.set_title(title) 318 else: 319 self.ax.set_xticks([]) 320 self.ax.set_yticks([]) 321 self.ax.set_xlabel("") 322 self.ax.set_ylabel("") 323 self.ax.axis("off") 324 325 # plot paths 326 if self.true_path is not None: 327 self._plot_path(self.true_path) 328 for path in self.predicted_paths: 329 self._plot_path(path) 330 331 # plot markers 332 for coord, kwargs in self.marked_coords: 333 self._place_marked_coords([coord], **kwargs) 334 335 return self 336 337 def _rowcol_to_coord(self, point: Coord) -> np.ndarray: 338 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" 339 point = np.array([point[1], point[0]]) 340 return self.unit_length * (point + 0.5) 341 342 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 343 """Mark coordinates on the maze with a marker. 344 345 default marker is a blue "+": 346 `dict(marker="+", color="blue")` 347 """ 348 kwargs = { 349 **dict(marker="+", color="blue"), 350 **kwargs, 351 } 352 for coord in coords: 353 self.marked_coords.append((coord, kwargs)) 354 355 return self 356 357 def _place_marked_coords( 358 self, 359 coords: CoordArray | list[Coord], 360 **kwargs, 361 ) -> MazePlot: 362 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords]) 363 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs) 364 365 return self 366 367 def _plot_maze(self) -> None: # noqa: C901, PLR0912 368 """Define Colormap and plot maze. 369 370 Colormap: x is -inf: black 371 else: use colormap 372 """ 373 img = self._lattice_maze_to_img() 374 375 # if no node_values have been passed (no colormap) 376 if self.custom_node_value_flag is False: 377 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1) 378 379 else: 380 assert self.node_values is not None, "Please pass node values." 381 assert not np.isnan(self.node_values).any(), ( 382 "Please pass node values, they cannot be nan." 383 ) 384 385 vals_min: float = np.nanmin(self.node_values) 386 vals_max: float = np.nanmax(self.node_values) 387 # if both are negative or both are positive, set max/min to 0 388 if vals_max < 0.0: 389 vals_max = 0.0 390 elif vals_min > 0.0: 391 vals_min = 0.0 392 393 # adjust vals_max, in case you need consistent colorbar across multiple plots 394 vals_max = self.colormap_max or vals_max 395 396 # create colormap 397 cmap = mpl.colormaps[self.node_color_map] 398 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black 399 cmap.set_bad(color="black") 400 401 if self.colormap_center is not None: 402 if not (vals_min < self.colormap_center < vals_max): 403 if vals_min == self.colormap_center: 404 vals_min -= 1e-10 405 elif vals_max == self.colormap_center: 406 vals_max += 1e-10 407 else: 408 err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}" 409 raise ValueError( 410 err_msg, 411 ) 412 413 norm = mpl.colors.TwoSlopeNorm( 414 vmin=vals_min, 415 vcenter=self.colormap_center, 416 vmax=vals_max, 417 ) 418 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm) 419 else: 420 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max) 421 422 # Add colorbar based on the condition of self.hide_colorbar 423 if not self.hide_colorbar: 424 ticks = np.linspace(vals_min, vals_max, 5) 425 426 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks): 427 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0) 428 429 if ( 430 self.colormap_center is not None 431 and self.colormap_center not in ticks 432 and vals_min < self.colormap_center < vals_max 433 ): 434 ticks = np.insert( 435 ticks, 436 np.searchsorted(ticks, self.colormap_center), 437 self.colormap_center, 438 ) 439 440 cbar = plt.colorbar( 441 _plotted, 442 ticks=ticks, 443 ax=self.ax, 444 cax=self.cbar_ax, 445 ) 446 self.cbar_ax = cbar.ax 447 448 # make the boundaries of the image thicker (walls look weird without this) 449 for axis in ["top", "bottom", "left", "right"]: 450 self.ax.spines[axis].set_linewidth(2) 451 452 def _lattice_maze_to_img( 453 self, 454 connection_val_scale: float = 0.93, 455 ) -> Bool[np.ndarray, "row col"]: 456 """Build an image to visualise the maze. 457 458 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. 459 - Nodes have area: (ul-1) * (ul-1) and value 1 by default 460 - take node_value if passed via .add_node_values() 461 - Walls have area: 1 * (ul-1) and value -1 462 - Connections have area: 1 * (ul-1); color and value 0.93 by default 463 - take node_value if passed via .add_node_values() 464 465 Axes definition: 466 (0,0) col 467 ----|-----------> 468 | 469 row | 470 | 471 v 472 473 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes. 474 """ 475 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored 476 node_bdry_hack: int 477 connection_list_processed: Float[np.ndarray, "dim row col"] 478 # Set node and connection values 479 if self.node_values is None: 480 scaled_node_values = np.ones(self.maze.grid_shape) 481 connection_values = scaled_node_values * connection_val_scale 482 node_bdry_hack = 0 483 # TODO: hack 484 # invert connection list 485 connection_list_processed = np.logical_not(self.maze.connection_list) 486 else: 487 # TODO: hack 488 scaled_node_values = self.node_values 489 # connection_values = scaled_node_values 490 connection_values = np.full_like(scaled_node_values, np.nan) 491 node_bdry_hack = 1 492 connection_list_processed = self.maze.connection_list 493 494 # Create background image (all pixels set to -1, walls everywhere) 495 img: Float[np.ndarray, "row col"] = -np.ones( 496 ( 497 self.maze.grid_shape[0] * self.unit_length + 1, 498 self.maze.grid_shape[1] * self.unit_length + 1, 499 ), 500 dtype=float, 501 ) 502 503 # Draw nodes and connections by iterating through lattice 504 for row in range(self.maze.grid_shape[0]): 505 for col in range(self.maze.grid_shape[1]): 506 # Draw node 507 img[ 508 row * self.unit_length + 1 : (row + 1) * self.unit_length 509 + node_bdry_hack, 510 col * self.unit_length + 1 : (col + 1) * self.unit_length 511 + node_bdry_hack, 512 ] = scaled_node_values[row, col] 513 514 # Down connection 515 if not connection_list_processed[0, row, col]: 516 img[ 517 (row + 1) * self.unit_length, 518 col * self.unit_length + 1 : (col + 1) * self.unit_length, 519 ] = connection_values[row, col] 520 521 # Right connection 522 if not connection_list_processed[1, row, col]: 523 img[ 524 row * self.unit_length + 1 : (row + 1) * self.unit_length, 525 (col + 1) * self.unit_length, 526 ] = connection_values[row, col] 527 528 return img 529 530 def _plot_path(self, path_format: PathFormat) -> None: 531 if len(path_format.path) == 0: 532 warnings.warn(f"Empty path, skipping plotting\n{path_format = }") 533 return 534 p_transformed = np.array( 535 [self._rowcol_to_coord(coord) for coord in path_format.path], 536 ) 537 if path_format.quiver_kwargs is not None: 538 try: 539 x: np.ndarray = p_transformed[:, 0] 540 y: np.ndarray = p_transformed[:, 1] 541 except Exception as e: 542 err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}" 543 raise ValueError( 544 err_msg, 545 ) from e 546 547 # Generate colors from the colormap 548 if path_format.cmap is not None: 549 n = len(x) - 1 # Number of arrows 550 cmap = plt.get_cmap(path_format.cmap) 551 colors = [cmap(i / n) for i in range(n)] 552 else: 553 colors = path_format.color 554 555 self.ax.quiver( 556 x[:-1], 557 y[:-1], 558 x[1:] - x[:-1], 559 y[1:] - y[:-1], 560 scale_units="xy", 561 angles="xy", 562 scale=1, 563 color=colors, 564 **path_format.quiver_kwargs, 565 ) 566 else: 567 self.ax.plot( 568 p_transformed[:, 0], 569 p_transformed[:, 1], 570 path_format.fmt, 571 lw=path_format.line_width, 572 color=path_format.color, 573 label=path_format.label, 574 ) 575 # mark endpoints 576 self.ax.plot( 577 [p_transformed[0][0]], 578 [p_transformed[0][1]], 579 "o", 580 color=path_format.color, 581 ms=10, 582 ) 583 self.ax.plot( 584 [p_transformed[-1][0]], 585 [p_transformed[-1][1]], 586 "x", 587 color=path_format.color, 588 ms=10, 589 ) 590 591 def to_ascii( 592 self, 593 show_endpoints: bool = True, 594 show_solution: bool = True, 595 ) -> str: 596 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 597 if self.true_path: 598 return self.solved_maze.as_ascii( 599 show_endpoints=show_endpoints, 600 show_solution=show_solution, 601 ) 602 else: 603 return self.maze.as_ascii(show_endpoints=show_endpoints)
26@dataclass(kw_only=True) 27class PathFormat: 28 """formatting options for path plot""" 29 30 label: str | None = None 31 fmt: str = "o" 32 color: str | None = None 33 cmap: str | None = None 34 line_width: float | None = None 35 quiver_kwargs: dict | None = None 36 37 def combine(self, other: PathFormat) -> PathFormat: 38 """combine with other PathFormat object, overwriting attributes with non-None values. 39 40 returns a modified copy of self. 41 """ 42 output: PathFormat = deepcopy(self) 43 for key, value in other.__dict__.items(): 44 if key == "path": 45 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 46 raise ValueError( 47 err_msg, 48 ) 49 if value is not None: 50 setattr(output, key, value) 51 52 return output
formatting options for path plot
37 def combine(self, other: PathFormat) -> PathFormat: 38 """combine with other PathFormat object, overwriting attributes with non-None values. 39 40 returns a modified copy of self. 41 """ 42 output: PathFormat = deepcopy(self) 43 for key, value in other.__dict__.items(): 44 if key == "path": 45 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 46 raise ValueError( 47 err_msg, 48 ) 49 if value is not None: 50 setattr(output, key, value) 51 52 return output
combine with other PathFormat object, overwriting attributes with non-None values.
returns a modified copy of self.
56@dataclass 57class StyledPath(PathFormat): 58 "a `StyledPath` is a `PathFormat` with a specific path" 59 60 path: CoordArray
a StyledPath
is a PathFormat
with a specific path
Inherited Members
81def process_path_input( 82 path: CoordList | CoordArray | StyledPath, 83 _default_key: str, 84 path_fmt: PathFormat | None = None, 85 **kwargs, 86) -> StyledPath: 87 "convert a path, which might be a list or array of coords, into a `StyledPath`" 88 styled_path: StyledPath 89 if isinstance(path, StyledPath): 90 styled_path = path 91 elif isinstance(path, np.ndarray): 92 styled_path = StyledPath(path=path) 93 # add default formatting 94 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 95 elif isinstance(path, list): 96 styled_path = StyledPath(path=np.array(path)) 97 # add default formatting 98 styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) 99 else: 100 err_msg: str = ( 101 f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}" 102 ) 103 raise TypeError( 104 err_msg, 105 ) 106 107 # add formatting from path_fmt 108 if path_fmt is not None: 109 styled_path = styled_path.combine(path_fmt) 110 111 # add formatting from kwargs 112 for key, value in kwargs.items(): 113 setattr(styled_path, key, value) 114 115 return styled_path
convert a path, which might be a list or array of coords, into a StyledPath
128class MazePlot: 129 """Class for displaying mazes and paths""" 130 131 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 132 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 133 134 Wall thickness is fixed to 1px 135 A "unit" consists of a single node and the right and lower connection/wall. 136 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 137 """ 138 self.unit_length: int = unit_length 139 self.maze: LatticeMaze = maze 140 self.true_path: StyledPath | None = None 141 self.predicted_paths: list[StyledPath] = [] 142 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 143 self.custom_node_value_flag: bool = False 144 self.node_color_map: str = "Blues" 145 self.target_token_coord: Coord = None 146 self.preceding_tokens_coords: CoordArray = None 147 self.colormap_center: float | None = None 148 self.cbar_ax = None 149 self.marked_coords: list[tuple[Coord, dict]] = list() 150 151 self.marker_kwargs_current: dict = dict( 152 marker="s", 153 color="green", 154 ms=12, 155 ) 156 self.marker_kwargs_next: dict = dict( 157 marker="P", 158 color="green", 159 ms=12, 160 ) 161 162 if isinstance(maze, SolvedMaze): 163 self.add_true_path(maze.solution) 164 else: 165 if isinstance(maze, TargetedLatticeMaze): 166 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution) 167 168 @property 169 def solved_maze(self) -> SolvedMaze: 170 "get the underlying `SolvedMaze` object" 171 if self.true_path is None: 172 raise ValueError( 173 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 174 ) 175 return SolvedMaze.from_lattice_maze( 176 lattice_maze=self.maze, 177 solution=self.true_path.path, 178 ) 179 180 def add_true_path( 181 self, 182 path: CoordList | CoordArray | StyledPath, 183 path_fmt: PathFormat | None = None, 184 **kwargs, 185 ) -> MazePlot: 186 "add a true path to the maze with optional formatting" 187 self.true_path = process_path_input( 188 path=path, 189 _default_key="true", 190 path_fmt=path_fmt, 191 **kwargs, 192 ) 193 194 return self 195 196 def add_predicted_path( 197 self, 198 path: CoordList | CoordArray | StyledPath, 199 path_fmt: PathFormat | None = None, 200 **kwargs, 201 ) -> MazePlot: 202 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 203 204 Default formatting depends on nuber of paths already saved in predicted path list. 205 """ 206 styled_path: StyledPath = process_path_input( 207 path=path, 208 _default_key="predicted", 209 path_fmt=path_fmt, 210 **kwargs, 211 ) 212 213 # set default label and color if not specified 214 if styled_path.label is None: 215 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 216 217 if styled_path.color is None: 218 color_num: int = len(self.predicted_paths) % len( 219 DEFAULT_PREDICTED_PATH_COLORS, 220 ) 221 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 222 223 self.predicted_paths.append(styled_path) 224 return self 225 226 def add_multiple_paths( 227 self, 228 path_list: Sequence[CoordList | CoordArray | StyledPath], 229 ) -> MazePlot: 230 """Function for adding multiple paths to MazePlot at once. 231 232 > DOCS: what are the two ways? 233 This can be done in two ways: 234 1. Passing a list of 235 """ 236 for path in path_list: 237 self.add_predicted_path(path) 238 return self 239 240 def add_node_values( 241 self, 242 node_values: Float[np.ndarray, "grid_n grid_n"], 243 color_map: str = "Blues", 244 target_token_coord: Coord | None = None, 245 preceeding_tokens_coords: CoordArray = None, 246 colormap_center: float | None = None, 247 colormap_max: float | None = None, 248 hide_colorbar: bool = False, 249 ) -> MazePlot: 250 """add node values to the maze for visualization as a heatmap 251 252 > DOCS: what are these arguments? 253 254 # Parameters: 255 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 256 - `color_map : str` 257 (defaults to `"Blues"`) 258 - `target_token_coord : Coord | None` 259 (defaults to `None`) 260 - `preceeding_tokens_coords : CoordArray` 261 (defaults to `None`) 262 - `colormap_center : float | None` 263 (defaults to `None`) 264 - `colormap_max : float | None` 265 (defaults to `None`) 266 - `hide_colorbar : bool` 267 (defaults to `False`) 268 269 # Returns: 270 - `MazePlot` 271 """ 272 assert node_values.shape == self.maze.grid_shape, ( 273 "Please pass node values of the same sape as LatticeMaze.grid_shape" 274 ) 275 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 276 277 self.node_values = node_values 278 # Set flag for choosing cmap while plotting maze 279 self.custom_node_value_flag = True 280 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 281 self.node_color_map = color_map 282 self.colormap_center = colormap_center 283 self.colormap_max = colormap_max 284 self.hide_colorbar = hide_colorbar 285 286 if target_token_coord is not None: 287 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 288 if preceeding_tokens_coords is not None: 289 for coord in preceeding_tokens_coords: 290 self.marked_coords.append((coord, self.marker_kwargs_current)) 291 return self 292 293 def plot( 294 self, 295 dpi: int = 100, 296 title: str = "", 297 fig_ax: tuple | None = None, 298 plain: bool = False, 299 ) -> MazePlot: 300 """Plot the maze and paths.""" 301 # set up figure 302 if fig_ax is None: 303 self.fig = plt.figure(dpi=dpi) 304 self.ax = self.fig.add_subplot(1, 1, 1) 305 else: 306 self.fig, self.ax = fig_ax 307 308 # plot maze 309 self._plot_maze() 310 311 # Plot labels 312 if not plain: 313 tick_arr = np.arange(self.maze.grid_shape[0]) 314 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 315 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 316 self.ax.set_xlabel("col") 317 self.ax.set_ylabel("row") 318 self.ax.set_title(title) 319 else: 320 self.ax.set_xticks([]) 321 self.ax.set_yticks([]) 322 self.ax.set_xlabel("") 323 self.ax.set_ylabel("") 324 self.ax.axis("off") 325 326 # plot paths 327 if self.true_path is not None: 328 self._plot_path(self.true_path) 329 for path in self.predicted_paths: 330 self._plot_path(path) 331 332 # plot markers 333 for coord, kwargs in self.marked_coords: 334 self._place_marked_coords([coord], **kwargs) 335 336 return self 337 338 def _rowcol_to_coord(self, point: Coord) -> np.ndarray: 339 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" 340 point = np.array([point[1], point[0]]) 341 return self.unit_length * (point + 0.5) 342 343 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 344 """Mark coordinates on the maze with a marker. 345 346 default marker is a blue "+": 347 `dict(marker="+", color="blue")` 348 """ 349 kwargs = { 350 **dict(marker="+", color="blue"), 351 **kwargs, 352 } 353 for coord in coords: 354 self.marked_coords.append((coord, kwargs)) 355 356 return self 357 358 def _place_marked_coords( 359 self, 360 coords: CoordArray | list[Coord], 361 **kwargs, 362 ) -> MazePlot: 363 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords]) 364 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs) 365 366 return self 367 368 def _plot_maze(self) -> None: # noqa: C901, PLR0912 369 """Define Colormap and plot maze. 370 371 Colormap: x is -inf: black 372 else: use colormap 373 """ 374 img = self._lattice_maze_to_img() 375 376 # if no node_values have been passed (no colormap) 377 if self.custom_node_value_flag is False: 378 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1) 379 380 else: 381 assert self.node_values is not None, "Please pass node values." 382 assert not np.isnan(self.node_values).any(), ( 383 "Please pass node values, they cannot be nan." 384 ) 385 386 vals_min: float = np.nanmin(self.node_values) 387 vals_max: float = np.nanmax(self.node_values) 388 # if both are negative or both are positive, set max/min to 0 389 if vals_max < 0.0: 390 vals_max = 0.0 391 elif vals_min > 0.0: 392 vals_min = 0.0 393 394 # adjust vals_max, in case you need consistent colorbar across multiple plots 395 vals_max = self.colormap_max or vals_max 396 397 # create colormap 398 cmap = mpl.colormaps[self.node_color_map] 399 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black 400 cmap.set_bad(color="black") 401 402 if self.colormap_center is not None: 403 if not (vals_min < self.colormap_center < vals_max): 404 if vals_min == self.colormap_center: 405 vals_min -= 1e-10 406 elif vals_max == self.colormap_center: 407 vals_max += 1e-10 408 else: 409 err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}" 410 raise ValueError( 411 err_msg, 412 ) 413 414 norm = mpl.colors.TwoSlopeNorm( 415 vmin=vals_min, 416 vcenter=self.colormap_center, 417 vmax=vals_max, 418 ) 419 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm) 420 else: 421 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max) 422 423 # Add colorbar based on the condition of self.hide_colorbar 424 if not self.hide_colorbar: 425 ticks = np.linspace(vals_min, vals_max, 5) 426 427 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks): 428 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0) 429 430 if ( 431 self.colormap_center is not None 432 and self.colormap_center not in ticks 433 and vals_min < self.colormap_center < vals_max 434 ): 435 ticks = np.insert( 436 ticks, 437 np.searchsorted(ticks, self.colormap_center), 438 self.colormap_center, 439 ) 440 441 cbar = plt.colorbar( 442 _plotted, 443 ticks=ticks, 444 ax=self.ax, 445 cax=self.cbar_ax, 446 ) 447 self.cbar_ax = cbar.ax 448 449 # make the boundaries of the image thicker (walls look weird without this) 450 for axis in ["top", "bottom", "left", "right"]: 451 self.ax.spines[axis].set_linewidth(2) 452 453 def _lattice_maze_to_img( 454 self, 455 connection_val_scale: float = 0.93, 456 ) -> Bool[np.ndarray, "row col"]: 457 """Build an image to visualise the maze. 458 459 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. 460 - Nodes have area: (ul-1) * (ul-1) and value 1 by default 461 - take node_value if passed via .add_node_values() 462 - Walls have area: 1 * (ul-1) and value -1 463 - Connections have area: 1 * (ul-1); color and value 0.93 by default 464 - take node_value if passed via .add_node_values() 465 466 Axes definition: 467 (0,0) col 468 ----|-----------> 469 | 470 row | 471 | 472 v 473 474 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes. 475 """ 476 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored 477 node_bdry_hack: int 478 connection_list_processed: Float[np.ndarray, "dim row col"] 479 # Set node and connection values 480 if self.node_values is None: 481 scaled_node_values = np.ones(self.maze.grid_shape) 482 connection_values = scaled_node_values * connection_val_scale 483 node_bdry_hack = 0 484 # TODO: hack 485 # invert connection list 486 connection_list_processed = np.logical_not(self.maze.connection_list) 487 else: 488 # TODO: hack 489 scaled_node_values = self.node_values 490 # connection_values = scaled_node_values 491 connection_values = np.full_like(scaled_node_values, np.nan) 492 node_bdry_hack = 1 493 connection_list_processed = self.maze.connection_list 494 495 # Create background image (all pixels set to -1, walls everywhere) 496 img: Float[np.ndarray, "row col"] = -np.ones( 497 ( 498 self.maze.grid_shape[0] * self.unit_length + 1, 499 self.maze.grid_shape[1] * self.unit_length + 1, 500 ), 501 dtype=float, 502 ) 503 504 # Draw nodes and connections by iterating through lattice 505 for row in range(self.maze.grid_shape[0]): 506 for col in range(self.maze.grid_shape[1]): 507 # Draw node 508 img[ 509 row * self.unit_length + 1 : (row + 1) * self.unit_length 510 + node_bdry_hack, 511 col * self.unit_length + 1 : (col + 1) * self.unit_length 512 + node_bdry_hack, 513 ] = scaled_node_values[row, col] 514 515 # Down connection 516 if not connection_list_processed[0, row, col]: 517 img[ 518 (row + 1) * self.unit_length, 519 col * self.unit_length + 1 : (col + 1) * self.unit_length, 520 ] = connection_values[row, col] 521 522 # Right connection 523 if not connection_list_processed[1, row, col]: 524 img[ 525 row * self.unit_length + 1 : (row + 1) * self.unit_length, 526 (col + 1) * self.unit_length, 527 ] = connection_values[row, col] 528 529 return img 530 531 def _plot_path(self, path_format: PathFormat) -> None: 532 if len(path_format.path) == 0: 533 warnings.warn(f"Empty path, skipping plotting\n{path_format = }") 534 return 535 p_transformed = np.array( 536 [self._rowcol_to_coord(coord) for coord in path_format.path], 537 ) 538 if path_format.quiver_kwargs is not None: 539 try: 540 x: np.ndarray = p_transformed[:, 0] 541 y: np.ndarray = p_transformed[:, 1] 542 except Exception as e: 543 err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}" 544 raise ValueError( 545 err_msg, 546 ) from e 547 548 # Generate colors from the colormap 549 if path_format.cmap is not None: 550 n = len(x) - 1 # Number of arrows 551 cmap = plt.get_cmap(path_format.cmap) 552 colors = [cmap(i / n) for i in range(n)] 553 else: 554 colors = path_format.color 555 556 self.ax.quiver( 557 x[:-1], 558 y[:-1], 559 x[1:] - x[:-1], 560 y[1:] - y[:-1], 561 scale_units="xy", 562 angles="xy", 563 scale=1, 564 color=colors, 565 **path_format.quiver_kwargs, 566 ) 567 else: 568 self.ax.plot( 569 p_transformed[:, 0], 570 p_transformed[:, 1], 571 path_format.fmt, 572 lw=path_format.line_width, 573 color=path_format.color, 574 label=path_format.label, 575 ) 576 # mark endpoints 577 self.ax.plot( 578 [p_transformed[0][0]], 579 [p_transformed[0][1]], 580 "o", 581 color=path_format.color, 582 ms=10, 583 ) 584 self.ax.plot( 585 [p_transformed[-1][0]], 586 [p_transformed[-1][1]], 587 "x", 588 color=path_format.color, 589 ms=10, 590 ) 591 592 def to_ascii( 593 self, 594 show_endpoints: bool = True, 595 show_solution: bool = True, 596 ) -> str: 597 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 598 if self.true_path: 599 return self.solved_maze.as_ascii( 600 show_endpoints=show_endpoints, 601 show_solution=show_solution, 602 ) 603 else: 604 return self.maze.as_ascii(show_endpoints=show_endpoints)
Class for displaying mazes and paths
131 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 132 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 133 134 Wall thickness is fixed to 1px 135 A "unit" consists of a single node and the right and lower connection/wall. 136 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 137 """ 138 self.unit_length: int = unit_length 139 self.maze: LatticeMaze = maze 140 self.true_path: StyledPath | None = None 141 self.predicted_paths: list[StyledPath] = [] 142 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 143 self.custom_node_value_flag: bool = False 144 self.node_color_map: str = "Blues" 145 self.target_token_coord: Coord = None 146 self.preceding_tokens_coords: CoordArray = None 147 self.colormap_center: float | None = None 148 self.cbar_ax = None 149 self.marked_coords: list[tuple[Coord, dict]] = list() 150 151 self.marker_kwargs_current: dict = dict( 152 marker="s", 153 color="green", 154 ms=12, 155 ) 156 self.marker_kwargs_next: dict = dict( 157 marker="P", 158 color="green", 159 ms=12, 160 ) 161 162 if isinstance(maze, SolvedMaze): 163 self.add_true_path(maze.solution) 164 else: 165 if isinstance(maze, TargetedLatticeMaze): 166 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
UNIT_LENGTH: Set ratio between node size and wall thickness in image.
Wall thickness is fixed to 1px A "unit" consists of a single node and the right and lower connection/wall. Example: ul = 14 yields 13:1 ratio between node size and wall thickness
168 @property 169 def solved_maze(self) -> SolvedMaze: 170 "get the underlying `SolvedMaze` object" 171 if self.true_path is None: 172 raise ValueError( 173 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 174 ) 175 return SolvedMaze.from_lattice_maze( 176 lattice_maze=self.maze, 177 solution=self.true_path.path, 178 )
get the underlying SolvedMaze
object
180 def add_true_path( 181 self, 182 path: CoordList | CoordArray | StyledPath, 183 path_fmt: PathFormat | None = None, 184 **kwargs, 185 ) -> MazePlot: 186 "add a true path to the maze with optional formatting" 187 self.true_path = process_path_input( 188 path=path, 189 _default_key="true", 190 path_fmt=path_fmt, 191 **kwargs, 192 ) 193 194 return self
add a true path to the maze with optional formatting
196 def add_predicted_path( 197 self, 198 path: CoordList | CoordArray | StyledPath, 199 path_fmt: PathFormat | None = None, 200 **kwargs, 201 ) -> MazePlot: 202 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 203 204 Default formatting depends on nuber of paths already saved in predicted path list. 205 """ 206 styled_path: StyledPath = process_path_input( 207 path=path, 208 _default_key="predicted", 209 path_fmt=path_fmt, 210 **kwargs, 211 ) 212 213 # set default label and color if not specified 214 if styled_path.label is None: 215 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 216 217 if styled_path.color is None: 218 color_num: int = len(self.predicted_paths) % len( 219 DEFAULT_PREDICTED_PATH_COLORS, 220 ) 221 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 222 223 self.predicted_paths.append(styled_path) 224 return self
Recieve predicted path and formatting preferences from input and save in predicted_path list.
Default formatting depends on nuber of paths already saved in predicted path list.
226 def add_multiple_paths( 227 self, 228 path_list: Sequence[CoordList | CoordArray | StyledPath], 229 ) -> MazePlot: 230 """Function for adding multiple paths to MazePlot at once. 231 232 > DOCS: what are the two ways? 233 This can be done in two ways: 234 1. Passing a list of 235 """ 236 for path in path_list: 237 self.add_predicted_path(path) 238 return self
Function for adding multiple paths to MazePlot at once.
DOCS: what are the two ways? This can be done in two ways:
- Passing a list of
240 def add_node_values( 241 self, 242 node_values: Float[np.ndarray, "grid_n grid_n"], 243 color_map: str = "Blues", 244 target_token_coord: Coord | None = None, 245 preceeding_tokens_coords: CoordArray = None, 246 colormap_center: float | None = None, 247 colormap_max: float | None = None, 248 hide_colorbar: bool = False, 249 ) -> MazePlot: 250 """add node values to the maze for visualization as a heatmap 251 252 > DOCS: what are these arguments? 253 254 # Parameters: 255 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 256 - `color_map : str` 257 (defaults to `"Blues"`) 258 - `target_token_coord : Coord | None` 259 (defaults to `None`) 260 - `preceeding_tokens_coords : CoordArray` 261 (defaults to `None`) 262 - `colormap_center : float | None` 263 (defaults to `None`) 264 - `colormap_max : float | None` 265 (defaults to `None`) 266 - `hide_colorbar : bool` 267 (defaults to `False`) 268 269 # Returns: 270 - `MazePlot` 271 """ 272 assert node_values.shape == self.maze.grid_shape, ( 273 "Please pass node values of the same sape as LatticeMaze.grid_shape" 274 ) 275 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 276 277 self.node_values = node_values 278 # Set flag for choosing cmap while plotting maze 279 self.custom_node_value_flag = True 280 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 281 self.node_color_map = color_map 282 self.colormap_center = colormap_center 283 self.colormap_max = colormap_max 284 self.hide_colorbar = hide_colorbar 285 286 if target_token_coord is not None: 287 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 288 if preceeding_tokens_coords is not None: 289 for coord in preceeding_tokens_coords: 290 self.marked_coords.append((coord, self.marker_kwargs_current)) 291 return self
add node values to the maze for visualization as a heatmap
DOCS: what are these arguments?
Parameters:
node_values : Float[np.ndarray, "grid_n grid_n"]
color_map : str
(defaults to"Blues"
)target_token_coord : Coord | None
(defaults toNone
)preceeding_tokens_coords : CoordArray
(defaults toNone
)colormap_center : float | None
(defaults toNone
)colormap_max : float | None
(defaults toNone
)hide_colorbar : bool
(defaults toFalse
)
Returns:
293 def plot( 294 self, 295 dpi: int = 100, 296 title: str = "", 297 fig_ax: tuple | None = None, 298 plain: bool = False, 299 ) -> MazePlot: 300 """Plot the maze and paths.""" 301 # set up figure 302 if fig_ax is None: 303 self.fig = plt.figure(dpi=dpi) 304 self.ax = self.fig.add_subplot(1, 1, 1) 305 else: 306 self.fig, self.ax = fig_ax 307 308 # plot maze 309 self._plot_maze() 310 311 # Plot labels 312 if not plain: 313 tick_arr = np.arange(self.maze.grid_shape[0]) 314 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 315 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 316 self.ax.set_xlabel("col") 317 self.ax.set_ylabel("row") 318 self.ax.set_title(title) 319 else: 320 self.ax.set_xticks([]) 321 self.ax.set_yticks([]) 322 self.ax.set_xlabel("") 323 self.ax.set_ylabel("") 324 self.ax.axis("off") 325 326 # plot paths 327 if self.true_path is not None: 328 self._plot_path(self.true_path) 329 for path in self.predicted_paths: 330 self._plot_path(path) 331 332 # plot markers 333 for coord, kwargs in self.marked_coords: 334 self._place_marked_coords([coord], **kwargs) 335 336 return self
Plot the maze and paths.
343 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 344 """Mark coordinates on the maze with a marker. 345 346 default marker is a blue "+": 347 `dict(marker="+", color="blue")` 348 """ 349 kwargs = { 350 **dict(marker="+", color="blue"), 351 **kwargs, 352 } 353 for coord in coords: 354 self.marked_coords.append((coord, kwargs)) 355 356 return self
Mark coordinates on the maze with a marker.
default marker is a blue "+":
dict(marker="+", color="blue")
592 def to_ascii( 593 self, 594 show_endpoints: bool = True, 595 show_solution: bool = True, 596 ) -> str: 597 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 598 if self.true_path: 599 return self.solved_maze.as_ascii( 600 show_endpoints=show_endpoints, 601 show_solution=show_solution, 602 ) 603 else: 604 return self.maze.as_ascii(show_endpoints=show_endpoints)
wrapper for self.solved_maze.as_ascii()
, shows the path if we have self.true_path