Coverage for maze_dataset/plotting/plot_tokens.py: 0%

22 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 00:49 -0600

1"`plot_colored_text` function to plot tokens on a matplotlib axis with colored backgrounds" 

2 

3from typing import Any, Sequence 

4 

5import matplotlib.pyplot as plt 

6import numpy as np 

7 

8 

9def plot_colored_text( 

10 tokens: Sequence[str], 

11 weights: Sequence[float], 

12 # assume its a colormap if not a string 

13 cmap: str | Any, # noqa: ANN401 

14 ax: plt.Axes | None = None, 

15 width_scale: float = 0.023, 

16 width_offset: float = 0.005, 

17 height_offset: float = 0.1, 

18 rect_height: float = 0.7, 

19 token_height: float = 0.7, 

20 label_height: float = 0.3, 

21 word_gap: float = 0.01, 

22 fontsize: int = 12, 

23 fig_height: float = 0.7, 

24 fig_width_scale: float = 0.25, 

25 char_min: int = 4, 

26) -> plt.Axes: 

27 "hacky function to plot tokens on a matplotlib axis with colored backgrounds" 

28 assert len(tokens) == len(weights), ( 

29 f"The number of tokens and weights must be the same: {len(tokens)} != {len(weights)}" 

30 ) 

31 total_len_estimate: float = sum([max(len(tok), char_min) for tok in tokens]) 

32 # set up figure if needed 

33 if ax is None: 

34 fig, ax = plt.subplots( 

35 figsize=(total_len_estimate * fig_width_scale, fig_height), 

36 ) 

37 ax.axis("off") 

38 

39 # Normalize the weights to be between 0 and 1 

40 norm_weights: Sequence[float] = (weights - np.min(weights)) / ( 

41 np.max(weights) - np.min(weights) 

42 ) 

43 

44 # Create a colormap instance 

45 if isinstance(cmap, str): 

46 colormap = plt.get_cmap(cmap) 

47 else: 

48 colormap = cmap 

49 

50 x_pos: float = 0.0 

51 for i, (tok, weight, norm_wgt) in enumerate( # noqa: B007 

52 zip(tokens, weights, norm_weights, strict=False), 

53 ): 

54 color = colormap(norm_wgt)[:3] 

55 

56 # Plot the background color 

57 rect_width = width_scale * max(len(tok), char_min) 

58 ax.add_patch( 

59 plt.Rectangle( 

60 (x_pos, height_offset), 

61 rect_width, 

62 height_offset + rect_height, 

63 fc=color, 

64 ec="none", 

65 ), 

66 ) 

67 

68 # Plot the token 

69 ax.text( 

70 x_pos + width_offset, 

71 token_height, 

72 tok, 

73 fontsize=fontsize, 

74 va="center", 

75 ha="left", 

76 ) 

77 

78 # Plot the weight below the token 

79 ax.text( 

80 x_pos + width_offset, 

81 label_height, 

82 f"{weight:.2f}", 

83 fontsize=fontsize, 

84 va="center", 

85 ha="left", 

86 ) 

87 

88 x_pos += rect_width + word_gap 

89 

90 return ax