docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.plotting.plot_tokens

plot_colored_text function to plot tokens on a matplotlib axis with colored backgrounds


 1"`plot_colored_text` function to plot tokens on a matplotlib axis with colored backgrounds"
 2
 3from typing import Any, Sequence
 4
 5import matplotlib.pyplot as plt
 6import numpy as np
 7
 8
 9def plot_colored_text(
10	tokens: Sequence[str],
11	weights: Sequence[float],
12	# assume its a colormap if not a string
13	cmap: str | Any,  # noqa: ANN401
14	ax: plt.Axes | None = None,
15	width_scale: float = 0.023,
16	width_offset: float = 0.005,
17	height_offset: float = 0.1,
18	rect_height: float = 0.7,
19	token_height: float = 0.7,
20	label_height: float = 0.3,
21	word_gap: float = 0.01,
22	fontsize: int = 12,
23	fig_height: float = 0.7,
24	fig_width_scale: float = 0.25,
25	char_min: int = 4,
26) -> plt.Axes:
27	"hacky function to plot tokens on a matplotlib axis with colored backgrounds"
28	assert len(tokens) == len(weights), (
29		f"The number of tokens and weights must be the same: {len(tokens)} != {len(weights)}"
30	)
31	total_len_estimate: float = sum([max(len(tok), char_min) for tok in tokens])
32	# set up figure if needed
33	if ax is None:
34		fig, ax = plt.subplots(
35			figsize=(total_len_estimate * fig_width_scale, fig_height),
36		)
37	ax.axis("off")
38
39	# Normalize the weights to be between 0 and 1
40	norm_weights: Sequence[float] = (weights - np.min(weights)) / (
41		np.max(weights) - np.min(weights)
42	)
43
44	# Create a colormap instance
45	if isinstance(cmap, str):
46		colormap = plt.get_cmap(cmap)
47	else:
48		colormap = cmap
49
50	x_pos: float = 0.0
51	for i, (tok, weight, norm_wgt) in enumerate(  # noqa: B007
52		zip(tokens, weights, norm_weights, strict=False),
53	):
54		color = colormap(norm_wgt)[:3]
55
56		# Plot the background color
57		rect_width = width_scale * max(len(tok), char_min)
58		ax.add_patch(
59			plt.Rectangle(
60				(x_pos, height_offset),
61				rect_width,
62				height_offset + rect_height,
63				fc=color,
64				ec="none",
65			),
66		)
67
68		# Plot the token
69		ax.text(
70			x_pos + width_offset,
71			token_height,
72			tok,
73			fontsize=fontsize,
74			va="center",
75			ha="left",
76		)
77
78		# Plot the weight below the token
79		ax.text(
80			x_pos + width_offset,
81			label_height,
82			f"{weight:.2f}",
83			fontsize=fontsize,
84			va="center",
85			ha="left",
86		)
87
88		x_pos += rect_width + word_gap
89
90	return ax

def plot_colored_text( tokens: Sequence[str], weights: Sequence[float], cmap: str | typing.Any, ax: matplotlib.axes._axes.Axes | None = None, width_scale: float = 0.023, width_offset: float = 0.005, height_offset: float = 0.1, rect_height: float = 0.7, token_height: float = 0.7, label_height: float = 0.3, word_gap: float = 0.01, fontsize: int = 12, fig_height: float = 0.7, fig_width_scale: float = 0.25, char_min: int = 4) -> matplotlib.axes._axes.Axes:
10def plot_colored_text(
11	tokens: Sequence[str],
12	weights: Sequence[float],
13	# assume its a colormap if not a string
14	cmap: str | Any,  # noqa: ANN401
15	ax: plt.Axes | None = None,
16	width_scale: float = 0.023,
17	width_offset: float = 0.005,
18	height_offset: float = 0.1,
19	rect_height: float = 0.7,
20	token_height: float = 0.7,
21	label_height: float = 0.3,
22	word_gap: float = 0.01,
23	fontsize: int = 12,
24	fig_height: float = 0.7,
25	fig_width_scale: float = 0.25,
26	char_min: int = 4,
27) -> plt.Axes:
28	"hacky function to plot tokens on a matplotlib axis with colored backgrounds"
29	assert len(tokens) == len(weights), (
30		f"The number of tokens and weights must be the same: {len(tokens)} != {len(weights)}"
31	)
32	total_len_estimate: float = sum([max(len(tok), char_min) for tok in tokens])
33	# set up figure if needed
34	if ax is None:
35		fig, ax = plt.subplots(
36			figsize=(total_len_estimate * fig_width_scale, fig_height),
37		)
38	ax.axis("off")
39
40	# Normalize the weights to be between 0 and 1
41	norm_weights: Sequence[float] = (weights - np.min(weights)) / (
42		np.max(weights) - np.min(weights)
43	)
44
45	# Create a colormap instance
46	if isinstance(cmap, str):
47		colormap = plt.get_cmap(cmap)
48	else:
49		colormap = cmap
50
51	x_pos: float = 0.0
52	for i, (tok, weight, norm_wgt) in enumerate(  # noqa: B007
53		zip(tokens, weights, norm_weights, strict=False),
54	):
55		color = colormap(norm_wgt)[:3]
56
57		# Plot the background color
58		rect_width = width_scale * max(len(tok), char_min)
59		ax.add_patch(
60			plt.Rectangle(
61				(x_pos, height_offset),
62				rect_width,
63				height_offset + rect_height,
64				fc=color,
65				ec="none",
66			),
67		)
68
69		# Plot the token
70		ax.text(
71			x_pos + width_offset,
72			token_height,
73			tok,
74			fontsize=fontsize,
75			va="center",
76			ha="left",
77		)
78
79		# Plot the weight below the token
80		ax.text(
81			x_pos + width_offset,
82			label_height,
83			f"{weight:.2f}",
84			fontsize=fontsize,
85			va="center",
86			ha="left",
87		)
88
89		x_pos += rect_width + word_gap
90
91	return ax

hacky function to plot tokens on a matplotlib axis with colored backgrounds