Coverage for maze_dataset/plotting/print_tokens.py: 57%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 12:48 -0600

1"""Functions to print tokens with colors in different formats 

2 

3you can color the tokens by their: 

4 

5- type (i.e. adjacency list, origin, target, path) using `color_maze_tokens_AOTP` 

6- custom weights (i.e. attention weights) using `color_tokens_cmap` 

7- entirely custom colors using `color_tokens_rgb` 

8 

9and the output can be in different formats, specified by `FormatType` (html, latex, terminal) 

10 

11""" 

12 

13import html 

14import textwrap 

15from typing import Literal, Sequence 

16 

17import matplotlib # noqa: ICN001 

18import numpy as np 

19from IPython.display import HTML, display 

20from jaxtyping import Float, UInt8 

21from muutils.misc import flatten 

22 

23from maze_dataset.constants import SPECIAL_TOKENS 

24from maze_dataset.token_utils import tokens_between 

25 

26RGBArray = UInt8[np.ndarray, "n 3"] 

27"1D array of RGB values" 

28 

29FormatType = Literal["html", "latex", "terminal", None] 

30"output format for the tokens" 

31 

32TEMPLATES: dict[FormatType, str] = { 

33 "html": '<span style="color: black; background-color: rgb({clr})">&nbsp{tok}&nbsp</span>', 

34 "latex": "\\colorbox[RGB]{{ {clr} }}{{ \\texttt{{ {tok} }} }}", 

35 "terminal": "\033[30m\033[48;2;{clr}m{tok}\033[0m", 

36} 

37"templates of printing tokens in different formats" 

38 

39_COLOR_JOIN: dict[FormatType, str] = { 

40 "html": ",", 

41 "latex": ",", 

42 "terminal": ";", 

43} 

44"joiner for colors in different formats" 

45 

46 

47def _escape_tok( 

48 tok: str, 

49 fmt: FormatType, 

50) -> str: 

51 "escape token based on format" 

52 if fmt == "html": 

53 return html.escape(tok) 

54 elif fmt == "latex": 

55 return tok.replace("_", "\\_").replace("#", "\\#") 

56 elif fmt == "terminal": 

57 return tok 

58 else: 

59 err_msg: str = f"Unexpected format: {fmt}" 

60 raise ValueError(err_msg) 

61 

62 

63def color_tokens_rgb( 

64 tokens: list, 

65 colors: Sequence[Sequence[int]] | Float[np.ndarray, "n 3"], 

66 fmt: FormatType = "html", 

67 template: str | None = None, 

68 clr_join: str | None = None, 

69 max_length: int | None = None, 

70) -> str: 

71 """color tokens from a list with an RGB color array 

72 

73 tokens will not be escaped if `fmt` is None 

74 

75 # Parameters: 

76 - `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length. 

77 """ 

78 # process format 

79 if fmt is None: 

80 assert template is not None 

81 assert clr_join is not None 

82 else: 

83 assert template is None 

84 assert clr_join is None 

85 template = TEMPLATES[fmt] 

86 clr_join = _COLOR_JOIN[fmt] 

87 

88 if max_length is not None: 

89 # TODO: why are we using a map here again? 

90 # TYPING: this is missing a lot of type hints 

91 wrapped: list = list( # noqa: C417 

92 map( 

93 lambda x: textwrap.wrap( 

94 x, 

95 width=max_length, 

96 break_long_words=False, 

97 break_on_hyphens=False, 

98 ), 

99 tokens, 

100 ), 

101 ) 

102 colors = list( 

103 flatten( 

104 [[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))], 

105 levels_to_flatten=1, 

106 ), 

107 ) 

108 wrapped = list(flatten(wrapped, levels_to_flatten=1)) 

109 tokens = wrapped 

110 

111 # put everything together 

112 output = [ 

113 template.format( 

114 clr=clr_join.join(map(str, map(int, clr))), 

115 tok=_escape_tok(tok, fmt), 

116 ) 

117 for tok, clr in zip(tokens, colors, strict=False) 

118 ] 

119 

120 return " ".join(output) 

121 

122 

123# TYPING: would be nice to type hint as html, latex, or terminal string and overload depending on `FormatType` 

124def color_tokens_cmap( 

125 tokens: list[str], 

126 weights: Sequence[float], 

127 cmap: str | matplotlib.colors.Colormap = "Blues", 

128 fmt: FormatType = "html", 

129 template: str | None = None, 

130 labels: bool = False, 

131) -> str: 

132 "color tokens given a list of weights and a colormap" 

133 n_tok: int = len(tokens) 

134 assert n_tok == len(weights), f"'{len(tokens) = }' != '{len(weights) = }'" 

135 weights_np: Float[np.ndarray, " n_tok"] = np.array(weights) 

136 # normalize weights to [0, 1] 

137 weights_norm = matplotlib.colors.Normalize()(weights_np) 

138 

139 if isinstance(cmap, str): 

140 cmap = matplotlib.colormaps.get_cmap(cmap) 

141 

142 colors: RGBArray = cmap(weights_norm)[:, :3] * 255 

143 

144 output: str = color_tokens_rgb( 

145 tokens=tokens, 

146 colors=colors, 

147 fmt=fmt, 

148 template=template, 

149 ) 

150 

151 if labels: 

152 if fmt != "terminal": 

153 raise NotImplementedError("labels only supported for terminal") 

154 # align labels with the tokens 

155 output += "\n" 

156 for tok, weight in zip(tokens, weights_np, strict=False): 

157 # 2 decimal points, left-aligned and trailing spaces to match token length 

158 weight_str: str = f"{weight:.1f}" 

159 # omit if longer than token 

160 if len(weight_str) > len(tok): 

161 weight_str = " " * len(tok) 

162 else: 

163 weight_str = weight_str.ljust(len(tok)) 

164 output += f"{weight_str} " 

165 

166 return output 

167 

168 

169# colors roughly made to be similar to visual representation 

170_MAZE_TOKENS_DEFAULT_COLORS: dict[tuple[str, str], tuple[int, int, int]] = { 

171 (SPECIAL_TOKENS.ADJLIST_START, SPECIAL_TOKENS.ADJLIST_END): ( 

172 176, 

173 152, 

174 232, 

175 ), # purple 

176 (SPECIAL_TOKENS.ORIGIN_START, SPECIAL_TOKENS.ORIGIN_END): (154, 239, 123), # green 

177 (SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END): (246, 136, 136), # red 

178 (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END): (111, 187, 254), # blue 

179} 

180"default colors for maze tokens, roughly matches the format of `as_pixels`" 

181 

182 

183def color_maze_tokens_AOTP( 

184 tokens: list[str], 

185 fmt: FormatType = "html", 

186 template: str | None = None, 

187 **kwargs, 

188) -> str: 

189 """color tokens assuming AOTP format 

190 

191 i.e: adjaceny list, origin, target, path 

192 

193 """ 

194 output: list[str] = [ 

195 " ".join( 

196 tokens_between( 

197 tokens, 

198 start_tok, 

199 end_tok, 

200 include_start=True, 

201 include_end=True, 

202 ), 

203 ) 

204 for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS 

205 ] 

206 

207 colors: RGBArray = np.array( 

208 list(_MAZE_TOKENS_DEFAULT_COLORS.values()), 

209 dtype=np.uint8, 

210 ) 

211 

212 return color_tokens_rgb( 

213 tokens=output, 

214 colors=colors, 

215 fmt=fmt, 

216 template=template, 

217 **kwargs, 

218 ) 

219 

220 

221def display_html(html: str) -> None: 

222 "display html string" 

223 display(HTML(html)) 

224 

225 

226def display_color_tokens_rgb( 

227 tokens: list[str], 

228 colors: RGBArray, 

229) -> None: 

230 """display tokens (as html) with custom colors""" 

231 html: str = color_tokens_rgb(tokens, colors, fmt="html") 

232 display_html(html) 

233 

234 

235def display_color_tokens_cmap( 

236 tokens: list[str], 

237 weights: Sequence[float], 

238 cmap: str | matplotlib.colors.Colormap = "Blues", 

239) -> None: 

240 """display tokens (as html) with color based on weights""" 

241 html: str = color_tokens_cmap(tokens, weights, cmap) 

242 display_html(html) 

243 

244 

245def display_color_maze_tokens_AOTP( 

246 tokens: list[str], 

247) -> None: 

248 """display maze tokens (as html) with AOTP coloring""" 

249 html: str = color_maze_tokens_AOTP(tokens) 

250 display_html(html)