maze_dataset.plotting
utilities for plotting mazes and printing tokens
- any
LatticeMaze
orSolvedMaze
comes with aas_pixels()
method that returns a 2D numpy array of pixel values, but this is somewhat limited MazePlot
is a class that can be used to plot mazes and paths in a more customizable wayprint_tokens
contains utilities for printing tokens, colored by their type, position, or some custom weights (i.e. attention weights)
1"""utilities for plotting mazes and printing tokens 2 3- any `LatticeMaze` or `SolvedMaze` comes with a `as_pixels()` method that returns 4 a 2D numpy array of pixel values, but this is somewhat limited 5- `MazePlot` is a class that can be used to plot mazes and paths in a more customizable way 6- `print_tokens` contains utilities for printing tokens, colored by their type, position, or some custom weights (i.e. attention weights) 7""" 8 9from maze_dataset.plotting.plot_dataset import plot_dataset_mazes, print_dataset_mazes 10from maze_dataset.plotting.plot_maze import DEFAULT_FORMATS, MazePlot, PathFormat 11from maze_dataset.plotting.print_tokens import ( 12 color_maze_tokens_AOTP, 13 color_tokens_cmap, 14 color_tokens_rgb, 15) 16 17__all__ = [ 18 # submodules 19 "plot_dataset", 20 "plot_maze", 21 "plot_svg_fancy", 22 "plot_tokens", 23 "print_tokens", 24 # imports 25 "plot_dataset_mazes", 26 "print_dataset_mazes", 27 "DEFAULT_FORMATS", 28 "MazePlot", 29 "PathFormat", 30 "color_tokens_cmap", 31 "color_maze_tokens_AOTP", 32 "color_tokens_rgb", 33]
12def plot_dataset_mazes( 13 ds: MazeDataset, 14 count: int | None = None, 15 figsize_mult: tuple[float, float] = (1.0, 2.0), 16 title: bool | str = True, 17) -> tuple | None: 18 "plot `count` mazes from the dataset `d` in a single figure using `SolvedMaze.as_pixels()`" 19 count = count or len(ds) 20 if count == 0: 21 print("No mazes to plot for dataset") 22 return None 23 fig, axes = plt.subplots( 24 1, 25 count, 26 figsize=(count * figsize_mult[0], figsize_mult[1]), 27 ) 28 if count == 1: 29 axes = [axes] 30 for i in range(count): 31 axes[i].imshow(ds[i].as_pixels()) 32 # remove ticks 33 axes[i].set_xticks([]) 34 axes[i].set_yticks([]) 35 36 # set title 37 if title: 38 if isinstance(title, str): 39 fig.suptitle(title) 40 else: 41 kwargs: dict = { 42 "grid_n": ds.cfg.grid_n, 43 # "n_mazes": ds.cfg.n_mazes, 44 **ds.cfg.maze_ctor_kwargs, 45 } 46 fig.suptitle( 47 f"{ds.cfg.to_fname()}\n{ds.cfg.maze_ctor.__name__}({', '.join(f'{k}={v}' for k, v in kwargs.items())})", 48 ) 49 50 # tight layout 51 fig.tight_layout() 52 # remove whitespace between title and subplots 53 fig.subplots_adjust(top=1.0) 54 55 return fig, axes
plot count
mazes from the dataset d
in a single figure using SolvedMaze.as_pixels()
58def print_dataset_mazes(ds: MazeDataset, count: int | None = None) -> None: 59 "print ascii representation of `count` mazes from the dataset `d`" 60 count = count or len(ds) 61 if count == 0: 62 print("No mazes to print for dataset") 63 return 64 for i in range(count): 65 print(ds[i].as_ascii(), "\n\n-----\n")
print ascii representation of count
mazes from the dataset d
128class MazePlot: 129 """Class for displaying mazes and paths""" 130 131 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 132 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 133 134 Wall thickness is fixed to 1px 135 A "unit" consists of a single node and the right and lower connection/wall. 136 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 137 """ 138 self.unit_length: int = unit_length 139 self.maze: LatticeMaze = maze 140 self.true_path: StyledPath | None = None 141 self.predicted_paths: list[StyledPath] = [] 142 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 143 self.custom_node_value_flag: bool = False 144 self.node_color_map: str = "Blues" 145 self.target_token_coord: Coord = None 146 self.preceding_tokens_coords: CoordArray = None 147 self.colormap_center: float | None = None 148 self.cbar_ax = None 149 self.marked_coords: list[tuple[Coord, dict]] = list() 150 151 self.marker_kwargs_current: dict = dict( 152 marker="s", 153 color="green", 154 ms=12, 155 ) 156 self.marker_kwargs_next: dict = dict( 157 marker="P", 158 color="green", 159 ms=12, 160 ) 161 162 if isinstance(maze, SolvedMaze): 163 self.add_true_path(maze.solution) 164 else: 165 if isinstance(maze, TargetedLatticeMaze): 166 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution) 167 168 @property 169 def solved_maze(self) -> SolvedMaze: 170 "get the underlying `SolvedMaze` object" 171 if self.true_path is None: 172 raise ValueError( 173 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 174 ) 175 return SolvedMaze.from_lattice_maze( 176 lattice_maze=self.maze, 177 solution=self.true_path.path, 178 ) 179 180 def add_true_path( 181 self, 182 path: CoordList | CoordArray | StyledPath, 183 path_fmt: PathFormat | None = None, 184 **kwargs, 185 ) -> MazePlot: 186 "add a true path to the maze with optional formatting" 187 self.true_path = process_path_input( 188 path=path, 189 _default_key="true", 190 path_fmt=path_fmt, 191 **kwargs, 192 ) 193 194 return self 195 196 def add_predicted_path( 197 self, 198 path: CoordList | CoordArray | StyledPath, 199 path_fmt: PathFormat | None = None, 200 **kwargs, 201 ) -> MazePlot: 202 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 203 204 Default formatting depends on nuber of paths already saved in predicted path list. 205 """ 206 styled_path: StyledPath = process_path_input( 207 path=path, 208 _default_key="predicted", 209 path_fmt=path_fmt, 210 **kwargs, 211 ) 212 213 # set default label and color if not specified 214 if styled_path.label is None: 215 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 216 217 if styled_path.color is None: 218 color_num: int = len(self.predicted_paths) % len( 219 DEFAULT_PREDICTED_PATH_COLORS, 220 ) 221 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 222 223 self.predicted_paths.append(styled_path) 224 return self 225 226 def add_multiple_paths( 227 self, 228 path_list: Sequence[CoordList | CoordArray | StyledPath], 229 ) -> MazePlot: 230 """Function for adding multiple paths to MazePlot at once. 231 232 > DOCS: what are the two ways? 233 This can be done in two ways: 234 1. Passing a list of 235 """ 236 for path in path_list: 237 self.add_predicted_path(path) 238 return self 239 240 def add_node_values( 241 self, 242 node_values: Float[np.ndarray, "grid_n grid_n"], 243 color_map: str = "Blues", 244 target_token_coord: Coord | None = None, 245 preceeding_tokens_coords: CoordArray = None, 246 colormap_center: float | None = None, 247 colormap_max: float | None = None, 248 hide_colorbar: bool = False, 249 ) -> MazePlot: 250 """add node values to the maze for visualization as a heatmap 251 252 > DOCS: what are these arguments? 253 254 # Parameters: 255 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 256 - `color_map : str` 257 (defaults to `"Blues"`) 258 - `target_token_coord : Coord | None` 259 (defaults to `None`) 260 - `preceeding_tokens_coords : CoordArray` 261 (defaults to `None`) 262 - `colormap_center : float | None` 263 (defaults to `None`) 264 - `colormap_max : float | None` 265 (defaults to `None`) 266 - `hide_colorbar : bool` 267 (defaults to `False`) 268 269 # Returns: 270 - `MazePlot` 271 """ 272 assert node_values.shape == self.maze.grid_shape, ( 273 "Please pass node values of the same sape as LatticeMaze.grid_shape" 274 ) 275 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 276 277 self.node_values = node_values 278 # Set flag for choosing cmap while plotting maze 279 self.custom_node_value_flag = True 280 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 281 self.node_color_map = color_map 282 self.colormap_center = colormap_center 283 self.colormap_max = colormap_max 284 self.hide_colorbar = hide_colorbar 285 286 if target_token_coord is not None: 287 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 288 if preceeding_tokens_coords is not None: 289 for coord in preceeding_tokens_coords: 290 self.marked_coords.append((coord, self.marker_kwargs_current)) 291 return self 292 293 def plot( 294 self, 295 dpi: int = 100, 296 title: str = "", 297 fig_ax: tuple | None = None, 298 plain: bool = False, 299 ) -> MazePlot: 300 """Plot the maze and paths.""" 301 # set up figure 302 if fig_ax is None: 303 self.fig = plt.figure(dpi=dpi) 304 self.ax = self.fig.add_subplot(1, 1, 1) 305 else: 306 self.fig, self.ax = fig_ax 307 308 # plot maze 309 self._plot_maze() 310 311 # Plot labels 312 if not plain: 313 tick_arr = np.arange(self.maze.grid_shape[0]) 314 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 315 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 316 self.ax.set_xlabel("col") 317 self.ax.set_ylabel("row") 318 self.ax.set_title(title) 319 else: 320 self.ax.set_xticks([]) 321 self.ax.set_yticks([]) 322 self.ax.set_xlabel("") 323 self.ax.set_ylabel("") 324 self.ax.axis("off") 325 326 # plot paths 327 if self.true_path is not None: 328 self._plot_path(self.true_path) 329 for path in self.predicted_paths: 330 self._plot_path(path) 331 332 # plot markers 333 for coord, kwargs in self.marked_coords: 334 self._place_marked_coords([coord], **kwargs) 335 336 return self 337 338 def _rowcol_to_coord(self, point: Coord) -> np.ndarray: 339 """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" 340 point = np.array([point[1], point[0]]) 341 return self.unit_length * (point + 0.5) 342 343 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 344 """Mark coordinates on the maze with a marker. 345 346 default marker is a blue "+": 347 `dict(marker="+", color="blue")` 348 """ 349 kwargs = { 350 **dict(marker="+", color="blue"), 351 **kwargs, 352 } 353 for coord in coords: 354 self.marked_coords.append((coord, kwargs)) 355 356 return self 357 358 def _place_marked_coords( 359 self, 360 coords: CoordArray | list[Coord], 361 **kwargs, 362 ) -> MazePlot: 363 coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords]) 364 self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs) 365 366 return self 367 368 def _plot_maze(self) -> None: # noqa: C901, PLR0912 369 """Define Colormap and plot maze. 370 371 Colormap: x is -inf: black 372 else: use colormap 373 """ 374 img = self._lattice_maze_to_img() 375 376 # if no node_values have been passed (no colormap) 377 if self.custom_node_value_flag is False: 378 self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1) 379 380 else: 381 assert self.node_values is not None, "Please pass node values." 382 assert not np.isnan(self.node_values).any(), ( 383 "Please pass node values, they cannot be nan." 384 ) 385 386 vals_min: float = np.nanmin(self.node_values) 387 vals_max: float = np.nanmax(self.node_values) 388 # if both are negative or both are positive, set max/min to 0 389 if vals_max < 0.0: 390 vals_max = 0.0 391 elif vals_min > 0.0: 392 vals_min = 0.0 393 394 # adjust vals_max, in case you need consistent colorbar across multiple plots 395 vals_max = self.colormap_max or vals_max 396 397 # create colormap 398 cmap = mpl.colormaps[self.node_color_map] 399 # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black 400 cmap.set_bad(color="black") 401 402 if self.colormap_center is not None: 403 if not (vals_min < self.colormap_center < vals_max): 404 if vals_min == self.colormap_center: 405 vals_min -= 1e-10 406 elif vals_max == self.colormap_center: 407 vals_max += 1e-10 408 else: 409 err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}" 410 raise ValueError( 411 err_msg, 412 ) 413 414 norm = mpl.colors.TwoSlopeNorm( 415 vmin=vals_min, 416 vcenter=self.colormap_center, 417 vmax=vals_max, 418 ) 419 _plotted = self.ax.imshow(img, cmap=cmap, norm=norm) 420 else: 421 _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max) 422 423 # Add colorbar based on the condition of self.hide_colorbar 424 if not self.hide_colorbar: 425 ticks = np.linspace(vals_min, vals_max, 5) 426 427 if (vals_min < 0.0 < vals_max) and (0.0 not in ticks): 428 ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0) 429 430 if ( 431 self.colormap_center is not None 432 and self.colormap_center not in ticks 433 and vals_min < self.colormap_center < vals_max 434 ): 435 ticks = np.insert( 436 ticks, 437 np.searchsorted(ticks, self.colormap_center), 438 self.colormap_center, 439 ) 440 441 cbar = plt.colorbar( 442 _plotted, 443 ticks=ticks, 444 ax=self.ax, 445 cax=self.cbar_ax, 446 ) 447 self.cbar_ax = cbar.ax 448 449 # make the boundaries of the image thicker (walls look weird without this) 450 for axis in ["top", "bottom", "left", "right"]: 451 self.ax.spines[axis].set_linewidth(2) 452 453 def _lattice_maze_to_img( 454 self, 455 connection_val_scale: float = 0.93, 456 ) -> Bool[np.ndarray, "row col"]: 457 """Build an image to visualise the maze. 458 459 Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. 460 - Nodes have area: (ul-1) * (ul-1) and value 1 by default 461 - take node_value if passed via .add_node_values() 462 - Walls have area: 1 * (ul-1) and value -1 463 - Connections have area: 1 * (ul-1); color and value 0.93 by default 464 - take node_value if passed via .add_node_values() 465 466 Axes definition: 467 (0,0) col 468 ----|-----------> 469 | 470 row | 471 | 472 v 473 474 Returns a matrix of side length (ul) * n + 1 where n is the number of nodes. 475 """ 476 # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored 477 node_bdry_hack: int 478 connection_list_processed: Float[np.ndarray, "dim row col"] 479 # Set node and connection values 480 if self.node_values is None: 481 scaled_node_values = np.ones(self.maze.grid_shape) 482 connection_values = scaled_node_values * connection_val_scale 483 node_bdry_hack = 0 484 # TODO: hack 485 # invert connection list 486 connection_list_processed = np.logical_not(self.maze.connection_list) 487 else: 488 # TODO: hack 489 scaled_node_values = self.node_values 490 # connection_values = scaled_node_values 491 connection_values = np.full_like(scaled_node_values, np.nan) 492 node_bdry_hack = 1 493 connection_list_processed = self.maze.connection_list 494 495 # Create background image (all pixels set to -1, walls everywhere) 496 img: Float[np.ndarray, "row col"] = -np.ones( 497 ( 498 self.maze.grid_shape[0] * self.unit_length + 1, 499 self.maze.grid_shape[1] * self.unit_length + 1, 500 ), 501 dtype=float, 502 ) 503 504 # Draw nodes and connections by iterating through lattice 505 for row in range(self.maze.grid_shape[0]): 506 for col in range(self.maze.grid_shape[1]): 507 # Draw node 508 img[ 509 row * self.unit_length + 1 : (row + 1) * self.unit_length 510 + node_bdry_hack, 511 col * self.unit_length + 1 : (col + 1) * self.unit_length 512 + node_bdry_hack, 513 ] = scaled_node_values[row, col] 514 515 # Down connection 516 if not connection_list_processed[0, row, col]: 517 img[ 518 (row + 1) * self.unit_length, 519 col * self.unit_length + 1 : (col + 1) * self.unit_length, 520 ] = connection_values[row, col] 521 522 # Right connection 523 if not connection_list_processed[1, row, col]: 524 img[ 525 row * self.unit_length + 1 : (row + 1) * self.unit_length, 526 (col + 1) * self.unit_length, 527 ] = connection_values[row, col] 528 529 return img 530 531 def _plot_path(self, path_format: PathFormat) -> None: 532 if len(path_format.path) == 0: 533 warnings.warn(f"Empty path, skipping plotting\n{path_format = }") 534 return 535 p_transformed = np.array( 536 [self._rowcol_to_coord(coord) for coord in path_format.path], 537 ) 538 if path_format.quiver_kwargs is not None: 539 try: 540 x: np.ndarray = p_transformed[:, 0] 541 y: np.ndarray = p_transformed[:, 1] 542 except Exception as e: 543 err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}" 544 raise ValueError( 545 err_msg, 546 ) from e 547 548 # Generate colors from the colormap 549 if path_format.cmap is not None: 550 n = len(x) - 1 # Number of arrows 551 cmap = plt.get_cmap(path_format.cmap) 552 colors = [cmap(i / n) for i in range(n)] 553 else: 554 colors = path_format.color 555 556 self.ax.quiver( 557 x[:-1], 558 y[:-1], 559 x[1:] - x[:-1], 560 y[1:] - y[:-1], 561 scale_units="xy", 562 angles="xy", 563 scale=1, 564 color=colors, 565 **path_format.quiver_kwargs, 566 ) 567 else: 568 self.ax.plot( 569 p_transformed[:, 0], 570 p_transformed[:, 1], 571 path_format.fmt, 572 lw=path_format.line_width, 573 color=path_format.color, 574 label=path_format.label, 575 ) 576 # mark endpoints 577 self.ax.plot( 578 [p_transformed[0][0]], 579 [p_transformed[0][1]], 580 "o", 581 color=path_format.color, 582 ms=10, 583 ) 584 self.ax.plot( 585 [p_transformed[-1][0]], 586 [p_transformed[-1][1]], 587 "x", 588 color=path_format.color, 589 ms=10, 590 ) 591 592 def to_ascii( 593 self, 594 show_endpoints: bool = True, 595 show_solution: bool = True, 596 ) -> str: 597 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 598 if self.true_path: 599 return self.solved_maze.as_ascii( 600 show_endpoints=show_endpoints, 601 show_solution=show_solution, 602 ) 603 else: 604 return self.maze.as_ascii(show_endpoints=show_endpoints)
Class for displaying mazes and paths
131 def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: 132 """UNIT_LENGTH: Set ratio between node size and wall thickness in image. 133 134 Wall thickness is fixed to 1px 135 A "unit" consists of a single node and the right and lower connection/wall. 136 Example: ul = 14 yields 13:1 ratio between node size and wall thickness 137 """ 138 self.unit_length: int = unit_length 139 self.maze: LatticeMaze = maze 140 self.true_path: StyledPath | None = None 141 self.predicted_paths: list[StyledPath] = [] 142 self.node_values: Float[np.ndarray, "grid_n grid_n"] = None 143 self.custom_node_value_flag: bool = False 144 self.node_color_map: str = "Blues" 145 self.target_token_coord: Coord = None 146 self.preceding_tokens_coords: CoordArray = None 147 self.colormap_center: float | None = None 148 self.cbar_ax = None 149 self.marked_coords: list[tuple[Coord, dict]] = list() 150 151 self.marker_kwargs_current: dict = dict( 152 marker="s", 153 color="green", 154 ms=12, 155 ) 156 self.marker_kwargs_next: dict = dict( 157 marker="P", 158 color="green", 159 ms=12, 160 ) 161 162 if isinstance(maze, SolvedMaze): 163 self.add_true_path(maze.solution) 164 else: 165 if isinstance(maze, TargetedLatticeMaze): 166 self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
UNIT_LENGTH: Set ratio between node size and wall thickness in image.
Wall thickness is fixed to 1px A "unit" consists of a single node and the right and lower connection/wall. Example: ul = 14 yields 13:1 ratio between node size and wall thickness
168 @property 169 def solved_maze(self) -> SolvedMaze: 170 "get the underlying `SolvedMaze` object" 171 if self.true_path is None: 172 raise ValueError( 173 "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", 174 ) 175 return SolvedMaze.from_lattice_maze( 176 lattice_maze=self.maze, 177 solution=self.true_path.path, 178 )
get the underlying SolvedMaze
object
180 def add_true_path( 181 self, 182 path: CoordList | CoordArray | StyledPath, 183 path_fmt: PathFormat | None = None, 184 **kwargs, 185 ) -> MazePlot: 186 "add a true path to the maze with optional formatting" 187 self.true_path = process_path_input( 188 path=path, 189 _default_key="true", 190 path_fmt=path_fmt, 191 **kwargs, 192 ) 193 194 return self
add a true path to the maze with optional formatting
196 def add_predicted_path( 197 self, 198 path: CoordList | CoordArray | StyledPath, 199 path_fmt: PathFormat | None = None, 200 **kwargs, 201 ) -> MazePlot: 202 """Recieve predicted path and formatting preferences from input and save in predicted_path list. 203 204 Default formatting depends on nuber of paths already saved in predicted path list. 205 """ 206 styled_path: StyledPath = process_path_input( 207 path=path, 208 _default_key="predicted", 209 path_fmt=path_fmt, 210 **kwargs, 211 ) 212 213 # set default label and color if not specified 214 if styled_path.label is None: 215 styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" 216 217 if styled_path.color is None: 218 color_num: int = len(self.predicted_paths) % len( 219 DEFAULT_PREDICTED_PATH_COLORS, 220 ) 221 styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] 222 223 self.predicted_paths.append(styled_path) 224 return self
Recieve predicted path and formatting preferences from input and save in predicted_path list.
Default formatting depends on nuber of paths already saved in predicted path list.
226 def add_multiple_paths( 227 self, 228 path_list: Sequence[CoordList | CoordArray | StyledPath], 229 ) -> MazePlot: 230 """Function for adding multiple paths to MazePlot at once. 231 232 > DOCS: what are the two ways? 233 This can be done in two ways: 234 1. Passing a list of 235 """ 236 for path in path_list: 237 self.add_predicted_path(path) 238 return self
Function for adding multiple paths to MazePlot at once.
DOCS: what are the two ways? This can be done in two ways:
- Passing a list of
240 def add_node_values( 241 self, 242 node_values: Float[np.ndarray, "grid_n grid_n"], 243 color_map: str = "Blues", 244 target_token_coord: Coord | None = None, 245 preceeding_tokens_coords: CoordArray = None, 246 colormap_center: float | None = None, 247 colormap_max: float | None = None, 248 hide_colorbar: bool = False, 249 ) -> MazePlot: 250 """add node values to the maze for visualization as a heatmap 251 252 > DOCS: what are these arguments? 253 254 # Parameters: 255 - `node_values : Float[np.ndarray, "grid_n grid_n"]` 256 - `color_map : str` 257 (defaults to `"Blues"`) 258 - `target_token_coord : Coord | None` 259 (defaults to `None`) 260 - `preceeding_tokens_coords : CoordArray` 261 (defaults to `None`) 262 - `colormap_center : float | None` 263 (defaults to `None`) 264 - `colormap_max : float | None` 265 (defaults to `None`) 266 - `hide_colorbar : bool` 267 (defaults to `False`) 268 269 # Returns: 270 - `MazePlot` 271 """ 272 assert node_values.shape == self.maze.grid_shape, ( 273 "Please pass node values of the same sape as LatticeMaze.grid_shape" 274 ) 275 # assert np.min(node_values) >= 0, "Please pass non-negative node values only." 276 277 self.node_values = node_values 278 # Set flag for choosing cmap while plotting maze 279 self.custom_node_value_flag = True 280 # Retrieve Max node value for plotting, +1e-10 to avoid division by zero 281 self.node_color_map = color_map 282 self.colormap_center = colormap_center 283 self.colormap_max = colormap_max 284 self.hide_colorbar = hide_colorbar 285 286 if target_token_coord is not None: 287 self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) 288 if preceeding_tokens_coords is not None: 289 for coord in preceeding_tokens_coords: 290 self.marked_coords.append((coord, self.marker_kwargs_current)) 291 return self
add node values to the maze for visualization as a heatmap
DOCS: what are these arguments?
Parameters:
node_values : Float[np.ndarray, "grid_n grid_n"]
color_map : str
(defaults to"Blues"
)target_token_coord : Coord | None
(defaults toNone
)preceeding_tokens_coords : CoordArray
(defaults toNone
)colormap_center : float | None
(defaults toNone
)colormap_max : float | None
(defaults toNone
)hide_colorbar : bool
(defaults toFalse
)
Returns:
293 def plot( 294 self, 295 dpi: int = 100, 296 title: str = "", 297 fig_ax: tuple | None = None, 298 plain: bool = False, 299 ) -> MazePlot: 300 """Plot the maze and paths.""" 301 # set up figure 302 if fig_ax is None: 303 self.fig = plt.figure(dpi=dpi) 304 self.ax = self.fig.add_subplot(1, 1, 1) 305 else: 306 self.fig, self.ax = fig_ax 307 308 # plot maze 309 self._plot_maze() 310 311 # Plot labels 312 if not plain: 313 tick_arr = np.arange(self.maze.grid_shape[0]) 314 self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) 315 self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) 316 self.ax.set_xlabel("col") 317 self.ax.set_ylabel("row") 318 self.ax.set_title(title) 319 else: 320 self.ax.set_xticks([]) 321 self.ax.set_yticks([]) 322 self.ax.set_xlabel("") 323 self.ax.set_ylabel("") 324 self.ax.axis("off") 325 326 # plot paths 327 if self.true_path is not None: 328 self._plot_path(self.true_path) 329 for path in self.predicted_paths: 330 self._plot_path(path) 331 332 # plot markers 333 for coord, kwargs in self.marked_coords: 334 self._place_marked_coords([coord], **kwargs) 335 336 return self
Plot the maze and paths.
343 def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: 344 """Mark coordinates on the maze with a marker. 345 346 default marker is a blue "+": 347 `dict(marker="+", color="blue")` 348 """ 349 kwargs = { 350 **dict(marker="+", color="blue"), 351 **kwargs, 352 } 353 for coord in coords: 354 self.marked_coords.append((coord, kwargs)) 355 356 return self
Mark coordinates on the maze with a marker.
default marker is a blue "+":
dict(marker="+", color="blue")
592 def to_ascii( 593 self, 594 show_endpoints: bool = True, 595 show_solution: bool = True, 596 ) -> str: 597 "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" 598 if self.true_path: 599 return self.solved_maze.as_ascii( 600 show_endpoints=show_endpoints, 601 show_solution=show_solution, 602 ) 603 else: 604 return self.maze.as_ascii(show_endpoints=show_endpoints)
wrapper for self.solved_maze.as_ascii()
, shows the path if we have self.true_path
26@dataclass(kw_only=True) 27class PathFormat: 28 """formatting options for path plot""" 29 30 label: str | None = None 31 fmt: str = "o" 32 color: str | None = None 33 cmap: str | None = None 34 line_width: float | None = None 35 quiver_kwargs: dict | None = None 36 37 def combine(self, other: PathFormat) -> PathFormat: 38 """combine with other PathFormat object, overwriting attributes with non-None values. 39 40 returns a modified copy of self. 41 """ 42 output: PathFormat = deepcopy(self) 43 for key, value in other.__dict__.items(): 44 if key == "path": 45 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 46 raise ValueError( 47 err_msg, 48 ) 49 if value is not None: 50 setattr(output, key, value) 51 52 return output
formatting options for path plot
37 def combine(self, other: PathFormat) -> PathFormat: 38 """combine with other PathFormat object, overwriting attributes with non-None values. 39 40 returns a modified copy of self. 41 """ 42 output: PathFormat = deepcopy(self) 43 for key, value in other.__dict__.items(): 44 if key == "path": 45 err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" 46 raise ValueError( 47 err_msg, 48 ) 49 if value is not None: 50 setattr(output, key, value) 51 52 return output
combine with other PathFormat object, overwriting attributes with non-None values.
returns a modified copy of self.
125def color_tokens_cmap( 126 tokens: list[str], 127 weights: Sequence[float], 128 cmap: str | matplotlib.colors.Colormap = "Blues", 129 fmt: FormatType = "html", 130 template: str | None = None, 131 labels: bool = False, 132) -> str: 133 "color tokens given a list of weights and a colormap" 134 n_tok: int = len(tokens) 135 assert n_tok == len(weights), f"'{len(tokens) = }' != '{len(weights) = }'" 136 weights_np: Float[np.ndarray, " n_tok"] = np.array(weights) 137 # normalize weights to [0, 1] 138 weights_norm = matplotlib.colors.Normalize()(weights_np) 139 140 if isinstance(cmap, str): 141 cmap = matplotlib.colormaps.get_cmap(cmap) 142 143 colors: RGBArray = cmap(weights_norm)[:, :3] * 255 144 145 output: str = color_tokens_rgb( 146 tokens=tokens, 147 colors=colors, 148 fmt=fmt, 149 template=template, 150 ) 151 152 if labels: 153 if fmt != "terminal": 154 raise NotImplementedError("labels only supported for terminal") 155 # align labels with the tokens 156 output += "\n" 157 for tok, weight in zip(tokens, weights_np, strict=False): 158 # 2 decimal points, left-aligned and trailing spaces to match token length 159 weight_str: str = f"{weight:.1f}" 160 # omit if longer than token 161 if len(weight_str) > len(tok): 162 weight_str = " " * len(tok) 163 else: 164 weight_str = weight_str.ljust(len(tok)) 165 output += f"{weight_str} " 166 167 return output
color tokens given a list of weights and a colormap
184def color_maze_tokens_AOTP( 185 tokens: list[str], 186 fmt: FormatType = "html", 187 template: str | None = None, 188 **kwargs, 189) -> str: 190 """color tokens assuming AOTP format 191 192 i.e: adjaceny list, origin, target, path 193 194 """ 195 output: list[str] = [ 196 " ".join( 197 tokens_between( 198 tokens, 199 start_tok, 200 end_tok, 201 include_start=True, 202 include_end=True, 203 ), 204 ) 205 for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS 206 ] 207 208 colors: RGBArray = np.array( 209 list(_MAZE_TOKENS_DEFAULT_COLORS.values()), 210 dtype=np.uint8, 211 ) 212 213 return color_tokens_rgb( 214 tokens=output, 215 colors=colors, 216 fmt=fmt, 217 template=template, 218 **kwargs, 219 )
color tokens assuming AOTP format
i.e: adjaceny list, origin, target, path
64def color_tokens_rgb( 65 tokens: list, 66 colors: Sequence[Sequence[int]] | Float[np.ndarray, "n 3"], 67 fmt: FormatType = "html", 68 template: str | None = None, 69 clr_join: str | None = None, 70 max_length: int | None = None, 71) -> str: 72 """color tokens from a list with an RGB color array 73 74 tokens will not be escaped if `fmt` is None 75 76 # Parameters: 77 - `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length. 78 """ 79 # process format 80 if fmt is None: 81 assert template is not None 82 assert clr_join is not None 83 else: 84 assert template is None 85 assert clr_join is None 86 template = TEMPLATES[fmt] 87 clr_join = _COLOR_JOIN[fmt] 88 89 if max_length is not None: 90 # TODO: why are we using a map here again? 91 # TYPING: this is missing a lot of type hints 92 wrapped: list = list( # noqa: C417 93 map( 94 lambda x: textwrap.wrap( 95 x, 96 width=max_length, 97 break_long_words=False, 98 break_on_hyphens=False, 99 ), 100 tokens, 101 ), 102 ) 103 colors = list( 104 flatten( 105 [[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))], 106 levels_to_flatten=1, 107 ), 108 ) 109 wrapped = list(flatten(wrapped, levels_to_flatten=1)) 110 tokens = wrapped 111 112 # put everything together 113 output = [ 114 template.format( 115 clr=clr_join.join(map(str, map(int, clr))), 116 tok=_escape_tok(tok, fmt), 117 ) 118 for tok, clr in zip(tokens, colors, strict=False) 119 ] 120 121 return " ".join(output)
color tokens from a list with an RGB color array
tokens will not be escaped if fmt
is None
Parameters:
max_length: int | None
: Max number of characters before triggering a line wrap, i.e., making a new colorbox. IfNone
, no limit on max length.