docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.plotting.plot_maze

provides MazePlot, which has many tools for plotting mazes with multiple paths, colored nodes, and more


  1"""provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more"""
  2
  3from __future__ import annotations  # for type hinting self as return value
  4
  5import warnings
  6from copy import deepcopy
  7from dataclasses import dataclass
  8from typing import Sequence
  9
 10import matplotlib as mpl
 11import matplotlib.pyplot as plt
 12import numpy as np
 13from jaxtyping import Bool, Float
 14
 15from maze_dataset.constants import Coord, CoordArray, CoordList
 16from maze_dataset.maze import (
 17	LatticeMaze,
 18	SolvedMaze,
 19	TargetedLatticeMaze,
 20)
 21
 22LARGE_NEGATIVE_NUMBER: float = -1e10
 23
 24
 25@dataclass(kw_only=True)
 26class PathFormat:
 27	"""formatting options for path plot"""
 28
 29	label: str | None = None
 30	fmt: str = "o"
 31	color: str | None = None
 32	cmap: str | None = None
 33	line_width: float | None = None
 34	quiver_kwargs: dict | None = None
 35
 36	def combine(self, other: PathFormat) -> PathFormat:
 37		"""combine with other PathFormat object, overwriting attributes with non-None values.
 38
 39		returns a modified copy of self.
 40		"""
 41		output: PathFormat = deepcopy(self)
 42		for key, value in other.__dict__.items():
 43			if key == "path":
 44				err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }"
 45				raise ValueError(
 46					err_msg,
 47				)
 48			if value is not None:
 49				setattr(output, key, value)
 50
 51		return output
 52
 53
 54# styled path
 55@dataclass
 56class StyledPath(PathFormat):
 57	"a `StyledPath` is a `PathFormat` with a specific path"
 58
 59	path: CoordArray
 60
 61
 62DEFAULT_FORMATS: dict[str, PathFormat] = {
 63	"true": PathFormat(
 64		label="true path",
 65		fmt="--",
 66		color="red",
 67		line_width=2.5,
 68		quiver_kwargs=None,
 69	),
 70	"predicted": PathFormat(
 71		label=None,
 72		fmt=":",
 73		color=None,
 74		line_width=2,
 75		quiver_kwargs={"width": 0.015},
 76	),
 77}
 78
 79
 80def process_path_input(
 81	path: CoordList | CoordArray | StyledPath,
 82	_default_key: str,
 83	path_fmt: PathFormat | None = None,
 84	**kwargs,
 85) -> StyledPath:
 86	"convert a path, which might be a list or array of coords, into a `StyledPath`"
 87	styled_path: StyledPath
 88	if isinstance(path, StyledPath):
 89		styled_path = path
 90	elif isinstance(path, np.ndarray):
 91		styled_path = StyledPath(path=path)
 92		# add default formatting
 93		styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
 94	elif isinstance(path, list):
 95		styled_path = StyledPath(path=np.array(path))
 96		# add default formatting
 97		styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
 98	else:
 99		err_msg: str = (
100			f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}"
101		)
102		raise TypeError(
103			err_msg,
104		)
105
106	# add formatting from path_fmt
107	if path_fmt is not None:
108		styled_path = styled_path.combine(path_fmt)
109
110	# add formatting from kwargs
111	for key, value in kwargs.items():
112		setattr(styled_path, key, value)
113
114	return styled_path
115
116
117DEFAULT_PREDICTED_PATH_COLORS: list[str] = [
118	"tab:orange",
119	"tab:olive",
120	"sienna",
121	"mediumseagreen",
122	"tab:purple",
123	"slategrey",
124]
125
126
127class MazePlot:
128	"""Class for displaying mazes and paths"""
129
130	def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
131		"""UNIT_LENGTH: Set ratio between node size and wall thickness in image.
132
133		Wall thickness is fixed to 1px
134		A "unit" consists of a single node and the right and lower connection/wall.
135		Example: ul = 14 yields 13:1 ratio between node size and wall thickness
136		"""
137		self.unit_length: int = unit_length
138		self.maze: LatticeMaze = maze
139		self.true_path: StyledPath | None = None
140		self.predicted_paths: list[StyledPath] = []
141		self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
142		self.custom_node_value_flag: bool = False
143		self.node_color_map: str = "Blues"
144		self.target_token_coord: Coord = None
145		self.preceding_tokens_coords: CoordArray = None
146		self.colormap_center: float | None = None
147		self.cbar_ax = None
148		self.marked_coords: list[tuple[Coord, dict]] = list()
149
150		self.marker_kwargs_current: dict = dict(
151			marker="s",
152			color="green",
153			ms=12,
154		)
155		self.marker_kwargs_next: dict = dict(
156			marker="P",
157			color="green",
158			ms=12,
159		)
160
161		if isinstance(maze, SolvedMaze):
162			self.add_true_path(maze.solution)
163		else:
164			if isinstance(maze, TargetedLatticeMaze):
165				self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
166
167	@property
168	def solved_maze(self) -> SolvedMaze:
169		"get the underlying `SolvedMaze` object"
170		if self.true_path is None:
171			raise ValueError(
172				"Cannot return SolvedMaze object without true path. Add true path with add_true_path method.",
173			)
174		return SolvedMaze.from_lattice_maze(
175			lattice_maze=self.maze,
176			solution=self.true_path.path,
177		)
178
179	def add_true_path(
180		self,
181		path: CoordList | CoordArray | StyledPath,
182		path_fmt: PathFormat | None = None,
183		**kwargs,
184	) -> MazePlot:
185		"add a true path to the maze with optional formatting"
186		self.true_path = process_path_input(
187			path=path,
188			_default_key="true",
189			path_fmt=path_fmt,
190			**kwargs,
191		)
192
193		return self
194
195	def add_predicted_path(
196		self,
197		path: CoordList | CoordArray | StyledPath,
198		path_fmt: PathFormat | None = None,
199		**kwargs,
200	) -> MazePlot:
201		"""Recieve predicted path and formatting preferences from input and save in predicted_path list.
202
203		Default formatting depends on nuber of paths already saved in predicted path list.
204		"""
205		styled_path: StyledPath = process_path_input(
206			path=path,
207			_default_key="predicted",
208			path_fmt=path_fmt,
209			**kwargs,
210		)
211
212		# set default label and color if not specified
213		if styled_path.label is None:
214			styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
215
216		if styled_path.color is None:
217			color_num: int = len(self.predicted_paths) % len(
218				DEFAULT_PREDICTED_PATH_COLORS,
219			)
220			styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num]
221
222		self.predicted_paths.append(styled_path)
223		return self
224
225	def add_multiple_paths(
226		self,
227		path_list: Sequence[CoordList | CoordArray | StyledPath],
228	) -> MazePlot:
229		"""Function for adding multiple paths to MazePlot at once.
230
231		> DOCS: what are the two ways?
232		This can be done in two ways:
233		1. Passing a list of
234		"""
235		for path in path_list:
236			self.add_predicted_path(path)
237		return self
238
239	def add_node_values(
240		self,
241		node_values: Float[np.ndarray, "grid_n grid_n"],
242		color_map: str = "Blues",
243		target_token_coord: Coord | None = None,
244		preceeding_tokens_coords: CoordArray = None,
245		colormap_center: float | None = None,
246		colormap_max: float | None = None,
247		hide_colorbar: bool = False,
248	) -> MazePlot:
249		"""add node values to the maze for visualization as a heatmap
250
251		> DOCS: what are these arguments?
252
253		# Parameters:
254		- `node_values : Float[np.ndarray, "grid_n grid_n"]`
255		- `color_map : str`
256			(defaults to `"Blues"`)
257		- `target_token_coord : Coord | None`
258			(defaults to `None`)
259		- `preceeding_tokens_coords : CoordArray`
260			(defaults to `None`)
261		- `colormap_center : float | None`
262			(defaults to `None`)
263		- `colormap_max : float | None`
264			(defaults to `None`)
265		- `hide_colorbar : bool`
266			(defaults to `False`)
267
268		# Returns:
269		- `MazePlot`
270		"""
271		assert node_values.shape == self.maze.grid_shape, (
272			"Please pass node values of the same sape as LatticeMaze.grid_shape"
273		)
274		# assert np.min(node_values) >= 0, "Please pass non-negative node values only."
275
276		self.node_values = node_values
277		# Set flag for choosing cmap while plotting maze
278		self.custom_node_value_flag = True
279		# Retrieve Max node value for plotting, +1e-10 to avoid division by zero
280		self.node_color_map = color_map
281		self.colormap_center = colormap_center
282		self.colormap_max = colormap_max
283		self.hide_colorbar = hide_colorbar
284
285		if target_token_coord is not None:
286			self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
287		if preceeding_tokens_coords is not None:
288			for coord in preceeding_tokens_coords:
289				self.marked_coords.append((coord, self.marker_kwargs_current))
290		return self
291
292	def plot(
293		self,
294		dpi: int = 100,
295		title: str = "",
296		fig_ax: tuple | None = None,
297		plain: bool = False,
298	) -> MazePlot:
299		"""Plot the maze and paths."""
300		# set up figure
301		if fig_ax is None:
302			self.fig = plt.figure(dpi=dpi)
303			self.ax = self.fig.add_subplot(1, 1, 1)
304		else:
305			self.fig, self.ax = fig_ax
306
307		# plot maze
308		self._plot_maze()
309
310		# Plot labels
311		if not plain:
312			tick_arr = np.arange(self.maze.grid_shape[0])
313			self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
314			self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
315			self.ax.set_xlabel("col")
316			self.ax.set_ylabel("row")
317			self.ax.set_title(title)
318		else:
319			self.ax.set_xticks([])
320			self.ax.set_yticks([])
321			self.ax.set_xlabel("")
322			self.ax.set_ylabel("")
323			self.ax.axis("off")
324
325		# plot paths
326		if self.true_path is not None:
327			self._plot_path(self.true_path)
328		for path in self.predicted_paths:
329			self._plot_path(path)
330
331		# plot markers
332		for coord, kwargs in self.marked_coords:
333			self._place_marked_coords([coord], **kwargs)
334
335		return self
336
337	def _rowcol_to_coord(self, point: Coord) -> np.ndarray:
338		"""Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis."""
339		point = np.array([point[1], point[0]])
340		return self.unit_length * (point + 0.5)
341
342	def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
343		"""Mark coordinates on the maze with a marker.
344
345		default marker is a blue "+":
346		`dict(marker="+", color="blue")`
347		"""
348		kwargs = {
349			**dict(marker="+", color="blue"),
350			**kwargs,
351		}
352		for coord in coords:
353			self.marked_coords.append((coord, kwargs))
354
355		return self
356
357	def _place_marked_coords(
358		self,
359		coords: CoordArray | list[Coord],
360		**kwargs,
361	) -> MazePlot:
362		coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords])
363		self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs)
364
365		return self
366
367	def _plot_maze(self) -> None:  # noqa: C901, PLR0912
368		"""Define Colormap and plot maze.
369
370		Colormap: x is -inf: black
371		else: use colormap
372		"""
373		img = self._lattice_maze_to_img()
374
375		# if no node_values have been passed (no colormap)
376		if self.custom_node_value_flag is False:
377			self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1)
378
379		else:
380			assert self.node_values is not None, "Please pass node values."
381			assert not np.isnan(self.node_values).any(), (
382				"Please pass node values, they cannot be nan."
383			)
384
385			vals_min: float = np.nanmin(self.node_values)
386			vals_max: float = np.nanmax(self.node_values)
387			# if both are negative or both are positive, set max/min to 0
388			if vals_max < 0.0:
389				vals_max = 0.0
390			elif vals_min > 0.0:
391				vals_min = 0.0
392
393			# adjust vals_max, in case you need consistent colorbar across multiple plots
394			vals_max = self.colormap_max or vals_max
395
396			# create colormap
397			cmap = mpl.colormaps[self.node_color_map]
398			# TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black
399			cmap.set_bad(color="black")
400
401			if self.colormap_center is not None:
402				if not (vals_min < self.colormap_center < vals_max):
403					if vals_min == self.colormap_center:
404						vals_min -= 1e-10
405					elif vals_max == self.colormap_center:
406						vals_max += 1e-10
407					else:
408						err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}"
409						raise ValueError(
410							err_msg,
411						)
412
413				norm = mpl.colors.TwoSlopeNorm(
414					vmin=vals_min,
415					vcenter=self.colormap_center,
416					vmax=vals_max,
417				)
418				_plotted = self.ax.imshow(img, cmap=cmap, norm=norm)
419			else:
420				_plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max)
421
422			# Add colorbar based on the condition of self.hide_colorbar
423			if not self.hide_colorbar:
424				ticks = np.linspace(vals_min, vals_max, 5)
425
426				if (vals_min < 0.0 < vals_max) and (0.0 not in ticks):
427					ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0)
428
429				if (
430					self.colormap_center is not None
431					and self.colormap_center not in ticks
432					and vals_min < self.colormap_center < vals_max
433				):
434					ticks = np.insert(
435						ticks,
436						np.searchsorted(ticks, self.colormap_center),
437						self.colormap_center,
438					)
439
440				cbar = plt.colorbar(
441					_plotted,
442					ticks=ticks,
443					ax=self.ax,
444					cax=self.cbar_ax,
445				)
446				self.cbar_ax = cbar.ax
447
448		# make the boundaries of the image thicker (walls look weird without this)
449		for axis in ["top", "bottom", "left", "right"]:
450			self.ax.spines[axis].set_linewidth(2)
451
452	def _lattice_maze_to_img(
453		self,
454		connection_val_scale: float = 0.93,
455	) -> Bool[np.ndarray, "row col"]:
456		"""Build an image to visualise the maze.
457
458		Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul.
459		- Nodes have area: (ul-1) * (ul-1) and value 1 by default
460			- take node_value if passed via .add_node_values()
461		- Walls have area: 1 * (ul-1) and value -1
462		- Connections have area: 1 * (ul-1); color and value 0.93 by default
463			- take node_value if passed via .add_node_values()
464
465		Axes definition:
466		(0,0)     col
467		----|----------->
468			|
469		row |
470			|
471			v
472
473		Returns a matrix of side length (ul) * n + 1 where n is the number of nodes.
474		"""
475		# TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored
476		node_bdry_hack: int
477		connection_list_processed: Float[np.ndarray, "dim row col"]
478		# Set node and connection values
479		if self.node_values is None:
480			scaled_node_values = np.ones(self.maze.grid_shape)
481			connection_values = scaled_node_values * connection_val_scale
482			node_bdry_hack = 0
483			# TODO: hack
484			# invert connection list
485			connection_list_processed = np.logical_not(self.maze.connection_list)
486		else:
487			# TODO: hack
488			scaled_node_values = self.node_values
489			# connection_values = scaled_node_values
490			connection_values = np.full_like(scaled_node_values, np.nan)
491			node_bdry_hack = 1
492			connection_list_processed = self.maze.connection_list
493
494		# Create background image (all pixels set to -1, walls everywhere)
495		img: Float[np.ndarray, "row col"] = -np.ones(
496			(
497				self.maze.grid_shape[0] * self.unit_length + 1,
498				self.maze.grid_shape[1] * self.unit_length + 1,
499			),
500			dtype=float,
501		)
502
503		# Draw nodes and connections by iterating through lattice
504		for row in range(self.maze.grid_shape[0]):
505			for col in range(self.maze.grid_shape[1]):
506				# Draw node
507				img[
508					row * self.unit_length + 1 : (row + 1) * self.unit_length
509					+ node_bdry_hack,
510					col * self.unit_length + 1 : (col + 1) * self.unit_length
511					+ node_bdry_hack,
512				] = scaled_node_values[row, col]
513
514				# Down connection
515				if not connection_list_processed[0, row, col]:
516					img[
517						(row + 1) * self.unit_length,
518						col * self.unit_length + 1 : (col + 1) * self.unit_length,
519					] = connection_values[row, col]
520
521				# Right connection
522				if not connection_list_processed[1, row, col]:
523					img[
524						row * self.unit_length + 1 : (row + 1) * self.unit_length,
525						(col + 1) * self.unit_length,
526					] = connection_values[row, col]
527
528		return img
529
530	def _plot_path(self, path_format: PathFormat) -> None:
531		if len(path_format.path) == 0:
532			warnings.warn(f"Empty path, skipping plotting\n{path_format = }")
533			return
534		p_transformed = np.array(
535			[self._rowcol_to_coord(coord) for coord in path_format.path],
536		)
537		if path_format.quiver_kwargs is not None:
538			try:
539				x: np.ndarray = p_transformed[:, 0]
540				y: np.ndarray = p_transformed[:, 1]
541			except Exception as e:
542				err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}"
543				raise ValueError(
544					err_msg,
545				) from e
546
547			# Generate colors from the colormap
548			if path_format.cmap is not None:
549				n = len(x) - 1  # Number of arrows
550				cmap = plt.get_cmap(path_format.cmap)
551				colors = [cmap(i / n) for i in range(n)]
552			else:
553				colors = path_format.color
554
555			self.ax.quiver(
556				x[:-1],
557				y[:-1],
558				x[1:] - x[:-1],
559				y[1:] - y[:-1],
560				scale_units="xy",
561				angles="xy",
562				scale=1,
563				color=colors,
564				**path_format.quiver_kwargs,
565			)
566		else:
567			self.ax.plot(
568				p_transformed[:, 0],
569				p_transformed[:, 1],
570				path_format.fmt,
571				lw=path_format.line_width,
572				color=path_format.color,
573				label=path_format.label,
574			)
575		# mark endpoints
576		self.ax.plot(
577			[p_transformed[0][0]],
578			[p_transformed[0][1]],
579			"o",
580			color=path_format.color,
581			ms=10,
582		)
583		self.ax.plot(
584			[p_transformed[-1][0]],
585			[p_transformed[-1][1]],
586			"x",
587			color=path_format.color,
588			ms=10,
589		)
590
591	def to_ascii(
592		self,
593		show_endpoints: bool = True,
594		show_solution: bool = True,
595	) -> str:
596		"wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`"
597		if self.true_path:
598			return self.solved_maze.as_ascii(
599				show_endpoints=show_endpoints,
600				show_solution=show_solution,
601			)
602		else:
603			return self.maze.as_ascii(show_endpoints=show_endpoints)

LARGE_NEGATIVE_NUMBER: float = -10000000000.0
@dataclass(kw_only=True)
class PathFormat:
26@dataclass(kw_only=True)
27class PathFormat:
28	"""formatting options for path plot"""
29
30	label: str | None = None
31	fmt: str = "o"
32	color: str | None = None
33	cmap: str | None = None
34	line_width: float | None = None
35	quiver_kwargs: dict | None = None
36
37	def combine(self, other: PathFormat) -> PathFormat:
38		"""combine with other PathFormat object, overwriting attributes with non-None values.
39
40		returns a modified copy of self.
41		"""
42		output: PathFormat = deepcopy(self)
43		for key, value in other.__dict__.items():
44			if key == "path":
45				err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }"
46				raise ValueError(
47					err_msg,
48				)
49			if value is not None:
50				setattr(output, key, value)
51
52		return output

formatting options for path plot

PathFormat( *, label: str | None = None, fmt: str = 'o', color: str | None = None, cmap: str | None = None, line_width: float | None = None, quiver_kwargs: dict | None = None)
label: str | None = None
fmt: str = 'o'
color: str | None = None
cmap: str | None = None
line_width: float | None = None
quiver_kwargs: dict | None = None
def combine( self, other: PathFormat) -> PathFormat:
37	def combine(self, other: PathFormat) -> PathFormat:
38		"""combine with other PathFormat object, overwriting attributes with non-None values.
39
40		returns a modified copy of self.
41		"""
42		output: PathFormat = deepcopy(self)
43		for key, value in other.__dict__.items():
44			if key == "path":
45				err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }"
46				raise ValueError(
47					err_msg,
48				)
49			if value is not None:
50				setattr(output, key, value)
51
52		return output

combine with other PathFormat object, overwriting attributes with non-None values.

returns a modified copy of self.

@dataclass
class StyledPath(PathFormat):
56@dataclass
57class StyledPath(PathFormat):
58	"a `StyledPath` is a `PathFormat` with a specific path"
59
60	path: CoordArray

a StyledPath is a PathFormat with a specific path

StyledPath( path: jaxtyping.Int8[ndarray, 'coord row_col=2'], *, label: str | None = None, fmt: str = 'o', color: str | None = None, cmap: str | None = None, line_width: float | None = None, quiver_kwargs: dict | None = None)
path: jaxtyping.Int8[ndarray, 'coord row_col=2']
DEFAULT_FORMATS: dict[str, PathFormat] = {'true': PathFormat(label='true path', fmt='--', color='red', cmap=None, line_width=2.5, quiver_kwargs=None), 'predicted': PathFormat(label=None, fmt=':', color=None, cmap=None, line_width=2, quiver_kwargs={'width': 0.015})}
def process_path_input( path: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | StyledPath, _default_key: str, path_fmt: PathFormat | None = None, **kwargs) -> StyledPath:
 81def process_path_input(
 82	path: CoordList | CoordArray | StyledPath,
 83	_default_key: str,
 84	path_fmt: PathFormat | None = None,
 85	**kwargs,
 86) -> StyledPath:
 87	"convert a path, which might be a list or array of coords, into a `StyledPath`"
 88	styled_path: StyledPath
 89	if isinstance(path, StyledPath):
 90		styled_path = path
 91	elif isinstance(path, np.ndarray):
 92		styled_path = StyledPath(path=path)
 93		# add default formatting
 94		styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
 95	elif isinstance(path, list):
 96		styled_path = StyledPath(path=np.array(path))
 97		# add default formatting
 98		styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key])
 99	else:
100		err_msg: str = (
101			f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}"
102		)
103		raise TypeError(
104			err_msg,
105		)
106
107	# add formatting from path_fmt
108	if path_fmt is not None:
109		styled_path = styled_path.combine(path_fmt)
110
111	# add formatting from kwargs
112	for key, value in kwargs.items():
113		setattr(styled_path, key, value)
114
115	return styled_path

convert a path, which might be a list or array of coords, into a StyledPath

DEFAULT_PREDICTED_PATH_COLORS: list[str] = ['tab:orange', 'tab:olive', 'sienna', 'mediumseagreen', 'tab:purple', 'slategrey']
class MazePlot:
128class MazePlot:
129	"""Class for displaying mazes and paths"""
130
131	def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
132		"""UNIT_LENGTH: Set ratio between node size and wall thickness in image.
133
134		Wall thickness is fixed to 1px
135		A "unit" consists of a single node and the right and lower connection/wall.
136		Example: ul = 14 yields 13:1 ratio between node size and wall thickness
137		"""
138		self.unit_length: int = unit_length
139		self.maze: LatticeMaze = maze
140		self.true_path: StyledPath | None = None
141		self.predicted_paths: list[StyledPath] = []
142		self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
143		self.custom_node_value_flag: bool = False
144		self.node_color_map: str = "Blues"
145		self.target_token_coord: Coord = None
146		self.preceding_tokens_coords: CoordArray = None
147		self.colormap_center: float | None = None
148		self.cbar_ax = None
149		self.marked_coords: list[tuple[Coord, dict]] = list()
150
151		self.marker_kwargs_current: dict = dict(
152			marker="s",
153			color="green",
154			ms=12,
155		)
156		self.marker_kwargs_next: dict = dict(
157			marker="P",
158			color="green",
159			ms=12,
160		)
161
162		if isinstance(maze, SolvedMaze):
163			self.add_true_path(maze.solution)
164		else:
165			if isinstance(maze, TargetedLatticeMaze):
166				self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
167
168	@property
169	def solved_maze(self) -> SolvedMaze:
170		"get the underlying `SolvedMaze` object"
171		if self.true_path is None:
172			raise ValueError(
173				"Cannot return SolvedMaze object without true path. Add true path with add_true_path method.",
174			)
175		return SolvedMaze.from_lattice_maze(
176			lattice_maze=self.maze,
177			solution=self.true_path.path,
178		)
179
180	def add_true_path(
181		self,
182		path: CoordList | CoordArray | StyledPath,
183		path_fmt: PathFormat | None = None,
184		**kwargs,
185	) -> MazePlot:
186		"add a true path to the maze with optional formatting"
187		self.true_path = process_path_input(
188			path=path,
189			_default_key="true",
190			path_fmt=path_fmt,
191			**kwargs,
192		)
193
194		return self
195
196	def add_predicted_path(
197		self,
198		path: CoordList | CoordArray | StyledPath,
199		path_fmt: PathFormat | None = None,
200		**kwargs,
201	) -> MazePlot:
202		"""Recieve predicted path and formatting preferences from input and save in predicted_path list.
203
204		Default formatting depends on nuber of paths already saved in predicted path list.
205		"""
206		styled_path: StyledPath = process_path_input(
207			path=path,
208			_default_key="predicted",
209			path_fmt=path_fmt,
210			**kwargs,
211		)
212
213		# set default label and color if not specified
214		if styled_path.label is None:
215			styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
216
217		if styled_path.color is None:
218			color_num: int = len(self.predicted_paths) % len(
219				DEFAULT_PREDICTED_PATH_COLORS,
220			)
221			styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num]
222
223		self.predicted_paths.append(styled_path)
224		return self
225
226	def add_multiple_paths(
227		self,
228		path_list: Sequence[CoordList | CoordArray | StyledPath],
229	) -> MazePlot:
230		"""Function for adding multiple paths to MazePlot at once.
231
232		> DOCS: what are the two ways?
233		This can be done in two ways:
234		1. Passing a list of
235		"""
236		for path in path_list:
237			self.add_predicted_path(path)
238		return self
239
240	def add_node_values(
241		self,
242		node_values: Float[np.ndarray, "grid_n grid_n"],
243		color_map: str = "Blues",
244		target_token_coord: Coord | None = None,
245		preceeding_tokens_coords: CoordArray = None,
246		colormap_center: float | None = None,
247		colormap_max: float | None = None,
248		hide_colorbar: bool = False,
249	) -> MazePlot:
250		"""add node values to the maze for visualization as a heatmap
251
252		> DOCS: what are these arguments?
253
254		# Parameters:
255		- `node_values : Float[np.ndarray, &quot;grid_n grid_n&quot;]`
256		- `color_map : str`
257			(defaults to `"Blues"`)
258		- `target_token_coord : Coord | None`
259			(defaults to `None`)
260		- `preceeding_tokens_coords : CoordArray`
261			(defaults to `None`)
262		- `colormap_center : float | None`
263			(defaults to `None`)
264		- `colormap_max : float | None`
265			(defaults to `None`)
266		- `hide_colorbar : bool`
267			(defaults to `False`)
268
269		# Returns:
270		- `MazePlot`
271		"""
272		assert node_values.shape == self.maze.grid_shape, (
273			"Please pass node values of the same sape as LatticeMaze.grid_shape"
274		)
275		# assert np.min(node_values) >= 0, "Please pass non-negative node values only."
276
277		self.node_values = node_values
278		# Set flag for choosing cmap while plotting maze
279		self.custom_node_value_flag = True
280		# Retrieve Max node value for plotting, +1e-10 to avoid division by zero
281		self.node_color_map = color_map
282		self.colormap_center = colormap_center
283		self.colormap_max = colormap_max
284		self.hide_colorbar = hide_colorbar
285
286		if target_token_coord is not None:
287			self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
288		if preceeding_tokens_coords is not None:
289			for coord in preceeding_tokens_coords:
290				self.marked_coords.append((coord, self.marker_kwargs_current))
291		return self
292
293	def plot(
294		self,
295		dpi: int = 100,
296		title: str = "",
297		fig_ax: tuple | None = None,
298		plain: bool = False,
299	) -> MazePlot:
300		"""Plot the maze and paths."""
301		# set up figure
302		if fig_ax is None:
303			self.fig = plt.figure(dpi=dpi)
304			self.ax = self.fig.add_subplot(1, 1, 1)
305		else:
306			self.fig, self.ax = fig_ax
307
308		# plot maze
309		self._plot_maze()
310
311		# Plot labels
312		if not plain:
313			tick_arr = np.arange(self.maze.grid_shape[0])
314			self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
315			self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
316			self.ax.set_xlabel("col")
317			self.ax.set_ylabel("row")
318			self.ax.set_title(title)
319		else:
320			self.ax.set_xticks([])
321			self.ax.set_yticks([])
322			self.ax.set_xlabel("")
323			self.ax.set_ylabel("")
324			self.ax.axis("off")
325
326		# plot paths
327		if self.true_path is not None:
328			self._plot_path(self.true_path)
329		for path in self.predicted_paths:
330			self._plot_path(path)
331
332		# plot markers
333		for coord, kwargs in self.marked_coords:
334			self._place_marked_coords([coord], **kwargs)
335
336		return self
337
338	def _rowcol_to_coord(self, point: Coord) -> np.ndarray:
339		"""Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis."""
340		point = np.array([point[1], point[0]])
341		return self.unit_length * (point + 0.5)
342
343	def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
344		"""Mark coordinates on the maze with a marker.
345
346		default marker is a blue "+":
347		`dict(marker="+", color="blue")`
348		"""
349		kwargs = {
350			**dict(marker="+", color="blue"),
351			**kwargs,
352		}
353		for coord in coords:
354			self.marked_coords.append((coord, kwargs))
355
356		return self
357
358	def _place_marked_coords(
359		self,
360		coords: CoordArray | list[Coord],
361		**kwargs,
362	) -> MazePlot:
363		coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords])
364		self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs)
365
366		return self
367
368	def _plot_maze(self) -> None:  # noqa: C901, PLR0912
369		"""Define Colormap and plot maze.
370
371		Colormap: x is -inf: black
372		else: use colormap
373		"""
374		img = self._lattice_maze_to_img()
375
376		# if no node_values have been passed (no colormap)
377		if self.custom_node_value_flag is False:
378			self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1)
379
380		else:
381			assert self.node_values is not None, "Please pass node values."
382			assert not np.isnan(self.node_values).any(), (
383				"Please pass node values, they cannot be nan."
384			)
385
386			vals_min: float = np.nanmin(self.node_values)
387			vals_max: float = np.nanmax(self.node_values)
388			# if both are negative or both are positive, set max/min to 0
389			if vals_max < 0.0:
390				vals_max = 0.0
391			elif vals_min > 0.0:
392				vals_min = 0.0
393
394			# adjust vals_max, in case you need consistent colorbar across multiple plots
395			vals_max = self.colormap_max or vals_max
396
397			# create colormap
398			cmap = mpl.colormaps[self.node_color_map]
399			# TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black
400			cmap.set_bad(color="black")
401
402			if self.colormap_center is not None:
403				if not (vals_min < self.colormap_center < vals_max):
404					if vals_min == self.colormap_center:
405						vals_min -= 1e-10
406					elif vals_max == self.colormap_center:
407						vals_max += 1e-10
408					else:
409						err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}"
410						raise ValueError(
411							err_msg,
412						)
413
414				norm = mpl.colors.TwoSlopeNorm(
415					vmin=vals_min,
416					vcenter=self.colormap_center,
417					vmax=vals_max,
418				)
419				_plotted = self.ax.imshow(img, cmap=cmap, norm=norm)
420			else:
421				_plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max)
422
423			# Add colorbar based on the condition of self.hide_colorbar
424			if not self.hide_colorbar:
425				ticks = np.linspace(vals_min, vals_max, 5)
426
427				if (vals_min < 0.0 < vals_max) and (0.0 not in ticks):
428					ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0)
429
430				if (
431					self.colormap_center is not None
432					and self.colormap_center not in ticks
433					and vals_min < self.colormap_center < vals_max
434				):
435					ticks = np.insert(
436						ticks,
437						np.searchsorted(ticks, self.colormap_center),
438						self.colormap_center,
439					)
440
441				cbar = plt.colorbar(
442					_plotted,
443					ticks=ticks,
444					ax=self.ax,
445					cax=self.cbar_ax,
446				)
447				self.cbar_ax = cbar.ax
448
449		# make the boundaries of the image thicker (walls look weird without this)
450		for axis in ["top", "bottom", "left", "right"]:
451			self.ax.spines[axis].set_linewidth(2)
452
453	def _lattice_maze_to_img(
454		self,
455		connection_val_scale: float = 0.93,
456	) -> Bool[np.ndarray, "row col"]:
457		"""Build an image to visualise the maze.
458
459		Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul.
460		- Nodes have area: (ul-1) * (ul-1) and value 1 by default
461			- take node_value if passed via .add_node_values()
462		- Walls have area: 1 * (ul-1) and value -1
463		- Connections have area: 1 * (ul-1); color and value 0.93 by default
464			- take node_value if passed via .add_node_values()
465
466		Axes definition:
467		(0,0)     col
468		----|----------->
469			|
470		row |
471			|
472			v
473
474		Returns a matrix of side length (ul) * n + 1 where n is the number of nodes.
475		"""
476		# TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored
477		node_bdry_hack: int
478		connection_list_processed: Float[np.ndarray, "dim row col"]
479		# Set node and connection values
480		if self.node_values is None:
481			scaled_node_values = np.ones(self.maze.grid_shape)
482			connection_values = scaled_node_values * connection_val_scale
483			node_bdry_hack = 0
484			# TODO: hack
485			# invert connection list
486			connection_list_processed = np.logical_not(self.maze.connection_list)
487		else:
488			# TODO: hack
489			scaled_node_values = self.node_values
490			# connection_values = scaled_node_values
491			connection_values = np.full_like(scaled_node_values, np.nan)
492			node_bdry_hack = 1
493			connection_list_processed = self.maze.connection_list
494
495		# Create background image (all pixels set to -1, walls everywhere)
496		img: Float[np.ndarray, "row col"] = -np.ones(
497			(
498				self.maze.grid_shape[0] * self.unit_length + 1,
499				self.maze.grid_shape[1] * self.unit_length + 1,
500			),
501			dtype=float,
502		)
503
504		# Draw nodes and connections by iterating through lattice
505		for row in range(self.maze.grid_shape[0]):
506			for col in range(self.maze.grid_shape[1]):
507				# Draw node
508				img[
509					row * self.unit_length + 1 : (row + 1) * self.unit_length
510					+ node_bdry_hack,
511					col * self.unit_length + 1 : (col + 1) * self.unit_length
512					+ node_bdry_hack,
513				] = scaled_node_values[row, col]
514
515				# Down connection
516				if not connection_list_processed[0, row, col]:
517					img[
518						(row + 1) * self.unit_length,
519						col * self.unit_length + 1 : (col + 1) * self.unit_length,
520					] = connection_values[row, col]
521
522				# Right connection
523				if not connection_list_processed[1, row, col]:
524					img[
525						row * self.unit_length + 1 : (row + 1) * self.unit_length,
526						(col + 1) * self.unit_length,
527					] = connection_values[row, col]
528
529		return img
530
531	def _plot_path(self, path_format: PathFormat) -> None:
532		if len(path_format.path) == 0:
533			warnings.warn(f"Empty path, skipping plotting\n{path_format = }")
534			return
535		p_transformed = np.array(
536			[self._rowcol_to_coord(coord) for coord in path_format.path],
537		)
538		if path_format.quiver_kwargs is not None:
539			try:
540				x: np.ndarray = p_transformed[:, 0]
541				y: np.ndarray = p_transformed[:, 1]
542			except Exception as e:
543				err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}"
544				raise ValueError(
545					err_msg,
546				) from e
547
548			# Generate colors from the colormap
549			if path_format.cmap is not None:
550				n = len(x) - 1  # Number of arrows
551				cmap = plt.get_cmap(path_format.cmap)
552				colors = [cmap(i / n) for i in range(n)]
553			else:
554				colors = path_format.color
555
556			self.ax.quiver(
557				x[:-1],
558				y[:-1],
559				x[1:] - x[:-1],
560				y[1:] - y[:-1],
561				scale_units="xy",
562				angles="xy",
563				scale=1,
564				color=colors,
565				**path_format.quiver_kwargs,
566			)
567		else:
568			self.ax.plot(
569				p_transformed[:, 0],
570				p_transformed[:, 1],
571				path_format.fmt,
572				lw=path_format.line_width,
573				color=path_format.color,
574				label=path_format.label,
575			)
576		# mark endpoints
577		self.ax.plot(
578			[p_transformed[0][0]],
579			[p_transformed[0][1]],
580			"o",
581			color=path_format.color,
582			ms=10,
583		)
584		self.ax.plot(
585			[p_transformed[-1][0]],
586			[p_transformed[-1][1]],
587			"x",
588			color=path_format.color,
589			ms=10,
590		)
591
592	def to_ascii(
593		self,
594		show_endpoints: bool = True,
595		show_solution: bool = True,
596	) -> str:
597		"wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`"
598		if self.true_path:
599			return self.solved_maze.as_ascii(
600				show_endpoints=show_endpoints,
601				show_solution=show_solution,
602			)
603		else:
604			return self.maze.as_ascii(show_endpoints=show_endpoints)

Class for displaying mazes and paths

MazePlot( maze: maze_dataset.LatticeMaze, unit_length: int = 14)
131	def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
132		"""UNIT_LENGTH: Set ratio between node size and wall thickness in image.
133
134		Wall thickness is fixed to 1px
135		A "unit" consists of a single node and the right and lower connection/wall.
136		Example: ul = 14 yields 13:1 ratio between node size and wall thickness
137		"""
138		self.unit_length: int = unit_length
139		self.maze: LatticeMaze = maze
140		self.true_path: StyledPath | None = None
141		self.predicted_paths: list[StyledPath] = []
142		self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
143		self.custom_node_value_flag: bool = False
144		self.node_color_map: str = "Blues"
145		self.target_token_coord: Coord = None
146		self.preceding_tokens_coords: CoordArray = None
147		self.colormap_center: float | None = None
148		self.cbar_ax = None
149		self.marked_coords: list[tuple[Coord, dict]] = list()
150
151		self.marker_kwargs_current: dict = dict(
152			marker="s",
153			color="green",
154			ms=12,
155		)
156		self.marker_kwargs_next: dict = dict(
157			marker="P",
158			color="green",
159			ms=12,
160		)
161
162		if isinstance(maze, SolvedMaze):
163			self.add_true_path(maze.solution)
164		else:
165			if isinstance(maze, TargetedLatticeMaze):
166				self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)

UNIT_LENGTH: Set ratio between node size and wall thickness in image.

Wall thickness is fixed to 1px A "unit" consists of a single node and the right and lower connection/wall. Example: ul = 14 yields 13:1 ratio between node size and wall thickness

unit_length: int
true_path: StyledPath | None
predicted_paths: list[StyledPath]
node_values: jaxtyping.Float[ndarray, 'grid_n grid_n']
custom_node_value_flag: bool
node_color_map: str
target_token_coord: jaxtyping.Int8[ndarray, 'row_col=2']
preceding_tokens_coords: jaxtyping.Int8[ndarray, 'coord row_col=2']
colormap_center: float | None
cbar_ax
marked_coords: list[tuple[jaxtyping.Int8[ndarray, 'row_col=2'], dict]]
marker_kwargs_current: dict
marker_kwargs_next: dict
solved_maze: maze_dataset.SolvedMaze
168	@property
169	def solved_maze(self) -> SolvedMaze:
170		"get the underlying `SolvedMaze` object"
171		if self.true_path is None:
172			raise ValueError(
173				"Cannot return SolvedMaze object without true path. Add true path with add_true_path method.",
174			)
175		return SolvedMaze.from_lattice_maze(
176			lattice_maze=self.maze,
177			solution=self.true_path.path,
178		)

get the underlying SolvedMaze object

def add_true_path( self, path: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | StyledPath, path_fmt: PathFormat | None = None, **kwargs) -> MazePlot:
180	def add_true_path(
181		self,
182		path: CoordList | CoordArray | StyledPath,
183		path_fmt: PathFormat | None = None,
184		**kwargs,
185	) -> MazePlot:
186		"add a true path to the maze with optional formatting"
187		self.true_path = process_path_input(
188			path=path,
189			_default_key="true",
190			path_fmt=path_fmt,
191			**kwargs,
192		)
193
194		return self

add a true path to the maze with optional formatting

def add_predicted_path( self, path: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | StyledPath, path_fmt: PathFormat | None = None, **kwargs) -> MazePlot:
196	def add_predicted_path(
197		self,
198		path: CoordList | CoordArray | StyledPath,
199		path_fmt: PathFormat | None = None,
200		**kwargs,
201	) -> MazePlot:
202		"""Recieve predicted path and formatting preferences from input and save in predicted_path list.
203
204		Default formatting depends on nuber of paths already saved in predicted path list.
205		"""
206		styled_path: StyledPath = process_path_input(
207			path=path,
208			_default_key="predicted",
209			path_fmt=path_fmt,
210			**kwargs,
211		)
212
213		# set default label and color if not specified
214		if styled_path.label is None:
215			styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
216
217		if styled_path.color is None:
218			color_num: int = len(self.predicted_paths) % len(
219				DEFAULT_PREDICTED_PATH_COLORS,
220			)
221			styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num]
222
223		self.predicted_paths.append(styled_path)
224		return self

Recieve predicted path and formatting preferences from input and save in predicted_path list.

Default formatting depends on nuber of paths already saved in predicted path list.

def add_multiple_paths( self, path_list: Sequence[list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | StyledPath]) -> MazePlot:
226	def add_multiple_paths(
227		self,
228		path_list: Sequence[CoordList | CoordArray | StyledPath],
229	) -> MazePlot:
230		"""Function for adding multiple paths to MazePlot at once.
231
232		> DOCS: what are the two ways?
233		This can be done in two ways:
234		1. Passing a list of
235		"""
236		for path in path_list:
237			self.add_predicted_path(path)
238		return self

Function for adding multiple paths to MazePlot at once.

DOCS: what are the two ways? This can be done in two ways:

  1. Passing a list of
def add_node_values( self, node_values: jaxtyping.Float[ndarray, 'grid_n grid_n'], color_map: str = 'Blues', target_token_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None, preceeding_tokens_coords: jaxtyping.Int8[ndarray, 'coord row_col=2'] = None, colormap_center: float | None = None, colormap_max: float | None = None, hide_colorbar: bool = False) -> MazePlot:
240	def add_node_values(
241		self,
242		node_values: Float[np.ndarray, "grid_n grid_n"],
243		color_map: str = "Blues",
244		target_token_coord: Coord | None = None,
245		preceeding_tokens_coords: CoordArray = None,
246		colormap_center: float | None = None,
247		colormap_max: float | None = None,
248		hide_colorbar: bool = False,
249	) -> MazePlot:
250		"""add node values to the maze for visualization as a heatmap
251
252		> DOCS: what are these arguments?
253
254		# Parameters:
255		- `node_values : Float[np.ndarray, &quot;grid_n grid_n&quot;]`
256		- `color_map : str`
257			(defaults to `"Blues"`)
258		- `target_token_coord : Coord | None`
259			(defaults to `None`)
260		- `preceeding_tokens_coords : CoordArray`
261			(defaults to `None`)
262		- `colormap_center : float | None`
263			(defaults to `None`)
264		- `colormap_max : float | None`
265			(defaults to `None`)
266		- `hide_colorbar : bool`
267			(defaults to `False`)
268
269		# Returns:
270		- `MazePlot`
271		"""
272		assert node_values.shape == self.maze.grid_shape, (
273			"Please pass node values of the same sape as LatticeMaze.grid_shape"
274		)
275		# assert np.min(node_values) >= 0, "Please pass non-negative node values only."
276
277		self.node_values = node_values
278		# Set flag for choosing cmap while plotting maze
279		self.custom_node_value_flag = True
280		# Retrieve Max node value for plotting, +1e-10 to avoid division by zero
281		self.node_color_map = color_map
282		self.colormap_center = colormap_center
283		self.colormap_max = colormap_max
284		self.hide_colorbar = hide_colorbar
285
286		if target_token_coord is not None:
287			self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
288		if preceeding_tokens_coords is not None:
289			for coord in preceeding_tokens_coords:
290				self.marked_coords.append((coord, self.marker_kwargs_current))
291		return self

add node values to the maze for visualization as a heatmap

DOCS: what are these arguments?

Parameters:

  • node_values : Float[np.ndarray, &quot;grid_n grid_n&quot;]
  • color_map : str (defaults to "Blues")
  • target_token_coord : Coord | None (defaults to None)
  • preceeding_tokens_coords : CoordArray (defaults to None)
  • colormap_center : float | None (defaults to None)
  • colormap_max : float | None (defaults to None)
  • hide_colorbar : bool (defaults to False)

Returns:

def plot( self, dpi: int = 100, title: str = '', fig_ax: tuple | None = None, plain: bool = False) -> MazePlot:
293	def plot(
294		self,
295		dpi: int = 100,
296		title: str = "",
297		fig_ax: tuple | None = None,
298		plain: bool = False,
299	) -> MazePlot:
300		"""Plot the maze and paths."""
301		# set up figure
302		if fig_ax is None:
303			self.fig = plt.figure(dpi=dpi)
304			self.ax = self.fig.add_subplot(1, 1, 1)
305		else:
306			self.fig, self.ax = fig_ax
307
308		# plot maze
309		self._plot_maze()
310
311		# Plot labels
312		if not plain:
313			tick_arr = np.arange(self.maze.grid_shape[0])
314			self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
315			self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
316			self.ax.set_xlabel("col")
317			self.ax.set_ylabel("row")
318			self.ax.set_title(title)
319		else:
320			self.ax.set_xticks([])
321			self.ax.set_yticks([])
322			self.ax.set_xlabel("")
323			self.ax.set_ylabel("")
324			self.ax.axis("off")
325
326		# plot paths
327		if self.true_path is not None:
328			self._plot_path(self.true_path)
329		for path in self.predicted_paths:
330			self._plot_path(path)
331
332		# plot markers
333		for coord, kwargs in self.marked_coords:
334			self._place_marked_coords([coord], **kwargs)
335
336		return self

Plot the maze and paths.

def mark_coords( self, coords: jaxtyping.Int8[ndarray, 'coord row_col=2'] | list[jaxtyping.Int8[ndarray, 'row_col=2']], **kwargs) -> MazePlot:
343	def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
344		"""Mark coordinates on the maze with a marker.
345
346		default marker is a blue "+":
347		`dict(marker="+", color="blue")`
348		"""
349		kwargs = {
350			**dict(marker="+", color="blue"),
351			**kwargs,
352		}
353		for coord in coords:
354			self.marked_coords.append((coord, kwargs))
355
356		return self

Mark coordinates on the maze with a marker.

default marker is a blue "+": dict(marker="+", color="blue")

def to_ascii(self, show_endpoints: bool = True, show_solution: bool = True) -> str:
592	def to_ascii(
593		self,
594		show_endpoints: bool = True,
595		show_solution: bool = True,
596	) -> str:
597		"wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`"
598		if self.true_path:
599			return self.solved_maze.as_ascii(
600				show_endpoints=show_endpoints,
601				show_solution=show_solution,
602			)
603		else:
604			return self.maze.as_ascii(show_endpoints=show_endpoints)

wrapper for self.solved_maze.as_ascii(), shows the path if we have self.true_path