maze_dataset.plotting.plot_tokens
plot_colored_text
function to plot tokens on a matplotlib axis with colored backgrounds
1"`plot_colored_text` function to plot tokens on a matplotlib axis with colored backgrounds" 2 3from typing import Any, Sequence 4 5import matplotlib.pyplot as plt 6import numpy as np 7 8 9def plot_colored_text( 10 tokens: Sequence[str], 11 weights: Sequence[float], 12 # assume its a colormap if not a string 13 cmap: str | Any, # noqa: ANN401 14 ax: plt.Axes | None = None, 15 width_scale: float = 0.023, 16 width_offset: float = 0.005, 17 height_offset: float = 0.1, 18 rect_height: float = 0.7, 19 token_height: float = 0.7, 20 label_height: float = 0.3, 21 word_gap: float = 0.01, 22 fontsize: int = 12, 23 fig_height: float = 0.7, 24 fig_width_scale: float = 0.25, 25 char_min: int = 4, 26) -> plt.Axes: 27 "hacky function to plot tokens on a matplotlib axis with colored backgrounds" 28 assert len(tokens) == len(weights), ( 29 f"The number of tokens and weights must be the same: {len(tokens)} != {len(weights)}" 30 ) 31 total_len_estimate: float = sum([max(len(tok), char_min) for tok in tokens]) 32 # set up figure if needed 33 if ax is None: 34 fig, ax = plt.subplots( 35 figsize=(total_len_estimate * fig_width_scale, fig_height), 36 ) 37 ax.axis("off") 38 39 # Normalize the weights to be between 0 and 1 40 norm_weights: Sequence[float] = (weights - np.min(weights)) / ( 41 np.max(weights) - np.min(weights) 42 ) 43 44 # Create a colormap instance 45 if isinstance(cmap, str): 46 colormap = plt.get_cmap(cmap) 47 else: 48 colormap = cmap 49 50 x_pos: float = 0.0 51 for i, (tok, weight, norm_wgt) in enumerate( # noqa: B007 52 zip(tokens, weights, norm_weights, strict=False), 53 ): 54 color = colormap(norm_wgt)[:3] 55 56 # Plot the background color 57 rect_width = width_scale * max(len(tok), char_min) 58 ax.add_patch( 59 plt.Rectangle( 60 (x_pos, height_offset), 61 rect_width, 62 height_offset + rect_height, 63 fc=color, 64 ec="none", 65 ), 66 ) 67 68 # Plot the token 69 ax.text( 70 x_pos + width_offset, 71 token_height, 72 tok, 73 fontsize=fontsize, 74 va="center", 75 ha="left", 76 ) 77 78 # Plot the weight below the token 79 ax.text( 80 x_pos + width_offset, 81 label_height, 82 f"{weight:.2f}", 83 fontsize=fontsize, 84 va="center", 85 ha="left", 86 ) 87 88 x_pos += rect_width + word_gap 89 90 return ax
def
plot_colored_text( tokens: Sequence[str], weights: Sequence[float], cmap: str | typing.Any, ax: matplotlib.axes._axes.Axes | None = None, width_scale: float = 0.023, width_offset: float = 0.005, height_offset: float = 0.1, rect_height: float = 0.7, token_height: float = 0.7, label_height: float = 0.3, word_gap: float = 0.01, fontsize: int = 12, fig_height: float = 0.7, fig_width_scale: float = 0.25, char_min: int = 4) -> matplotlib.axes._axes.Axes:
10def plot_colored_text( 11 tokens: Sequence[str], 12 weights: Sequence[float], 13 # assume its a colormap if not a string 14 cmap: str | Any, # noqa: ANN401 15 ax: plt.Axes | None = None, 16 width_scale: float = 0.023, 17 width_offset: float = 0.005, 18 height_offset: float = 0.1, 19 rect_height: float = 0.7, 20 token_height: float = 0.7, 21 label_height: float = 0.3, 22 word_gap: float = 0.01, 23 fontsize: int = 12, 24 fig_height: float = 0.7, 25 fig_width_scale: float = 0.25, 26 char_min: int = 4, 27) -> plt.Axes: 28 "hacky function to plot tokens on a matplotlib axis with colored backgrounds" 29 assert len(tokens) == len(weights), ( 30 f"The number of tokens and weights must be the same: {len(tokens)} != {len(weights)}" 31 ) 32 total_len_estimate: float = sum([max(len(tok), char_min) for tok in tokens]) 33 # set up figure if needed 34 if ax is None: 35 fig, ax = plt.subplots( 36 figsize=(total_len_estimate * fig_width_scale, fig_height), 37 ) 38 ax.axis("off") 39 40 # Normalize the weights to be between 0 and 1 41 norm_weights: Sequence[float] = (weights - np.min(weights)) / ( 42 np.max(weights) - np.min(weights) 43 ) 44 45 # Create a colormap instance 46 if isinstance(cmap, str): 47 colormap = plt.get_cmap(cmap) 48 else: 49 colormap = cmap 50 51 x_pos: float = 0.0 52 for i, (tok, weight, norm_wgt) in enumerate( # noqa: B007 53 zip(tokens, weights, norm_weights, strict=False), 54 ): 55 color = colormap(norm_wgt)[:3] 56 57 # Plot the background color 58 rect_width = width_scale * max(len(tok), char_min) 59 ax.add_patch( 60 plt.Rectangle( 61 (x_pos, height_offset), 62 rect_width, 63 height_offset + rect_height, 64 fc=color, 65 ec="none", 66 ), 67 ) 68 69 # Plot the token 70 ax.text( 71 x_pos + width_offset, 72 token_height, 73 tok, 74 fontsize=fontsize, 75 va="center", 76 ha="left", 77 ) 78 79 # Plot the weight below the token 80 ax.text( 81 x_pos + width_offset, 82 label_height, 83 f"{weight:.2f}", 84 fontsize=fontsize, 85 va="center", 86 ha="left", 87 ) 88 89 x_pos += rect_width + word_gap 90 91 return ax
hacky function to plot tokens on a matplotlib axis with colored backgrounds