Coverage for maze_dataset/plotting/print_tokens.py: 57%
79 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
1"""Functions to print tokens with colors in different formats
3you can color the tokens by their:
5- type (i.e. adjacency list, origin, target, path) using `color_maze_tokens_AOTP`
6- custom weights (i.e. attention weights) using `color_tokens_cmap`
7- entirely custom colors using `color_tokens_rgb`
9and the output can be in different formats, specified by `FormatType` (html, latex, terminal)
11"""
13import html
14import textwrap
15from typing import Literal, Sequence
17import matplotlib # noqa: ICN001
18import numpy as np
19from IPython.display import HTML, display
20from jaxtyping import Float, UInt8
21from muutils.misc import flatten
23from maze_dataset.constants import SPECIAL_TOKENS
24from maze_dataset.token_utils import tokens_between
26RGBArray = UInt8[np.ndarray, "n 3"]
27"1D array of RGB values"
29FormatType = Literal["html", "latex", "terminal", None]
30"output format for the tokens"
32TEMPLATES: dict[FormatType, str] = {
33 "html": '<span style="color: black; background-color: rgb({clr})"> {tok} </span>',
34 "latex": "\\colorbox[RGB]{{ {clr} }}{{ \\texttt{{ {tok} }} }}",
35 "terminal": "\033[30m\033[48;2;{clr}m{tok}\033[0m",
36}
37"templates of printing tokens in different formats"
39_COLOR_JOIN: dict[FormatType, str] = {
40 "html": ",",
41 "latex": ",",
42 "terminal": ";",
43}
44"joiner for colors in different formats"
47def _escape_tok(
48 tok: str,
49 fmt: FormatType,
50) -> str:
51 "escape token based on format"
52 if fmt == "html":
53 return html.escape(tok)
54 elif fmt == "latex":
55 return tok.replace("_", "\\_").replace("#", "\\#")
56 elif fmt == "terminal":
57 return tok
58 else:
59 err_msg: str = f"Unexpected format: {fmt}"
60 raise ValueError(err_msg)
63def color_tokens_rgb(
64 tokens: list,
65 colors: Sequence[Sequence[int]] | Float[np.ndarray, "n 3"],
66 fmt: FormatType = "html",
67 template: str | None = None,
68 clr_join: str | None = None,
69 max_length: int | None = None,
70) -> str:
71 """color tokens from a list with an RGB color array
73 tokens will not be escaped if `fmt` is None
75 # Parameters:
76 - `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length.
77 """
78 # process format
79 if fmt is None:
80 assert template is not None
81 assert clr_join is not None
82 else:
83 assert template is None
84 assert clr_join is None
85 template = TEMPLATES[fmt]
86 clr_join = _COLOR_JOIN[fmt]
88 if max_length is not None:
89 # TODO: why are we using a map here again?
90 # TYPING: this is missing a lot of type hints
91 wrapped: list = list( # noqa: C417
92 map(
93 lambda x: textwrap.wrap(
94 x,
95 width=max_length,
96 break_long_words=False,
97 break_on_hyphens=False,
98 ),
99 tokens,
100 ),
101 )
102 colors = list(
103 flatten(
104 [[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))],
105 levels_to_flatten=1,
106 ),
107 )
108 wrapped = list(flatten(wrapped, levels_to_flatten=1))
109 tokens = wrapped
111 # put everything together
112 output = [
113 template.format(
114 clr=clr_join.join(map(str, map(int, clr))),
115 tok=_escape_tok(tok, fmt),
116 )
117 for tok, clr in zip(tokens, colors, strict=False)
118 ]
120 return " ".join(output)
123# TYPING: would be nice to type hint as html, latex, or terminal string and overload depending on `FormatType`
124def color_tokens_cmap(
125 tokens: list[str],
126 weights: Sequence[float],
127 cmap: str | matplotlib.colors.Colormap = "Blues",
128 fmt: FormatType = "html",
129 template: str | None = None,
130 labels: bool = False,
131) -> str:
132 "color tokens given a list of weights and a colormap"
133 n_tok: int = len(tokens)
134 assert n_tok == len(weights), f"'{len(tokens) = }' != '{len(weights) = }'"
135 weights_np: Float[np.ndarray, " n_tok"] = np.array(weights)
136 # normalize weights to [0, 1]
137 weights_norm = matplotlib.colors.Normalize()(weights_np)
139 if isinstance(cmap, str):
140 cmap = matplotlib.colormaps.get_cmap(cmap)
142 colors: RGBArray = cmap(weights_norm)[:, :3] * 255
144 output: str = color_tokens_rgb(
145 tokens=tokens,
146 colors=colors,
147 fmt=fmt,
148 template=template,
149 )
151 if labels:
152 if fmt != "terminal":
153 raise NotImplementedError("labels only supported for terminal")
154 # align labels with the tokens
155 output += "\n"
156 for tok, weight in zip(tokens, weights_np, strict=False):
157 # 2 decimal points, left-aligned and trailing spaces to match token length
158 weight_str: str = f"{weight:.1f}"
159 # omit if longer than token
160 if len(weight_str) > len(tok):
161 weight_str = " " * len(tok)
162 else:
163 weight_str = weight_str.ljust(len(tok))
164 output += f"{weight_str} "
166 return output
169# colors roughly made to be similar to visual representation
170_MAZE_TOKENS_DEFAULT_COLORS: dict[tuple[str, str], tuple[int, int, int]] = {
171 (SPECIAL_TOKENS.ADJLIST_START, SPECIAL_TOKENS.ADJLIST_END): (
172 176,
173 152,
174 232,
175 ), # purple
176 (SPECIAL_TOKENS.ORIGIN_START, SPECIAL_TOKENS.ORIGIN_END): (154, 239, 123), # green
177 (SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END): (246, 136, 136), # red
178 (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END): (111, 187, 254), # blue
179}
180"default colors for maze tokens, roughly matches the format of `as_pixels`"
183def color_maze_tokens_AOTP(
184 tokens: list[str],
185 fmt: FormatType = "html",
186 template: str | None = None,
187 **kwargs,
188) -> str:
189 """color tokens assuming AOTP format
191 i.e: adjaceny list, origin, target, path
193 """
194 output: list[str] = [
195 " ".join(
196 tokens_between(
197 tokens,
198 start_tok,
199 end_tok,
200 include_start=True,
201 include_end=True,
202 ),
203 )
204 for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS
205 ]
207 colors: RGBArray = np.array(
208 list(_MAZE_TOKENS_DEFAULT_COLORS.values()),
209 dtype=np.uint8,
210 )
212 return color_tokens_rgb(
213 tokens=output,
214 colors=colors,
215 fmt=fmt,
216 template=template,
217 **kwargs,
218 )
221def display_html(html: str) -> None:
222 "display html string"
223 display(HTML(html))
226def display_color_tokens_rgb(
227 tokens: list[str],
228 colors: RGBArray,
229) -> None:
230 """display tokens (as html) with custom colors"""
231 html: str = color_tokens_rgb(tokens, colors, fmt="html")
232 display_html(html)
235def display_color_tokens_cmap(
236 tokens: list[str],
237 weights: Sequence[float],
238 cmap: str | matplotlib.colors.Colormap = "Blues",
239) -> None:
240 """display tokens (as html) with color based on weights"""
241 html: str = color_tokens_cmap(tokens, weights, cmap)
242 display_html(html)
245def display_color_maze_tokens_AOTP(
246 tokens: list[str],
247) -> None:
248 """display maze tokens (as html) with AOTP coloring"""
249 html: str = color_maze_tokens_AOTP(tokens)
250 display_html(html)