maze_dataset.plotting.print_tokens
Functions to print tokens with colors in different formats
you can color the tokens by their:
- type (i.e. adjacency list, origin, target, path) using
color_maze_tokens_AOTP
- custom weights (i.e. attention weights) using
color_tokens_cmap
- entirely custom colors using
color_tokens_rgb
and the output can be in different formats, specified by FormatType
(html, latex, terminal)
1"""Functions to print tokens with colors in different formats 2 3you can color the tokens by their: 4 5- type (i.e. adjacency list, origin, target, path) using `color_maze_tokens_AOTP` 6- custom weights (i.e. attention weights) using `color_tokens_cmap` 7- entirely custom colors using `color_tokens_rgb` 8 9and the output can be in different formats, specified by `FormatType` (html, latex, terminal) 10 11""" 12 13import html 14import textwrap 15from typing import Literal, Sequence 16 17import matplotlib # noqa: ICN001 18import numpy as np 19from IPython.display import HTML, display 20from jaxtyping import Float, UInt8 21from muutils.misc import flatten 22 23from maze_dataset.constants import SPECIAL_TOKENS 24from maze_dataset.token_utils import tokens_between 25 26RGBArray = UInt8[np.ndarray, "n 3"] 27"1D array of RGB values" 28 29FormatType = Literal["html", "latex", "terminal", None] 30"output format for the tokens" 31 32TEMPLATES: dict[FormatType, str] = { 33 "html": '<span style="color: black; background-color: rgb({clr})"> {tok} </span>', 34 "latex": "\\colorbox[RGB]{{ {clr} }}{{ \\texttt{{ {tok} }} }}", 35 "terminal": "\033[30m\033[48;2;{clr}m{tok}\033[0m", 36} 37"templates of printing tokens in different formats" 38 39_COLOR_JOIN: dict[FormatType, str] = { 40 "html": ",", 41 "latex": ",", 42 "terminal": ";", 43} 44"joiner for colors in different formats" 45 46 47def _escape_tok( 48 tok: str, 49 fmt: FormatType, 50) -> str: 51 "escape token based on format" 52 if fmt == "html": 53 return html.escape(tok) 54 elif fmt == "latex": 55 return tok.replace("_", "\\_").replace("#", "\\#") 56 elif fmt == "terminal": 57 return tok 58 else: 59 err_msg: str = f"Unexpected format: {fmt}" 60 raise ValueError(err_msg) 61 62 63def color_tokens_rgb( 64 tokens: list, 65 colors: Sequence[Sequence[int]] | Float[np.ndarray, "n 3"], 66 fmt: FormatType = "html", 67 template: str | None = None, 68 clr_join: str | None = None, 69 max_length: int | None = None, 70) -> str: 71 """color tokens from a list with an RGB color array 72 73 tokens will not be escaped if `fmt` is None 74 75 # Parameters: 76 - `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length. 77 """ 78 # process format 79 if fmt is None: 80 assert template is not None 81 assert clr_join is not None 82 else: 83 assert template is None 84 assert clr_join is None 85 template = TEMPLATES[fmt] 86 clr_join = _COLOR_JOIN[fmt] 87 88 if max_length is not None: 89 # TODO: why are we using a map here again? 90 # TYPING: this is missing a lot of type hints 91 wrapped: list = list( # noqa: C417 92 map( 93 lambda x: textwrap.wrap( 94 x, 95 width=max_length, 96 break_long_words=False, 97 break_on_hyphens=False, 98 ), 99 tokens, 100 ), 101 ) 102 colors = list( 103 flatten( 104 [[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))], 105 levels_to_flatten=1, 106 ), 107 ) 108 wrapped = list(flatten(wrapped, levels_to_flatten=1)) 109 tokens = wrapped 110 111 # put everything together 112 output = [ 113 template.format( 114 clr=clr_join.join(map(str, map(int, clr))), 115 tok=_escape_tok(tok, fmt), 116 ) 117 for tok, clr in zip(tokens, colors, strict=False) 118 ] 119 120 return " ".join(output) 121 122 123# TYPING: would be nice to type hint as html, latex, or terminal string and overload depending on `FormatType` 124def color_tokens_cmap( 125 tokens: list[str], 126 weights: Sequence[float], 127 cmap: str | matplotlib.colors.Colormap = "Blues", 128 fmt: FormatType = "html", 129 template: str | None = None, 130 labels: bool = False, 131) -> str: 132 "color tokens given a list of weights and a colormap" 133 n_tok: int = len(tokens) 134 assert n_tok == len(weights), f"'{len(tokens) = }' != '{len(weights) = }'" 135 weights_np: Float[np.ndarray, " n_tok"] = np.array(weights) 136 # normalize weights to [0, 1] 137 weights_norm = matplotlib.colors.Normalize()(weights_np) 138 139 if isinstance(cmap, str): 140 cmap = matplotlib.colormaps.get_cmap(cmap) 141 142 colors: RGBArray = cmap(weights_norm)[:, :3] * 255 143 144 output: str = color_tokens_rgb( 145 tokens=tokens, 146 colors=colors, 147 fmt=fmt, 148 template=template, 149 ) 150 151 if labels: 152 if fmt != "terminal": 153 raise NotImplementedError("labels only supported for terminal") 154 # align labels with the tokens 155 output += "\n" 156 for tok, weight in zip(tokens, weights_np, strict=False): 157 # 2 decimal points, left-aligned and trailing spaces to match token length 158 weight_str: str = f"{weight:.1f}" 159 # omit if longer than token 160 if len(weight_str) > len(tok): 161 weight_str = " " * len(tok) 162 else: 163 weight_str = weight_str.ljust(len(tok)) 164 output += f"{weight_str} " 165 166 return output 167 168 169# colors roughly made to be similar to visual representation 170_MAZE_TOKENS_DEFAULT_COLORS: dict[tuple[str, str], tuple[int, int, int]] = { 171 (SPECIAL_TOKENS.ADJLIST_START, SPECIAL_TOKENS.ADJLIST_END): ( 172 176, 173 152, 174 232, 175 ), # purple 176 (SPECIAL_TOKENS.ORIGIN_START, SPECIAL_TOKENS.ORIGIN_END): (154, 239, 123), # green 177 (SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END): (246, 136, 136), # red 178 (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END): (111, 187, 254), # blue 179} 180"default colors for maze tokens, roughly matches the format of `as_pixels`" 181 182 183def color_maze_tokens_AOTP( 184 tokens: list[str], 185 fmt: FormatType = "html", 186 template: str | None = None, 187 **kwargs, 188) -> str: 189 """color tokens assuming AOTP format 190 191 i.e: adjaceny list, origin, target, path 192 193 """ 194 output: list[str] = [ 195 " ".join( 196 tokens_between( 197 tokens, 198 start_tok, 199 end_tok, 200 include_start=True, 201 include_end=True, 202 ), 203 ) 204 for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS 205 ] 206 207 colors: RGBArray = np.array( 208 list(_MAZE_TOKENS_DEFAULT_COLORS.values()), 209 dtype=np.uint8, 210 ) 211 212 return color_tokens_rgb( 213 tokens=output, 214 colors=colors, 215 fmt=fmt, 216 template=template, 217 **kwargs, 218 ) 219 220 221def display_html(html: str) -> None: 222 "display html string" 223 display(HTML(html)) 224 225 226def display_color_tokens_rgb( 227 tokens: list[str], 228 colors: RGBArray, 229) -> None: 230 """display tokens (as html) with custom colors""" 231 html: str = color_tokens_rgb(tokens, colors, fmt="html") 232 display_html(html) 233 234 235def display_color_tokens_cmap( 236 tokens: list[str], 237 weights: Sequence[float], 238 cmap: str | matplotlib.colors.Colormap = "Blues", 239) -> None: 240 """display tokens (as html) with color based on weights""" 241 html: str = color_tokens_cmap(tokens, weights, cmap) 242 display_html(html) 243 244 245def display_color_maze_tokens_AOTP( 246 tokens: list[str], 247) -> None: 248 """display maze tokens (as html) with AOTP coloring""" 249 html: str = color_maze_tokens_AOTP(tokens) 250 display_html(html)
RGBArray =
<class 'jaxtyping.UInt8[ndarray, 'n 3']'>
1D array of RGB values
FormatType =
typing.Literal['html', 'latex', 'terminal', None]
output format for the tokens
TEMPLATES: dict[typing.Literal['html', 'latex', 'terminal', None], str] =
{'html': '<span style="color: black; background-color: rgb({clr})"> {tok} </span>', 'latex': '\\colorbox[RGB]{{ {clr} }}{{ \\texttt{{ {tok} }} }}', 'terminal': '\x1b[30m\x1b[48;2;{clr}m{tok}\x1b[0m'}
templates of printing tokens in different formats
def
color_tokens_rgb( tokens: list, colors: Union[Sequence[Sequence[int]], jaxtyping.Float[ndarray, 'n 3']], fmt: Literal['html', 'latex', 'terminal', None] = 'html', template: str | None = None, clr_join: str | None = None, max_length: int | None = None) -> str:
64def color_tokens_rgb( 65 tokens: list, 66 colors: Sequence[Sequence[int]] | Float[np.ndarray, "n 3"], 67 fmt: FormatType = "html", 68 template: str | None = None, 69 clr_join: str | None = None, 70 max_length: int | None = None, 71) -> str: 72 """color tokens from a list with an RGB color array 73 74 tokens will not be escaped if `fmt` is None 75 76 # Parameters: 77 - `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length. 78 """ 79 # process format 80 if fmt is None: 81 assert template is not None 82 assert clr_join is not None 83 else: 84 assert template is None 85 assert clr_join is None 86 template = TEMPLATES[fmt] 87 clr_join = _COLOR_JOIN[fmt] 88 89 if max_length is not None: 90 # TODO: why are we using a map here again? 91 # TYPING: this is missing a lot of type hints 92 wrapped: list = list( # noqa: C417 93 map( 94 lambda x: textwrap.wrap( 95 x, 96 width=max_length, 97 break_long_words=False, 98 break_on_hyphens=False, 99 ), 100 tokens, 101 ), 102 ) 103 colors = list( 104 flatten( 105 [[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))], 106 levels_to_flatten=1, 107 ), 108 ) 109 wrapped = list(flatten(wrapped, levels_to_flatten=1)) 110 tokens = wrapped 111 112 # put everything together 113 output = [ 114 template.format( 115 clr=clr_join.join(map(str, map(int, clr))), 116 tok=_escape_tok(tok, fmt), 117 ) 118 for tok, clr in zip(tokens, colors, strict=False) 119 ] 120 121 return " ".join(output)
color tokens from a list with an RGB color array
tokens will not be escaped if fmt
is None
Parameters:
max_length: int | None
: Max number of characters before triggering a line wrap, i.e., making a new colorbox. IfNone
, no limit on max length.
def
color_tokens_cmap( tokens: list[str], weights: Sequence[float], cmap: str | matplotlib.colors.Colormap = 'Blues', fmt: Literal['html', 'latex', 'terminal', None] = 'html', template: str | None = None, labels: bool = False) -> str:
125def color_tokens_cmap( 126 tokens: list[str], 127 weights: Sequence[float], 128 cmap: str | matplotlib.colors.Colormap = "Blues", 129 fmt: FormatType = "html", 130 template: str | None = None, 131 labels: bool = False, 132) -> str: 133 "color tokens given a list of weights and a colormap" 134 n_tok: int = len(tokens) 135 assert n_tok == len(weights), f"'{len(tokens) = }' != '{len(weights) = }'" 136 weights_np: Float[np.ndarray, " n_tok"] = np.array(weights) 137 # normalize weights to [0, 1] 138 weights_norm = matplotlib.colors.Normalize()(weights_np) 139 140 if isinstance(cmap, str): 141 cmap = matplotlib.colormaps.get_cmap(cmap) 142 143 colors: RGBArray = cmap(weights_norm)[:, :3] * 255 144 145 output: str = color_tokens_rgb( 146 tokens=tokens, 147 colors=colors, 148 fmt=fmt, 149 template=template, 150 ) 151 152 if labels: 153 if fmt != "terminal": 154 raise NotImplementedError("labels only supported for terminal") 155 # align labels with the tokens 156 output += "\n" 157 for tok, weight in zip(tokens, weights_np, strict=False): 158 # 2 decimal points, left-aligned and trailing spaces to match token length 159 weight_str: str = f"{weight:.1f}" 160 # omit if longer than token 161 if len(weight_str) > len(tok): 162 weight_str = " " * len(tok) 163 else: 164 weight_str = weight_str.ljust(len(tok)) 165 output += f"{weight_str} " 166 167 return output
color tokens given a list of weights and a colormap
def
color_maze_tokens_AOTP( tokens: list[str], fmt: Literal['html', 'latex', 'terminal', None] = 'html', template: str | None = None, **kwargs) -> str:
184def color_maze_tokens_AOTP( 185 tokens: list[str], 186 fmt: FormatType = "html", 187 template: str | None = None, 188 **kwargs, 189) -> str: 190 """color tokens assuming AOTP format 191 192 i.e: adjaceny list, origin, target, path 193 194 """ 195 output: list[str] = [ 196 " ".join( 197 tokens_between( 198 tokens, 199 start_tok, 200 end_tok, 201 include_start=True, 202 include_end=True, 203 ), 204 ) 205 for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS 206 ] 207 208 colors: RGBArray = np.array( 209 list(_MAZE_TOKENS_DEFAULT_COLORS.values()), 210 dtype=np.uint8, 211 ) 212 213 return color_tokens_rgb( 214 tokens=output, 215 colors=colors, 216 fmt=fmt, 217 template=template, 218 **kwargs, 219 )
color tokens assuming AOTP format
i.e: adjaceny list, origin, target, path
def
display_html(html: str) -> None:
display html string
def
display_color_tokens_rgb(tokens: list[str], colors: jaxtyping.UInt8[ndarray, 'n 3']) -> None:
227def display_color_tokens_rgb( 228 tokens: list[str], 229 colors: RGBArray, 230) -> None: 231 """display tokens (as html) with custom colors""" 232 html: str = color_tokens_rgb(tokens, colors, fmt="html") 233 display_html(html)
display tokens (as html) with custom colors
def
display_color_tokens_cmap( tokens: list[str], weights: Sequence[float], cmap: str | matplotlib.colors.Colormap = 'Blues') -> None:
236def display_color_tokens_cmap( 237 tokens: list[str], 238 weights: Sequence[float], 239 cmap: str | matplotlib.colors.Colormap = "Blues", 240) -> None: 241 """display tokens (as html) with color based on weights""" 242 html: str = color_tokens_cmap(tokens, weights, cmap) 243 display_html(html)
display tokens (as html) with color based on weights
def
display_color_maze_tokens_AOTP(tokens: list[str]) -> None:
246def display_color_maze_tokens_AOTP( 247 tokens: list[str], 248) -> None: 249 """display maze tokens (as html) with AOTP coloring""" 250 html: str = color_maze_tokens_AOTP(tokens) 251 display_html(html)
display maze tokens (as html) with AOTP coloring