docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.plotting.print_tokens

Functions to print tokens with colors in different formats

you can color the tokens by their:

and the output can be in different formats, specified by FormatType (html, latex, terminal)


  1"""Functions to print tokens with colors in different formats
  2
  3you can color the tokens by their:
  4
  5- type (i.e. adjacency list, origin, target, path) using `color_maze_tokens_AOTP`
  6- custom weights (i.e. attention weights) using `color_tokens_cmap`
  7- entirely custom colors using `color_tokens_rgb`
  8
  9and the output can be in different formats, specified by `FormatType` (html, latex, terminal)
 10
 11"""
 12
 13import html
 14import textwrap
 15from typing import Literal, Sequence
 16
 17import matplotlib  # noqa: ICN001
 18import numpy as np
 19from IPython.display import HTML, display
 20from jaxtyping import Float, UInt8
 21from muutils.misc import flatten
 22
 23from maze_dataset.constants import SPECIAL_TOKENS
 24from maze_dataset.token_utils import tokens_between
 25
 26RGBArray = UInt8[np.ndarray, "n 3"]
 27"1D array of RGB values"
 28
 29FormatType = Literal["html", "latex", "terminal", None]
 30"output format for the tokens"
 31
 32TEMPLATES: dict[FormatType, str] = {
 33	"html": '<span style="color: black; background-color: rgb({clr})">&nbsp{tok}&nbsp</span>',
 34	"latex": "\\colorbox[RGB]{{ {clr} }}{{ \\texttt{{ {tok} }} }}",
 35	"terminal": "\033[30m\033[48;2;{clr}m{tok}\033[0m",
 36}
 37"templates of printing tokens in different formats"
 38
 39_COLOR_JOIN: dict[FormatType, str] = {
 40	"html": ",",
 41	"latex": ",",
 42	"terminal": ";",
 43}
 44"joiner for colors in different formats"
 45
 46
 47def _escape_tok(
 48	tok: str,
 49	fmt: FormatType,
 50) -> str:
 51	"escape token based on format"
 52	if fmt == "html":
 53		return html.escape(tok)
 54	elif fmt == "latex":
 55		return tok.replace("_", "\\_").replace("#", "\\#")
 56	elif fmt == "terminal":
 57		return tok
 58	else:
 59		err_msg: str = f"Unexpected format: {fmt}"
 60		raise ValueError(err_msg)
 61
 62
 63def color_tokens_rgb(
 64	tokens: list,
 65	colors: Sequence[Sequence[int]] | Float[np.ndarray, "n 3"],
 66	fmt: FormatType = "html",
 67	template: str | None = None,
 68	clr_join: str | None = None,
 69	max_length: int | None = None,
 70) -> str:
 71	"""color tokens from a list with an RGB color array
 72
 73	tokens will not be escaped if `fmt` is None
 74
 75	# Parameters:
 76	- `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length.
 77	"""
 78	# process format
 79	if fmt is None:
 80		assert template is not None
 81		assert clr_join is not None
 82	else:
 83		assert template is None
 84		assert clr_join is None
 85		template = TEMPLATES[fmt]
 86		clr_join = _COLOR_JOIN[fmt]
 87
 88	if max_length is not None:
 89		# TODO: why are we using a map here again?
 90		# TYPING: this is missing a lot of type hints
 91		wrapped: list = list(  # noqa: C417
 92			map(
 93				lambda x: textwrap.wrap(
 94					x,
 95					width=max_length,
 96					break_long_words=False,
 97					break_on_hyphens=False,
 98				),
 99				tokens,
100			),
101		)
102		colors = list(
103			flatten(
104				[[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))],
105				levels_to_flatten=1,
106			),
107		)
108		wrapped = list(flatten(wrapped, levels_to_flatten=1))
109		tokens = wrapped
110
111	# put everything together
112	output = [
113		template.format(
114			clr=clr_join.join(map(str, map(int, clr))),
115			tok=_escape_tok(tok, fmt),
116		)
117		for tok, clr in zip(tokens, colors, strict=False)
118	]
119
120	return " ".join(output)
121
122
123# TYPING: would be nice to type hint as html, latex, or terminal string and overload depending on `FormatType`
124def color_tokens_cmap(
125	tokens: list[str],
126	weights: Sequence[float],
127	cmap: str | matplotlib.colors.Colormap = "Blues",
128	fmt: FormatType = "html",
129	template: str | None = None,
130	labels: bool = False,
131) -> str:
132	"color tokens given a list of weights and a colormap"
133	n_tok: int = len(tokens)
134	assert n_tok == len(weights), f"'{len(tokens) = }' != '{len(weights) = }'"
135	weights_np: Float[np.ndarray, " n_tok"] = np.array(weights)
136	# normalize weights to [0, 1]
137	weights_norm = matplotlib.colors.Normalize()(weights_np)
138
139	if isinstance(cmap, str):
140		cmap = matplotlib.colormaps.get_cmap(cmap)
141
142	colors: RGBArray = cmap(weights_norm)[:, :3] * 255
143
144	output: str = color_tokens_rgb(
145		tokens=tokens,
146		colors=colors,
147		fmt=fmt,
148		template=template,
149	)
150
151	if labels:
152		if fmt != "terminal":
153			raise NotImplementedError("labels only supported for terminal")
154		# align labels with the tokens
155		output += "\n"
156		for tok, weight in zip(tokens, weights_np, strict=False):
157			# 2 decimal points, left-aligned and trailing spaces to match token length
158			weight_str: str = f"{weight:.1f}"
159			# omit if longer than token
160			if len(weight_str) > len(tok):
161				weight_str = " " * len(tok)
162			else:
163				weight_str = weight_str.ljust(len(tok))
164			output += f"{weight_str} "
165
166	return output
167
168
169# colors roughly made to be similar to visual representation
170_MAZE_TOKENS_DEFAULT_COLORS: dict[tuple[str, str], tuple[int, int, int]] = {
171	(SPECIAL_TOKENS.ADJLIST_START, SPECIAL_TOKENS.ADJLIST_END): (
172		176,
173		152,
174		232,
175	),  # purple
176	(SPECIAL_TOKENS.ORIGIN_START, SPECIAL_TOKENS.ORIGIN_END): (154, 239, 123),  # green
177	(SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END): (246, 136, 136),  # red
178	(SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END): (111, 187, 254),  # blue
179}
180"default colors for maze tokens, roughly matches the format of `as_pixels`"
181
182
183def color_maze_tokens_AOTP(
184	tokens: list[str],
185	fmt: FormatType = "html",
186	template: str | None = None,
187	**kwargs,
188) -> str:
189	"""color tokens assuming AOTP format
190
191	i.e: adjaceny list, origin, target, path
192
193	"""
194	output: list[str] = [
195		" ".join(
196			tokens_between(
197				tokens,
198				start_tok,
199				end_tok,
200				include_start=True,
201				include_end=True,
202			),
203		)
204		for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS
205	]
206
207	colors: RGBArray = np.array(
208		list(_MAZE_TOKENS_DEFAULT_COLORS.values()),
209		dtype=np.uint8,
210	)
211
212	return color_tokens_rgb(
213		tokens=output,
214		colors=colors,
215		fmt=fmt,
216		template=template,
217		**kwargs,
218	)
219
220
221def display_html(html: str) -> None:
222	"display html string"
223	display(HTML(html))
224
225
226def display_color_tokens_rgb(
227	tokens: list[str],
228	colors: RGBArray,
229) -> None:
230	"""display tokens (as html) with custom colors"""
231	html: str = color_tokens_rgb(tokens, colors, fmt="html")
232	display_html(html)
233
234
235def display_color_tokens_cmap(
236	tokens: list[str],
237	weights: Sequence[float],
238	cmap: str | matplotlib.colors.Colormap = "Blues",
239) -> None:
240	"""display tokens (as html) with color based on weights"""
241	html: str = color_tokens_cmap(tokens, weights, cmap)
242	display_html(html)
243
244
245def display_color_maze_tokens_AOTP(
246	tokens: list[str],
247) -> None:
248	"""display maze tokens (as html) with AOTP coloring"""
249	html: str = color_maze_tokens_AOTP(tokens)
250	display_html(html)

RGBArray = <class 'jaxtyping.UInt8[ndarray, 'n 3']'>

1D array of RGB values

FormatType = typing.Literal['html', 'latex', 'terminal', None]

output format for the tokens

TEMPLATES: dict[typing.Literal['html', 'latex', 'terminal', None], str] = {'html': '<span style="color: black; background-color: rgb({clr})">&nbsp{tok}&nbsp</span>', 'latex': '\\colorbox[RGB]{{ {clr} }}{{ \\texttt{{ {tok} }} }}', 'terminal': '\x1b[30m\x1b[48;2;{clr}m{tok}\x1b[0m'}

templates of printing tokens in different formats

def color_tokens_rgb( tokens: list, colors: Union[Sequence[Sequence[int]], jaxtyping.Float[ndarray, 'n 3']], fmt: Literal['html', 'latex', 'terminal', None] = 'html', template: str | None = None, clr_join: str | None = None, max_length: int | None = None) -> str:
 64def color_tokens_rgb(
 65	tokens: list,
 66	colors: Sequence[Sequence[int]] | Float[np.ndarray, "n 3"],
 67	fmt: FormatType = "html",
 68	template: str | None = None,
 69	clr_join: str | None = None,
 70	max_length: int | None = None,
 71) -> str:
 72	"""color tokens from a list with an RGB color array
 73
 74	tokens will not be escaped if `fmt` is None
 75
 76	# Parameters:
 77	- `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length.
 78	"""
 79	# process format
 80	if fmt is None:
 81		assert template is not None
 82		assert clr_join is not None
 83	else:
 84		assert template is None
 85		assert clr_join is None
 86		template = TEMPLATES[fmt]
 87		clr_join = _COLOR_JOIN[fmt]
 88
 89	if max_length is not None:
 90		# TODO: why are we using a map here again?
 91		# TYPING: this is missing a lot of type hints
 92		wrapped: list = list(  # noqa: C417
 93			map(
 94				lambda x: textwrap.wrap(
 95					x,
 96					width=max_length,
 97					break_long_words=False,
 98					break_on_hyphens=False,
 99				),
100				tokens,
101			),
102		)
103		colors = list(
104			flatten(
105				[[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))],
106				levels_to_flatten=1,
107			),
108		)
109		wrapped = list(flatten(wrapped, levels_to_flatten=1))
110		tokens = wrapped
111
112	# put everything together
113	output = [
114		template.format(
115			clr=clr_join.join(map(str, map(int, clr))),
116			tok=_escape_tok(tok, fmt),
117		)
118		for tok, clr in zip(tokens, colors, strict=False)
119	]
120
121	return " ".join(output)

color tokens from a list with an RGB color array

tokens will not be escaped if fmt is None

Parameters:

  • max_length: int | None: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If None, no limit on max length.
def color_tokens_cmap( tokens: list[str], weights: Sequence[float], cmap: str | matplotlib.colors.Colormap = 'Blues', fmt: Literal['html', 'latex', 'terminal', None] = 'html', template: str | None = None, labels: bool = False) -> str:
125def color_tokens_cmap(
126	tokens: list[str],
127	weights: Sequence[float],
128	cmap: str | matplotlib.colors.Colormap = "Blues",
129	fmt: FormatType = "html",
130	template: str | None = None,
131	labels: bool = False,
132) -> str:
133	"color tokens given a list of weights and a colormap"
134	n_tok: int = len(tokens)
135	assert n_tok == len(weights), f"'{len(tokens) = }' != '{len(weights) = }'"
136	weights_np: Float[np.ndarray, " n_tok"] = np.array(weights)
137	# normalize weights to [0, 1]
138	weights_norm = matplotlib.colors.Normalize()(weights_np)
139
140	if isinstance(cmap, str):
141		cmap = matplotlib.colormaps.get_cmap(cmap)
142
143	colors: RGBArray = cmap(weights_norm)[:, :3] * 255
144
145	output: str = color_tokens_rgb(
146		tokens=tokens,
147		colors=colors,
148		fmt=fmt,
149		template=template,
150	)
151
152	if labels:
153		if fmt != "terminal":
154			raise NotImplementedError("labels only supported for terminal")
155		# align labels with the tokens
156		output += "\n"
157		for tok, weight in zip(tokens, weights_np, strict=False):
158			# 2 decimal points, left-aligned and trailing spaces to match token length
159			weight_str: str = f"{weight:.1f}"
160			# omit if longer than token
161			if len(weight_str) > len(tok):
162				weight_str = " " * len(tok)
163			else:
164				weight_str = weight_str.ljust(len(tok))
165			output += f"{weight_str} "
166
167	return output

color tokens given a list of weights and a colormap

def color_maze_tokens_AOTP( tokens: list[str], fmt: Literal['html', 'latex', 'terminal', None] = 'html', template: str | None = None, **kwargs) -> str:
184def color_maze_tokens_AOTP(
185	tokens: list[str],
186	fmt: FormatType = "html",
187	template: str | None = None,
188	**kwargs,
189) -> str:
190	"""color tokens assuming AOTP format
191
192	i.e: adjaceny list, origin, target, path
193
194	"""
195	output: list[str] = [
196		" ".join(
197			tokens_between(
198				tokens,
199				start_tok,
200				end_tok,
201				include_start=True,
202				include_end=True,
203			),
204		)
205		for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS
206	]
207
208	colors: RGBArray = np.array(
209		list(_MAZE_TOKENS_DEFAULT_COLORS.values()),
210		dtype=np.uint8,
211	)
212
213	return color_tokens_rgb(
214		tokens=output,
215		colors=colors,
216		fmt=fmt,
217		template=template,
218		**kwargs,
219	)

color tokens assuming AOTP format

i.e: adjaceny list, origin, target, path

def display_html(html: str) -> None:
222def display_html(html: str) -> None:
223	"display html string"
224	display(HTML(html))

display html string

def display_color_tokens_rgb(tokens: list[str], colors: jaxtyping.UInt8[ndarray, 'n 3']) -> None:
227def display_color_tokens_rgb(
228	tokens: list[str],
229	colors: RGBArray,
230) -> None:
231	"""display tokens (as html) with custom colors"""
232	html: str = color_tokens_rgb(tokens, colors, fmt="html")
233	display_html(html)

display tokens (as html) with custom colors

def display_color_tokens_cmap( tokens: list[str], weights: Sequence[float], cmap: str | matplotlib.colors.Colormap = 'Blues') -> None:
236def display_color_tokens_cmap(
237	tokens: list[str],
238	weights: Sequence[float],
239	cmap: str | matplotlib.colors.Colormap = "Blues",
240) -> None:
241	"""display tokens (as html) with color based on weights"""
242	html: str = color_tokens_cmap(tokens, weights, cmap)
243	display_html(html)

display tokens (as html) with color based on weights

def display_color_maze_tokens_AOTP(tokens: list[str]) -> None:
246def display_color_maze_tokens_AOTP(
247	tokens: list[str],
248) -> None:
249	"""display maze tokens (as html) with AOTP coloring"""
250	html: str = color_maze_tokens_AOTP(tokens)
251	display_html(html)

display maze tokens (as html) with AOTP coloring