Coverage for maze_dataset/plotting/plot_tokens.py: 0%
22 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 00:49 -0600
1"`plot_colored_text` function to plot tokens on a matplotlib axis with colored backgrounds"
3from typing import Any, Sequence
5import matplotlib.pyplot as plt
6import numpy as np
9def plot_colored_text(
10 tokens: Sequence[str],
11 weights: Sequence[float],
12 # assume its a colormap if not a string
13 cmap: str | Any, # noqa: ANN401
14 ax: plt.Axes | None = None,
15 width_scale: float = 0.023,
16 width_offset: float = 0.005,
17 height_offset: float = 0.1,
18 rect_height: float = 0.7,
19 token_height: float = 0.7,
20 label_height: float = 0.3,
21 word_gap: float = 0.01,
22 fontsize: int = 12,
23 fig_height: float = 0.7,
24 fig_width_scale: float = 0.25,
25 char_min: int = 4,
26) -> plt.Axes:
27 "hacky function to plot tokens on a matplotlib axis with colored backgrounds"
28 assert len(tokens) == len(weights), (
29 f"The number of tokens and weights must be the same: {len(tokens)} != {len(weights)}"
30 )
31 total_len_estimate: float = sum([max(len(tok), char_min) for tok in tokens])
32 # set up figure if needed
33 if ax is None:
34 fig, ax = plt.subplots(
35 figsize=(total_len_estimate * fig_width_scale, fig_height),
36 )
37 ax.axis("off")
39 # Normalize the weights to be between 0 and 1
40 norm_weights: Sequence[float] = (weights - np.min(weights)) / (
41 np.max(weights) - np.min(weights)
42 )
44 # Create a colormap instance
45 if isinstance(cmap, str):
46 colormap = plt.get_cmap(cmap)
47 else:
48 colormap = cmap
50 x_pos: float = 0.0
51 for i, (tok, weight, norm_wgt) in enumerate( # noqa: B007
52 zip(tokens, weights, norm_weights, strict=False),
53 ):
54 color = colormap(norm_wgt)[:3]
56 # Plot the background color
57 rect_width = width_scale * max(len(tok), char_min)
58 ax.add_patch(
59 plt.Rectangle(
60 (x_pos, height_offset),
61 rect_width,
62 height_offset + rect_height,
63 fc=color,
64 ec="none",
65 ),
66 )
68 # Plot the token
69 ax.text(
70 x_pos + width_offset,
71 token_height,
72 tok,
73 fontsize=fontsize,
74 va="center",
75 ha="left",
76 )
78 # Plot the weight below the token
79 ax.text(
80 x_pos + width_offset,
81 label_height,
82 f"{weight:.2f}",
83 fontsize=fontsize,
84 va="center",
85 ha="left",
86 )
88 x_pos += rect_width + word_gap
90 return ax