docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.plotting

utilities for plotting mazes and printing tokens

  • any LatticeMaze or SolvedMaze comes with a as_pixels() method that returns a 2D numpy array of pixel values, but this is somewhat limited
  • MazePlot is a class that can be used to plot mazes and paths in a more customizable way
  • print_tokens contains utilities for printing tokens, colored by their type, position, or some custom weights (i.e. attention weights)

 1"""utilities for plotting mazes and printing tokens
 2
 3- any `LatticeMaze` or `SolvedMaze` comes with a `as_pixels()` method that returns
 4  a 2D numpy array of pixel values, but this is somewhat limited
 5- `MazePlot` is a class that can be used to plot mazes and paths in a more customizable way
 6- `print_tokens` contains utilities for printing tokens, colored by their type, position, or some custom weights (i.e. attention weights)
 7"""
 8
 9from maze_dataset.plotting.plot_dataset import plot_dataset_mazes, print_dataset_mazes
10from maze_dataset.plotting.plot_maze import DEFAULT_FORMATS, MazePlot, PathFormat
11from maze_dataset.plotting.print_tokens import (
12	color_maze_tokens_AOTP,
13	color_tokens_cmap,
14	color_tokens_rgb,
15)
16
17__all__ = [
18	# submodules
19	"plot_dataset",
20	"plot_maze",
21	"plot_svg_fancy",
22	"plot_tokens",
23	"print_tokens",
24	# imports
25	"plot_dataset_mazes",
26	"print_dataset_mazes",
27	"DEFAULT_FORMATS",
28	"MazePlot",
29	"PathFormat",
30	"color_tokens_cmap",
31	"color_maze_tokens_AOTP",
32	"color_tokens_rgb",
33]

def plot_dataset_mazes( ds: maze_dataset.MazeDataset, count: int | None = None, figsize_mult: tuple[float, float] = (1.0, 2.0), title: bool | str = True) -> tuple | None:
12def plot_dataset_mazes(
13	ds: MazeDataset,
14	count: int | None = None,
15	figsize_mult: tuple[float, float] = (1.0, 2.0),
16	title: bool | str = True,
17) -> tuple | None:
18	"plot `count` mazes from the dataset `d` in a single figure using `SolvedMaze.as_pixels()`"
19	count = count or len(ds)
20	if count == 0:
21		print("No mazes to plot for dataset")
22		return None
23	fig, axes = plt.subplots(
24		1,
25		count,
26		figsize=(count * figsize_mult[0], figsize_mult[1]),
27	)
28	if count == 1:
29		axes = [axes]
30	for i in range(count):
31		axes[i].imshow(ds[i].as_pixels())
32		# remove ticks
33		axes[i].set_xticks([])
34		axes[i].set_yticks([])
35
36	# set title
37	if title:
38		if isinstance(title, str):
39			fig.suptitle(title)
40		else:
41			kwargs: dict = {
42				"grid_n": ds.cfg.grid_n,
43				# "n_mazes": ds.cfg.n_mazes,
44				**ds.cfg.maze_ctor_kwargs,
45			}
46			fig.suptitle(
47				f"{ds.cfg.to_fname()}\n{ds.cfg.maze_ctor.__name__}({', '.join(f'{k}={v}' for k, v in kwargs.items())})",
48			)
49
50	# tight layout
51	fig.tight_layout()
52	# remove whitespace between title and subplots
53	fig.subplots_adjust(top=1.0)
54
55	return fig, axes

plot count mazes from the dataset d in a single figure using SolvedMaze.as_pixels()

DEFAULT_FORMATS = {'true': PathFormat(label='true path', fmt='--', color='red', cmap=None, line_width=2.5, quiver_kwargs=None), 'predicted': PathFormat(label=None, fmt=':', color=None, cmap=None, line_width=2, quiver_kwargs={'width': 0.015})}
class MazePlot:
128class MazePlot:
129	"""Class for displaying mazes and paths"""
130
131	def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
132		"""UNIT_LENGTH: Set ratio between node size and wall thickness in image.
133
134		Wall thickness is fixed to 1px
135		A "unit" consists of a single node and the right and lower connection/wall.
136		Example: ul = 14 yields 13:1 ratio between node size and wall thickness
137		"""
138		self.unit_length: int = unit_length
139		self.maze: LatticeMaze = maze
140		self.true_path: StyledPath | None = None
141		self.predicted_paths: list[StyledPath] = []
142		self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
143		self.custom_node_value_flag: bool = False
144		self.node_color_map: str = "Blues"
145		self.target_token_coord: Coord = None
146		self.preceding_tokens_coords: CoordArray = None
147		self.colormap_center: float | None = None
148		self.cbar_ax = None
149		self.marked_coords: list[tuple[Coord, dict]] = list()
150
151		self.marker_kwargs_current: dict = dict(
152			marker="s",
153			color="green",
154			ms=12,
155		)
156		self.marker_kwargs_next: dict = dict(
157			marker="P",
158			color="green",
159			ms=12,
160		)
161
162		if isinstance(maze, SolvedMaze):
163			self.add_true_path(maze.solution)
164		else:
165			if isinstance(maze, TargetedLatticeMaze):
166				self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)
167
168	@property
169	def solved_maze(self) -> SolvedMaze:
170		"get the underlying `SolvedMaze` object"
171		if self.true_path is None:
172			raise ValueError(
173				"Cannot return SolvedMaze object without true path. Add true path with add_true_path method.",
174			)
175		return SolvedMaze.from_lattice_maze(
176			lattice_maze=self.maze,
177			solution=self.true_path.path,
178		)
179
180	def add_true_path(
181		self,
182		path: CoordList | CoordArray | StyledPath,
183		path_fmt: PathFormat | None = None,
184		**kwargs,
185	) -> MazePlot:
186		"add a true path to the maze with optional formatting"
187		self.true_path = process_path_input(
188			path=path,
189			_default_key="true",
190			path_fmt=path_fmt,
191			**kwargs,
192		)
193
194		return self
195
196	def add_predicted_path(
197		self,
198		path: CoordList | CoordArray | StyledPath,
199		path_fmt: PathFormat | None = None,
200		**kwargs,
201	) -> MazePlot:
202		"""Recieve predicted path and formatting preferences from input and save in predicted_path list.
203
204		Default formatting depends on nuber of paths already saved in predicted path list.
205		"""
206		styled_path: StyledPath = process_path_input(
207			path=path,
208			_default_key="predicted",
209			path_fmt=path_fmt,
210			**kwargs,
211		)
212
213		# set default label and color if not specified
214		if styled_path.label is None:
215			styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
216
217		if styled_path.color is None:
218			color_num: int = len(self.predicted_paths) % len(
219				DEFAULT_PREDICTED_PATH_COLORS,
220			)
221			styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num]
222
223		self.predicted_paths.append(styled_path)
224		return self
225
226	def add_multiple_paths(
227		self,
228		path_list: Sequence[CoordList | CoordArray | StyledPath],
229	) -> MazePlot:
230		"""Function for adding multiple paths to MazePlot at once.
231
232		> DOCS: what are the two ways?
233		This can be done in two ways:
234		1. Passing a list of
235		"""
236		for path in path_list:
237			self.add_predicted_path(path)
238		return self
239
240	def add_node_values(
241		self,
242		node_values: Float[np.ndarray, "grid_n grid_n"],
243		color_map: str = "Blues",
244		target_token_coord: Coord | None = None,
245		preceeding_tokens_coords: CoordArray = None,
246		colormap_center: float | None = None,
247		colormap_max: float | None = None,
248		hide_colorbar: bool = False,
249	) -> MazePlot:
250		"""add node values to the maze for visualization as a heatmap
251
252		> DOCS: what are these arguments?
253
254		# Parameters:
255		- `node_values : Float[np.ndarray, "grid_n grid_n"]`
256		- `color_map : str`
257			(defaults to `"Blues"`)
258		- `target_token_coord : Coord | None`
259			(defaults to `None`)
260		- `preceeding_tokens_coords : CoordArray`
261			(defaults to `None`)
262		- `colormap_center : float | None`
263			(defaults to `None`)
264		- `colormap_max : float | None`
265			(defaults to `None`)
266		- `hide_colorbar : bool`
267			(defaults to `False`)
268
269		# Returns:
270		- `MazePlot`
271		"""
272		assert node_values.shape == self.maze.grid_shape, (
273			"Please pass node values of the same sape as LatticeMaze.grid_shape"
274		)
275		# assert np.min(node_values) >= 0, "Please pass non-negative node values only."
276
277		self.node_values = node_values
278		# Set flag for choosing cmap while plotting maze
279		self.custom_node_value_flag = True
280		# Retrieve Max node value for plotting, +1e-10 to avoid division by zero
281		self.node_color_map = color_map
282		self.colormap_center = colormap_center
283		self.colormap_max = colormap_max
284		self.hide_colorbar = hide_colorbar
285
286		if target_token_coord is not None:
287			self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
288		if preceeding_tokens_coords is not None:
289			for coord in preceeding_tokens_coords:
290				self.marked_coords.append((coord, self.marker_kwargs_current))
291		return self
292
293	def plot(
294		self,
295		dpi: int = 100,
296		title: str = "",
297		fig_ax: tuple | None = None,
298		plain: bool = False,
299	) -> MazePlot:
300		"""Plot the maze and paths."""
301		# set up figure
302		if fig_ax is None:
303			self.fig = plt.figure(dpi=dpi)
304			self.ax = self.fig.add_subplot(1, 1, 1)
305		else:
306			self.fig, self.ax = fig_ax
307
308		# plot maze
309		self._plot_maze()
310
311		# Plot labels
312		if not plain:
313			tick_arr = np.arange(self.maze.grid_shape[0])
314			self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
315			self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
316			self.ax.set_xlabel("col")
317			self.ax.set_ylabel("row")
318			self.ax.set_title(title)
319		else:
320			self.ax.set_xticks([])
321			self.ax.set_yticks([])
322			self.ax.set_xlabel("")
323			self.ax.set_ylabel("")
324			self.ax.axis("off")
325
326		# plot paths
327		if self.true_path is not None:
328			self._plot_path(self.true_path)
329		for path in self.predicted_paths:
330			self._plot_path(path)
331
332		# plot markers
333		for coord, kwargs in self.marked_coords:
334			self._place_marked_coords([coord], **kwargs)
335
336		return self
337
338	def _rowcol_to_coord(self, point: Coord) -> np.ndarray:
339		"""Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis."""
340		point = np.array([point[1], point[0]])
341		return self.unit_length * (point + 0.5)
342
343	def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
344		"""Mark coordinates on the maze with a marker.
345
346		default marker is a blue "+":
347		`dict(marker="+", color="blue")`
348		"""
349		kwargs = {
350			**dict(marker="+", color="blue"),
351			**kwargs,
352		}
353		for coord in coords:
354			self.marked_coords.append((coord, kwargs))
355
356		return self
357
358	def _place_marked_coords(
359		self,
360		coords: CoordArray | list[Coord],
361		**kwargs,
362	) -> MazePlot:
363		coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords])
364		self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs)
365
366		return self
367
368	def _plot_maze(self) -> None:  # noqa: C901, PLR0912
369		"""Define Colormap and plot maze.
370
371		Colormap: x is -inf: black
372		else: use colormap
373		"""
374		img = self._lattice_maze_to_img()
375
376		# if no node_values have been passed (no colormap)
377		if self.custom_node_value_flag is False:
378			self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1)
379
380		else:
381			assert self.node_values is not None, "Please pass node values."
382			assert not np.isnan(self.node_values).any(), (
383				"Please pass node values, they cannot be nan."
384			)
385
386			vals_min: float = np.nanmin(self.node_values)
387			vals_max: float = np.nanmax(self.node_values)
388			# if both are negative or both are positive, set max/min to 0
389			if vals_max < 0.0:
390				vals_max = 0.0
391			elif vals_min > 0.0:
392				vals_min = 0.0
393
394			# adjust vals_max, in case you need consistent colorbar across multiple plots
395			vals_max = self.colormap_max or vals_max
396
397			# create colormap
398			cmap = mpl.colormaps[self.node_color_map]
399			# TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black
400			cmap.set_bad(color="black")
401
402			if self.colormap_center is not None:
403				if not (vals_min < self.colormap_center < vals_max):
404					if vals_min == self.colormap_center:
405						vals_min -= 1e-10
406					elif vals_max == self.colormap_center:
407						vals_max += 1e-10
408					else:
409						err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}"
410						raise ValueError(
411							err_msg,
412						)
413
414				norm = mpl.colors.TwoSlopeNorm(
415					vmin=vals_min,
416					vcenter=self.colormap_center,
417					vmax=vals_max,
418				)
419				_plotted = self.ax.imshow(img, cmap=cmap, norm=norm)
420			else:
421				_plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max)
422
423			# Add colorbar based on the condition of self.hide_colorbar
424			if not self.hide_colorbar:
425				ticks = np.linspace(vals_min, vals_max, 5)
426
427				if (vals_min < 0.0 < vals_max) and (0.0 not in ticks):
428					ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0)
429
430				if (
431					self.colormap_center is not None
432					and self.colormap_center not in ticks
433					and vals_min < self.colormap_center < vals_max
434				):
435					ticks = np.insert(
436						ticks,
437						np.searchsorted(ticks, self.colormap_center),
438						self.colormap_center,
439					)
440
441				cbar = plt.colorbar(
442					_plotted,
443					ticks=ticks,
444					ax=self.ax,
445					cax=self.cbar_ax,
446				)
447				self.cbar_ax = cbar.ax
448
449		# make the boundaries of the image thicker (walls look weird without this)
450		for axis in ["top", "bottom", "left", "right"]:
451			self.ax.spines[axis].set_linewidth(2)
452
453	def _lattice_maze_to_img(
454		self,
455		connection_val_scale: float = 0.93,
456	) -> Bool[np.ndarray, "row col"]:
457		"""Build an image to visualise the maze.
458
459		Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul.
460		- Nodes have area: (ul-1) * (ul-1) and value 1 by default
461			- take node_value if passed via .add_node_values()
462		- Walls have area: 1 * (ul-1) and value -1
463		- Connections have area: 1 * (ul-1); color and value 0.93 by default
464			- take node_value if passed via .add_node_values()
465
466		Axes definition:
467		(0,0)     col
468		----|----------->
469			|
470		row |
471			|
472			v
473
474		Returns a matrix of side length (ul) * n + 1 where n is the number of nodes.
475		"""
476		# TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored
477		node_bdry_hack: int
478		connection_list_processed: Float[np.ndarray, "dim row col"]
479		# Set node and connection values
480		if self.node_values is None:
481			scaled_node_values = np.ones(self.maze.grid_shape)
482			connection_values = scaled_node_values * connection_val_scale
483			node_bdry_hack = 0
484			# TODO: hack
485			# invert connection list
486			connection_list_processed = np.logical_not(self.maze.connection_list)
487		else:
488			# TODO: hack
489			scaled_node_values = self.node_values
490			# connection_values = scaled_node_values
491			connection_values = np.full_like(scaled_node_values, np.nan)
492			node_bdry_hack = 1
493			connection_list_processed = self.maze.connection_list
494
495		# Create background image (all pixels set to -1, walls everywhere)
496		img: Float[np.ndarray, "row col"] = -np.ones(
497			(
498				self.maze.grid_shape[0] * self.unit_length + 1,
499				self.maze.grid_shape[1] * self.unit_length + 1,
500			),
501			dtype=float,
502		)
503
504		# Draw nodes and connections by iterating through lattice
505		for row in range(self.maze.grid_shape[0]):
506			for col in range(self.maze.grid_shape[1]):
507				# Draw node
508				img[
509					row * self.unit_length + 1 : (row + 1) * self.unit_length
510					+ node_bdry_hack,
511					col * self.unit_length + 1 : (col + 1) * self.unit_length
512					+ node_bdry_hack,
513				] = scaled_node_values[row, col]
514
515				# Down connection
516				if not connection_list_processed[0, row, col]:
517					img[
518						(row + 1) * self.unit_length,
519						col * self.unit_length + 1 : (col + 1) * self.unit_length,
520					] = connection_values[row, col]
521
522				# Right connection
523				if not connection_list_processed[1, row, col]:
524					img[
525						row * self.unit_length + 1 : (row + 1) * self.unit_length,
526						(col + 1) * self.unit_length,
527					] = connection_values[row, col]
528
529		return img
530
531	def _plot_path(self, path_format: PathFormat) -> None:
532		if len(path_format.path) == 0:
533			warnings.warn(f"Empty path, skipping plotting\n{path_format = }")
534			return
535		p_transformed = np.array(
536			[self._rowcol_to_coord(coord) for coord in path_format.path],
537		)
538		if path_format.quiver_kwargs is not None:
539			try:
540				x: np.ndarray = p_transformed[:, 0]
541				y: np.ndarray = p_transformed[:, 1]
542			except Exception as e:
543				err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}"
544				raise ValueError(
545					err_msg,
546				) from e
547
548			# Generate colors from the colormap
549			if path_format.cmap is not None:
550				n = len(x) - 1  # Number of arrows
551				cmap = plt.get_cmap(path_format.cmap)
552				colors = [cmap(i / n) for i in range(n)]
553			else:
554				colors = path_format.color
555
556			self.ax.quiver(
557				x[:-1],
558				y[:-1],
559				x[1:] - x[:-1],
560				y[1:] - y[:-1],
561				scale_units="xy",
562				angles="xy",
563				scale=1,
564				color=colors,
565				**path_format.quiver_kwargs,
566			)
567		else:
568			self.ax.plot(
569				p_transformed[:, 0],
570				p_transformed[:, 1],
571				path_format.fmt,
572				lw=path_format.line_width,
573				color=path_format.color,
574				label=path_format.label,
575			)
576		# mark endpoints
577		self.ax.plot(
578			[p_transformed[0][0]],
579			[p_transformed[0][1]],
580			"o",
581			color=path_format.color,
582			ms=10,
583		)
584		self.ax.plot(
585			[p_transformed[-1][0]],
586			[p_transformed[-1][1]],
587			"x",
588			color=path_format.color,
589			ms=10,
590		)
591
592	def to_ascii(
593		self,
594		show_endpoints: bool = True,
595		show_solution: bool = True,
596	) -> str:
597		"wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`"
598		if self.true_path:
599			return self.solved_maze.as_ascii(
600				show_endpoints=show_endpoints,
601				show_solution=show_solution,
602			)
603		else:
604			return self.maze.as_ascii(show_endpoints=show_endpoints)

Class for displaying mazes and paths

MazePlot( maze: maze_dataset.LatticeMaze, unit_length: int = 14)
131	def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None:
132		"""UNIT_LENGTH: Set ratio between node size and wall thickness in image.
133
134		Wall thickness is fixed to 1px
135		A "unit" consists of a single node and the right and lower connection/wall.
136		Example: ul = 14 yields 13:1 ratio between node size and wall thickness
137		"""
138		self.unit_length: int = unit_length
139		self.maze: LatticeMaze = maze
140		self.true_path: StyledPath | None = None
141		self.predicted_paths: list[StyledPath] = []
142		self.node_values: Float[np.ndarray, "grid_n grid_n"] = None
143		self.custom_node_value_flag: bool = False
144		self.node_color_map: str = "Blues"
145		self.target_token_coord: Coord = None
146		self.preceding_tokens_coords: CoordArray = None
147		self.colormap_center: float | None = None
148		self.cbar_ax = None
149		self.marked_coords: list[tuple[Coord, dict]] = list()
150
151		self.marker_kwargs_current: dict = dict(
152			marker="s",
153			color="green",
154			ms=12,
155		)
156		self.marker_kwargs_next: dict = dict(
157			marker="P",
158			color="green",
159			ms=12,
160		)
161
162		if isinstance(maze, SolvedMaze):
163			self.add_true_path(maze.solution)
164		else:
165			if isinstance(maze, TargetedLatticeMaze):
166				self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution)

UNIT_LENGTH: Set ratio between node size and wall thickness in image.

Wall thickness is fixed to 1px A "unit" consists of a single node and the right and lower connection/wall. Example: ul = 14 yields 13:1 ratio between node size and wall thickness

unit_length: int
node_values: jaxtyping.Float[ndarray, 'grid_n grid_n']
custom_node_value_flag: bool
node_color_map: str
target_token_coord: jaxtyping.Int8[ndarray, 'row_col=2']
preceding_tokens_coords: jaxtyping.Int8[ndarray, 'coord row_col=2']
colormap_center: float | None
cbar_ax
marked_coords: list[tuple[jaxtyping.Int8[ndarray, 'row_col=2'], dict]]
marker_kwargs_current: dict
marker_kwargs_next: dict
solved_maze: maze_dataset.SolvedMaze
168	@property
169	def solved_maze(self) -> SolvedMaze:
170		"get the underlying `SolvedMaze` object"
171		if self.true_path is None:
172			raise ValueError(
173				"Cannot return SolvedMaze object without true path. Add true path with add_true_path method.",
174			)
175		return SolvedMaze.from_lattice_maze(
176			lattice_maze=self.maze,
177			solution=self.true_path.path,
178		)

get the underlying SolvedMaze object

def add_true_path( self, path: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | maze_dataset.plotting.plot_maze.StyledPath, path_fmt: PathFormat | None = None, **kwargs) -> MazePlot:
180	def add_true_path(
181		self,
182		path: CoordList | CoordArray | StyledPath,
183		path_fmt: PathFormat | None = None,
184		**kwargs,
185	) -> MazePlot:
186		"add a true path to the maze with optional formatting"
187		self.true_path = process_path_input(
188			path=path,
189			_default_key="true",
190			path_fmt=path_fmt,
191			**kwargs,
192		)
193
194		return self

add a true path to the maze with optional formatting

def add_predicted_path( self, path: list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | maze_dataset.plotting.plot_maze.StyledPath, path_fmt: PathFormat | None = None, **kwargs) -> MazePlot:
196	def add_predicted_path(
197		self,
198		path: CoordList | CoordArray | StyledPath,
199		path_fmt: PathFormat | None = None,
200		**kwargs,
201	) -> MazePlot:
202		"""Recieve predicted path and formatting preferences from input and save in predicted_path list.
203
204		Default formatting depends on nuber of paths already saved in predicted path list.
205		"""
206		styled_path: StyledPath = process_path_input(
207			path=path,
208			_default_key="predicted",
209			path_fmt=path_fmt,
210			**kwargs,
211		)
212
213		# set default label and color if not specified
214		if styled_path.label is None:
215			styled_path.label = f"predicted path {len(self.predicted_paths) + 1}"
216
217		if styled_path.color is None:
218			color_num: int = len(self.predicted_paths) % len(
219				DEFAULT_PREDICTED_PATH_COLORS,
220			)
221			styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num]
222
223		self.predicted_paths.append(styled_path)
224		return self

Recieve predicted path and formatting preferences from input and save in predicted_path list.

Default formatting depends on nuber of paths already saved in predicted path list.

def add_multiple_paths( self, path_list: Sequence[list[tuple[int, int]] | jaxtyping.Int8[ndarray, 'coord row_col=2'] | maze_dataset.plotting.plot_maze.StyledPath]) -> MazePlot:
226	def add_multiple_paths(
227		self,
228		path_list: Sequence[CoordList | CoordArray | StyledPath],
229	) -> MazePlot:
230		"""Function for adding multiple paths to MazePlot at once.
231
232		> DOCS: what are the two ways?
233		This can be done in two ways:
234		1. Passing a list of
235		"""
236		for path in path_list:
237			self.add_predicted_path(path)
238		return self

Function for adding multiple paths to MazePlot at once.

DOCS: what are the two ways? This can be done in two ways:

  1. Passing a list of
def add_node_values( self, node_values: jaxtyping.Float[ndarray, 'grid_n grid_n'], color_map: str = 'Blues', target_token_coord: jaxtyping.Int8[ndarray, 'row_col=2'] | None = None, preceeding_tokens_coords: jaxtyping.Int8[ndarray, 'coord row_col=2'] = None, colormap_center: float | None = None, colormap_max: float | None = None, hide_colorbar: bool = False) -> MazePlot:
240	def add_node_values(
241		self,
242		node_values: Float[np.ndarray, "grid_n grid_n"],
243		color_map: str = "Blues",
244		target_token_coord: Coord | None = None,
245		preceeding_tokens_coords: CoordArray = None,
246		colormap_center: float | None = None,
247		colormap_max: float | None = None,
248		hide_colorbar: bool = False,
249	) -> MazePlot:
250		"""add node values to the maze for visualization as a heatmap
251
252		> DOCS: what are these arguments?
253
254		# Parameters:
255		- `node_values : Float[np.ndarray, &quot;grid_n grid_n&quot;]`
256		- `color_map : str`
257			(defaults to `"Blues"`)
258		- `target_token_coord : Coord | None`
259			(defaults to `None`)
260		- `preceeding_tokens_coords : CoordArray`
261			(defaults to `None`)
262		- `colormap_center : float | None`
263			(defaults to `None`)
264		- `colormap_max : float | None`
265			(defaults to `None`)
266		- `hide_colorbar : bool`
267			(defaults to `False`)
268
269		# Returns:
270		- `MazePlot`
271		"""
272		assert node_values.shape == self.maze.grid_shape, (
273			"Please pass node values of the same sape as LatticeMaze.grid_shape"
274		)
275		# assert np.min(node_values) >= 0, "Please pass non-negative node values only."
276
277		self.node_values = node_values
278		# Set flag for choosing cmap while plotting maze
279		self.custom_node_value_flag = True
280		# Retrieve Max node value for plotting, +1e-10 to avoid division by zero
281		self.node_color_map = color_map
282		self.colormap_center = colormap_center
283		self.colormap_max = colormap_max
284		self.hide_colorbar = hide_colorbar
285
286		if target_token_coord is not None:
287			self.marked_coords.append((target_token_coord, self.marker_kwargs_next))
288		if preceeding_tokens_coords is not None:
289			for coord in preceeding_tokens_coords:
290				self.marked_coords.append((coord, self.marker_kwargs_current))
291		return self

add node values to the maze for visualization as a heatmap

DOCS: what are these arguments?

Parameters:

  • node_values : Float[np.ndarray, &quot;grid_n grid_n&quot;]
  • color_map : str (defaults to "Blues")
  • target_token_coord : Coord | None (defaults to None)
  • preceeding_tokens_coords : CoordArray (defaults to None)
  • colormap_center : float | None (defaults to None)
  • colormap_max : float | None (defaults to None)
  • hide_colorbar : bool (defaults to False)

Returns:

def plot( self, dpi: int = 100, title: str = '', fig_ax: tuple | None = None, plain: bool = False) -> MazePlot:
293	def plot(
294		self,
295		dpi: int = 100,
296		title: str = "",
297		fig_ax: tuple | None = None,
298		plain: bool = False,
299	) -> MazePlot:
300		"""Plot the maze and paths."""
301		# set up figure
302		if fig_ax is None:
303			self.fig = plt.figure(dpi=dpi)
304			self.ax = self.fig.add_subplot(1, 1, 1)
305		else:
306			self.fig, self.ax = fig_ax
307
308		# plot maze
309		self._plot_maze()
310
311		# Plot labels
312		if not plain:
313			tick_arr = np.arange(self.maze.grid_shape[0])
314			self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr)
315			self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr)
316			self.ax.set_xlabel("col")
317			self.ax.set_ylabel("row")
318			self.ax.set_title(title)
319		else:
320			self.ax.set_xticks([])
321			self.ax.set_yticks([])
322			self.ax.set_xlabel("")
323			self.ax.set_ylabel("")
324			self.ax.axis("off")
325
326		# plot paths
327		if self.true_path is not None:
328			self._plot_path(self.true_path)
329		for path in self.predicted_paths:
330			self._plot_path(path)
331
332		# plot markers
333		for coord, kwargs in self.marked_coords:
334			self._place_marked_coords([coord], **kwargs)
335
336		return self

Plot the maze and paths.

def mark_coords( self, coords: jaxtyping.Int8[ndarray, 'coord row_col=2'] | list[jaxtyping.Int8[ndarray, 'row_col=2']], **kwargs) -> MazePlot:
343	def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot:
344		"""Mark coordinates on the maze with a marker.
345
346		default marker is a blue "+":
347		`dict(marker="+", color="blue")`
348		"""
349		kwargs = {
350			**dict(marker="+", color="blue"),
351			**kwargs,
352		}
353		for coord in coords:
354			self.marked_coords.append((coord, kwargs))
355
356		return self

Mark coordinates on the maze with a marker.

default marker is a blue "+": dict(marker="+", color="blue")

def to_ascii(self, show_endpoints: bool = True, show_solution: bool = True) -> str:
592	def to_ascii(
593		self,
594		show_endpoints: bool = True,
595		show_solution: bool = True,
596	) -> str:
597		"wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`"
598		if self.true_path:
599			return self.solved_maze.as_ascii(
600				show_endpoints=show_endpoints,
601				show_solution=show_solution,
602			)
603		else:
604			return self.maze.as_ascii(show_endpoints=show_endpoints)

wrapper for self.solved_maze.as_ascii(), shows the path if we have self.true_path

@dataclass(kw_only=True)
class PathFormat:
26@dataclass(kw_only=True)
27class PathFormat:
28	"""formatting options for path plot"""
29
30	label: str | None = None
31	fmt: str = "o"
32	color: str | None = None
33	cmap: str | None = None
34	line_width: float | None = None
35	quiver_kwargs: dict | None = None
36
37	def combine(self, other: PathFormat) -> PathFormat:
38		"""combine with other PathFormat object, overwriting attributes with non-None values.
39
40		returns a modified copy of self.
41		"""
42		output: PathFormat = deepcopy(self)
43		for key, value in other.__dict__.items():
44			if key == "path":
45				err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }"
46				raise ValueError(
47					err_msg,
48				)
49			if value is not None:
50				setattr(output, key, value)
51
52		return output

formatting options for path plot

PathFormat( *, label: str | None = None, fmt: str = 'o', color: str | None = None, cmap: str | None = None, line_width: float | None = None, quiver_kwargs: dict | None = None)
label: str | None = None
fmt: str = 'o'
color: str | None = None
cmap: str | None = None
line_width: float | None = None
quiver_kwargs: dict | None = None
def combine( self, other: PathFormat) -> PathFormat:
37	def combine(self, other: PathFormat) -> PathFormat:
38		"""combine with other PathFormat object, overwriting attributes with non-None values.
39
40		returns a modified copy of self.
41		"""
42		output: PathFormat = deepcopy(self)
43		for key, value in other.__dict__.items():
44			if key == "path":
45				err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }"
46				raise ValueError(
47					err_msg,
48				)
49			if value is not None:
50				setattr(output, key, value)
51
52		return output

combine with other PathFormat object, overwriting attributes with non-None values.

returns a modified copy of self.

def color_tokens_cmap( tokens: list[str], weights: Sequence[float], cmap: str | matplotlib.colors.Colormap = 'Blues', fmt: Literal['html', 'latex', 'terminal', None] = 'html', template: str | None = None, labels: bool = False) -> str:
125def color_tokens_cmap(
126	tokens: list[str],
127	weights: Sequence[float],
128	cmap: str | matplotlib.colors.Colormap = "Blues",
129	fmt: FormatType = "html",
130	template: str | None = None,
131	labels: bool = False,
132) -> str:
133	"color tokens given a list of weights and a colormap"
134	n_tok: int = len(tokens)
135	assert n_tok == len(weights), f"'{len(tokens) = }' != '{len(weights) = }'"
136	weights_np: Float[np.ndarray, " n_tok"] = np.array(weights)
137	# normalize weights to [0, 1]
138	weights_norm = matplotlib.colors.Normalize()(weights_np)
139
140	if isinstance(cmap, str):
141		cmap = matplotlib.colormaps.get_cmap(cmap)
142
143	colors: RGBArray = cmap(weights_norm)[:, :3] * 255
144
145	output: str = color_tokens_rgb(
146		tokens=tokens,
147		colors=colors,
148		fmt=fmt,
149		template=template,
150	)
151
152	if labels:
153		if fmt != "terminal":
154			raise NotImplementedError("labels only supported for terminal")
155		# align labels with the tokens
156		output += "\n"
157		for tok, weight in zip(tokens, weights_np, strict=False):
158			# 2 decimal points, left-aligned and trailing spaces to match token length
159			weight_str: str = f"{weight:.1f}"
160			# omit if longer than token
161			if len(weight_str) > len(tok):
162				weight_str = " " * len(tok)
163			else:
164				weight_str = weight_str.ljust(len(tok))
165			output += f"{weight_str} "
166
167	return output

color tokens given a list of weights and a colormap

def color_maze_tokens_AOTP( tokens: list[str], fmt: Literal['html', 'latex', 'terminal', None] = 'html', template: str | None = None, **kwargs) -> str:
184def color_maze_tokens_AOTP(
185	tokens: list[str],
186	fmt: FormatType = "html",
187	template: str | None = None,
188	**kwargs,
189) -> str:
190	"""color tokens assuming AOTP format
191
192	i.e: adjaceny list, origin, target, path
193
194	"""
195	output: list[str] = [
196		" ".join(
197			tokens_between(
198				tokens,
199				start_tok,
200				end_tok,
201				include_start=True,
202				include_end=True,
203			),
204		)
205		for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS
206	]
207
208	colors: RGBArray = np.array(
209		list(_MAZE_TOKENS_DEFAULT_COLORS.values()),
210		dtype=np.uint8,
211	)
212
213	return color_tokens_rgb(
214		tokens=output,
215		colors=colors,
216		fmt=fmt,
217		template=template,
218		**kwargs,
219	)

color tokens assuming AOTP format

i.e: adjaceny list, origin, target, path

def color_tokens_rgb( tokens: list, colors: Union[Sequence[Sequence[int]], jaxtyping.Float[ndarray, 'n 3']], fmt: Literal['html', 'latex', 'terminal', None] = 'html', template: str | None = None, clr_join: str | None = None, max_length: int | None = None) -> str:
 64def color_tokens_rgb(
 65	tokens: list,
 66	colors: Sequence[Sequence[int]] | Float[np.ndarray, "n 3"],
 67	fmt: FormatType = "html",
 68	template: str | None = None,
 69	clr_join: str | None = None,
 70	max_length: int | None = None,
 71) -> str:
 72	"""color tokens from a list with an RGB color array
 73
 74	tokens will not be escaped if `fmt` is None
 75
 76	# Parameters:
 77	- `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length.
 78	"""
 79	# process format
 80	if fmt is None:
 81		assert template is not None
 82		assert clr_join is not None
 83	else:
 84		assert template is None
 85		assert clr_join is None
 86		template = TEMPLATES[fmt]
 87		clr_join = _COLOR_JOIN[fmt]
 88
 89	if max_length is not None:
 90		# TODO: why are we using a map here again?
 91		# TYPING: this is missing a lot of type hints
 92		wrapped: list = list(  # noqa: C417
 93			map(
 94				lambda x: textwrap.wrap(
 95					x,
 96					width=max_length,
 97					break_long_words=False,
 98					break_on_hyphens=False,
 99				),
100				tokens,
101			),
102		)
103		colors = list(
104			flatten(
105				[[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))],
106				levels_to_flatten=1,
107			),
108		)
109		wrapped = list(flatten(wrapped, levels_to_flatten=1))
110		tokens = wrapped
111
112	# put everything together
113	output = [
114		template.format(
115			clr=clr_join.join(map(str, map(int, clr))),
116			tok=_escape_tok(tok, fmt),
117		)
118		for tok, clr in zip(tokens, colors, strict=False)
119	]
120
121	return " ".join(output)

color tokens from a list with an RGB color array

tokens will not be escaped if fmt is None

Parameters:

  • max_length: int | None: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If None, no limit on max length.