In [1]:
import matplotlib.pyplot as plt
import numpy as np
from maze_dataset import (
LatticeMazeGenerators,
MazeDataset,
MazeDatasetConfig,
SolvedMaze,
)
from maze_dataset.plotting import plot_dataset_mazes
from maze_dataset.plotting.print_tokens import (
display_color_maze_tokens_AOTP,
display_color_tokens_cmap,
display_color_tokens_rgb,
)
from maze_dataset.tokenization import MazeTokenizer, TokenizationMode
from maze_dataset.utils import corner_first_ndindex
Let's get a basic dataset first:
In [2]:
CFG: MazeDatasetConfig = MazeDatasetConfig(
name="test",
grid_n=5,
n_mazes=5,
maze_ctor=LatticeMazeGenerators.gen_dfs,
)
In [3]:
DATASET: MazeDataset = MazeDataset.from_config(
CFG,
local_base_path="../data/maze_dataset/",
)
In [4]:
plot_dataset_mazes(DATASET)
Out[4]:
(<Figure size 500x200 with 5 Axes>, array([<Axes: >, <Axes: >, <Axes: >, <Axes: >, <Axes: >], dtype=object))
Now, let's see how tokenization works:
In [5]:
TOKENIZER: MazeTokenizer = MazeTokenizer(
tokenization_mode=TokenizationMode.AOTP_UT_rasterized,
max_grid_size=100,
)
TOKENIZER_INDEXED: MazeTokenizer = MazeTokenizer(
tokenization_mode=TokenizationMode.AOTP_CTT_indexed,
max_grid_size=100,
)
In [6]:
STRINGIFIED: list[str] = DATASET.as_tokens(TOKENIZER, join_tokens_individual_maze=True)
STRINGIFIED_INDEXED: list[str] = DATASET.as_tokens(
TOKENIZER_INDEXED,
join_tokens_individual_maze=True,
)
print("Rasterized:\n" + "\n".join(STRINGIFIED))
print("\nIndexed:\n" + "\n".join(STRINGIFIED_INDEXED))
Rasterized: <ADJLIST_START> (0,0) <--> (1,0) ; (2,0) <--> (3,0) ; (4,1) <--> (4,0) ; (2,0) <--> (2,1) ; (1,0) <--> (1,1) ; (3,4) <--> (2,4) ; (4,2) <--> (4,3) ; (0,0) <--> (0,1) ; (0,3) <--> (0,2) ; (4,4) <--> (3,4) ; (4,3) <--> (4,4) ; (4,1) <--> (4,2) ; (2,1) <--> (2,2) ; (1,4) <--> (0,4) ; (1,2) <--> (0,2) ; (2,4) <--> (2,3) ; (4,0) <--> (3,0) ; (2,2) <--> (3,2) ; (1,2) <--> (2,2) ; (1,3) <--> (0,3) ; (3,2) <--> (3,3) ; (0,2) <--> (0,1) ; (3,1) <--> (3,2) ; (1,3) <--> (1,4) ; <ADJLIST_END> <ORIGIN_START> (1,3) <ORIGIN_END> <TARGET_START> (2,3) <TARGET_END> <PATH_START> (1,3) (0,3) (0,2) (1,2) (2,2) (2,1) (2,0) (3,0) (4,0) (4,1) (4,2) (4,3) (4,4) (3,4) (2,4) (2,3) <PATH_END> <ADJLIST_START> (2,1) <--> (2,2) ; (1,4) <--> (0,4) ; (2,4) <--> (1,4) ; (0,1) <--> (1,1) ; (1,0) <--> (2,0) ; (2,3) <--> (3,3) ; (3,3) <--> (3,4) ; (2,0) <--> (2,1) ; (1,3) <--> (1,2) ; (0,1) <--> (0,0) ; (1,3) <--> (1,4) ; (0,2) <--> (1,2) ; (4,0) <--> (3,0) ; (4,1) <--> (4,2) ; (3,1) <--> (2,1) ; (4,1) <--> (3,1) ; (2,3) <--> (2,4) ; (4,3) <--> (4,2) ; (2,0) <--> (3,0) ; (0,0) <--> (1,0) ; (4,3) <--> (4,4) ; (0,3) <--> (0,2) ; (3,2) <--> (2,2) ; (4,4) <--> (3,4) ; <ADJLIST_END> <ORIGIN_START> (3,4) <ORIGIN_END> <TARGET_START> (2,0) <TARGET_END> <PATH_START> (3,4) (4,4) (4,3) (4,2) (4,1) (3,1) (2,1) (2,0) <PATH_END> <ADJLIST_START> (0,0) <--> (1,0) ; (1,0) <--> (1,1) ; (0,1) <--> (0,0) ; (3,3) <--> (2,3) ; (2,3) <--> (2,4) ; (1,2) <--> (1,1) ; (3,3) <--> (4,3) ; (4,2) <--> (4,1) ; (3,4) <--> (2,4) ; (4,0) <--> (3,0) ; (1,3) <--> (0,3) ; (2,0) <--> (3,0) ; (1,4) <--> (0,4) ; (2,1) <--> (2,2) ; (4,0) <--> (4,1) ; (0,4) <--> (0,3) ; (3,2) <--> (3,3) ; (4,2) <--> (4,3) ; (3,1) <--> (4,1) ; (0,3) <--> (0,2) ; (4,4) <--> (3,4) ; (2,1) <--> (3,1) ; (1,4) <--> (2,4) ; (1,2) <--> (2,2) ; <ADJLIST_END> <ORIGIN_START> (0,0) <ORIGIN_END> <TARGET_START> (3,3) <TARGET_END> <PATH_START> (0,0) (1,0) (1,1) (1,2) (2,2) (2,1) (3,1) (4,1) (4,2) (4,3) (3,3) <PATH_END> <ADJLIST_START> (3,2) <--> (4,2) ; (0,0) <--> (0,1) ; (4,1) <--> (3,1) ; (3,3) <--> (2,3) ; (3,4) <--> (4,4) ; (1,1) <--> (2,1) ; (2,0) <--> (3,0) ; (3,1) <--> (2,1) ; (1,4) <--> (2,4) ; (3,0) <--> (4,0) ; (2,3) <--> (2,4) ; (1,1) <--> (1,0) ; (2,0) <--> (1,0) ; (0,1) <--> (0,2) ; (1,0) <--> (0,0) ; (0,3) <--> (0,2) ; (4,2) <--> (4,3) ; (1,4) <--> (1,3) ; (1,2) <--> (0,2) ; (3,3) <--> (3,2) ; (4,2) <--> (4,1) ; (4,3) <--> (4,4) ; (1,2) <--> (2,2) ; (0,3) <--> (0,4) ; <ADJLIST_END> <ORIGIN_START> (0,1) <ORIGIN_END> <TARGET_START> (1,0) <TARGET_END> <PATH_START> (0,1) (0,0) (1,0) <PATH_END> <ADJLIST_START> (2,2) <--> (3,2) ; (0,2) <--> (0,3) ; (1,3) <--> (1,2) ; (2,1) <--> (1,1) ; (1,4) <--> (2,4) ; (4,4) <--> (4,3) ; (4,2) <--> (4,3) ; (2,3) <--> (2,4) ; (2,3) <--> (3,3) ; (4,1) <--> (4,2) ; (4,1) <--> (3,1) ; (0,0) <--> (0,1) ; (2,2) <--> (1,2) ; (4,4) <--> (3,4) ; (0,2) <--> (0,1) ; (3,3) <--> (3,2) ; (1,0) <--> (0,0) ; (0,4) <--> (1,4) ; (1,1) <--> (1,2) ; (4,0) <--> (4,1) ; (0,4) <--> (0,3) ; (3,1) <--> (2,1) ; (2,0) <--> (1,0) ; (3,0) <--> (4,0) ; <ADJLIST_END> <ORIGIN_START> (2,2) <ORIGIN_END> <TARGET_START> (1,0) <TARGET_END> <PATH_START> (2,2) (3,2) (3,3) (2,3) (2,4) (1,4) (0,4) (0,3) (0,2) (0,1) (0,0) (1,0) <PATH_END> Indexed: <ADJLIST_START> ( 2 , 2 ) <--> ( 3 , 2 ) ; ( 3 , 4 ) <--> ( 4 , 4 ) ; ( 4 , 1 ) <--> ( 4 , 0 ) ; ( 2 , 1 ) <--> ( 2 , 2 ) ; ( 3 , 1 ) <--> ( 3 , 2 ) ; ( 1 , 3 ) <--> ( 1 , 4 ) ; ( 4 , 0 ) <--> ( 3 , 0 ) ; ( 1 , 4 ) <--> ( 0 , 4 ) ; ( 1 , 3 ) <--> ( 0 , 3 ) ; ( 2 , 3 ) <--> ( 2 , 4 ) ; ( 0 , 1 ) <--> ( 0 , 2 ) ; ( 1 , 0 ) <--> ( 0 , 0 ) ; ( 2 , 2 ) <--> ( 1 , 2 ) ; ( 1 , 2 ) <--> ( 0 , 2 ) ; ( 2 , 0 ) <--> ( 2 , 1 ) ; ( 0 , 0 ) <--> ( 0 , 1 ) ; ( 3 , 4 ) <--> ( 2 , 4 ) ; ( 0 , 2 ) <--> ( 0 , 3 ) ; ( 3 , 3 ) <--> ( 3 , 2 ) ; ( 3 , 0 ) <--> ( 2 , 0 ) ; ( 1 , 1 ) <--> ( 1 , 0 ) ; ( 4 , 4 ) <--> ( 4 , 3 ) ; ( 4 , 3 ) <--> ( 4 , 2 ) ; ( 4 , 1 ) <--> ( 4 , 2 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 3 ) <ORIGIN_END> <TARGET_START> ( 2 , 3 ) <TARGET_END> <PATH_START> ( 1 , 3 ) ( 0 , 3 ) ( 0 , 2 ) ( 1 , 2 ) ( 2 , 2 ) ( 2 , 1 ) ( 2 , 0 ) ( 3 , 0 ) ( 4 , 0 ) ( 4 , 1 ) ( 4 , 2 ) ( 4 , 3 ) ( 4 , 4 ) ( 3 , 4 ) ( 2 , 4 ) ( 2 , 3 ) <PATH_END> <ADJLIST_START> ( 3 , 3 ) <--> ( 3 , 4 ) ; ( 1 , 4 ) <--> ( 1 , 3 ) ; ( 2 , 4 ) <--> ( 2 , 3 ) ; ( 2 , 4 ) <--> ( 1 , 4 ) ; ( 4 , 2 ) <--> ( 4 , 3 ) ; ( 0 , 0 ) <--> ( 0 , 1 ) ; ( 1 , 0 ) <--> ( 2 , 0 ) ; ( 1 , 4 ) <--> ( 0 , 4 ) ; ( 0 , 2 ) <--> ( 0 , 3 ) ; ( 2 , 1 ) <--> ( 2 , 2 ) ; ( 2 , 0 ) <--> ( 2 , 1 ) ; ( 3 , 1 ) <--> ( 4 , 1 ) ; ( 3 , 4 ) <--> ( 4 , 4 ) ; ( 4 , 3 ) <--> ( 4 , 4 ) ; ( 4 , 1 ) <--> ( 4 , 2 ) ; ( 1 , 1 ) <--> ( 0 , 1 ) ; ( 1 , 2 ) <--> ( 0 , 2 ) ; ( 2 , 3 ) <--> ( 3 , 3 ) ; ( 3 , 0 ) <--> ( 2 , 0 ) ; ( 1 , 3 ) <--> ( 1 , 2 ) ; ( 3 , 0 ) <--> ( 4 , 0 ) ; ( 2 , 2 ) <--> ( 3 , 2 ) ; ( 0 , 0 ) <--> ( 1 , 0 ) ; ( 3 , 1 ) <--> ( 2 , 1 ) ; <ADJLIST_END> <ORIGIN_START> ( 3 , 4 ) <ORIGIN_END> <TARGET_START> ( 2 , 0 ) <TARGET_END> <PATH_START> ( 3 , 4 ) ( 4 , 4 ) ( 4 , 3 ) ( 4 , 2 ) ( 4 , 1 ) ( 3 , 1 ) ( 2 , 1 ) ( 2 , 0 ) <PATH_END> <ADJLIST_START> ( 4 , 3 ) <--> ( 4 , 2 ) ; ( 0 , 4 ) <--> ( 1 , 4 ) ; ( 4 , 1 ) <--> ( 3 , 1 ) ; ( 2 , 2 ) <--> ( 1 , 2 ) ; ( 3 , 3 ) <--> ( 4 , 3 ) ; ( 3 , 0 ) <--> ( 2 , 0 ) ; ( 1 , 3 ) <--> ( 0 , 3 ) ; ( 3 , 3 ) <--> ( 3 , 2 ) ; ( 4 , 0 ) <--> ( 4 , 1 ) ; ( 0 , 0 ) <--> ( 1 , 0 ) ; ( 0 , 3 ) <--> ( 0 , 2 ) ; ( 3 , 4 ) <--> ( 2 , 4 ) ; ( 2 , 1 ) <--> ( 3 , 1 ) ; ( 0 , 0 ) <--> ( 0 , 1 ) ; ( 4 , 1 ) <--> ( 4 , 2 ) ; ( 2 , 2 ) <--> ( 2 , 1 ) ; ( 1 , 4 ) <--> ( 2 , 4 ) ; ( 1 , 1 ) <--> ( 1 , 0 ) ; ( 4 , 0 ) <--> ( 3 , 0 ) ; ( 1 , 1 ) <--> ( 1 , 2 ) ; ( 3 , 4 ) <--> ( 4 , 4 ) ; ( 0 , 3 ) <--> ( 0 , 4 ) ; ( 2 , 3 ) <--> ( 2 , 4 ) ; ( 2 , 3 ) <--> ( 3 , 3 ) ; <ADJLIST_END> <ORIGIN_START> ( 0 , 0 ) <ORIGIN_END> <TARGET_START> ( 3 , 3 ) <TARGET_END> <PATH_START> ( 0 , 0 ) ( 1 , 0 ) ( 1 , 1 ) ( 1 , 2 ) ( 2 , 2 ) ( 2 , 1 ) ( 3 , 1 ) ( 4 , 1 ) ( 4 , 2 ) ( 4 , 3 ) ( 3 , 3 ) <PATH_END> <ADJLIST_START> ( 0 , 1 ) <--> ( 0 , 2 ) ; ( 1 , 4 ) <--> ( 2 , 4 ) ; ( 4 , 4 ) <--> ( 4 , 3 ) ; ( 4 , 1 ) <--> ( 4 , 2 ) ; ( 2 , 3 ) <--> ( 3 , 3 ) ; ( 2 , 0 ) <--> ( 1 , 0 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; ( 4 , 3 ) <--> ( 4 , 2 ) ; ( 0 , 4 ) <--> ( 0 , 3 ) ; ( 1 , 1 ) <--> ( 1 , 0 ) ; ( 3 , 1 ) <--> ( 4 , 1 ) ; ( 3 , 2 ) <--> ( 4 , 2 ) ; ( 3 , 3 ) <--> ( 3 , 2 ) ; ( 1 , 2 ) <--> ( 0 , 2 ) ; ( 3 , 0 ) <--> ( 2 , 0 ) ; ( 2 , 1 ) <--> ( 3 , 1 ) ; ( 0 , 2 ) <--> ( 0 , 3 ) ; ( 1 , 1 ) <--> ( 2 , 1 ) ; ( 3 , 4 ) <--> ( 4 , 4 ) ; ( 2 , 2 ) <--> ( 1 , 2 ) ; ( 1 , 0 ) <--> ( 0 , 0 ) ; ( 1 , 3 ) <--> ( 1 , 4 ) ; ( 3 , 0 ) <--> ( 4 , 0 ) ; ( 2 , 3 ) <--> ( 2 , 4 ) ; <ADJLIST_END> <ORIGIN_START> ( 0 , 1 ) <ORIGIN_END> <TARGET_START> ( 1 , 0 ) <TARGET_END> <PATH_START> ( 0 , 1 ) ( 0 , 0 ) ( 1 , 0 ) <PATH_END> <ADJLIST_START> ( 0 , 0 ) <--> ( 0 , 1 ) ; ( 3 , 3 ) <--> ( 3 , 2 ) ; ( 1 , 2 ) <--> ( 1 , 3 ) ; ( 1 , 4 ) <--> ( 2 , 4 ) ; ( 2 , 4 ) <--> ( 2 , 3 ) ; ( 4 , 3 ) <--> ( 4 , 4 ) ; ( 4 , 2 ) <--> ( 4 , 1 ) ; ( 4 , 1 ) <--> ( 3 , 1 ) ; ( 3 , 4 ) <--> ( 4 , 4 ) ; ( 0 , 2 ) <--> ( 0 , 3 ) ; ( 4 , 3 ) <--> ( 4 , 2 ) ; ( 0 , 2 ) <--> ( 0 , 1 ) ; ( 2 , 2 ) <--> ( 3 , 2 ) ; ( 2 , 1 ) <--> ( 3 , 1 ) ; ( 1 , 0 ) <--> ( 0 , 0 ) ; ( 1 , 4 ) <--> ( 0 , 4 ) ; ( 2 , 1 ) <--> ( 1 , 1 ) ; ( 0 , 4 ) <--> ( 0 , 3 ) ; ( 4 , 1 ) <--> ( 4 , 0 ) ; ( 1 , 1 ) <--> ( 1 , 2 ) ; ( 1 , 0 ) <--> ( 2 , 0 ) ; ( 2 , 3 ) <--> ( 3 , 3 ) ; ( 2 , 2 ) <--> ( 1 , 2 ) ; ( 3 , 0 ) <--> ( 4 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 2 , 2 ) <ORIGIN_END> <TARGET_START> ( 1 , 0 ) <TARGET_END> <PATH_START> ( 2 , 2 ) ( 3 , 2 ) ( 3 , 3 ) ( 2 , 3 ) ( 2 , 4 ) ( 1 , 4 ) ( 0 , 4 ) ( 0 , 3 ) ( 0 , 2 ) ( 0 , 1 ) ( 0 , 0 ) ( 1 , 0 ) <PATH_END>
In [7]:
x = STRINGIFIED[0].split()
display_color_tokens_rgb(x, np.random.randint(0, 255, (len(x), 3)))
display_color_tokens_cmap(x, np.random.randint(0, 255, len(x)))
display_color_maze_tokens_AOTP(x)
<ADJLIST_START> (0,0) <--> (1,0) (2,0) <--> (3,0) (4,1) <--> (4,0) (2,0) <--> (2,1) (1,0) <--> (1,1) (3,4) <--> (2,4) (4,2) <--> (4,3) (0,0) <--> (0,1) (0,3) <--> (0,2) (4,4) <--> (3,4) (4,3) <--> (4,4) (4,1) <--> (4,2) (2,1) <--> (2,2) (1,4) <--> (0,4) (1,2) <--> (0,2) (2,4) <--> (2,3) (4,0) <--> (3,0) (2,2) <--> (3,2) (1,2) <--> (2,2) (1,3) <--> (0,3) (3,2) <--> (3,3) (0,2) <--> (0,1) (3,1) <--> (3,2) (1,3) <--> (1,4) <ADJLIST_END> <ORIGIN_START> (1,3) <ORIGIN_END> <TARGET_START> (2,3) <TARGET_END> <PATH_START> (1,3) (0,3) (0,2) (1,2) (2,2) (2,1) (2,0) (3,0) (4,0) (4,1) (4,2) (4,3) (4,4) (3,4) (2,4) (2,3) <PATH_END>
<ADJLIST_START> (0,0) <--> (1,0) (2,0) <--> (3,0) (4,1) <--> (4,0) (2,0) <--> (2,1) (1,0) <--> (1,1) (3,4) <--> (2,4) (4,2) <--> (4,3) (0,0) <--> (0,1) (0,3) <--> (0,2) (4,4) <--> (3,4) (4,3) <--> (4,4) (4,1) <--> (4,2) (2,1) <--> (2,2) (1,4) <--> (0,4) (1,2) <--> (0,2) (2,4) <--> (2,3) (4,0) <--> (3,0) (2,2) <--> (3,2) (1,2) <--> (2,2) (1,3) <--> (0,3) (3,2) <--> (3,3) (0,2) <--> (0,1) (3,1) <--> (3,2) (1,3) <--> (1,4) <ADJLIST_END> <ORIGIN_START> (1,3) <ORIGIN_END> <TARGET_START> (2,3) <TARGET_END> <PATH_START> (1,3) (0,3) (0,2) (1,2) (2,2) (2,1) (2,0) (3,0) (4,0) (4,1) (4,2) (4,3) (4,4) (3,4) (2,4) (2,3) <PATH_END>
<ADJLIST_START> (0,0) <--> (1,0) ; (2,0) <--> (3,0) ; (4,1) <--> (4,0) ; (2,0) <--> (2,1) ; (1,0) <--> (1,1) ; (3,4) <--> (2,4) ; (4,2) <--> (4,3) ; (0,0) <--> (0,1) ; (0,3) <--> (0,2) ; (4,4) <--> (3,4) ; (4,3) <--> (4,4) ; (4,1) <--> (4,2) ; (2,1) <--> (2,2) ; (1,4) <--> (0,4) ; (1,2) <--> (0,2) ; (2,4) <--> (2,3) ; (4,0) <--> (3,0) ; (2,2) <--> (3,2) ; (1,2) <--> (2,2) ; (1,3) <--> (0,3) ; (3,2) <--> (3,3) ; (0,2) <--> (0,1) ; (3,1) <--> (3,2) ; (1,3) <--> (1,4) ; <ADJLIST_END> <ORIGIN_START> (1,3) <ORIGIN_END> <TARGET_START> (2,3) <TARGET_END> <PATH_START> (1,3) (0,3) (0,2) (1,2) (2,2) (2,1) (2,0) (3,0) (4,0) (4,1) (4,2) (4,3) (4,4) (3,4) (2,4) (2,3) <PATH_END>
Now do the same for TokenizerMode.AOTP_CTT_indexed
.
In [8]:
x = STRINGIFIED_INDEXED[0].split()
display_color_tokens_rgb(x, np.random.randint(0, 255, (len(x), 3)))
display_color_tokens_cmap(x, np.random.randint(0, 255, len(x)))
display_color_maze_tokens_AOTP(x)
<ADJLIST_START> (  2 ,  2 ) <--> (  3 ,  2 ) (  3 ,  4 ) <--> (  4 ,  4 ) (  4 ,  1 ) <--> (  4 ,  0 ) (  2 ,  1 ) <--> (  2 ,  2 ) (  3 ,  1 ) <--> (  3 ,  2 ) (  1 ,  3 ) <--> (  1 ,  4 ) (  4 ,  0 ) <--> (  3 ,  0 ) (  1 ,  4 ) <--> (  0 ,  4 ) (  1 ,  3 ) <--> (  0 ,  3 ) (  2 ,  3 ) <--> (  2 ,  4 ) (  0 ,  1 ) <--> (  0 ,  2 ) (  1 ,  0 ) <--> (  0 ,  0 ) (  2 ,  2 ) <--> (  1 ,  2 ) (  1 ,  2 ) <--> (  0 ,  2 ) (  2 ,  0 ) <--> (  2 ,  1 ) (  0 ,  0 ) <--> (  0 ,  1 ) (  3 ,  4 ) <--> (  2 ,  4 ) (  0 ,  2 ) <--> (  0 ,  3 ) (  3 ,  3 ) <--> (  3 ,  2 ) (  3 ,  0 ) <--> (  2 ,  0 ) (  1 ,  1 ) <--> (  1 ,  0 ) (  4 ,  4 ) <--> (  4 ,  3 ) (  4 ,  3 ) <--> (  4 ,  2 ) (  4 ,  1 ) <--> (  4 ,  2 ) <ADJLIST_END> <ORIGIN_START> (  1 ,  3 ) <ORIGIN_END> <TARGET_START> (  2 ,  3 ) <TARGET_END> <PATH_START> (  1 ,  3 ) (  0 ,  3 ) (  0 ,  2 ) (  1 ,  2 ) (  2 ,  2 ) (  2 ,  1 ) (  2 ,  0 ) (  3 ,  0 ) (  4 ,  0 ) (  4 ,  1 ) (  4 ,  2 ) (  4 ,  3 ) (  4 ,  4 ) (  3 ,  4 ) (  2 ,  4 ) (  2 ,  3 ) <PATH_END>
<ADJLIST_START> (  2 ,  2 ) <--> (  3 ,  2 ) (  3 ,  4 ) <--> (  4 ,  4 ) (  4 ,  1 ) <--> (  4 ,  0 ) (  2 ,  1 ) <--> (  2 ,  2 ) (  3 ,  1 ) <--> (  3 ,  2 ) (  1 ,  3 ) <--> (  1 ,  4 ) (  4 ,  0 ) <--> (  3 ,  0 ) (  1 ,  4 ) <--> (  0 ,  4 ) (  1 ,  3 ) <--> (  0 ,  3 ) (  2 ,  3 ) <--> (  2 ,  4 ) (  0 ,  1 ) <--> (  0 ,  2 ) (  1 ,  0 ) <--> (  0 ,  0 ) (  2 ,  2 ) <--> (  1 ,  2 ) (  1 ,  2 ) <--> (  0 ,  2 ) (  2 ,  0 ) <--> (  2 ,  1 ) (  0 ,  0 ) <--> (  0 ,  1 ) (  3 ,  4 ) <--> (  2 ,  4 ) (  0 ,  2 ) <--> (  0 ,  3 ) (  3 ,  3 ) <--> (  3 ,  2 ) (  3 ,  0 ) <--> (  2 ,  0 ) (  1 ,  1 ) <--> (  1 ,  0 ) (  4 ,  4 ) <--> (  4 ,  3 ) (  4 ,  3 ) <--> (  4 ,  2 ) (  4 ,  1 ) <--> (  4 ,  2 ) <ADJLIST_END> <ORIGIN_START> (  1 ,  3 ) <ORIGIN_END> <TARGET_START> (  2 ,  3 ) <TARGET_END> <PATH_START> (  1 ,  3 ) (  0 ,  3 ) (  0 ,  2 ) (  1 ,  2 ) (  2 ,  2 ) (  2 ,  1 ) (  2 ,  0 ) (  3 ,  0 ) (  4 ,  0 ) (  4 ,  1 ) (  4 ,  2 ) (  4 ,  3 ) (  4 ,  4 ) (  3 ,  4 ) (  2 ,  4 ) (  2 ,  3 ) <PATH_END>
<ADJLIST_START> ( 2 , 2 ) <--> ( 3 , 2 ) ; ( 3 , 4 ) <--> ( 4 , 4 ) ; ( 4 , 1 ) <--> ( 4 , 0 ) ; ( 2 , 1 ) <--> ( 2 , 2 ) ; ( 3 , 1 ) <--> ( 3 , 2 ) ; ( 1 , 3 ) <--> ( 1 , 4 ) ; ( 4 , 0 ) <--> ( 3 , 0 ) ; ( 1 , 4 ) <--> ( 0 , 4 ) ; ( 1 , 3 ) <--> ( 0 , 3 ) ; ( 2 , 3 ) <--> ( 2 , 4 ) ; ( 0 , 1 ) <--> ( 0 , 2 ) ; ( 1 , 0 ) <--> ( 0 , 0 ) ; ( 2 , 2 ) <--> ( 1 , 2 ) ; ( 1 , 2 ) <--> ( 0 , 2 ) ; ( 2 , 0 ) <--> ( 2 , 1 ) ; ( 0 , 0 ) <--> ( 0 , 1 ) ; ( 3 , 4 ) <--> ( 2 , 4 ) ; ( 0 , 2 ) <--> ( 0 , 3 ) ; ( 3 , 3 ) <--> ( 3 , 2 ) ; ( 3 , 0 ) <--> ( 2 , 0 ) ; ( 1 , 1 ) <--> ( 1 , 0 ) ; ( 4 , 4 ) <--> ( 4 , 3 ) ; ( 4 , 3 ) <--> ( 4 , 2 ) ; ( 4 , 1 ) <--> ( 4 , 2 ) ; <ADJLIST_END> <ORIGIN_START> ( 1 , 3 ) <ORIGIN_END> <TARGET_START> ( 2 , 3 ) <TARGET_END> <PATH_START> ( 1 , 3 ) ( 0 , 3 ) ( 0 , 2 ) ( 1 , 2 ) ( 2 , 2 ) ( 2 , 1 ) ( 2 , 0 ) ( 3 , 0 ) ( 4 , 0 ) ( 4 , 1 ) ( 4 , 2 ) ( 4 , 3 ) ( 4 , 4 ) ( 3 , 4 ) ( 2 , 4 ) ( 2 , 3 ) <PATH_END>
now let's see how we can take the actual tokenized data to a SolvedMaze
. This is only possible with legacy tokenizers or their MazeTokenizerModular
equivalents.
In [9]:
maze_toks: list[str] = (
"""<ADJLIST_START> (1,1) <--> (2,1) ; (2,0) <--> (1,0) ; (0,1) <--> (0,0) ; (2,2) <--> (2,1) ; (2,0) <--> (2,1) ; (0,2) <--> (1,2) ; (0,0) <--> (1,0) ; (0,2) <--> (0,1) ; <ADJLIST_END> <ORIGIN_START> (0,0) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (0,0) (1,0) (2,0) (2,1) <PATH_END>""".split()
)
maze_encoded: list[int] = TOKENIZER.encode(maze_toks)
maze_tok_roundtrip: list[str] = TOKENIZER.decode(maze_encoded)
assert maze_toks == maze_tok_roundtrip
maze_from_toks: SolvedMaze = SolvedMaze.from_tokens(maze_toks, TOKENIZER)
print(maze_from_toks.as_ascii())
print(" ".join(maze_from_toks.as_tokens(TOKENIZER)))
####### #S # #X### # #X# # # #X# ### #XXE # ####### <ADJLIST_START> (0,1) <--> (0,2) ; (2,1) <--> (2,0) ; (2,1) <--> (1,1) ; (2,0) <--> (1,0) ; (2,1) <--> (2,2) ; (0,0) <--> (0,1) ; (0,2) <--> (1,2) ; (0,0) <--> (1,0) ; <ADJLIST_END> <ORIGIN_START> (0,0) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (0,0) (1,0) (2,0) (2,1) <PATH_END>
Now do the same for the the CTT
tokenizer.
In [10]:
maze_toks_indexed: list[str] = (
"""<ADJLIST_START> ( 1 , 1 ) <--> ( 2 , 1 ) ; ( 2 , 0 ) <--> ( 1 , 0 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; ( 2 , 2 ) <--> ( 2 , 1 ) ; ( 2 , 0 ) <--> ( 2 , 1 ) ; ( 0 , 2 ) <--> ( 1 , 2 ) ; ( 0 , 0 ) <--> ( 1 , 0 ) ; ( 0 , 2 ) <--> ( 0 , 1 ) ; <ADJLIST_END> <ORIGIN_START> ( 0 , 0 ) <ORIGIN_END> <TARGET_START> ( 2 , 1 ) <TARGET_END> <PATH_START> ( 0 , 0 ) ( 1 , 0 ) ( 2 , 0 ) ( 2 , 1 ) <PATH_END>""".split()
)
maze_encoded: list[int] = TOKENIZER_INDEXED.encode(maze_toks_indexed)
maze_tok_roundtrip: list[str] = TOKENIZER_INDEXED.decode(maze_encoded)
assert maze_toks_indexed == maze_tok_roundtrip
maze_from_toks_indexed: SolvedMaze = SolvedMaze.from_tokens(
maze_toks_indexed,
TOKENIZER_INDEXED,
)
assert maze_from_toks_indexed == maze_from_toks
print(maze_from_toks_indexed.as_ascii())
print(" ".join(maze_from_toks_indexed.as_tokens(TOKENIZER_INDEXED)))
####### #S # #X### # #X# # # #X# ### #XXE # ####### <ADJLIST_START> ( 1 , 1 ) <--> ( 2 , 1 ) ; ( 2 , 1 ) <--> ( 2 , 2 ) ; ( 0 , 1 ) <--> ( 0 , 2 ) ; ( 2 , 0 ) <--> ( 1 , 0 ) ; ( 1 , 0 ) <--> ( 0 , 0 ) ; ( 0 , 0 ) <--> ( 0 , 1 ) ; ( 0 , 2 ) <--> ( 1 , 2 ) ; ( 2 , 1 ) <--> ( 2 , 0 ) ; <ADJLIST_END> <ORIGIN_START> ( 0 , 0 ) <ORIGIN_END> <TARGET_START> ( 2 , 1 ) <TARGET_END> <PATH_START> ( 0 , 0 ) ( 1 , 0 ) ( 2 , 0 ) ( 2 , 1 ) <PATH_END>
Vocab index¶
special tokens come first, but then there are a few choices for the rest of the tokens:
TokenizationMode.AOTP_UT_rasterized
: unique token for each coord, order is simple rasterizationTokenizationMode.AOTP_UT_uniform
: unique token for each coord, order assembled to preserve uniformity regardless of maze sizeTokenizationMode.AOTP_CTT_indexed
: each coordinate is 5 tokens:( i , j )
wherei
andj
are the coordinates
In [11]:
def plot_corner_first_ndindex(n: int, ndim: int = 2) -> None:
"""Plot a figure that shows the order of each grid point in the list
provided by the function corner_first_ndindex.
"""
indices = corner_first_ndindex(n, ndim)
# Create a 2D grid to store the order of each index
grid = np.zeros((n, n), dtype=int)
for order, (x, y) in enumerate(indices):
grid[x, y] = order + 1 # Adding 1 to start the order from 1 instead of 0
fig, ax = plt.subplots(figsize=(2, 2))
# Plot the grid
cax = ax.matshow(grid, cmap=plt.cm.Blues)
# Annotate each cell with its order
for i in range(n):
for j in range(n):
c = grid[j, i]
ax.text(i, j, str(c), va="center", ha="center")
plt.title("Order of Grid Points in Vocabulary")
plt.xlabel("X-axis")
plt.ylabel("Y-axis")
plt.colorbar(cax)
# plt.savefig("corner-first-vocab.pdf")
plt.show()
# Example plot for n=3
plot_corner_first_ndindex(5)