# Stats - 64 files - 16400 (16K) lines - 532427 (532K) chars - 60025 (60K) `whitespace-split` tokens # File Tree ``` maze-dataset ├── .github │ └── workflows │ ├── all-tok-checks.yml [ 55L 1,728C 155T] │ ├── build-paper.yml [ 28L 785C 85T] │ ├── checks.yml [ 94L 2,164C 224T] │ ├── make-docs.yml [ 47L 1,262C 141T] │ └── speed-benchmark.yml [ 48L 1,261C 109T] ├── maze_dataset │ ├── benchmark │ │ ├── __init__.py [ 9L 220C 24T] │ │ ├── config_sweep.py [ 564L 17,188C 1,819T] │ │ ├── speed.py [ 132L 3,053C 327T] │ │ └── sweep_fit.py [ 407L 11,569C 1,369T] │ ├── dataset │ │ ├── __init__.py [ 45L 2,111C 243T] │ │ ├── collected_dataset.py [ 241L 8,113C 848T] │ │ ├── configs.py [ 285L 8,000C 668T] │ │ ├── dataset.py [ 553L 18,539C 2,133T] │ │ ├── filters.py [ 295L 9,571C 1,100T] │ │ ├── maze_dataset.py [ 632L 20,310C 2,073T] │ │ ├── maze_dataset_config.py [ 496L 17,917C 2,019T] │ │ ├── rasterized.py [ 330L 9,476C 1,001T] │ │ └── success_predict_math.py [ 106L 2,809C 460T] │ ├── generation │ │ ├── __init__.py [ 23L 534C 54T] │ │ ├── default_generators.py [ 23L 817C 72T] │ │ ├── generators.py [ 651L 23,337C 2,932T] │ │ └── seed.py [ 3L 45C 7T] │ ├── maze │ │ ├── __init__.py [ 34L 1,905C 272T] │ │ └── lattice_maze.py [1,532L 49,541C 5,753T] │ ├── plotting │ │ ├── __init__.py [ 33L 1,008C 110T] │ │ ├── plot_dataset.py [ 64L 1,635C 200T] │ │ ├── plot_maze.py [ 603L 16,819C 1,950T] │ │ ├── plot_svg_fancy.py [ 199L 5,617C 863T] │ │ ├── plot_tokens.py [ 90L 2,127C 280T] │ │ └── print_tokens.py [ 250L 6,431C 799T] │ ├── tokenization │ │ ├── modular │ │ │ ├── __init__.py [ 90L 5,930C 797T] │ │ │ ├── all_instances.py [ 263L 10,748C 1,388T] │ │ │ ├── all_tokenizers.py [ 219L 8,375C 855T] │ │ │ ├── element_base.py [ 314L 11,410C 1,396T] │ │ │ ├── elements.py [1,294L 41,681C 4,161T] │ │ │ ├── fst.py [ 200L 5,706C 619T] │ │ │ ├── fst_load.py [ 95L 3,718C 485T] │ │ │ ├── hashing.py [ 95L 3,359C 390T] │ │ │ ├── maze_tokenizer_modular.py [ 349L 11,291C 1,189T] │ │ │ └── save_hashes.py [ 124L 3,399C 294T] │ │ ├── __init__.py [ 125L 8,653C 1,156T] │ │ ├── common.py [ 7L 106C 11T] │ │ ├── maze_tokenizer.py [ 15L 307C 23T] │ │ └── maze_tokenizer_legacy.py [ 500L 15,213C 1,641T] │ ├── __init__.py [ 69L 1,412C 106T] │ ├── constants.py [ 232L 7,193C 773T] │ ├── py.typed [ 0L 0C 0T] │ ├── testing_utils.py [ 178L 8,548C 1,369T] │ ├── token_utils.py [ 540L 17,351C 2,009T] │ └── utils.py [ 184L 5,237C 699T] ├── notebooks │ ├── demo_dataset.ipynb [1,037L 507,898C 2,974T] │ ├── demo_generator.ipynb [ 197L 56,358C 356T] │ ├── demo_latticemaze.ipynb [ 751L 293,508C 2,318T] │ ├── demo_mazetokenizermodular.ipynb [1,788L 92,501C 7,020T] │ ├── demo_tokenization.ipynb [ 434L 145,864C 9,259T] │ ├── estimate_dataset_fractions.ipynb [ 503L 3,265,942C 1,368T] │ ├── forking_points.ipynb [ 221L 28,462C 621T] │ ├── iterated_backfilling.ipynb [ 207L 30,509C 547T] │ ├── output_formats.ipynb [ 235L 20,199C 1,348T] │ └── profile_dataset_save_read.ipynb [1,133L 269,709C 3,406T] ├── templates ├── LICENSE.md [ 427L 20,126C 2,768T] ├── README.md [ 186L 8,054C 696T] ├── makefile [1,851L 57,868C 6,652T] ├── pyproject.toml [ 364L 12,919C 1,470T] ``` # File Contents ``````{ path=".github/workflows/all-tok-checks.yml" } name: All Tokenizer Checks on: push: paths: - 'maze_dataset/utils.py' # temporary - 'maze_dataset/token_utils.py' # temporary - 'maze_dataset/constants.py' - 'maze_dataset/tokenization/*.py' - 'maze_dataset/tokenization/modular/*.py' - 'maze_dataset/tokenization/modular/MazeTokenizerModular_tested.fst' - 'notebooks/demo_mazetokenizermodular.ipynb' - 'tests/all_tokenizers/*.py' - 'pyproject.toml' # on new version or update deps - '.github/workflows/all-tok-checks.yml' # changing this file - '.lastversion' # on new release workflow_dispatch: inputs: n_to_test: description: 'Number of tokenizers to test' required: false default: 10000 type: number pytest_parallel: description: '1 to parallelize tests with -n auto, to run without parallelization' required: false default: 1 type: number jobs: all_tok_test: name: All Tokenizer Tests runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: "3.11" - name: set up uv run: curl -LsSf https://astral.sh/uv/install.sh | sh - name: install dependencies run: make dep - name: tokenizer fst check (check all, this will take a while) run: | make tokenizer-fst-check - name: long tokenizer tests run: | N_TO_TEST=${{ github.event.inputs.n_to_test || '10000' }} PYTEST_PARALLEL=${{ github.event.inputs.pytest_parallel || '1' }} make tokenizer-test-long NUM_TOKENIZERS_TO_TEST=$N_TO_TEST PYTEST_PARALLEL=$PYTEST_PARALLEL ``````{ end_of_file=".github/workflows/all-tok-checks.yml" } ``````{ path=".github/workflows/build-paper.yml" } name: build paper draft pdf on: push: paths: - docs/paper/** - .github/workflows/build-paper.yml jobs: paper: runs-on: ubuntu-latest name: Paper Draft steps: - name: Checkout uses: actions/checkout@v4 - name: Build draft PDF uses: openjournals/openjournals-draft-action@master with: journal: joss # This should be the path to the paper within your repo. paper-path: docs/paper/paper.md - name: Upload uses: actions/upload-artifact@v4 with: name: paper # This is the output path where Pandoc will write the compiled # PDF. Note, this should be the same directory as the input # paper.md path: docs/paper/paper.pdf ``````{ end_of_file=".github/workflows/build-paper.yml" } ``````{ path=".github/workflows/checks.yml" } name: Checks on: workflow_dispatch: pull_request: branches: - '*' push: branches: - main jobs: lint: name: Formatting runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 with: fetch-depth: 0 - name: Install linters run: pip install -r .meta/requirements/requirements-lint.txt - name: Run Format Checks run: make format-check RUN_GLOBAL=1 test: name: Test runs-on: ubuntu-latest strategy: matrix: python: ["3.10", "3.11", "3.12", "3.13"] pkg: - numpy: "1.24.4" group: "legacy" - numpy: "" group: "latest" exclude: - python: "3.12" pkg: group: "legacy" - python: "3.13" pkg: group: "legacy" steps: - name: Checkout code uses: actions/checkout@v4 with: fetch-depth: 1 - name: Set up python uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - name: set up uv run: curl -LsSf https://astral.sh/uv/install.sh | sh - name: check dependencies run: make dep-check - name: install dependencies and package run: make dep - name: Install different numpy version if: ${{ matrix.pkg.numpy != '' }} run: uv pip install numpy==${{ matrix.pkg.numpy }} - name: info run: make info-long - name: format check run: make format-check - name: Unit tests run: make test-unit - name: Notebook tests (nbmake) run: make test-notebooks-nbmake - name: Notebook tests (muutils) run: make test-notebooks-muutils - name: check typing (3.11+) run: make typing if: ${{ matrix.python != '3.10' }} - name: run benchmarks tests run: make benchmark-test - name: Test Tokenizer fst # test a reduced number of tokenizers run: make tokenizer-fst-check NUM_TOKENIZERS_TO_TEST=1000 ``````{ end_of_file=".github/workflows/checks.yml" } ``````{ path=".github/workflows/make-docs.yml" } # this workflow partially copied from # https://github.com/TransformerLensOrg/TransformerLens/blob/main/.github/workflows/checks.yml name: make docs on: pull_request: branches: - main push: branches: - main jobs: build-docs: # When running on a PR, this just checks we can build the docs without errors # When running on merge to main, it builds the docs and then another job deploys them name: 'Build Docs' runs-on: ubuntu-latest if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev') || contains(github.head_ref, 'docs') steps: - name: Install pandoc uses: awalsh128/cache-apt-pkgs-action@latest with: packages: pandoc version: '3.3' - name: Check pandoc version run: pandoc --version - name: Checkout code uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Python uses: actions/setup-python@v5 with: python-version: '3.11' - name: set up uv run: curl -LsSf https://astral.sh/uv/install.sh | sh - name: Install package and deps run: make dep - name: Build Docs run: make docs ``````{ end_of_file=".github/workflows/make-docs.yml" } ``````{ path=".github/workflows/speed-benchmark.yml" } name: Benchmark Generation on: workflow_dispatch: inputs: analysis_type: description: 'Benchmark analysis type to run' required: true default: 'large' type: choice options: - test - default - large jobs: benchmark: name: Run Benchmark runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 with: fetch-depth: 1 - name: Set up python uses: actions/setup-python@v5 with: python-version: '3.11' - name: Set up uv run: curl -LsSf https://astral.sh/uv/install.sh | sh - name: Install dependencies and package run: make dep - name: Info run: make info-long - name: Run benchmark run: uv run python docs/benchmarks/benchmark_generation.py ${{ github.event.inputs.analysis_type }} --save-path benchmarks/results/benchmark_data.jsonl - name: Upload benchmark results uses: actions/upload-artifact@v4 with: name: benchmark-results-${{ github.event.inputs.analysis_type }} path: benchmarks/results/benchmark_data.jsonl retention-days: 90 ``````{ end_of_file=".github/workflows/speed-benchmark.yml" } ``````{ path="maze_dataset/benchmark/__init__.py" } """benchmarking the speed or success rate of maze generation. you can view generated benchmark results here: https://understanding-search.github.io/maze-dataset/benchmarks/ """ __all__ = [ "config_sweep", "speed", ] ``````{ end_of_file="maze_dataset/benchmark/__init__.py" } ``````{ path="maze_dataset/benchmark/config_sweep.py" } """Benchmarking of how successful maze generation is for various values of percolation""" import functools import json import warnings from pathlib import Path from typing import Any, Callable, Generic, Literal, Sequence, TypeVar import matplotlib.pyplot as plt import numpy as np from jaxtyping import Float from muutils.dictmagic import dotlist_to_nested_dict, update_with_nested_dict from muutils.json_serialize import ( JSONitem, SerializableDataclass, json_serialize, serializable_dataclass, serializable_field, ) from muutils.parallel import run_maybe_parallel from zanj import ZANJ from maze_dataset import MazeDataset, MazeDatasetConfig from maze_dataset.generation import LatticeMazeGenerators SweepReturnType = TypeVar("SweepReturnType") ParamType = TypeVar("ParamType") AnalysisFunc = Callable[[MazeDatasetConfig], SweepReturnType] def dataset_success_fraction(cfg: MazeDatasetConfig) -> float: """empirical success fraction of maze generation for use as an `analyze_func` in `sweep()` """ dataset: MazeDataset = MazeDataset.from_config( cfg, do_download=False, load_local=False, save_local=False, verbose=False, ) return len(dataset) / cfg.n_mazes ANALYSIS_FUNCS: dict[str, AnalysisFunc] = dict( dataset_success_fraction=dataset_success_fraction, ) def sweep( cfg_base: MazeDatasetConfig, param_values: list[ParamType], param_key: str, analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], ) -> list[SweepReturnType]: """given a base config, parameter values list, key, and analysis function, return the results of the analysis function for each parameter value # Parameters: - `cfg_base : MazeDatasetConfig` base config on which we will modify the value at `param_key` with values from `param_values` - `param_values : list[ParamType]` list of values to try - `param_key : str` value to modify in `cfg_base` - `analyze_func : Callable[[MazeDatasetConfig], SweepReturnType]` function which analyzes the resulting config. originally built for `dataset_success_fraction` # Returns: - `list[SweepReturnType]` _description_ """ outputs: list[SweepReturnType] = [] for p in param_values: # update the config cfg_dict: dict = cfg_base.serialize() update_with_nested_dict( cfg_dict, dotlist_to_nested_dict({param_key: p}), ) cfg_test: MazeDatasetConfig = MazeDatasetConfig.load(cfg_dict) outputs.append(analyze_func(cfg_test)) return outputs @serializable_dataclass() class SweepResult(SerializableDataclass, Generic[ParamType, SweepReturnType]): """result of a parameter sweep""" configs: list[MazeDatasetConfig] = serializable_field( serialization_fn=lambda cfgs: [cfg.serialize() for cfg in cfgs], deserialize_fn=lambda cfgs: [MazeDatasetConfig.load(cfg) for cfg in cfgs], ) param_values: list[ParamType] = serializable_field( serialization_fn=lambda x: json_serialize(x), deserialize_fn=lambda x: x, assert_type=False, ) result_values: dict[str, Sequence[SweepReturnType]] = serializable_field( serialization_fn=lambda x: json_serialize(x), deserialize_fn=lambda x: x, assert_type=False, ) param_key: str analyze_func: Callable[[MazeDatasetConfig], SweepReturnType] = serializable_field( serialization_fn=lambda f: f.__name__, deserialize_fn=ANALYSIS_FUNCS.get, assert_type=False, ) def summary(self) -> JSONitem: "human-readable and json-dumpable short summary of the result" return { "len(configs)": len(self.configs), "len(param_values)": len(self.param_values), "len(result_values)": len(self.result_values), "param_key": self.param_key, "analyze_func": self.analyze_func.__name__, } def save(self, path: str | Path, z: ZANJ | None = None) -> None: "save to a file with zanj" if z is None: z = ZANJ() z.save(self, path) @classmethod def read(cls, path: str | Path, z: ZANJ | None = None) -> "SweepResult": "read from a file with zanj" if z is None: z = ZANJ() return z.read(path) def configs_by_name(self) -> dict[str, MazeDatasetConfig]: "return configs by name" return {cfg.name: cfg for cfg in self.configs} def configs_by_key(self) -> dict[str, MazeDatasetConfig]: "return configs by the key used in `result_values`, which is the filename of the config" return {cfg.to_fname(): cfg for cfg in self.configs} def configs_shared(self) -> dict[str, Any]: "return key: value pairs that are shared across all configs" # we know that the configs all have the same keys, # so this way of doing it is fine config_vals: dict[str, set[Any]] = dict() for cfg in self.configs: for k, v in cfg.serialize().items(): if k not in config_vals: config_vals[k] = set() config_vals[k].add(json.dumps(v)) shared_vals: dict[str, Any] = dict() cfg_ser: dict = self.configs[0].serialize() for k, v in config_vals.items(): if len(v) == 1: shared_vals[k] = cfg_ser[k] return shared_vals def configs_differing_keys(self) -> set[str]: "return keys that differ across configs" shared_vals: dict[str, Any] = self.configs_shared() differing_keys: set[str] = set() for k in MazeDatasetConfig.__dataclass_fields__: if k not in shared_vals: differing_keys.add(k) return differing_keys def configs_value_set(self, key: str) -> list[Any]: "return a list of the unique values for a given key" d: dict[str, Any] = { json.dumps(json_serialize(getattr(cfg, key))): getattr(cfg, key) for cfg in self.configs } return list(d.values()) def get_where(self, key: str, val_check: Callable[[Any], bool]) -> "SweepResult": "get a subset of this `Result` where the configs has `key` satisfying `val_check`" configs_list: list[MazeDatasetConfig] = [ cfg for cfg in self.configs if val_check(getattr(cfg, key)) ] configs_keys: set[str] = {cfg.to_fname() for cfg in configs_list} result_values: dict[str, Sequence[SweepReturnType]] = { k: self.result_values[k] for k in configs_keys } return SweepResult( configs=configs_list, param_values=self.param_values, result_values=result_values, param_key=self.param_key, analyze_func=self.analyze_func, ) @classmethod def analyze( cls, configs: list[MazeDatasetConfig], param_values: list[ParamType], param_key: str, analyze_func: Callable[[MazeDatasetConfig], SweepReturnType], parallel: bool | int = False, **kwargs, ) -> "SweepResult": """Analyze success rate of maze generation for different percolation values # Parameters: - `configs : list[MazeDatasetConfig]` configs to try - `param_values : np.ndarray` numpy array of values to try # Returns: - `SweepResult` """ n_pvals: int = len(param_values) result_values_list: list[float] = run_maybe_parallel( # TYPING: error: Argument "func" to "run_maybe_parallel" has incompatible type "partial[list[SweepReturnType]]"; expected "Callable[[MazeDatasetConfig], float]" [arg-type] func=functools.partial( # type: ignore[arg-type] sweep, param_values=param_values, param_key=param_key, analyze_func=analyze_func, ), iterable=configs, keep_ordered=True, parallel=parallel, pbar_kwargs=dict(total=len(configs)), **kwargs, ) result_values: dict[str, Float[np.ndarray, n_pvals]] = { cfg.to_fname(): np.array(res) for cfg, res in zip(configs, result_values_list, strict=False) } return cls( configs=configs, param_values=param_values, # TYPING: error: Argument "result_values" to "SweepResult" has incompatible type "dict[str, ndarray[Any, Any]]"; expected "dict[str, Sequence[SweepReturnType]]" [arg-type] result_values=result_values, # type: ignore[arg-type] param_key=param_key, analyze_func=analyze_func, ) def plot( self, save_path: str | None = None, cfg_keys: list[str] | None = None, cmap_name: str | None = "viridis", plot_only: bool = False, show: bool = True, ax: plt.Axes | None = None, minify_title: bool = False, legend_kwargs: dict[str, Any] | None = None, ) -> plt.Axes: """Plot the results of percolation analysis""" # set up figure if not ax: fig: plt.Figure ax_: plt.Axes fig, ax_ = plt.subplots(1, 1, figsize=(22, 10)) else: ax_ = ax # plot cmap = plt.get_cmap(cmap_name) n_cfgs: int = len(self.result_values) for i, (ep_cfg_name, result_values) in enumerate( sorted( self.result_values.items(), # HACK: sort by grid size # |--< name of config # | |-----------< gets 'g{n}' # | | |--< gets '{n}' # | | | key=lambda x: int(x[0].split("-")[0][1:]), ), ): ax_.plot( # TYPING: error: Argument 1 to "plot" of "Axes" has incompatible type "list[ParamType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] self.param_values, # type: ignore[arg-type] # TYPING: error: Argument 2 to "plot" of "Axes" has incompatible type "Sequence[SweepReturnType]"; expected "float | Buffer | _SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes] | str" [arg-type] result_values, # type: ignore[arg-type] ".-", label=self.configs_by_key()[ep_cfg_name].name, color=cmap((i + 0.5) / (n_cfgs - 0.5)), ) # repr of config cfg_shared: dict = self.configs_shared() if minify_title: cfg_shared["endpoint_kwargs"] = { k: v for k, v in cfg_shared["endpoint_kwargs"].items() if k != "except_on_no_valid_endpoint" } cfg_repr: str = ( str(cfg_shared) if cfg_keys is None else ( "MazeDatasetConfig(" + ", ".join( [ f"{k}={cfg_shared[k].__name__}" # TYPING: error: Argument 2 to "isinstance" has incompatible type ""; expected "_ClassInfo" [arg-type] if isinstance(cfg_shared[k], Callable) # type: ignore[arg-type] else f"{k}={cfg_shared[k]}" for k in cfg_keys ], ) + ")" ) ) # add title and stuff if not plot_only: ax_.set_xlabel(self.param_key) ax_.set_ylabel(self.analyze_func.__name__) ax_.set_title( f"{self.param_key} vs {self.analyze_func.__name__}\n{cfg_repr}", ) ax_.grid(True) # ax_.legend(loc="upper center", ncol=2, bbox_to_anchor=(0.5, -0.1)) legend_kwargs = { **dict(loc="center left"), **(legend_kwargs or dict()), } ax_.legend(**legend_kwargs) # save and show if save_path: plt.savefig(save_path) if show: plt.show() return ax_ DEFAULT_ENDPOINT_KWARGS: list[tuple[str, dict]] = [ ( "any", dict(deadend_start=False, deadend_end=False, except_on_no_valid_endpoint=False), ), ( "deadends", dict( deadend_start=True, deadend_end=True, endpoints_not_equal=False, except_on_no_valid_endpoint=False, ), ), ( "deadends_unique", dict( deadend_start=True, deadend_end=True, endpoints_not_equal=True, except_on_no_valid_endpoint=False, ), ), ] def endpoint_kwargs_to_name(ep_kwargs: dict) -> str: """convert endpoint kwargs options to a human-readable name""" if ep_kwargs.get("deadend_start", False) or ep_kwargs.get("deadend_end", False): if ep_kwargs.get("endpoints_not_equal", False): return "deadends_unique" else: return "deadends" else: return "any" def full_percolation_analysis( n_mazes: int, p_val_count: int, grid_sizes: list[int], ep_kwargs: list[tuple[str, dict]] | None = None, generators: Sequence[Callable] = ( LatticeMazeGenerators.gen_percolation, LatticeMazeGenerators.gen_dfs_percolation, ), save_dir: Path = Path("../docs/benchmarks/percolation_fractions"), parallel: bool | int = False, **analyze_kwargs, ) -> SweepResult: "run the full analysis of how percolation affects maze generation success" if ep_kwargs is None: ep_kwargs = DEFAULT_ENDPOINT_KWARGS # configs configs: list[MazeDatasetConfig] = list() # TODO: B007 noqaed because we dont use `ep_kw_name` or `gf_idx` for ep_kw_name, ep_kw in ep_kwargs: # noqa: B007 for gf_idx, gen_func in enumerate(generators): # noqa: B007 configs.extend( [ MazeDatasetConfig( name=f"g{grid_n}-{gen_func.__name__.removeprefix('gen_').removesuffix('olation')}", grid_n=grid_n, n_mazes=n_mazes, maze_ctor=gen_func, maze_ctor_kwargs=dict(p=float("nan")), endpoint_kwargs=ep_kw, ) for grid_n in grid_sizes ], ) # get results result: SweepResult = SweepResult.analyze( configs=configs, # type: ignore[misc] # TYPING: error: Argument "param_values" to "analyze" of "SweepResult" has incompatible type "float | list[float] | list[list[float]] | list[list[list[Any]]]"; expected "list[Any]" [arg-type] param_values=np.linspace(0.0, 1.0, p_val_count).tolist(), # type: ignore[arg-type] param_key="maze_ctor_kwargs.p", analyze_func=dataset_success_fraction, parallel=parallel, **analyze_kwargs, ) # save the result results_path: Path = ( save_dir / f"result-n{n_mazes}-c{len(configs)}-p{p_val_count}.zanj" ) print(f"Saving results to {results_path.as_posix()}") result.save(results_path) return result def _is_eq(a, b) -> bool: # noqa: ANN001 """check if two objects are equal""" return a == b def plot_grouped( # noqa: C901 results: SweepResult, predict_fn: Callable[[MazeDatasetConfig], float] | None = None, prediction_density: int = 50, save_dir: Path | None = None, show: bool = True, logy: bool = False, save_fmt: str = "svg", figsize: tuple[int, int] = (22, 10), minify_title: bool = False, legend_kwargs: dict[str, Any] | None = None, manual_titles: dict[Literal["x", "y", "title"], str] | None = None, ) -> None: """Plot grouped sweep percolation value results for each distinct `endpoint_kwargs` in the configs with separate colormaps for each maze generator function # Parameters: - `results : SweepResult` The sweep results to plot - `predict_fn : Callable[[MazeDatasetConfig], float] | None` Optional function that predicts success rate from a config. If provided, will plot predictions as dashed lines. - `prediction_density : int` Number of points to use for prediction curves (default: 50) - `save_dir : Path | None` Directory to save plots (defaults to `None`, meaning no saving) - `show : bool` Whether to display the plots (defaults to `True`) # Usage: ```python >>> result = full_analysis(n_mazes=100, p_val_count=11, grid_sizes=[8,16]) >>> plot_grouped(result, save_dir=Path("./plots"), show=False) ``` """ # groups endpoint_kwargs_set: list[dict] = results.configs_value_set("endpoint_kwargs") # type: ignore[assignment] generator_funcs_names: list[str] = list( {cfg.maze_ctor.__name__ for cfg in results.configs}, ) # if predicting, create denser p values if predict_fn is not None: p_dense = np.linspace(0.0, 1.0, prediction_density) # separate plot for each set of endpoint kwargs for ep_kw in endpoint_kwargs_set: results_epkw: SweepResult = results.get_where( "endpoint_kwargs", functools.partial(_is_eq, b=ep_kw), # lambda x: x == ep_kw, ) shared_keys: set[str] = set(results_epkw.configs_shared().keys()) cfg_keys: set[str] = shared_keys.intersection({"n_mazes", "endpoint_kwargs"}) fig, ax = plt.subplots(1, 1, figsize=figsize) for gf_idx, gen_func in enumerate(generator_funcs_names): results_filtered: SweepResult = results_epkw.get_where( "maze_ctor", # HACK: big hassle to do this without a lambda, is it really that bad? lambda x: x.__name__ == gen_func, # noqa: B023 ) if len(results_filtered.configs) < 1: warnings.warn( f"No results for {gen_func} and {ep_kw}. Skipping.", ) continue cmap_name = "Reds" if gf_idx == 0 else "Blues" cmap = plt.get_cmap(cmap_name) # Plot actual results ax = results_filtered.plot( cfg_keys=list(cfg_keys), ax=ax, show=False, cmap_name=cmap_name, minify_title=minify_title, legend_kwargs=legend_kwargs, ) if logy: ax.set_yscale("log") # Plot predictions if function provided if predict_fn is not None: for cfg_idx, cfg in enumerate(results_filtered.configs): predictions = [] for p in p_dense: cfg_temp = MazeDatasetConfig.load(cfg.serialize()) cfg_temp.maze_ctor_kwargs["p"] = p predictions.append(predict_fn(cfg_temp)) # Get the same color as the actual data n_cfgs: int = len(results_filtered.configs) color = cmap((cfg_idx + 0.5) / (n_cfgs - 0.5)) # Plot prediction as dashed line ax.plot(p_dense, predictions, "--", color=color, alpha=0.8) if manual_titles: ax.set_xlabel(manual_titles["x"]) ax.set_ylabel(manual_titles["y"]) ax.set_title(manual_titles["title"]) # save and show if save_dir: save_path: Path = ( save_dir / f"ep_{endpoint_kwargs_to_name(ep_kw)}.{save_fmt}" ) print(f"Saving plot to {save_path.as_posix()}") save_path.parent.mkdir(exist_ok=True, parents=True) plt.savefig(save_path) if show: plt.show() ``````{ end_of_file="maze_dataset/benchmark/config_sweep.py" } ``````{ path="maze_dataset/benchmark/speed.py" } "benchmark the speed of maze generation" import functools import random import timeit from pathlib import Path from typing import Any, Sequence from tqdm import tqdm from maze_dataset import MazeDataset, MazeDatasetConfig from maze_dataset.generation.default_generators import DEFAULT_GENERATORS from maze_dataset.generation.generators import GENERATORS_MAP _BASE_CFG_KWARGS: dict = dict( grid_n=None, n_mazes=None, ) _GENERATE_KWARGS: dict = dict( gen_parallel=False, pool_kwargs=None, verbose=False, # do_generate = True, # load_local = False, # save_local = False, # zanj = None, # do_download = False, # local_base_path = "INVALID", # except_on_config_mismatch = True, # verbose = False, ) def time_generation( base_configs: list[tuple[str, dict]], grid_n_vals: list[int], n_mazes_vals: list[int], trials: int = 10, verbose: bool = False, ) -> list[dict[str, Any]]: "time the generation of mazes for various configurations" # assemble configs configs: list[MazeDatasetConfig] = list() for b_cfg in base_configs: for grid_n in grid_n_vals: for n_mazes in n_mazes_vals: configs.append( MazeDatasetConfig( name="benchmark", grid_n=grid_n, n_mazes=n_mazes, maze_ctor=GENERATORS_MAP[b_cfg[0]], maze_ctor_kwargs=b_cfg[1], ), ) # shuffle configs (in place) (otherwise progress bar is annoying) random.shuffle(configs) # time generation for each config times: list[dict[str, Any]] = list() total: int = len(configs) for idx, cfg in tqdm( enumerate(configs), desc="Timing generation", unit="config", total=total, disable=verbose, ): if verbose: print(f"Timing generation for config {idx + 1}/{total}\n{cfg}") t: float = ( timeit.timeit( stmt=functools.partial(MazeDataset.generate, cfg, **_GENERATE_KWARGS), # type: ignore[arg-type] number=trials, ) / trials ) if verbose: print(f"avg time: {t:.3f} s") times.append( dict( cfg_name=cfg.name, grid_n=cfg.grid_n, n_mazes=cfg.n_mazes, maze_ctor=cfg.maze_ctor.__name__, maze_ctor_kwargs=cfg.maze_ctor_kwargs, trials=trials, time=t, ), ) return times def run_benchmark( save_path: str, base_configs: list[tuple[str, dict]] | None = None, grid_n_vals: Sequence[int] = (2, 3, 4, 5, 8, 10, 16, 25, 32), n_mazes_vals: Sequence[int] = tuple(range(1, 12, 2)), trials: int = 10, verbose: bool = True, ) -> "pd.DataFrame": # type: ignore[name-defined] # noqa: F821 "run the benchmark and save the results to a file" import pandas as pd if base_configs is None: base_configs = DEFAULT_GENERATORS times: list[dict] = time_generation( base_configs=base_configs, grid_n_vals=list(grid_n_vals), n_mazes_vals=list(n_mazes_vals), trials=trials, verbose=verbose, ) df: pd.DataFrame = pd.DataFrame(times) # print the whole dataframe contents to console as csv print(df.to_csv()) # save to file Path(save_path).parent.mkdir(parents=True, exist_ok=True) df.to_json(save_path, orient="records", lines=True) return df ``````{ end_of_file="maze_dataset/benchmark/speed.py" } ``````{ path="maze_dataset/benchmark/sweep_fit.py" } """Fit a PySR model to a sweep result and plot the results""" from pathlib import Path from typing import TYPE_CHECKING, Callable import numpy as np import sympy as sp # type: ignore[import-untyped] from jaxtyping import Float from pysr import PySRRegressor # type: ignore[import-untyped] from maze_dataset import MazeDatasetConfig from maze_dataset.benchmark.config_sweep import ( SweepResult, plot_grouped, ) def extract_training_data( sweep_result: SweepResult, ) -> tuple[Float[np.ndarray, "num_rows 5"], Float[np.ndarray, " num_rows"]]: """Extract data (X, y) from a SweepResult. # Parameters: - `sweep_result : SweepResult` The sweep result holding configs and success arrays. # Returns: - `X : Float[np.ndarray, "num_rows 5"]` Stacked [p, grid_n, deadends, endpoints_not_equal, generator_func] for each config & param-value - `y : Float[np.ndarray, "num_rows"]` The corresponding success rate """ x_list: list[list[float]] = [] y_list: list[float] = [] for cfg in sweep_result.configs: # success_arr is an array of success rates for param_values success_arr = sweep_result.result_values[cfg.to_fname()] for i, p in enumerate(sweep_result.param_values): # Temporarily override p in the config's array representation: arr = cfg._to_ps_array().copy() arr[0] = p # index 0 is 'p' x_list.append(arr) # type: ignore[arg-type] y_list.append(success_arr[i]) return np.array(x_list, dtype=np.float64), np.array(y_list, dtype=np.float64) DEFAULT_PYSR_KWARGS: dict = dict( niterations=50, unary_operators=[ "exp", "log", "square(x) = x^2", "cube(x) = x^3", "sigmoid(x) = 1/(1 + exp(-x))", ], extra_sympy_mappings={ "square": lambda x: x**2, "cube": lambda x: x**3, "sigmoid": lambda x: 1 / (1 + sp.exp(-x)), }, binary_operators=["+", "-", "*", "/", "^"], # populations=50, progress=True, model_selection="best", ) def train_pysr_model( data: SweepResult, **pysr_kwargs, ) -> PySRRegressor: """Train a PySR model on the given sweep result data""" # Convert to arrays x, y = extract_training_data(data) print(f"training data extracted: {x.shape = }, {y.shape = }") # Fit the PySR model model: PySRRegressor = PySRRegressor(**{**DEFAULT_PYSR_KWARGS, **pysr_kwargs}) model.fit(x, y) return model def plot_model( data: SweepResult, model: PySRRegressor, save_dir: Path, show: bool = True, ) -> None: """Plot the model predictions against the sweep data""" # save all the equations save_dir.mkdir(parents=True, exist_ok=True) equations_file: Path = save_dir / "equations.txt" equations_file.write_text(repr(model)) print(f"Equations saved to: {equations_file = }") # Create a callable that predicts from MazeDatasetConfig predict_fn: Callable = model.get_best()["lambda_format"] print(f"Best PySR Equation: {model.get_best()['equation'] = }") print(f"{predict_fn =}") def predict_config(cfg: MazeDatasetConfig) -> float: arr = cfg._to_ps_array() result = predict_fn(arr)[0] return float(result) # pass the array as separate args plot_grouped( data, predict_fn=predict_config, save_dir=save_dir, show=show, ) def sweep_fit( data_path: Path, save_dir: Path, **pysr_kwargs, ) -> None: """read a sweep result, train a PySR model, and plot the results""" # Load the sweep result data: SweepResult = SweepResult.read(data_path) print(f"loaded data: {data.summary() = }") # Train the PySR model model: PySRRegressor = train_pysr_model(data, **pysr_kwargs) # Plot the model plot_model(data, model, save_dir, show=False) if __name__ == "__main__": import argparse argparser: argparse.ArgumentParser = argparse.ArgumentParser() argparser.add_argument( "data_path", type=Path, help="Path to the sweep result file", ) argparser.add_argument( "--save_dir", type=Path, default=Path("tests/_temp/percolation_fractions/fit_plots/"), help="Path to save the plots", ) argparser.add_argument( "--niterations", type=int, default=50, help="Number of iterations for PySR", ) args: argparse.Namespace = argparser.parse_args() sweep_fit( args.data_path, args.save_dir, niterations=args.niterations, # add any additional kwargs here if running in CLI populations=50, # ^ Assuming we have 4 cores, this means 2 populations per core, so one is always running. population_size=50, # ^ Generations between migrations. timeout_in_seconds=60 * 60 * 7, # ^ stop after 7 hours have passed. maxsize=50, # ^ Allow greater complexity. weight_randomize=0.01, # ^ Randomize the tree much more frequently turbo=True, # ^ Faster evaluation (experimental) ) def create_interactive_plot(heatmap: bool = True) -> None: # noqa: C901, PLR0915 """Create an interactive plot with the specified grid layout # Parameters: - `heatmap : bool` Whether to show heatmaps (defaults to `True`) """ import ipywidgets as widgets # type: ignore[import-untyped] import matplotlib.pyplot as plt from ipywidgets import FloatSlider, HBox, Layout, VBox from matplotlib.gridspec import GridSpec from maze_dataset.dataset.success_predict_math import soft_step # Create sliders with better layout x_slider = FloatSlider( min=0.0, max=1.0, step=0.01, value=0.5, description="x:", style={"description_width": "30px"}, layout=Layout(width="98%"), ) p_slider = FloatSlider( min=0.0, max=1.0, step=0.01, value=0.5, description="p:", style={"description_width": "30px"}, layout=Layout(width="98%"), ) alpha_slider = FloatSlider( min=0.1, max=30.0, step=0.1, value=10.0, description="α:", # noqa: RUF001 style={"description_width": "30px"}, layout=Layout(width="98%"), ) w_slider = FloatSlider( min=0.0, max=20, step=0.5, value=4.0 / 7.0, description="w:", style={"description_width": "30px"}, layout=Layout(width="98%"), ) # Slider layout control slider_box = VBox( [ widgets.Label("Adjust parameters:"), HBox( [x_slider, w_slider], layout=Layout(width="100%", justify_content="space-between"), ), HBox( [p_slider, alpha_slider], layout=Layout(width="100%", justify_content="space-between"), ), ], ) def update_plot(x: float, p: float, alpha: float, w: float) -> None: # noqa: PLR0915 """Update the plot with current slider values # Parameters: - `x : float` x value - `p : float` p value - `k : float` k value - `alpha : float` alpha value """ # Set up the figure and grid - now 2x2 grid fig = plt.figure(figsize=(14, 10)) gs = GridSpec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1]) # Create x and p values focused on [0,1] range xs = np.linspace(0.0, 1.0, 500) ps = np.linspace(0.0, 1.0, 500) # Plot 1: f(x) vs x (top left) ax1 = fig.add_subplot(gs[0, 0]) ys = soft_step(xs, p, alpha, w) ax1.plot(xs, ys, "b-", linewidth=2.5) # Add guidelines ax1.axvline(x=p, color="red", linestyle="--", alpha=0.7, label=f"p = {p:.2f}") ax1.axvline(x=w, color="green", linestyle="--", alpha=0.7, label=f"w = {w:.2f}") ax1.axvline(x=x, color="blue", linestyle=":", alpha=0.7, label=f"x = {x:.2f}") # Add identity line for reference ax1.plot(xs, xs, "k--", alpha=0.3, label="f(x) = x") ax1.set_xlim(0, 1) ax1.set_ylim(0, 1) ax1.set_xlabel("x") ax1.set_ylabel("f(x)") ax1.set_title(f"f(x) with p={p:.2f}, w={w:.2f}, α={alpha:.1f}") # noqa: RUF001 ax1.grid(True, alpha=0.3) ax1.legend(loc="best") # Plot 2: f(p) vs p with fixed x (top right) ax2 = fig.add_subplot(gs[0, 1]) # Plot the main curve with current x value f_p_values = np.array([soft_step(x, p_val, alpha, w) for p_val in ps]) ax2.plot(ps, f_p_values, "blue", linewidth=2.5, label=f"x = {x:.2f}") # Create additional curves for different x values x_values = [0.2, 0.4, 0.6, 0.8] colors = ["purple", "orange", "magenta", "green"] for x_val, color in zip(x_values, colors, strict=False): # Don't draw if too close to current x if abs(x_val - x) > 0.05: # noqa: PLR2004 f_p_values = np.array( [soft_step(x_val, p_val, alpha, w) for p_val in ps], ) ax2.plot( ps, f_p_values, color=color, linewidth=1.5, alpha=0.4, label=f"x = {x_val}", ) # Add guideline for current p value ax2.axvline(x=p, color="red", linestyle="--", alpha=0.7) ax2.set_xlim(0, 1) ax2.set_ylim(0, 1) ax2.set_xlabel("p") ax2.set_ylabel("f(x,p)") ax2.set_title(f"f(x,p) for fixed x={x:.2f}, w={w:.2f}, α={alpha:.1f}") # noqa: RUF001 ax2.grid(True, alpha=0.3) ax2.legend(loc="best") if heatmap: # Plot 3: Heatmap of f(x,p) (bottom left) ax3 = fig.add_subplot(gs[1, 0]) X, P = np.meshgrid(xs, ps) # noqa: N806 Z = np.zeros_like(X) # noqa: N806 # Calculate f(x,p) for all combinations for i, p_val in enumerate(ps): # TYPING: error: Incompatible types in assignment (expression has type "floating[Any]", variable has type "float") [assignment] for j, x_val in enumerate(xs): # type: ignore[assignment] Z[i, j] = soft_step(x_val, p_val, alpha, w) c = ax3.pcolormesh(X, P, Z, cmap="viridis", shading="auto") # Add current parameter values as lines ax3.axhline(y=p, color="red", linestyle="--", label=f"p = {p:.2f}") ax3.axvline(x=w, color="green", linestyle="--", label=f"w = {w:.2f}") ax3.axvline(x=x, color="blue", linestyle="--", label=f"x = {x:.2f}") # Add lines for the reference x values used in the top-right plot for x_val, color in zip(x_values, colors, strict=False): # Don't draw if too close to current x, magic value is fine if abs(x_val - x) > 0.05: # noqa: PLR2004 ax3.axvline(x=x_val, color=color, linestyle=":", alpha=0.4) # Mark the specific point corresponding to the current x and p values ax3.plot(x, p, "ro", markersize=8) # yes we mean to use alpha here (RUF001) ax3.set_xlabel("x") ax3.set_ylabel("p") ax3.set_title(f"f(x,p) heatmap with w={w:.2f}, α={alpha:.1f}") # noqa: RUF001 fig.colorbar(c, ax=ax3, label="f(x,p)") # Plot 4: NEW Heatmap of f(x,p) as function of k and alpha (bottom right) ax4 = fig.add_subplot(gs[1, 1]) # Create k and alpha ranges ws = np.linspace(0.0, 1.0, 100) alphas = np.linspace(0.1, 30.0, 100) K, A = np.meshgrid(ws, alphas) # noqa: N806 Z_ka = np.zeros_like(K) # noqa: N806 # Calculate f(x,p) for all combinations of k and alpha for i, alpha_val in enumerate(alphas): for j, w_val in enumerate(ws): Z_ka[i, j] = soft_step(x, p, alpha_val, w_val) c2 = ax4.pcolormesh(K, A, Z_ka, cmap="plasma", shading="auto") # Add current parameter values as lines # yes we mean to use alpha here (RUF001) ax4.axhline( y=alpha, color="purple", linestyle="--", label=f"α = {alpha:.1f}", # noqa: RUF001 ) ax4.axvline(x=w, color="green", linestyle="--", label=f"w = {w:.2f}") # Mark the specific point corresponding to the current w and alpha values ax4.plot(w, alpha, "ro", markersize=8) # yes we mean to use alpha here (RUF001) ax4.set_xlabel("w") ax4.set_ylabel("α") # noqa: RUF001 ax4.set_title(f"f(x,p) heatmap with fixed x={x:.2f}, p={p:.2f}") fig.colorbar(c2, ax=ax4, label="f(x,p,w,α)") # noqa: RUF001 plt.tight_layout() plt.show() # Display the interactive widget interactive_output = widgets.interactive_output( update_plot, {"x": x_slider, "p": p_slider, "w": w_slider, "alpha": alpha_slider}, ) # we noqa here because we will only call this function inside a notebook if not TYPE_CHECKING: display(VBox([slider_box, interactive_output])) # noqa: F821 ``````{ end_of_file="maze_dataset/benchmark/sweep_fit.py" } ``````{ path="maze_dataset/dataset/__init__.py" } """`MazeDatasetConfig`s are used to create a `MazeDataset` via `MazeDataset.from_config(cfg)`" When initializing mazes, further configuration options can be specified through the `from_config()` factory method as necessary. Options include 1) whether to generate the dataset during runtime or load an existing dataset, 2) if and how to parallelize generation, and 3) where to store the generated dataset. Full documentation of configuration options is available in our repository [@maze-dataset-github]. Available maze generation algorithms are static methods of the `LatticeMazeGenerators` class. Furthermore, a dataset of mazes can be filtered to satisfy certain properties: ```python dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3) ``` Custom filters can be specified, and several filters are included: - `path_length(min_length: int)`: shortest length from the origin to target should be at least `min_length`. - `start_end_distance(min_distance: int)`: Manhattan distance between start and end should be at least `min_distance`, ignoring walls. - `remove_duplicates(...)`: remove mazes which are similar to others in the dataset, measured via Hamming distance. - `remove_duplicates_fast()`: remove mazes which are exactly identical to others in the dataset. All implemented maze generation algorithms are stochastic by nature. For reproducibility, the `seed` parameter of `MazeDatasetConfig` may be set. In practice, we do not find that exact duplicates of mazes are generated with any meaningful frequency, even when generating large datasets. """ from maze_dataset.dataset.collected_dataset import ( MazeDatasetCollection, MazeDatasetCollectionConfig, ) from maze_dataset.dataset.maze_dataset import MazeDataset from maze_dataset.dataset.maze_dataset_config import MazeDatasetConfig __all__ = [ # submodules "collected_dataset", "configs", "dataset", "filters", "maze_dataset_config", "maze_dataset", "rasterized", "success_predict_math", # dataset classes "MazeDataset", "MazeDatasetConfig", "MazeDatasetCollection", "MazeDatasetCollectionConfig", ] ``````{ end_of_file="maze_dataset/dataset/__init__.py" } ``````{ path="maze_dataset/dataset/collected_dataset.py" } """collecting different maze datasets into a single dataset, for greater variety in a training or validation set > [!CAUTION] > `MazeDatasetCollection` is not thoroughly tested and is not guaranteed to work. """ import itertools import json import typing from functools import cached_property import numpy as np from jaxtyping import Int from muutils.json_serialize import ( json_serialize, serializable_dataclass, serializable_field, ) from muutils.json_serialize.util import _FORMAT_KEY, JSONdict from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler from maze_dataset.constants import Coord, CoordTup from maze_dataset.dataset.dataset import GPTDataset, GPTDatasetConfig from maze_dataset.dataset.maze_dataset import MazeDataset, MazeDatasetConfig from maze_dataset.maze import LatticeMaze @serializable_dataclass(kw_only=True) class MazeDatasetCollectionConfig(GPTDatasetConfig): """maze dataset collection configuration, including tokenizers and shuffle""" # Attributes without a default cannot follow attributes with one [misc] maze_dataset_configs: list[MazeDatasetConfig] = serializable_field( # type: ignore[misc] serialization_fn=lambda configs: [config.serialize() for config in configs], loading_fn=lambda data: [ MazeDatasetConfig.load(config) for config in data["maze_dataset_configs"] ], ) def summary(self) -> dict: """return a summary of the config""" return dict( n_mazes=self.n_mazes, max_grid_n=self.max_grid_n, max_grid_shape=self.max_grid_shape, fname=self.to_fname(), cfg_summaries=[c.summary() for c in self.maze_dataset_configs], ) @property def n_mazes(self) -> int: """return the total number of mazes in the collection across all dataset""" return sum(config.n_mazes for config in self.maze_dataset_configs) @property def max_grid_n(self) -> int: """return the maximum grid size of the mazes in the collection""" return max(config.grid_n for config in self.maze_dataset_configs) @property def max_grid_shape(self) -> CoordTup: """return the maximum grid shape of the mazes in the collection""" return (self.max_grid_n, self.max_grid_n) @property def max_grid_shape_np(self) -> Coord: """return the maximum grid shape of the mazes in the collection as a numpy array""" return np.array(self.max_grid_shape, dtype=np.int32) def stable_hash_cfg(self) -> int: """return a stable hash of the config""" return stable_hash(json.dumps(self.serialize())) def to_fname(self) -> str: """convert config to a filename""" return sanitize_fname( f"collected-{self.name}-n{shorten_numerical_to_str(self.n_mazes)}-h{self.stable_hash_cfg() % 10**5}", ) class MazeDatasetCollection(GPTDataset): """a collection of maze datasets""" def __init__( self, cfg: MazeDatasetCollectionConfig, maze_datasets: list[MazeDataset], generation_metadata_collected: dict | None = None, ) -> None: "initialize the dataset collection from a `MazeDatasetCollectionConfig` and a list of `MazeDataset`s" super().__init__() self.cfg: MazeDatasetCollectionConfig = cfg self.maze_datasets: list[MazeDataset] = list(maze_datasets) for c, ds in zip( self.cfg.maze_dataset_configs, self.maze_datasets, strict=False, ): assert c.name == ds.cfg.name assert c == ds.cfg self.generation_metadata_collected: dict | None = generation_metadata_collected @property def dataset_lengths(self) -> list[int]: """return the lengths of each dataset in the collection""" return [len(dataset) for dataset in self.maze_datasets] @property def dataset_cum_lengths(self) -> Int[np.ndarray, " indices"]: """return the cumulative lengths of each dataset in the collection""" return np.array(list(itertools.accumulate(self.dataset_lengths))) @cached_property def mazes(self) -> list[LatticeMaze]: "single list of all mazes in the collection" return list( itertools.chain.from_iterable( dataset.mazes for dataset in self.maze_datasets ), ) def __len__(self) -> int: """return the total number of mazes in the collection""" return sum(len(dataset) for dataset in self.maze_datasets) def __getitem__(self, index: int) -> LatticeMaze: "get a maze by index" # find which dataset the index belongs to # we add 1, since np.searchsorted returns the # index of the last element that is strictly less than the target # while we want the index of the last element less than or equal to the target dataset_idx: int = int(np.searchsorted(self.dataset_cum_lengths, index + 1)) index_adjusted: int = index if dataset_idx > 0: # if the index is 0, `dataset_idx - 1` will be -1. # We just want to use the base index index_adjusted -= self.dataset_cum_lengths[dataset_idx - 1] return self.maze_datasets[dataset_idx][index_adjusted] @classmethod def generate( cls, cfg: MazeDatasetCollectionConfig, **kwargs, ) -> "MazeDatasetCollection": """generate a dataset collection from a config""" datasets = [ MazeDataset.generate(config, **kwargs) for config in cfg.maze_dataset_configs ] return cls(cfg, datasets) @classmethod def download( cls, cfg: MazeDatasetCollectionConfig, **kwargs, ) -> "MazeDatasetCollection": "(not implemented!) download a dataset collection from a config" datasets = [ MazeDataset.download(config, **kwargs) for config in cfg.maze_dataset_configs ] return cls(cfg, datasets) def serialize(self) -> JSONdict: """serialize the dataset collection""" return { _FORMAT_KEY: "MazeDatasetCollection", "cfg": self.cfg.serialize(), "maze_datasets": [dataset.serialize() for dataset in self.maze_datasets], "generation_metadata_collected": json_serialize( self.generation_metadata_collected, ), } @classmethod def load(cls, data: JSONdict) -> "MazeDatasetCollection": """load the dataset collection from the representation created by `serialize`""" assert data[_FORMAT_KEY] == "MazeDatasetCollection" return cls( **{ key: load_item_recursive(data[key], tuple()) for key in ["cfg", "maze_datasets", "generation_metadata_collected"] }, ) # TODO: remove duplication with MazeDatasetConfig().as_tokens() somehow? def as_tokens( self, # TODO: MazeTokenizer maze_tokenizer, # noqa: ANN001 limit: int | None = None, join_tokens_individual_maze: bool = False, ) -> list[list[str]] | list[str]: """return the dataset as tokens if join_tokens_individual_maze is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.: >>> dataset.as_tokens(join_tokens_individual_maze=False) [["a", "b", "c"], ["d", "e", "f"]] >>> dataset.as_tokens(join_tokens_individual_maze=True) ["a b c", "d e f"] """ output: list[list[str]] = [ maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] ] if join_tokens_individual_maze: return [" ".join(tokens) for tokens in output] else: return output def update_self_config(self) -> None: "update the config to match the number of mazes, and update the underlying configs of each dataset" # TODO: why cant we set this directly? its not frozen, and it seems to work in a regular MazeDataset self.cfg.__dict__["n_mazes"] = len(self) for dataset in self.maze_datasets: dataset.update_self_config() self.cfg.maze_dataset_configs = [dataset.cfg for dataset in self.maze_datasets] MazeDatasetCollectionConfig._dataset_class = MazeDatasetCollection # type: ignore[method-assign, assignment] register_loader_handler( LoaderHandler( check=lambda json_item, path=None, z=None: ( # type: ignore[misc] # noqa: ARG005 isinstance(json_item, typing.Mapping) and _FORMAT_KEY in json_item and json_item[_FORMAT_KEY].startswith("MazeDatasetCollection") ), load=lambda json_item, path=None, z=None: MazeDatasetCollection.load(json_item), # type: ignore[misc] # noqa: ARG005 uid="MazeDatasetCollection", source_pckg="maze_dataset.generation.maze_dataset_collection", desc="MazeDatasetCollection", ), ) ``````{ end_of_file="maze_dataset/dataset/collected_dataset.py" } ``````{ path="maze_dataset/dataset/configs.py" } "`MAZE_DATASET_CONFIGS` contains some default configs for tests and demos" import copy from typing import Callable, Iterator, Mapping from maze_dataset.dataset.maze_dataset import MazeDatasetConfig from maze_dataset.generation.generators import LatticeMazeGenerators _MAZE_DATASET_CONFIGS_SRC: dict[str, MazeDatasetConfig] = { cfg.to_fname(): cfg for cfg in [ MazeDatasetConfig( name="test", grid_n=3, n_mazes=5, maze_ctor=LatticeMazeGenerators.gen_dfs, ), MazeDatasetConfig( name="test-perc", grid_n=3, n_mazes=5, maze_ctor=LatticeMazeGenerators.gen_dfs_percolation, maze_ctor_kwargs={"p": 0.7}, ), MazeDatasetConfig( name="demo_small", grid_n=3, n_mazes=100, maze_ctor=LatticeMazeGenerators.gen_dfs, ), MazeDatasetConfig( name="demo", grid_n=6, n_mazes=10000, maze_ctor=LatticeMazeGenerators.gen_dfs, ), ] } class _MazeDatsetConfigsWrapper(Mapping[str, MazeDatasetConfig]): "wrap the default configs in a read-only dict-like object" def __init__(self, configs: dict[str, MazeDatasetConfig]) -> None: "initialize with a dict of configs" self._configs = configs def __getitem__(self, item: str) -> MazeDatasetConfig: return self._configs[item] def __len__(self) -> int: return len(self._configs) def __iter__(self) -> Iterator: "iterate over the keys" return iter(self._configs) # TYPING: error: Return type "list[str]" of "keys" incompatible with return type "KeysView[str]" in supertype "Mapping" [override] def keys(self) -> list[str]: # type: ignore[override] "return the keys" return list(self._configs.keys()) # TYPING: error: Return type "list[tuple[str, MazeDatasetConfig]]" of "items" incompatible with return type "ItemsView[str, MazeDatasetConfig]" in supertype "Mapping" [override] def items(self) -> list[tuple[str, MazeDatasetConfig]]: # type: ignore[override] "return the items" return [(k, copy.deepcopy(v)) for k, v in self._configs.items()] # TYPING: error: Return type "list[MazeDatasetConfig]" of "values" incompatible with return type "ValuesView[MazeDatasetConfig]" in supertype "Mapping" [override] def values(self) -> list[MazeDatasetConfig]: # type: ignore[override] return [copy.deepcopy(v) for v in self._configs.values()] MAZE_DATASET_CONFIGS: _MazeDatsetConfigsWrapper = _MazeDatsetConfigsWrapper( _MAZE_DATASET_CONFIGS_SRC, ) def _get_configs_for_examples() -> list[dict]: """Generate a comprehensive list of diverse maze configurations. # Returns: - `list[dict]` List of configuration dictionaries for maze generation """ configs: list[dict] = [] # Define the grid sizes to test grid_sizes: list[int] = [5, 8, 12, 15, 20] # Define percolation probabilities percolation_probs: list[float] = [0.3, 0.5, 0.7] # Core algorithms with basic configurations basic_algorithms: dict[str, tuple[Callable, dict]] = { "dfs": (LatticeMazeGenerators.gen_dfs, {}), "wilson": (LatticeMazeGenerators.gen_wilson, {}), "kruskal": (LatticeMazeGenerators.gen_kruskal, {}), "recursive_division": (LatticeMazeGenerators.gen_recursive_division, {}), } # Generate basic configurations for each algorithm and grid size for grid_n in grid_sizes: for algo_name, (maze_ctor, base_kwargs) in basic_algorithms.items(): configs.append( dict( name="basic", grid_n=grid_n, maze_ctor=maze_ctor, maze_ctor_kwargs=base_kwargs, description=f"Basic {algo_name.upper()} maze ({grid_n}x{grid_n})", tags=[f"algo:{algo_name}", "basic", f"grid:{grid_n}"], ) ) # Generate percolation configurations for grid_n in grid_sizes: for p in percolation_probs: # Pure percolation configs.append( dict( name=f"p{p}", grid_n=grid_n, maze_ctor=LatticeMazeGenerators.gen_percolation, maze_ctor_kwargs=dict(p=p), description=f"Pure percolation (p={p}) ({grid_n}x{grid_n})", tags=[ "algo:percolation", "percolation", f"percolation:{p}", f"grid:{grid_n}", ], ) ) # DFS with percolation configs.append( dict( name=f"p{p}", grid_n=grid_n, maze_ctor=LatticeMazeGenerators.gen_dfs_percolation, maze_ctor_kwargs=dict(p=p), description=f"DFS with percolation (p={p}) ({grid_n}x{grid_n})", tags=[ "algo:dfs_percolation", "dfs", "percolation", f"percolation:{p}", f"grid:{grid_n}", ], ) ) # Generate specialized constraint configurations constraint_base_config: dict = dict( grid_n=10, maze_ctor=LatticeMazeGenerators.gen_dfs, ) constraint_base_tags: list[str] = [ "algo:dfs", "dfs", "constrained_dfs", f"grid:{constraint_base_config['grid_n']}", ] constraint_configs: list[dict] = [ # DFS without forks (simple path) dict( name="forkless", maze_ctor_kwargs=dict(do_forks=False), description="DFS without forks (10x10)", tags=["forkless"], ), # Accessible cells constraints dict( name="accessible_cells_count", maze_ctor_kwargs=dict(accessible_cells=50), description="DFS with limited accessible cells (50)", tags=["limited:cells", "limited:absolute"], ), dict( name="accessible_cells_ratio", maze_ctor_kwargs=dict(accessible_cells=0.6), description="DFS with 60% accessible cells", tags=["limited:cells", "limited:ratio"], ), # Tree depth constraints dict( name="max_tree_depth_absolute", maze_ctor_kwargs=dict(max_tree_depth=10), description="DFS with max tree depth of 10", tags=["limited:depth", "limited:absolute"], ), dict( name="max_tree_depth_ratio", maze_ctor_kwargs=dict(max_tree_depth=0.3), description="DFS with max tree depth 30% of grid size", tags=["limited:depth", "limited:ratio"], ), # Start position constraint dict( name="start_center", maze_ctor_kwargs=dict(start_coord=[5, 5]), description="DFS starting from center of grid", tags=["custom_start"], ), dict( name="start_corner", maze_ctor_kwargs=dict(start_coord=[0, 0]), description="DFS starting from corner of grid", tags=["custom_start"], ), ] # Add combined constraints as special case configs.append( dict( name="combined_constraints", grid_n=15, maze_ctor=LatticeMazeGenerators.gen_dfs, maze_ctor_kwargs=dict( accessible_cells=100, max_tree_depth=25, start_coord=[7, 7], ), description="DFS with multiple constraints (100 cells, depth 25, center start)", tags=["algo:dfs", "dfs", "constrained_dfs", "grid:15"], ) ) # Apply the base config to all constraint configs and add to main configs list for config in constraint_configs: full_config = constraint_base_config.copy() full_config.update(config) full_config["tags"] = constraint_base_tags + config["tags"] configs.append(full_config) # Generate endpoint options endpoint_variations: list[tuple[bool, bool, str]] = [ (True, False, "deadend start only"), (False, True, "deadend end only"), (True, True, "deadend start and end"), ] for deadend_start, deadend_end, desc in endpoint_variations: configs.append( dict( name=f"deadend_s{int(deadend_start)}_e{int(deadend_end)}", grid_n=8, maze_ctor=LatticeMazeGenerators.gen_dfs, maze_ctor_kwargs={}, endpoint_kwargs=dict( deadend_start=deadend_start, deadend_end=deadend_end, endpoints_not_equal=True, ), description=f"DFS with {desc}", tags=["algo:dfs", "dfs", "deadend_endpoints", "grid:8"], ) ) # Add percolation with deadend endpoints configs.append( dict( name="deadends", grid_n=8, maze_ctor=LatticeMazeGenerators.gen_dfs_percolation, maze_ctor_kwargs=dict(p=0.3), endpoint_kwargs=dict( deadend_start=True, deadend_end=True, endpoints_not_equal=True, except_on_no_valid_endpoint=False, ), description="DFS percolation (p=0.3) with deadend endpoints", tags=[ "algo:dfs_percolation", "dfs", "percolation", "deadend_endpoints", "grid:8", ], ) ) return configs ``````{ end_of_file="maze_dataset/dataset/configs.py" } ``````{ path="maze_dataset/dataset/dataset.py" } """`GPTDatasetConfig` and `GPTDataset` are base classes for datasets they implement some basic functionality, saving/loading, the `from_config` pipeline, and filtering > [!NOTE] > these should probably be moved into a different package, so don't rely on them being here """ import functools import json import random import typing import warnings from pathlib import Path from typing import Callable, Type, TypeVar import numpy as np from muutils.json_serialize import ( JSONitem, SerializableDataclass, serializable_dataclass, serializable_field, ) from muutils.json_serialize.util import ( JSONdict, ) from muutils.misc import sanitize_fname, shorten_numerical_to_str, stable_hash from zanj import ZANJ from maze_dataset.generation.seed import GLOBAL_SEED def set_reproducibility(seed: int) -> None: "set reproducibility in stdlib random and numpy (but not torch)" random.seed(seed) np.random.seed(seed) class FilterInfoMismatchError(ValueError): """raised when the filter info in a dataset config does not match the filter info in the dataset""" pass def _load_applied_filters( filters: list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]], ) -> list[dict[typing.Literal["name", "args", "kwargs"], str | tuple | dict]]: try: return [ dict( name=filter_info["name"], args=tuple( filter_info["args"], ), # muutils/zanj save tuples as lists, and this causes problems kwargs=dict(filter_info["kwargs"]), # type: ignore[arg-type] ) for filter_info in filters ] except Exception as e: err_msg: str = f"failed to load applied filters:\n{filters}" raise ValueError(err_msg) from e @serializable_dataclass(kw_only=True) class GPTDatasetConfig(SerializableDataclass): """base GPTDatasetConfig class""" name: str # TODO: get rid of all these things as part of migration to tokenizer-free dataset config # -------------------------------------------------- seq_len_min: int = serializable_field(default=1) seq_len_max: int = serializable_field(default=512) # -------------------------------------------------- seed: int | None = serializable_field(default=GLOBAL_SEED) applied_filters: list[ dict[typing.Literal["name", "args", "kwargs"], str | list | tuple | dict] ] = serializable_field( default_factory=list, deserialize_fn=_load_applied_filters, assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures ) def __post_init__(self) -> None: "post init, where we set a random seed if none is set" assert self.seq_len_min <= self.seq_len_max # if seed set to None, then generate a new random seed if self.seed is None: self.seed = np.random.randint(2**31) # TODO: something here is broken if self.seed != GLOBAL_SEED: warnings.warn( f"in GPTDatasetConfig {self.name=}, {self.seed=} is trying to override {GLOBAL_SEED = }", ) set_reproducibility(self.seed) def summary(self) -> dict: """return a summary of the config""" # do we run this to make sure it doesn't error? self_ser: dict = self.serialize() assert self_ser return dict( name=self.name, seq_len_min=self.seq_len_min, seq_len_max=self.seq_len_max, seed=self.seed, applied_filters=self.applied_filters, ) @property def _dataset_class(self) -> type: raise NotImplementedError("this should be implemented by subclasses!") def to_fname(self) -> str: """convert config to a filename""" self_json_str: str = json.dumps(self.serialize()) self_json_hash: int = int(abs(stable_hash(self_json_str)) % 1e10) warnings.warn( f"using fallblack to_fname() method for {self.__class__.__name__}, this should be implemented by subclasses!", ) return sanitize_fname( # TYPING: error: Argument 1 to "len" has incompatible type "GPTDatasetConfig"; expected "Sized" [arg-type] f"f{self.name}-n{shorten_numerical_to_str(len(self))}-h{self_json_hash}", # type: ignore[arg-type] ) def _dataset_config_load(*args, **kwargs) -> "GPTDatasetConfig": err_msg: str = f"this `load` function should be implemented by subclasses! got: {args=}, {kwargs=}" raise NotImplementedError( err_msg, ) # abstract function, hence we dont care that `self` is unused def _dataset_config_serialize(self, *args, **kwargs) -> JSONitem: # noqa: ANN001, ARG001 err_msg: str = f"this `serialize` function should be implemented by subclasses! got: {args=}, {kwargs=}" raise NotImplementedError( err_msg, ) GPTDatasetConfig.load = _dataset_config_load # type: ignore[method-assign] GPTDatasetConfig.serialize = _dataset_config_serialize # type: ignore[method-assign,assignment] T_DatasetConfig = TypeVar("T_DatasetConfig", bound=GPTDatasetConfig) class GPTDataset(typing.Generic[T_DatasetConfig]): """wrapper for torch dataset with some extra functionality (meaning the functionality should be inherited in downstream classes) > [!NOTE] > `GPTDatasetConfig` should implement a `to_fname` method that returns a unique filename for the config # Requires: the following methods should be implemented in subclasses: - `__init__(self, cfg: GPTDatasetConfig, **kwargs)` initialize the dataset from a given config. kwargs are not passed through, the kwargs should take the actual generated or loaded data (a list of objects or sequences probably) - `generate(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` generate the dataset from a given config. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. how many threads to use for generation) - `serialize(self) -> JSONitem` serialize the dataset to a ZANJ-serializable object, including: - config - data in formats specified by `self.save_formats` - `load(cls, data: JSONitem) -> GPTDataset` load the dataset from a ZANJ-serializable object - `download(cls, cfg: GPTDatasetConfig, **kwargs) -> GPTDataset` given a config, try to download a dataset from some source. kwargs are passed through from `from_config`, and should only contain things that dont belong in the config (i.e. some kind of auth token or source url) - `__len__(self) -> int` return the length of the dataset, required to match interface of `torch.utils.data.Dataset` - `__getitem__(self, i: int) -> list[str]` return the ith item in the dataset, required to match interface of `torch.utils.data.Dataset` - `update_self_config(self) -> None` update the config of the dataset to match the current state of the dataset, used primarily in filtering and validation - decorating the appropriate filter namespace with `register_filter_namespace_for_dataset(your_dataset_class)` if you want to use filters # Parameters: - `cfg : GPTDatasetConfig` config for the dataset, used to generate the dataset - `do_generate : bool` whether to generate the dataset if it isn't found (defaults to `True`) - `load_local : bool` whether to try finding the dataset locally (defaults to `True`) - `save_local : bool` whether to save the dataset locally if it is generated or downloaded (defaults to `True`) - `do_download : bool` whether to try downloading the dataset (defaults to `True`) - `local_base_path : Path` where to save the dataset (defaults to `Path("data/maze_dataset")`) # Returns: - `GPTDataset` the dataset, as you wanted it # Implements: - `save(self, file_path: str) -> None` save the dataset to a file, using ZANJ - `read(cls, file_path: str) -> GPTDataset` read the dataset from a file, using ZANJ get all items in the dataset, in the specified format - `filter_by(self)` returns a namespace class - `_filter_namespace(self) -> Class` returns a namespace class for filtering the dataset, checking that method - `_apply_filters_from_config(self) -> None` apply filters to the dataset, as specified in the config. used in `from_config()` but only when generating """ _FILTER_NAMESPACE: type = "this isn't a filter namespace! you have to initialize this by registering with `register_filter_namespace_for_dataset`" # type: ignore cfg: "T_DatasetConfig" @classmethod def from_config( # noqa: C901, PLR0912 cls: "type[T_Dataset]", cfg: "T_DatasetConfig", do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: ZANJ | None = None, do_download: bool = True, local_base_path: Path = Path("data/maze_dataset"), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs, ) -> "T_Dataset": """base class for gpt datasets priority of loading: 1. load from local 2. download 3. generate """ print_log: Callable = print if verbose else lambda *_a, **_kw: None local_base_path = Path(local_base_path) fname: Path = Path(f"{cfg.to_fname()}.zanj") output: T_Dataset | None = None did_load_local: bool = False if zanj is None: zanj = ZANJ() print_log(f"trying to get the dataset '{cfg.to_fname()}'") if not (load_local or do_download or do_generate): raise ValueError( "no way to load dataset! you said not to load local, not to download, and not to generate", ) dataset_path: Path = local_base_path / fname # try loading if load_local: # noqa: SIM102 if dataset_path.exists(): print_log(f"loading dataset from {dataset_path.as_posix()}") try: output = cls.read(dataset_path, zanj=zanj) did_load_local = True print_log("load successful!") except Exception as e: # noqa: BLE001 print_log(f"failed to load dataset: {e}") if do_download and output is None: print_log("seeing if we can download the dataset...") try: output = cls.download(cfg, **kwargs) print_log("download successful!") except NotImplementedError: print_log("no download found, or download failed") if do_generate and output is None: print_log("generating dataset...") output = cls.generate(cfg, verbose=verbose, **kwargs) # only if we generated it, apply filters output = output._apply_filters_from_config() # check and save if output is None: raise ValueError("failed to load dataset!") cfg_diff: dict = cfg.diff(output.cfg, of_serialized=True) if cfg_diff: if except_on_config_mismatch: if allow_generation_metadata_filter_mismatch and ( cfg_diff == { "applied_filters": { "self": [], "other": [ { "name": "collect_generation_meta", "args": (), "kwargs": {}, }, ], }, } ): pass else: err_msg: str = f"config mismatch: {cfg_diff = }" raise ValueError(err_msg) else: warnings.warn(f"config mismatch: {cfg_diff = }") if save_local and not did_load_local: print_log(f"saving dataset to {dataset_path}") output.save(dataset_path, zanj=zanj) print_log( f"Got dataset {output.cfg.name} with {len(output)} items. {output.cfg.to_fname() = }", ) return output def save(self, file_path: Path | str, zanj: ZANJ | None = None) -> None: "save dataset to a file with zanj" if zanj is None: zanj = ZANJ() zanj.save(self.serialize(), file_path) # serialization & loading @classmethod def read( cls: "type[T_Dataset]", file_path: str | Path, zanj: ZANJ | None = None ) -> "T_Dataset": "read dataset from a file with zanj" if zanj is None: zanj = ZANJ() return zanj.read(file_path) def serialize(self: "T_Dataset") -> JSONdict: "(implement in subclass!) serialize to something we can save with zanj" raise NotImplementedError def data_hash(self: "T_Dataset") -> int: "(implement in subclass!) return a hash of the data" raise NotImplementedError @classmethod def load(cls: "type[T_Dataset]", data: JSONdict) -> "T_Dataset": "(implement in subclass!) load a dataset from what we made with `.serialize()`" raise NotImplementedError # generating & downloading @classmethod def generate( cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs ) -> "T_Dataset": "(implement in subclass!) generative given the config" raise NotImplementedError @classmethod def download( cls: "type[T_Dataset]", cfg: "T_DatasetConfig", **kwargs ) -> "T_Dataset": "(implement in subclass!) download the dataset given the config" raise NotImplementedError # filtering def update_self_config(self) -> None: """(implement in subclass!) update the config of the dataset to match the actual data, if needed for example, adjust number of mazes after filtering """ pass def __len__(self) -> int: "return the length of the dataset" raise NotImplementedError("implement in subclass!") class FilterBy: """thanks GPT-4""" def __init__(self, dataset: "T_Dataset") -> None: "mock class so we can call `my_dataset.filter_by.some_registered_filter()`" self.dataset: T_Dataset = dataset def __getattr__(self, name: str) -> typing.Callable[..., "T_Dataset"]: "override getattr so we can call `my_dataset.filter_by.some_registered_filter()`" filter_func: DatasetFilterFunc = getattr( self.dataset._FILTER_NAMESPACE, name, ) def wrapped_filter_func(*args, **kwargs): # noqa: ANN202 return filter_func(self.dataset, *args, **kwargs) return wrapped_filter_func @property def filter_by(self) -> "FilterBy": "can call `my_dataset.filter_by.some_registered_filter()` to filter the dataset" return self.FilterBy(self) def _apply_filters_from_config(self: "T_Dataset") -> "T_Dataset": """apply filters to the dataset, as specified in the config. used in `from_config()`""" output: T_Dataset = self # copy the list, and then clear it in the config. we do this because each time we apply a filter it will update config.applied_filters applied_filters_old: list[ dict[typing.Literal["name", "args", "kwargs"], typing.Any] ] = self.cfg.applied_filters output.cfg.applied_filters = list() # apply the filters for filter_info in applied_filters_old: filter_name: str = filter_info["name"] if filter_name not in output._FILTER_NAMESPACE.__dict__: if filter_name.startswith("__custom__:"): err_msg = f"the dataset {output.cfg.to_fname()} was filtering using a custom filter: '{filter_name}', which we don't know about. add it to MazeDatasetFilters!" raise ValueError( err_msg, ) err_msg = f"the dataset {output.cfg.to_fname()} was filtering using an unknown filter: '{filter_name}'" raise ValueError( err_msg, ) filter_args: list = filter_info.get("args", list()) filter_kwargs: dict = filter_info.get("kwargs", dict()) output = getattr(output.filter_by, filter_name)( *filter_args, **filter_kwargs, ) # update the config, perform checks # TODO: some funny business with manually specified filters here? output.update_self_config() _check_filter_equality( filters_old=applied_filters_old, filters_new=output.cfg.applied_filters, # type: ignore[arg-type] ) return output def _check_filter_equality( filters_old: list[ dict[typing.Literal["name", "args", "kwargs"], str | list | dict] ], filters_new: list[ dict[typing.Literal["name", "args", "kwargs"], str | list | dict] ], ) -> None: try: assert len(filters_old) == len(filters_new) for filterinfo_new, filterinfo_old in zip( filters_old, filters_new, strict=False, ): # basic checks assert isinstance(filterinfo_new, dict), "filterinfo_new is not a dict" assert isinstance(filterinfo_old, dict), "filterinfo_old is not a dict" assert all(key in filterinfo_new for key in ["name", "args", "kwargs"]), ( "missing keys in filterinfo_new" ) assert all(key in filterinfo_old for key in ["name", "args", "kwargs"]), ( "missing keys in filterinfo_old" ) # name assert filterinfo_new["name"] == filterinfo_old["name"], ( "filter names don't match" ) # args assert len(filterinfo_new["args"]) == len(filterinfo_old["args"]), ( "filter args of different lengths" ) for arg_new, arg_old in zip( filterinfo_new["args"], filterinfo_old["args"], strict=False, ): assert arg_new == arg_old, "filter args don't match" # kwargs assert len(filterinfo_new["kwargs"]) == len(filterinfo_old["kwargs"]), ( "filter kwargs of different lengths" ) for key in filterinfo_old["kwargs"]: assert key in filterinfo_new["kwargs"], ( f"filter kwargs don't match: missing key '{key}'" ) assert filterinfo_new["kwargs"][key] == filterinfo_old["kwargs"][key], ( # type: ignore[index] f"filter kwargs don't match: values for key '{key}' don't match" ) except AssertionError as e: err_msg: str = ( f"config mismatch in applied filters: {filters_new} != {filters_old}" ) raise FilterInfoMismatchError( err_msg, ) from e def register_filter_namespace_for_dataset( dataset_cls: Type[GPTDataset], ) -> Callable[[Type], Type]: """register the namespace class with the given dataset class""" def decorator(filter_namespace_cls: Type) -> Type: dataset_cls._FILTER_NAMESPACE = filter_namespace_cls filter_namespace_cls._BASE_DATASET = dataset_cls return filter_namespace_cls return decorator T_Dataset = TypeVar("T_Dataset", bound=GPTDataset) P_FilterKwargs = typing.ParamSpec("P_FilterKwargs") DatasetFilterFunc = Callable[typing.Concatenate[T_Dataset, P_FilterKwargs], T_Dataset] def register_dataset_filter( method: DatasetFilterFunc, ) -> DatasetFilterFunc: """register a dataset filter, copying the underlying dataset and updating the config be sure to return a COPY, not the original? # TODO: what the heck do we mean by the above? why the question mark? it should be a copy right? method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` """ @functools.wraps(method) def wrapper( # TYPING: error: ParamSpec "P_FilterKwargs" is unbound [valid-type] dataset: T_Dataset, *args: P_FilterKwargs.args, # type: ignore[valid-type] **kwargs: P_FilterKwargs.kwargs, # type: ignore[valid-type] ) -> T_Dataset: new_dataset = method(dataset, *args, **kwargs) # update the config new_dataset.cfg.applied_filters.append( dict(name=method.__name__, args=args, kwargs=kwargs), # type: ignore[attr-defined] ) new_dataset.update_self_config() return new_dataset # TYPING: error: Incompatible return value type (got "_Wrapped[[Any, KwArg(Any)], Any, [Never, VarArg(Any), KwArg(Any)], Never]", expected "DatasetFilterProtocol[Any]") [return-value] return wrapper # type: ignore[return-value] ``````{ end_of_file="maze_dataset/dataset/dataset.py" } ``````{ path="maze_dataset/dataset/filters.py" } "filtering `MazeDataset`s" import copy import functools import typing from collections import Counter, defaultdict import numpy as np from maze_dataset.constants import CoordTup from maze_dataset.dataset.dataset import ( DatasetFilterFunc, register_dataset_filter, register_filter_namespace_for_dataset, ) from maze_dataset.dataset.maze_dataset import MazeDataset from maze_dataset.maze import SolvedMaze def register_maze_filter( method: typing.Callable[[SolvedMaze, typing.Any], bool], ) -> DatasetFilterFunc: """register a maze filter, casting it to operate over the whole list of mazes method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays """ @functools.wraps(method) def wrapper(dataset: MazeDataset, *args, **kwargs) -> MazeDataset: # copy and filter new_dataset: MazeDataset = copy.deepcopy( MazeDataset( cfg=dataset.cfg, mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)], ), ) # update the config new_dataset.cfg.applied_filters.append( dict(name=method.__name__, args=args, kwargs=kwargs), ) new_dataset.update_self_config() return new_dataset return wrapper @register_filter_namespace_for_dataset(MazeDataset) class MazeDatasetFilters: "namespace for filters for `MazeDataset`s" @register_maze_filter @staticmethod def path_length(maze: SolvedMaze, min_length: int) -> bool: """filter out mazes with a solution length less than `min_length`""" return len(maze.solution) >= min_length @register_maze_filter @staticmethod def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool: """filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)""" return bool( (np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance).all() ) @register_dataset_filter @staticmethod def cut_percentile_shortest( dataset: MazeDataset, percentile: float = 10.0, ) -> MazeDataset: """cut the shortest `percentile` of mazes from the dataset `percentile` is 1-100, not 0-1, as this is what `np.percentile` expects """ lengths: np.ndarray = np.array([len(m.solution) for m in dataset]) cutoff: int = int(np.percentile(lengths, percentile)) filtered_mazes: list[SolvedMaze] = [ m for m in dataset if len(m.solution) > cutoff ] new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes) return copy.deepcopy(new_dataset) @register_dataset_filter @staticmethod def truncate_count( dataset: MazeDataset, max_count: int, ) -> MazeDataset: """truncate the dataset to be at most `max_count` mazes""" new_dataset: MazeDataset = MazeDataset( cfg=dataset.cfg, mazes=dataset.mazes[:max_count], ) return copy.deepcopy(new_dataset) @register_dataset_filter @staticmethod def remove_duplicates( dataset: MazeDataset, minimum_difference_connection_list: int | None = 1, minimum_difference_solution: int | None = 1, _max_dataset_len_threshold: int = 1000, ) -> MazeDataset: """remove duplicates from a dataset, keeping the **LAST** unique maze set minimum either minimum difference to `None` to disable checking if you want to avoid mazes which have more overlap, set the minimum difference to be greater Gotchas: - if two mazes are of different sizes, they will never be considered duplicates - if two solutions are of different lengths, they will never be considered duplicates TODO: check for overlap? """ if len(dataset) > _max_dataset_len_threshold: raise ValueError( "this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n", "if you know what you're doing, change `_max_dataset_len_threshold`", ) unique_mazes: list[SolvedMaze] = list() maze_a: SolvedMaze maze_b: SolvedMaze for i, maze_a in enumerate(dataset.mazes): a_unique: bool = True for maze_b in dataset.mazes[i + 1 :]: # after all that nesting, more nesting to perform checks if (minimum_difference_connection_list is not None) and ( # noqa: SIM102 maze_a.connection_list.shape == maze_b.connection_list.shape ): if ( np.sum(maze_a.connection_list != maze_b.connection_list) <= minimum_difference_connection_list ): a_unique = False break if (minimum_difference_solution is not None) and ( # noqa: SIM102 maze_a.solution.shape == maze_b.solution.shape ): if ( np.sum(maze_a.solution != maze_b.solution) <= minimum_difference_solution ): a_unique = False break if a_unique: unique_mazes.append(maze_a) return copy.deepcopy( MazeDataset( cfg=dataset.cfg, mazes=unique_mazes, generation_metadata_collected=dataset.generation_metadata_collected, ), ) @register_dataset_filter @staticmethod def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset: """remove duplicates from a dataset""" unique_mazes = list(dict.fromkeys(dataset.mazes)) return copy.deepcopy( MazeDataset( cfg=dataset.cfg, mazes=unique_mazes, generation_metadata_collected=dataset.generation_metadata_collected, ), ) @register_dataset_filter @staticmethod def strip_generation_meta(dataset: MazeDataset) -> MazeDataset: """strip the generation meta from the dataset""" new_dataset: MazeDataset = copy.deepcopy(dataset) for maze in new_dataset: # hacky because it's a frozen dataclass maze.__dict__["generation_meta"] = None return new_dataset @register_dataset_filter @staticmethod # yes, this function is complicated hence the noqa def collect_generation_meta( # noqa: C901, PLR0912 dataset: MazeDataset, clear_in_mazes: bool = True, inplace: bool = True, allow_fail: bool = False, ) -> MazeDataset: """collect the generation metadata from each maze into a dataset-level metadata (saves space) # Parameters: - `dataset : MazeDataset` - `clear_in_mazes : bool` whether to clear the generation meta in the mazes after collecting it, keep it there if `False` (defaults to `True`) - `inplace : bool` whether to modify the dataset in place or return a new one (defaults to `True`) - `allow_fail : bool` whether to allow the collection to fail if the generation meta is not present in a maze (defaults to `False`) # Returns: - `MazeDataset` the dataset with the generation metadata collected # Raises: - `ValueError` : if the generation meta is not present in a maze and `allow_fail` is `False` - `ValueError` : if we have other problems converting the generation metadata - `TypeError` : if the generation meta on a maze is of an unexpected type """ if dataset.generation_metadata_collected is not None: return dataset else: assert dataset[0].generation_meta is not None, ( "generation meta is not collected and original is not present" ) # if the generation meta is already collected, don't collect it again, do nothing new_dataset: MazeDataset if inplace: new_dataset = dataset else: new_dataset = copy.deepcopy(dataset) gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = ( defaultdict(Counter) ) for maze in new_dataset: if maze.generation_meta is None: if allow_fail: break raise ValueError( "generation meta is not present in a maze, cannot collect generation meta", ) for key, value in maze.generation_meta.items(): if isinstance(value, (bool, int, float, str)): # noqa: UP038 gen_meta_lists[key][value] += 1 elif isinstance(value, set): # special case for visited_cells gen_meta_lists[key].update(value) elif isinstance(value, (list, np.ndarray)): # noqa: UP038 if isinstance(value, list): # TODO: `for` loop variable `value` overwritten by assignment target (Ruff PLW2901) try: value = np.array(value) # noqa: PLW2901 except ValueError as convert_to_np_err: err_msg = ( f"Cannot collect generation meta for {key} as it is a list of type '{type(value[0]) = !s}'" "\nexpected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords" ) raise ValueError(err_msg) from convert_to_np_err if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim): # assume its a single coordinate gen_meta_lists[key][tuple(value)] += 1 # magic value is fine here elif (len(value.shape) == 2) and ( # noqa: PLR2004 value.shape[1] == maze.lattice_dim ): # assume its a list of coordinates gen_meta_lists[key].update([tuple(v) for v in value]) else: err_msg = ( f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}\n" "expected either a coord of shape (2,) or a list of coords of shape (n, 2)" ) raise ValueError(err_msg) else: err_msg = ( f"Cannot collect generation meta for {key} as it is of type '{type(value)!s}'\n" "expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords" ) raise TypeError(err_msg) # clear the data if clear_in_mazes: # hacky because it's a frozen dataclass maze.__dict__["generation_meta"] = None new_dataset.generation_metadata_collected = { key: dict(value) for key, value in gen_meta_lists.items() } return new_dataset ``````{ end_of_file="maze_dataset/dataset/filters.py" } ``````{ path="maze_dataset/dataset/maze_dataset.py" } """`MazeDatasetConfig` is where you decide what your dataset should look like, then pass it to `MazeDataset.from_config` to generate or load the dataset. see [demo_dataset notebook](../../notebooks/demo_dataset) """ import copy import json import multiprocessing import typing import warnings from pathlib import Path from typing import Literal, Optional, cast, overload import numpy as np import tqdm from jaxtyping import Int from muutils.json_serialize import ( json_serialize, ) from muutils.json_serialize.util import ( _FORMAT_KEY, JSONdict, ) from muutils.misc import stable_hash from zanj import ZANJ from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler from maze_dataset.constants import CoordArray from maze_dataset.dataset.dataset import ( GPTDataset, ) from maze_dataset.dataset.maze_dataset_config import ( SERIALIZE_MINIMAL_THRESHOLD, EndpointKwargsType, MazeDatasetConfig, ) from maze_dataset.generation.seed import GLOBAL_SEED from maze_dataset.maze import LatticeMaze, SolvedMaze _GLOBAL_WORKER_CONFIG: MazeDatasetConfig def _generate_maze_helper(index: int) -> Optional[SolvedMaze]: # noqa: ARG001 """Helper function for generating mazes in parallel. > [!CAUTION] > don't use this unless generating in parallel! """ global _GLOBAL_WORKER_CONFIG # noqa: PLW0602 # TODO: don't use this unless generating in parallel! maze: LatticeMaze = _GLOBAL_WORKER_CONFIG.maze_ctor( grid_shape=_GLOBAL_WORKER_CONFIG.grid_shape_np, **_GLOBAL_WORKER_CONFIG.maze_ctor_kwargs, ) endpoint_kwargs: EndpointKwargsType = _GLOBAL_WORKER_CONFIG.endpoint_kwargs.copy() # Generate the solution # mypy doesnt realize EndpointKwargsType has only string keys: `Keywords must be strings [misc]` # TYPING: error: No overload variant of "generate_random_path" of "LatticeMaze" matches argument type "dict[Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | list[tuple[int, int]] | None]" [call-overload] solution: Optional[CoordArray] = maze.generate_random_path(**endpoint_kwargs) # type: ignore[misc, call-overload] # Validate the solution if ( solution is None or len(solution) == 0 or not isinstance(solution, np.ndarray) # magic value is fine here or len(solution.shape) != 2 # noqa: PLR2004 ): return None # Return None if the solution is invalid return SolvedMaze.from_lattice_maze( lattice_maze=maze, solution=solution, ) def _maze_gen_init_worker(config: MazeDatasetConfig) -> None: """special worker helper > [!CAUTION] > this makes the generation depend both on whether parallelism is used, and on the number of processes. this is bad! """ # TODO: dont use globals here! global _GLOBAL_WORKER_CONFIG # noqa: PLW0603 _GLOBAL_WORKER_CONFIG = config process_id: tuple[int, ...] = multiprocessing.current_process()._identity if len(process_id) == 0: # no multiprocessing, seed was already set pass elif len(process_id) == 1: # multiprocessing, adjust seed based on process id # only set numpy seed, since we do not use other random gens np.random.seed( _GLOBAL_WORKER_CONFIG.seed or GLOBAL_SEED # if the seed is None, use the global seed + process_id[0] ) else: err_msg = ( f"unexpected process id: {process_id = }\n{multiprocessing.Process() = }" ) raise ValueError( err_msg, ) class MazeDataset(GPTDataset[MazeDatasetConfig]): """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`""" def __init__( self, cfg: MazeDatasetConfig, mazes: typing.Sequence[SolvedMaze], generation_metadata_collected: dict | None = None, ) -> None: """initialize a maze dataset from a config and a list of solved mazes""" super().__init__() self.cfg: MazeDatasetConfig = cfg self.mazes: list[SolvedMaze] = list(mazes) self.generation_metadata_collected: dict | None = generation_metadata_collected # TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset" [override] @classmethod def from_config( # type: ignore[override] cls, # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] cfg: MazeDatasetConfig, # type: ignore[override] do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: ZANJ | None = None, do_download: bool = True, local_base_path: Path = Path("data/maze_dataset"), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs, ) -> "MazeDataset": """create a maze dataset from a config priority of loading: 1. load from local 2. download 3. generate """ return cast( MazeDataset, super().from_config( cfg=cfg, do_generate=do_generate, load_local=load_local, save_local=save_local, zanj=zanj, do_download=do_download, local_base_path=local_base_path, except_on_config_mismatch=except_on_config_mismatch, allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, verbose=verbose, **kwargs, ), ) def data_hash(self) -> int: """return a hash of the data""" return stable_hash(str(tuple([x.serialize() for x in self.mazes]))) def __getitem__(self, i: int) -> SolvedMaze: """get a maze by index""" return self.mazes[i] def __iter__(self) -> typing.Iterator[SolvedMaze]: """iterate over the mazes""" return iter(self.mazes) def __deepcopy__(self, memo) -> "MazeDataset": # noqa: ANN001 """deepcopy the dataset FIX: this isnt actually a deepcopy I think? """ return MazeDataset.load(self._serialize_full()) # TYPING: get type hints on the tokenizer here @overload def as_tokens( self, maze_tokenizer, # noqa: ANN001 limit: int | None = None, join_tokens_individual_maze: Literal[False] = False, ) -> list[list[str]]: ... @overload def as_tokens( self, maze_tokenizer, # noqa: ANN001 limit: int | None = None, join_tokens_individual_maze: Literal[True] = True, ) -> list[str]: ... def as_tokens( self, maze_tokenizer, # TODO: MazeTokenizer limit: int | None = None, join_tokens_individual_maze: bool = False, ) -> list[list[str]] | list[str]: """return the dataset as tokens according to the passed `maze_tokenizer` the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` if `join_tokens_individual_maze` is True, then the tokens of each maze are joined with a space, and the result is a list of strings. i.e.: >>> dataset.as_tokens(join_tokens_individual_maze=False) [["a", "b", "c"], ["d", "e", "f"]] >>> dataset.as_tokens(join_tokens_individual_maze=True) ["a b c", "d e f"] """ output: list[list[str]] = [ maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] ] if join_tokens_individual_maze: return [" ".join(tokens) for tokens in output] else: return output def __len__(self) -> int: """return the number of mazes in the dataset""" return len(self.mazes) def __eq__(self, other: object) -> bool: """compare two datasets""" if not isinstance(other, MazeDataset): raise NotImplementedError( "can only compare with other MazeDataset objects", ) # TODO: compare hashes of data instead of the data itself? return self.cfg == other.cfg and self.mazes == other.mazes def assert_equal(self, other: "MazeDataset") -> None: """assert that two datasets are equal""" assert isinstance(other, MazeDataset) assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }" @classmethod def generate( cls, cfg: MazeDatasetConfig, gen_parallel: bool = False, pool_kwargs: dict | None = None, verbose: bool = False, # TODO: what to do when unexpected kwargs are passed? **kwargs, # noqa: ARG003 ) -> "MazeDataset": """Generate a maze dataset given a config and some generation parameters""" # Copy the config to avoid modifying the original cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( json.loads(json.dumps(cfg.serialize())), ) if pool_kwargs is None: pool_kwargs = dict() maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] solved_mazes: list[SolvedMaze | None] # Configure tqdm for progress bar tqdm_kwargs: dict = dict( total=cfg_cpy.n_mazes, unit="maze", desc="generating & solving mazes", disable=not verbose, ) # TODO: don't use the global unless generating in parallel! if gen_parallel: with multiprocessing.Pool( **pool_kwargs, initializer=_maze_gen_init_worker, initargs=(cfg_cpy,), ) as pool: solved_mazes = list( tqdm.tqdm( pool.imap(_generate_maze_helper, maze_indexes), **tqdm_kwargs, ), ) else: _maze_gen_init_worker(cfg_cpy) solved_mazes = list( tqdm.tqdm( map( # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] # why does it think tolist() returns a string? _generate_maze_helper, # type: ignore[arg-type] maze_indexes.tolist(), ), **tqdm_kwargs, ), ) # Filter out None values explicitly after ensuring all results are collected solved_mazes_: list[SolvedMaze] = [ maze for maze in solved_mazes if maze is not None ] # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) # Update the config with the actual number of mazes cfg_cpy.n_mazes = len(solved_mazes_) dataset: MazeDataset = cls( cfg=cfg_cpy, mazes=solved_mazes_, ) dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy return dataset @classmethod def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": "(not implemented yet!) download a maze dataset from the internet" raise NotImplementedError("not implemented yet") @classmethod def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": """load from zanj/json""" if data[_FORMAT_KEY] == "MazeDataset:minimal": return cls._load_minimal(data) elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": return cls._load_minimal_soln_cat(data) elif data[_FORMAT_KEY] == "MazeDataset": if ( SERIALIZE_MINIMAL_THRESHOLD == -1 ): # Allow access to `_load_legacy` for profiling return cls._load_legacy(data) return cls._load_full(data) else: err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" raise KeyError( err_msg, ) @classmethod def _load_full(cls, data: JSONdict) -> "MazeDataset": assert data[_FORMAT_KEY] == "MazeDataset" return cls( cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] mazes=load_item_recursive(data["mazes"], tuple()), generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] ) @classmethod def _load_minimal(cls, data: JSONdict) -> "MazeDataset": assert data[_FORMAT_KEY] == "MazeDataset:minimal" return cls( cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] mazes=[ SolvedMaze( clist, soln[:slen, ...], ) for clist, slen, soln in zip( load_item_recursive(data["maze_connection_lists"], tuple()), load_item_recursive(data["maze_solution_lengths"], tuple()), load_item_recursive(data["maze_solutions"], tuple()), strict=False, # load_item_recursive(data["maze_endpoints"], tuple()), ) ], ) @classmethod def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset": assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat" maze_solution_lengths = load_item_recursive( data["maze_solution_lengths"], tuple(), ) maze_solutions_concat = load_item_recursive( data["maze_solutions_concat"], tuple(), ) maze_solutions = np.split( maze_solutions_concat, np.cumsum(maze_solution_lengths)[:-1], axis=0, ) return cls( cfg=load_item_recursive(data["cfg"], tuple()), generation_metadata_collected=load_item_recursive( data["generation_metadata_collected"], tuple(), ), mazes=[ SolvedMaze( connection_list=clist, solution=soln, ) for clist, soln in zip( load_item_recursive(data["maze_connection_lists"], tuple()), # load_item_recursive(data["maze_endpoints"], tuple()), maze_solutions, strict=False, ) ], ) @classmethod def _load_legacy(cls, data: JSONdict) -> "MazeDataset": """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison.""" assert data[_FORMAT_KEY] == "MazeDataset" return cls( **{ key: load_item_recursive(data[key], tuple()) for key in ["cfg", "mazes", "generation_metadata_collected"] }, ) def serialize(self) -> JSONdict: """serialize to zanj/json""" if ( SERIALIZE_MINIMAL_THRESHOLD is not None and len(self) >= SERIALIZE_MINIMAL_THRESHOLD ): return self._serialize_minimal() return self._serialize_full() def _serialize_full(self) -> JSONdict: return { _FORMAT_KEY: "MazeDataset", "cfg": json_serialize(self.cfg), "fname": self.cfg.to_fname(), "mazes": json_serialize(self.mazes), "generation_metadata_collected": json_serialize( self.generation_metadata_collected, ), } def _serialize_minimal(self) -> JSONdict: "alternate serialization where metadata is collected and mazes are stored in concatenated form" filtered_meta: MazeDataset if self.generation_metadata_collected is None: filtered_meta = self.filter_by.collect_generation_meta() else: filtered_meta = self max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes) n_mazes: int = len(filtered_meta.mazes) grid_n: int = filtered_meta.cfg.grid_n maze_connection_lists: np.ndarray = np.empty( (n_mazes, 2, grid_n, grid_n), dtype=np.bool_, ) # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32) maze_solutions: np.ndarray = np.empty( (n_mazes, max_solution_len, 2), dtype=np.int8, ) for idx, maze in enumerate(filtered_meta.mazes): maze_connection_lists[idx] = maze.connection_list # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) maze_solution_lengths[idx] = maze.solution.shape[0] maze_solutions[idx, : maze.solution.shape[0]] = maze.solution return { _FORMAT_KEY: "MazeDataset:minimal", "cfg": json_serialize(filtered_meta.cfg), "fname": filtered_meta.cfg.to_fname(), "generation_metadata_collected": json_serialize( filtered_meta.generation_metadata_collected, ), "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] # "maze_endpoints": maze_endpoints, "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] "maze_solutions": maze_solutions, # type: ignore[dict-item] } def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict: "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form" filtered_meta: MazeDataset if self.generation_metadata_collected is None: filtered_meta = self.filter_by.collect_generation_meta() else: filtered_meta = self maze_solution_lengths: np.ndarray = np.array( [m.solution.shape[0] for m in filtered_meta.mazes], dtype=np.int32, ) n_mazes: int = len(filtered_meta.mazes) grid_n: int = filtered_meta.cfg.grid_n total_solution_len: int = np.sum(maze_solution_lengths) maze_connection_lists: np.ndarray = np.empty( (n_mazes, 2, grid_n, grid_n), dtype=np.bool_, ) maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) maze_solutions_concat: np.ndarray = np.empty( (total_solution_len, 2), dtype=np.int8, ) solutions_running_idx: int = 0 for idx, maze in enumerate(filtered_meta.mazes): maze_connection_lists[idx] = maze.connection_list maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) soln_len: int = maze.solution.shape[0] maze_solution_lengths[idx] = soln_len maze_solutions_concat[ solutions_running_idx : solutions_running_idx + soln_len ] = maze.solution solutions_running_idx += soln_len return { _FORMAT_KEY: "MazeDataset:minimal_soln_cat", "cfg": json_serialize(filtered_meta.cfg), "fname": filtered_meta.cfg.to_fname(), "generation_metadata_collected": json_serialize( filtered_meta.generation_metadata_collected, ), "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] "maze_endpoints": maze_endpoints, # type: ignore[dict-item] "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item] } def update_self_config(self) -> None: """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" if self.cfg.n_mazes != len(self.mazes): warnings.warn( f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", ) self.cfg.n_mazes = len(self.mazes) def custom_maze_filter( self, method: typing.Callable[[SolvedMaze], bool], **kwargs, ) -> "MazeDataset": """filter the dataset using a custom method""" output: MazeDataset = MazeDataset( cfg=copy.deepcopy(self.cfg), mazes=[m for m in self.mazes if method(m, **kwargs)], ) output.cfg.applied_filters.append( { "name": f"__custom__:{method.__name__}", "kwargs": kwargs, }, ) output.update_self_config() return output MazeDatasetConfig._dataset_class = property( # type: ignore[method-assign, assignment] lambda self: MazeDataset, # noqa: ARG005 ) # register things with zanj register_loader_handler( LoaderHandler( check=lambda json_item, path=None, z=None: ( # type: ignore[misc] # noqa: ARG005 isinstance(json_item, typing.Mapping) and _FORMAT_KEY in json_item and json_item[_FORMAT_KEY].startswith("MazeDataset") ), load=lambda json_item, path=None, z=None: MazeDataset.load(json_item), # type: ignore[misc] # noqa: ARG005 uid="MazeDataset", source_pckg="maze_dataset.generation.maze_dataset", desc="MazeDataset", ), ) # TODO: the code below is for doing some smarter collecting and type checking. Probably will delete. """ collect either the type at the field, or the shape of the field if it is an array metadata_types: dict[str, set[type, tuple]] = dict() for maze in new_dataset: for key, value in maze.generation_meta.items(): if key not in metadata_types: metadata_types[key] = set() if isinstance(value, np.ndarray): metadata_types[key].add(value.shape) else: metadata_types[key].add(type(value)) # figure out what to do for each field metadata_actions: dict[str, typing.Callable] = dict() for key, key_type in metadata_types.items(): if all(isinstance(kt, tuple) for kt in key_type): if all(kt == (2,) for kt in key_type): # its all coords, do a statcounter on those coords metadata_actions[key] = lambda vals: Counter(tuple(x) for x in vals) elif all( (len(kt) == 2) and (kt[1] == 2) for kt in key_type ): # its a list of coords, do a statcounter on those coords metadata_actions[key] = lambda vals: Counter( tuple(x) for x in np.concatenate(vals) ) else: # its a list of something else, do a counter on those # TODO: throw except here? metadata_actions[key] = Counter elif all(kt in (bool, int, float) for kt in key_type): # statcounter for numeric types metadata_actions[key] = StatCounter elif all(kt == str for kt in key_type): # counter for string types metadata_actions[key] = Counter else: # counter for everything else # TODO: throw except here? metadata_actions[key] = Counter """ ``````{ end_of_file="maze_dataset/dataset/maze_dataset.py" } ``````{ path="maze_dataset/dataset/maze_dataset_config.py" } "implements `MazeDatasetConfig` which is used to generate or load a dataset" import hashlib import importlib.metadata import json import typing import warnings from typing import Callable import numpy as np from jaxtyping import Float from muutils.json_serialize import ( serializable_dataclass, serializable_field, ) from muutils.json_serialize.util import ( safe_getsource, string_as_lines, ) from muutils.misc import sanitize_fname, shorten_numerical_to_str from maze_dataset.constants import Coord, CoordTup from maze_dataset.dataset.dataset import ( GPTDatasetConfig, ) from maze_dataset.dataset.success_predict_math import cfg_success_predict_fn from maze_dataset.generation.generators import _GENERATORS_PERCOLATED, GENERATORS_MAP SERIALIZE_MINIMAL_THRESHOLD: int | None = 100 """If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`. Setting to None means that `serialize_minimal` will never be used. Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only.""" MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5 "length of the has, in characters, of the hash in the fname of a `MazeDatasetConfig`" _PercolationSuccessArray = Float[ np.ndarray, "p/grid_n/deadends/endpoints_not_equal/generator_func=5", ] class NoPercolationInConfigError(ValueError): """raised when trying to predict the success fraction of a config that doesn't have percolation""" pass class SuccessChanceTooSmallError(ValueError): """raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`""" pass def set_serialize_minimal_threshold(threshold: int | None) -> None: "get the global SERIALIZE_MINIMAL_THRESHOLD" global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603 SERIALIZE_MINIMAL_THRESHOLD = threshold def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable: "get the maze constructor from `GENERATORS_MAP`" if isinstance(maze_ctor_serialized, dict): # this is both the new and old version of the serialization return GENERATORS_MAP[maze_ctor_serialized["__name__"]] elif isinstance(maze_ctor_serialized, str): # this is a version I switched to for a while but now we are switching back warnings.warn( "you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: " "https://github.com/understanding-search/maze-dataset/issues/new", ) return GENERATORS_MAP[maze_ctor_serialized] else: err_msg: str = f"maze_ctor_serialized is of type {type(maze_ctor_serialized) = }, expected str or dict\n{maze_ctor_serialized = }" raise TypeError(err_msg) EndpointKwargsType = dict[ typing.Literal[ "allowed_start", "allowed_end", "deadend_start", "deadend_end", "endpoints_not_equal", "except_on_no_valid_endpoint", ], bool | None | list[tuple[int, int]], ] """type hint for `MazeDatasetConfig.endpoint_kwargs` - `except_on_no_valid_endpoint : bool` (default: `True`) some of the conditions (dead ends if a maze is very open, no path between given start and end) can cause the maze generation to fail. if `except_on_no_valid_endpoint` is `True`, then the maze generation will raise an error if it fails to generate a valid maze. however, if `False`, then the maze generation will return a dataset with fewer mazes than requested. If you are generating large datasets, consider using `MazeDatasetConfig.success_fraction_compensate()` this uses a pysr-created function to roughly estimate the success fraction of the dataset. - `allowed_start : list[tuple[int, int]]` (default: `None`) list of allowed starting position coordinates - `allowed_end : list[tuple[int, int]]` (default: `None`) list of allowed ending position coordinates - `deadend_start : bool` (default: `False`) if `True`, the starting position must be a dead end - `deadend_end : bool` (default: `False`) if `True`, the ending position must be a dead end - `endpoints_not_equal : bool` (default: `True`) if `True`, the starting and ending positions must be different """ def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType: if data.get("endpoint_kwargs") is None: return dict() else: return { k: ( # bools and Nones are fine v if (isinstance(v, bool) or v is None) # assume its a CoordList else [tuple(x) for x in v] # muutils/zanj saves tuples as lists ) for k, v in data["endpoint_kwargs"].items() } # not private because we need this to show up in docs @serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"]) class MazeDatasetConfig_base(GPTDatasetConfig): # noqa: N801 """base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here""" # NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only grid_n: int = serializable_field() # type: ignore[misc] # not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters n_mazes: int = serializable_field(compare=False) # type: ignore[misc] maze_ctor: Callable = serializable_field( default=GENERATORS_MAP["gen_dfs"], serialization_fn=lambda gen_func: { "__name__": gen_func.__name__, "__module__": gen_func.__module__, # NOTE: this was causing hashing issues on 3.13 vs older versions because somehow, # the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY # so we just uh. strip it all now. # see: # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53 # https://www.diffchecker.com/tqIMSevy/ # update: we also need to filter for empty lines. B) "__doc__": [ line.strip() for line in string_as_lines(gen_func.__doc__) if line.strip() ], "source_code": safe_getsource(gen_func), }, loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]), assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures ) maze_ctor_kwargs: dict = serializable_field( default_factory=dict, serialization_fn=lambda kwargs: kwargs, loading_fn=lambda data: ( dict() if data.get("maze_ctor_kwargs", None) is None # this should handle the backwards compatibility else data["maze_ctor_kwargs"] ), ) endpoint_kwargs: EndpointKwargsType = serializable_field( default_factory=dict, serialization_fn=lambda kwargs: kwargs, loading_fn=_load_endpoint_kwargs, assert_type=False, ) # NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*, # and so we need to save an `None` here or this wont load the `fname` field on load # this is a total mess, and very confusing, and entirely my fault _fname_loaded: str | None = serializable_field( default=None, compare=False, serialization_fn=lambda _: None, loading_fn=lambda data: data.get("fname", None), ) @property def grid_shape(self) -> CoordTup: """return the shape of the grid as a tuple""" return (self.grid_n, self.grid_n) @property def grid_shape_np(self) -> Coord: """return the shape of the grid as a numpy array""" return np.array(self.grid_shape) @property def max_grid_n(self) -> int: """return the maximum of the grid shape""" return max(self.grid_shape) def _serialize_base( self, applied_filters__skip__collect_generation_meta: bool = True ) -> dict: """serialize the base config for user in `stable_hash_cfg()` and `to_fname()` - note that the _fname_loaded will always be `None` to avoid infinite recursion - note that we **do not** by default include information about metadata collection here, since otherwise loading a dataset that we minified by collecting the metadata would be impossible but for comparing things, we do store it when serializing properly by setting `applied_filters__skip__collect_generation_meta=False` """ serialized: dict = MazeDatasetConfig_base.serialize(self) if applied_filters__skip__collect_generation_meta: serialized["applied_filters"] = [ x for x in serialized["applied_filters"] if x.get("name", None) != "collect_generation_meta" ] return serialized def _stable_str_dump(self) -> str: return json.dumps( self._serialize_base(), sort_keys=True, indent=None, ) def stable_hash_cfg(self) -> int: """return a stable hash of the config""" return int.from_bytes( hashlib.md5( # noqa: S324 bytes(self._stable_str_dump(), "ascii") ).digest(), "big", ) def to_fname(self) -> str: """return a unique identifier (valid as a filename) for this config""" n_mazes_str: str = shorten_numerical_to_str(self.n_mazes) maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_") hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH return sanitize_fname( f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}", ) # NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only @serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) class MazeDatasetConfig(MazeDatasetConfig_base): # type: ignore[misc] """config object which is passed to `MazeDataset.from_config` to generate or load a dataset # Parameters: - `name : str` name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname` - `grid_n : int` grid size of the maze (number of rows/columns) - `n_mazes : int` number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate. see `EndpointKwargsType` for more details. - `maze_ctor : Callable` maze generator function. This should be a function that takes a grid size and returns a maze. This will usually be one of the functions in `LatticeMazeGenerators`. - `maze_ctor_kwargs : dict` keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using. - `endpoint_kwargs : EndpointKwargsType` keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info. - `applied_filters : list[dict]` list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, but these are stored with the config in case you want to re-generate the dataset with the same filters. """ @property def config_version(self) -> str: """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" return "1.0" @property def versions(self) -> dict: """return the versions of the config and the maze_dataset""" return dict( config=self.config_version, maze_dataset=importlib.metadata.version("maze_dataset"), ) def serialize(self) -> dict: "serialize the MazeDatasetConfig with all fields and fname" return { **self._serialize_base( applied_filters__skip__collect_generation_meta=False ), "fname": self.to_fname(), "versions": self.versions, } def summary(self) -> dict: """return a summary of the config""" # do we run this to make sure it doesn't error? super_summary: dict = super().summary() assert super_summary self_ser: dict = self.serialize() return dict( name=self.name, fname=self.to_fname(), sdc_hash=self.stable_hash_cfg(), seed=self.seed, seq_len_min=self.seq_len_min, seq_len_max=self.seq_len_max, applied_filters=self.applied_filters, grid_n=self_ser["grid_n"], n_mazes=self_ser["n_mazes"], maze_ctor_name=self_ser["maze_ctor"]["__name__"], maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], endpoint_kwargs=self_ser["endpoint_kwargs"], ) def _to_ps_array(self) -> _PercolationSuccessArray: """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. used in predicting the success rate """ try: assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" ) assert "p" in self.maze_ctor_kwargs, ( f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" ) assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" ) except AssertionError as e: err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" raise NoPercolationInConfigError( err_msg, ) from e endpoints_unique_flag: int = int( # we are pretty sure it will be an int or bool here self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] ) # adjustment for bknutson0 if not ( self.endpoint_kwargs.get("deadend_start", False) and self.endpoint_kwargs.get("deadend_end", False) ): # we didnt train on this, but if either endpoint is not required to be in a dead end # then requiring the endpoints to be unique does not really affect the success rate # (except for very small percolation values, pure percolation generation) endpoints_unique_flag = 0 return np.array( [ float(self.maze_ctor_kwargs["p"]), float(self.grid_n), float( int( self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] or self.endpoint_kwargs.get("deadend_end", False), ), ), float(endpoints_unique_flag), float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), ], dtype=np.float64, ) @classmethod def _from_ps_array( cls, arr: _PercolationSuccessArray, name: str = "predict", n_mazes: int = 100, **kwargs, ) -> "MazeDatasetConfig": """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. # Returns: - `MazeDatasetConfig` Config corresponding to `arr` """ return cls( name=name, grid_n=int(arr[1]), n_mazes=n_mazes, maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], maze_ctor_kwargs={"p": float(arr[0])}, endpoint_kwargs=dict( deadend_start=bool(arr[2]), deadend_end=bool(arr[2]), endpoints_not_equal=bool(arr[3]), except_on_no_valid_endpoint=False, ), **kwargs, ) def success_fraction_estimate( self, except_if_all_success_expected: bool = False, ) -> float: """Estimate the success fraction of this config. only valid when the generator is a percolation generator, and endpoints are enforced to be dead ends more information on where this comes from can be found in - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` - `estimate_dataset_fractions.ipynb` - `maze_dataset.benchmarks.sweep_fit` # Parameters: - `except_if_all_success_expected : bool` if `True`, don't raise an error if the success fraction is below the threshold. will always return `1.0` if the config is not expected to fail # Returns: - `float` estimated success fraction # Raises: - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` """ try: return cfg_success_predict_fn(self) except NoPercolationInConfigError as e: if except_if_all_success_expected: raise e # noqa: TRY201 return 1.0 def success_fraction_compensate( self, safety_margin: float = 1.2, except_if_all_success_expected: bool = False, epsilon: float = 1e-2, ) -> "MazeDatasetConfig": """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` more information on where this comes from can be found in - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` - `estimate_dataset_fractions.ipynb` - `maze_dataset.benchmarks.sweep_fit` # Parameters: - `safety_margin : float` safety margin to apply to the success fraction estimate (defaults to `1.2`, or 20% more mazes than estimated) - `except_if_all_success_expected : bool` if `True`, don't raise an error if the success fraction is below the threshold. this is passed to `MazeDatasetConfig.success_fraction_estimate`. if your config isn't expected to fail, passing this might mean you generate more mazes than needed since `safety_margin` is still applied. (defaults to `False`) - `epsilon : float` raise `SuccessChanceTooSmallError` if the success fraction is below this threshold (defaults to `1e-2`) # Returns: - `MazeDatasetConfig` new config with adjusted `n_mazes` # Raises: - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` """ # compute and check the success fraction success_fraction: float = self.success_fraction_estimate( except_if_all_success_expected=except_if_all_success_expected, ) if success_fraction < epsilon: err_msg: str = ( f"{success_fraction = } is below the threshold of {epsilon = }" ) raise SuccessChanceTooSmallError( err_msg, ) # compute the new number of mazes n_mazes: int = self.n_mazes new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 # put it in a new config and return cfg_dict: dict = self.serialize() cfg_dict["n_mazes"] = new_n_mazes return MazeDatasetConfig.load(cfg_dict) ``````{ end_of_file="maze_dataset/dataset/maze_dataset_config.py" } ``````{ path="maze_dataset/dataset/rasterized.py" } """a special `RasterizedMazeDataset` that returns 2 images, one for input and one for target, for each maze this lets you match the input and target format of the [`easy_2_hard`](https://github.com/aks2203/easy-to-hard) dataset see their paper: ```bibtex @misc{schwarzschild2021learn, title={Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks}, author={Avi Schwarzschild and Eitan Borgnia and Arjun Gupta and Furong Huang and Uzi Vishkin and Micah Goldblum and Tom Goldstein}, year={2021}, eprint={2106.04537}, archivePrefix={arXiv}, primaryClass={cs.LG} } ``` """ import typing from pathlib import Path import numpy as np from jaxtyping import Float, Int from muutils.json_serialize import serializable_dataclass, serializable_field from zanj import ZANJ from maze_dataset import MazeDataset, MazeDatasetConfig from maze_dataset.maze import PixelColors, SolvedMaze from maze_dataset.maze.lattice_maze import PixelGrid, _remove_isolated_cells def _extend_pixels( image: Int[np.ndarray, "x y rgb"], n_mult: int = 2, n_bdry: int = 1, ) -> Int[np.ndarray, "n_mult*x+2*n_bdry n_mult*y+2*n_bdry rgb"]: wall_fill: int = PixelColors.WALL[0] assert all(x == wall_fill for x in PixelColors.WALL), ( "PixelColors.WALL must be a single value" ) output: np.ndarray = np.repeat( np.repeat( image, n_mult, axis=0, ), n_mult, axis=1, ) # pad on all sides by n_bdry return np.pad( output, pad_width=((n_bdry, n_bdry), (n_bdry, n_bdry), (0, 0)), mode="constant", constant_values=wall_fill, ) _RASTERIZED_CFG_ADDED_PARAMS: list[str] = [ "remove_isolated_cells", "extend_pixels", "endpoints_as_open", ] def process_maze_rasterized_input_target( maze: SolvedMaze, remove_isolated_cells: bool = True, extend_pixels: bool = True, endpoints_as_open: bool = False, ) -> Float[np.ndarray, "in/tgt=2 x y rgb=3"]: """turn a single `SolvedMaze` into an array representation has extra options for matching the format in https://github.com/aks2203/easy-to-hard # Parameters: - `maze: SolvedMaze` the maze to process - `remove_isolated_cells: bool` whether to set isolated cells (no connections) to walls (default: `True`) - `extend_pixels: bool` whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) (default: `True`) - `endpoints_as_open: bool` whether to set endpoints to open (default: `False`) """ # problem and solution mazes maze_pixels: PixelGrid = maze.as_pixels(show_endpoints=True, show_solution=True) problem_maze: PixelGrid = maze_pixels.copy() solution_maze: PixelGrid = maze_pixels.copy() # in problem maze, set path to open problem_maze[(problem_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN # wherever solution maze is PixelColors.OPEN, set it to PixelColors.WALL solution_maze[(solution_maze == PixelColors.OPEN).all(axis=-1)] = PixelColors.WALL # wherever it is solution, set it to PixelColors.OPEN solution_maze[(solution_maze == PixelColors.PATH).all(axis=-1)] = PixelColors.OPEN if endpoints_as_open: for color in (PixelColors.START, PixelColors.END): solution_maze[(solution_maze == color).all(axis=-1)] = PixelColors.OPEN # postprocess to match original easy_2_hard dataset if remove_isolated_cells: problem_maze = _remove_isolated_cells(problem_maze) solution_maze = _remove_isolated_cells(solution_maze) if extend_pixels: problem_maze = _extend_pixels(problem_maze) solution_maze = _extend_pixels(solution_maze) return np.array([problem_maze, solution_maze]) # TYPING: error: Attributes without a default cannot follow attributes with one [misc] @serializable_dataclass class RasterizedMazeDatasetConfig(MazeDatasetConfig): # type: ignore[misc] """adds options which we then pass to `process_maze_rasterized_input_target` - `remove_isolated_cells: bool` whether to set isolated cells to walls - `extend_pixels: bool` whether to extend pixels to match easy_2_hard dataset (2x2 cells, extra 1 pixel row of wall around maze) - `endpoints_as_open: bool` whether to set endpoints to open """ remove_isolated_cells: bool = serializable_field(default=True) extend_pixels: bool = serializable_field(default=True) endpoints_as_open: bool = serializable_field(default=False) class RasterizedMazeDataset(MazeDataset): "subclass of `MazeDataset` that uses a `RasterizedMazeDatasetConfig`" cfg: RasterizedMazeDatasetConfig # this override here is intentional def __getitem__(self, idx: int) -> Float[np.ndarray, "item in/tgt=2 x y rgb=3"]: # type: ignore[override] """get a single maze""" # get the solved maze solved_maze: SolvedMaze = self.mazes[idx] return process_maze_rasterized_input_target( maze=solved_maze, remove_isolated_cells=self.cfg.remove_isolated_cells, extend_pixels=self.cfg.extend_pixels, endpoints_as_open=self.cfg.endpoints_as_open, ) def get_batch( self, idxs: list[int] | None, ) -> Float[np.ndarray, "in/tgt=2 item x y rgb=3"]: """get a batch of mazes as a tensor, from a list of indices""" if idxs is None: idxs = list(range(len(self))) inputs: list[Float[np.ndarray, "x y rgb=3"]] targets: list[Float[np.ndarray, "x y rgb=3"]] inputs, targets = zip(*[self[i] for i in idxs], strict=False) # type: ignore[assignment] return np.array([inputs, targets]) # override here is intentional @classmethod def from_config( cls, cfg: RasterizedMazeDatasetConfig | MazeDatasetConfig, # type: ignore[override] do_generate: bool = True, load_local: bool = True, save_local: bool = True, zanj: ZANJ | None = None, do_download: bool = True, local_base_path: Path = Path("data/maze_dataset"), except_on_config_mismatch: bool = True, allow_generation_metadata_filter_mismatch: bool = True, verbose: bool = False, **kwargs, ) -> "RasterizedMazeDataset": """create a rasterized maze dataset from a config priority of loading: 1. load from local 2. download 3. generate """ return typing.cast( RasterizedMazeDataset, super().from_config( cfg=cfg, do_generate=do_generate, load_local=load_local, save_local=save_local, zanj=zanj, do_download=do_download, local_base_path=local_base_path, except_on_config_mismatch=except_on_config_mismatch, allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, verbose=verbose, **kwargs, ), ) @classmethod def from_config_augmented( cls, cfg: RasterizedMazeDatasetConfig, **kwargs, ) -> "RasterizedMazeDataset": """loads either a maze transformer dataset or an easy_2_hard dataset""" _cfg_temp: MazeDatasetConfig = MazeDatasetConfig.load(cfg.serialize()) return cls.from_base_MazeDataset( cls.from_config(cfg=_cfg_temp, **kwargs), added_params={ k: v for k, v in cfg.serialize().items() if k in _RASTERIZED_CFG_ADDED_PARAMS }, ) @classmethod def from_base_MazeDataset( cls, base_dataset: MazeDataset, added_params: dict | None = None, ) -> "RasterizedMazeDataset": """loads either a maze transformer dataset or an easy_2_hard dataset""" if added_params is None: added_params = dict( remove_isolated_cells=True, extend_pixels=True, ) cfg: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( { **base_dataset.cfg.serialize(), **added_params, }, ) output: RasterizedMazeDataset = cls( cfg=cfg, mazes=base_dataset.mazes, ) return output def plot(self, count: int | None = None, show: bool = True) -> tuple | None: """plot the first `count` mazes in the dataset""" import matplotlib.pyplot as plt print(f"{self[0][0].shape = }, {self[0][1].shape = }") count = count or len(self) if count == 0: print("No mazes to plot for dataset") return None fig, axes = plt.subplots(2, count, figsize=(15, 5)) if count == 1: axes = [axes] for i in range(count): axes[0, i].imshow(self[i][0]) axes[1, i].imshow(self[i][1]) # remove ticks axes[0, i].set_xticks([]) axes[0, i].set_yticks([]) axes[1, i].set_xticks([]) axes[1, i].set_yticks([]) if show: plt.show() return fig, axes def make_numpy_collection( base_cfg: RasterizedMazeDatasetConfig, grid_sizes: list[int], from_config_kwargs: dict | None = None, verbose: bool = True, key_fmt: str = "{size}x{size}", ) -> dict[ typing.Literal["configs", "arrays"], dict[str, RasterizedMazeDatasetConfig | np.ndarray], ]: """create a collection of configs and arrays for different grid sizes, in plain tensor form output is of structure: ``` { "configs": { "x": RasterizedMazeDatasetConfig, ... }, "arrays": { "x": np.ndarray, ... }, } ``` """ if from_config_kwargs is None: from_config_kwargs = {} datasets: dict[int, RasterizedMazeDataset] = {} for size in grid_sizes: if verbose: print(f"Generating dataset for maze size {size}...") cfg_temp: RasterizedMazeDatasetConfig = RasterizedMazeDatasetConfig.load( base_cfg.serialize(), ) cfg_temp.grid_n = size datasets[size] = RasterizedMazeDataset.from_config_augmented( cfg=cfg_temp, **from_config_kwargs, ) return dict( configs={ key_fmt.format(size=size): dataset.cfg for size, dataset in datasets.items() }, arrays={ # get_batch(None) returns a single tensor of shape (n, 2, x, y, 3) key_fmt.format(size=size): dataset.get_batch(None) for size, dataset in datasets.items() }, ) ``````{ end_of_file="maze_dataset/dataset/rasterized.py" } ``````{ path="maze_dataset/dataset/success_predict_math.py" } """math for getting the `MazeDatasetConfig.success_fraction_estimate()` function to work Desmos link: https://www.desmos.com/calculator/qllvhwftvy """ import numpy as np from jaxtyping import Float def sigmoid(x: float) -> float: r"$\sigma(x) = \frac{1}{1 + e^{-x}}$" return 1 / (1 + np.exp(-x)) # sigmoid_shifted = lambda x: 1 / (1 + np.exp(-1000 * (x - 0.5))) # r"sigmoid(x)= 1 / (1 + e^{-b(x-0.5)})" # g_poly = lambda q, a: 1 - np.abs(2 * q - 1) ** a # r"g(q,a) = 1 - (|2q-1|)^{a}" # f_poly = lambda q, a: q * g_poly(q, a) # r"f(q,a) = q * g(q,a)" # h_func = lambda q, a: f_poly(q, a) * (1 - sigmoid_shifted(q)) + (1 - f_poly(1 - q, a)) * sigmoid_shifted(q) # r"h(q,a,b) = f(q,a) * (1-s(q,b)) + (1-f(1-q,a)) * s(q,b)" # A_scaling = lambda q, a, w: w * g_poly(q, a) # r"A(q) = b * g(q, a)" def sigmoid_shifted(x: float) -> float: r"\sigma_s(x)= \frac{1}{1 + e^{-10^3 \cdot (x-0.5)}}" return 1 / (1 + np.exp(-1000 * (x - 0.5))) def g_poly(q: float, a: float) -> float: r"$g(q,a) = 1 - (|2q-1|)^{a}$" return 1 - np.abs(2 * q - 1) ** a def f_poly(q: float, a: float) -> float: r"$f(q,a) = q \cdot g(q,a)$" return q * g_poly(q, a) def h_func(q: float, a: float) -> float: r"""$h(q,a,b) = f(q,a) \cdot (1-\sigma_s(q)) + (1-f(1-q,a)) \cdot \sigma_s(q)$""" return f_poly(q, a) * (1 - sigmoid_shifted(q)) + ( 1 - f_poly(1 - q, a) ) * sigmoid_shifted(q) def A_scaling(q: float, a: float, w: float) -> float: r"$A(q) = w \cdot g(q, a)$" return w * g_poly(q, a) def soft_step( x: float | np.floating | Float[np.ndarray, " n"], p: float | np.floating, alpha: float | np.floating = 5, w: float | np.floating = 50, ) -> float: """when p is close to 0.5 acts like the identity wrt x, but when p is close to 0 or 1, pushes x to 0 or 1 (whichever is closest) https://www.desmos.com/calculator/qllvhwftvy """ # TYPING: this is messed up, some of these args can be arrays but i dont remember which? return h_func( x, # type: ignore[arg-type] A_scaling(p, alpha, w), # type: ignore[arg-type] ) # `cfg: MazeDatasetConfig` but we can't import that because it would create a circular import def cfg_success_predict_fn(cfg) -> float: # noqa: ANN001 "learned by pysr, see `estimate_dataset_fractions.ipynb` and `maze_dataset.benchmark.config_fit`" x = cfg._to_ps_array() raw_val: float = sigmoid( ( ( ((sigmoid((x[1] - x[3]) ** 3) * -4.721228) - (x[3] * 1.4636494)) * ( x[2] * ( x[4] + (((x[0] + 0.048765484) ** 9.746339) + (0.8998194 ** x[1])) ) ) ) + (2.4524326 ** (2.9501643 - x[0])) ) * ( ( (((0.9077277 - x[0]) * ((x[4] * 1.0520288) ** x[1])) + x[0]) * sigmoid(x[1]) ** 3 ) + -0.18268494 ), ) return soft_step( x=raw_val, p=x[0], alpha=5, # manually tuned w=10, # manually tuned ) ``````{ end_of_file="maze_dataset/dataset/success_predict_math.py" } ``````{ path="maze_dataset/generation/__init__.py" } """generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators` `DEFAULT_GENERATORS` is a list of generator name, generator kwargs pairs used in tests and demos """ from maze_dataset.generation.generators import ( GENERATORS_MAP, LatticeMazeGenerators, get_maze_with_solution, numpy_rng, ) __all__ = [ # submodules "default_generators", "generators", "seed", # imports "LatticeMazeGenerators", "GENERATORS_MAP", "get_maze_with_solution", "numpy_rng", ] ``````{ end_of_file="maze_dataset/generation/__init__.py" } ``````{ path="maze_dataset/generation/default_generators.py" } """`DEFAULT_GENERATORS` is a list of generator name, generator kwargs pairs used in tests and demos""" DEFAULT_GENERATORS: list[tuple[str, dict]] = [ ("gen_dfs", dict()), ("gen_dfs", dict(do_forks=False)), ("gen_dfs", dict(accessible_cells=20)), ("gen_dfs", dict(max_tree_depth=0.5)), ("gen_wilson", dict()), # ("gen_percolation", dict(p=0.1)), ( "gen_percolation", dict(p=1.0), ), # anything less than this and tests will stochastically fail ("gen_dfs_percolation", dict(p=0.1)), ("gen_dfs_percolation", dict(p=0.4)), # ("gen_prim", dict()), # ("gen_prim", dict(do_forks=False)), # ("gen_prim", dict(accessible_cells=0.5)), # ("gen_prim", dict(max_tree_depth=0.5)), # ("gen_prim", dict(accessible_cells=0.5, max_tree_depth=0.5)), ("gen_kruskal", dict()), ("gen_recursive_division", dict()), ] ``````{ end_of_file="maze_dataset/generation/default_generators.py" } ``````{ path="maze_dataset/generation/generators.py" } """generation functions have signature `(grid_shape: Coord, **kwargs) -> LatticeMaze` and are methods in `LatticeMazeGenerators`""" import random import warnings from typing import Any, Callable import numpy as np from jaxtyping import Bool from maze_dataset.constants import CoordArray, CoordTup from maze_dataset.generation.seed import GLOBAL_SEED from maze_dataset.maze import ConnectionList, Coord, LatticeMaze, SolvedMaze from maze_dataset.maze.lattice_maze import NEIGHBORS_MASK, _fill_edges_with_walls numpy_rng = np.random.default_rng(GLOBAL_SEED) random.seed(GLOBAL_SEED) def _random_start_coord( grid_shape: Coord, start_coord: Coord | CoordTup | None, ) -> Coord: "picking a random start coord within the bounds of `grid_shape` if none is provided" start_coord_: Coord if start_coord is None: start_coord_ = np.random.randint( 0, # lower bound np.maximum(grid_shape - 1, 1), # upper bound (at least 1) size=len(grid_shape), # dimensionality ) else: start_coord_ = np.array(start_coord) return start_coord_ def get_neighbors_in_bounds( coord: Coord, grid_shape: Coord, ) -> CoordArray: "get all neighbors of a coordinate that are within the bounds of the grid" # get all neighbors neighbors: CoordArray = coord + NEIGHBORS_MASK # filter neighbors by being within grid bounds neighbors_in_bounds: CoordArray = neighbors[ (neighbors >= 0).all(axis=1) & (neighbors < grid_shape).all(axis=1) ] return neighbors_in_bounds class LatticeMazeGenerators: """namespace for lattice maze generation algorithms examples of generated mazes can be found here: https://understanding-search.github.io/maze-dataset/examples/maze_examples.html """ @staticmethod def gen_dfs( grid_shape: Coord | CoordTup, lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, randomized_stack: bool = False, start_coord: Coord | None = None, ) -> LatticeMaze: """generate a lattice maze using depth first search, iterative # Arguments - `grid_shape: Coord`: the shape of the grid - `lattice_dim: int`: the dimension of the lattice (default: `2`) - `accessible_cells: int | float |None`: the number of accessible cells in the maze. If `None`, defaults to the total number of cells in the grid. if a float, asserts it is <= 1 and treats it as a proportion of **total cells** (default: `None`) - `max_tree_depth: int | float | None`: the maximum depth of the tree. If `None`, defaults to `2 * accessible_cells`. if a float, asserts it is <= 1 and treats it as a proportion of the **sum of the grid shape** (default: `None`) - `do_forks: bool`: whether to allow forks in the maze. If `False`, the maze will be have no forks and will be a simple hallway. - `start_coord: Coord | None`: the starting coordinate of the generation algorithm. If `None`, defaults to a random coordinate. # algorithm 1. Choose the initial cell, mark it as visited and push it to the stack 2. While the stack is not empty 1. Pop a cell from the stack and make it a current cell 2. If the current cell has any neighbours which have not been visited 1. Push the current cell to the stack 2. Choose one of the unvisited neighbours 3. Remove the wall between the current cell and the chosen cell 4. Mark the chosen cell as visited and push it to the stack """ # Default values if no constraints have been passed grid_shape_: Coord = np.array(grid_shape) n_total_cells: int = int(np.prod(grid_shape_)) n_accessible_cells: int if accessible_cells is None: n_accessible_cells = n_total_cells elif isinstance(accessible_cells, float): assert accessible_cells <= 1, ( f"accessible_cells must be an int (count) or a float in the range [0, 1] (proportion), got {accessible_cells}" ) n_accessible_cells = int(accessible_cells * n_total_cells) else: assert isinstance(accessible_cells, int) n_accessible_cells = accessible_cells if max_tree_depth is None: max_tree_depth = ( 2 * n_total_cells ) # We define max tree depth counting from the start coord in two directions. Therefore we divide by two in the if clause for neighboring sites later and multiply by two here. elif isinstance(max_tree_depth, float): assert max_tree_depth <= 1, ( f"max_tree_depth must be an int (count) or a float in the range [0, 1] (proportion), got {max_tree_depth}" ) max_tree_depth = int(max_tree_depth * np.sum(grid_shape_)) # choose a random start coord start_coord = _random_start_coord(grid_shape_, start_coord) # initialize the maze with no connections connection_list: ConnectionList = np.zeros( (lattice_dim, grid_shape_[0], grid_shape_[1]), dtype=np.bool_, ) # initialize the stack with the target coord visited_cells: set[tuple[int, int]] = set() visited_cells.add(tuple(start_coord)) # this wasnt a bug after all lol stack: list[Coord] = [start_coord] # initialize tree_depth_counter current_tree_depth: int = 1 # loop until the stack is empty or n_connected_cells is reached while stack and (len(visited_cells) < n_accessible_cells): # get the current coord from the stack current_coord: Coord if randomized_stack: current_coord = stack.pop(random.randint(0, len(stack) - 1)) else: current_coord = stack.pop() # filter neighbors by being within grid bounds and being unvisited unvisited_neighbors_deltas: list[tuple[Coord, Coord]] = [ (neighbor, delta) for neighbor, delta in zip( current_coord + NEIGHBORS_MASK, NEIGHBORS_MASK, strict=False, ) if ( (tuple(neighbor) not in visited_cells) and (0 <= neighbor[0] < grid_shape_[0]) and (0 <= neighbor[1] < grid_shape_[1]) ) ] # don't continue if max_tree_depth/2 is already reached (divide by 2 because we can branch to multiple directions) if unvisited_neighbors_deltas and ( current_tree_depth <= max_tree_depth / 2 ): # if we want a maze without forks, simply don't add the current coord back to the stack if do_forks and (len(unvisited_neighbors_deltas) > 1): stack.append(current_coord) # choose one of the unvisited neighbors chosen_neighbor, delta = random.choice(unvisited_neighbors_deltas) # add connection dim: int = int(np.argmax(np.abs(delta))) # if positive, down/right from current coord # if negative, up/left from current coord (down/right from neighbor) clist_node: Coord = ( current_coord if (delta.sum() > 0) else chosen_neighbor ) connection_list[dim, clist_node[0], clist_node[1]] = True # add to visited cells and stack visited_cells.add(tuple(chosen_neighbor)) stack.append(chosen_neighbor) # Update current tree depth current_tree_depth += 1 else: current_tree_depth -= 1 return LatticeMaze( connection_list=connection_list, generation_meta=dict( func_name="gen_dfs", grid_shape=grid_shape_, start_coord=start_coord, n_accessible_cells=int(n_accessible_cells), max_tree_depth=int(max_tree_depth), # oh my god this took so long to track down. its almost 5am and I've spent like 2 hours on this bug # it was checking that len(visited_cells) == n_accessible_cells, but this means that the maze is # treated as fully connected even when it is most certainly not, causing solving the maze to break fully_connected=bool(len(visited_cells) == n_total_cells), visited_cells={tuple(int(x) for x in coord) for coord in visited_cells}, ), ) @staticmethod def gen_prim( grid_shape: Coord | CoordTup, lattice_dim: int = 2, accessible_cells: float | None = None, max_tree_depth: float | None = None, do_forks: bool = True, start_coord: Coord | None = None, ) -> LatticeMaze: "(broken!) generate a lattice maze using Prim's algorithm" warnings.warn( "gen_prim does not correctly implement prim's algorithm, see issue: https://github.com/understanding-search/maze-dataset/issues/12", ) return LatticeMazeGenerators.gen_dfs( grid_shape=grid_shape, lattice_dim=lattice_dim, accessible_cells=accessible_cells, max_tree_depth=max_tree_depth, do_forks=do_forks, start_coord=start_coord, randomized_stack=True, ) @staticmethod def gen_wilson( grid_shape: Coord | CoordTup, **kwargs, ) -> LatticeMaze: """Generate a lattice maze using Wilson's algorithm. # Algorithm Wilson's algorithm generates an unbiased (random) maze sampled from the uniform distribution over all mazes, using loop-erased random walks. The generated maze is acyclic and all cells are part of a unique connected space. https://en.wikipedia.org/wiki/Maze_generation_algorithm#Wilson's_algorithm """ assert not kwargs, ( f"gen_wilson does not take any additional arguments, got {kwargs = }" ) grid_shape_: Coord = np.array(grid_shape) # Initialize grid and visited cells connection_list: ConnectionList = np.zeros((2, *grid_shape_), dtype=np.bool_) visited: Bool[np.ndarray, "x y"] = np.zeros(grid_shape_, dtype=np.bool_) # Choose a random cell and mark it as visited start_coord: Coord = _random_start_coord(grid_shape_, None) visited[start_coord[0], start_coord[1]] = True del start_coord while not visited.all(): # Perform loop-erased random walk from another random cell # Choose walk_start only from unvisited cells unvisited_coords: CoordArray = np.column_stack(np.where(~visited)) walk_start: Coord = unvisited_coords[ np.random.choice(unvisited_coords.shape[0]) ] # Perform the random walk path: list[Coord] = [walk_start] current: Coord = walk_start # exit the loop once the current path hits a visited cell while not visited[current[0], current[1]]: # find a valid neighbor (one always exists on a lattice) neighbors: CoordArray = get_neighbors_in_bounds(current, grid_shape_) next_cell: Coord = neighbors[np.random.choice(neighbors.shape[0])] # Check for loop loop_exit: int | None = None for i, p in enumerate(path): if np.array_equal(next_cell, p): loop_exit = i break # erase the loop, or continue the walk if loop_exit is not None: # this removes everything after and including the loop start path = path[: loop_exit + 1] # reset current cell to end of path current = path[-1] else: path.append(next_cell) current = next_cell # Add the path to the maze for i in range(len(path) - 1): c_1: Coord = path[i] c_2: Coord = path[i + 1] # find the dimension of the connection delta: Coord = c_2 - c_1 dim: int = int(np.argmax(np.abs(delta))) # if positive, down/right from current coord # if negative, up/left from current coord (down/right from neighbor) clist_node: Coord = c_1 if (delta.sum() > 0) else c_2 connection_list[dim, clist_node[0], clist_node[1]] = True visited[c_1[0], c_1[1]] = True # we dont add c_2 because the last c_2 will have already been visited return LatticeMaze( connection_list=connection_list, generation_meta=dict( func_name="gen_wilson", grid_shape=grid_shape_, fully_connected=True, ), ) @staticmethod def gen_percolation( grid_shape: Coord | CoordTup, p: float = 0.4, lattice_dim: int = 2, start_coord: Coord | None = None, ) -> LatticeMaze: """generate a lattice maze using simple percolation note that p in the range (0.4, 0.7) gives the most interesting mazes # Arguments - `grid_shape: Coord`: the shape of the grid - `lattice_dim: int`: the dimension of the lattice (default: `2`) - `p: float`: the probability of a cell being accessible (default: `0.5`) - `start_coord: Coord | None`: the starting coordinate for the connected component (default: `None` will give a random start) """ assert p >= 0 and p <= 1, f"p must be between 0 and 1, got {p}" # noqa: PT018 grid_shape_: Coord = np.array(grid_shape) start_coord = _random_start_coord(grid_shape_, start_coord) connection_list: ConnectionList = np.random.rand(lattice_dim, *grid_shape_) < p connection_list = _fill_edges_with_walls(connection_list) output: LatticeMaze = LatticeMaze( connection_list=connection_list, generation_meta=dict( func_name="gen_percolation", grid_shape=grid_shape_, percolation_p=p, start_coord=start_coord, ), ) # generation_meta is sometimes None, but not here since we just made it a dict above output.generation_meta["visited_cells"] = output.gen_connected_component_from( # type: ignore[index] start_coord, ) return output @staticmethod def gen_dfs_percolation( grid_shape: Coord | CoordTup, p: float = 0.4, lattice_dim: int = 2, accessible_cells: int | None = None, max_tree_depth: int | None = None, start_coord: Coord | None = None, ) -> LatticeMaze: """dfs and then percolation (adds cycles)""" grid_shape_: Coord = np.array(grid_shape) start_coord = _random_start_coord(grid_shape_, start_coord) # generate initial maze via dfs maze: LatticeMaze = LatticeMazeGenerators.gen_dfs( grid_shape=grid_shape_, lattice_dim=lattice_dim, accessible_cells=accessible_cells, max_tree_depth=max_tree_depth, start_coord=start_coord, ) # percolate connection_list_perc: np.ndarray = ( np.random.rand(*maze.connection_list.shape) < p ) connection_list_perc = _fill_edges_with_walls(connection_list_perc) maze.__dict__["connection_list"] = np.logical_or( maze.connection_list, connection_list_perc, ) # generation_meta is sometimes None, but not here since we just made it a dict above maze.generation_meta["func_name"] = "gen_dfs_percolation" # type: ignore[index] maze.generation_meta["percolation_p"] = p # type: ignore[index] maze.generation_meta["visited_cells"] = maze.gen_connected_component_from( # type: ignore[index] start_coord, ) return maze @staticmethod def gen_kruskal( grid_shape: "Coord | CoordTup", lattice_dim: int = 2, start_coord: "Coord | None" = None, ) -> "LatticeMaze": """Generate a maze using Kruskal's algorithm. This function generates a random spanning tree over a grid using Kruskal's algorithm. Each cell is treated as a node, and all valid adjacent edges are listed and processed in random order. An edge is added (i.e. its passage carved) only if it connects two cells that are not already connected. The resulting maze is a perfect maze (i.e. a spanning tree) without cycles. https://en.wikipedia.org/wiki/Kruskal's_algorithm # Parameters: - `grid_shape : Coord | CoordTup` The shape of the maze grid (for example, `(n_rows, n_cols)`). - `lattice_dim : int` The lattice dimension (default is `2`). - `start_coord : Coord | None` Optionally, specify a starting coordinate. If `None`, a random coordinate will be chosen. - `**kwargs` Additional keyword arguments (currently unused). # Returns: - `LatticeMaze` A maze represented by a connection list, generated as a spanning tree using Kruskal's algorithm. # Usage: ```python maze = gen_kruskal((10, 10)) ``` """ assert lattice_dim == 2, ( # noqa: PLR2004 "Kruskal's algorithm is only implemented for 2D lattices." ) # Convert grid_shape to a tuple of ints grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] n_rows, n_cols = grid_shape_ # Initialize union-find data structure. parent: dict[tuple[int, int], tuple[int, int]] = {} def find(cell: tuple[int, int]) -> tuple[int, int]: while parent[cell] != cell: parent[cell] = parent[parent[cell]] cell = parent[cell] return cell def union(cell1: tuple[int, int], cell2: tuple[int, int]) -> None: root1 = find(cell1) root2 = find(cell2) parent[root2] = root1 # Initialize each cell as its own set. for i in range(n_rows): for j in range(n_cols): parent[(i, j)] = (i, j) # List all possible edges. # For vertical edges (i.e. connecting a cell to its right neighbor): edges: list[tuple[tuple[int, int], tuple[int, int], int]] = [] for i in range(n_rows): for j in range(n_cols - 1): edges.append(((i, j), (i, j + 1), 1)) # For horizontal edges (i.e. connecting a cell to its bottom neighbor): for i in range(n_rows - 1): for j in range(n_cols): edges.append(((i, j), (i + 1, j), 0)) # Shuffle the list of edges. import random random.shuffle(edges) # Initialize connection_list with no connections. # connection_list[0] stores downward connections (from cell (i,j) to (i+1,j)). # connection_list[1] stores rightward connections (from cell (i,j) to (i,j+1)). import numpy as np connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) # Process each edge; if it connects two different trees, union them and carve the passage. for cell1, cell2, direction in edges: if find(cell1) != find(cell2): union(cell1, cell2) if direction == 0: # Horizontal edge: connection is stored in connection_list[0] at cell1. connection_list[0, cell1[0], cell1[1]] = True else: # Vertical edge: connection is stored in connection_list[1] at cell1. connection_list[1, cell1[0], cell1[1]] = True if start_coord is None: start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] generation_meta: dict = dict( func_name="gen_kruskal", grid_shape=grid_shape_, start_coord=start_coord, algorithm="kruskal", fully_connected=True, ) return LatticeMaze( connection_list=connection_list, generation_meta=generation_meta ) @staticmethod def gen_recursive_division( grid_shape: "Coord | CoordTup", lattice_dim: int = 2, start_coord: "Coord | None" = None, ) -> "LatticeMaze": """Generate a maze using the recursive division algorithm. This function generates a maze by recursively dividing the grid with walls and carving a single passage through each wall. The algorithm begins with a fully connected grid (i.e. every pair of adjacent cells is connected) and then removes connections along a chosen division line—leaving one gap as a passage. The resulting maze is a perfect maze, meaning there is exactly one path between any two cells. # Parameters: - `grid_shape : Coord | CoordTup` The shape of the maze grid (e.g., `(n_rows, n_cols)`). - `lattice_dim : int` The lattice dimension (default is `2`). - `start_coord : Coord | None` Optionally, specify a starting coordinate. If `None`, a random coordinate is chosen. - `**kwargs` Additional keyword arguments (currently unused). # Returns: - `LatticeMaze` A maze represented by a connection list, generated using recursive division. # Usage: ```python maze = gen_recursive_division((10, 10)) ``` """ assert lattice_dim == 2, ( # noqa: PLR2004 "Recursive division algorithm is only implemented for 2D lattices." ) # Convert grid_shape to a tuple of ints. grid_shape_: CoordTup = tuple(int(x) for x in grid_shape) # type: ignore[assignment] n_rows, n_cols = grid_shape_ # Initialize connection_list as a fully connected grid. # For horizontal connections: for each cell (i,j) with i in [0, n_rows-2], set connection to True. # For vertical connections: for each cell (i,j) with j in [0, n_cols-2], set connection to True. connection_list = np.zeros((2, n_rows, n_cols), dtype=bool) connection_list[0, : n_rows - 1, :] = True connection_list[1, :, : n_cols - 1] = True def divide(x: int, y: int, width: int, height: int) -> None: """Recursively divide the region starting at (x, y) with the given width and height. Removes connections along the chosen division line except for one randomly chosen gap. """ if width < 2 or height < 2: # noqa: PLR2004 return if width > height: # Vertical division. wall_col = random.randint(x + 1, x + width - 1) gap_row = random.randint(y, y + height - 1) for row in range(y, y + height): if row == gap_row: continue # Remove the vertical connection between (row, wall_col-1) and (row, wall_col). if wall_col - 1 < n_cols - 1: connection_list[1, row, wall_col - 1] = False # Recurse on the left and right subregions. divide(x, y, wall_col - x, height) divide(wall_col, y, x + width - wall_col, height) else: # Horizontal division. wall_row = random.randint(y + 1, y + height - 1) gap_col = random.randint(x, x + width - 1) for col in range(x, x + width): if col == gap_col: continue # Remove the horizontal connection between (wall_row-1, col) and (wall_row, col). if wall_row - 1 < n_rows - 1: connection_list[0, wall_row - 1, col] = False # Recurse on the top and bottom subregions. divide(x, y, width, wall_row - y) divide(x, wall_row, width, y + height - wall_row) # Begin the division on the full grid. divide(0, 0, n_cols, n_rows) if start_coord is None: start_coord = tuple(np.random.randint(0, n) for n in grid_shape_) # type: ignore[assignment] generation_meta: dict = dict( func_name="gen_recursive_division", grid_shape=grid_shape_, start_coord=start_coord, algorithm="recursive_division", fully_connected=True, ) return LatticeMaze( connection_list=connection_list, generation_meta=generation_meta ) # cant automatically populate this because it messes with pickling :( GENERATORS_MAP: dict[str, Callable[[Coord | CoordTup, Any], "LatticeMaze"]] = { "gen_dfs": LatticeMazeGenerators.gen_dfs, # TYPING: error: Dict entry 1 has incompatible type # "str": "Callable[[ndarray[Any, Any] | tuple[int, int], KwArg(Any)], LatticeMaze]"; # expected "str": "Callable[[ndarray[Any, Any] | tuple[int, int], Any], LatticeMaze]" [dict-item] # gen_wilson takes no kwargs and we check that the kwargs are empty # but mypy doesnt like this, `Any` != `KwArg(Any)` "gen_wilson": LatticeMazeGenerators.gen_wilson, # type: ignore[dict-item] "gen_percolation": LatticeMazeGenerators.gen_percolation, "gen_dfs_percolation": LatticeMazeGenerators.gen_dfs_percolation, "gen_prim": LatticeMazeGenerators.gen_prim, "gen_kruskal": LatticeMazeGenerators.gen_kruskal, "gen_recursive_division": LatticeMazeGenerators.gen_recursive_division, } "mapping of generator names to generator functions, useful for loading `MazeDatasetConfig`" _GENERATORS_PERCOLATED: list[str] = [ "gen_percolation", "gen_dfs_percolation", ] """list of generator names that generate percolated mazes we use this to figure out the expected success rate, since depending on the endpoint kwargs this might fail this variable is primarily used in `MazeDatasetConfig._to_ps_array` and `MazeDatasetConfig._from_ps_array` """ def get_maze_with_solution( gen_name: str, grid_shape: Coord | CoordTup, maze_ctor_kwargs: dict | None = None, ) -> SolvedMaze: "helper function to get a maze already with a solution" if maze_ctor_kwargs is None: maze_ctor_kwargs = dict() # TYPING: error: Too few arguments [call-arg] # not sure why this is happening -- doesnt recognize the kwargs? maze: LatticeMaze = GENERATORS_MAP[gen_name](grid_shape, **maze_ctor_kwargs) # type: ignore[call-arg] solution: CoordArray = np.array(maze.generate_random_path()) return SolvedMaze.from_lattice_maze(lattice_maze=maze, solution=solution) ``````{ end_of_file="maze_dataset/generation/generators.py" } ``````{ path="maze_dataset/generation/seed.py" } "global default seed" GLOBAL_SEED: int = 42 ``````{ end_of_file="maze_dataset/generation/seed.py" } ``````{ path="maze_dataset/maze/__init__.py" } r"""`LatticeMaze` and the classes like `SolvedMaze` that inherit from it, along with a variety of helper functions" This package utilizes a simple, efficient representation of mazes. Using an adjacency list to represent mazes would lead to a poor lookup time of whether any given connection exists, whilst using a dense adjacency matrix would waste memory by failing to exploit the structure (e.g., only 4 of the diagonals would be filled in). Instead, we describe mazes with the following simple representation: for a $d$-dimensional lattice with $r$ rows and $c$ columns, we initialize a boolean array $A = \{0, 1\}^{d \times r \times c}$, which we refer to in the code as a `connection_list`. The value at $A[0,i,j]$ determines whether a downward connection exists from node $[i,j]$ to $[i+1, j]$. Likewise, the value at $A[1,i,j]$ determines whether a rightwards connection to $[i, j+1]$ exists. Thus, we avoid duplication of data about the existence of connections, at the cost of requiring additional care with indexing when looking for a connection upwards or to the left. Note that this setup allows for a periodic lattice. To produce solutions to mazes, two points are selected uniformly at random without replacement from the connected component of the maze, and the $A^*$ algorithm is applied to find the shortest path between them. Parallelization is implemented via the `multiprocessing` module in the Python standard library, and parallel generation can be controlled via keyword arguments to the `MazeDataset.from_config()` function. """ from maze_dataset.maze.lattice_maze import ( AsciiChars, ConnectionList, Coord, CoordArray, LatticeMaze, PixelColors, SolvedMaze, TargetedLatticeMaze, ) __all__ = [ # submodules "lattice_maze", # imports "SolvedMaze", "TargetedLatticeMaze", "LatticeMaze", "ConnectionList", "AsciiChars", "Coord", "CoordArray", "PixelColors", ] ``````{ end_of_file="maze_dataset/maze/__init__.py" } ``````{ path="maze_dataset/maze/lattice_maze.py" } """Implements `LatticeMaze`, and the `TargetedLatticeMaze` and `SolvedMaze` subclasses. also includes basic utilities, including converting to/from ascii and pixel representations. """ import typing import warnings from dataclasses import dataclass from itertools import chain import numpy as np from jaxtyping import Bool, Int, Int8, Shaped from muutils.json_serialize.serializable_dataclass import ( SerializableDataclass, serializable_dataclass, serializable_field, ) from muutils.misc import isinstance_by_type_name, list_split from maze_dataset.constants import ( NEIGHBORS_MASK, SPECIAL_TOKENS, ConnectionList, Coord, CoordArray, CoordList, CoordTup, ) from maze_dataset.token_utils import ( TokenizerDeprecationWarning, connection_list_to_adj_list, get_adj_list_tokens, get_origin_tokens, get_path_tokens, get_target_tokens, ) if typing.TYPE_CHECKING: from maze_dataset.tokenization import ( MazeTokenizer, MazeTokenizerModular, TokenizationMode, ) RGB = tuple[int, int, int] "rgb tuple of values 0-255" PixelGrid = Int[np.ndarray, "x y rgb"] "rgb grid of pixels" BinaryPixelGrid = Bool[np.ndarray, "x y"] "boolean grid of pixels" DIM_2: int = 2 "2 dimensions" class NoValidEndpointException(Exception): # noqa: N818 """Raised when no valid start or end positions are found in a maze.""" pass def _fill_edges_with_walls(connection_list: ConnectionList) -> ConnectionList: """fill the last elements of the connections lists as false for each dim""" for dim in range(connection_list.shape[0]): # last row for down if dim == 0: connection_list[dim, -1, :] = False # last column for right elif dim == 1: connection_list[dim, :, -1] = False else: err_msg: str = f"only 2d lattices supported. got {dim=}" raise NotImplementedError(err_msg) return connection_list def color_in_pixel_grid(pixel_grid: PixelGrid, color: RGB) -> bool: """check if a color is in a pixel grid""" for row in pixel_grid: for pixel in row: if np.all(pixel == color): return True return False @dataclass(frozen=True) class PixelColors: "standard colors for pixel grids" WALL: RGB = (0, 0, 0) OPEN: RGB = (255, 255, 255) START: RGB = (0, 255, 0) END: RGB = (255, 0, 0) PATH: RGB = (0, 0, 255) @dataclass(frozen=True) class AsciiChars: "standard ascii characters for mazes" WALL: str = "#" OPEN: str = " " START: str = "S" END: str = "E" PATH: str = "X" ASCII_PIXEL_PAIRINGS: dict[str, RGB] = { AsciiChars.WALL: PixelColors.WALL, AsciiChars.OPEN: PixelColors.OPEN, AsciiChars.START: PixelColors.START, AsciiChars.END: PixelColors.END, AsciiChars.PATH: PixelColors.PATH, } "map ascii characters to pixel colors" @serializable_dataclass( frozen=True, kw_only=True, properties_to_serialize=["lattice_dim", "generation_meta"], ) class LatticeMaze(SerializableDataclass): """lattice maze (nodes on a lattice, connections only to neighboring nodes) Connection List represents which nodes (N) are connected in each direction. First and second elements represent rightward and downward connections, respectively. Example: Connection list: [ [ # down [F T], [F F] ], [ # right [T F], [T F] ] ] Nodes with connections N T N F F T N T N F F F Graph: N - N | N - N Note: the bottom row connections going down, and the right-hand connections going right, will always be False. """ connection_list: ConnectionList generation_meta: dict | None = serializable_field(default=None, compare=False) lattice_dim = property(lambda self: self.connection_list.shape[0]) grid_shape = property(lambda self: self.connection_list.shape[1:]) n_connections = property(lambda self: self.connection_list.sum()) @property def grid_n(self) -> int: "grid size as int, raises `AssertionError` if not square" assert self.grid_shape[0] == self.grid_shape[1], "only square mazes supported" return self.grid_shape[0] # ============================================================ # basic methods # ============================================================ def __eq__(self, other: object) -> bool: "equality check calls super" return super().__eq__(other) @staticmethod def heuristic(a: CoordTup, b: CoordTup) -> float: """return manhattan distance between two points""" return np.abs(a[0] - b[0]) + np.abs(a[1] - b[1]) def __hash__(self) -> int: """hash the connection list by converting connection list to bytes""" return hash(self.connection_list.tobytes()) def nodes_connected(self, a: Coord, b: Coord, /) -> bool: """returns whether two nodes are connected""" delta: Coord = b - a if np.abs(delta).sum() != 1: # return false if not even adjacent return False else: # test for wall dim: int = int(np.argmax(np.abs(delta))) clist_node: Coord = a if (delta.sum() > 0) else b return self.connection_list[dim, clist_node[0], clist_node[1]] def is_valid_path(self, path: CoordArray, empty_is_valid: bool = False) -> bool: """check if a path is valid""" # check path is not empty if len(path) == 0: return empty_is_valid # check all coords in bounds of maze if not np.all((path >= 0) & (path < self.grid_shape)): return False # check all nodes connected for i in range(len(path) - 1): if not self.nodes_connected(path[i], path[i + 1]): return False return True def coord_degrees(self) -> Int8[np.ndarray, "row col"]: """Returns an array with the connectivity degree of each coord. I.e., how many neighbors each coord has. """ int_conn: Int8[np.ndarray, "lattice_dim=2 row col"] = ( self.connection_list.astype(np.int8) ) degrees: Int8[np.ndarray, "row col"] = np.sum( int_conn, axis=0, ) # Connections to east and south degrees[:, 1:] += int_conn[1, :, :-1] # Connections to west degrees[1:, :] += int_conn[0, :-1, :] # Connections to north return degrees def get_coord_neighbors(self, c: Coord | CoordTup) -> CoordArray: """Returns an array of the neighboring, connected coords of `c`.""" c = np.array(c) # type: ignore[assignment] neighbors: list[Coord] = [ neighbor for neighbor in (c + NEIGHBORS_MASK) if ( (0 <= neighbor[0] < self.grid_shape[0]) # in x bounds and (0 <= neighbor[1] < self.grid_shape[1]) # in y bounds and self.nodes_connected(c, neighbor) # connected ) ] output: CoordArray = np.array(neighbors) if len(neighbors) > 0: assert output.shape == ( len(neighbors), 2, ), ( f"invalid shape: {output.shape}, expected ({len(neighbors)}, 2))\n{c = }\n{neighbors = }\n{self.as_ascii()}" ) return output def gen_connected_component_from(self, c: Coord) -> CoordArray: """return the connected component from a given coordinate""" # Stack for DFS stack: list[Coord] = [c] # Set to store visited nodes visited: set[CoordTup] = set() while stack: current_node: Coord = stack.pop() # this is fine since we know current_node is a coord and thus of length 2 visited.add(tuple(current_node)) # type: ignore[arg-type] # Get the neighbors of the current node neighbors = self.get_coord_neighbors(current_node) # Iterate over neighbors for neighbor in neighbors: if tuple(neighbor) not in visited: stack.append(neighbor) return np.array(list(visited)) def find_shortest_path( self, c_start: CoordTup | Coord, c_end: CoordTup | Coord, ) -> CoordArray: """find the shortest path between two coordinates, using A*""" c_start = tuple(c_start) # type: ignore[assignment] c_end = tuple(c_end) # type: ignore[assignment] g_score: dict[CoordTup, float] = ( dict() ) # cost of cheapest path to node from start currently known f_score: dict[CoordTup, float] = { c_start: 0.0, } # estimated total cost of path thru a node: f_score[c] := g_score[c] + heuristic(c, c_end) # init g_score[c_start] = 0.0 g_score[c_start] = self.heuristic(c_start, c_end) closed_vtx: set[CoordTup] = set() # nodes already evaluated # nodes to be evaluated # we need a set of the tuples, dont place the ints in the set open_vtx: set[CoordTup] = set([c_start]) # noqa: C405 source: dict[CoordTup, CoordTup] = ( dict() ) # node immediately preceding each node in the path (currently known shortest path) while open_vtx: # get lowest f_score node # mypy cant tell that c is of length 2 c_current: CoordTup = min(open_vtx, key=lambda c: f_score[tuple(c)]) # type: ignore[index] # f_current: float = f_score[c_current] # check if goal is reached if c_end == c_current: path: list[CoordTup] = [c_current] p_current: CoordTup = c_current while p_current in source: p_current = source[p_current] path.append(p_current) # ---------------------------------------------------------------------- # this is the only return statement return np.array(path[::-1]) # ---------------------------------------------------------------------- # close current node closed_vtx.add(c_current) open_vtx.remove(c_current) # update g_score of neighbors _np_neighbor: Coord for _np_neighbor in self.get_coord_neighbors(c_current): neighbor: CoordTup = tuple(_np_neighbor) if neighbor in closed_vtx: # already checked continue g_temp: float = g_score[c_current] + 1 # always 1 for maze neighbors if neighbor not in open_vtx: # found new vtx, so add open_vtx.add(neighbor) elif g_temp >= g_score[neighbor]: # if already knew about this one, but current g_score is worse, skip continue # store g_score and source source[neighbor] = c_current g_score[neighbor] = g_temp f_score[neighbor] = g_score[neighbor] + self.heuristic(neighbor, c_end) raise ValueError( "A solution could not be found!", f"{c_start = }, {c_end = }", self.as_ascii(), ) def get_nodes(self) -> CoordArray: """return a list of all nodes in the maze""" rows: Int[np.ndarray, "x y"] cols: Int[np.ndarray, "x y"] rows, cols = np.meshgrid( range(self.grid_shape[0]), range(self.grid_shape[1]), indexing="ij", ) nodes: CoordArray = np.vstack((rows.ravel(), cols.ravel())).T return nodes def get_connected_component(self) -> CoordArray: """get the largest (and assumed only nonsingular) connected component of the maze TODO: other connected components? """ if (self.generation_meta is None) or ( self.generation_meta.get("fully_connected", False) ): # for fully connected case, pick any two positions return self.get_nodes() else: # if metadata provided, use visited cells visited_cells: set[CoordTup] | None = self.generation_meta.get( "visited_cells", None, ) if visited_cells is None: # TODO: dynamically generate visited_cells? err_msg: str = f"a maze which is not marked as fully connected must have a visited_cells field in its generation_meta: {self.generation_meta}\n{self}\n{self.as_ascii()}" raise ValueError( err_msg, ) visited_cells_np: Int[np.ndarray, "N 2"] = np.array(list(visited_cells)) return visited_cells_np @typing.overload def generate_random_path( self, allowed_start: CoordList | None = None, allowed_end: CoordList | None = None, deadend_start: bool = False, deadend_end: bool = False, endpoints_not_equal: bool = False, except_on_no_valid_endpoint: typing.Literal[True] = True, ) -> CoordArray: ... @typing.overload def generate_random_path( self, allowed_start: CoordList | None = None, allowed_end: CoordList | None = None, deadend_start: bool = False, deadend_end: bool = False, endpoints_not_equal: bool = False, except_on_no_valid_endpoint: typing.Literal[False] = False, ) -> typing.Optional[CoordArray]: ... def generate_random_path( # noqa: C901 self, allowed_start: CoordList | None = None, allowed_end: CoordList | None = None, deadend_start: bool = False, deadend_end: bool = False, endpoints_not_equal: bool = False, except_on_no_valid_endpoint: bool = True, ) -> typing.Optional[CoordArray]: """return a path between randomly chosen start and end nodes within the connected component Note that setting special conditions on start and end positions might cause the same position to be selected as both start and end. # Parameters: - `allowed_start : CoordList | None` a list of allowed start positions. If `None`, any position in the connected component is allowed (defaults to `None`) - `allowed_end : CoordList | None` a list of allowed end positions. If `None`, any position in the connected component is allowed (defaults to `None`) - `deadend_start : bool` whether to ***force*** the start position to be a deadend (defaults to `False`) (defaults to `False`) - `deadend_end : bool` whether to ***force*** the end position to be a deadend (defaults to `False`) (defaults to `False`) - `endpoints_not_equal : bool` whether to ensure tha the start and end point are not the same (defaults to `False`) - `except_on_no_valid_endpoint : bool` whether to raise an error if no valid start or end positions are found if this is `False`, the function might return `None` and this must be handled by the caller (defaults to `True`) # Returns: - `CoordArray` a path between the selected start and end positions # Raises: - `NoValidEndpointException` : if no valid start or end positions are found, and `except_on_no_valid_endpoint` is `True` """ # we can't create a "path" in a single-node maze assert self.grid_shape[0] > 1 and self.grid_shape[1] > 1, ( # noqa: PT018 f"can't create path in single-node maze: {self.as_ascii()}" ) # get connected component connected_component: CoordArray = self.get_connected_component() # initialize start and end positions positions: Int[np.int8, "2 2"] # if no special conditions on start and end positions if (allowed_start, allowed_end, deadend_start, deadend_end) == ( None, None, False, False, ): try: positions = connected_component[ # type: ignore[assignment] np.random.choice( len(connected_component), size=2, replace=False, ) ] except ValueError as e: if except_on_no_valid_endpoint: err_msg: str = f"No valid start or end positions found because we could not sample from {connected_component = }" raise NoValidEndpointException( err_msg, ) from e return None return self.find_shortest_path(positions[0], positions[1]) # type: ignore[index] # handle special conditions connected_component_set: set[CoordTup] = set(map(tuple, connected_component)) # copy connected component set allowed_start_set: set[CoordTup] = connected_component_set.copy() allowed_end_set: set[CoordTup] = connected_component_set.copy() # filter by explicitly allowed start and end positions # '# type: ignore[assignment]' here because the returned tuple can be of any length if allowed_start is not None: allowed_start_set = set(map(tuple, allowed_start)) & connected_component_set # type: ignore[assignment] if allowed_end is not None: allowed_end_set = set(map(tuple, allowed_end)) & connected_component_set # type: ignore[assignment] # filter by forcing deadends if deadend_start: allowed_start_set = set( filter( lambda x: len(self.get_coord_neighbors(x)) == 1, allowed_start_set, ), ) if deadend_end: allowed_end_set = set( filter( lambda x: len(self.get_coord_neighbors(x)) == 1, allowed_end_set, ), ) # check we have valid positions if len(allowed_start_set) == 0 or len(allowed_end_set) == 0: if except_on_no_valid_endpoint: err_msg = f"No valid start (or end?) positions found: {allowed_start_set = }, {allowed_end_set = }" raise NoValidEndpointException( err_msg, ) return None # randomly select start and end positions try: # ignore assignment here since `tuple()` returns a tuple of any length, but we know it will be ok start_pos: CoordTup = tuple( # type: ignore[assignment] list(allowed_start_set)[np.random.randint(0, len(allowed_start_set))], ) if endpoints_not_equal: # remove start position from end positions allowed_end_set.discard(start_pos) end_pos: CoordTup = tuple( # type: ignore[assignment] list(allowed_end_set)[np.random.randint(0, len(allowed_end_set))], ) except ValueError as e: if except_on_no_valid_endpoint: err_msg = f"No valid start or end positions found, maybe can't find an endpoint after we removed the start point: {allowed_start_set = }, {allowed_end_set = }" raise NoValidEndpointException( err_msg, ) from e return None return self.find_shortest_path(start_pos, end_pos) # ============================================================ # to and from adjacency list # ============================================================ def as_adj_list( self, shuffle_d0: bool = True, shuffle_d1: bool = True, ) -> Int8[np.ndarray, "conn start_end coord"]: """return the maze as an adjacency list, wraps `maze_dataset.token_utils.connection_list_to_adj_list`""" return connection_list_to_adj_list(self.connection_list, shuffle_d0, shuffle_d1) @classmethod def from_adj_list( cls, adj_list: Int8[np.ndarray, "conn start_end coord"], ) -> "LatticeMaze": """create a LatticeMaze from a list of connections > [!NOTE] > This has only been tested for square mazes. Might need to change some things if rectangular mazes are needed. """ # this is where it would probably break for rectangular mazes grid_n: int = adj_list.max() + 1 connection_list: ConnectionList = np.zeros( (2, grid_n, grid_n), dtype=np.bool_, ) for c_start, c_end in adj_list: # check that exactly 1 coordinate matches if (c_start == c_end).sum() != 1: raise ValueError("invalid connection") # get the direction d: int = (c_start != c_end).argmax() x: int y: int # pick whichever has the lesser value in the direction `d` if c_start[d] < c_end[d]: x, y = c_start else: x, y = c_end connection_list[d, x, y] = True return LatticeMaze( connection_list=connection_list, ) def as_adj_list_tokens(self) -> list[str | CoordTup]: """(deprecated!) turn the maze into adjacency list tokens, use `MazeTokenizerModular` instead""" warnings.warn( "`LatticeMaze.as_adj_list_tokens` will be removed from the public API in a future release.", TokenizerDeprecationWarning, ) return [ SPECIAL_TOKENS.ADJLIST_START, *chain.from_iterable( # type: ignore[list-item] [ [ tuple(c_s), SPECIAL_TOKENS.CONNECTOR, tuple(c_e), SPECIAL_TOKENS.ADJACENCY_ENDLINE, ] for c_s, c_e in self.as_adj_list() ], ), SPECIAL_TOKENS.ADJLIST_END, ] def _as_adj_list_tokens(self) -> list[str | CoordTup]: return [ SPECIAL_TOKENS.ADJLIST_START, *chain.from_iterable( # type: ignore[list-item] [ [ tuple(c_s), SPECIAL_TOKENS.CONNECTOR, tuple(c_e), SPECIAL_TOKENS.ADJACENCY_ENDLINE, ] for c_s, c_e in self.as_adj_list() ], ), SPECIAL_TOKENS.ADJLIST_END, ] def _as_coords_and_special_AOTP(self) -> list[CoordTup | str]: """turn the maze into adjacency list, origin, target, and solution -- keep coords as tuples""" output: list[CoordTup | str] = self._as_adj_list_tokens() # if getattr(self, "start_pos", None) is not None: if isinstance(self, TargetedLatticeMaze): output += self._get_start_pos_tokens() if isinstance(self, TargetedLatticeMaze): output += self._get_end_pos_tokens() if isinstance(self, SolvedMaze): output += self._get_solution_tokens() return output def _as_tokens( self, maze_tokenizer: "MazeTokenizer | TokenizationMode", ) -> list[str]: # type ignores here fine since we check the instance if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] if ( isinstance_by_type_name(maze_tokenizer, "MazeTokenizer") and maze_tokenizer.is_AOTP() # type: ignore[union-attr] ): coords_raw: list[CoordTup | str] = self._as_coords_and_special_AOTP() coords_processed: list[str] = maze_tokenizer.coords_to_strings( # type: ignore[union-attr] coords=coords_raw, when_noncoord="include", ) return coords_processed else: err_msg: str = f"Unsupported tokenizer type: {maze_tokenizer}" raise NotImplementedError(err_msg) def as_tokens( self, maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", ) -> list[str]: """serialize maze and solution to tokens""" if isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular"): return maze_tokenizer.to_tokens(self) # type: ignore[union-attr] else: return self._as_tokens(maze_tokenizer) # type: ignore[union-attr,arg-type] @classmethod def _from_tokens_AOTP( cls, tokens: list[str], maze_tokenizer: "MazeTokenizer | MazeTokenizerModular", ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": """create a LatticeMaze from a list of tokens""" # figure out what input format # ======================================== if tokens[0] == SPECIAL_TOKENS.ADJLIST_START: adj_list_tokens = get_adj_list_tokens(tokens) else: # If we're not getting a "complete" tokenized maze, assume it's just a the adjacency list tokens adj_list_tokens = tokens warnings.warn( "Assuming input is just adjacency list tokens, no special tokens found", ) # process edges for adjacency list # ======================================== edges: list[list[str]] = list_split( adj_list_tokens, SPECIAL_TOKENS.ADJACENCY_ENDLINE, ) coordinates: list[tuple[CoordTup, CoordTup]] = list() for e in edges: # skip last endline if len(e) != 0: # convert to coords, split start and end e_coords: list[str | CoordTup] = maze_tokenizer.strings_to_coords( e, when_noncoord="include", ) # this assertion depends on the tokenizer having exactly one token for the connector # which is also why we "include" above # the connector token is discarded below assert len(e_coords) == 3, f"invalid edge: {e = } {e_coords = }" # noqa: PLR2004 assert e_coords[1] == SPECIAL_TOKENS.CONNECTOR, ( f"invalid edge: {e = } {e_coords = }" ) e_coords_first: CoordTup = e_coords[0] # type: ignore[assignment] e_coords_last: CoordTup = e_coords[-1] # type: ignore[assignment] coordinates.append((e_coords_first, e_coords_last)) assert all(len(c) == DIM_2 for c in coordinates), ( f"invalid coordinates: {coordinates = }" ) adj_list: Int8[np.ndarray, "conn start_end coord"] = np.array(coordinates) assert tuple(adj_list.shape) == ( len(coordinates), 2, 2, ), f"invalid adj_list: {adj_list.shape = } {coordinates = }" output_maze: LatticeMaze = cls.from_adj_list(adj_list) # add start and end positions # ======================================== is_targeted: bool = False if all( x in tokens for x in ( SPECIAL_TOKENS.ORIGIN_START, SPECIAL_TOKENS.ORIGIN_END, SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END, ) ): start_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( get_origin_tokens(tokens), when_noncoord="error", ) end_pos_list: list[CoordTup] = maze_tokenizer.strings_to_coords( get_target_tokens(tokens), when_noncoord="error", ) assert len(start_pos_list) == 1, ( f"invalid start_pos_list: {start_pos_list = }" ) assert len(end_pos_list) == 1, f"invalid end_pos_list: {end_pos_list = }" start_pos: CoordTup = start_pos_list[0] end_pos: CoordTup = end_pos_list[0] output_maze = TargetedLatticeMaze.from_lattice_maze( lattice_maze=output_maze, start_pos=start_pos, end_pos=end_pos, ) is_targeted = True if all( x in tokens for x in (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END) ): assert is_targeted, "maze must be targeted to have a solution" solution: list[CoordTup] = maze_tokenizer.strings_to_coords( get_path_tokens(tokens, trim_end=True), when_noncoord="error", ) output_maze = SolvedMaze.from_targeted_lattice_maze( # HACK: I think this is fine, but im not sure targeted_lattice_maze=output_maze, # type: ignore[arg-type] solution=solution, ) return output_maze # TODO: any way to get return type hinting working for this? @classmethod def from_tokens( cls, tokens: list[str], maze_tokenizer: "MazeTokenizer | TokenizationMode | MazeTokenizerModular", ) -> "LatticeMaze | TargetedLatticeMaze | SolvedMaze": """Constructs a maze from a tokenization. Only legacy tokenizers and their `MazeTokenizerModular` analogs are supported. """ # HACK: type ignores here fine since we check the instance if isinstance_by_type_name(maze_tokenizer, "TokenizationMode"): maze_tokenizer = maze_tokenizer.to_legacy_tokenizer() # type: ignore[union-attr] if ( isinstance_by_type_name(maze_tokenizer, "MazeTokenizerModular") and not maze_tokenizer.is_legacy_equivalent() # type: ignore[union-attr] ): err_msg: str = f"Only legacy tokenizers and their exact `MazeTokenizerModular` analogs supported, not {maze_tokenizer}." raise NotImplementedError( err_msg, ) if isinstance(tokens, str): tokens = tokens.split() if maze_tokenizer.is_AOTP(): # type: ignore[union-attr] return cls._from_tokens_AOTP(tokens, maze_tokenizer) # type: ignore[arg-type] else: raise NotImplementedError("only AOTP tokenization is supported") # ============================================================ # to and from pixels # ============================================================ def _as_pixels_bw(self) -> BinaryPixelGrid: assert self.lattice_dim == DIM_2, "only 2D mazes are supported" # Create an empty pixel grid with walls pixel_grid: Int[np.ndarray, "x y"] = np.full( (self.grid_shape[0] * 2 + 1, self.grid_shape[1] * 2 + 1), False, dtype=np.bool_, ) # Set white nodes pixel_grid[1::2, 1::2] = True # Set white connections (downward) for i, row in enumerate(self.connection_list[0]): for j, connected in enumerate(row): if connected: pixel_grid[i * 2 + 2, j * 2 + 1] = True # Set white connections (rightward) for i, row in enumerate(self.connection_list[1]): for j, connected in enumerate(row): if connected: pixel_grid[i * 2 + 1, j * 2 + 2] = True return pixel_grid def as_pixels( self, show_endpoints: bool = True, show_solution: bool = True, ) -> PixelGrid: """convert the maze to a pixel grid - useful as a simpler way of plotting the maze than the more complex `MazePlot` - the same underlying representation as `as_ascii` but as an image - used in `RasterizedMazeDataset`, which mimics the mazes in https://github.com/aks2203/easy-to-hard-data """ # HACK: lots of `# type: ignore[attr-defined]` here since its defined for any `LatticeMaze` # but solution, start_pos, end_pos not always defined # but its fine since we explicitly check the type if show_solution and not show_endpoints: raise ValueError("show_solution=True requires show_endpoints=True") # convert original bool pixel grid to RGB pixel_grid_bw: BinaryPixelGrid = self._as_pixels_bw() pixel_grid: PixelGrid = np.full( (*pixel_grid_bw.shape, 3), PixelColors.WALL, dtype=np.uint8, ) pixel_grid[pixel_grid_bw == True] = PixelColors.OPEN # noqa: E712 if self.__class__ == LatticeMaze: return pixel_grid # set endpoints for TargetedLatticeMaze if self.__class__ == TargetedLatticeMaze: if show_endpoints: pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] PixelColors.START ) pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] PixelColors.END ) return pixel_grid # set solution -- we only reach this part if `self.__class__ == SolvedMaze` if show_solution: for coord in self.solution: # type: ignore[attr-defined] pixel_grid[coord[0] * 2 + 1, coord[1] * 2 + 1] = PixelColors.PATH # set pixels between coords for index, coord in enumerate(self.solution[:-1]): # type: ignore[attr-defined] next_coord = self.solution[index + 1] # type: ignore[attr-defined] # check they are adjacent using norm assert np.linalg.norm(np.array(coord) - np.array(next_coord)) == 1, ( f"Coords {coord} and {next_coord} are not adjacent" ) # set pixel between them pixel_grid[ coord[0] * 2 + 1 + next_coord[0] - coord[0], coord[1] * 2 + 1 + next_coord[1] - coord[1], ] = PixelColors.PATH # set endpoints (again, since path would overwrite them) pixel_grid[self.start_pos[0] * 2 + 1, self.start_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] PixelColors.START ) pixel_grid[self.end_pos[0] * 2 + 1, self.end_pos[1] * 2 + 1] = ( # type: ignore[attr-defined] PixelColors.END ) return pixel_grid @classmethod def _from_pixel_grid_bw( cls, pixel_grid: BinaryPixelGrid, ) -> tuple[ConnectionList, tuple[int, int]]: grid_shape: tuple[int, int] = ( pixel_grid.shape[0] // 2, pixel_grid.shape[1] // 2, ) connection_list: ConnectionList = np.zeros((2, *grid_shape), dtype=np.bool_) # Extract downward connections connection_list[0] = pixel_grid[2::2, 1::2] # Extract rightward connections connection_list[1] = pixel_grid[1::2, 2::2] return connection_list, grid_shape @classmethod def _from_pixel_grid_with_positions( cls, pixel_grid: PixelGrid | BinaryPixelGrid, marked_positions: dict[str, RGB], ) -> tuple[ConnectionList, tuple[int, int], dict[str, CoordArray]]: # Convert RGB pixel grid to Bool pixel grid # error: Incompatible types in assignment (expression has type # "numpy.bool[builtins.bool] | ndarray[tuple[int, ...], dtype[numpy.bool[builtins.bool]]]", # variable has type "ndarray[Any, Any]") [assignment] pixel_grid_bw: BinaryPixelGrid = ~np.all( # type: ignore[assignment] pixel_grid == PixelColors.WALL, axis=-1, ) connection_list: ConnectionList grid_shape: tuple[int, int] connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid_bw) # Find any marked positions out_positions: dict[str, CoordArray] = dict() for key, color in marked_positions.items(): pos_temp: Int[np.ndarray, "x y"] = np.argwhere( np.all(pixel_grid == color, axis=-1), ) pos_save: list[CoordTup] = list() for pos in pos_temp: # if it is a coordinate and not connection (transform position, %2==1) if pos[0] % 2 == 1 and pos[1] % 2 == 1: pos_save.append((pos[0] // 2, pos[1] // 2)) out_positions[key] = np.array(pos_save) return connection_list, grid_shape, out_positions @classmethod def from_pixels( cls, pixel_grid: PixelGrid, ) -> "LatticeMaze": """create a LatticeMaze from a pixel grid. reverse of `as_pixels` # Raises: - `ValueError` : if the pixel grid cannot be cast to a `LatticeMaze` -- it's probably a `TargetedLatticeMaze` or `SolvedMaze` """ connection_list: ConnectionList grid_shape: tuple[int, int] # if a binary pixel grid, return regular LatticeMaze if len(pixel_grid.shape) == 2: # noqa: PLR2004 connection_list, grid_shape = cls._from_pixel_grid_bw(pixel_grid) return LatticeMaze(connection_list=connection_list) # otherwise, detect and check it's valid cls_detected: typing.Type[LatticeMaze] = detect_pixels_type(pixel_grid) if cls not in cls_detected.__mro__: err_msg: str = f"Pixel grid cannot be cast to {cls.__name__ = }, detected type {cls_detected.__name__ = }" raise ValueError( err_msg, ) ( connection_list, grid_shape, marked_pos, ) = cls._from_pixel_grid_with_positions( pixel_grid=pixel_grid, marked_positions=dict( start=PixelColors.START, end=PixelColors.END, solution=PixelColors.PATH, ), ) # if we wanted a LatticeMaze, return it if cls == LatticeMaze: return LatticeMaze(connection_list=connection_list) # otherwise, keep going temp_maze: LatticeMaze = LatticeMaze(connection_list=connection_list) # start and end pos start_pos_arr, end_pos_arr = marked_pos["start"], marked_pos["end"] assert start_pos_arr.shape == ( 1, 2, ), ( f"start_pos_arr {start_pos_arr} has shape {start_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" ) assert end_pos_arr.shape == ( 1, 2, ), ( f"end_pos_arr {end_pos_arr} has shape {end_pos_arr.shape}, expected shape (1, 2) -- a single coordinate" ) start_pos: Coord = start_pos_arr[0] end_pos: Coord = end_pos_arr[0] # return a TargetedLatticeMaze if that's what we wanted if cls == TargetedLatticeMaze: return TargetedLatticeMaze( connection_list=connection_list, start_pos=start_pos, end_pos=end_pos, ) # raw solution, only contains path elements and not start or end solution_raw: CoordArray = marked_pos["solution"] if len(solution_raw.shape) == 2: # noqa: PLR2004 assert solution_raw.shape[1] == 2, ( # noqa: PLR2004 f"solution {solution_raw} has shape {solution_raw.shape}, expected shape (n, 2)" ) elif solution_raw.shape == (0,): # the solution and end should be immediately adjacent assert np.sum(np.abs(start_pos - end_pos)) == 1, ( f"start_pos {start_pos} and end_pos {end_pos} are not adjacent, but no solution was given" ) # order the solution, by creating a list from the start to the end # add end pos, since we will iterate over all these starting from the start pos solution_raw_list: list[CoordTup] = [tuple(c) for c in solution_raw] + [ tuple(end_pos), ] # solution starts with start point solution: list[CoordTup] = [tuple(start_pos)] while solution[-1] != tuple(end_pos): # use `get_coord_neighbors` to find connected neighbors neighbors: CoordArray = temp_maze.get_coord_neighbors(solution[-1]) # TODO: make this less ugly assert (len(neighbors.shape) == 2) and (neighbors.shape[1] == 2), ( # noqa: PT018, PLR2004 f"neighbors {neighbors} has shape {neighbors.shape}, expected shape (n, 2)\n{neighbors = }\n{solution = }\n{solution_raw = }\n{temp_maze.as_ascii()}" ) # neighbors = neighbors[:, [1, 0]] # filter out neighbors that are not in the raw solution neighbors_filtered: CoordArray = np.array( [ coord for coord in neighbors if ( tuple(coord) in solution_raw_list and tuple(coord) not in solution ) ], ) # assert only one element is left, and then add it to the solution assert neighbors_filtered.shape == ( 1, 2, ), ( f"neighbors_filtered has shape {neighbors_filtered.shape}, expected shape (1, 2)\n{neighbors = }\n{neighbors_filtered = }\n{solution = }\n{solution_raw_list = }\n{temp_maze.as_ascii()}" ) solution.append(tuple(neighbors_filtered[0])) # assert the solution is complete assert solution[0] == tuple(start_pos), ( f"solution {solution} does not start at start_pos {start_pos}" ) assert solution[-1] == tuple(end_pos), ( f"solution {solution} does not end at end_pos {end_pos}" ) return cls( connection_list=np.array(connection_list), solution=np.array(solution), # type: ignore[call-arg] ) # ============================================================ # to and from ASCII # ============================================================ def _as_ascii_grid(self) -> Shaped[np.ndarray, "x y"]: # Get the pixel grid using to_pixels(). pixel_grid: Bool[np.ndarray, "x y"] = self._as_pixels_bw() # Replace pixel values with ASCII characters. ascii_grid: Shaped[np.ndarray, "x y"] = np.full( pixel_grid.shape, AsciiChars.WALL, dtype=str, ) ascii_grid[pixel_grid == True] = AsciiChars.OPEN # noqa: E712 return ascii_grid def as_ascii( self, show_endpoints: bool = True, show_solution: bool = True, ) -> str: """return an ASCII grid of the maze useful for debugging in the terminal, or as it's own format can be reversed with `LatticeMaze.from_ascii()` """ ascii_grid: Shaped[np.ndarray, "x y"] = self._as_ascii_grid() pixel_grid: PixelGrid = self.as_pixels( show_endpoints=show_endpoints, show_solution=show_solution, ) chars_replace: tuple = tuple() if show_endpoints: chars_replace += (AsciiChars.START, AsciiChars.END) if show_solution: chars_replace += (AsciiChars.PATH,) for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): if ascii_char in chars_replace: ascii_grid[(pixel_grid == pixel_color).all(axis=-1)] = ascii_char return "\n".join("".join(row) for row in ascii_grid) @classmethod def from_ascii(cls, ascii_str: str) -> "LatticeMaze": "get a `LatticeMaze` from an ASCII representation (reverses `LaticeMaze.as_ascii`)" lines: list[str] = ascii_str.strip().split("\n") lines = [line.strip() for line in lines] ascii_grid: Shaped[np.ndarray, "x y"] = np.array( [list(line) for line in lines], dtype=str, ) pixel_grid: PixelGrid = np.zeros((*ascii_grid.shape, 3), dtype=np.uint8) for ascii_char, pixel_color in ASCII_PIXEL_PAIRINGS.items(): pixel_grid[ascii_grid == ascii_char] = pixel_color return cls.from_pixels(pixel_grid) # type ignore here even though theyre all frozen # maybe `SerializeableDataclass` itself is not frozen, but thats an ABC # error: Cannot inherit frozen dataclass from a non-frozen one [misc] @serializable_dataclass(frozen=True, kw_only=True) class TargetedLatticeMaze(LatticeMaze): # type: ignore[misc] """A LatticeMaze with a start and end position""" # this jank is so that SolvedMaze can inherit from this class without needing arguments for start_pos and end_pos # type ignore here because even though its a kw-only dataclass, # mypy doesn't like that non-default arguments are after default arguments start_pos: Coord = serializable_field( # type: ignore[misc] assert_type=False, ) end_pos: Coord = serializable_field( # type: ignore[misc] assert_type=False, ) def __post_init__(self) -> None: "post init converts start and end pos to numpy arrays, checks they exist and are in bounds" # make things numpy arrays (very jank to override frozen dataclass) self.__dict__["start_pos"] = np.array(self.start_pos) self.__dict__["end_pos"] = np.array(self.end_pos) assert self.start_pos is not None assert self.end_pos is not None # check that start and end are in bounds if ( self.start_pos[0] >= self.grid_shape[0] or self.start_pos[1] >= self.grid_shape[1] ): err_msg: str = f"start_pos {self.start_pos} is out of bounds for grid shape {self.grid_shape}" raise ValueError( err_msg, ) if ( self.end_pos[0] >= self.grid_shape[0] or self.end_pos[1] >= self.grid_shape[1] ): err_msg = f"end_pos {self.end_pos = } is out of bounds for grid shape {self.grid_shape = }" raise ValueError( err_msg, ) def __eq__(self, other: object) -> bool: "check equality, calls parent class equality check" return super().__eq__(other) def _get_start_pos_tokens(self) -> list[str | CoordTup]: return [ SPECIAL_TOKENS.ORIGIN_START, tuple(self.start_pos), SPECIAL_TOKENS.ORIGIN_END, ] def get_start_pos_tokens(self) -> list[str | CoordTup]: "(deprecated!) return the start position as a list of tokens" warnings.warn( "`TargetedLatticeMaze.get_start_pos_tokens` will be removed from the public API in a future release.", TokenizerDeprecationWarning, ) return self._get_start_pos_tokens() def _get_end_pos_tokens(self) -> list[str | CoordTup]: return [ SPECIAL_TOKENS.TARGET_START, tuple(self.end_pos), SPECIAL_TOKENS.TARGET_END, ] def get_end_pos_tokens(self) -> list[str | CoordTup]: "(deprecated!) return the end position as a list of tokens" warnings.warn( "`TargetedLatticeMaze.get_end_pos_tokens` will be removed from the public API in a future release.", TokenizerDeprecationWarning, ) return self._get_end_pos_tokens() @classmethod def from_lattice_maze( cls, lattice_maze: LatticeMaze, start_pos: Coord | CoordTup, end_pos: Coord | CoordTup, ) -> "TargetedLatticeMaze": "get a `TargetedLatticeMaze` from a `LatticeMaze` by specifying start and end positions" return cls( connection_list=lattice_maze.connection_list, start_pos=np.array(start_pos), end_pos=np.array(end_pos), generation_meta=lattice_maze.generation_meta, ) @serializable_dataclass(frozen=True, kw_only=True) class SolvedMaze(TargetedLatticeMaze): # type: ignore[misc] """Stores a maze and a solution""" solution: CoordArray = serializable_field( # type: ignore[misc] assert_type=False, ) def __init__( self, connection_list: ConnectionList, solution: CoordArray, generation_meta: dict | None = None, start_pos: Coord | None = None, end_pos: Coord | None = None, allow_invalid: bool = False, ) -> None: """Create a SolvedMaze from a connection list and a solution > DOCS: better documentation for this init method """ # figure out the solution solution_valid: bool = False if solution is not None: solution = np.array(solution) # note that a path length of 1 here is valid, since the start and end pos could be the same if (solution.shape[0] > 0) and (solution.shape[1] == 2): # noqa: PLR2004 solution_valid = True if not solution_valid and not allow_invalid: err_msg: str = f"invalid solution: {solution.shape = } {solution = } {solution_valid = } {allow_invalid = }" raise ValueError( err_msg, f"{connection_list = }", ) # init the TargetedLatticeMaze super().__init__( connection_list=connection_list, generation_meta=generation_meta, # TODO: the argument type is stricter than the expected type but it still fails? # error: Argument "start_pos" to "__init__" of "TargetedLatticeMaze" has incompatible type # "ndarray[tuple[int, ...], dtype[Any]] | None"; expected "ndarray[Any, Any]" [arg-type] start_pos=np.array(solution[0]) if solution_valid else None, # type: ignore[arg-type] end_pos=np.array(solution[-1]) if solution_valid else None, # type: ignore[arg-type] ) self.__dict__["solution"] = solution # adjust the endpoints if not allow_invalid: if start_pos is not None: assert np.array_equal(np.array(start_pos), self.start_pos), ( f"when trying to create a SolvedMaze, the given start_pos does not match the one in the solution: given={start_pos}, solution={self.start_pos}" ) if end_pos is not None: assert np.array_equal(np.array(end_pos), self.end_pos), ( f"when trying to create a SolvedMaze, the given end_pos does not match the one in the solution: given={end_pos}, solution={self.end_pos}" ) # TODO: assert the path does not backtrack, walk through walls, etc? def __eq__(self, other: object) -> bool: "check equality, calls parent class equality check" return super().__eq__(other) def __hash__(self) -> int: "hash the `SolvedMaze` by hashing a tuple of the connection list and solution arrays as bytes" return hash((self.connection_list.tobytes(), self.solution.tobytes())) def _get_solution_tokens(self) -> list[str | CoordTup]: return [ SPECIAL_TOKENS.PATH_START, *[tuple(c) for c in self.solution], SPECIAL_TOKENS.PATH_END, ] def get_solution_tokens(self) -> list[str | CoordTup]: "(deprecated!) return the solution as a list of tokens" warnings.warn( "`LatticeMaze.get_solution_tokens` is deprecated.", TokenizerDeprecationWarning, ) return self._get_solution_tokens() # for backwards compatibility @property def maze(self) -> LatticeMaze: "(deprecated!) return the maze without the solution" warnings.warn( "`maze` is deprecated, SolvedMaze now inherits from LatticeMaze.", DeprecationWarning, ) return LatticeMaze(connection_list=self.connection_list) # type ignore here since we're overriding a method with a different signature @classmethod def from_lattice_maze( # type: ignore[override] cls, lattice_maze: LatticeMaze, solution: list[CoordTup] | CoordArray, ) -> "SolvedMaze": "get a `SolvedMaze` from a `LatticeMaze` by specifying a solution" return cls( connection_list=lattice_maze.connection_list, solution=np.array(solution), generation_meta=lattice_maze.generation_meta, ) @classmethod def from_targeted_lattice_maze( cls, targeted_lattice_maze: TargetedLatticeMaze, solution: list[CoordTup] | CoordArray | None = None, ) -> "SolvedMaze": """solves the given targeted lattice maze and returns a SolvedMaze""" if solution is None: solution = targeted_lattice_maze.find_shortest_path( targeted_lattice_maze.start_pos, targeted_lattice_maze.end_pos, ) return cls( connection_list=targeted_lattice_maze.connection_list, solution=np.array(solution), generation_meta=targeted_lattice_maze.generation_meta, ) def get_solution_forking_points( self, always_include_endpoints: bool = False, ) -> tuple[list[int], CoordArray]: """coordinates and their indicies from the solution where a fork is present - if the start point is not a dead end, this counts as a fork - if the end point is not a dead end, this counts as a fork """ output_idxs: list[int] = list() output_coords: list[CoordTup] = list() for idx, coord in enumerate(self.solution): # more than one choice for first coord, or more than 2 for any other # since the previous coord doesn't count as a choice is_endpoint: bool = idx == 0 or idx == self.solution.shape[0] - 1 theshold: int = 1 if is_endpoint else 2 if self.get_coord_neighbors(coord).shape[0] > theshold or ( is_endpoint and always_include_endpoints ): output_idxs.append(idx) output_coords.append(coord) return output_idxs, np.array(output_coords) def get_solution_path_following_points(self) -> tuple[list[int], CoordArray]: """coordinates from the solution where there is only a single (non-backtracking) point to move to returns the complement of `get_solution_forking_points` from the path """ forks_idxs, _ = self.get_solution_forking_points() # HACK: idk why type ignore here return ( # type: ignore[return-value] np.delete(np.arange(self.solution.shape[0]), forks_idxs, axis=0), np.delete(self.solution, forks_idxs, axis=0), ) def detect_pixels_type(data: PixelGrid) -> typing.Type[LatticeMaze]: """Detects the type of pixels data by checking for the presence of start and end pixels""" if color_in_pixel_grid(data, PixelColors.START) or color_in_pixel_grid( data, PixelColors.END, ): if color_in_pixel_grid(data, PixelColors.PATH): return SolvedMaze else: return TargetedLatticeMaze else: return LatticeMaze def _remove_isolated_cells( image: Int[np.ndarray, "RGB x y"], ) -> Int[np.ndarray, "RGB x y"]: """Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides.""" # Create a binary mask where True represents walls wall_mask = np.all(image == PixelColors.WALL, axis=-1) # Pad the wall mask to handle edge cases padded_wall_mask = np.pad( wall_mask, ((1, 1), (1, 1)), mode="constant", constant_values=True, ) # Check neighbors in all four directions isolated_mask = ( padded_wall_mask[1:-1, 2:] # right & padded_wall_mask[1:-1, :-2] # left & padded_wall_mask[2:, 1:-1] # down & padded_wall_mask[:-2, 1:-1] # up ) # Combine with non-wall mask to only affect open cells non_wall_mask = ~wall_mask isolated_mask = isolated_mask & non_wall_mask # Create the output image output_image = image.copy() output_image[isolated_mask] = PixelColors.WALL return output_image _RIC_PADS: dict = { "left": ((1, 0), (0, 0)), "right": ((0, 1), (0, 0)), "up": ((0, 0), (1, 0)), "down": ((0, 0), (0, 1)), } # Define slices for each direction _RIC_SLICES: dict = { "left": (slice(1, None), slice(None, None)), "right": (slice(None, -1), slice(None, None)), "up": (slice(None, None), slice(1, None)), "down": (slice(None, None), slice(None, -1)), } # TODO: figure out why this function doesnt work, or maybe just get rid of it # def _remove_isolated_cells_old( # image: Int[np.ndarray, "RGB x y"], # ) -> Int[np.ndarray, "RGB x y"]: # """ # Removes isolated cells from an image. An isolated cell is a cell that is surrounded by walls on all sides. # """ # warnings.warn("this functin doesn't work and I have no idea why!!!") # masks: dict[str, np.ndarray] = { # d: np.all( # np.pad( # image[_RIC_SLICES[d][0], _RIC_SLICES[d][1], :] == PixelColors.WALL, # np.array((*_RIC_PADS[d], (0, 0)), dtype=np.int8), # mode="constant", # constant_values=True, # ), # axis=2, # ) # for d in _RIC_SLICES.keys() # } # # Create a mask for non-wall cells # mask_non_wall = np.all(image != PixelColors.WALL, axis=2) # # print(f"{mask_non_wall.shape = }") # # print(f"{ {k: masks[k].shape for k in masks.keys()} = }") # # print(f"{mask_non_wall = }") # # print(f"{masks['down'] = }") # # Combine the masks # mask = mask_non_wall & masks["left"] & masks["right"] & masks["up"] & masks["down"] # # Apply the mask # output_image = np.where( # np.stack([mask] * 3, axis=-1), # PixelColors.WALL, # image, # ) # return output_image ``````{ end_of_file="maze_dataset/maze/lattice_maze.py" } ``````{ path="maze_dataset/plotting/__init__.py" } """utilities for plotting mazes and printing tokens - any `LatticeMaze` or `SolvedMaze` comes with a `as_pixels()` method that returns a 2D numpy array of pixel values, but this is somewhat limited - `MazePlot` is a class that can be used to plot mazes and paths in a more customizable way - `print_tokens` contains utilities for printing tokens, colored by their type, position, or some custom weights (i.e. attention weights) """ from maze_dataset.plotting.plot_dataset import plot_dataset_mazes, print_dataset_mazes from maze_dataset.plotting.plot_maze import DEFAULT_FORMATS, MazePlot, PathFormat from maze_dataset.plotting.print_tokens import ( color_maze_tokens_AOTP, color_tokens_cmap, color_tokens_rgb, ) __all__ = [ # submodules "plot_dataset", "plot_maze", "plot_svg_fancy", "plot_tokens", "print_tokens", # imports "plot_dataset_mazes", "print_dataset_mazes", "DEFAULT_FORMATS", "MazePlot", "PathFormat", "color_tokens_cmap", "color_maze_tokens_AOTP", "color_tokens_rgb", ] ``````{ end_of_file="maze_dataset/plotting/__init__.py" } ``````{ path="maze_dataset/plotting/plot_dataset.py" } """`plot_dataset_mazes` will plot several mazes using `as_pixels` `print_dataset_mazes` will use `as_ascii` to print several mazes """ import matplotlib.pyplot as plt # type: ignore[import] from maze_dataset.dataset.maze_dataset import MazeDataset def plot_dataset_mazes( ds: MazeDataset, count: int | None = None, figsize_mult: tuple[float, float] = (1.0, 2.0), title: bool | str = True, ) -> tuple | None: "plot `count` mazes from the dataset `d` in a single figure using `SolvedMaze.as_pixels()`" count = count or len(ds) if count == 0: print("No mazes to plot for dataset") return None fig, axes = plt.subplots( 1, count, figsize=(count * figsize_mult[0], figsize_mult[1]), ) if count == 1: axes = [axes] for i in range(count): axes[i].imshow(ds[i].as_pixels()) # remove ticks axes[i].set_xticks([]) axes[i].set_yticks([]) # set title if title: if isinstance(title, str): fig.suptitle(title) else: kwargs: dict = { "grid_n": ds.cfg.grid_n, # "n_mazes": ds.cfg.n_mazes, **ds.cfg.maze_ctor_kwargs, } fig.suptitle( f"{ds.cfg.to_fname()}\n{ds.cfg.maze_ctor.__name__}({', '.join(f'{k}={v}' for k, v in kwargs.items())})", ) # tight layout fig.tight_layout() # remove whitespace between title and subplots fig.subplots_adjust(top=1.0) return fig, axes def print_dataset_mazes(ds: MazeDataset, count: int | None = None) -> None: "print ascii representation of `count` mazes from the dataset `d`" count = count or len(ds) if count == 0: print("No mazes to print for dataset") return for i in range(count): print(ds[i].as_ascii(), "\n\n-----\n") ``````{ end_of_file="maze_dataset/plotting/plot_dataset.py" } ``````{ path="maze_dataset/plotting/plot_maze.py" } """provides `MazePlot`, which has many tools for plotting mazes with multiple paths, colored nodes, and more""" from __future__ import annotations # for type hinting self as return value import warnings from copy import deepcopy from dataclasses import dataclass from typing import Sequence import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np from jaxtyping import Bool, Float from maze_dataset.constants import Coord, CoordArray, CoordList from maze_dataset.maze import ( LatticeMaze, SolvedMaze, TargetedLatticeMaze, ) LARGE_NEGATIVE_NUMBER: float = -1e10 @dataclass(kw_only=True) class PathFormat: """formatting options for path plot""" label: str | None = None fmt: str = "o" color: str | None = None cmap: str | None = None line_width: float | None = None quiver_kwargs: dict | None = None def combine(self, other: PathFormat) -> PathFormat: """combine with other PathFormat object, overwriting attributes with non-None values. returns a modified copy of self. """ output: PathFormat = deepcopy(self) for key, value in other.__dict__.items(): if key == "path": err_msg: str = f"Cannot overwrite path attribute! {self = }, {other = }" raise ValueError( err_msg, ) if value is not None: setattr(output, key, value) return output # styled path @dataclass class StyledPath(PathFormat): "a `StyledPath` is a `PathFormat` with a specific path" path: CoordArray DEFAULT_FORMATS: dict[str, PathFormat] = { "true": PathFormat( label="true path", fmt="--", color="red", line_width=2.5, quiver_kwargs=None, ), "predicted": PathFormat( label=None, fmt=":", color=None, line_width=2, quiver_kwargs={"width": 0.015}, ), } def process_path_input( path: CoordList | CoordArray | StyledPath, _default_key: str, path_fmt: PathFormat | None = None, **kwargs, ) -> StyledPath: "convert a path, which might be a list or array of coords, into a `StyledPath`" styled_path: StyledPath if isinstance(path, StyledPath): styled_path = path elif isinstance(path, np.ndarray): styled_path = StyledPath(path=path) # add default formatting styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) elif isinstance(path, list): styled_path = StyledPath(path=np.array(path)) # add default formatting styled_path = styled_path.combine(DEFAULT_FORMATS[_default_key]) else: err_msg: str = ( f"Expected CoordList, CoordArray or StyledPath, got {type(path)}: {path}" ) raise TypeError( err_msg, ) # add formatting from path_fmt if path_fmt is not None: styled_path = styled_path.combine(path_fmt) # add formatting from kwargs for key, value in kwargs.items(): setattr(styled_path, key, value) return styled_path DEFAULT_PREDICTED_PATH_COLORS: list[str] = [ "tab:orange", "tab:olive", "sienna", "mediumseagreen", "tab:purple", "slategrey", ] class MazePlot: """Class for displaying mazes and paths""" def __init__(self, maze: LatticeMaze, unit_length: int = 14) -> None: """UNIT_LENGTH: Set ratio between node size and wall thickness in image. Wall thickness is fixed to 1px A "unit" consists of a single node and the right and lower connection/wall. Example: ul = 14 yields 13:1 ratio between node size and wall thickness """ self.unit_length: int = unit_length self.maze: LatticeMaze = maze self.true_path: StyledPath | None = None self.predicted_paths: list[StyledPath] = [] self.node_values: Float[np.ndarray, "grid_n grid_n"] = None self.custom_node_value_flag: bool = False self.node_color_map: str = "Blues" self.target_token_coord: Coord = None self.preceding_tokens_coords: CoordArray = None self.colormap_center: float | None = None self.cbar_ax = None self.marked_coords: list[tuple[Coord, dict]] = list() self.marker_kwargs_current: dict = dict( marker="s", color="green", ms=12, ) self.marker_kwargs_next: dict = dict( marker="P", color="green", ms=12, ) if isinstance(maze, SolvedMaze): self.add_true_path(maze.solution) else: if isinstance(maze, TargetedLatticeMaze): self.add_true_path(SolvedMaze.from_targeted_lattice_maze(maze).solution) @property def solved_maze(self) -> SolvedMaze: "get the underlying `SolvedMaze` object" if self.true_path is None: raise ValueError( "Cannot return SolvedMaze object without true path. Add true path with add_true_path method.", ) return SolvedMaze.from_lattice_maze( lattice_maze=self.maze, solution=self.true_path.path, ) def add_true_path( self, path: CoordList | CoordArray | StyledPath, path_fmt: PathFormat | None = None, **kwargs, ) -> MazePlot: "add a true path to the maze with optional formatting" self.true_path = process_path_input( path=path, _default_key="true", path_fmt=path_fmt, **kwargs, ) return self def add_predicted_path( self, path: CoordList | CoordArray | StyledPath, path_fmt: PathFormat | None = None, **kwargs, ) -> MazePlot: """Recieve predicted path and formatting preferences from input and save in predicted_path list. Default formatting depends on nuber of paths already saved in predicted path list. """ styled_path: StyledPath = process_path_input( path=path, _default_key="predicted", path_fmt=path_fmt, **kwargs, ) # set default label and color if not specified if styled_path.label is None: styled_path.label = f"predicted path {len(self.predicted_paths) + 1}" if styled_path.color is None: color_num: int = len(self.predicted_paths) % len( DEFAULT_PREDICTED_PATH_COLORS, ) styled_path.color = DEFAULT_PREDICTED_PATH_COLORS[color_num] self.predicted_paths.append(styled_path) return self def add_multiple_paths( self, path_list: Sequence[CoordList | CoordArray | StyledPath], ) -> MazePlot: """Function for adding multiple paths to MazePlot at once. > DOCS: what are the two ways? This can be done in two ways: 1. Passing a list of """ for path in path_list: self.add_predicted_path(path) return self def add_node_values( self, node_values: Float[np.ndarray, "grid_n grid_n"], color_map: str = "Blues", target_token_coord: Coord | None = None, preceeding_tokens_coords: CoordArray = None, colormap_center: float | None = None, colormap_max: float | None = None, hide_colorbar: bool = False, ) -> MazePlot: """add node values to the maze for visualization as a heatmap > DOCS: what are these arguments? # Parameters: - `node_values : Float[np.ndarray, "grid_n grid_n"]` - `color_map : str` (defaults to `"Blues"`) - `target_token_coord : Coord | None` (defaults to `None`) - `preceeding_tokens_coords : CoordArray` (defaults to `None`) - `colormap_center : float | None` (defaults to `None`) - `colormap_max : float | None` (defaults to `None`) - `hide_colorbar : bool` (defaults to `False`) # Returns: - `MazePlot` """ assert node_values.shape == self.maze.grid_shape, ( "Please pass node values of the same sape as LatticeMaze.grid_shape" ) # assert np.min(node_values) >= 0, "Please pass non-negative node values only." self.node_values = node_values # Set flag for choosing cmap while plotting maze self.custom_node_value_flag = True # Retrieve Max node value for plotting, +1e-10 to avoid division by zero self.node_color_map = color_map self.colormap_center = colormap_center self.colormap_max = colormap_max self.hide_colorbar = hide_colorbar if target_token_coord is not None: self.marked_coords.append((target_token_coord, self.marker_kwargs_next)) if preceeding_tokens_coords is not None: for coord in preceeding_tokens_coords: self.marked_coords.append((coord, self.marker_kwargs_current)) return self def plot( self, dpi: int = 100, title: str = "", fig_ax: tuple | None = None, plain: bool = False, ) -> MazePlot: """Plot the maze and paths.""" # set up figure if fig_ax is None: self.fig = plt.figure(dpi=dpi) self.ax = self.fig.add_subplot(1, 1, 1) else: self.fig, self.ax = fig_ax # plot maze self._plot_maze() # Plot labels if not plain: tick_arr = np.arange(self.maze.grid_shape[0]) self.ax.set_xticks(self.unit_length * (tick_arr + 0.5), tick_arr) self.ax.set_yticks(self.unit_length * (tick_arr + 0.5), tick_arr) self.ax.set_xlabel("col") self.ax.set_ylabel("row") self.ax.set_title(title) else: self.ax.set_xticks([]) self.ax.set_yticks([]) self.ax.set_xlabel("") self.ax.set_ylabel("") self.ax.axis("off") # plot paths if self.true_path is not None: self._plot_path(self.true_path) for path in self.predicted_paths: self._plot_path(path) # plot markers for coord, kwargs in self.marked_coords: self._place_marked_coords([coord], **kwargs) return self def _rowcol_to_coord(self, point: Coord) -> np.ndarray: """Transform Point from MazeTransformer (row, column) notation to matplotlib default (x, y) notation where x is the horizontal axis.""" point = np.array([point[1], point[0]]) return self.unit_length * (point + 0.5) def mark_coords(self, coords: CoordArray | list[Coord], **kwargs) -> MazePlot: """Mark coordinates on the maze with a marker. default marker is a blue "+": `dict(marker="+", color="blue")` """ kwargs = { **dict(marker="+", color="blue"), **kwargs, } for coord in coords: self.marked_coords.append((coord, kwargs)) return self def _place_marked_coords( self, coords: CoordArray | list[Coord], **kwargs, ) -> MazePlot: coords_tp = np.array([self._rowcol_to_coord(coord) for coord in coords]) self.ax.plot(coords_tp[:, 0], coords_tp[:, 1], **kwargs) return self def _plot_maze(self) -> None: # noqa: C901, PLR0912 """Define Colormap and plot maze. Colormap: x is -inf: black else: use colormap """ img = self._lattice_maze_to_img() # if no node_values have been passed (no colormap) if self.custom_node_value_flag is False: self.ax.imshow(img, cmap="gray", vmin=-1, vmax=1) else: assert self.node_values is not None, "Please pass node values." assert not np.isnan(self.node_values).any(), ( "Please pass node values, they cannot be nan." ) vals_min: float = np.nanmin(self.node_values) vals_max: float = np.nanmax(self.node_values) # if both are negative or both are positive, set max/min to 0 if vals_max < 0.0: vals_max = 0.0 elif vals_min > 0.0: vals_min = 0.0 # adjust vals_max, in case you need consistent colorbar across multiple plots vals_max = self.colormap_max or vals_max # create colormap cmap = mpl.colormaps[self.node_color_map] # TODO: this is a hack, we make the walls black (while still allowing negative values) by setting the nan color to black cmap.set_bad(color="black") if self.colormap_center is not None: if not (vals_min < self.colormap_center < vals_max): if vals_min == self.colormap_center: vals_min -= 1e-10 elif vals_max == self.colormap_center: vals_max += 1e-10 else: err_msg: str = f"Please pass colormap_center value between {vals_min} and {vals_max}" raise ValueError( err_msg, ) norm = mpl.colors.TwoSlopeNorm( vmin=vals_min, vcenter=self.colormap_center, vmax=vals_max, ) _plotted = self.ax.imshow(img, cmap=cmap, norm=norm) else: _plotted = self.ax.imshow(img, cmap=cmap, vmin=vals_min, vmax=vals_max) # Add colorbar based on the condition of self.hide_colorbar if not self.hide_colorbar: ticks = np.linspace(vals_min, vals_max, 5) if (vals_min < 0.0 < vals_max) and (0.0 not in ticks): ticks = np.insert(ticks, np.searchsorted(ticks, 0.0), 0.0) if ( self.colormap_center is not None and self.colormap_center not in ticks and vals_min < self.colormap_center < vals_max ): ticks = np.insert( ticks, np.searchsorted(ticks, self.colormap_center), self.colormap_center, ) cbar = plt.colorbar( _plotted, ticks=ticks, ax=self.ax, cax=self.cbar_ax, ) self.cbar_ax = cbar.ax # make the boundaries of the image thicker (walls look weird without this) for axis in ["top", "bottom", "left", "right"]: self.ax.spines[axis].set_linewidth(2) def _lattice_maze_to_img( self, connection_val_scale: float = 0.93, ) -> Bool[np.ndarray, "row col"]: """Build an image to visualise the maze. Each "unit" consists of a node and the right and lower adjacent wall/connection. Its area is ul * ul. - Nodes have area: (ul-1) * (ul-1) and value 1 by default - take node_value if passed via .add_node_values() - Walls have area: 1 * (ul-1) and value -1 - Connections have area: 1 * (ul-1); color and value 0.93 by default - take node_value if passed via .add_node_values() Axes definition: (0,0) col ----|-----------> | row | | v Returns a matrix of side length (ul) * n + 1 where n is the number of nodes. """ # TODO: this is a hack, but if you add 1 always then non-node valued plots have their walls dissapear. if you dont add 1, you get ugly colors between nodes when they are colored node_bdry_hack: int connection_list_processed: Float[np.ndarray, "dim row col"] # Set node and connection values if self.node_values is None: scaled_node_values = np.ones(self.maze.grid_shape) connection_values = scaled_node_values * connection_val_scale node_bdry_hack = 0 # TODO: hack # invert connection list connection_list_processed = np.logical_not(self.maze.connection_list) else: # TODO: hack scaled_node_values = self.node_values # connection_values = scaled_node_values connection_values = np.full_like(scaled_node_values, np.nan) node_bdry_hack = 1 connection_list_processed = self.maze.connection_list # Create background image (all pixels set to -1, walls everywhere) img: Float[np.ndarray, "row col"] = -np.ones( ( self.maze.grid_shape[0] * self.unit_length + 1, self.maze.grid_shape[1] * self.unit_length + 1, ), dtype=float, ) # Draw nodes and connections by iterating through lattice for row in range(self.maze.grid_shape[0]): for col in range(self.maze.grid_shape[1]): # Draw node img[ row * self.unit_length + 1 : (row + 1) * self.unit_length + node_bdry_hack, col * self.unit_length + 1 : (col + 1) * self.unit_length + node_bdry_hack, ] = scaled_node_values[row, col] # Down connection if not connection_list_processed[0, row, col]: img[ (row + 1) * self.unit_length, col * self.unit_length + 1 : (col + 1) * self.unit_length, ] = connection_values[row, col] # Right connection if not connection_list_processed[1, row, col]: img[ row * self.unit_length + 1 : (row + 1) * self.unit_length, (col + 1) * self.unit_length, ] = connection_values[row, col] return img def _plot_path(self, path_format: PathFormat) -> None: if len(path_format.path) == 0: warnings.warn(f"Empty path, skipping plotting\n{path_format = }") return p_transformed = np.array( [self._rowcol_to_coord(coord) for coord in path_format.path], ) if path_format.quiver_kwargs is not None: try: x: np.ndarray = p_transformed[:, 0] y: np.ndarray = p_transformed[:, 1] except Exception as e: err_msg: str = f"Error in plotting quiver path:\n{path_format = }\n{p_transformed = }\n{e}" raise ValueError( err_msg, ) from e # Generate colors from the colormap if path_format.cmap is not None: n = len(x) - 1 # Number of arrows cmap = plt.get_cmap(path_format.cmap) colors = [cmap(i / n) for i in range(n)] else: colors = path_format.color self.ax.quiver( x[:-1], y[:-1], x[1:] - x[:-1], y[1:] - y[:-1], scale_units="xy", angles="xy", scale=1, color=colors, **path_format.quiver_kwargs, ) else: self.ax.plot( p_transformed[:, 0], p_transformed[:, 1], path_format.fmt, lw=path_format.line_width, color=path_format.color, label=path_format.label, ) # mark endpoints self.ax.plot( [p_transformed[0][0]], [p_transformed[0][1]], "o", color=path_format.color, ms=10, ) self.ax.plot( [p_transformed[-1][0]], [p_transformed[-1][1]], "x", color=path_format.color, ms=10, ) def to_ascii( self, show_endpoints: bool = True, show_solution: bool = True, ) -> str: "wrapper for `self.solved_maze.as_ascii()`, shows the path if we have `self.true_path`" if self.true_path: return self.solved_maze.as_ascii( show_endpoints=show_endpoints, show_solution=show_solution, ) else: return self.maze.as_ascii(show_endpoints=show_endpoints) ``````{ end_of_file="maze_dataset/plotting/plot_maze.py" } ``````{ path="maze_dataset/plotting/plot_svg_fancy.py" } """Plot a maze as SVG with rounded corners.""" from xml.dom.minidom import parseString from xml.etree.ElementTree import Element, SubElement, tostring import numpy as np # Known color map (excluding walls). COLOR_MAP: dict[tuple[int, int, int], str] = { (255, 255, 255): "#f0f0f0", (0, 255, 0): "#4caf50", (255, 0, 0): "#f44336", (0, 0, 255): "#2196f3", } WALL_COLOR_HEX: str = "#222" # (0,0,0) in hex WALL_RGB: tuple[int, int, int] = (0, 0, 0) # Offsets in the order [top, right, bottom, left] _NEIGHBORS: np.ndarray = np.array( [ [-1, 0], # top [0, +1], # right [+1, 0], # bottom [0, -1], # left ], dtype=int, ) def is_wall(y: int, x: int, grid: np.ndarray) -> bool: """True if (y, x) is out of bounds or has the wall color.""" h, w, _ = grid.shape if not (0 <= y < h and 0 <= x < w): return True return bool((grid[y, x] == WALL_RGB).all()) def create_tile_path( origin: tuple[float, float], tile_size: float, corner_radius: float, edges: tuple[bool, bool, bool, bool], ) -> str: """Generate an SVG path for a tile at `origin` with side length `tile_size`. `edges` is (top, right, bottom, left) booleans, where True means that edge borders a wall/outside. If both edges meeting at a corner are True and corner_radius>0, we draw a rounded corner; else it's a sharp corner. Corner order (clockwise): c0 = top-left c1 = top-right c2 = bottom-right c3 = bottom-left edges = (top, right, bottom, left). corner c0 is formed by edges top + left => edges[0] & edges[3] corner c1 => top + right => edges[0] & edges[1] corner c2 => right + bottom => edges[1] & edges[2] corner c3 => bottom + left => edges[2] & edges[3] """ x0, y0 = origin top, right, bottom, left = edges # A corner is "exposed" if both adjoining edges are True c0_exposed: bool = top and left # top-left c1_exposed: bool = top and right # top-right c2_exposed: bool = right and bottom # bottom-right c3_exposed: bool = bottom and left # bottom-left # If corner_radius=0, arcs become straight lines. r: float = corner_radius # We'll construct the path in a standard top-left -> top-right -> bottom-right -> bottom-left order. path_cmds = [] # Move to top-left corner, possibly offset if c0 is exposed # (meaning both top and left edges are external). start_x = x0 + (r if c0_exposed else 0) start_y = y0 path_cmds.append(f"M {start_x},{start_y}") # === TOP edge to top-right corner end_x = x0 + tile_size - (r if c1_exposed else 0) end_y = y0 path_cmds.append(f"L {end_x},{end_y}") # Arc if c1_exposed if c1_exposed and r > 0: path_cmds.append(f"A {r} {r} 0 0 1 {x0 + tile_size},{y0 + r}") # === RIGHT edge to bottom-right corner path_cmds.append(f"L {x0 + tile_size},{y0 + tile_size - (r if c2_exposed else 0)}") if c2_exposed and r > 0: path_cmds.append(f"A {r} {r} 0 0 1 {x0 + tile_size - r},{y0 + tile_size}") # === BOTTOM edge to bottom-left corner path_cmds.append(f"L {x0 + (r if c3_exposed else 0)},{y0 + tile_size}") if c3_exposed and r > 0: path_cmds.append(f"A {r} {r} 0 0 1 {x0},{y0 + tile_size - r}") # === LEFT edge back up to top-left corner path_cmds.append(f"L {x0},{y0 + (r if c0_exposed else 0)}") if c0_exposed and r > 0: path_cmds.append(f"A {r} {r} 0 0 1 {x0 + r},{y0}") path_cmds.append("Z") return " ".join(path_cmds) def plot_svg_fancy( pixel_grid: np.ndarray, size: int = 40, corner_radius: float = 8.0, bounding_corner_radius: float = 20.0, ) -> str: """plot the output of SolvedMaze(...).as_pixels() as a nice svg Create an SVG with: - A single rounded-square background (walls). - Each non-wall cell is drawn via create_tile_path, with corner_radius controlling whether corners are rounded. (Set corner_radius=0 for squares.) # Parameters: - `pixel_grid : np.ndarray` 3D array of shape (h, w, 3) with RGB values - `size : int` Size (in px) of each grid cell - `corner_radius : float` Radius for rounding corners of each tile (0 => squares) - `bounding_corner_radius : float` Radius for rounding the outer bounding rectangle # Returns: `str`: A pretty-printed SVG string """ h, w, _ = pixel_grid.shape # Create the root svg = Element( "svg", xmlns="http://www.w3.org/2000/svg", width=str(w * size), height=str(h * size), viewBox=f"0 0 {w * size} {h * size}", ) # Single rounded-square background for the walls SubElement( svg, "rect", { "x": "0", "y": "0", "width": str(w * size), "height": str(h * size), "fill": WALL_COLOR_HEX, "rx": str(bounding_corner_radius), "ry": str(bounding_corner_radius), }, ) for yy in range(h): for xx in range(w): rgb_tuple = tuple(pixel_grid[yy, xx]) if rgb_tuple == WALL_RGB: # It's a wall => skip (already covered by background) continue fill_color: str | None = COLOR_MAP.get(rgb_tuple, None) # noqa: SIM910 if fill_color is None: # Unknown color => skip or handle differently continue # Check which edges are "external" => next cell is wall # edges in the order (top, right, bottom, left) edges_bool = [ is_wall(yy + dy, xx + dx, pixel_grid) for (dy, dx) in _NEIGHBORS ] d_path = create_tile_path( origin=(xx * size, yy * size), tile_size=size, corner_radius=corner_radius, edges=tuple(edges_bool), # type: ignore[arg-type] ) SubElement( svg, "path", { "d": d_path, "fill": fill_color, "stroke": "none", }, ) raw_svg = tostring(svg, encoding="unicode") # we are in charge of the svg so it's safe to decode return parseString(raw_svg).toprettyxml(indent=" ") # noqa: S318 ``````{ end_of_file="maze_dataset/plotting/plot_svg_fancy.py" } ``````{ path="maze_dataset/plotting/plot_tokens.py" } "`plot_colored_text` function to plot tokens on a matplotlib axis with colored backgrounds" from typing import Any, Sequence import matplotlib.pyplot as plt import numpy as np def plot_colored_text( tokens: Sequence[str], weights: Sequence[float], # assume its a colormap if not a string cmap: str | Any, # noqa: ANN401 ax: plt.Axes | None = None, width_scale: float = 0.023, width_offset: float = 0.005, height_offset: float = 0.1, rect_height: float = 0.7, token_height: float = 0.7, label_height: float = 0.3, word_gap: float = 0.01, fontsize: int = 12, fig_height: float = 0.7, fig_width_scale: float = 0.25, char_min: int = 4, ) -> plt.Axes: "hacky function to plot tokens on a matplotlib axis with colored backgrounds" assert len(tokens) == len(weights), ( f"The number of tokens and weights must be the same: {len(tokens)} != {len(weights)}" ) total_len_estimate: float = sum([max(len(tok), char_min) for tok in tokens]) # set up figure if needed if ax is None: fig, ax = plt.subplots( figsize=(total_len_estimate * fig_width_scale, fig_height), ) ax.axis("off") # Normalize the weights to be between 0 and 1 norm_weights: Sequence[float] = (weights - np.min(weights)) / ( np.max(weights) - np.min(weights) ) # Create a colormap instance if isinstance(cmap, str): colormap = plt.get_cmap(cmap) else: colormap = cmap x_pos: float = 0.0 for i, (tok, weight, norm_wgt) in enumerate( # noqa: B007 zip(tokens, weights, norm_weights, strict=False), ): color = colormap(norm_wgt)[:3] # Plot the background color rect_width = width_scale * max(len(tok), char_min) ax.add_patch( plt.Rectangle( (x_pos, height_offset), rect_width, height_offset + rect_height, fc=color, ec="none", ), ) # Plot the token ax.text( x_pos + width_offset, token_height, tok, fontsize=fontsize, va="center", ha="left", ) # Plot the weight below the token ax.text( x_pos + width_offset, label_height, f"{weight:.2f}", fontsize=fontsize, va="center", ha="left", ) x_pos += rect_width + word_gap return ax ``````{ end_of_file="maze_dataset/plotting/plot_tokens.py" } ``````{ path="maze_dataset/plotting/print_tokens.py" } """Functions to print tokens with colors in different formats you can color the tokens by their: - type (i.e. adjacency list, origin, target, path) using `color_maze_tokens_AOTP` - custom weights (i.e. attention weights) using `color_tokens_cmap` - entirely custom colors using `color_tokens_rgb` and the output can be in different formats, specified by `FormatType` (html, latex, terminal) """ import html import textwrap from typing import Literal, Sequence import matplotlib # noqa: ICN001 import numpy as np from IPython.display import HTML, display from jaxtyping import Float, UInt8 from muutils.misc import flatten from maze_dataset.constants import SPECIAL_TOKENS from maze_dataset.token_utils import tokens_between RGBArray = UInt8[np.ndarray, "n 3"] "1D array of RGB values" FormatType = Literal["html", "latex", "terminal", None] "output format for the tokens" TEMPLATES: dict[FormatType, str] = { "html": ' {tok} ', "latex": "\\colorbox[RGB]{{ {clr} }}{{ \\texttt{{ {tok} }} }}", "terminal": "\033[30m\033[48;2;{clr}m{tok}\033[0m", } "templates of printing tokens in different formats" _COLOR_JOIN: dict[FormatType, str] = { "html": ",", "latex": ",", "terminal": ";", } "joiner for colors in different formats" def _escape_tok( tok: str, fmt: FormatType, ) -> str: "escape token based on format" if fmt == "html": return html.escape(tok) elif fmt == "latex": return tok.replace("_", "\\_").replace("#", "\\#") elif fmt == "terminal": return tok else: err_msg: str = f"Unexpected format: {fmt}" raise ValueError(err_msg) def color_tokens_rgb( tokens: list, colors: Sequence[Sequence[int]] | Float[np.ndarray, "n 3"], fmt: FormatType = "html", template: str | None = None, clr_join: str | None = None, max_length: int | None = None, ) -> str: """color tokens from a list with an RGB color array tokens will not be escaped if `fmt` is None # Parameters: - `max_length: int | None`: Max number of characters before triggering a line wrap, i.e., making a new colorbox. If `None`, no limit on max length. """ # process format if fmt is None: assert template is not None assert clr_join is not None else: assert template is None assert clr_join is None template = TEMPLATES[fmt] clr_join = _COLOR_JOIN[fmt] if max_length is not None: # TODO: why are we using a map here again? # TYPING: this is missing a lot of type hints wrapped: list = list( # noqa: C417 map( lambda x: textwrap.wrap( x, width=max_length, break_long_words=False, break_on_hyphens=False, ), tokens, ), ) colors = list( flatten( [[colors[i]] * len(wrapped[i]) for i in range(len(wrapped))], levels_to_flatten=1, ), ) wrapped = list(flatten(wrapped, levels_to_flatten=1)) tokens = wrapped # put everything together output = [ template.format( clr=clr_join.join(map(str, map(int, clr))), tok=_escape_tok(tok, fmt), ) for tok, clr in zip(tokens, colors, strict=False) ] return " ".join(output) # TYPING: would be nice to type hint as html, latex, or terminal string and overload depending on `FormatType` def color_tokens_cmap( tokens: list[str], weights: Sequence[float], cmap: str | matplotlib.colors.Colormap = "Blues", fmt: FormatType = "html", template: str | None = None, labels: bool = False, ) -> str: "color tokens given a list of weights and a colormap" n_tok: int = len(tokens) assert n_tok == len(weights), f"'{len(tokens) = }' != '{len(weights) = }'" weights_np: Float[np.ndarray, " n_tok"] = np.array(weights) # normalize weights to [0, 1] weights_norm = matplotlib.colors.Normalize()(weights_np) if isinstance(cmap, str): cmap = matplotlib.colormaps.get_cmap(cmap) colors: RGBArray = cmap(weights_norm)[:, :3] * 255 output: str = color_tokens_rgb( tokens=tokens, colors=colors, fmt=fmt, template=template, ) if labels: if fmt != "terminal": raise NotImplementedError("labels only supported for terminal") # align labels with the tokens output += "\n" for tok, weight in zip(tokens, weights_np, strict=False): # 2 decimal points, left-aligned and trailing spaces to match token length weight_str: str = f"{weight:.1f}" # omit if longer than token if len(weight_str) > len(tok): weight_str = " " * len(tok) else: weight_str = weight_str.ljust(len(tok)) output += f"{weight_str} " return output # colors roughly made to be similar to visual representation _MAZE_TOKENS_DEFAULT_COLORS: dict[tuple[str, str], tuple[int, int, int]] = { (SPECIAL_TOKENS.ADJLIST_START, SPECIAL_TOKENS.ADJLIST_END): ( 176, 152, 232, ), # purple (SPECIAL_TOKENS.ORIGIN_START, SPECIAL_TOKENS.ORIGIN_END): (154, 239, 123), # green (SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END): (246, 136, 136), # red (SPECIAL_TOKENS.PATH_START, SPECIAL_TOKENS.PATH_END): (111, 187, 254), # blue } "default colors for maze tokens, roughly matches the format of `as_pixels`" def color_maze_tokens_AOTP( tokens: list[str], fmt: FormatType = "html", template: str | None = None, **kwargs, ) -> str: """color tokens assuming AOTP format i.e: adjaceny list, origin, target, path """ output: list[str] = [ " ".join( tokens_between( tokens, start_tok, end_tok, include_start=True, include_end=True, ), ) for start_tok, end_tok in _MAZE_TOKENS_DEFAULT_COLORS ] colors: RGBArray = np.array( list(_MAZE_TOKENS_DEFAULT_COLORS.values()), dtype=np.uint8, ) return color_tokens_rgb( tokens=output, colors=colors, fmt=fmt, template=template, **kwargs, ) def display_html(html: str) -> None: "display html string" display(HTML(html)) def display_color_tokens_rgb( tokens: list[str], colors: RGBArray, ) -> None: """display tokens (as html) with custom colors""" html: str = color_tokens_rgb(tokens, colors, fmt="html") display_html(html) def display_color_tokens_cmap( tokens: list[str], weights: Sequence[float], cmap: str | matplotlib.colors.Colormap = "Blues", ) -> None: """display tokens (as html) with color based on weights""" html: str = color_tokens_cmap(tokens, weights, cmap) display_html(html) def display_color_maze_tokens_AOTP( tokens: list[str], ) -> None: """display maze tokens (as html) with AOTP coloring""" html: str = color_maze_tokens_AOTP(tokens) display_html(html) ``````{ end_of_file="maze_dataset/plotting/print_tokens.py" } ``````{ path="maze_dataset/tokenization/modular/__init__.py" } """implements `ModularMazeTokenizer` and related code the structure of a typical `MazeTokenizerModular` is something like this: ``` +----------------------------------------------------+ | MazeTokenizerModular | | +-----------------------------------------------+ | | | _PromptSequencer | | | | +-----------------------------+ | | | | | _CoordTokenizer | | | | | +-----------------------------+ | | | | +------------------------------------+ | | | | | _AdjListTokenizer | | | | | | +-----------+ +-------------+ | | | | | | |_EdgeSubset| |_EdgeGrouping| | | | | | | +-----------+ +-------------+ | | | | | | +-------------+ | | | | | | |_EdgePermuter| | | | | | | +-------------+ | | | | | +------------------------------------+ | | | | +-----------------------------+ | | | | | _TargetTokenizer | | | | | +-----------------------------+ | | | | +------------------------------------------+ | | | | | _PathTokenizer | | | | | | +---------------+ +----------------+ | | | | | | | _StepSize | | _StepTokenizer | | | | | | | +---------------+ +----------------+ | | | | | | | _StepTokenizer | | | | | | | +----------------+ | | | | | | : | | | | | +------------------------------------------+ | | | +-----------------------------------------------+ | +----------------------------------------------------+ ``` Optional delimiter tokens may be added in many places in the output. Delimiter options are all configured using the parameters named `pre`, `intra`, and `post` in various `_TokenizerElement` classes. Each option controls a unique delimiter token. Here we describe each `_TokenizerElement` and the behaviors they support. We also discuss some of the model behaviors and properties that may be investigated using these options. ### Coordinates {#coordtokenizer} The `_CoordTokenizer` object controls how coordinates in the lattice are represented in across all token regions. Options include: - **Unique tokens**: Each coordinate is represented as a single unique token `"(i,j)"` - **Coordinate tuple tokens**: Each coordinate is represented as a sequence of 2 tokens, respectively encoding the row and column positions: `["i", ",", "j"]` ### Adjacency List {#adjlisttokenizer} The `_AdjListTokenizer` object controls this token region. All tokenizations represent the maze connectivity as a sequence of connections or walls between pairs of adjacent coordinates in the lattice. - `_EdgeSubset`: Specifies the subset of lattice edges to be tokenized - **All edges**: Every edge in the lattice - **Connections**: Only edges which contain a connection - **Walls**: Only edges which contain a wall - `_EdgePermuter`: Specifies how to sequence the two coordinates in each lattice edge - **Random** - **Sorted**: The smaller coordinate always comes first - **Both permutations**: Each edge is represented twice, once with each permutation. This option attempts to represent connections in a more directionally symmetric manner. Including only one permutation of each edge may affect models' internal representations of edges, treating a path traversing the edge differently depending on if the coordinate sequence in the path matches the sequence in the adjacency list. - `shuffle_d0`: Whether to shuffle the edges randomly or sort them in the output by their first coordinate - `connection_token_ordinal`: Location in the sequence of the token representing whether the edge is a connection or a wall ### Path {#pathtokenizer} The `_PathTokenizer` object controls this token region. Paths are all represented as a sequence of steps moving from the start to the end position. - `_StepSize`: Specifies the size of each step - **Singles**: Every coordinate traversed between start and end is directly represented - **Forks**: Only coordinates at forking points in the maze are represented. The paths between forking points are implicit. Using this option might train models more directly to represent forking points differently from coordinates where the maze connectivity implies an obvious next step in the path. - `_StepTokenizer`: Specifies how an individual step is represented - **Coordinate**: The coordinates of each step are directly tokenized using a `_CoordTokenizer` - **Cardinal direction**: A single token corresponding to the cardinal direction taken at the starting position of that step. E.g., `NORTH`, `SOUTH`. If using a `_StepSize` other than **Singles**, this direction may not correspond to the final direction traveled to arrive at the end position of the step. - **Relative direction**: A single token corresponding to the first-person perspective relative direction taken at the starting position of that step. E.g., `RIGHT`, `LEFT`. - **Distance**: A single token corresponding to the number of coordinate positions traversed in that step. E.g., using a `_StepSize` of **Singles**, the **Distance** token would be the same for each step, corresponding to a distance of 1 coordinate. This option is only of interest in combination with a `_StepSize` other than **Singles**. A `_PathTokenizer` contains a sequence of one or more unique `_StepTokenizer` objects. Different step representations may be mixed and permuted, allowing for investigation of model representations of multiple aspects of a maze solution at once. """ __all__ = [ # modules "all_instances", "all_tokenizers", "element_base", "elements", "fst_load", "fst", "hashing", "maze_tokenizer_modular", "save_hashes", ] ``````{ end_of_file="maze_dataset/tokenization/modular/__init__.py" } ``````{ path="maze_dataset/tokenization/modular/all_instances.py" } "`all_instances`, `FiniteValued`, and related code for tokenizers" import enum import itertools import typing from dataclasses import Field # noqa: TC003 from functools import cache, wraps from types import UnionType from typing import ( Callable, Generator, Iterable, Literal, TypeVar, get_args, get_origin, ) try: import frozendict except ImportError as e: raise ImportError( "You need to install the `frozendict` package to use `all_instances` -- try installing `maze_dataset[tokenization]`" ) from e from muutils.misc import IsDataclass, flatten, is_abstract FiniteValued = TypeVar("FiniteValued", bound=bool | IsDataclass | enum.Enum) """ # `FiniteValued` The details of this type are not possible to fully define via the Python 3.10 typing library. This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space. `FiniteValued` defines the domain of supported types for the `all_instances` function, since that function relies heavily on static typing. These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below). The leaves of the tree must always be Primitive Types. # `FiniteValued` Subtypes *: Indicates that this subtype is not yet supported by `all_instances` ## Non-`FiniteValued` (Unbounded) Types These are NOT valid subtypes, and are listed for illustrative purposes only. This list is not comprehensive. While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite, they are considered unbounded types in this context. - No Container subtype may contain any of these unbounded subtypes. - `int` - `float` - `str` - `list` - `set`: Set types without a `FiniteValued` argument are unbounded - `tuple`: Tuple types without a fixed length are unbounded ## Primitive Types Primitive types are non-nested types which resolve directly to a concrete range of values - `bool`: has 2 possible values - *`enum.Enum`: The range of a concrete `Enum` subclass is its set of enum members - `typing.Literal`: Every type constructed using `Literal` has a finite set of possible literal values in its definition. This is the preferred way to include limited ranges of non-`FiniteValued` types such as `int` or `str` in a `FiniteValued` hierarchy. ## Container Types Container types are types which contain zero or more fields of `FiniteValued` type. The range of a container type is the cartesian product of their field types, except for `set[FiniteValued]`. - `tuple[FiniteValued]`: Tuples of fixed length whose elements are each `FiniteValued`. - `IsDataclass`: Concrete dataclasses whose fields are `FiniteValued`. - *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are `FiniteValued`-typed. - *`set[FiniteValued]`: Sets of fixed length of a `FiniteValued` type. ## Superclass Types Superclass types don't directly contain data members like container types. Their range is the union of the ranges of their subtypes. - Abstract dataclasses: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types - *`IsDataclass`: Concrete dataclasses which also have their own subclasses. - *Standard abstract classes: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types - `UnionType`: Any union of `FiniteValued` types, e.g., bool | Literal[2, 3] """ def _apply_validation_func( type_: FiniteValued, vals: Generator[FiniteValued, None, None], validation_funcs: ( frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] | None ) = None, ) -> Generator[FiniteValued, None, None]: """Helper function for `all_instances`. Filters `vals` according to `validation_funcs`. If `type_` is a regular type, searches in MRO order in `validation_funcs` and applies the first match, if any. Handles generic types supported by `all_instances` with special `if` clauses. # Parameters - `type_: FiniteValued`: A type - `vals: Generator[FiniteValued, None, None]`: Instances of `type_` - `validation_funcs: dict`: Collection of types mapped to filtering validation functions """ if validation_funcs is None: return vals if type_ in validation_funcs: # Only possible catch of UnionTypes # TYPING: Incompatible return value type (got "filter[FiniteValued]", expected "Generator[FiniteValued, None, None]") [return-value] return filter(validation_funcs[type_], vals) elif hasattr( type_, "__mro__", ): # Generic types like UnionType, Literal don't have `__mro__` for superclass in type_.__mro__: if superclass not in validation_funcs: continue # TYPING: error: Incompatible types in assignment (expression has type "filter[FiniteValued]", variable has type "Generator[FiniteValued, None, None]") [assignment] vals = filter(validation_funcs[superclass], vals) break # Only the first validation function hit in the mro is applied elif get_origin(type_) == Literal: return flatten( ( _apply_validation_func(type(v), [v], validation_funcs) for v in get_args(type_) ), levels_to_flatten=1, ) return vals # TYPING: some better type hints would be nice here def _all_instances_wrapper(f: Callable) -> Callable: """Converts dicts to frozendicts to allow caching and applies `_apply_validation_func`.""" @wraps(f) def wrapper(*args, **kwargs): # noqa: ANN202 @cache def cached_wrapper( # noqa: ANN202 type_: type, all_instances_func: Callable, validation_funcs: ( frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] | None ), ): return _apply_validation_func( type_, all_instances_func(type_, validation_funcs), validation_funcs, ) validation_funcs: frozendict.frozendict # TODO: what is this magic value here exactly? if len(args) >= 2 and args[1] is not None: # noqa: PLR2004 validation_funcs = frozendict.frozendict(args[1]) elif "validation_funcs" in kwargs and kwargs["validation_funcs"] is not None: validation_funcs = frozendict.frozendict(kwargs["validation_funcs"]) else: validation_funcs = None return cached_wrapper(args[0], f, validation_funcs) return wrapper class UnsupportedAllInstancesError(TypeError): """Raised when `all_instances` is called on an unsupported type either has unbounded possible values or is not supported (Enum is not supported) """ def __init__(self, type_: type) -> None: "constructs an error message with the type and mro of the type" msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }" super().__init__(msg) @_all_instances_wrapper def all_instances( type_: FiniteValued, validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None, ) -> Generator[FiniteValued, None, None]: """Returns all possible values of an instance of `type_` if finite instances exist. Uses type hinting to construct the possible values. All nested elements of `type_` must themselves be typed. Do not use with types whose members contain circular references. Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`. # Parameters - `type_: FiniteValued` A finite-valued type. See docstring on `FiniteValued` for full details. - `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None` A mapping of types to auxiliary functions to validate instances of that type. This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide. See `validation_funcs` Details section below. (default: `None`) ## Supported `type_` Values See docstring on `FiniteValued` for full details. `type_` may be: - `FiniteValued` - A finite-valued, fixed-length Generic tuple type. E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK. `tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed. - Nested versions of any of the types in this list - A `UnionType` of any of the types in this list ## `validation_funcs` Details - `validation_funcs` is applied after all instances have been generated according to type hints. - If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`. - `validation_funcs` is passed down for all recursive calls of `all_instances`. - This allows for improved performance through maximal pruning of the exponential tree. - `validation_funcs` supports subclass checking. - If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order. - If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned. - If no superclass of `type_` is found, then no filter is applied. # Raises: - `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`. """ if type_ == bool: # noqa: E721 yield from [True, False] elif hasattr(type_, "__dataclass_fields__"): if is_abstract(type_): # Abstract dataclass: call `all_instances` on each subclass yield from flatten( ( all_instances(sub, validation_funcs) for sub in type_.__subclasses__() ), levels_to_flatten=1, ) else: # Concrete dataclass: construct dataclass instances with all possible combinations of fields fields: list[Field] = type_.__dataclass_fields__ fields_to_types: dict[str, type] = {f: fields[f].type for f in fields} all_arg_sequences: Iterable = itertools.product( *[ all_instances(arg_type, validation_funcs) for arg_type in fields_to_types.values() ], ) yield from ( type_( **dict(zip(fields_to_types.keys(), args, strict=False)), ) for args in all_arg_sequences ) else: type_origin = get_origin(type_) if type_origin == tuple: # noqa: E721 # Only matches Generic type tuple since regular tuple is not finite-valued # Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields. yield from ( tuple(combo) for combo in itertools.product( *( all_instances(tup_item, validation_funcs) for tup_item in get_args(type_) ), ) ) elif type_origin in (UnionType, typing.Union): # Union: call `all_instances` for each type in the Union yield from flatten( [all_instances(sub, validation_funcs) for sub in get_args(type_)], levels_to_flatten=1, ) elif type_origin is Literal: # Literal: return all Literal arguments yield from get_args(type_) else: raise UnsupportedAllInstancesError(type_) ``````{ end_of_file="maze_dataset/tokenization/modular/all_instances.py" } ``````{ path="maze_dataset/tokenization/modular/all_tokenizers.py" } """Contains `get_all_tokenizers()` and supporting limited-use functions. # `get_all_tokenizers()` returns a comprehensive collection of all valid `MazeTokenizerModular` objects. This is an overwhelming majority subset of the set of all possible `MazeTokenizerModular` objects. Other tokenizers not contained in `get_all_tokenizers()` may be possible to construct, but they are untested and not guaranteed to work. This collection is in a separate module since it is expensive to compute and will grow more expensive as features are added to `MazeTokenizerModular`. ## Use Cases In general, uses for this module are limited to development of the library and specific research studying many tokenization behaviors. - Unit testing: - Tokenizers to use in unit tests are sampled from `get_all_tokenizers()` - Large-scale tokenizer research: - Specific research training models on many tokenization behaviors can use `get_all_tokenizers()` as the maximally inclusive collection - `get_all_tokenizers()` may be subsequently filtered using `MazeTokenizerModular.has_element` For other uses, it's likely that the computational expense can be avoided by using - `maze_tokenizer.get_all_tokenizer_hashes()` for membership checks - `utils.all_instances` for generating smaller subsets of `MazeTokenizerModular` or `_TokenizerElement` objects # `EVERY_TEST_TOKENIZERS` A collection of the tokenizers which should always be included in unit tests when test fuzzing is used. This collection should be expanded as specific tokenizers become canonical or popular. """ import functools import multiprocessing import random from functools import cache from pathlib import Path from typing import Callable import frozendict import numpy as np from muutils.spinner import NoOpContextManager, SpinnerContext from tqdm import tqdm from maze_dataset.tokenization import ( CoordTokenizers, MazeTokenizerModular, PromptSequencers, StepTokenizers, _TokenizerElement, ) from maze_dataset.tokenization.modular.all_instances import FiniteValued, all_instances from maze_dataset.tokenization.modular.hashing import ( AllTokenizersHashBitLength, AllTokenizersHashDtype, AllTokenizersHashesArray, ) # Always include this as the first item in the dict `validation_funcs` whenever using `all_instances` with `MazeTokenizerModular` # TYPING: error: Type variable "maze_dataset.utils.FiniteValued" is unbound [valid-type] # note: (Hint: Use "Generic[FiniteValued]" or "Protocol[FiniteValued]" base class to bind "FiniteValued" inside a class) # note: (Hint: Use "FiniteValued" in function signature to bind "FiniteValued" inside a function) MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS: frozendict.frozendict[ type[FiniteValued], Callable[[FiniteValued], bool], ] = frozendict.frozendict( { # TYPING: Item "bool" of the upper bound "bool | IsDataclass | Enum" of type variable "FiniteValued" has no attribute "is_valid" [union-attr] _TokenizerElement: lambda x: x.is_valid(), # Currently no need for `MazeTokenizerModular.is_valid` since that method contains no special cases not already covered by `_TokenizerElement.is_valid` # MazeTokenizerModular: lambda x: x.is_valid(), # TYPING: error: No overload variant of "set" matches argument type "FiniteValued" [call-overload] # note: Possible overload variants: # note: def [_T] set(self) -> set[_T] # note: def [_T] set(self, Iterable[_T], /) -> set[_T] # TYPING: error: Argument 1 to "len" has incompatible type "FiniteValued"; expected "Sized" [arg-type] StepTokenizers.StepTokenizerPermutation: lambda x: len(set(x)) == len(x) and x != (StepTokenizers.Distance(),), }, ) DOWNLOAD_URL: str = "https://raw.githubusercontent.com/understanding-search/maze-dataset/main/maze_dataset/tokenization/MazeTokenizerModular_hashes.npz" @cache def get_all_tokenizers() -> list[MazeTokenizerModular]: """Computes a complete list of all valid tokenizers. Warning: This is an expensive function. """ return list( all_instances( MazeTokenizerModular, validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, ), ) @cache def get_all_tokenizers_names() -> list[str]: """computes the sorted list of names of all tokenizers""" return sorted([tokenizer.name for tokenizer in get_all_tokenizers()]) EVERY_TEST_TOKENIZERS: list[MazeTokenizerModular] = [ MazeTokenizerModular(), MazeTokenizerModular( prompt_sequencer=PromptSequencers.AOTP(coord_tokenizer=CoordTokenizers.CTT()), ), # TODO: add more here as specific tokenizers become canonical and frequently used ] @cache def all_tokenizers_set() -> set[MazeTokenizerModular]: """Casts `get_all_tokenizers()` to a set.""" return set(get_all_tokenizers()) @cache def _all_tokenizers_except_every_test_tokenizers() -> list[MazeTokenizerModular]: """Returns""" return list(all_tokenizers_set().difference(EVERY_TEST_TOKENIZERS)) def sample_all_tokenizers(n: int) -> list[MazeTokenizerModular]: """Samples `n` tokenizers from `get_all_tokenizers()`.""" return random.sample(get_all_tokenizers(), n) def sample_tokenizers_for_test(n: int | None) -> list[MazeTokenizerModular]: """Returns a sample of size `n` of unique elements from `get_all_tokenizers()`, always including every element in `EVERY_TEST_TOKENIZERS`. """ if n is None: return get_all_tokenizers() if n < len(EVERY_TEST_TOKENIZERS): err_msg: str = f"`n` must be at least {len(EVERY_TEST_TOKENIZERS) = } such that the sample can contain `EVERY_TEST_TOKENIZERS`." raise ValueError( err_msg, ) sample: list[MazeTokenizerModular] = random.sample( _all_tokenizers_except_every_test_tokenizers(), n - len(EVERY_TEST_TOKENIZERS), ) sample.extend(EVERY_TEST_TOKENIZERS) return sample def save_hashes( path: Path | None = None, verbose: bool = False, parallelize: bool | int = False, ) -> AllTokenizersHashesArray: """Computes, sorts, and saves the hashes of every member of `get_all_tokenizers()`.""" spinner = ( functools.partial(SpinnerContext, spinner_chars="square_dot") if verbose else NoOpContextManager ) # get all tokenizers with spinner(initial_value="getting all tokenizers...", update_interval=2.0): all_tokenizers = get_all_tokenizers() # compute hashes hashes_array_np64: AllTokenizersHashesArray if parallelize: n_cpus: int = ( parallelize if int(parallelize) > 1 else multiprocessing.cpu_count() ) with spinner( # noqa: SIM117 initial_value=f"using {n_cpus} processes to compute {len(all_tokenizers)} tokenizer hashes...", update_interval=2.0, ): with multiprocessing.Pool(processes=n_cpus) as pool: hashes_list: list[int] = list(pool.map(hash, all_tokenizers)) with spinner(initial_value="converting hashes to numpy array..."): hashes_array_np64 = np.array(hashes_list, dtype=np.int64) else: with spinner( initial_value=f"computing {len(all_tokenizers)} tokenizer hashes...", ): hashes_array_np64 = np.array( [ hash(obj) # uses stable hash for obj in tqdm(all_tokenizers, disable=not verbose) ], dtype=np.int64, ) # convert to correct dtype hashes_array: AllTokenizersHashesArray = ( hashes_array_np64 % (1 << AllTokenizersHashBitLength) if AllTokenizersHashBitLength < 64 # noqa: PLR2004 else hashes_array_np64 ).astype(AllTokenizersHashDtype) # make sure there are no dupes with spinner(initial_value="sorting and checking for hash collisions..."): sorted_hashes, counts = np.unique(hashes_array, return_counts=True) if sorted_hashes.shape[0] != hashes_array.shape[0]: collisions: np.array = sorted_hashes[counts > 1] n_collisions: int = hashes_array.shape[0] - sorted_hashes.shape[0] err_msg: str = ( f"{n_collisions} tokenizer hash collisions: {collisions}\n" "Report error to the developer to increase the hash size or otherwise update the tokenizer hashing size:\n" f"https://github.com/understanding-search/maze-dataset/issues/new?labels=bug,tokenization&title=Tokenizer+hash+collision+error&body={n_collisions}+collisions+out+of+{hashes_array.shape[0]}+total+hashes", ) raise ValueError( err_msg, ) # save and return with spinner(initial_value="saving hashes...", update_interval=0.5): if path is None: path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz" np.savez_compressed( path, hashes=sorted_hashes, ) return sorted_hashes ``````{ end_of_file="maze_dataset/tokenization/modular/all_tokenizers.py" } ``````{ path="maze_dataset/tokenization/modular/element_base.py" } """provides the base `_TokenizerElement` class and related functionality for modular maze tokenization see the code in `maze_dataset.tokenization.modular.elements` for examples of subclasses of `_TokenizerElement` """ import abc from typing import ( Any, Callable, Literal, TypeVar, ) from muutils.json_serialize import ( SerializableDataclass, serializable_dataclass, serializable_field, ) from muutils.json_serialize.util import _FORMAT_KEY from muutils.misc import flatten from zanj.loading import load_item_recursive from maze_dataset.tokenization.modular.hashing import _hash_tokenizer_name # from maze_dataset import SolvedMaze @serializable_dataclass(frozen=True, kw_only=True) class _TokenizerElement(SerializableDataclass, abc.ABC): """Superclass for tokenizer elements. Subclasses contain modular functionality for maze tokenization. # Development > [!TIP] > Due to the functionality of `get_all_tokenizers()`, `_TokenizerElement` subclasses > may only contain fields of type `utils.FiniteValued`. > Implementing a subclass with an `int` or `float`-typed field, for example, is not supported. > In the event that adding such fields is deemed necessary, `get_all_tokenizers()` must be updated. """ # TYPING: type hint `v` more specifically @staticmethod def _stringify(k: str, v: Any) -> str: # noqa: ANN401 if isinstance(v, bool): return f"{k}={str(v)[0]}" if isinstance(v, _TokenizerElement): return v.name if isinstance(v, tuple): return f"{k}={''.join(['(', *[str(x) + ', ' for x in v], ')'])}" else: return f"{k}={v}" @property def name(self) -> str: members_str: str = ", ".join( [self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"], ) output: str = f"{type(self).__name__}({members_str})" if "." in output and output.index("(") > output.index("."): return "".join(output.split(".")[1:]) else: return output def __str__(self) -> str: return self.name # TYPING: type hints for `__init_subclass__`? def __init_subclass__(cls, **kwargs): # noqa: ANN204 """Hack: dataclass hashes don't include the class itself in the hash function inputs. This causes dataclasses with identical fields but different types to hash identically. This hack circumvents this by adding a slightly hidden field to every subclass with a value of `repr(cls)`. To maintain compatibility with `all_instances`, the static type of the new field can only have 1 possible value. So we type it as a singleton `Literal` type. muutils 0.6.1 doesn't support `Literal` type validation, so `assert_type=False`. Ignore Pylance complaining about the arg to `Literal` being an expression. """ super().__init_subclass__(**kwargs) # we are adding a new attr here intentionally cls._type_ = serializable_field( # type: ignore[attr-defined] init=True, repr=False, default=repr(cls), assert_type=False, ) cls.__annotations__["_type_"] = Literal[repr(cls)] def __hash__(self) -> int: "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" return _hash_tokenizer_name(self.name) @classmethod def _level_one_subclass(cls) -> type["_TokenizerElement"]: """Returns the immediate subclass of `_TokenizerElement` of which `cls` is an instance.""" return ( set(cls.__mro__).intersection(set(_TokenizerElement.__subclasses__())).pop() ) def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]: """Returns a list of all `_TokenizerElement` instances contained in the subtree. Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or which sit inside a `tuple` without further nesting. # Parameters - `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer. """ if not any(type(el) == tuple for el in self.__dict__.values()): # noqa: E721 return list( flatten( [ [el, *el.tokenizer_elements()] for el in self.__dict__.values() if isinstance(el, _TokenizerElement) ], ) if deep else filter( lambda x: isinstance(x, _TokenizerElement), self.__dict__.values(), ), ) else: non_tuple_elems: list[_TokenizerElement] = list( flatten( [ [el, *el.tokenizer_elements()] for el in self.__dict__.values() if isinstance(el, _TokenizerElement) ] if deep else filter( lambda x: isinstance(x, _TokenizerElement), self.__dict__.values(), ), ), ) tuple_elems: list[_TokenizerElement] = list( flatten( [ ( [ [tup_el, *tup_el.tokenizer_elements()] for tup_el in el if isinstance(tup_el, _TokenizerElement) ] if deep else filter(lambda x: isinstance(x, _TokenizerElement), el) ) for el in self.__dict__.values() if isinstance(el, tuple) ], ), ) non_tuple_elems.extend(tuple_elems) return non_tuple_elems def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str: """Returns a string representation of the tree of tokenizer elements contained in `self`. # Parameters - `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify. - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. """ name: str = "\t" * depth + ( type(self).__name__ if not abstract else type(self)._level_one_subclass().__name__ ) return ( name + "\n" + "".join( el.tokenizer_element_tree(depth + 1, abstract) for el in self.tokenizer_elements(deep=False) ) ) def tokenizer_element_dict(self) -> dict: """Returns a dictionary representation of the tree of tokenizer elements contained in `self`.""" return { type(self).__name__: { key: ( val.tokenizer_element_dict() if isinstance(val, _TokenizerElement) else ( val if not isinstance(val, tuple) else [ ( el.tokenizer_element_dict() if isinstance(el, _TokenizerElement) else el ) for el in val ] ) ) for key, val in self.__dict__.items() if key != "_type_" }, } @classmethod @abc.abstractmethod def attribute_key(cls) -> str: """Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`.""" raise NotImplementedError def to_tokens(self, *args, **kwargs) -> list[str]: """Converts a maze element into a list of tokens. Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method. Those subclasses which do produce tokens should override this method. """ raise NotImplementedError @abc.abstractmethod def is_valid(self, do_except: bool = False) -> bool: """Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`. Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints. `is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone. If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass. # Types of Invalidity In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. Invalidity types, in ascending order of invalidity: - Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. E.g., `_TokenizerElement`s which are strictly worse than some alternative. - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers. - Untrainable: Training functional models using these tokenizers would be (nearly) impossible. - Erroneous: These tokenizers might raise exceptions during use. # Development `is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid. When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary. ## Nesting In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class. In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree. `MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected. ## Types of Invalidity If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances. """ raise NotImplementedError T = TypeVar("T", bound=_TokenizerElement) def _unsupported_is_invalid(self, do_except: bool = False) -> bool: # noqa: ANN001 """Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes""" if do_except: err_msg: str = ( f"Class `{type(self).__name__ = }, marked as unsupported, is not valid." f"{type(self) = }, {self = }" ) raise ValueError(err_msg) return False # TYPING: better type hints for this function def mark_as_unsupported(is_valid: Callable[[T, bool], bool]) -> Callable[[T], T]: """mark a _TokenizerElement as unsupported. Classes marked with this decorator won't show up in `get_all_tokenizers()` and thus wont be tested. The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method. The decorator exists to prune the space of tokenizers returned by `all_instances` both for testing and usage. Previously, the space was too large, resulting in impractical runtimes. These decorators could be removed in future releases to expand the space of possible tokenizers. """ def wrapper(cls: T) -> T: # intentionally modifying method here # idk why it things `T`/`self` should not be an argument cls.is_valid = is_valid # type: ignore[assignment, method-assign] return cls return wrapper # TODO: why noqa here? `B024 `__TokenizerElementNamespace` is an abstract base class, but it has no abstract methods or properties` class __TokenizerElementNamespace(abc.ABC): # noqa: B024 """ABC for namespaces # Properties - key: The binding used in `MazeTokenizerModular` for instances of the classes contained within that `__TokenizerElementNamespace`. """ # HACK: this is not the right way of doing this lol key: str = NotImplementedError # type: ignore[assignment] def _load_tokenizer_element( data: dict[str, Any], namespace: type[__TokenizerElementNamespace], ) -> _TokenizerElement: """Loads a `TokenizerElement` stored via zanj.""" key: str = namespace.key format_: str = data[key][_FORMAT_KEY] cls_name: str = format_.split("(")[0] cls: type[_TokenizerElement] = getattr(namespace, cls_name) kwargs: dict[str, Any] = { k: load_item_recursive(data[key][k], tuple()) for k, v in data[key].items() } if _FORMAT_KEY in kwargs: kwargs.pop(_FORMAT_KEY) return cls(**kwargs) ``````{ end_of_file="maze_dataset/tokenization/modular/element_base.py" } ``````{ path="maze_dataset/tokenization/modular/elements.py" } """implements subclasses of `_TokenizerElement` to be used in `MazeTokenizerModular`""" import abc import random from typing import ( Callable, Literal, Sequence, TypedDict, ) import numpy as np from jaxtyping import Bool, Int from muutils.json_serialize import ( serializable_dataclass, serializable_field, ) from muutils.misc import empty_sequence_if_attr_false, flatten # from maze_dataset import SolvedMaze from maze_dataset.constants import ( VOCAB, ConnectionArray, ConnectionList, Coord, CoordTup, ) from maze_dataset.generation import numpy_rng from maze_dataset.maze.lattice_maze import LatticeMaze, SolvedMaze from maze_dataset.token_utils import ( connection_list_to_adj_list, get_cardinal_direction, get_relative_direction, is_connection, tokens_between, ) from maze_dataset.tokenization.modular.element_base import ( __TokenizerElementNamespace, _load_tokenizer_element, _TokenizerElement, _unsupported_is_invalid, mark_as_unsupported, ) from maze_dataset.utils import lattice_connection_array class CoordTokenizers(__TokenizerElementNamespace): """Namespace for `_CoordTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" key = "coord_tokenizer" @serializable_dataclass(frozen=True, kw_only=True) class _CoordTokenizer(_TokenizerElement, abc.ABC): """Superclass for classes which tokenize singular coords in a maze.""" @abc.abstractmethod def to_tokens(self, coord: Coord | CoordTup) -> list[str]: pass @classmethod def attribute_key(cls) -> str: return CoordTokenizers.key def is_valid(self, do_except: bool = False) -> bool: # No invalid instances possible within data member type hint bounds return True @serializable_dataclass(frozen=True, kw_only=True) class UT(_CoordTokenizer): """Unique token coordinate tokenizer.""" # inherit docstring def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 return ["".join(["(", str(coord[0]), ",", str(coord[1]), ")"])] @serializable_dataclass(frozen=True, kw_only=True) class CTT(_CoordTokenizer): """Coordinate tuple tokenizer # Parameters - `pre`: Whether all coords include an integral preceding delimiter token - `intra`: Whether all coords include a delimiter token between coordinates - `post`: Whether all coords include an integral following delimiter token """ pre: bool = serializable_field(default=True) intra: bool = serializable_field(default=True) post: bool = serializable_field(default=True) # Implement methods # inherit docstring def to_tokens(self, coord: Coord | CoordTup) -> list[str]: # noqa: D102 return [ *empty_sequence_if_attr_false([VOCAB.COORD_PRE], self, "pre"), str(coord[0]), *empty_sequence_if_attr_false([VOCAB.COORD_INTRA], self, "intra"), str(coord[1]), *empty_sequence_if_attr_false([VOCAB.COORD_POST], self, "post"), ] class EdgeGroupings(__TokenizerElementNamespace): """Namespace for `_EdgeGrouping` subclass hierarchy used by `_AdjListTokenizer`.""" key = "edge_grouping" class _GroupingTokenParams(TypedDict): """A uniform private hyperparameter interface used by `AdjListTokenizer`.""" connection_token_ordinal: Literal[0, 1, 2] intra: bool grouped: bool @serializable_dataclass(frozen=True, kw_only=True) class _EdgeGrouping(_TokenizerElement, abc.ABC): """Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called a edge grouping.""" @classmethod def attribute_key(cls) -> str: return EdgeGroupings.key def is_valid(self, do_except: bool = False) -> bool: return True @abc.abstractmethod def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: """Divides a ConnectionArray into groups of edges. Shuffles/sequences within each group if applicable. """ pass @abc.abstractmethod def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": """Returns the tok.nization hyperparameters necessary for an `AdjListTokenizer` to tokenize. These hyperparameters are not used by `_EdgeGrouping` internally. They are located in `_EdgeGrouping` rather than in `AdjListTokenizer` since the hyperparameter space is a function of the `_EdgeGrouping` subclass. This function resolves the `_EdgeGrouping` hyperparameter space which is non-uniform across subclasses into a uniform private interface used by `AdjListTokenizer`. """ pass @serializable_dataclass(frozen=True, kw_only=True) class Ungrouped(_EdgeGrouping): """No grouping occurs, each edge is tokenized individually. # Parameters - `connection_token_ordinal`: At which index in the edge tokenization the connector (or wall) token appears. Edge tokenizations contain 3 parts: a leading coord, a connector (or wall) token, and either a second coord or cardinal direction tokenization. """ connection_token_ordinal: Literal[0, 1, 2] = serializable_field( default=1, assert_type=False, ) def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": return EdgeGroupings._GroupingTokenParams( connection_token_ordinal=self.connection_token_ordinal, intra=False, grouped=False, ) def _group_edges(self, edges: ConnectionList) -> Sequence[ConnectionList]: return np.expand_dims(edges, 1) @serializable_dataclass(frozen=True, kw_only=True) @mark_as_unsupported(_unsupported_is_invalid) class ByLeadingCoord(_EdgeGrouping): """All edges with the same leading coord are grouped together. # Parameters - `intra`: Whether all edge groupings include a delimiter token between individual edge representations. Note that each edge representation will already always include a connector token (`VOCAB.CONNECTOR`, or possibly `) - `shuffle_group`: Whether the sequence of edges within the group should be shuffled or appear in a fixed order. If false, the fixed order is lexicographical by (row, col). In effect, lexicographical sorting sorts edges by their cardinal direction in the sequence NORTH, WEST, EAST, SOUTH, where the directions indicate the position of the trailing coord relative to the leading coord. - `connection_token_ordinal`: At which index in token sequence representing a single edge the connector (or wall) token appears. Edge tokenizations contain 2 parts: a connector (or wall) token and a coord or cardinal tokenization. """ intra: bool = serializable_field(default=True) shuffle_group: bool = serializable_field(default=True) connection_token_ordinal: Literal[0, 1] = serializable_field( default=0, assert_type=False, ) def _token_params(self) -> "EdgeGroupings._GroupingTokenParams": return EdgeGroupings._GroupingTokenParams( connection_token_ordinal=self.connection_token_ordinal, intra=self.intra, grouped=True, ) def _group_edges(self, edges: ConnectionArray) -> Sequence[ConnectionArray]: # Adapted from: https://stackoverflow.com/questions/38013778/is-there-any-numpy-group-by-function index_array: Int[np.ndarray, "sort_indices=edges"] = np.lexsort( (edges[:, 1, 1], edges[:, 1, 0], edges[:, 0, 1], edges[:, 0, 0]), ) sorted_edges: ConnectionArray = edges[index_array, ...] groups: list[ConnectionArray] = np.split( sorted_edges, np.unique(sorted_edges[:, 0, :], return_index=True, axis=0)[1][1:], ) if self.shuffle_group: [numpy_rng.shuffle(g, axis=0) for g in groups] return groups class EdgePermuters(__TokenizerElementNamespace): """Namespace for `_EdgePermuter` subclass hierarchy used by `_AdjListTokenizer`.""" key = "edge_permuter" @serializable_dataclass(frozen=True, kw_only=True) class _EdgePermuter(_TokenizerElement, abc.ABC): """Specifies how to sequence the two coords that encode a lattice edge.""" @classmethod def attribute_key(cls) -> str: return EdgePermuters.key def is_valid(self, do_except: bool = False) -> bool: # No invalid instances possible within data member type hint bounds return True @staticmethod @abc.abstractmethod def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: """Executes a permutation. Warning: Caller should be aware that `lattice_edges` may be modified in-place depending on the subclass's implementation. # Parameters - `lattice_edges`: Array of lattice edges. The two coords in shape[1] must be adjacent in the lattice. # Returns - Array of lattice edges with entries along shape[1] systematically permuted. - shape[0] of the returned array is NOT guaranteed to match `lattice_edges.shape[1]`. """ pass @serializable_dataclass(frozen=True, kw_only=True) class SortedCoords(_EdgePermuter): """returns a sorted representation. useful for checking consistency""" @staticmethod def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: return lattice_edges[ np.lexsort( ( lattice_edges[:, 1, 1], lattice_edges[:, 1, 0], lattice_edges[:, 0, 1], lattice_edges[:, 0, 0], ), ), ..., ] @serializable_dataclass(frozen=True, kw_only=True) class RandomCoords(_EdgePermuter): """Permutes each edge randomly.""" @staticmethod def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: numpy_rng.permuted(lattice_edges, axis=1, out=lattice_edges) return lattice_edges @serializable_dataclass(frozen=True, kw_only=True) class BothCoords(_EdgePermuter): """Includes both possible permutations of every edge in the output. Since input ConnectionList has only 1 instance of each edge, a call to `BothCoords._permute` will modify `lattice_edges` in-place, doubling `shape[0]`. """ @staticmethod def _permute(lattice_edges: ConnectionArray) -> ConnectionArray: return np.append(lattice_edges, np.flip(lattice_edges, axis=1), axis=0) class EdgeSubsets(__TokenizerElementNamespace): """Namespace for `_EdgeSubset` subclass hierarchy used by `_AdjListTokenizer`.""" key = "edge_subset" @serializable_dataclass(frozen=True, kw_only=True) class _EdgeSubset(_TokenizerElement, abc.ABC): """Component of an `AdjListTokenizers._AdjListTokenizer` which specifies the subset of lattice edges to be tokenized.""" @classmethod def attribute_key(cls) -> str: return EdgeSubsets.key def is_valid(self, do_except: bool = False) -> bool: return True @abc.abstractmethod def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: """Returns the set of lattice edges to be tokenized.""" pass @serializable_dataclass(frozen=True, kw_only=True) class AllLatticeEdges(_EdgeSubset): """All 2n**2-2n edges of the lattice are tokenized. If a wall exists on that edge, the edge is tokenized in the same manner, using `VOCAB.ADJLIST_WALL` in place of `VOCAB.CONNECTOR`. """ def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: return lattice_connection_array(maze.grid_n) @serializable_dataclass(frozen=True, kw_only=True) class ConnectionEdges(_EdgeSubset): """Only edges which contain a connection are tokenized. Alternatively, only edges which contain a wall are tokenized. # Parameters - `walls`: Whether wall edges or connection edges are tokenized. If true, `VOCAB.ADJLIST_WALL` is used in place of `VOCAB.CONNECTOR`. """ walls: bool = serializable_field(default=False) def _get_edges(self, maze: LatticeMaze) -> ConnectionArray: conn_list: ConnectionList = maze.connection_list if self.walls: conn_list = np.logical_not(conn_list) conn_list[0, -1, :] = False conn_list[1, :, -1] = False return connection_list_to_adj_list( conn_list, shuffle_d0=False, shuffle_d1=False, ) def _adjlist_no_pre_unsupported(self_, do_except: bool = False) -> bool: # noqa: ANN001 """Returns False if `pre` is True, True otherwise.""" output: bool = self_.pre is False if do_except and not output: raise ValueError( "AdjListCoord does not support `pre == False`.", ) return output class AdjListTokenizers(__TokenizerElementNamespace): """Namespace for `_AdjListTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" key = "adj_list_tokenizer" @serializable_dataclass(frozen=True, kw_only=True) @mark_as_unsupported(_adjlist_no_pre_unsupported) class _AdjListTokenizer(_TokenizerElement, abc.ABC): """Specifies how the adjacency list is tokenized. Tokenization behavior is decomposed into specification of edge subsets, groupings, and permutations. See documentation of `EdgeSubset` and `EdgeGrouping` classes for more details. # Parameters - `pre`: Whether all edge groupings include a preceding delimiter token - `post`: Whether all edge groupings include a following delimiter token - `shuffle_d0`: Specifies how to sequence the edge groupings. If true, groupings are shuffled randomly. If false, groupings are sorted by the leading coord of each group. - `edge_grouping`: Specifies if/how multiple coord-coord connections are grouped together in a token subsequence called an edge grouping. - `edge_subset`: Specifies the subset of lattice edges to be tokenized. - `edge_permuter`: Specifies, in each edge tokenization, which coord either: 1. Appears first in the tokenization, for `AdjListCoord`. 2. Is tokenized directly as a coord, for `AdjListCardinal`. - `shuffle`: For each edge, the leading coord is selected randomly. - `all`: Each edge appears twice in the tokenization, appearing with both leading coords. - `evens`, `odds`: The leading coord is the one belonging to that coord subset. See `EdgeSubsets.ChessboardSublattice` for details. """ pre: bool = serializable_field(default=False, assert_type=False) post: bool = serializable_field(default=True) shuffle_d0: bool = serializable_field(default=True) edge_grouping: EdgeGroupings._EdgeGrouping = serializable_field( default=EdgeGroupings.Ungrouped(), loading_fn=lambda x: _load_tokenizer_element(x, EdgeGroupings), ) edge_subset: EdgeSubsets._EdgeSubset = serializable_field( default=EdgeSubsets.ConnectionEdges(), loading_fn=lambda x: _load_tokenizer_element(x, EdgeSubsets), ) edge_permuter: EdgePermuters._EdgePermuter = serializable_field( default=EdgePermuters.RandomCoords(), loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), ) @classmethod def attribute_key(cls) -> str: return AdjListTokenizers.key def is_valid(self, do_except: bool = False) -> bool: # No invalid instances possible within data member type hint bounds return True @abc.abstractmethod def _tokenization_callables( self, edges: ConnectionArray, is_conn: Bool[np.ndarray, " edges"], coord_tokenizer: CoordTokenizers._CoordTokenizer, *args, **kwargs, ) -> list[Callable]: """Returns a sequence of callables which take an index in `edges` and return parts of that edge tokenization. # Returns - `[0]`: leading coord tokens - `[1]`: connector tokens - `[2]`: trailing coord tokens """ pass def _tokenize_edge_grouping( self, edges: ConnectionArray, maze: LatticeMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer, group_params: EdgeGroupings._GroupingTokenParams, ) -> Sequence[str]: """Tokenizes a single edge grouping.""" cxn_ord: int = group_params["connection_token_ordinal"] is_conn: Bool[np.ndarray, edges] = is_connection( edges, maze.connection_list, ) tokenize_callables = self._tokenization_callables( edges, is_conn, coord_tokenizer, ) if group_params["grouped"]: # If grouped callable_permutation: list[int] = [1, 2] if cxn_ord == 0 else [2, 1] repeated_callables = [ tokenize_callables[i] for i in callable_permutation ] return flatten( [ tokenize_callables[0](0), [ [ *[ tok_callable(i) for tok_callable in repeated_callables ], *( (VOCAB.ADJLIST_INTRA,) if group_params["intra"] else () ), ] for i in range(edges.shape[0]) ], ], ) else: # If ungrouped callable_permutation = [0, 2] callable_permutation.insert(cxn_ord, 1) tokenize_callables = [ tokenize_callables[i] for i in callable_permutation ] return flatten( [ [ [ *[ tok_callable(i) for tok_callable in tokenize_callables ], *empty_sequence_if_attr_false( (VOCAB.ADJLIST_INTRA,), group_params, "intra", ), ] for i in range(edges.shape[0]) ], ], ) def to_tokens( self, maze: LatticeMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer, ) -> list[str]: # Get the set of edges to be tokenized edges: ConnectionArray = self.edge_subset._get_edges(maze) # Systematically permute the leading coord of each edge edges: ConnectionArray = self.edge_permuter._permute(edges) group_params: EdgeGroupings._GroupingTokenParams = ( self.edge_grouping._token_params() ) # then, we need to group the edges groups: Sequence[ConnectionArray] = self.edge_grouping._group_edges(edges) # shuffle the groups if specified if self.shuffle_d0: if isinstance(groups, np.ndarray): numpy_rng.shuffle(groups, axis=0) elif isinstance(groups, list): random.shuffle(groups) else: err_msg: str = f"`groups` is an unexpected type {type(groups)}. Only types `list` and `np.ndarray` are currently supported." raise TypeError(err_msg) # Tokenize each group with optional delimiters tokens: list[str] = list( flatten( [ [ *empty_sequence_if_attr_false( (VOCAB.ADJLIST_PRE,), self, "pre", ), *self._tokenize_edge_grouping( group, maze, coord_tokenizer, group_params, ), *empty_sequence_if_attr_false( (VOCAB.ADJACENCY_ENDLINE,), self, "post", ), ] for group in groups ], ), ) return tokens @serializable_dataclass(frozen=True, kw_only=True) class AdjListCoord(_AdjListTokenizer): """Represents an edge group as tokens for the leading coord followed by coord tokens for the other group members.""" edge_permuter: EdgePermuters._EdgePermuter = serializable_field( default=EdgePermuters.RandomCoords(), loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), ) def _tokenization_callables( self, edges: ConnectionArray, is_conn: Bool[np.ndarray, " edges"], coord_tokenizer: CoordTokenizers._CoordTokenizer, *args, **kwargs, ) -> list[Callable]: # Map from `is_conn` to the tokens which represent connections and walls conn_token_map: dict[bool, str] = { True: VOCAB.CONNECTOR, False: VOCAB.ADJLIST_WALL, } return [ lambda i: coord_tokenizer.to_tokens(edges[i, 0]), lambda i: conn_token_map[is_conn[i]], lambda i: coord_tokenizer.to_tokens(edges[i, 1]), ] @serializable_dataclass(frozen=True, kw_only=True) class AdjListCardinal(_AdjListTokenizer): """Represents an edge group as coord tokens for the leading coord and cardinal tokens relative to the leading coord for the other group members. # Parameters - `coord_first`: Whether the leading coord token(s) should come before or after the sequence of cardinal tokens. """ edge_permuter: EdgePermuters._EdgePermuter = serializable_field( default=EdgePermuters.BothCoords(), loading_fn=lambda x: _load_tokenizer_element(x, EdgePermuters), ) def _tokenization_callables( self, edges: ConnectionArray, is_conn: Bool[np.ndarray, " edges"], coord_tokenizer: CoordTokenizers._CoordTokenizer, *args, **kwargs, ) -> list[Callable]: # Map from `is_conn` to the tokens which represent connections and walls conn_token_map: dict[bool, str] = { True: VOCAB.CONNECTOR, False: VOCAB.ADJLIST_WALL, } return [ lambda i: coord_tokenizer.to_tokens(edges[i, 0]), lambda i: conn_token_map[is_conn[i]], lambda i: get_cardinal_direction(edges[i]), ] class TargetTokenizers(__TokenizerElementNamespace): """Namespace for `_TargetTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" key = "target_tokenizer" @serializable_dataclass(frozen=True, kw_only=True) class _TargetTokenizer(_TokenizerElement, abc.ABC): """Superclass of tokenizers for maze targets.""" @abc.abstractmethod def to_tokens( self, targets: Sequence[Coord], coord_tokenizer: CoordTokenizers._CoordTokenizer, ) -> list[str]: """Returns tokens representing the target.""" pass @classmethod def attribute_key(cls) -> str: return TargetTokenizers.key @serializable_dataclass(frozen=True, kw_only=True) class Unlabeled(_TargetTokenizer): """Targets are simply listed as coord tokens. - `post`: Whether all coords include an integral following delimiter token """ post: bool = serializable_field(default=False) # inherit docstring def to_tokens( # noqa: D102 self, targets: Sequence[Coord], coord_tokenizer: CoordTokenizers._CoordTokenizer, ) -> list[str]: return list( flatten( [ [ *coord_tokenizer.to_tokens(target), *empty_sequence_if_attr_false( [VOCAB.TARGET_POST], self, "post", ), ] for target in targets ], ), ) # inherit docstring def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 # No invalid instances possible within data member type hint bounds return True class StepSizes(__TokenizerElementNamespace): """Namespace for `_StepSize` subclass hierarchy used by `MazeTokenizerModular`.""" key = "step_size" @serializable_dataclass(frozen=True, kw_only=True) class _StepSize(_TokenizerElement, abc.ABC): """Specifies which coords in `maze.solution` are used to represent the path.""" @classmethod def attribute_key(cls) -> str: return StepSizes.key @abc.abstractmethod # TODO: make this a static/class method, allowing ForksAndStraightaways to skip object construction at every call def _step_single_indices(self, maze: SolvedMaze) -> list[int]: """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" raise NotImplementedError( "Subclasses must implement `StepSize.step_indices.", ) def step_start_end_indices(self, maze: SolvedMaze) -> list[tuple[int, int]]: """Returns steps as tuples of starting and ending positions for each step.""" indices: list[int] = self._step_single_indices(maze) # TODO: RUF007 Prefer `itertools.pairwise()` over `zip()` when iterating over successive pairs return [ (start, end) for start, end in zip(indices[:-1], indices[1:], strict=False) # noqa: RUF007 ] def is_valid(self, do_except: bool = False) -> bool: # No invalid instances possible within data member type hint bounds return True @serializable_dataclass(frozen=True, kw_only=True) class Singles(_StepSize): """Every coord in `maze.solution` is represented. Legacy tokenizers all use this behavior. """ def _step_single_indices(self, maze: SolvedMaze) -> list[int]: """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" return list(range(maze.solution.shape[0])) @serializable_dataclass(frozen=True, kw_only=True) @mark_as_unsupported(_unsupported_is_invalid) class Straightaways(_StepSize): """Only coords where the path turns are represented in the path. I.e., the path is represented as a sequence of straightaways, specified by the coords at the turns. """ def _step_single_indices(self, maze: SolvedMaze) -> list[int]: """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" last_turn_coord: Coord = maze.solution[0, ...] indices: list[int] = [0] for i, coord in enumerate(maze.solution): if coord[0] != last_turn_coord[0] and coord[1] != last_turn_coord[1]: indices.append(i - 1) last_turn_coord = maze.solution[i - 1, ...] indices.append(i) return indices @serializable_dataclass(frozen=True, kw_only=True) class Forks(_StepSize): """Only coords at forks, where the path has >=2 options for the next step are included. Excludes the option of backtracking. The starting and ending coords are always included. """ def _step_single_indices(self, maze: SolvedMaze) -> list[int]: """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" return maze.get_solution_forking_points(always_include_endpoints=True)[0] @serializable_dataclass(frozen=True, kw_only=True) @mark_as_unsupported(_unsupported_is_invalid) class ForksAndStraightaways(_StepSize): """Includes the union of the coords included by `Forks` and `Straightaways`. See documentation for those classes for details. """ def _step_single_indices(self, maze: SolvedMaze) -> list[int]: """Returns the indices of `maze.solution` corresponding to the steps to be tokenized.""" return list( np.unique( np.concatenate( ( StepSizes.Straightaways()._step_single_indices(maze), StepSizes.Forks()._step_single_indices(maze), ), ), ), ) class StepTokenizers(__TokenizerElementNamespace): """Namespace for `_StepTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" key = "step_tokenizers" @serializable_dataclass(frozen=True, kw_only=True) class _StepTokenizer(_TokenizerElement, abc.ABC): """Specifies how a single step (as specified by an instance of `_StepSize`) is tokenized.""" @classmethod def attribute_key(cls) -> str: return StepTokenizers.key @abc.abstractmethod def to_tokens( self, maze: SolvedMaze, start_index: int, end_index: int, **kwargs, ) -> list[str]: """Tokenizes a single step in the solution. # Parameters - `maze`: Maze to be tokenized - `start_index`: The index of the Coord in `maze.solution` at which the current step starts - `end_index`: The index of the Coord in `maze.solution` at which the current step ends """ raise NotImplementedError( "Subclasses must implement `StepTokenizer.to_tokens.", ) def is_valid(self, do_except: bool = False) -> bool: # No invalid instances possible within data member type hint bounds return True @serializable_dataclass(frozen=True, kw_only=True) class Coord(_StepTokenizer): """A direct tokenization of the end position coord represents the step.""" # inherit docstring def to_tokens( # noqa: D102 self, maze: SolvedMaze, start_index: int, end_index: int, coord_tokenizer: CoordTokenizers._CoordTokenizer, ) -> list[str]: return coord_tokenizer.to_tokens(maze.solution[end_index, ...]) @serializable_dataclass(frozen=True, kw_only=True) class Cardinal(_StepTokenizer): """A step is tokenized with a cardinal direction token. It is the direction of the step from the starting position along the solution. """ # inherit docstring def to_tokens( # noqa: D102 self, maze: SolvedMaze, start_index: int, end_index: int, **kwargs, ) -> list[str]: return [ get_cardinal_direction(maze.solution[start_index : start_index + 2]), ] @serializable_dataclass(frozen=True, kw_only=True) class Relative(_StepTokenizer): """Tokenizes a solution step using relative first-person directions (right, left, forward, etc.). To simplify the indeterminacy, at the start of a solution the "agent" solving the maze is assumed to be facing NORTH. Similarly to `Cardinal`, the direction is that of the step from the starting position. """ # inherit docstring def to_tokens( # noqa: D102 self, maze: SolvedMaze, start_index: int, end_index: int, **kwargs, ) -> list[str]: if start_index == 0: start = maze.solution[0] previous = start + np.array([1, 0]) return [ get_relative_direction( np.concatenate( ( np.expand_dims(previous, 0), maze.solution[start_index : start_index + 2], ), axis=0, ), ), ] return [ get_relative_direction( maze.solution[start_index - 1 : start_index + 2], ), ] @serializable_dataclass(frozen=True, kw_only=True) class Distance(_StepTokenizer): """A count of the number of individual steps from the starting point to the end point. Contains no information about directionality, only the distance traveled in the step. `Distance` must be combined with at least one other `_StepTokenizer` in a `StepTokenizerPermutation`. This constraint is enforced in `_PathTokenizer.is_valid`. """ # inherit docstring def to_tokens( # noqa: D102 self, maze: SolvedMaze, start_index: int, end_index: int, **kwargs, ) -> list[str]: d: int = end_index - start_index return [getattr(VOCAB, f"I_{d:03}")] """ `StepTokenizerPermutation` A sequence of unique `_StepTokenizer`s. This type exists mostly just for the clarity and convenience of `_PathTokenizer` code. """ StepTokenizerPermutation: type = ( tuple[_StepTokenizer] | tuple[_StepTokenizer, _StepTokenizer] | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer] | tuple[_StepTokenizer, _StepTokenizer, _StepTokenizer, _StepTokenizer] ) class PathTokenizers(__TokenizerElementNamespace): """Namespace for `_PathTokenizer` subclass hierarchy used by `MazeTokenizerModular`.""" key = "path_tokenizer" @serializable_dataclass(frozen=True, kw_only=True) class _PathTokenizer(_TokenizerElement, abc.ABC): """Superclass of tokenizers for maze solution paths.""" @abc.abstractmethod def to_tokens( self, maze: SolvedMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer, ) -> list[str]: """Returns tokens representing the solution path.""" pass @classmethod def attribute_key(cls) -> str: return PathTokenizers.key @serializable_dataclass(frozen=True, kw_only=True) class StepSequence(_PathTokenizer, abc.ABC): """Any `PathTokenizer` where the tokenization may be assembled from token subsequences, each of which represents a step along the path. Allows for a sequence of leading and trailing tokens which don't fit the step pattern. # Parameters - `step_size`: Selects the size of a single step in the sequence - `step_tokenizers`: Selects the combination and permutation of tokens - `pre`: Whether all steps include an integral preceding delimiter token - `intra`: Whether all steps include a delimiter token after each individual `_StepTokenizer` tokenization. - `post`: Whether all steps include an integral following delimiter token """ step_size: StepSizes._StepSize = serializable_field( default=StepSizes.Singles(), loading_fn=lambda x: _load_tokenizer_element(x, StepSizes), ) step_tokenizers: StepTokenizers.StepTokenizerPermutation = serializable_field( default=(StepTokenizers.Coord(),), serialization_fn=lambda x: [y.serialize() for y in x], loading_fn=lambda x: tuple(x[StepTokenizers.key]), ) pre: bool = serializable_field(default=False) intra: bool = serializable_field(default=False) post: bool = serializable_field(default=False) # inherit docstring def to_tokens( # noqa: D102 self, maze: SolvedMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer, ) -> list[str]: return [ *self._leading_tokens(maze, coord_tokenizer), *flatten( [ self._single_step_tokens(maze, start, end, coord_tokenizer) for start, end in self.step_size.step_start_end_indices(maze) ], ), *self._trailing_tokens(maze, coord_tokenizer), ] def _single_step_tokens( self, maze: SolvedMaze, i: int, j: int, coord_tokenizer: CoordTokenizers._CoordTokenizer, ) -> list[str]: """Returns the token sequence representing a single step along the path.""" step_rep_tokens: list[list[str]] = [ step_tokenizer.to_tokens(maze, i, j, coord_tokenizer=coord_tokenizer) for step_tokenizer in self.step_tokenizers ] if self.intra: step_rep_tokens_and_intra: list[str] = [None] * ( len(step_rep_tokens) * 2 ) step_rep_tokens_and_intra[::2] = step_rep_tokens step_rep_tokens_and_intra[1::2] = [VOCAB.PATH_INTRA] * len( step_rep_tokens, ) step_rep_tokens = list(flatten(step_rep_tokens_and_intra)) all_tokens: list[str] = [ *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), *flatten(step_rep_tokens), *empty_sequence_if_attr_false((VOCAB.PATH_POST,), self, "post"), ] return all_tokens def _leading_tokens( self, maze: SolvedMaze, coord_tokenizer: CoordTokenizers._CoordTokenizer, ) -> list[str]: """Returns tokens preceding those from the sequence from `_single_step_tokens`. Since the for loop in `to_tokens` iterates `len(path)-1` times, a fencepost problem exists with `StepTokenizers.Coord`. should NOT be included. """ if StepTokenizers.Coord() in self.step_tokenizers: return [ *empty_sequence_if_attr_false((VOCAB.PATH_PRE,), self, "pre"), *coord_tokenizer.to_tokens(maze.solution[0, ...]), *empty_sequence_if_attr_false((VOCAB.PATH_INTRA,), self, "intra"), ] return [] def _trailing_tokens( self, c: Coord, coord_tokenizer: CoordTokenizers._CoordTokenizer, ) -> list[str]: """Returns tokens following those from the sequence from `_single_step_tokens`. should NOT be included. """ return [] # inherits docstring def is_valid(self, do_except: bool = False) -> bool: # noqa: D102 output: bool if len(set(self.step_tokenizers)) != len(self.step_tokenizers): # Uninteresting: repeated elements are not useful output = False else: # we do noqa for the comment if false if len(self.step_tokenizers) == 1 and isinstance( self.step_tokenizers[0], StepTokenizers.Distance, ): # Untrainable: `Distance` alone cannot encode a path. >=1 `StepTokenizer` which indicates direction/location is required. output = False else: output = True if not output and do_except: raise ValueError( "PathTokenizer must contain at least one `StepTokenizer` which indicates direction/location, or it will be untrainable.", ) return output class PromptSequencers(__TokenizerElementNamespace): """Namespace for `_PromptSequencer` subclass hierarchy used by `MazeTokenizerModular`.""" key = "prompt_sequencer" @serializable_dataclass(frozen=True, kw_only=True) class _PromptSequencer(_TokenizerElement, abc.ABC): """Sequences token regions into a complete maze tokenization. # Parameters - `coord_tokenizer`: Tokenizer element which tokenizes a single `Coord` aka maze position. - `adj_list_tokenizer`: Tokenizer element which tokenizes the adjacency list of a `LatticeMaze`. Uses `coord_tokenizer` to tokenize coords if needed in other `TokenizerElement`s. """ coord_tokenizer: CoordTokenizers._CoordTokenizer = serializable_field( default=CoordTokenizers.UT(), loading_fn=lambda x: _load_tokenizer_element(x, CoordTokenizers), ) adj_list_tokenizer: AdjListTokenizers._AdjListTokenizer = serializable_field( default=AdjListTokenizers.AdjListCoord(), loading_fn=lambda x: _load_tokenizer_element(x, AdjListTokenizers), ) @classmethod def attribute_key(cls) -> str: return PromptSequencers.key @staticmethod def _trim_if_unsolved_maze( untrimmed: list[str], is_untargeted: bool = False, is_unsolved: bool = False, ) -> list[str]: """Trims a full `SolvedMaze` prompt if the maze data reflects an unsolved or untargeted maze. # Development This implementation should function for `AOTP`, `AOP`, and other concrete classes using any subsequence of AOTP. It is not located in `token_utils.py` because it may need to be overridden in more exotic `PromptSequencer` subclasses. """ if is_untargeted: return tokens_between( untrimmed, VOCAB.ADJLIST_START, VOCAB.ADJLIST_END, include_start=True, include_end=True, ) if is_unsolved: if VOCAB.TARGET_END in untrimmed: return tokens_between( untrimmed, VOCAB.ADJLIST_START, VOCAB.TARGET_END, include_start=True, include_end=True, ) else: return tokens_between( untrimmed, VOCAB.ADJLIST_START, VOCAB.ORIGIN_END, include_start=True, include_end=True, ) return untrimmed def to_tokens( self, maze: LatticeMaze, *args, **kwargs, ) -> list[str]: """Returns a complete list of tokens for a given set of maze elements.""" untrimmed: list[str] = self._sequence_tokens( *self._get_prompt_regions(maze), ) return self._trim_if_unsolved_maze( untrimmed, not hasattr(maze, "start_pos"), not hasattr(maze, "solution"), ) def _get_prompt_regions( self, maze: LatticeMaze, *args, **kwargs, ) -> list[list[str]]: """Gets the prompt regions of a maze in a fixed sequence. This method is NOT responsible for including/excluding any prompt regions. Always return according to the API described under Returns. This implementation is expected to be suitable for most `PromptSequencer` subclasses. Subclasses may override this method if needed for special behavior. # Returns - [0]: list[str] Adjacency list tokens - [1]: list[str] Origin tokens - [2]: list[str] Target tokens - [3]: list[str] Path tokens # `None`-valued Args If one or more of `origin`, `target`, or `path` are `None`, that indicates that an unsolved or untargeted maze is being tokenized. To ensure unpackability in `_sequence_tokens`, these `None` values are substituted for empty iterables. """ origin: Coord | None = getattr(maze, "start_pos", None) target: list[Coord] | None = [ getattr(maze, "end_pos", None), ] # TargetTokenizer requires target: Sequence[Coord] return [ ( self.adj_list_tokenizer.to_tokens( maze, coord_tokenizer=self.coord_tokenizer, ) if hasattr(self, "adj_list_tokenizer") else [] ), self.coord_tokenizer.to_tokens(origin) if origin is not None else [], ( self.target_tokenizer.to_tokens( target, coord_tokenizer=self.coord_tokenizer, ) if target[0] is not None and hasattr(self, "target_tokenizer") else [] ), ( self.path_tokenizer.to_tokens( maze, coord_tokenizer=self.coord_tokenizer, ) if hasattr(maze, "solution") and hasattr(self, "path_tokenizer") else [] ), ] @abc.abstractmethod def _sequence_tokens( self, adj_list: list[str], origin: list[str] | None, target: list[str] | None, path: list[str] | None, ) -> list[str]: """Sequences token regions into a complete prompt. Includes any boundary tokens in `constants.SPECIAL_TOKENS` such as , , etc. # Parameters - `adj_list`: Tokens representing the adjacency list - `origin`: Tokens representing the origin - `target`: Tokens representing the target - `path`: Tokens representing the path """ pass def is_valid(self, do_except: bool = False) -> bool: # No invalid instances possible within data member type hint bounds return True @serializable_dataclass(frozen=True, kw_only=True) class AOTP(_PromptSequencer): """Sequences a prompt as [adjacency list, origin, target, path]. # Parameters - `target_tokenizer`: Tokenizer element which tokenizes the target(s) of a `TargetedLatticeMaze`. Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `TargetTokenizer`. - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. """ target_tokenizer: TargetTokenizers._TargetTokenizer = serializable_field( default=TargetTokenizers.Unlabeled(), loading_fn=lambda x: _load_tokenizer_element(x, TargetTokenizers), ) path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( default=PathTokenizers.StepSequence(), loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), ) def _sequence_tokens( self, adj_list: list[str], origin: list[str], target: list[str], path: list[str], ) -> list[str]: return [ VOCAB.ADJLIST_START, *adj_list, VOCAB.ADJLIST_END, VOCAB.ORIGIN_START, *origin, VOCAB.ORIGIN_END, VOCAB.TARGET_START, *target, VOCAB.TARGET_END, VOCAB.PATH_START, *path, VOCAB.PATH_END, ] @serializable_dataclass(frozen=True, kw_only=True) class AOP(_PromptSequencer): """Sequences a prompt as [adjacency list, origin, path]. Still includes "" and "" tokens, but no representation of the target itself. # Parameters - `path_tokenizer`: Tokenizer element which tokenizes the solution path of a `SolvedMaze`. Uses `coord_tokenizer` to tokenize coords if that is part of the design of that `PathTokenizer`. """ path_tokenizer: PathTokenizers._PathTokenizer = serializable_field( default=PathTokenizers.StepSequence(), loading_fn=lambda x: _load_tokenizer_element(x, PathTokenizers), ) def _sequence_tokens( self, adj_list: list[str], origin: list[str], # explicitly no target in this tokenizer target: list[str], path: list[str], ) -> list[str]: return [ VOCAB.ADJLIST_START, *adj_list, VOCAB.ADJLIST_END, VOCAB.ORIGIN_START, *origin, VOCAB.ORIGIN_END, VOCAB.TARGET_START, VOCAB.TARGET_END, VOCAB.PATH_START, *path, VOCAB.PATH_END, ] ``````{ end_of_file="maze_dataset/tokenization/modular/elements.py" } ``````{ path="maze_dataset/tokenization/modular/fst.py" } """to check if a tokenizer is one of our "approved" ones, we store this in a fst set using `rust_fst` this file handles the creation of this fst file, which we ship to the user this file relies on importing `get_all_tokenizers` and thus `MazeTokenizerModular`. as such, loading this file for validating a tokenizer is the separate `maze_dataset.tokenization.modular.fst_load` module, since we need to be able to import that from `maze_dataset.tokenization.modular.maze_tokenizer_modular` and we cannot circularly import """ import functools import random import tqdm from muutils.misc.numerical import shorten_numerical_to_str from muutils.parallel import run_maybe_parallel from muutils.spinner import NoOpContextManager, SpinnerContext from rust_fst import Set as FstSet # type: ignore[import-untyped] from maze_dataset.tokenization.modular.all_tokenizers import get_all_tokenizers from maze_dataset.tokenization.modular.fst_load import ( MMT_FST_PATH, check_tokenizer_in_fst, get_tokenizers_fst, ) def _get_tokenizer_name(tokenizer) -> str: # noqa: ANN001 return tokenizer.name def save_all_tokenizers_fst( verbose: bool = True, parallel: bool | int = False ) -> FstSet: """get all the tokenizers, save an fst file at `MMT_FST_PATH` and return the set""" # TYPING: add a protocol or abc for both of these which is a context manager that takes the args we care about # probably do this in muutils sp: type[SpinnerContext | NoOpContextManager] = ( SpinnerContext if verbose else NoOpContextManager ) with sp(message="getting all tokenizers"): all_tokenizers: list = get_all_tokenizers() n_tokenizers: int = len(all_tokenizers) all_tokenizers_names: list[str] = run_maybe_parallel( func=_get_tokenizer_name, iterable=all_tokenizers, parallel=parallel, pbar=tqdm.tqdm, pbar_kwargs=dict( total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose ), ) assert n_tokenizers == len(all_tokenizers_names) print( f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names" ) with sp(message="sorting tokenizer names"): all_tokenizers_names_sorted: list[str] = sorted(all_tokenizers_names) # construct an fst set and save it # we expect it to be 1.6kb or so with sp(message="constructing and saving tokenizers fst set"): tok_set: FstSet = FstSet.from_iter( all_tokenizers_names_sorted, path=MMT_FST_PATH.as_posix(), ) print( f"# tokenizers fst set saved to {MMT_FST_PATH}, size: {MMT_FST_PATH.stat().st_size} bytes" ) return tok_set def check_tokenizers_fst( verbose: bool = True, parallel: bool | int = False, n_check: int | None = None, ) -> FstSet: "regen all tokenizers, check they are in the pre-existing fst set" sp: type[SpinnerContext | NoOpContextManager] = ( SpinnerContext if verbose else NoOpContextManager ) with sp(message="getting all tokenizers from scratch"): all_tokenizers: list = get_all_tokenizers() with sp(message="load the pre-existing tokenizers fst set"): get_tokenizers_fst() n_tokenizers: int = len(all_tokenizers) selected_tokenizers: list if n_check is not None: selected_tokenizers = random.sample(all_tokenizers, n_check) else: selected_tokenizers = all_tokenizers tokenizers_names: list[str] = run_maybe_parallel( func=_get_tokenizer_name, iterable=selected_tokenizers, parallel=parallel, pbar=tqdm.tqdm, pbar_kwargs=dict( total=n_tokenizers, desc="get name of each tokenizer", disable=not verbose ), ) if n_check is None: assert n_tokenizers == len(tokenizers_names) print( f"# got {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) tokenizers names" ) else: assert n_check == len(tokenizers_names) print( f"# selected {n_check} tokenizers to check out of {shorten_numerical_to_str(n_tokenizers)} ({n_tokenizers}) total" ) check_tokenizer_in_fst__do_except = functools.partial( check_tokenizer_in_fst, do_except=True ) run_maybe_parallel( func=check_tokenizer_in_fst__do_except, iterable=tokenizers_names, parallel=parallel, pbar=tqdm.tqdm, pbar_kwargs=dict( total=len(selected_tokenizers), desc="checking tokenizers in fst", disable=not verbose, ), ) if n_check is None: print("# all tokenizers are in the pre-existing fst set!") else: print(f"# all {n_check} selected tokenizers are in the pre-existing fst set!") if __name__ == "__main__": import argparse arg_parser: argparse.ArgumentParser = argparse.ArgumentParser( description="save the tokenizers fst set" ) arg_parser.add_argument( "-c", "--check", action="store_true", help="check that all tokenizers are in the pre-existing fst set", ) arg_parser.add_argument( "-q", "--quiet", action="store_true", help="don't show spinners and progress bars", ) arg_parser.add_argument( "-p", "--parallel", action="store", nargs="?", type=int, const=True, default=False, help="Control parallelization. will run in serial if nothing specified, use all cpus if flag passed without args, or number of cpus if int passed.", ) arg_parser.add_argument( "-n", "--n-check", action="store", default=None, help="if passed, check n random tokenizers. pass an int to check that many. pass 'none' or a -1 to check all", ) args: argparse.Namespace = arg_parser.parse_args() n_check: int | None = ( int(args.n_check) if (args.n_check is not None and args.n_check.lower() != "none") else None ) if n_check is not None and n_check < 0: n_check = None if args.check: check_tokenizers_fst( verbose=not args.quiet, parallel=args.parallel, n_check=n_check, ) else: save_all_tokenizers_fst(verbose=not args.quiet, parallel=args.parallel) ``````{ end_of_file="maze_dataset/tokenization/modular/fst.py" } ``````{ path="maze_dataset/tokenization/modular/fst_load.py" } """to check if a tokenizer is one of our "approved" ones, look in an fst set we made with `rust_fst` this file handles the creation of this fst file, which we ship to the user this file relies on importing `get_all_tokenizers` and thus `MazeTokenizerModular`. as such, loading this file for validating a tokenizer is the separate `maze_dataset.tokenization.modular.fst_load` module, since we need to be able to import that from `maze_dataset.tokenization.modular.maze_tokenizer_modular` and we cannot circularly import thanks to https://github.com/rozbb for suggesting doing this instead of storing a whole bunch of hashes like we were doing before """ import warnings from functools import cache from pathlib import Path _RUST_FST_LOADED: bool = False """if the rust_fst module was loaded successfully""" _RUST_FST_ERR_MSG: str = ( "you need the `rust_fst` package to use `maze_dataset.tokenization.modular` properly. installing `maze-dataset[tokenization]` will install it\n" "Note that rust-fst doesn't work on mac, see https://github.com/understanding-search/maze-dataset/issues/57\n" "and this makes modular tokenizers not checkable on mac. Things should still work, but you will have no guarantee that a tokenizer is tested.\n" "If you can find away around this, please let us know!\n" ) class RustFstNotLoadedWarning(UserWarning): """warning for when `rust_fst` is not loaded""" try: from rust_fst import Set as FstSet # type: ignore[import-untyped] _RUST_FST_LOADED = True except ImportError as e: warnings.warn(_RUST_FST_ERR_MSG + str(e), RustFstNotLoadedWarning) _RUST_FST_LOADED = False MMT_FST_PATH: Path = Path(__file__).parent / "MazeTokenizerModular_tested.fst" @cache def get_tokenizers_fst() -> "FstSet": """(cached) load the tokenizers fst set from `MMT_FST_PATH`""" return FstSet(MMT_FST_PATH.as_posix()) def check_tokenizer_in_fst(tokenizer_name: str, do_except: bool = False) -> bool: """check if a tokenizer is in the fst set prints nearest matches if `do_except` is `True` and the tokenizer is not found """ search_0: list[str] = list(get_tokenizers_fst().search(tokenizer_name, 0)) in_fst: bool = len(search_0) == 1 and search_0[0] == tokenizer_name if do_except and not in_fst: search_1: list[str] | None = None search_2: list[str] | None = None try: search_1 = list(get_tokenizers_fst().search(tokenizer_name, 1)) search_2 = list(get_tokenizers_fst().search(tokenizer_name, 2)) except Exception: # noqa: BLE001, S110 # the only thing failing here is getting possible match tokenizers, so it's fine to just ignore the errors pass err_msg: str = ( f"Tokenizer `{tokenizer_name}` not found in the list of tested tokenizers, and {do_except = }. We found the following matches based on edit distance:" f"\nedit dist 0 (should be empty?): {search_0}" + (f"\nedit dist 1: {search_1}" if search_1 is not None else "") + (f"\nedit dist 2: {search_2}" if search_2 is not None else "") ) raise ValueError(err_msg) return in_fst def _check_tokenizer_in_fst_mock(tokenizer_name: str, do_except: bool = False) -> bool: # noqa: ARG001 """mock function for `check_tokenizer_in_fst` runs when we cant import `rust_fst` which sets `_RUST_FST_LOADED` to `False` """ warnings.warn( _RUST_FST_ERR_MSG + "you are seeing this warning probably because you tried to run" "`MazeTokenizerModular(...).is_tested_tokenizer()` on a mac or without `rust_fst` installed" + "this is fine, but note that the tokenizer will be checked for validity, but is not part of the tested set" ) return True # override the function if we can't load rust_fst if not _RUST_FST_LOADED: check_tokenizer_in_fst = _check_tokenizer_in_fst_mock ``````{ end_of_file="maze_dataset/tokenization/modular/fst_load.py" } ``````{ path="maze_dataset/tokenization/modular/hashing.py" } """legacy system for checking a `ModularMazeTokenizer` is valid -- compare its hash to a table of known hashes this has been superseded by the fst system """ import hashlib from pathlib import Path import numpy as np from jaxtyping import UInt32 # NOTE: these all need to match! AllTokenizersHashBitLength = 32 "bit length of the hashes of all tokenizers, must match `AllTokenizersHashDtype` and `AllTokenizersHashesArray`" AllTokenizersHashDtype = np.uint32 "numpy data type of the hashes of all tokenizers, must match `AllTokenizersHashBitLength` and `AllTokenizersHashesArray`" AllTokenizersHashesArray = UInt32[np.ndarray, " n_tokens"] "jaxtyping type of the hashes of all tokenizers, must match `AllTokenizersHashBitLength` and `AllTokenizersHashDtype`" def _hash_tokenizer_name(s: str) -> int: h64: int = int.from_bytes( hashlib.shake_256(s.encode("utf-8")).digest(64), byteorder="big", ) return (h64 >> 32) ^ (h64 & 0xFFFFFFFF) _ALL_TOKENIZER_HASHES: AllTokenizersHashesArray "private array of all tokenizer hashes" _TOKENIZER_HASHES_PATH: Path = Path(__file__).parent / "MazeTokenizerModular_hashes.npz" "path to where we expect the hashes file -- in the same dir as this file, by default. change with `set_tokenizer_hashes_path`" def set_tokenizer_hashes_path(path: Path) -> None: """set path to tokenizer hashes, and reload the hashes if needed the hashes are expected to be stored in and read from `_TOKENIZER_HASHES_PATH`, which by default is `Path(__file__).parent / "MazeTokenizerModular_hashes.npz"` or in this file's directory. However, this might not always work, so we provide a way to change this. """ global _TOKENIZER_HASHES_PATH, _ALL_TOKENIZER_HASHES # noqa: PLW0603 path = Path(path) if path.is_dir(): path = path / "MazeTokenizerModular_hashes.npz" if not path.is_file(): err_msg: str = f"could not find maze tokenizer hashes file at: {path}" raise FileNotFoundError(err_msg) if _TOKENIZER_HASHES_PATH.absolute() != path.absolute(): # reload if they aren't equal _TOKENIZER_HASHES_PATH = path _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() else: # always set to new path _TOKENIZER_HASHES_PATH = path def _load_tokenizer_hashes() -> AllTokenizersHashesArray: """Loads the sorted list of `all_tokenizers.get_all_tokenizers()` hashes from disk.""" global _TOKENIZER_HASHES_PATH # noqa: PLW0602 try: path: Path = _TOKENIZER_HASHES_PATH return np.load(path)["hashes"] except FileNotFoundError as e: err_msg: str = ( "Tokenizers hashes cannot be loaded. To fix this, run" "\n`python -m maze-dataset.tokenization.save_hashes` which will save the hashes to" "\n`data/MazeTokenizerModular_hashes.npz`" "\nrelative to the current working directory -- this is where the code looks for them." ) raise FileNotFoundError(err_msg) from e def get_all_tokenizer_hashes() -> AllTokenizersHashesArray: """returns all the tokenizer hashes in an `AllTokenizersHashesDtype` array, setting global variable if needed""" # naughty use of globals global _ALL_TOKENIZER_HASHES # noqa: PLW0603 try: got_tokenizers: bool = len(_ALL_TOKENIZER_HASHES) > 0 if got_tokenizers: return _ALL_TOKENIZER_HASHES else: _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() except NameError: _ALL_TOKENIZER_HASHES = _load_tokenizer_hashes() return _ALL_TOKENIZER_HASHES ``````{ end_of_file="maze_dataset/tokenization/modular/hashing.py" } ``````{ path="maze_dataset/tokenization/modular/maze_tokenizer_modular.py" } "implements the actual `MazeTokenizerModular` class" import base64 import warnings from functools import cached_property from typing import ( Iterable, Literal, Sequence, overload, ) from muutils.json_serialize import ( SerializableDataclass, serializable_dataclass, serializable_field, ) from muutils.misc import flatten from muutils.misc.sequence import WhenMissing # from maze_dataset import SolvedMaze from maze_dataset.constants import ( VOCAB, VOCAB_LIST, VOCAB_TOKEN_TO_INDEX, Coord, CoordTup, ) from maze_dataset.maze.lattice_maze import LatticeMaze from maze_dataset.token_utils import ( TokenizerPendingDeprecationWarning, strings_to_coords, ) from maze_dataset.tokenization.common import TokenError from maze_dataset.tokenization.maze_tokenizer_legacy import ( MazeTokenizer, TokenizationMode, ) from maze_dataset.tokenization.modular.element_base import ( _load_tokenizer_element, _TokenizerElement, ) from maze_dataset.tokenization.modular.elements import CoordTokenizers, PromptSequencers from maze_dataset.tokenization.modular.fst_load import check_tokenizer_in_fst from maze_dataset.tokenization.modular.hashing import ( _hash_tokenizer_name, ) @serializable_dataclass( frozen=True, kw_only=True, properties_to_serialize=["tokenizer_element_tree_concrete", "name"], ) class MazeTokenizerModular(SerializableDataclass): """Tokenizer for mazes # Parameters - `prompt_sequencer`: Tokenizer element which assembles token regions (adjacency list, origin, target, path) into a complete prompt. # Development - To ensure backwards compatibility, the default constructor must always return a tokenizer equivalent to the legacy `TokenizationMode.AOTP_UT_Uniform`. - Furthermore, the mapping reflected in `from_legacy` must also be maintained. - Updates to `MazeTokenizerModular` or the `_TokenizerElement` hierarchy must maintain that behavior. """ prompt_sequencer: PromptSequencers._PromptSequencer = serializable_field( default=PromptSequencers.AOTP(), loading_fn=lambda x: _load_tokenizer_element(x, PromptSequencers), ) def hash_int(self) -> int: "return integer hash using blake2b" return _hash_tokenizer_name(self.name) def __hash__(self) -> int: "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" return self.hash_int() def hash_b64(self, n_bytes: int = 8) -> str: """filename-safe base64 encoding of the hash""" # Use modulus to ensure the integer fits within n_bytes * 8 bits hash_mod: int = self.hash_int() % (1 << (n_bytes * 8)) encoded = base64.b64encode( hash_mod.to_bytes(n_bytes, byteorder="big"), altchars=b"-_", ).decode() # Remove any padding equals signs return encoded.rstrip("=") # Information Querying Methods @cached_property def tokenizer_elements(self) -> list[_TokenizerElement]: "returns a list of all the elements of this tokenizer" return [self.prompt_sequencer, *self.prompt_sequencer.tokenizer_elements()] def tokenizer_element_tree(self, abstract: bool = False) -> str: """Returns a string representation of the tree of tokenizer elements contained in `self`. # Parameters - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. """ return "\n".join( [ type(self).__name__, self.prompt_sequencer.tokenizer_element_tree( abstract=abstract, depth=1, ), ], ) @property def tokenizer_element_tree_concrete(self) -> str: """Property wrapper for `tokenizer_element_tree` so that it can be used in `properties_to_serialize`.""" return self.tokenizer_element_tree() def tokenizer_element_dict(self) -> dict: """Nested dictionary of the internal `TokenizerElement`s.""" return {type(self).__name__: self.prompt_sequencer.tokenizer_element_dict()} @property def name(self) -> str: """Serializes MazeTokenizer into a key for encoding in zanj""" return "-".join([type(self).__name__, self.prompt_sequencer.name]) # noqa: FLY002 def summary(self) -> dict[str, str]: """Single-level dictionary of the internal `TokenizerElement`s.""" return { # "prompt_sequencer": self.prompt_sequencer.name, **{elem.attribute_key(): elem.name for elem in self.tokenizer_elements}, } @staticmethod def _type_check(obj: any) -> None: """Helper method for `has_element`""" if not ( isinstance(obj, _TokenizerElement) or (isinstance(obj, type) and issubclass(obj, _TokenizerElement)) ): err_msg: str = f"{obj} is not a `_TokenizerElement` instance or subclass." raise TypeError(err_msg) def _has_element_singular( self, el: type[_TokenizerElement] | _TokenizerElement, ) -> bool: """Helper method for `has_element`""" self._type_check(el) if isinstance(el, type): return any(isinstance(e, el) for e in self.tokenizer_elements) else: return el in self.tokenizer_elements def has_element( self, *elements: Sequence[type[_TokenizerElement] | _TokenizerElement], ) -> bool: """Returns True if the `MazeTokenizerModular` instance contains ALL of the items specified in `elements`. Querying with a partial subset of `_TokenizerElement` fields is not currently supported. To do such a query, assemble multiple calls to `has_elements`. # Parameters - `elements`: Singleton or iterable of `_TokenizerElement` instances or classes. If an instance is provided, then comparison is done via instance equality. If a class is provided, then comparison isdone via `isinstance`. I.e., any instance of that class is accepted. """ if len(elements) == 1 and isinstance(elements[0], Iterable): elements = elements[0] return all(self._has_element_singular(e) for e in elements) def is_valid(self, do_except: bool = False) -> bool: """Returns `True` if `self` is a valid tokenizer. Evaluates the validity of all of `self.tokenizer_elements` according to each one's method. """ return all(el.is_valid(do_except=do_except) for el in self.tokenizer_elements) def is_legacy_equivalent(self) -> bool: """Returns if `self` has identical stringification behavior as any legacy `MazeTokenizer`.""" return any( self == MazeTokenizerModular.from_legacy(tok_mode) for tok_mode in TokenizationMode ) def is_tested_tokenizer(self, do_except: bool = False) -> bool: """Returns if the tokenizer is returned by `all_tokenizers.get_all_tokenizers`, the set of tested and reliable tokenizers. uses an fst on the `name` attributes of all the tokenizers if `do_assert` is `True`, raises an `AssertionError` if the tokenizer is not tested. """ is_valid: bool = self.is_valid(do_except=do_except) in_tested_fst: bool = check_tokenizer_in_fst(self.name, do_except=do_except) if do_except: assert is_valid, "self.is_valid returns False" return True else: return in_tested_fst and is_valid def is_AOTP(self) -> bool: "is this tokenizer an AOTP tokenizer? AOTP = Adjacency list, Origin, Target, Path" return self.has_element(PromptSequencers.AOTP) def is_UT(self) -> bool: "is this tokenizer a UT tokenizer? UT = Unique Token (for each coord)" return self.has_element(CoordTokenizers.UT) # Alternate Constructors # ====================== @classmethod def from_legacy( cls, legacy_maze_tokenizer: MazeTokenizer | TokenizationMode, ) -> "MazeTokenizerModular": """Maps a legacy `MazeTokenizer` or `TokenizationMode` to its equivalent `MazeTokenizerModular` instance.""" if isinstance(legacy_maze_tokenizer, MazeTokenizer): legacy_maze_tokenizer = legacy_maze_tokenizer.tokenization_mode return { TokenizationMode.AOTP_UT_uniform: MazeTokenizerModular(), TokenizationMode.AOTP_UT_rasterized: MazeTokenizerModular(), TokenizationMode.AOTP_CTT_indexed: MazeTokenizerModular( prompt_sequencer=PromptSequencers.AOTP( coord_tokenizer=CoordTokenizers.CTT(), ), ), }[legacy_maze_tokenizer] # Simple properties # ================= @classmethod def from_tokens( cls, tokens: str | list[str], ) -> "MazeTokenizerModular": """Infers most `MazeTokenizerModular` parameters from a full sequence of tokens.""" raise NotImplementedError( "Recovering tokenizer objects from MazeTokenizerModular-produced strings is not supported", ) @property def token_arr(self) -> list[str] | None: """map from index to token""" return VOCAB_LIST @property def tokenizer_map(self) -> dict[str, int]: """map from token to index""" return VOCAB_TOKEN_TO_INDEX @property def vocab_size(self) -> int: """Number of tokens in the static vocab""" return len(VOCAB_LIST) @property def n_tokens(self) -> int: "get the number of tokens in the vocabulary (deprecated)" err_msg: str = "`MazeTokenizerModular.n_tokens` has been removed. Use `len(maze_dataset.VOCAB_LIST)` instead." raise NameError(err_msg) @property def padding_token_index(self) -> int: "get the index of the padding token" return VOCAB_TOKEN_TO_INDEX[VOCAB.PADDING] # conversion functions # ============================================================ def to_tokens( self, maze: LatticeMaze, ) -> list[str]: """Converts maze into a list of tokens.""" return self.prompt_sequencer.to_tokens(maze) def coords_to_strings(self, coords: list[CoordTup | Coord]) -> list[str]: "calls self.prompt_sequencer.coord_tokenizer.to_tokens(c) for each c in coords" return list( flatten( [self.prompt_sequencer.coord_tokenizer.to_tokens(c) for c in coords], ), ) # TODO: unclear why we need to use `noqa: N805` here since its a classmethod # maybe we need to hit every overload with `@classmethod`? @overload def strings_to_coords( cls, # noqa: N805 text: str | list[str], when_noncoord: Literal["skip"] = "skip", ) -> list[CoordTup]: ... @overload def strings_to_coords( cls, # noqa: N805 text: str | list[str], when_noncoord: Literal["error"] = "error", ) -> list[CoordTup]: ... @overload def strings_to_coords( cls, # noqa: N805 text: str | list[str], when_noncoord: Literal["include"] = "include", ) -> list[str | CoordTup]: ... @classmethod def strings_to_coords( cls, text: str | list[str], when_noncoord: WhenMissing = "skip", ) -> list[str | CoordTup]: "wrapper for maze_dataset.token_utils.strings_to_coords" warnings.warn( "`MazeTokenizerModular.strings_to_coords` only supports legacy UT strings.", TokenizerPendingDeprecationWarning, ) return strings_to_coords(text=text, when_noncoord=when_noncoord) @staticmethod def encode(text: str | list[str]) -> list[int]: """encode a string or list of strings into a list of tokens""" try: if isinstance(text, str): text = text.split() return [VOCAB_TOKEN_TO_INDEX[token] for token in text] except KeyError as e: err_msg: str = f"Token {e} not found in `VOCAB`." raise TokenError(err_msg) from e @staticmethod def decode( token_ids: Sequence[int], joined_tokens: bool = False, ) -> list[str] | str: """decode a list of tokens into a string or list of strings""" try: output: list[str] = [VOCAB_LIST[token_id] for token_id in token_ids] except IndexError as e: err_msg: str = f"Token index '{e}' not found in `VOCAB`." raise TokenError(err_msg) from e if joined_tokens: return " ".join(output) else: return output ``````{ end_of_file="maze_dataset/tokenization/modular/maze_tokenizer_modular.py" } ``````{ path="maze_dataset/tokenization/modular/save_hashes.py" } """generate and save the hashes of all supported tokenizers > [!CAUTION] > using hashes to validate validate a `MazeTokenizerModular` is deprecated in favor of using fst calls `maze_dataset.tokenization.all_tokenizers.save_hashes()` Usage: To save to the default location (inside package, `maze_dataset/tokenization/MazeTokenizerModular_hashes.npy`): ```bash python -m maze_dataset.tokenization.save_hashes ``` to save to a custom location: ```bash python -m maze_dataset.tokenization.save_hashes /path/to/save/to.npy ``` to check hashes shipped with the package: ```bash python -m maze_dataset.tokenization.save_hashes --check ``` """ from pathlib import Path import numpy as np from muutils.spinner import SpinnerContext from maze_dataset.tokenization.modular import all_tokenizers from maze_dataset.tokenization.modular.hashing import ( _load_tokenizer_hashes, get_all_tokenizer_hashes, ) if __name__ == "__main__": # parse args # ================================================== import argparse parser: argparse.ArgumentParser = argparse.ArgumentParser( description="generate and save (or download) the hashes of all supported tokenizers", ) parser.add_argument("path", type=str, nargs="?", help="path to save the hashes to") parser.add_argument( "--quiet", "-q", action="store_true", help="disable progress bar and spinner", ) parser.add_argument( "--parallelize", "-p", action="store_true", help="parallelize the computation", ) parser.add_argument( "--check", "-c", action="store_true", help="save to temp location, then compare to existing", ) parser.add_argument( "--download", "-d", action="store_true", help=f"download the hashes from github: {all_tokenizers.DOWNLOAD_URL}", ) args: argparse.Namespace = parser.parse_args() if not args.check: # write new hashes # ================================================== all_tokenizers.save_hashes( path=args.path, verbose=not args.quiet, parallelize=args.parallelize, ) else: # check hashes only # ================================================== # set up path if args.path is not None: raise ValueError("cannot use --check with a custom path") temp_path: Path = Path("tests/_temp/tok_hashes.npz") temp_path.parent.mkdir(parents=True, exist_ok=True) # generate and save to temp location returned_hashes: np.ndarray = all_tokenizers.save_hashes( path=temp_path, verbose=not args.quiet, parallelize=args.parallelize, ) # load saved hashes with SpinnerContext( spinner_chars="square_dot", update_interval=0.5, message="loading saved hashes...", ): read_hashes: np.ndarray = np.load(temp_path)["hashes"] read_hashes_pkg: np.ndarray = _load_tokenizer_hashes() read_hashes_wrapped: np.ndarray = get_all_tokenizer_hashes() # compare with SpinnerContext( spinner_chars="square_dot", update_interval=0.01, message="checking hashes: ", format_string="\r{spinner} ({elapsed_time:.2f}s) {message}{value} ", format_string_when_updated=True, ) as sp: sp.update_value("returned vs read") assert np.array_equal(returned_hashes, read_hashes) sp.update_value("returned vs _load_tokenizer_hashes") assert np.array_equal(returned_hashes, read_hashes_pkg) sp.update_value("returned vs get_all_tokenizer_hashes()") assert np.array_equal(read_hashes, read_hashes_wrapped) ``````{ end_of_file="maze_dataset/tokenization/modular/save_hashes.py" } ``````{ path="maze_dataset/tokenization/__init__.py" } """turning a maze into text - `MazeTokenizerModular` is the new recommended way to do this as of 1.0.0 - legacy `TokenizationMode` enum and `MazeTokenizer` class for supporting existing code - a variety of helper classes and functions There are many algorithms by which one might tokenize a 2D maze into a 1D format usable by autoregressive text models. Training multiple models on the encodings output from each of these algorithms may produce very different internal representations, learned solution algorithms, and levels of performance. To explore how different maze tokenization algorithms affect these models, the `MazeTokenizerModular` class contains a rich set of options to customize how mazes are stringified. This class contains 19 discrete parameters, resulting in 5.9 million unique tokenizers. But wait, there's more! There are 6 additional parameters available in the library which are untested but further expand the the number of tokenizers by a factor of $44/3$ to 86 million. All output sequences consist of four token regions representing different features of the maze. These regions are distinguished by color in Figure below. - Adjacency list: A text representation of the lattice graph - Origin: Starting coordinate - Target: Ending coordinate - Path: Maze solution sequence from the start to the end ![Example text output format with token regions highlighted.](figures/outputs-tokens-colored.tex) Each `MazeTokenizerModular` is constructed from a set of several `_TokenizerElement` objects, each of which specifies how different token regions or other elements of the stringification are produced. ![Nested internal structure of `_TokenizerElement` objects inside a typical `MazeTokenizerModular` object.](figures/TokenizerElement_structure.pdf) Optional delimiter tokens may be added in many places in the output. Delimiter options are all configured using the parameters named `pre`, `intra`, and `post` in various `_TokenizerElement` classes. Each option controls a unique delimiter token. Here we describe each `_TokenizerElement` and the behaviors they support. We also discuss some of the model behaviors and properties that may be investigated using these options. ### Coordinates The `_CoordTokenizer` object controls how coordinates in the lattice are represented in across all token regions. Options include: - **Unique tokens**: Each coordinate is represented as a single unique token `"(i,j)"` - **Coordinate tuple tokens**: Each coordinate is represented as a sequence of 2 tokens, respectively encoding the row and column positions: `["i", ",", "j"]` ### Adjacency List The `_AdjListTokenizer` object controls this token region. All tokenizations represent the maze connectivity as a sequence of connections or walls between pairs of adjacent coordinates in the lattice. - `_EdgeSubset`: Specifies the subset of lattice edges to be tokenized - **All edges**: Every edge in the lattice - **Connections**: Only edges which contain a connection - **Walls**: Only edges which contain a wall - `_EdgePermuter`: Specifies how to sequence the two coordinates in each lattice edge - **Random** - **Sorted**: The smaller coordinate always comes first - **Both permutations**: Each edge is represented twice, once with each permutation. This option attempts to represent connections in a more directionally symmetric manner. Including only one permutation of each edge may affect models' internal representations of edges, treating a path traversing the edge differently depending on if the coordinate sequence in the path matches the sequence in the adjacency list. - `shuffle_d0`: Whether to shuffle the edges randomly or sort them in the output by their first coordinate - `connection_token_ordinal`: Location in the sequence of the token representing whether the edge is a connection or a wall ### Path The `_PathTokenizer` object controls this token region. Paths are all represented as a sequence of steps moving from the start to the end position. - `_StepSize`: Specifies the size of each step - **Singles**: Every coordinate traversed between start and end is directly represented - **Forks**: Only coordinates at forking points in the maze are represented. The paths between forking points are implicit. Using this option might train models more directly to represent forking points differently from coordinates where the maze connectivity implies an obvious next step in the path. - `_StepTokenizer`: Specifies how an individual step is represented - **Coordinate**: The coordinates of each step are directly tokenized using a `_CoordTokenizer` - **Cardinal direction**: A single token corresponding to the cardinal direction taken at the starting position of that step. E.g., `NORTH`, `SOUTH`. If using a `_StepSize` other than **Singles**, this direction may not correspond to the final direction traveled to arrive at the end position of the step. - **Relative direction**: A single token corresponding to the first-person perspective relative direction taken at the starting position of that step. E.g., `RIGHT`, `LEFT`. - **Distance**: A single token corresponding to the number of coordinate positions traversed in that step. E.g., using a `_StepSize` of **Singles**, the **Distance** token would be the same for each step, corresponding to a distance of 1 coordinate. This option is only of interest in combination with a `_StepSize` other than **Singles**. A `_PathTokenizer` contains a sequence of one or more unique `_StepTokenizer` objects. Different step representations may be mixed and permuted, allowing for investigation of model representations of multiple aspects of a maze solution at once. ## Tokenized Outputs for Training and Evaluation {#token-training} During deployment we provide only the prompt up to the `` token. Examples of usage of this dataset to train autoregressive transformers can be found in our `maze-transformer` library [@maze-transformer-github]. Other tokenization and vocabulary schemes are also included, such as representing each coordinate as a pair of $i,j$ index tokens. ## Extensibility The tokenizer architecture is purposefully designed such that adding and testing a wide variety of new tokenization algorithms is fast and minimizes disturbances to functioning code. This is enabled by the modular architecture and the automatic inclusion of any new tokenizers in integration tests. To create a new tokenizer, developers forking the library may simply create their own `_TokenizerElement` subclass and implement the abstract methods. If the behavior change is sufficiently small, simply adding a parameter to an existing `_TokenizerElement` subclass and updating its implementation will suffice. For small additions, simply adding new cases to existing unit tests will suffice. The breadth of tokenizers is also easily scaled in the opposite direction. Due to the exponential scaling of parameter combinations, adding a small number of new features can significantly slow certain procedures which rely on constructing all possible tokenizers, such as integration tests. If any existing subclass contains features which aren't needed, a developer tool decorator is provided which can be applied to the unneeded `_TokenizerElement` subclasses to prune those features and compact the available space of tokenizers. """ from maze_dataset.tokenization.maze_tokenizer_legacy import ( MazeTokenizer, TokenizationMode, get_tokens_up_to_path_start, ) from maze_dataset.tokenization.modular.element_base import _TokenizerElement from maze_dataset.tokenization.modular.elements import ( AdjListTokenizers, CoordTokenizers, EdgeGroupings, EdgePermuters, EdgeSubsets, PathTokenizers, PromptSequencers, StepSizes, StepTokenizers, TargetTokenizers, ) from maze_dataset.tokenization.modular.maze_tokenizer_modular import ( MazeTokenizerModular, ) # we don't sort alphabetically on purpose, we sort by the type __all__ = [ # submodules "modular", "common", "maze_tokenizer_legacy", "maze_tokenizer", # legacy tokenizer "MazeTokenizer", "TokenizationMode", # MMT "MazeTokenizerModular", # element base "_TokenizerElement", # elements "PromptSequencers", "CoordTokenizers", "AdjListTokenizers", "EdgeGroupings", "EdgePermuters", "EdgeSubsets", "TargetTokenizers", "StepSizes", "StepTokenizers", "PathTokenizers", # helpers "get_tokens_up_to_path_start", ] ``````{ end_of_file="maze_dataset/tokenization/__init__.py" } ``````{ path="maze_dataset/tokenization/common.py" } "common code for various tokenizers" class TokenError(ValueError): """error for tokenization""" pass ``````{ end_of_file="maze_dataset/tokenization/common.py" } ``````{ path="maze_dataset/tokenization/maze_tokenizer.py" } """preserving legacy imports""" from maze_dataset.tokenization.maze_tokenizer_legacy import ( MazeTokenizer, TokenizationMode, ) from maze_dataset.tokenization.modular.maze_tokenizer_modular import ( MazeTokenizerModular, ) __all__ = [ "MazeTokenizer", "TokenizationMode", "MazeTokenizerModular", ] ``````{ end_of_file="maze_dataset/tokenization/maze_tokenizer.py" } ``````{ path="maze_dataset/tokenization/maze_tokenizer_legacy.py" } """legacy tokenizer which uses a `TokenizationMode` enum and a `MazeTokenizer` class > [!CAUTION] > `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended > for use, but will remain for compatibility with existing code. """ import warnings from enum import Enum from functools import cached_property from typing import ( Callable, Iterable, Literal, Mapping, Sequence, overload, ) import numpy as np from muutils.json_serialize import ( SerializableDataclass, serializable_dataclass, serializable_field, ) from muutils.kappa import Kappa from muutils.misc.sequence import WhenMissing # from maze_dataset import SolvedMaze from maze_dataset.constants import ( SPECIAL_TOKENS, CoordTup, ) from maze_dataset.token_utils import ( TokenizerPendingDeprecationWarning, _coord_to_strings_indexed, _coord_to_strings_UT, coords_to_strings, strings_to_coords, ) from maze_dataset.tokenization.common import TokenError from maze_dataset.utils import corner_first_ndindex class TokenizationMode(Enum): """legacy tokenization modes > [!CAUTION] > Legacy mode of tokenization. will still be around in future releases, but is no longer recommended for use. > Use `MazeTokenizerModular` instead. # Abbreviations: - `AOTP`: Ajacency list, Origin, Target, Path - `UT`: Unique Token (for each coordiate) - `CTT`: Coordinate Tuple Tokens (each coordinate is tokenized as a tuple of integers) # Modes: - `AOTP_UT_rasterized`: the "classic" mode: assigning tokens to each coordinate is done via rasterization example: for a 3x3 maze, token order is `(0,0), (0,1), (0,2), (1,0), (1,1), (1,2), (2,0), (2,1), (2,2)` - `AOTP_UT_uniform`: new mode, where a 3x3 tokenization scheme and 5x5 tokenizations scheme are compatible uses `corner_first_ndindex` function to order the tokens - `AOTP_CTT_indexed`: each coordinate is a tuple of integers """ AOTP_UT_rasterized = "AOTP_UT_rasterized" AOTP_UT_uniform = "AOTP_UT_uniform" AOTP_CTT_indexed = "AOTP_CTT_indexed" def to_legacy_tokenizer(self, max_grid_size: int | None = None) -> "MazeTokenizer": "convert the mode to a legacy `MazeTokenizer` object given a `max_grid_size`" return MazeTokenizer(tokenization_mode=self, max_grid_size=max_grid_size) _NDINDEX_FUNC_MAP: dict[ TokenizationMode, Callable[[int], Iterable[tuple[int, ...]]], ] = { TokenizationMode.AOTP_UT_rasterized: lambda n: list(np.ndindex(n, n)), TokenizationMode.AOTP_UT_uniform: lambda n: corner_first_ndindex(n, 2), } def is_UT(tokenization_mode: TokenizationMode) -> bool: "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" return tokenization_mode in ( TokenizationMode.AOTP_UT_rasterized, TokenizationMode.AOTP_UT_uniform, ) def get_tokens_up_to_path_start( tokens: list[str], include_start_coord: bool = True, tokenization_mode: TokenizationMode = TokenizationMode.AOTP_UT_uniform, ) -> list[str]: """get tokens up to the path start token # Parameters: - `tokens : list[str]` - `include_start_coord : bool` (defaults to `True`) - `tokenization_mode : TokenizationMode` (defaults to `TokenizationMode.AOTP_UT_uniform`) # Returns: - `list[str]` subsequence of `tokens` up to the path start token # Raises: - `ValueError` : if `tokenization_mode` is invalid """ warnings.warn( "`maze_tokenizer.get_tokens_up_to_path_start` will be deprecated for a `MazeTokenizerModular`-compatible function in a future release.", TokenizerPendingDeprecationWarning, ) path_start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + 1 if include_start_coord: if is_UT(tokenization_mode): return tokens[: path_start_idx + 1] elif tokenization_mode == TokenizationMode.AOTP_CTT_indexed: return tokens[: path_start_idx + 5] else: err_msg: str = f"Invalid tokenization mode: {tokenization_mode}" raise ValueError(err_msg) else: return tokens[:path_start_idx] _MAZETOKENIZER_PROPERTIES_TO_SERIALIZE: list[str] = [ "name", "max_grid_size", "token_arr", "tokenizer_map", "vocab_size", "padding_token_index", ] @serializable_dataclass( properties_to_serialize=_MAZETOKENIZER_PROPERTIES_TO_SERIALIZE, kw_only=True, ) class MazeTokenizer(SerializableDataclass): """LEGACY Tokenizer for mazes > [!CAUTION] > `MazeTokenizerModular` is the new standard for tokenization. This class is no longer recommended > for use, but will remain for compatibility with existing code. # Parameters: - `tokenization_mode: TokenizationMode` mode of tokenization. required. - `max_grid_size: int | None` maximum grid size. required for actually turning text tokens to numerical tokens, but not for moving between coordinates/mazes and text # Properties - `name: str` auto-generated name of the tokenizer from mode and size ## Conditional Properties - `node_strings_map: Mapping[CoordTup, str]` map from node to string. This returns a `muutils.kappa.Kappa` object which you can use like a dictionary. returns `None` if not a `UT` mode these all return `None` if `max_grid_size` is `None`. Prepend `_` to the name to get a guaranteed type, and cause an exception if `max_grid_size` is `None` - `token_arr: list[str]` list of tokens, in order of their indices in the vocabulary - `tokenizer_map: Mapping[str, int]` map from token to index - `vocab_size: int` size of the vocabulary - `padding_token_index: int` index of the padding token # Methods - `coords_to_strings(coords: list[CoordTup]) -> list[str]` convert a list of coordinates to a list of tokens. Optionally except, skip, or ignore non-coordinates - `strings_to_coords(strings: list[str]) -> list[CoordTup]` convert a list of tokens to a list of coordinates. Optionally except, skip, or ignore non-coordinates """ # parameters # ============================================================ tokenization_mode: TokenizationMode = serializable_field( default=TokenizationMode.AOTP_UT_uniform, serialization_fn=lambda x: x.value, loading_fn=lambda x: TokenizationMode[x["tokenization_mode"]], ) max_grid_size: int | None = serializable_field(default=None) # properties # ============================================================ @property def name(self) -> str: """auto-generated name of the tokenizer from mode and size""" max_grid_size_str: str = ( f"-g{self.max_grid_size}" if self.max_grid_size is not None else "" ) return f"maze_tokenizer-{self.tokenization_mode.value}{max_grid_size_str}" @cached_property def _node_strings_map(self) -> Mapping[CoordTup, list[str]]: """map a coordinate to a token""" if self.tokenization_mode in ( TokenizationMode.AOTP_UT_rasterized, TokenizationMode.AOTP_UT_uniform, ): return Kappa(_coord_to_strings_UT) elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: return Kappa(_coord_to_strings_indexed) else: err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" raise ValueError(err_msg) @cached_property def node_strings_map(self) -> Mapping[CoordTup, list[str]] | None: """map a coordinate to a token""" if self.tokenization_mode in ( TokenizationMode.AOTP_UT_rasterized, TokenizationMode.AOTP_UT_uniform, ): return None else: return self._node_strings_map # conditional properties (on max_grid_size existing) # ------------------------------------------------------------ @cached_property def _token_arr(self) -> list[str]: """map from index to token""" if self.max_grid_size is None: err_msg: str = f"max_grid_size must be specified to use token_arr property: {self.max_grid_size = }" raise ValueError(err_msg) output: list[str] = list(SPECIAL_TOKENS.values()) if self.tokenization_mode in ( TokenizationMode.AOTP_UT_rasterized, TokenizationMode.AOTP_UT_uniform, ): output.extend( [ self._node_strings_map[coord][0] for coord in _NDINDEX_FUNC_MAP[self.tokenization_mode]( self.max_grid_size, ) ], ) elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: # TODO: this is hacky, but we don't want to modify the original SPECIAL_TOKENS since that will break old models output.extend( [ "(", ",", ")", # new special chars *map(str, range(self.max_grid_size)), # numbers ], ) else: err_msg: str = ( f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}", ) raise ValueError(err_msg) return output @cached_property def token_arr(self) -> list[str] | None: "get the token array if the max_grid_size is specified" if self.max_grid_size is None: return None return self._token_arr @cached_property def _tokenizer_map(self) -> dict[str, int]: """map from token to index""" return {token: i for i, token in enumerate(self._token_arr)} @cached_property def tokenizer_map(self) -> dict[str, int] | None: "get the tokenizer map if the max_grid_size is specified" if self.max_grid_size is None: return None return self._tokenizer_map @property def _vocab_size(self) -> int: return len(self._token_arr) @property def vocab_size(self) -> int | None: "get the size of the vocabulary if the max_grid_size is specified" if self.max_grid_size is None: return None return self._vocab_size @property def _n_tokens(self) -> int: # TODO: deprecate return self._vocab_size @property def n_tokens(self) -> int | None: "get the number of tokens if the max_grid_size is specified" if self.max_grid_size is None: return None return self._n_tokens @cached_property def _padding_token_index(self) -> int: return self.tokenizer_map[SPECIAL_TOKENS.PADDING] @cached_property def padding_token_index(self) -> int | None: "get the index of the padding token if it exists" if self.max_grid_size is None: return None return self._padding_token_index # conversion functions # ============================================================ @overload def coords_to_strings( self, coords: list[str | CoordTup], when_noncoord: Literal["include", "skip"] = "skip", ) -> list[str]: ... @overload def coords_to_strings( self, coords: list[CoordTup], when_noncoord: Literal["error"] = "error", ) -> list[str]: ... def coords_to_strings( self, coords: list[CoordTup], when_noncoord: WhenMissing = "skip", ) -> list[str]: """map a list of coordinate tuples (and maybe other tokens) to strings wraps `maze_dataset.token_utils.coords_to_strings` with either `_coord_to_strings_UT` or `_coord_to_strings_indexed` depending on the tokenization mode """ if self.tokenization_mode in ( TokenizationMode.AOTP_UT_rasterized, TokenizationMode.AOTP_UT_uniform, ): return coords_to_strings( coords=coords, coord_to_strings_func=_coord_to_strings_UT, when_noncoord=when_noncoord, ) elif self.tokenization_mode == TokenizationMode.AOTP_CTT_indexed: return coords_to_strings( coords=coords, coord_to_strings_func=_coord_to_strings_indexed, when_noncoord=when_noncoord, ) else: err_msg: str = f"Invalid tokenization mode {self.tokenization_mode}, expected one of {TokenizationMode.__members__}" raise ValueError(err_msg) @overload def strings_to_coords( cls, # noqa: N805 text: str | list[str], when_noncoord: Literal["skip"] = "skip", ) -> list[CoordTup]: ... @overload def strings_to_coords( cls, # noqa: N805 text: str | list[str], when_noncoord: Literal["error"] = "error", ) -> list[CoordTup]: ... @overload def strings_to_coords( cls, # noqa: N805 text: str | list[str], when_noncoord: Literal["include"] = "include", ) -> list[str | CoordTup]: ... @classmethod def strings_to_coords( cls, text: str | list[str], when_noncoord: WhenMissing = "skip", ) -> list[str | CoordTup]: "wrapper for `maze_dataset.token_utils.strings_to_coords`" return strings_to_coords(text=text, when_noncoord=when_noncoord) def encode(self, text: str | list[str]) -> list[int]: """encode a string or list of strings into a list of tokens""" try: if isinstance(text, str): text = text.split() return [self.tokenizer_map[token] for token in text] except KeyError as e: err_msg: str = ( f"Token {e} not found in vocabulary of {self}:\n{self.token_arr}" ) raise TokenError(err_msg) from e def decode( self, tokens: Sequence[int], joined_tokens: bool = False, ) -> list[str] | str: """decode a list of tokens into a string or list of strings""" try: output: list[str] = [self.token_arr[token] for token in tokens] except IndexError as e: err_msg: str = ( f"Token index '{e}' not found in vocabulary of length {self.vocab_size}" ) raise TokenError(err_msg) from e if joined_tokens: return " ".join(output) else: return output # UT-only coordinate stuff # ============================================================ @cached_property def coordinate_tokens_coords(self) -> dict[CoordTup, int]: "map of coordiante tuples to their token ids, only valid for UT" # print(f"{self.tokenization_mode = }") if not self.is_UT(): err_msg: str = f"coordinate_tokens_coords is only valid for UT tokenization modes, got {self.tokenization_mode = }" raise ValueError(err_msg) if self.max_grid_size is None: err_msg: str = f"max_grid_size must be specified to use coordinate_tokens: {self.max_grid_size = }" raise ValueError(err_msg) raw_converted: list[CoordTup | str] = self.strings_to_coords( self.token_arr, when_noncoord="include", ) # filter out non-coordinates return { coord: i for i, coord in enumerate(raw_converted) if not isinstance(coord, str) } @cached_property def coordinate_tokens_ids(self) -> dict[str, int]: "map of coordinate tokens to their token ids, only valid for UT" # checks performed in call output: dict[str, int] = dict() for coord, index in self.coordinate_tokens_coords.items(): _for_key: list[str] = self.coords_to_strings([coord]) assert len(_for_key) == 1 output[_for_key[0]] = index return output # other # ============================================================ def summary(self) -> dict: """returns a summary of the tokenization mode""" return { "tokenization_mode": self.tokenization_mode.value, "max_grid_size": self.max_grid_size, "vocab_size": self.vocab_size, } def is_AOTP(self) -> bool: """returns true if a tokenization mode is Adjacency list, Origin, Target, Path""" return self.tokenization_mode in ( TokenizationMode.AOTP_UT_rasterized, TokenizationMode.AOTP_UT_uniform, TokenizationMode.AOTP_CTT_indexed, ) def is_UT(self) -> bool: "returns true if a tokenization mode is a UT mode: UT = Unique Token (for each coordinate)" return is_UT(self.tokenization_mode) def clear_cache(self) -> None: """clears all cached properties""" # delete the properties only if they exist for name, prop in self.__class__.__dict__.items(): if isinstance(prop, cached_property): # if the property exists, delete it try: # noqa: SIM105 delattr(self, name) except AttributeError: pass ``````{ end_of_file="maze_dataset/tokenization/maze_tokenizer_legacy.py" } ``````{ path="maze_dataset/__init__.py" } """.. include:: ../README.md""" from maze_dataset.constants import ( SPECIAL_TOKENS, VOCAB, VOCAB_LIST, VOCAB_TOKEN_TO_INDEX, Connection, ConnectionArray, ConnectionList, Coord, CoordArray, CoordList, CoordTup, ) from maze_dataset.dataset.collected_dataset import ( MazeDatasetCollection, MazeDatasetCollectionConfig, ) from maze_dataset.dataset.filters import register_maze_filter from maze_dataset.dataset.maze_dataset import ( MazeDataset, MazeDatasetConfig, ) from maze_dataset.dataset.maze_dataset_config import set_serialize_minimal_threshold from maze_dataset.generation.generators import LatticeMazeGenerators from maze_dataset.maze.lattice_maze import LatticeMaze, SolvedMaze, TargetedLatticeMaze __all__ = [ # submodules (with sub-submodules) "benchmark", "dataset", "generation", "maze", "plotting", "tokenization", # submodules "constants", "testing_utils", "token_utils", "utils", # main "SolvedMaze", "MazeDatasetConfig", "MazeDataset", # dataset classes "MazeDatasetCollection", "MazeDatasetCollectionConfig", # maze classes "TargetedLatticeMaze", "LatticeMaze", # other "set_serialize_minimal_threshold", "register_maze_filter", "LatticeMazeGenerators", # types "Coord", "CoordTup", "CoordList", "CoordArray", "Connection", "ConnectionList", "ConnectionArray", # constants "SPECIAL_TOKENS", "VOCAB", "VOCAB_LIST", "VOCAB_TOKEN_TO_INDEX", ] ``````{ end_of_file="maze_dataset/__init__.py" } ``````{ path="maze_dataset/constants.py" } """constants and type hints used accross the package""" import warnings from dataclasses import dataclass, field, make_dataclass from typing import Iterator import numpy as np from jaxtyping import Bool, Int8 from maze_dataset.utils import corner_first_ndindex # various type hints for coordinates, connections, etc. Coord = Int8[np.ndarray, "row_col=2"] "single coordinate as array" CoordTup = tuple[int, int] "single coordinate as tuple" CoordArray = Int8[np.ndarray, "coord row_col=2"] "array of coordinates" CoordList = list[CoordTup] "list of tuple coordinates" Connection = Int8[np.ndarray, "coord=2 row_col=2"] "single connection (pair of coords) as array" ConnectionList = Bool[np.ndarray, "lattice_dim=2 row col"] "internal representation used in `LatticeMaze`" ConnectionArray = Int8[np.ndarray, "edges leading_trailing_coord=2 row_col=2"] "n_edges * 2 * 2 array of connections, like an adjacency list" class SpecialTokensError(Exception): "(unused!) errors related to special tokens" pass _SPECIAL_TOKENS_ABBREVIATIONS: dict[str, str] = { "": "", "": "", "": "", "": "", "": "", "": "", "": "", "": "", "<-->": "<-->", ";": ";", "": "", } "map abbreviations for (some) special tokens" @dataclass(frozen=True) class _SPECIAL_TOKENS_BASE: # noqa: N801 "special dataclass used for handling special tokens" ADJLIST_START: str = "" ADJLIST_END: str = "" TARGET_START: str = "" TARGET_END: str = "" ORIGIN_START: str = "" ORIGIN_END: str = "" PATH_START: str = "" PATH_END: str = "" CONNECTOR: str = "<-->" ADJACENCY_ENDLINE: str = ";" PADDING: str = "" def __getitem__(self, key: str) -> str: key_upper: str = key.upper() if not isinstance(key, str): err_msg: str = f"key must be str, not {type(key)}" raise TypeError(err_msg) # error checking for old lowercase format if key != key_upper: warnings.warn( f"Accessing special token '{key}' without uppercase. this is deprecated and will be removed in the future.", DeprecationWarning, ) key = key_upper # `ADJLIST` used to be `adj_list`, changed to match actual token content if key_upper not in self.keys(): key_upper_modified: str = key_upper.replace("ADJ_LIST", "ADJLIST") if key_upper_modified in self.keys(): warnings.warn( f"Accessing '{key}' in old format, should use {key_upper_modified}. this is deprecated and will be removed in the future.", DeprecationWarning, ) return getattr(self, key_upper_modified) else: err_msg: str = f"invalid special token '{key}'" raise KeyError(err_msg) # normal return return getattr(self, key.upper()) def get_abbrev(self, key: str) -> str: return _SPECIAL_TOKENS_ABBREVIATIONS[self[key]] def __iter__(self) -> Iterator[str]: return iter(self.__dict__.keys()) def __len__(self) -> int: return len(self.__dict__.keys()) def __contains__(self, key: str) -> bool: return key in self.__dict__ def values(self) -> Iterator[str]: return self.__dict__.values() def items(self) -> Iterator[tuple[str, str]]: return self.__dict__.items() def keys(self) -> Iterator[str]: return self.__dict__.keys() SPECIAL_TOKENS: _SPECIAL_TOKENS_BASE = _SPECIAL_TOKENS_BASE() "special tokens" DIRECTIONS_MAP: Int8[np.ndarray, "direction axes"] = np.array( [ [0, 1], # down [0, -1], # up [1, 1], # right [1, -1], # left ], ) "down, up, right, left directions for when inside a `ConnectionList`" NEIGHBORS_MASK: Int8[np.ndarray, "coord point"] = np.array( [ [0, 1], # down [0, -1], # up [1, 0], # right [-1, 0], # left ], ) "down, up, right, left as vectors" # last element of the tuple is actually a Field[str], but mypy complains _VOCAB_FIELDS: list[tuple[str, type[str], str]] = [ # *[(k, str, field(default=v)) for k, v in SPECIAL_TOKENS.items()], ("COORD_PRE", str, field(default="(")), ("COORD_INTRA", str, field(default=",")), ("COORD_POST", str, field(default=")")), ("TARGET_INTRA", str, field(default="=")), ("TARGET_POST", str, field(default="||")), ("PATH_INTRA", str, field(default=":")), ("PATH_POST", str, field(default="THEN")), ("NEGATIVE", str, field(default="-")), ("UNKNOWN", str, field(default="")), *[ (f"TARGET_{a}", str, field(default=f"TARGET_{a}")) for a in "ABCDEFGHIJKLMNOPQRSTUVWXYZ" ], ("TARGET_NORTH", str, field(default="TARGET_NORTH")), ("TARGET_SOUTH", str, field(default="TARGET_SOUTH")), ("TARGET_EAST", str, field(default="TARGET_EAST")), ("TARGET_WEST", str, field(default="TARGET_WEST")), ("TARGET_NORTHEAST", str, field(default="TARGET_NORTHEAST")), ("TARGET_NORTHWEST", str, field(default="TARGET_NORTHWEST")), ("TARGET_SOUTHEAST", str, field(default="TARGET_SOUTHEAST")), ("TARGET_SOUTHWEST", str, field(default="TARGET_SOUTHWEST")), ("TARGET_CENTER", str, field(default="TARGET_CENTER")), ("PATH_NORTH", str, field(default="NORTH")), ("PATH_SOUTH", str, field(default="SOUTH")), ("PATH_EAST", str, field(default="EAST")), ("PATH_WEST", str, field(default="WEST")), ("PATH_FORWARD", str, field(default="FORWARD")), ("PATH_BACKWARD", str, field(default="BACKWARD")), ("PATH_LEFT", str, field(default="LEFT")), ("PATH_RIGHT", str, field(default="RIGHT")), ("PATH_STAY", str, field(default="STAY")), *[ (f"I_{i:03}", str, field(default=f"+{i}")) for i in range(256) ], # General purpose positive int tokens. Used by `StepTokenizers.Distance`. *[ (f"CTT_{i}", str, field(default=f"{i}")) for i in range(128) ], # Coord tuple tokens *[ (f"I_N{-i:03}", str, field(default=f"{i}")) for i in range(-256, 0) ], # General purpose negative int tokens ("PATH_PRE", str, field(default="STEP")), ("ADJLIST_PRE", str, field(default="ADJ_GROUP")), ("ADJLIST_INTRA", str, field(default="&")), ("ADJLIST_WALL", str, field(default="")), *[(f"RESERVE_{i}", str, field(default=f"")) for i in range(708, 1596)], *[ (f"UT_{x:02}_{y:02}", str, field(default=f"({x},{y})")) for x, y in corner_first_ndindex(50) ], ] "fields for the `MazeTokenizerModular` style combined vocab" _VOCAB_BASE: type = make_dataclass( "_VOCAB_BASE", fields=_VOCAB_FIELDS, bases=(_SPECIAL_TOKENS_BASE,), frozen=True, ) "combined vocab class, private" # TODO: edit __getitem__ to add warning for accessing a RESERVE token # HACK: mypy doesn't recognize the fields in this dataclass VOCAB: _VOCAB_BASE = _VOCAB_BASE() # type: ignore "public access to universal vocabulary for `MazeTokenizerModular`" VOCAB_LIST: list[str] = list(VOCAB.values()) "list of `VOCAB` tokens, in order" VOCAB_TOKEN_TO_INDEX: dict[str, int] = {token: i for i, token in enumerate(VOCAB_LIST)} "map of `VOCAB` tokens to their indices" # CARDINAL_MAP: Maps tuple(coord1 - coord0) : cardinal direction CARDINAL_MAP: dict[tuple[int, int], str] = { (-1, 0): VOCAB.PATH_NORTH, (1, 0): VOCAB.PATH_SOUTH, (0, -1): VOCAB.PATH_WEST, (0, 1): VOCAB.PATH_EAST, } "map of cardinal directions to appropriate tokens" ``````{ end_of_file="maze_dataset/constants.py" } ``````{ path="maze_dataset/py.typed" } ``````{ end_of_file="maze_dataset/py.typed" } ``````{ path="maze_dataset/testing_utils.py" } """Shared utilities for tests only. Do not import into any module outside of the tests directory """ import itertools from typing import Final, NamedTuple, Sequence import frozendict import numpy as np from maze_dataset import ( CoordArray, LatticeMaze, LatticeMazeGenerators, MazeDataset, MazeDatasetConfig, SolvedMaze, TargetedLatticeMaze, ) from maze_dataset.tokenization import ( MazeTokenizer, MazeTokenizerModular, TokenizationMode, ) GRID_N: Final[int] = 5 N_MAZES: Final[int] = 5 CFG: Final[MazeDatasetConfig] = MazeDatasetConfig( name="test", grid_n=GRID_N, n_mazes=N_MAZES, maze_ctor=LatticeMazeGenerators.gen_dfs, ) MAZE_DATASET: Final[MazeDataset] = MazeDataset.from_config( CFG, do_download=False, load_local=False, do_generate=True, save_local=False, verbose=True, gen_parallel=False, ) LATTICE_MAZES: Final[tuple[LatticeMaze, ...]] = tuple( LatticeMazeGenerators.gen_dfs(np.array([GRID_N, GRID_N])) for _ in range(N_MAZES) ) _PATHS = tuple(maze.generate_random_path() for maze in LATTICE_MAZES) TARGETED_MAZES: Final[tuple[TargetedLatticeMaze, ...]] = tuple( TargetedLatticeMaze.from_lattice_maze(maze, path[0], path[-1]) for maze, path in zip(LATTICE_MAZES, _PATHS, strict=False) ) # MIXED_MAZES alternates the maze types, so you can slice a contiguous subset and still get all types MIXED_MAZES: Final[tuple[LatticeMaze | TargetedLatticeMaze | SolvedMaze, ...]] = tuple( x for x in itertools.chain.from_iterable( itertools.zip_longest(MAZE_DATASET.mazes, TARGETED_MAZES, LATTICE_MAZES), ) ) class MANUAL_MAZE(NamedTuple): # noqa: N801 """A named tuple for manual maze definitions""" tokens: str ascii: Sequence[str] straightaway_footprints: CoordArray ASCII_MAZES: Final[frozendict.frozendict[str, MANUAL_MAZE]] = frozendict.frozendict( small_3x3=MANUAL_MAZE( tokens=" (2,0) <--> (2,1) ; (0,0) <--> (0,1) ; (0,0) <--> (1,0) ; (0,2) <--> (1,2) ; (1,0) <--> (2,0) ; (0,2) <--> (0,1) ; (2,2) <--> (2,1) ; (1,1) <--> (2,1) ; (0,0) (2,1) (0,0) (1,0) (2,0) (2,1) ", ascii=( "#######", "#S #", "#X### #", "#X# # #", "#X# ###", "#XXE #", "#######", ), straightaway_footprints=np.array( [ [0, 0], [2, 0], [2, 1], ], ), ), big_10x10=MANUAL_MAZE( tokens=" (8,2) <--> (8,3) ; (3,7) <--> (3,6) ; (6,7) <--> (6,8) ; (4,6) <--> (5,6) ; (9,5) <--> (9,4) ; (3,3) <--> (3,4) ; (5,1) <--> (4,1) ; (2,6) <--> (2,7) ; (8,5) <--> (8,4) ; (1,9) <--> (2,9) ; (4,1) <--> (4,2) ; (0,8) <--> (0,7) ; (5,4) <--> (5,3) ; (6,3) <--> (6,4) ; (5,0) <--> (4,0) ; (5,3) <--> (5,2) ; (3,1) <--> (2,1) ; (9,1) <--> (9,0) ; (3,5) <--> (3,6) ; (5,5) <--> (6,5) ; (7,1) <--> (7,2) ; (0,1) <--> (1,1) ; (7,8) <--> (8,8) ; (3,9) <--> (4,9) ; (4,6) <--> (4,7) ; (0,6) <--> (0,7) ; (3,4) <--> (3,5) ; (6,0) <--> (5,0) ; (7,7) <--> (7,6) ; (1,6) <--> (0,6) ; (6,1) <--> (6,0) ; (8,6) <--> (8,7) ; (9,9) <--> (9,8) ; (1,8) <--> (1,9) ; (2,1) <--> (2,2) ; (9,2) <--> (9,3) ; (5,9) <--> (6,9) ; (3,2) <--> (2,2) ; (0,8) <--> (0,9) ; (5,6) <--> (5,7) ; (2,3) <--> (2,4) ; (4,5) <--> (4,4) ; (8,9) <--> (8,8) ; (9,6) <--> (8,6) ; (3,7) <--> (3,8) ; (8,0) <--> (7,0) ; (6,1) <--> (6,2) ; (0,1) <--> (0,0) ; (7,3) <--> (7,4) ; (9,4) <--> (9,3) ; (9,6) <--> (9,5) ; (8,7) <--> (7,7) ; (5,2) <--> (5,1) ; (0,0) <--> (1,0) ; (7,2) <--> (7,3) ; (2,5) <--> (2,6) ; (4,9) <--> (5,9) ; (5,5) <--> (5,4) ; (5,6) <--> (6,6) ; (7,8) <--> (7,9) ; (1,7) <--> (2,7) ; (4,6) <--> (4,5) ; (1,1) <--> (1,2) ; (3,1) <--> (3,0) ; (1,5) <--> (1,6) ; (8,3) <--> (8,4) ; (9,9) <--> (8,9) ; (8,5) <--> (7,5) ; (1,4) <--> (2,4) ; (3,0) <--> (4,0) ; (3,3) <--> (4,3) ; (6,9) <--> (6,8) ; (1,0) <--> (2,0) ; (6,0) <--> (7,0) ; (8,0) <--> (9,0) ; (2,3) <--> (2,2) ; (2,8) <--> (3,8) ; (5,7) <--> (6,7) ; (1,3) <--> (0,3) ; (9,7) <--> (9,8) ; (7,5) <--> (7,4) ; (1,8) <--> (2,8) ; (6,5) <--> (6,4) ; (0,2) <--> (1,2) ; (0,7) <--> (1,7) ; (0,3) <--> (0,2) ; (4,3) <--> (4,2) ; (5,8) <--> (4,8) ; (9,1) <--> (8,1) ; (9,2) <--> (8,2) ; (1,3) <--> (1,4) ; (2,9) <--> (3,9) ; (4,8) <--> (4,7) ; (0,5) <--> (0,4) ; (8,1) <--> (7,1) ; (0,3) <--> (0,4) ; (9,7) <--> (9,6) ; (7,6) <--> (6,6) ; (1,5) <--> (0,5) ; (6,2) (2,1) (6,2) (6,1) (6,0) (5,0) (4,0) (3,0) (3,1) (2,1) ", ascii=( "#####################", "# # # #", "# # # # ### # # #####", "# # # # # # #", "# ####### ##### # # #", "# #E # # # #", "###X# ########### # #", "#XXX# # # #", "#X##### ########### #", "#X# # # #", "#X# ######### ### # #", "#X# # # # #", "#X######### # # ### #", "#XXXXS# # # #", "# ########### #######", "# # # # #", "# # ####### ### # ###", "# # # # # #", "# # # ####### ##### #", "# # #", "#####################", ), straightaway_footprints=np.array( [ [6, 2], [6, 0], [3, 0], [3, 1], [2, 1], ], ), ), longer_10x10=MANUAL_MAZE( tokens=" (8,2) <--> (8,3) ; (3,7) <--> (3,6) ; (6,7) <--> (6,8) ; (4,6) <--> (5,6) ; (9,5) <--> (9,4) ; (3,3) <--> (3,4) ; (5,1) <--> (4,1) ; (2,6) <--> (2,7) ; (8,5) <--> (8,4) ; (1,9) <--> (2,9) ; (4,1) <--> (4,2) ; (0,8) <--> (0,7) ; (5,4) <--> (5,3) ; (6,3) <--> (6,4) ; (5,0) <--> (4,0) ; (5,3) <--> (5,2) ; (3,1) <--> (2,1) ; (9,1) <--> (9,0) ; (3,5) <--> (3,6) ; (5,5) <--> (6,5) ; (7,1) <--> (7,2) ; (0,1) <--> (1,1) ; (7,8) <--> (8,8) ; (3,9) <--> (4,9) ; (4,6) <--> (4,7) ; (0,6) <--> (0,7) ; (3,4) <--> (3,5) ; (6,0) <--> (5,0) ; (7,7) <--> (7,6) ; (1,6) <--> (0,6) ; (6,1) <--> (6,0) ; (8,6) <--> (8,7) ; (9,9) <--> (9,8) ; (1,8) <--> (1,9) ; (2,1) <--> (2,2) ; (9,2) <--> (9,3) ; (5,9) <--> (6,9) ; (3,2) <--> (2,2) ; (0,8) <--> (0,9) ; (5,6) <--> (5,7) ; (2,3) <--> (2,4) ; (4,5) <--> (4,4) ; (8,9) <--> (8,8) ; (9,6) <--> (8,6) ; (3,7) <--> (3,8) ; (8,0) <--> (7,0) ; (6,1) <--> (6,2) ; (0,1) <--> (0,0) ; (7,3) <--> (7,4) ; (9,4) <--> (9,3) ; (9,6) <--> (9,5) ; (8,7) <--> (7,7) ; (5,2) <--> (5,1) ; (0,0) <--> (1,0) ; (7,2) <--> (7,3) ; (2,5) <--> (2,6) ; (4,9) <--> (5,9) ; (5,5) <--> (5,4) ; (5,6) <--> (6,6) ; (7,8) <--> (7,9) ; (1,7) <--> (2,7) ; (4,6) <--> (4,5) ; (1,1) <--> (1,2) ; (3,1) <--> (3,0) ; (1,5) <--> (1,6) ; (8,3) <--> (8,4) ; (9,9) <--> (8,9) ; (8,5) <--> (7,5) ; (1,4) <--> (2,4) ; (3,0) <--> (4,0) ; (3,3) <--> (4,3) ; (6,9) <--> (6,8) ; (1,0) <--> (2,0) ; (6,0) <--> (7,0) ; (8,0) <--> (9,0) ; (2,3) <--> (2,2) ; (2,8) <--> (3,8) ; (5,7) <--> (6,7) ; (1,3) <--> (0,3) ; (9,7) <--> (9,8) ; (7,5) <--> (7,4) ; (1,8) <--> (2,8) ; (6,5) <--> (6,4) ; (0,2) <--> (1,2) ; (0,7) <--> (1,7) ; (0,3) <--> (0,2) ; (4,3) <--> (4,2) ; (5,8) <--> (4,8) ; (9,1) <--> (8,1) ; (9,2) <--> (8,2) ; (1,3) <--> (1,4) ; (2,9) <--> (3,9) ; (4,8) <--> (4,7) ; (0,5) <--> (0,4) ; (8,1) <--> (7,1) ; (0,3) <--> (0,4) ; (9,7) <--> (9,6) ; (7,6) <--> (6,6) ; (1,5) <--> (0,5) ; (6,2) (2,1) (6,2) (6,1) (6,0) (5,0) (4,0) (3,0) (3,1) (2,1) (2,2) (2,3) (2,4) (1,4) (1,3) (0,3) (0,4) (0,5) (1,5) (1,6) (0,6) (0,7) (0,8) ", ascii=( "#####################", "# # XXXXX#XXXXE #", "# # # #X###X#X# #####", "# # #XXX#XXX# # #", "# #######X##### # # #", "# #XXXXXXX# # # #", "###X# ########### # #", "#XXX# # # #", "#X##### ########### #", "#X# # # #", "#X# ######### ### # #", "#X# # # # #", "#X######### # # ### #", "#XXXXS# # # #", "# ########### #######", "# # # # #", "# # ####### ### # ###", "# # # # # #", "# # # ####### ##### #", "# # #", "#####################", ), straightaway_footprints=np.array( [ [6, 2], [6, 0], [3, 0], [3, 1], [2, 1], [2, 4], [1, 4], [1, 3], [0, 3], [0, 5], [1, 5], [1, 6], [0, 6], [0, 8], ], ), ), ) # A list of legacy `MazeTokenizer`s and their `MazeTokenizerModular` equivalents. # Used for unit tests where both versions are supported LEGACY_AND_EQUIVALENT_TOKENIZERS: list[MazeTokenizer | MazeTokenizerModular] = [ *[ MazeTokenizer(tokenization_mode=tok_mode, max_grid_size=20) for tok_mode in TokenizationMode ], *[MazeTokenizerModular.from_legacy(tok_mode) for tok_mode in TokenizationMode], ] ``````{ end_of_file="maze_dataset/testing_utils.py" } ``````{ path="maze_dataset/token_utils.py" } """a whole bunch of utilities for tokenization""" import re import typing import warnings from collections import Counter from typing import Callable, Literal, overload import numpy as np from jaxtyping import Bool, Float, Int, Int8 from muutils.errormode import ErrorMode from muutils.misc import list_join from muutils.misc.sequence import WhenMissing from maze_dataset.constants import ( CARDINAL_MAP, SPECIAL_TOKENS, VOCAB, ConnectionArray, ConnectionList, CoordTup, ) # filtering things from a prompt or generated text # ================================================== def remove_padding_from_token_str(token_str: str) -> str: """remove padding tokens from a joined token string""" token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING} ", "") token_str = token_str.replace(f"{SPECIAL_TOKENS.PADDING}", "") return token_str # noqa: RET504 def tokens_between( tokens: list[str], start_value: str, end_value: str, include_start: bool = False, include_end: bool = False, except_when_tokens_not_unique: bool = False, ) -> list[str]: """given a list `tokens`, get the tokens between `start_value` and `end_value` _extended_summary_ # Parameters: - `tokens : list[str]` - `start_value : str` - `end_value : str` - `include_start : bool` (defaults to `False`) - `include_end : bool` (defaults to `False`) - `except_when_tokens_not_unique : bool` when `True`, raise an error if `start_value` or `end_value` are not unique in the input tokens (defaults to `False`) # Returns: - `list[str]` # Raises: - `ValueError` : if `start_value` and `end_value` are the same - `ValueError` : if `except_when_tokens_not_unique` is `True` and `start_value` or `end_value` are not unique in the input tokens - `ValueError` : if `start_value` or `end_value` are not present in the input tokens """ if start_value == end_value: err_msg: str = f"start_value and end_value cannot be the same: {start_value = } {end_value = }" raise ValueError( err_msg, ) if except_when_tokens_not_unique: if (tokens.count(start_value) != 1) or (tokens.count(end_value) != 1): err_msg: str = ( "start_value or end_value is not unique in the input tokens:" f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }" f"\n{start_value = } {end_value = }" f"\n{tokens = }" ) raise ValueError(err_msg) else: if (tokens.count(start_value) < 1) or (tokens.count(end_value) < 1): err_msg: str = ( "start_value or end_value is not present in the input tokens:" f"\n{tokens.count(start_value) = } {tokens.count(end_value) = }" f"\n{start_value = } {end_value = }" f"\n{tokens = }" ) raise ValueError(err_msg) start_idx: int = tokens.index(start_value) + int(not include_start) end_idx: int = tokens.index(end_value) + int(include_end) assert start_idx < end_idx, "Start must come before end" return tokens[start_idx:end_idx] def get_adj_list_tokens(tokens: list[str]) -> list[str]: "get tokens between ADJLIST_START and ADJLIST_END, without the special tokens themselves" return tokens_between( tokens, SPECIAL_TOKENS.ADJLIST_START, SPECIAL_TOKENS.ADJLIST_END, ) def get_path_tokens(tokens: list[str], trim_end: bool = False) -> list[str]: """The path is considered everything from the first path coord to the path_end token, if it exists.""" if SPECIAL_TOKENS.PATH_START not in tokens: err_msg: str = f"Path start token {SPECIAL_TOKENS.PATH_START} not found in tokens:\n{tokens}" raise ValueError( err_msg, ) start_idx: int = tokens.index(SPECIAL_TOKENS.PATH_START) + int(trim_end) end_idx: int | None = None if trim_end and (SPECIAL_TOKENS.PATH_END in tokens): end_idx = tokens.index(SPECIAL_TOKENS.PATH_END) return tokens[start_idx:end_idx] def get_context_tokens(tokens: list[str]) -> list[str]: "get tokens between ADJLIST_START and PATH_START" return tokens_between( tokens, SPECIAL_TOKENS.ADJLIST_START, SPECIAL_TOKENS.PATH_START, include_start=True, include_end=True, ) def get_origin_tokens(tokens: list[str]) -> list[str]: "get tokens_between ORIGIN_START and ORIGIN_END" return tokens_between( tokens, SPECIAL_TOKENS.ORIGIN_START, SPECIAL_TOKENS.ORIGIN_END, include_start=False, include_end=False, ) def get_target_tokens(tokens: list[str]) -> list[str]: "get tokens_between TARGET_START and TARGET_END" return tokens_between( tokens, SPECIAL_TOKENS.TARGET_START, SPECIAL_TOKENS.TARGET_END, include_start=False, include_end=False, ) def get_cardinal_direction(coords: Int[np.ndarray, "start_end=2 row_col=2"]) -> str: """Returns the cardinal direction token corresponding to traveling from `coords[0]` to `coords[1]`.""" return CARDINAL_MAP[tuple(coords[1] - coords[0])] def get_relative_direction(coords: Int[np.ndarray, "prev_cur_next=3 row_col=2"]) -> str: """Returns the relative first-person direction token corresponding to traveling from `coords[1]` to `coords[2]`. # Parameters - `coords`: Contains 3 Coords, each of which must neighbor the previous Coord. - `coords[0]`: The previous location, used to determine the current absolute direction that the "agent" is facing. - `coords[1]`: The current location - `coords[2]`: The next location. May be equal to the current location. """ if coords.shape != (3, 2): err_msg: str = f"`coords` must have shape (3,2). Got {coords.shape} instead." raise ValueError(err_msg) directions = coords[1:] - coords[:-1] if not np.all(np.linalg.norm(directions, axis=1) <= np.array([1.1, 1.1])): # Use floats as constant since `np.linalg.norm` returns float array err_msg: str = f"Adjacent `coords` must be neighboring or equivalent. Got {coords} instead." raise ValueError( err_msg, ) if np.array_equal(coords[1], coords[2]): return VOCAB.PATH_STAY if np.array_equal(coords[0], coords[2]): return VOCAB.PATH_BACKWARD if np.array_equal(coords[0], coords[1]): err_msg: str = f"Previous first-person direction indeterminate from {coords=}." raise ValueError( err_msg, ) if np.array_equal(directions[0], directions[1]): return VOCAB.PATH_FORWARD directions = np.append( directions, [[0], [0]], axis=1, ) # Augment to represent unit basis vectors in 3D match np.cross(directions[0], directions[1])[-1]: case 1: return VOCAB.PATH_LEFT case -1: return VOCAB.PATH_RIGHT class TokenizerPendingDeprecationWarning(PendingDeprecationWarning): """Pending deprecation warnings related to the `MazeTokenizerModular` upgrade.""" pass def str_is_coord(coord_str: str, allow_whitespace: bool = True) -> bool: """return True if the string represents a coordinate, False otherwise""" warnings.warn( "`util.str_is_coord` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.", TokenizerPendingDeprecationWarning, ) strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x # noqa: E731 coord_str = strip_func(coord_str) return all( [ coord_str.startswith("("), coord_str.endswith(")"), "," in coord_str, all( strip_func(x).isdigit() for x in strip_func(coord_str.lstrip("(").rstrip(")")).split(",") ), ], ) class TokenizerDeprecationWarning(DeprecationWarning): """Deprecation warnings related to the `MazeTokenizerModular` upgrade.""" pass # coordinate to strings # ================================================== def _coord_to_strings_UT(coord: typing.Sequence[int]) -> list[str]: """convert a coordinate to a string: `(i,j)`->"(i,j)" always returns a list of length 1 """ return [f"({','.join(str(c) for c in coord)})"] def _coord_to_strings_indexed(coord: typing.Sequence[int]) -> list[str]: """convert a coordinate to a list of indexed strings: `(i,j)`->"(", "i", ",", "j", ")" always returns a list of length 5 """ return [ "(", *list_join([str(c) for c in coord], lambda: ","), ")", ] def coord_str_to_tuple( coord_str: str, allow_whitespace: bool = True, ) -> tuple[int, ...]: """convert a coordinate string to a tuple""" strip_func: Callable[[str], str] = lambda x: x.strip() if allow_whitespace else x # noqa: E731 coord_str = strip_func(coord_str) stripped: str = strip_func(coord_str.lstrip("(").rstrip(")")) return tuple(int(strip_func(x)) for x in stripped.split(",")) def coord_str_to_coord_np(coord_str: str, allow_whitespace: bool = True) -> np.ndarray: """convert a coordinate string to a numpy array""" return np.array(coord_str_to_tuple(coord_str, allow_whitespace=allow_whitespace)) def coord_str_to_tuple_noneable(coord_str: str) -> CoordTup | None: """convert a coordinate string to a tuple, or None if the string is not a coordinate string""" if not str_is_coord(coord_str): return None return coord_str_to_tuple(coord_str) def coords_string_split_UT(coords: str) -> list[str]: """Splits a string of tokens into a list containing the UT tokens for each coordinate. Not capable of producing indexed tokens ("(", "1", ",", "2", ")"), only unique tokens ("(1,2)"). Non-whitespace portions of the input string not matched are preserved in the same list: "(1,2) (5,6)" -> ["(1,2)", "", "(5,6)"] """ # ty gpt4 return re.findall(r"\([^)]*\)|\S+", coords) # back and forth in wrapped form # ================================================== @overload def strings_to_coords( text: str | list[str], when_noncoord: Literal["skip"] = "skip", ) -> list[CoordTup]: ... @overload def strings_to_coords( text: str | list[str], when_noncoord: Literal["error"] = "error", ) -> list[CoordTup]: ... @overload def strings_to_coords( text: str | list[str], when_noncoord: Literal["include"] = "include", ) -> list[str | CoordTup]: ... def strings_to_coords( text: str | list[str], when_noncoord: WhenMissing = "skip", ) -> list[str | CoordTup]: """converts a list of tokens to a list of coordinates returns list[CoordTup] if `when_noncoord` is "skip" or "error" returns list[str | CoordTup] if `when_noncoord` is "include" """ warnings.warn( "`util.strings_to_coords` only supports legacy UT strings. Function will be replaced with a generalized version in a future release.", TokenizerPendingDeprecationWarning, ) tokens_joined: str = text if isinstance(text, str) else " ".join(text) tokens_processed: list[str] = coords_string_split_UT(tokens_joined) result: list[str] = list() for token in tokens_processed: coord: CoordTup | None = coord_str_to_tuple_noneable(token) if coord is None: if when_noncoord == "skip": continue if when_noncoord == "error": err_msg: str = ( f"Invalid non-coordinate token '{token}' in text: '{text}'" ) raise ValueError( err_msg, ) if when_noncoord == "include": result.append(token) else: err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'" raise ValueError(err_msg) else: result.append(coord) return result @overload def coords_to_strings( coords: list[str | CoordTup], coord_to_strings_func: Callable[[CoordTup], list[str]], when_noncoord: Literal["include", "skip"] = "skip", ) -> list[str]: ... @overload def coords_to_strings( coords: list[CoordTup], coord_to_strings_func: Callable[[CoordTup], list[str]], when_noncoord: Literal["error"] = "error", ) -> list[str]: ... def coords_to_strings( coords: list[str | CoordTup], coord_to_strings_func: Callable[[CoordTup], list[str]], when_noncoord: WhenMissing = "skip", ) -> list[str]: """converts a list of coordinates to a list of strings (tokens) expects list[CoordTup] if `when_noncoord` is "error" expects list[str | CoordTup] if `when_noncoord` is "include" or "skip" """ result: list[str] = list() for coord in coords: if isinstance(coord, str): if when_noncoord == "skip": continue if when_noncoord == "error": err_msg: str = ( f"Invalid non-coordinate '{coord}' in list of coords: '{coords}'" ) raise ValueError( err_msg, ) if when_noncoord == "include": result.append(coord) else: err_msg: str = f"Invalid when_noncoord value '{when_noncoord}'" raise ValueError(err_msg) else: result.extend(coord_to_strings_func(coord)) return result def get_token_regions(toks: list[str]) -> tuple[list[str], list[str]]: """Splits a list of tokens into adjacency list tokens and non-adjacency list tokens.""" adj_list_start, adj_list_end = ( toks.index("") + 1, toks.index(""), ) adj_list = toks[adj_list_start:adj_list_end] non_adj_list = toks[:adj_list_start] + toks[adj_list_end:] return adj_list, non_adj_list def equal_except_adj_list_sequence( # noqa: C901 rollout1: list[str], rollout2: list[str], do_except: bool = False, when_counter_mismatch: ErrorMode = ErrorMode.EXCEPT, when_len_mismatch: ErrorMode = ErrorMode.EXCEPT, ) -> bool: """Returns if the rollout strings are equal, allowing for differently sequenced adjacency lists. and tokens must be in the rollouts. Intended ONLY for determining if two tokenization schemes are the same for rollouts generated from the same maze. This function should NOT be used to determine if two rollouts encode the same `LatticeMaze` object. # Warning: CTT False Positives This function is not robustly correct for some corner cases using `CoordTokenizers.CTT`. If rollouts are passed for identical tokenizers processing two slightly different mazes, a false positive is possible. More specifically, some cases of zero-sum adding and removing of connections in a maze within square regions along the diagonal will produce a false positive. """ if len(rollout1) != len(rollout2): if do_except: when_len_mismatch.process( f"Rollouts are not the same length: {len(rollout1)} != {len(rollout2)}", ) return False if ("" in rollout1) ^ ("" in rollout2): if do_except: err_msg: str = f"Rollouts do not have the same token: `{'' in rollout1 = }` != `{'' in rollout2 = }`" raise ValueError( err_msg, ) return False if ("" in rollout1) ^ ("" in rollout2): if do_except: err_msg: str = f"Rollouts do not have the same token: `{'' in rollout1 = }` != `{'' in rollout2 = }`" raise ValueError( err_msg, ) return False adj_list1, non_adj_list1 = get_token_regions(rollout1) adj_list2, non_adj_list2 = get_token_regions(rollout2) if non_adj_list1 != non_adj_list2: if do_except: when_len_mismatch.process( f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}", ) err_msg: str = f"Non-adjacency list tokens are not the same:\n{non_adj_list1}\n!=\n{non_adj_list2}" raise ValueError( err_msg, ) return False counter1: Counter = Counter(adj_list1) counter2: Counter = Counter(adj_list2) counters_eq: bool = counter1 == counter2 if not counters_eq: if do_except: when_counter_mismatch.process( f"Adjacency list counters are not the same:\n{counter1}\n!=\n{counter2}\n{counter1 - counter2 = }", ) return False return True def connection_list_to_adj_list( conn_list: ConnectionList, shuffle_d0: bool = True, shuffle_d1: bool = True, ) -> Int8[np.ndarray, "conn start_end=2 coord=2"]: """converts a `ConnectionList` (special lattice format) to a shuffled adjacency list # Parameters: - `conn_list: ConnectionList` special internal format for graphs which are subgraphs of a lattice - `shuffle_d0: bool` shuffle the adjacency list along the 0th axis (order of pairs) - `shuffle_d1: bool` shuffle the adjacency list along the 1st axis (order of coordinates in each pair). If `False`, all pairs have the smaller coord first. # Returns: - `Int8[np.ndarray, "conn start_end=2 coord=2"]` adjacency list in the shape `(n_connections, 2, 2)` """ n_connections: int = conn_list.sum() adj_list: Int8[np.ndarray, "conn start_end=2 coord=2"] = np.full( (n_connections, 2, 2), -1, dtype=np.int8, ) if shuffle_d1: flip_d1: Float[np.ndarray, " conn"] = np.random.rand(n_connections) # loop over all nonzero elements of the connection list i: int = 0 for d, x, y in np.ndindex(conn_list.shape): if conn_list[d, x, y]: c_start: CoordTup = (x, y) c_end: CoordTup = ( x + (1 if d == 0 else 0), y + (1 if d == 1 else 0), ) adj_list[i, 0] = np.array(c_start, dtype=np.int8) adj_list[i, 1] = np.array(c_end, dtype=np.int8) # flip if shuffling # magic value is fine here if shuffle_d1 and (flip_d1[i] > 0.5): # noqa: PLR2004 c_s, c_e = adj_list[i, 0].copy(), adj_list[i, 1].copy() adj_list[i, 0] = c_e adj_list[i, 1] = c_s i += 1 if shuffle_d0: np.random.shuffle(adj_list) return adj_list def is_connection( edges: ConnectionArray, connection_list: ConnectionList, ) -> Bool[np.ndarray, "is_connection=edges"]: """Returns if each edge in `edges` is a connection (`True`) or wall (`False`) in `connection_list`.""" sorted_edges = np.sort(edges, axis=1) edge_direction = ( (sorted_edges[:, 1, :] - sorted_edges[:, 0, :])[:, 0] == 0 ).astype(np.int8) return connection_list[edge_direction, sorted_edges[:, 0, 0], sorted_edges[:, 0, 1]] # string to coordinate representation # ================================================== ``````{ end_of_file="maze_dataset/token_utils.py" } ``````{ path="maze_dataset/utils.py" } "misc utilities for the `maze_dataset` package" import math from typing import ( overload, ) import numpy as np from jaxtyping import Bool, Int, Int8 def bool_array_from_string( string: str, shape: list[int], true_symbol: str = "T", ) -> Bool[np.ndarray, "*shape"]: """Transform a string into an ndarray of bools. Parameters ---------- string: str The string representation of the array shape: list[int] The shape of the resulting array true_symbol: The character to parse as True. Whitespace will be removed. All other characters will be parsed as False. Returns ------- np.ndarray A ndarray with dtype bool of shape `shape` Examples -------- >>> bool_array_from_string( ... "TT TF", shape=[2,2] ... ) array([[ True, True], [ True, False]]) """ stripped = "".join(string.split()) expected_symbol_count = math.prod(shape) symbol_count = len(stripped) if len(stripped) != expected_symbol_count: err_msg: str = f"Connection List contains the wrong number of symbols. Expected {expected_symbol_count}. Found {symbol_count} in {stripped}." raise ValueError(err_msg) bools = [(symbol == true_symbol) for symbol in stripped] return np.array(bools).reshape(*shape) def corner_first_ndindex(n: int, ndim: int = 2) -> list[tuple]: """returns an array of indices, sorted by distance from the corner this gives the property that `np.ndindex((n,n))` is equal to the first n^2 elements of `np.ndindex((n+1, n+1))` ``` >>> corner_first_ndindex(1) [(0, 0)] >>> corner_first_ndindex(2) [(0, 0), (0, 1), (1, 0), (1, 1)] >>> corner_first_ndindex(3) [(0, 0), (0, 1), (1, 0), (1, 1), (0, 2), (2, 0), (1, 2), (2, 1), (2, 2)] ``` """ unsorted: list = list(np.ndindex(tuple([n for _ in range(ndim)]))) return sorted(unsorted, key=lambda x: (max(x), x if x[0] % 2 == 0 else x[::-1])) # alternate numpy version from GPT-4: """ # Create all index combinations indices = np.indices([n]*ndim).reshape(ndim, -1).T # Find the max value for each index max_indices = np.max(indices, axis=1) # Identify the odd max values odd_mask = max_indices % 2 != 0 # Make a copy of indices to avoid changing the original one indices_copy = indices.copy() # Reverse the order of the coordinates for indices with odd max value indices_copy[odd_mask] = indices_copy[odd_mask, ::-1] # Sort by max index value, then by coordinates sorted_order = np.lexsort((*indices_copy.T, max_indices)) return indices[sorted_order] """ @overload def manhattan_distance( edges: Int[np.ndarray, "edges coord=2 row_col=2"], ) -> Int8[np.ndarray, " edges"]: ... # TYPING: error: Overloaded function signature 2 will never be matched: signature 1's parameter type(s) are the same or broader [overload-cannot-match] # this is because mypy doesn't play nice with jaxtyping @overload def manhattan_distance( # type: ignore[overload-cannot-match] edges: Int[np.ndarray, "coord=2 row_col=2"], ) -> int: ... def manhattan_distance( edges: ( Int[np.ndarray, "edges coord=2 row_col=2"] | Int[np.ndarray, "coord=2 row_col=2"] ), ) -> Int8[np.ndarray, " edges"] | int: """Returns the Manhattan distance between two coords.""" # magic values for dims fine here if len(edges.shape) == 3: # noqa: PLR2004 return np.linalg.norm(edges[:, 0, :] - edges[:, 1, :], axis=1, ord=1).astype( np.int8, ) elif len(edges.shape) == 2: # noqa: PLR2004 return int(np.linalg.norm(edges[0, :] - edges[1, :], ord=1).astype(np.int8)) else: err_msg: str = f"{edges} has shape {edges.shape}, but must be match the shape in the type hints." raise ValueError(err_msg) def lattice_max_degrees(n: int) -> Int8[np.ndarray, "row col"]: """Returns an array with the maximum possible degree for each coord.""" out = np.full((n, n), 2) out[1:-1, :] += 1 out[:, 1:-1] += 1 return out def lattice_connection_array( n: int, ) -> Int8[np.ndarray, "edges=2*n*(n-1) leading_trailing_coord=2 row_col=2"]: """Returns a 3D NumPy array containing all the edges in a 2D square lattice of size n x n. Thanks Claude. # Parameters - `n`: The size of the square lattice. # Returns np.ndarray: A 3D NumPy array of shape containing the coordinates of the edges in the 2D square lattice. In each pair, the coord with the smaller sum always comes first. """ row_coords, col_coords = np.meshgrid( np.arange(n, dtype=np.int8), np.arange(n, dtype=np.int8), indexing="ij", ) # Horizontal edges horiz_edges = np.column_stack( ( row_coords[:, :-1].ravel(), col_coords[:, :-1].ravel(), row_coords[:, 1:].ravel(), col_coords[:, 1:].ravel(), ), ) # Vertical edges vert_edges = np.column_stack( ( row_coords[:-1, :].ravel(), col_coords[:-1, :].ravel(), row_coords[1:, :].ravel(), col_coords[1:, :].ravel(), ), ) return np.concatenate( (horiz_edges.reshape(n**2 - n, 2, 2), vert_edges.reshape(n**2 - n, 2, 2)), axis=0, ) def adj_list_to_nested_set(adj_list: list) -> set: """Used for comparison of adj_lists Adj_list looks like [[[0, 1], [1, 1]], [[0, 0], [0, 1]], ...] We don't care about order of coordinate pairs within the adj_list or coordinates within each coordinate pair. """ return { frozenset([tuple(start_coord), tuple(end_coord)]) for start_coord, end_coord in adj_list } ``````{ end_of_file="maze_dataset/utils.py" } ``````{ path="notebooks/demo_dataset.ipynb" processed_with="ipynb_to_md" } # Basics to start, let's import a few things we'll need: ```python # other package imports import json import matplotlib.pyplot as plt # keep this import for CI to work from zanj import ZANJ # saving/loading data # maze_dataset imports from maze_dataset import MazeDataset, MazeDatasetConfig, SolvedMaze from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS from maze_dataset.generation import GENERATORS_MAP, LatticeMazeGenerators from maze_dataset.generation.default_generators import DEFAULT_GENERATORS from maze_dataset.plotting import plot_dataset_mazes # check the configs print(MAZE_DATASET_CONFIGS.keys()) # for saving/loading things LOCAL_DATA_PATH: str = "../data/maze_dataset/" zanj: ZANJ = ZANJ(external_list_threshold=256) def pprint_summary(summary: dict) -> None: "pretty print as json" print(json.dumps(summary, indent=2)) ``` You should always see `test-g3-n5-a_dfs-h84385` in the list of available dataset configs above. Now, let's set up our initial config and dataset: ```python cfg: MazeDatasetConfig = MazeDatasetConfig( name="test", # name is only for you to keep track of things grid_n=5, # number of rows/columns in the lattice n_mazes=4, # number of mazes to generate maze_ctor=LatticeMazeGenerators.gen_dfs, # algorithm to generate the maze # there are a few more arguments here, to be discussed later ) # each config will use this function to get the name of the dataset # it contains some basic info about the algorithm, size, and number of mazes # at the end after "h" is a stable hash of the config to avoid collisions print(cfg.to_fname()) ``` ```python # to create a dataset, just call MazeDataset.from_config dataset: MazeDataset = MazeDataset.from_config( # your config cfg, # and all this below is completely optional do_download=False, load_local=False, do_generate=True, save_local=True, local_base_path=LOCAL_DATA_PATH, verbose=True, zanj=zanj, gen_parallel=False, # parallel generation has overhead, not worth it unless you're doing a lot of mazes ) ``` now that we have our dataset, let's take a look at it! ```python plot_dataset_mazes( dataset, count=None, ) # for large datasets, set the count to some int to just plot the first few ``` # Filtering you can also filter datasets by a variety of parameters: ```python dataset_filtered: MazeDataset = dataset.filter_by.path_length(min_length=3) print(f"{len(dataset) = }") print(f"{len(dataset_filtered) = }") ``` ```python plot_dataset_mazes(dataset_filtered) ``` ```python pprint_summary(dataset_filtered.cfg.serialize()["applied_filters"]) print(f"{MazeDataset._FILTER_NAMESPACE = }") ``` ```python # filters can also be specified at generation time -- but it will still generate the whole dataset and then filter it dataset_filtered_from_scratch: MazeDataset = MazeDataset.from_config( dataset_filtered.cfg, do_download=False, load_local=False, do_generate=True, save_local=False, local_base_path=LOCAL_DATA_PATH, verbose=True, zanj=zanj, gen_parallel=False, ) ``` ```python plot_dataset_mazes(dataset_filtered_from_scratch) dataset_filtered_nodupe = dataset_filtered_from_scratch.filter_by.remove_duplicates() plot_dataset_mazes(dataset_filtered_nodupe) ``` ```python dataset_filtered_custom: MazeDataset = dataset.custom_maze_filter( lambda m, p: len(m.solution) == p, p=5, ) plot_dataset_mazes(dataset) plot_dataset_mazes(dataset_filtered_custom) ``` ## metadata by default, each maze stores some metadata about generation in a dictionary. if you don't care about this, you can filter it out (but keep some statistics) to save on storage space: ```python dataset_with_meta = dataset.filter_by.collect_generation_meta() metadata = dataset_with_meta.serialize()["generation_metadata_collected"] metadata["visited_cells"] = "..." # this is a huge list and unweildy to print pprint_summary(metadata) ``` # output formats ```python from maze_dataset.dataset.rasterized import process_maze_rasterized_input_target from maze_dataset.plotting import MazePlot from maze_dataset.plotting.print_tokens import ( color_maze_tokens_AOTP, display_color_maze_tokens_AOTP, ) from maze_dataset.tokenization import MazeTokenizer, TokenizationMode maze: SolvedMaze = dataset[0] # as pixels (what you've already seen) plt.imshow(maze.as_pixels()) # as ascii (useful for debugging) print("ASCII:\n") print(maze.as_ascii()) # as e2h style input/target input_, target = process_maze_rasterized_input_target(maze) fig, ax = plt.subplots(1, 2) ax[0].imshow(input_) ax[1].imshow(target) # remove ticks for a in ax: a.set_xticks([]) a.set_yticks([]) plt.show() # as a MazePlot MazePlot(maze).plot() # as tokens # first, initialize a tokenizer -- more about this in the `notebooks/demo_tokenization.ipynb` notebook tokenizer: MazeTokenizer = MazeTokenizer( tokenization_mode=TokenizationMode.AOTP_UT_rasterized, max_grid_size=100, ) maze_tok = maze.as_tokens(maze_tokenizer=tokenizer) # you can view the tokens directly print("\nRaw tokens:\n") print(" ".join(maze_tok)) # or color and print them in various formats print("\nColored tokens, raw html:\n") print(color_maze_tokens_AOTP(maze_tok, fmt="html")) print("\nColored tokens, raw latex:\n") print(color_maze_tokens_AOTP(maze_tok, fmt="latex")) print("\nColored tokens, terminal:\n") print(color_maze_tokens_AOTP(maze_tok, fmt="terminal")) display_color_maze_tokens_AOTP(maze_tok) ``` # endpoint options ```python for endpoint_kwargs in [ dict(), dict(allowed_start=[(0, 0)]), dict(allowed_end=[(2, 2)]), dict(allowed_start=[(0, 1), (0, 2), (0, 3)]), dict(allowed_start=[(0, 0)], allowed_end=[(4, 4)]), dict(deadend_start=True), dict(deadend_end=True), dict(deadend_start=True, deadend_end=True), dict(deadend_start=True, deadend_end=True, except_on_no_valid_endpoint=False), ]: d = MazeDataset.from_config( MazeDatasetConfig( name="endpoint-test", grid_n=5, n_mazes=4, maze_ctor=LatticeMazeGenerators.gen_dfs, endpoint_kwargs=endpoint_kwargs, ), ) plot_dataset_mazes(d, title=str(endpoint_kwargs)) ``` ```python # endpoint options with percolation d_endpt = MazeDataset.from_config( MazeDatasetConfig( name="endpoint-test", grid_n=5, n_mazes=4, maze_ctor=LatticeMazeGenerators.gen_dfs_percolation, maze_ctor_kwargs=dict(p=0.8), endpoint_kwargs=dict( deadend_start=True, deadend_end=True, except_on_no_valid_endpoint=False, ), ), ) plot_dataset_mazes(d_endpt, title=str(endpoint_kwargs)) ``` # more algorithms there are a bunch of algorithms included, some with various parameters. Here's a few: ```python DATASETS: dict[int, list[MazeDataset]] = dict() for grid_n in [4, 8]: DATASETS[grid_n] = list() for gen_name, gen_kwargs in DEFAULT_GENERATORS: print(f"Generating {gen_name} for grid_n={grid_n}") DATASETS[grid_n].append( MazeDataset.from_config( MazeDatasetConfig( name="demo", maze_ctor=GENERATORS_MAP[gen_name], grid_n=grid_n, n_mazes=8, maze_ctor_kwargs=gen_kwargs, ), local_base_path=LOCAL_DATA_PATH, load_local=False, verbose=False, zanj=zanj, ), ) ``` ```python for ds_list in DATASETS.values(): for ds in ds_list: plot_dataset_mazes(ds, figsize_mult=(2, 4)) ``` ``````{ end_of_file="notebooks/demo_dataset.ipynb" } ``````{ path="notebooks/demo_generator.ipynb" processed_with="ipynb_to_md" } ```python # Imports import matplotlib.pyplot as plt from maze_dataset.generation import LatticeMazeGenerators from maze_dataset.plotting import MazePlot ``` # Generate maze using depth first search Constrain number of accessible cells Only applying this constrain tends to yield mazes with few forks. ```python sample_lattice_maze = LatticeMazeGenerators.gen_dfs( grid_shape=(6, 6), lattice_dim=2, accessible_cells=20, max_tree_depth=None, start_coord=None, ) MazePlot(sample_lattice_maze).plot() plt.show() ``` Constrain maximum tree depth ```python sample_lattice_maze = LatticeMazeGenerators.gen_dfs( grid_shape=(6, 6), lattice_dim=2, accessible_cells=None, max_tree_depth=5, start_coord=(0, 0), ) MazePlot(sample_lattice_maze).plot() plt.show() ``` Shift start coord of DFS algorithm ```python sample_lattice_maze = LatticeMazeGenerators.gen_dfs( grid_shape=(6, 6), lattice_dim=2, accessible_cells=None, max_tree_depth=None, start_coord=None, ) MazePlot(sample_lattice_maze).plot() plt.show() ``` All constraints enabled ```python sample_lattice_maze = LatticeMazeGenerators.gen_dfs( grid_shape=(10, 10), lattice_dim=2, accessible_cells=20, max_tree_depth=5, start_coord=(5, 5), ) MazePlot(sample_lattice_maze).plot() plt.show() ``` ``````{ end_of_file="notebooks/demo_generator.ipynb" } ``````{ path="notebooks/demo_latticemaze.ipynb" processed_with="ipynb_to_md" } # LatticeMaze Demo This notebook contains a tutorial for [LatticeMaze](../maze_dataset/generation/latticemaze.py), the central maze object in the `maze_dataset` library. ```python import matplotlib.pyplot as plt import numpy as np from maze_dataset.generation import LatticeMazeGenerators from maze_dataset.maze import SolvedMaze, TargetedLatticeMaze from maze_dataset.plotting import MazePlot ``` ## Maze representation The maze can be thought of as a grid of nodes, where an edge between nodes represents a path, and the lack of an edge represents a wall. The following generates a 4x4 maze using depth-first search. ```python N: int = 10 maze = LatticeMazeGenerators.gen_dfs(np.array([N, N])) tgt_maze: TargetedLatticeMaze = TargetedLatticeMaze.from_lattice_maze( maze, (0, 0), (N - 1, N - 1), ) solved_maze: SolvedMaze = SolvedMaze.from_targeted_lattice_maze(tgt_maze) fig, ax = plt.subplots(1, 3, figsize=(15, 5)) for ax_i, temp_maze in zip(ax, [maze, tgt_maze, solved_maze], strict=False): ax_i.set_title(temp_maze.as_ascii(), fontfamily="monospace") ax_i.imshow(temp_maze.as_pixels()) x = temp_maze.__class__.from_pixels(temp_maze.as_pixels()) assert temp_maze == temp_maze.__class__.from_pixels(temp_maze.as_pixels()) assert temp_maze == temp_maze.__class__.from_ascii(temp_maze.as_ascii()) plt.show() ``` ### Connection List In the above cell, we can see the canonical representation of the maze, the *connection list*. To understand this representation, consider the following connection list for a 2x2 maze. ``` [ [ # down [F T], [F F] ], [ # right [T F], [T F] ] ] ``` The two matrices in the connection list represent the *downward* and *rightward* connections, respectively. It tells us whether a given node has a connection in that direction. ``` down: N N right: N - N | N N N - N ``` Note that the bottom row connections going down, and the right-hand column connections going right, will always be False. We can superimpose the downward and rightward connections to visualize the maze: ``` N - N | N - N ``` --- Using the same method, we can interpret the connection list for the original maze: ```python maze.connection_list ``` ``` N N - N - N | | N - N - N - N | N - N N - N | | N - N - N - N ``` ### Adjacency list Another common maze representation structure is an adjacency list, which is literally a list of every pair of adjacent nodes in the maze. We can view the adjacency list representation of the graph using `LatticeMaze.as_adjlist` ```python for start, end in maze.as_adj_list(): print(f"({start[0]}, {start[1]}) <--> ({end[0]}, {end[1]})") ``` ## Plotting a maze The `MazePlot` class bundles our plotting functionality. We can use `.show()` to display the maze: ```python MazePlot(maze).plot() plt.show() plt.imshow(maze.as_pixels()) plt.show() print(maze.as_ascii()) ``` Note that the adjacency list contains coordinates in `(row, column)` notation. This is the inverse of Cartesian Coordinates `(x, y)` with a horizontal x-axis. ## Solving the maze algorithmically `LatticeMaze.find_shortest_path` uses the A* algorithm to find the optimal path through the maze. ```python true_path = maze.find_shortest_path(c_start=(0, 0), c_end=(3, 3)) print(f"{true_path =}") ``` We can plot the shortest path with `.add_true_path()`. ```python MazePlot(maze).add_true_path(true_path).plot() plt.show() ``` ## Other Plotting functionality Displaying one or more predicted paths ```python pred_path1 = [(0, 0), (1, 0), (2, 0), (3, 0), (3, 1), (3, 2), (3, 3)] pred_path2 = [(0, 0), (0, 1), (0, 2), (0, 3), (1, 3), (2, 3), (2, 2), (3, 2), (3, 3)] ( MazePlot(maze) .add_true_path(true_path) .add_predicted_path(pred_path1) .add_predicted_path(pred_path2) .plot() ) plt.show() ``` ```python # node_values = np.random.uniform(size=maze.grid_shape) node_values = np.random.randn(*maze.grid_shape) MazePlot(maze).add_node_values(node_values, color_map="bwr").add_true_path( true_path, ).plot() plt.show() ``` ```python MazePlot(maze).add_node_values( node_values, color_map="Blues", target_token_coord=np.array([2, 0]), preceeding_tokens_coords=np.array([[0, 0], [3, 1]]), ).plot() plt.show() ``` Alternatively, plotting multiple paths at once is available. Paths must be of type `CoordArray` or `PathFormat` ```python pred_paths = [pred_path1, pred_path2] MazePlot(maze).add_multiple_paths(pred_paths).plot() plt.show() ``` Plotting a maze as a string (e.g. for quick debugging via commandline) ```python ascii_maze = MazePlot(maze).to_ascii() print(ascii_maze) ``` ``````{ end_of_file="notebooks/demo_latticemaze.ipynb" } ``````{ path="notebooks/demo_mazetokenizermodular.ipynb" processed_with="ipynb_to_md" } # Imports ```python import random import matplotlib.pyplot as plt import pandas as pd import yaml from muutils.misc import shorten_numerical_to_str from tqdm import tqdm from maze_dataset import ( VOCAB, VOCAB_LIST, VOCAB_TOKEN_TO_INDEX, LatticeMazeGenerators, MazeDataset, MazeDatasetConfig, SolvedMaze, ) from maze_dataset.plotting import MazePlot from maze_dataset.tokenization import ( AdjListTokenizers, CoordTokenizers, EdgePermuters, EdgeSubsets, MazeTokenizer, MazeTokenizerModular, PathTokenizers, PromptSequencers, StepSizes, StepTokenizers, TargetTokenizers, TokenizationMode, _TokenizerElement, ) from maze_dataset.tokenization.modular.all_instances import all_instances from maze_dataset.tokenization.modular.all_tokenizers import ( MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, get_all_tokenizers, ) ``` ```python # magic autoreload %load_ext autoreload %autoreload 2 ``` # `MazeTokenizerModular` Initialization and Structure Initialiation can be done vai the default constructor or via `MazeTokenizerModular.from_legacy`. The latter is useful for converting a legacy `MazeTokenizer` into its equivalent `MazeTokenizerModular`. Most of the API for these tokenizers is contained in the `MazeTokenizerModular` class. The only time when users need to interact with the internal components of a `MazeTokenizerModular` is when initializing a non-default tokenizer. ```python mt_default: MazeTokenizerModular = MazeTokenizerModular() mt_ctt: MazeTokenizerModular = MazeTokenizerModular.from_legacy( TokenizationMode.AOTP_CTT_indexed, ) ``` The objects composing `MazeTokenizerModular` are all instances of `_TokenizerElement`. ```python print("\n".join([str(elem) for elem in _TokenizerElement.__subclasses__()])) assert all( issubclass(elem, _TokenizerElement) for elem in _TokenizerElement.__subclasses__() ) ``` Within a tokenizer, these `_TokenizerElement`s are structured in a nested dataclass tree. The tree is slightly different depending on the particular options selected. Below are shown 3 different tree representations of `mt_default`. ```python print("\nAOTP `_TokenizerElement` Structure:\n") print(mt_default.tokenizer_element_tree(abstract=True)) print("Default tokenizer elements:\n") print(mt_default.tokenizer_element_tree()) print("\nDefault tokenizer `name`:\n") print(mt_default.name) print("`MazeTokenizerModular` structure with all fields:\n") print(yaml.dump(mt_default.tokenizer_element_dict())) ``` There are currently no other constructor methods. To construct a `MazeTokenizerModular` with other `TokenizerElement`s besides those available via `from_legacy`, the standard constructor with all parent `TokenizerElement`s in the tree must be used. Some `TokenizerElement`s also contain their own initialization arguments, most of which are `boolean`-typed. The most common arguments across all `TokenizerElement`s are named `pre`, `intra`, and `post`, which all control the option to add delimiter tokens to that part of the output. Other args are more specialized; see the class docstrings for more details. # Vocabulary All instances of `MazeTokenizerModular` uses a static vocabulary `VOCAB`, which is one of the main functional differences from `MazeTokenizer`. Direct access to the static vocabulary can be made through 3 constants: - `VOCAB` - Extension of the `SPECIAL_TOKENS` dataclass - Supports direct property attribution - `VOCAB_LIST: list[str]` - Contains the vocabulary in a list - Index of a token is its unique ID - `VOCAB_TOKEN_TO_INDEX: dict[str, int]` - Inverse mapping of `VOCAB_LIST`, maps tokens to unique IDs The following shows a visualization of the first 5 elements of each constant. ```python print("`VOCAB`: IsDataclass") for i, t in enumerate(VOCAB): if i >= 5: break print(f"\tVOCAB.{t} =\t'{getattr(VOCAB, t)}'") print("\t...") print("\n`VOCAB_LIST`: list[str]") for t in VOCAB_LIST[:5]: print(f"\t'{t}'") print("\t...") print("\n`VOCAB_TOKEN_TO_INDEX`: dict[str, int]") for t in VOCAB_TOKEN_TO_INDEX: if VOCAB_TOKEN_TO_INDEX[t] >= 5: break print(f"\t'{t}': \t{VOCAB_TOKEN_TO_INDEX[t]}") print("\t...") ``` ### Considerations of Static Vocabulary - No more rasterized vs uniform indexing, it's all fixed as uniform now - Fixed max grid size - There is now a fixed maximum maze size which is supported. - Unique tokens (`CoordTokenizers.UT`): 50x50 - Coordinate tuple tokens (`CoordTokenizers.CTT`): 128x128 - Mazes larger than these sizes are not supported - There should be fewer compatibility issues with tokenizers using different `max_grid_size` parameters - Vocabulary access - Since maze-dataset 1.0, there is no need to pass around a tokenizer object or any data structure to access its custom vocabulary ### Refactoring your code from legacy `MazeTokenizer` and `TokenizationMode` Since `MazeTokenizerModular` uses a static vocabulary, it is not backwards compatible with any models trained using a legacy `MazeTokenizer`. The `maze-transformer` library is updated in vX.X.X to use `MazeTokenizerModular` by default. If you've manually specified a `MazeTokenizer` or `TokenizationMode` in your research code, the easiest way to refactor is using `MazeTokenizerModular.from_legacy`, which will convert a `MazeTokenizer` or `TokenizationMode` to its corresponding `MazeTokenizerModular` instance. Note that this correspondence means only that the stringification of mazes are equivalent; the encodings of strings to integer vocabulary indices are not. ```python legacy_maze_tokenizer: MazeTokenizer = ( TokenizationMode.AOTP_UT_uniform.to_legacy_tokenizer() ) modular_tokenizer_equivalent: MazeTokenizerModular = MazeTokenizerModular.from_legacy( legacy_maze_tokenizer, ) print(legacy_maze_tokenizer, "\n", modular_tokenizer_equivalent) ``` ## `get_all_tokenizers` Most combinations of `TokenizerElement`s and their arguments will produce a valid and unique `MazeTokenizerModular`. However, it is not guaranteed that every possible `MazeTokenizerModular` that can be constructed will make practical sense or have been put through testing. `get_all_tokenizers` constructs and caches all the tested tokenizers at once. For research investigating many different tokenization schemes, one practical way to access them is by looping through/sampling from `get_all_tokenizers()`. Be aware that the indexing of specific tokenizers may change without notice. ```python all_tokenizers = get_all_tokenizers() ``` ```python print( f"{len(all_tokenizers)} or {shorten_numerical_to_str(len(all_tokenizers))} tokenizers found.", ) ``` Other possible tokenizers which aren't in `get_all_tokenizers` are not guaranteed to function. Instead of running the expensive call to `get_all_tokenizers` yourself, you can check if a tokenizer is tested using `MazeTokenizerModular.is_tested_tokenizer` or `MazeTokenizerModular.is_valid`. ```python assert mt_default.is_tested_tokenizer(do_except=True) assert mt_default.is_valid() assert mt_ctt.is_tested_tokenizer() assert mt_ctt.is_valid() custom_untested_tokenizer = MazeTokenizerModular( prompt_sequencer=PromptSequencers.AOP( path_tokenizer=PathTokenizers.StepSequence( step_tokenizers=(StepTokenizers.Distance(),), ), ), ) assert not custom_untested_tokenizer.is_tested_tokenizer() assert not custom_untested_tokenizer.is_valid() # Danger, use this tokenizer at your own risk! ``` this uses below file, shipped with the package, to keep track of which tokenizer names are valid. the code for generating them is in `maze_dataset.tokenization.modular.fst` ```python from maze_dataset.tokenization.modular.fst_load import MMT_FST_PATH print(f"{MMT_FST_PATH = }") print(f"{MMT_FST_PATH.stat().st_size = }") ``` ```python # we can also use `check_tokenizer_in_fst` manually, and if it cant find a tokenizer it will give us similar ones from maze_dataset.tokenization.modular.fst_load import check_tokenizer_in_fst print(mt_default.name) mt_name_modified: str = mt_default.name.replace( "ConnectionEdges(walls=F),", "ConnectionEdges(walls=X)," ) print(mt_name_modified) try: check_tokenizer_in_fst(mt_name_modified, do_except=True) except Exception as e: # noqa: BLE001 print("[ERROR]: ", e) ``` # Filtering Tokenizer Collections There are a several practical ways to filter down a collection of tokenizers, or alternatively, generate a new collection with a filter. **WARNING: Applying `filter` to the output of `get_all_tokenizers` is extremely slow due to the size of the initial population. Only use the first 3 methods for filtering much smaller collections of tokenizers. To generate a new collection based on filters, always use `utils.all_instances`** In order of increasing speed, power and decreasing syntactic concision: 1. `MazeTokenizerModular.has_element` - Use case: Use with `filter` for concise, basic filtering on an existing collection 1. `MazeTokenizerModular.tokenizer_elements` - Use case: Use with `filter` for more precise filtering on an existing collection 1. `MazeTokenizerModular.summary` - Use case: Use with `filter` for more precise filtering on an existing collection 1. `utils.all_instances` - Use case: Generate a new collection with filter(s). - Anytime you don't already have a small collection of tokenizers as the starting population. ```python len_all = len(get_all_tokenizers()) ``` ```python filtered_1: list[MazeTokenizerModular] = list( all_instances( MazeTokenizerModular, { **MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, # Always include this as the first item in the dict whenever calling `all_instances` with `MazeTokenizerModular` or any `_TokenizerElement` CoordTokenizers._CoordTokenizer: lambda x: isinstance( x, CoordTokenizers.UT, ), StepTokenizers.StepTokenizerPermutation: lambda x: x[0] == StepTokenizers.Cardinal() and len(x) < 3, AdjListTokenizers._AdjListTokenizer: lambda x: isinstance( x, AdjListTokenizers.AdjListCardinal, ), EdgeSubsets._EdgeSubset: lambda x: x == EdgeSubsets.ConnectionEdges(walls=False), }, ), ) filtered_2: list[MazeTokenizerModular] = list( all_instances( MazeTokenizerModular, { **MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, # Always include this as the first item in the dict whenever calling`all_instances` with `MazeTokenizerModular` or any `_TokenizerElement` _TokenizerElement: lambda x: x.is_valid() and not getattr(x, "pre", False) and not getattr(x, "intra", False) and not getattr(x, "post", False), # Minimal delimiters everywhere... CoordTokenizers.CTT: lambda x: x.pre and x.intra and x.post, # ...except for the coord tokens }, ), ) filtered_3: list[MazeTokenizerModular] = list( all_instances( MazeTokenizerModular, { **MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, # Always include this as the first item in the dict whenever calling `all_instances` with `MazeTokenizerModular` or any `_TokenizerElement` PromptSequencers._PromptSequencer: lambda x: isinstance( x, PromptSequencers.AOTP, ), TargetTokenizers._TargetTokenizer: lambda x: x == TargetTokenizers.Unlabeled(), StepSizes.Singles: lambda x: False, # noqa: ARG005 }, ), ) print(f"filtered 1: {len(filtered_1)} tokenizers / {len_all} tokenizers") print(f"filtered 2: {len(filtered_2)} tokenizers / {len_all} tokenizers") print(f"filtered 3: {len(filtered_3)} tokenizers / {len_all} tokenizers") ``` The examples below show equivalent methods of filtering one of the smaller collections above using options 1-3. ```python filtered_has_element: list[MazeTokenizerModular] = list( filter(lambda x: x.has_element(EdgePermuters.BothCoords()), filtered_1), ) filtered_tokenizer_elements: list[MazeTokenizerModular] = list( filter(lambda x: EdgePermuters.BothCoords() in x.tokenizer_elements, filtered_1), ) filtered_summary: list[MazeTokenizerModular] = list( filter( lambda x: x.summary()["edge_permuter"] == EdgePermuters.BothCoords().name, filtered_1, ), ) print(f"filtered: {len(filtered_has_element)} tokenizers / {len_all} tokenizers") assert set(filtered_has_element) == set(filtered_tokenizer_elements) print(f"{set(filtered_has_element).symmetric_difference(set(filtered_summary)) = }") assert set(filtered_has_element) == set(filtered_summary) ``` # TokenizerElement Behavior Reference For each primary `TokenizerElement`, tokenizations and encodings derived from the below maze are logged in DataFrames for reference. ```python cfg: MazeDatasetConfig = MazeDatasetConfig( name="test", grid_n=3, n_mazes=1, maze_ctor=LatticeMazeGenerators.gen_dfs, ) dataset: MazeDataset = MazeDataset.from_config( cfg, do_download=False, load_local=False, do_generate=True, save_local=False, verbose=True, gen_parallel=False, ) ``` ```python pd.set_option("display.max_colwidth", None) mz: SolvedMaze = dataset[0] MazePlot(mz).plot() plt.show() ``` ```python def all_elements_df( elem_type: type[_TokenizerElement], encoding: bool = True, **to_tokens_kwargs, ) -> pd.DataFrame: columns = ["_TokenizerElement", "tokens"] if encoding: columns.append("encoding") tokenizers: pd.DataFrame = pd.DataFrame(columns=columns) tokenizers["_TokenizerElement"] = list( all_instances( elem_type, validation_funcs=MAZE_TOKENIZER_MODULAR_DEFAULT_VALIDATION_FUNCS, ), ) tokenizers["tokens"] = tokenizers["_TokenizerElement"].apply( lambda x: " ".join(x.to_tokens(**to_tokens_kwargs)), ) if encoding: tokenizers["encoding"] = tokenizers["tokens"].apply( lambda x: MazeTokenizerModular.encode(x), ) return tokenizers ``` ## `CoordTokenizers` ```python coord_tokenizers = all_elements_df( CoordTokenizers._CoordTokenizer, coord=mz.solution[0], ) coord_tokenizers ``` ## Adjacency List Tokenizers ```python adjlist_tokenizers = all_elements_df( AdjListTokenizers._AdjListTokenizer, encoding=False, maze=mz, coord_tokenizer=CoordTokenizers.UT(), ) adjlist_tokenizers ``` ## Target Tokenizers ```python target_tokenizers = all_elements_df( TargetTokenizers._TargetTokenizer, targets=[mz.end_pos], coord_tokenizer=CoordTokenizers.UT(), ) target_tokenizers ``` ## Path Tokenizers ```python path_tokenizers = all_elements_df( PathTokenizers._PathTokenizer, maze=mz, coord_tokenizer=CoordTokenizers.UT(), ) path_tokenizers ``` ## Prompt Sequencers Currently, the only difference in possible prompt sequencers is the inclusion/exclusion of target tokens. ```python prompt_sequencers = [PromptSequencers.AOTP(), PromptSequencers.AOP()] columns = ["_TokenizerElement", "tokens"] tokenizers: pd.DataFrame = pd.DataFrame(columns=columns) tokenizers["_TokenizerElement"] = prompt_sequencers tokenizers["tokens"] = tokenizers["_TokenizerElement"].apply( lambda x: " ".join(x.to_tokens(maze=mz)), ) tokenizers ``` ## Random Sample of `MazeTokenizerModular`s ```python random_sample_size: int = 1_000 tokenizers: list[MazeTokenizerModular] = random.sample( get_all_tokenizers(), random_sample_size, ) columns = ["MazeTokenizerModular", "tokens", "encoding", *mt_default.summary().keys()] df: pd.DataFrame = pd.DataFrame(columns=columns) df["MazeTokenizerModular"] = tokenizers df["tokens"] = df["MazeTokenizerModular"].apply( lambda x: " ".join(x.to_tokens(maze=mz)), ) df.encoding = df.tokens.apply(MazeTokenizerModular.encode) ``` ```python for k in tqdm( mt_default.summary().keys(), desc="Tokenizers", total=len(mt_default.summary()), ): df[k] = df.apply( lambda x: x.MazeTokenizerModular.summary().get(k, None), # noqa: B023 axis=1, ) pd.set_option("display.max_colwidth", 50) df ``` ```python ``` ``````{ end_of_file="notebooks/demo_mazetokenizermodular.ipynb" } ``````{ path="notebooks/demo_tokenization.ipynb" processed_with="ipynb_to_md" } ```python import matplotlib.pyplot as plt import numpy as np from maze_dataset import ( LatticeMazeGenerators, MazeDataset, MazeDatasetConfig, SolvedMaze, ) from maze_dataset.plotting import plot_dataset_mazes from maze_dataset.plotting.print_tokens import ( display_color_maze_tokens_AOTP, display_color_tokens_cmap, display_color_tokens_rgb, ) from maze_dataset.tokenization import MazeTokenizer, TokenizationMode from maze_dataset.utils import corner_first_ndindex ``` Let's get a basic dataset first: ```python CFG: MazeDatasetConfig = MazeDatasetConfig( name="test", grid_n=5, n_mazes=5, maze_ctor=LatticeMazeGenerators.gen_dfs, ) ``` ```python DATASET: MazeDataset = MazeDataset.from_config( CFG, local_base_path="../data/maze_dataset/", ) ``` ```python plot_dataset_mazes(DATASET) ``` Now, let's see how tokenization works: ```python TOKENIZER: MazeTokenizer = MazeTokenizer( tokenization_mode=TokenizationMode.AOTP_UT_rasterized, max_grid_size=100, ) TOKENIZER_INDEXED: MazeTokenizer = MazeTokenizer( tokenization_mode=TokenizationMode.AOTP_CTT_indexed, max_grid_size=100, ) ``` ```python STRINGIFIED: list[str] = DATASET.as_tokens(TOKENIZER, join_tokens_individual_maze=True) STRINGIFIED_INDEXED: list[str] = DATASET.as_tokens( TOKENIZER_INDEXED, join_tokens_individual_maze=True, ) print("Rasterized:\n" + "\n".join(STRINGIFIED)) print("\nIndexed:\n" + "\n".join(STRINGIFIED_INDEXED)) ``` ```python x = STRINGIFIED[0].split() display_color_tokens_rgb(x, np.random.randint(0, 255, (len(x), 3))) display_color_tokens_cmap(x, np.random.randint(0, 255, len(x))) display_color_maze_tokens_AOTP(x) ``` Now do the same for `TokenizerMode.AOTP_CTT_indexed`. ```python x = STRINGIFIED_INDEXED[0].split() display_color_tokens_rgb(x, np.random.randint(0, 255, (len(x), 3))) display_color_tokens_cmap(x, np.random.randint(0, 255, len(x))) display_color_maze_tokens_AOTP(x) ``` now let's see how we can take the actual tokenized data to a `SolvedMaze`. This is only possible with legacy tokenizers or their `MazeTokenizerModular` equivalents. ```python maze_toks: list[str] = ( """ (1,1) <--> (2,1) ; (2,0) <--> (1,0) ; (0,1) <--> (0,0) ; (2,2) <--> (2,1) ; (2,0) <--> (2,1) ; (0,2) <--> (1,2) ; (0,0) <--> (1,0) ; (0,2) <--> (0,1) ; (0,0) (2,1) (0,0) (1,0) (2,0) (2,1) """.split() ) maze_encoded: list[int] = TOKENIZER.encode(maze_toks) maze_tok_roundtrip: list[str] = TOKENIZER.decode(maze_encoded) assert maze_toks == maze_tok_roundtrip maze_from_toks: SolvedMaze = SolvedMaze.from_tokens(maze_toks, TOKENIZER) print(maze_from_toks.as_ascii()) print(" ".join(maze_from_toks.as_tokens(TOKENIZER))) ``` Now do the same for the the `CTT` tokenizer. ```python maze_toks_indexed: list[str] = ( """ ( 1 , 1 ) <--> ( 2 , 1 ) ; ( 2 , 0 ) <--> ( 1 , 0 ) ; ( 0 , 1 ) <--> ( 0 , 0 ) ; ( 2 , 2 ) <--> ( 2 , 1 ) ; ( 2 , 0 ) <--> ( 2 , 1 ) ; ( 0 , 2 ) <--> ( 1 , 2 ) ; ( 0 , 0 ) <--> ( 1 , 0 ) ; ( 0 , 2 ) <--> ( 0 , 1 ) ; ( 0 , 0 ) ( 2 , 1 ) ( 0 , 0 ) ( 1 , 0 ) ( 2 , 0 ) ( 2 , 1 ) """.split() ) maze_encoded: list[int] = TOKENIZER_INDEXED.encode(maze_toks_indexed) maze_tok_roundtrip: list[str] = TOKENIZER_INDEXED.decode(maze_encoded) assert maze_toks_indexed == maze_tok_roundtrip maze_from_toks_indexed: SolvedMaze = SolvedMaze.from_tokens( maze_toks_indexed, TOKENIZER_INDEXED, ) assert maze_from_toks_indexed == maze_from_toks print(maze_from_toks_indexed.as_ascii()) print(" ".join(maze_from_toks_indexed.as_tokens(TOKENIZER_INDEXED))) ``` # Vocab index special tokens come first, but then there are a few choices for the rest of the tokens: - `TokenizationMode.AOTP_UT_rasterized`: unique token for each coord, order is simple rasterization - `TokenizationMode.AOTP_UT_uniform`: unique token for each coord, order assembled to preserve uniformity regardless of maze size - `TokenizationMode.AOTP_CTT_indexed`: each coordinate is 5 tokens: `( i , j )` where `i` and `j` are the coordinates ```python def plot_corner_first_ndindex(n: int, ndim: int = 2) -> None: """Plot a figure that shows the order of each grid point in the list provided by the function corner_first_ndindex. """ indices = corner_first_ndindex(n, ndim) # Create a 2D grid to store the order of each index grid = np.zeros((n, n), dtype=int) for order, (x, y) in enumerate(indices): grid[x, y] = order + 1 # Adding 1 to start the order from 1 instead of 0 fig, ax = plt.subplots(figsize=(2, 2)) # Plot the grid cax = ax.matshow(grid, cmap=plt.cm.Blues) # Annotate each cell with its order for i in range(n): for j in range(n): c = grid[j, i] ax.text(i, j, str(c), va="center", ha="center") plt.title("Order of Grid Points in Vocabulary") plt.xlabel("X-axis") plt.ylabel("Y-axis") plt.colorbar(cax) # plt.savefig("corner-first-vocab.pdf") plt.show() # Example plot for n=3 plot_corner_first_ndindex(5) ``` ``````{ end_of_file="notebooks/demo_tokenization.ipynb" } ``````{ path="notebooks/estimate_dataset_fractions.ipynb" processed_with="ipynb_to_md" } ```python from pathlib import Path import numpy as np # import pysr before torch to avoid # UserWarning: torch was imported before juliacall. This may cause a segfault. To avoid this, import juliacall before importing torch. For updates, see https://github.com/pytorch/pytorch/issues/78829. import pysr # noqa: F401 from zanj import ZANJ from maze_dataset import LatticeMazeGenerators, MazeDataset, MazeDatasetConfig from maze_dataset.benchmark.config_sweep import ( SweepResult, dataset_success_fraction, full_percolation_analysis, plot_grouped, ) from maze_dataset.benchmark.sweep_fit import sweep_fit ``` # run a basic analysis ```python # Run the analysis results: SweepResult = SweepResult.analyze( configs=[ MazeDatasetConfig( name=f"g{grid_n}-perc", grid_n=grid_n, n_mazes=32, maze_ctor=LatticeMazeGenerators.gen_percolation, maze_ctor_kwargs=dict(), endpoint_kwargs=dict( deadend_start=False, deadend_end=False, endpoints_not_equal=False, except_on_no_valid_endpoint=False, ), ) for grid_n in [2, 4, 6] ], param_values=np.linspace(0.0, 1.0, 16).tolist(), param_key="maze_ctor_kwargs.p", analyze_func=dataset_success_fraction, parallel=False, ) # Plot results results.plot(save_path=None, cfg_keys=["n_mazes", "endpoint_kwargs"]) ``` ## check saving/loading ```python path = Path("../tests/_temp/dataset_frac_sweep/results_small.zanj") results.save(path) ZANJ().read(path).plot(cfg_keys=["n_mazes", "endpoint_kwargs"]) ``` # sweep acrossall endpoint kwargs and generator funcs ```python results_sweep: SweepResult = full_percolation_analysis( n_mazes=16, p_val_count=11, grid_sizes=[2, 4, 6], parallel=False, save_dir=Path("tests/_temp/dataset_frac_sweep"), ) ``` ```python results_medium: SweepResult = SweepResult.read( "../docs/benchmarks/percolation_fractions/medium/result-n128-c42-p50.zanj", # "../docs/benchmarks/percolation_fractions/large/result-n256-c54-p100.zanj" ) ``` ```python plot_grouped( results_medium, predict_fn=lambda x: x.success_fraction_estimate(), prediction_density=100, # for paper version # figsize=(11, 4), # save_fmt="pdf", # save_dir=Path("../docs/paper/figures/ep"), # minify_title=True, # legend_kwargs=dict(loc="center right", ncol=1, bbox_to_anchor=(1.14, 0.5)), # manual_titles=dict( # x="percolation probability $p$", # y="success fraction", # title="", # ), ) ``` # perform a pysr regression on a dataset we load ```python DATA_PATH_DIR: Path = Path("../docs/benchmarks/percolation_fractions/") # DATA_PATH: str = DATA_PATH_DIR / "large/result-n256-c54-p100.zanj" # DATA_PATH: str = DATA_PATH_DIR / "medium/result-n128-c42-p50.zanj" DATA_PATH: str = DATA_PATH_DIR / "small/result-n64-c30-p25.zanj" # DATA_PATH: str = DATA_PATH_DIR / "test/result-n16-c12-p16.zanj" sweep_fit( DATA_PATH, Path("tests/_temp/fit_plots/"), niterations=3, ) ``` # interactive plots for figuring out `maze_dataset.math.soft_step()` ```python # Run the interactive visualization if in a Jupyter notebook if "__vsc_ipynb_file__" in globals(): from maze_dataset.benchmark.sweep_fit import create_interactive_plot create_interactive_plot(True) ``` ```python cfg = MazeDatasetConfig( name="test", seed=3, grid_n=5, n_mazes=10, maze_ctor=LatticeMazeGenerators.gen_dfs_percolation, maze_ctor_kwargs=dict(p=0.7), endpoint_kwargs=dict( deadend_start=True, # deadend_end=True, endpoints_not_equal=True, except_on_no_valid_endpoint=False, ), ) print(f"{cfg.success_fraction_estimate() = }") cfg_new = cfg.success_fraction_compensate() print(f"{cfg_new.n_mazes = }") ``` ```python len(MazeDataset.from_config(cfg_new)) ``` ``````{ end_of_file="notebooks/estimate_dataset_fractions.ipynb" } ``````{ path="notebooks/forking_points.ipynb" processed_with="ipynb_to_md" } ```python # other package imports import matplotlib.pyplot as plt # keep this import for CI to work from zanj import ZANJ # saving/loading data # maze_dataset imports from maze_dataset import MazeDataset, MazeDatasetConfig from maze_dataset.dataset.configs import MAZE_DATASET_CONFIGS from maze_dataset.generation import LatticeMazeGenerators from maze_dataset.plotting import MazePlot # check the configs print(MAZE_DATASET_CONFIGS.keys()) # for saving/loading things LOCAL_DATA_PATH: str = "../data/maze_dataset/" zanj: ZANJ = ZANJ(external_list_threshold=256) ``` ```python # magic autoreload %load_ext autoreload %autoreload 2 ``` ```python cfg: MazeDatasetConfig = MazeDatasetConfig( name="test", # name is only for you to keep track of things grid_n=5, # number of rows/columns in the lattice n_mazes=4, # number of mazes to generate maze_ctor=LatticeMazeGenerators.gen_dfs, # algorithm to generate the maze ) # each config will use this function to get the name of the dataset # it contains some basic info about the algorithm, size, and number of mazes # at the end after "h" is a stable hash of the config to avoid collisions print(cfg.to_fname()) ``` ```python data = MazeDataset.from_config(cfg) maze = data[0] ``` ```python print(maze.as_ascii()) plt.imshow(maze.as_pixels()) ``` ```python fork_idxs, fork_coords = maze.get_solution_forking_points() follow_idxs, follow_coords = maze.get_solution_path_following_points() print( dict( fork_idxs=fork_idxs, fork_coords=fork_coords.tolist(), follow_idxs=follow_idxs, follow_coords=follow_coords.tolist(), ), ) ``` ```python mp: MazePlot = MazePlot(maze) mp.mark_coords(fork_coords, color="green", marker="s") mp.mark_coords(follow_coords) print(mp.marked_coords) mp.plot() ``` ``````{ end_of_file="notebooks/forking_points.ipynb" } ``````{ path="notebooks/iterated_backfilling.ipynb" processed_with="ipynb_to_md" } ```python import matplotlib.pyplot as plt import numpy as np from maze_dataset import ( CoordTup, LatticeMazeGenerators, MazeDataset, MazeDatasetConfig, ) from maze_dataset.maze import TargetedLatticeMaze from maze_dataset.maze.lattice_maze import _remove_isolated_cells ``` ```python def iterated_backfilling(maze: TargetedLatticeMaze) -> TargetedLatticeMaze: """Perform iterated backfilling on a TargetedLatticeMaze object. This algorithm iteratively removes dead ends (nodes with only one neighbor) that are not the start or target nodes until no more such nodes exist. Args: maze (TargetedLatticeMaze): The input maze to perform backfilling on. Returns: TargetedLatticeMaze: A new TargetedLatticeMaze object with dead ends removed. """ # Create a copy of the connection list to modify new_connection_list = maze.connection_list.copy() # Create a temporary TargetedLatticeMaze object for using its methods temp_maze = TargetedLatticeMaze( connection_list=new_connection_list, start_pos=maze.start_pos, end_pos=maze.end_pos, ) changed = True while changed: changed = False for i in range(maze.grid_shape[0]): for j in range(maze.grid_shape[1]): pos = (i, j) if _should_remove_node(temp_maze, pos): _remove_node(new_connection_list, pos) changed = True # Update the temporary maze with the new connection list temp_maze = TargetedLatticeMaze( connection_list=new_connection_list, start_pos=maze.start_pos, end_pos=maze.end_pos, ) return TargetedLatticeMaze( connection_list=new_connection_list, start_pos=maze.start_pos, end_pos=maze.end_pos, ) def _should_remove_node(maze: TargetedLatticeMaze, pos: CoordTup) -> bool: """Check if a node should be removed.""" if pos == tuple(maze.start_pos) or pos == tuple(maze.end_pos): return False neighbors = maze.get_coord_neighbors(np.array(pos)) return len(neighbors) == 1 def _remove_node(connection_list: np.ndarray, pos: CoordTup) -> None: """Remove a node by disconnecting all its connections.""" i, j = pos # Remove up connection if i > 0: connection_list[0, i - 1, j] = False # Remove down connection if i < connection_list.shape[1] - 1: connection_list[0, i, j] = False # Remove left connection if j > 0: connection_list[1, i, j - 1] = False # Remove right connection if j < connection_list.shape[2] - 1: connection_list[1, i, j] = False ``` ```python cfg: MazeDatasetConfig = MazeDatasetConfig( name="test", # name is only for you to keep track of things grid_n=10, # number of rows/columns in the lattice n_mazes=4, # number of mazes to generate maze_ctor=LatticeMazeGenerators.gen_dfs_percolation, # algorithm to generate the maze maze_ctor_kwargs={"p": 0.01}, # keyword arguments to pass to the maze ) # to create a dataset, just call MazeDataset.from_config dataset: MazeDataset = MazeDataset.from_config(cfg) ``` ```python maze = dataset[0] plt.imshow(maze.as_pixels()) ``` ```python maze_bf = iterated_backfilling(maze) plt.imshow(_remove_isolated_cells(maze_bf.as_pixels())) ``` ``````{ end_of_file="notebooks/iterated_backfilling.ipynb" } ``````{ path="notebooks/output_formats.ipynb" processed_with="ipynb_to_md" } ```python import matplotlib.pyplot as plt from IPython.display import SVG, display # noqa: A004 from maze_dataset import LatticeMazeGenerators, MazeDataset, MazeDatasetConfig from maze_dataset.plotting.plot_svg_fancy import plot_svg_fancy cfg = MazeDatasetConfig( name="test", grid_n=5, n_mazes=5, maze_ctor=LatticeMazeGenerators.gen_dfs, ) ds = MazeDataset.from_config(cfg) maze = ds[0] pixels = maze.as_pixels() ``` ```python plt.imshow(pixels) plt.axis("off") ``` ```python svg_string: str = plot_svg_fancy(pixels) display(SVG(svg_string)) ``` ```python from maze_dataset.plotting import MazePlot MazePlot(maze).plot(plain=True) ``` ```python print(maze.as_ascii()) ``` ```python from maze_dataset.plotting.print_tokens import ( color_maze_tokens_AOTP, ) from maze_dataset.tokenization import MazeTokenizer print(color_maze_tokens_AOTP(maze.as_tokens(MazeTokenizer()), "html")) ``` ``````{ end_of_file="notebooks/output_formats.ipynb" } ``````{ path="notebooks/profile_dataset_save_read.ipynb" processed_with="ipynb_to_md" } # Profiling of `maze_dataset` serializing/loading/saving/reading ```python import copy import itertools import warnings from typing import Any, Callable, Sequence import matplotlib.pyplot as plt import numpy as np import pandas as pd from muutils.statcounter import StatCounter from muutils.timeit_fancy import FancyTimeitResult, timeit_fancy from maze_dataset import ( MazeDataset, MazeDatasetConfig, set_serialize_minimal_threshold, ) from maze_dataset.generation.generators import GENERATORS_MAP ``` ## Generate Datasets ```python cfgs: list[MazeDatasetConfig] = [ MazeDatasetConfig( name="test", grid_n=grid_n, n_mazes=n_mazes, maze_ctor=GENERATORS_MAP["gen_dfs"], ) for grid_n, n_mazes in itertools.product( [10], np.logspace(1, 2, 2, dtype=int).tolist(), # 100, for CI tests # np.logspace(1, 3, 5, dtype=int).tolist(), # 1k # np.logspace(0, 4, 9, dtype=int).tolist(), # 10k, notebook results from this set ) ] datasets: list[MazeDataset] = [ MazeDataset.from_config(cfg, load_local=False) for cfg in cfgs ] ``` ## Profile ```python columns: list[str] = [ "grid_n", "n_mazes", "serialize", "serialize_minimal", "load", "load_minimal", "save", "save_minimal", "read", "read_minimal", ] speeds_data: list[dict] = list() ``` ```python def wrapped_timeit_fancy( name: str, function: Callable, do_profiling: bool, repeats: int, timing_stat: Callable[[StatCounter], float], ) -> tuple[dict, Any]: output: dict = dict() result: FancyTimeitResult = timeit_fancy( function, get_return=True, do_profiling=do_profiling, repeats=repeats, ) output[name] = timing_stat(result.timings) output[f"{name}:stats"] = result.timings if do_profiling: output[f"{name}:profiling"] = result.profile return output, result.return_value def measure_dataset_speed( d: MazeDataset, do_profiling: bool = True, repeats: int = 1, timing_stat: Callable[[StatCounter], float] = StatCounter.min, ) -> dict: if repeats > 1: warnings.warn( "Repeats > 1, results might not be accurate due to generation metadata being collected.", ) kwargs_fancy_timeit: dict = dict( do_profiling=do_profiling, timing_stat=timing_stat, repeats=repeats, ) set_serialize_minimal_threshold(None) _d_cpy: MazeDataset = copy.deepcopy(d) # set up row data row_data: dict = dict( grid_n=d.cfg.grid_n, n_mazes=d.cfg.n_mazes, ) # serialization & loading info_serialize, result_serialize = wrapped_timeit_fancy( "serialize_full", _d_cpy._serialize_full, **kwargs_fancy_timeit, ) row_data.update(info_serialize) _d_cpy = copy.deepcopy(d) info_serialize_min, result_serialize_min = wrapped_timeit_fancy( "serialize_minimal", _d_cpy._serialize_minimal, **kwargs_fancy_timeit, ) row_data.update(info_serialize_min) _d_cpy = copy.deepcopy(d) # info_serialize_min_alt, result_serialize_min_alt = wrapped_timeit_fancy( # 'serialize_minimal_alt', _d_cpy._serialize_minimal_alt, **kwargs_fancy_timeit # ) # row_data.update(info_serialize_min_alt) _d_cpy = copy.deepcopy(d) info_serialize_cat, result_serialize_cat = wrapped_timeit_fancy( "serialize_minimal_soln_cat", _d_cpy._serialize_minimal_soln_cat, **kwargs_fancy_timeit, ) row_data.update(info_serialize_cat) _d_cpy = copy.deepcopy(d) row_data.update( wrapped_timeit_fancy( "load_legacy", lambda: MazeDataset._load_legacy(result_serialize), **kwargs_fancy_timeit, )[0], ) row_data.update( wrapped_timeit_fancy( "load_full", lambda: MazeDataset._load_full(result_serialize), **kwargs_fancy_timeit, )[0], ) row_data.update( wrapped_timeit_fancy( "load_minimal", lambda: MazeDataset._load_minimal(result_serialize_min), **kwargs_fancy_timeit, )[0], ) row_data.update( wrapped_timeit_fancy( "load_minimal_soln_cat", lambda: MazeDataset._load_minimal_soln_cat(result_serialize_cat), **kwargs_fancy_timeit, )[0], ) row_data.update( wrapped_timeit_fancy( "load_full", lambda: MazeDataset._load_full(result_serialize), **kwargs_fancy_timeit, )[0], ) row_data.update( wrapped_timeit_fancy( "load_minimal", lambda: MazeDataset._load_minimal(result_serialize_min), **kwargs_fancy_timeit, )[0], ) row_data.update( wrapped_timeit_fancy( "load_minimal_soln_cat", lambda: MazeDataset._load_minimal_soln_cat(result_serialize_cat), **kwargs_fancy_timeit, )[0], ) # saving and loading path_default: str = f"../data/{d.cfg.to_fname()}.zanj" path_min: str = f"../data/{d.cfg.to_fname()}_min.zanj" # default set_serialize_minimal_threshold(None) _d_cpy = copy.deepcopy(d) row_data.update( wrapped_timeit_fancy( "save", lambda: _d_cpy.save(file_path=path_default), **kwargs_fancy_timeit, )[0], ) _d_cpy = copy.deepcopy(d) # read_legacy set_serialize_minimal_threshold(-1) row_data.update( wrapped_timeit_fancy( "read_legacy", lambda: MazeDataset.read(file_path=path_default), **kwargs_fancy_timeit, )[0], ) # default read set_serialize_minimal_threshold(None) row_data.update( wrapped_timeit_fancy( "read", lambda: MazeDataset.read(file_path=path_default), **kwargs_fancy_timeit, )[0], ) # minimal set_serialize_minimal_threshold(0) _d_cpy = copy.deepcopy(d) row_data.update( wrapped_timeit_fancy( "save_minimal", lambda: _d_cpy.save(file_path=path_min), **kwargs_fancy_timeit, )[0], ) _d_cpy = copy.deepcopy(d) row_data.update( wrapped_timeit_fancy( "read_minimal", lambda: MazeDataset.read(file_path=path_min), **kwargs_fancy_timeit, )[0], ) # asserts # assert d == read_default # assert d == read_minimal # reset cfg? set_serialize_minimal_threshold(None) return row_data ``` ## Run Profiling ```python for i, d in enumerate(datasets): print(f"Profiling {i + 1}/{len(datasets)}:\t{d.cfg}") result = measure_dataset_speed(d) speeds_data.append(result) cols_short: str = str({k: v for k, v in result.items() if ":" not in k}) print(f"\t{cols_short}") print(f"\t{d.cfg!s}") ``` ### Results ```python SPEEDS: pd.DataFrame = pd.DataFrame(speeds_data) SPEEDS ``` ```python def compute_speedups(speeds: pd.DataFrame) -> pd.DataFrame: # for prefix in column_measurement_prefixes: # speeds[f'{prefix}_speedup'] = speeds[f'{prefix}_full'] / speeds[f'{prefix}_minimal'] speeds["serialize/speedup"] = speeds["serialize_full"] / speeds["serialize_minimal"] speeds["load/speedup"] = speeds["load_full"] / speeds["load_minimal"] speeds["save/speedup"] = speeds["save"] / speeds["save_minimal"] speeds["read/speedup"] = speeds["read"] / speeds["read_minimal"] return speeds SPEEDS = compute_speedups(SPEEDS) ``` ```python SPEEDS: pd.DataFrame = pd.DataFrame(speeds_data) # SPEEDS.loc[:,"load_legacy":"load_minimal_soln_cat:profiling"] SPEEDS.loc[:, "read_legacy":"read:profiling"] ``` ```python SPEEDS.columns ``` ```python def compute_speedups(speeds: pd.DataFrame) -> pd.DataFrame: # for prefix in column_measurement_prefixes: # speeds[f'{prefix}_speedup'] = speeds[f'{prefix}_full'] / speeds[f'{prefix}_minimal'] speeds["serialize/speedup"] = speeds["serialize_full"] / speeds["serialize_minimal"] speeds["load_minimal/speedup"] = speeds["load_legacy"] / speeds["load_minimal"] speeds["load/speedup"] = speeds["load_legacy"] / speeds["load_full"] speeds["save/speedup"] = speeds["save"] / speeds["save_minimal"] speeds["read_minimal/speedup"] = speeds["read_legacy"] / speeds["read_minimal"] speeds["read/speedup"] = speeds["read_legacy"] / speeds["read"] return speeds SPEEDS = compute_speedups(SPEEDS) ``` ```python SPEEDS[[c for c in SPEEDS.columns if ":" not in c]] ``` ```python def plot_speeds( speeds: pd.DataFrame, column_measurement_prefixes: Sequence[str] = ("serialize", "load", "save", "read"), ) -> None: n_measurements: int = len(column_measurement_prefixes) fig, axs = plt.subplots(2, n_measurements, figsize=(n_measurements * 5, 10)) unique_grid_ns: list[int] = speeds["grid_n"].unique().tolist() for i, prefix in enumerate(column_measurement_prefixes): print(f"Plotting {prefix} timings and speedups") for grid_n in unique_grid_ns: print(f"Plotting grid_n={grid_n}") # raw timings ax_timings = axs[0, i] speeds_masked = speeds[speeds["grid_n"] == grid_n].sort_values("n_mazes") x_n_mazes = speeds_masked["n_mazes"] # Plotting for col in speeds_masked.columns: if (prefix in col) and ("speedup" not in col) and (":" not in col): ax_timings.plot( x_n_mazes, speeds_masked[col], "x-", label=f"grid_n={grid_n}, {col}", ) # Setting multiple properties with `set` ax_timings.set( xscale="log", yscale="log", xlabel="Number of mazes", ylabel="Runtime [sec]", title=f"{prefix} timings", ) ax_timings.legend() # speedups ax_speedups = axs[1, i] col_name: str = ( f"{prefix}" if prefix in ("serialize", "save") else f"{prefix}_minimal" ) ax_speedups.plot( x_n_mazes, speeds_masked[f"{col_name}/speedup"], "x-", label=f"grid_n={grid_n}", ) # Setting multiple properties with `set` for ax_speedups ax_speedups.set( xscale="log", yscale="log", xlabel="Number of mazes", ylabel="Speedup", title=f"{col_name} speedups", ) ax_speedups.plot( x_n_mazes, speeds_masked[f"{prefix}/speedup"], "x-", label=f"grid_n={grid_n}", ) # Setting multiple properties with `set` for ax_speedups ax_speedups.set( xscale="log", yscale="log", xlabel="Number of mazes", ylabel="Speedup", title=f"{prefix} speedups", ) ax_speedups.legend() plot_speeds(SPEEDS) ``` Speedups plotted on the bottom set of axes all show the `_minimal` compared to the legacy performance. `serialize_full` and `save` are unchanged from the legacy version, so speedups are plotted relative to those vectors. ```python SPEEDS[["grid_n", "n_mazes", "serialize_minimal:profiling"]] ``` ```python SPEEDS["load_minimal:profiling"][len(SPEEDS) - 1].sort_stats("tottime").print_stats() ``` ```python ``` ``````{ end_of_file="notebooks/profile_dataset_save_read.ipynb" } ``````{ path="LICENSE.md" } Attribution-ShareAlike 4.0 International ======================================================================= Creative Commons Corporation ("Creative Commons") is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an "as-is" basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. Using Creative Commons Public Licenses Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. Considerations for licensors: Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC- licensed material, or material used under an exception or limitation to copyright. More considerations for licensors: wiki.creativecommons.org/Considerations_for_licensors Considerations for the public: By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor's permission is not necessary for any reason--for example, because of any applicable exception or limitation to copyright--then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. More_considerations for the public: wiki.creativecommons.org/Considerations_for_licensees ======================================================================= Creative Commons Attribution-ShareAlike 4.0 International Public License By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. Section 1 -- Definitions. a. Adapted Material means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. b. Adapter's License means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. c. BY-SA Compatible License means a license listed at creativecommons.org/compatiblelicenses, approved by Creative Commons as essentially the equivalent of this Public License. d. Copyright and Similar Rights means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. e. Effective Technological Measures means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. f. Exceptions and Limitations means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. g. License Elements means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution and ShareAlike. h. Licensed Material means the artistic or literary work, database, or other material to which the Licensor applied this Public License. i. Licensed Rights means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. j. Licensor means the individual(s) or entity(ies) granting rights under this Public License. k. Share means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. l. Sui Generis Database Rights means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. m. You means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. Section 2 -- Scope. a. License grant. 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: a. reproduce and Share the Licensed Material, in whole or in part; and b. produce, reproduce, and Share Adapted Material. 2. Exceptions and Limitations. For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 3. Term. The term of this Public License is specified in Section 6(a). 4. Media and formats; technical modifications allowed. The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a) (4) never produces Adapted Material. 5. Downstream recipients. a. Offer from the Licensor -- Licensed Material. Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. b. Additional offer from the Licensor -- Adapted Material. Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter's License You apply. c. No downstream restrictions. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 6. No endorsement. Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). b. Other rights. 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 2. Patent and trademark rights are not licensed under this Public License. 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties. Section 3 -- License Conditions. Your exercise of the Licensed Rights is expressly made subject to the following conditions. a. Attribution. 1. If You Share the Licensed Material (including in modified form), You must: a. retain the following if it is supplied by the Licensor with the Licensed Material: i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); ii. a copyright notice; iii. a notice that refers to this Public License; iv. a notice that refers to the disclaimer of warranties; v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; b. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and c. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. b. ShareAlike. In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 1. The Adapter's License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-SA Compatible License. 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. Section 4 -- Sui Generis Database Rights. Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database; b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. Section 5 -- Disclaimer of Warranties and Limitation of Liability. a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. Section 6 -- Term and Termination. a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 2. upon express reinstatement by the Licensor. For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. Section 7 -- Other Terms and Conditions. a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. Section 8 -- Interpretation. a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. ======================================================================= Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark "Creative Commons" or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. Creative Commons may be contacted at creativecommons.org. ``````{ end_of_file="LICENSE.md" } ``````{ path="README.md" }
Maze Dataset Logo

maze-dataset

PyPI   Docs   Examples   arXiv

Diagram

PyPI   Python Version   Checks   Coverage   code size, bytes   GitHub commit activity   GitHub closed issues   GitHub closed pull requests   PyPI - Downloads

# `maze-dataset` This package provides utilities for generation, filtering, solving, visualizing, and processing of mazes for training or evaluating ML systems. Primarily built for the [maze-transformer interpretability](https://github.com/understanding-search/maze-transformer) project. You can find our paper on it here: http://arxiv.org/abs/2309.10498 This package includes a variety of maze generation algorithms, including randomized depth first search, Wilson's algorithm for uniform spanning trees, and percolation. Datasets can be filtered to select mazes of a certain length or complexity, remove duplicates, and satisfy custom properties. A variety of output formats for visualization and training ML models are provided. | | | | | |---|---|---|---| | Maze generated via percolation | Maze generated via constrained randomized depth first search | Maze with random heatmap | MazePlot with solution | You can view and search through a wide variety of example mazes here: [`understanding-search.github.io/maze-dataset/examples/maze_examples`](https://understanding-search.github.io/maze-dataset/examples/maze_examples.html) # Citing If you use this code in your research, please cite [our paper](http://arxiv.org/abs/2309.10498): ``` @misc{maze-dataset, title={A Configurable Library for Generating and Manipulating Maze Datasets}, author={Michael Igorevich Ivanitskiy and Rusheb Shah and Alex F. Spies and Tilman Räuker and Dan Valentine and Can Rager and Lucia Quirke and Chris Mathwin and Guillaume Corlouer and Cecilia Diniz Behn and Samy Wu Fung}, year={2023}, eprint={2309.10498}, archivePrefix={arXiv}, primaryClass={cs.LG}, url={http://arxiv.org/abs/2309.10498} } ``` # Installation This package is [available on PyPI](https://pypi.org/project/maze-dataset/), and can be installed via ``` pip install maze-dataset ``` # Docs The full hosted documentation is available at [https://understanding-search.github.io/maze-dataset/](https://understanding-search.github.io/maze-dataset/). Additionally, our [notebooks](https://understanding-search.github.io/maze-dataset/notebooks) serve as a good starting point for understanding the package. # Usage ## Creating a dataset To create a `MazeDataset`, which inherits from `torch.utils.data.Dataset`, you first create a `MazeDatasetConfig`: ```python from maze_dataset import MazeDataset, MazeDatasetConfig from maze_dataset.generation import LatticeMazeGenerators cfg: MazeDatasetConfig = MazeDatasetConfig( name="test", # name is only for you to keep track of things grid_n=5, # number of rows/columns in the lattice n_mazes=4, # number of mazes to generate maze_ctor=LatticeMazeGenerators.gen_dfs, # algorithm to generate the maze maze_ctor_kwargs=dict(do_forks=False), # additional parameters to pass to the maze generation algorithm ) ``` and then pass this config to the `MazeDataset.from_config` method: ```python dataset: MazeDataset = MazeDataset.from_config(cfg) ``` This method can search for whether a dataset with matching config hash already exists on your filesystem in the expected location, and load it if so. It can also generate a dataset on the fly if needed. ## Conversions to useful formats The elements of the dataset are [`SolvedMaze`](maze_dataset/maze/lattice_maze.py) objects: ```python >>> m = dataset[0] >>> type(m) maze_dataset.maze.lattice_maze.SolvedMaze ``` Which can be converted to a variety of formats: ```python # visual representation as ascii art m.as_ascii() # RGB image, optionally without solution or endpoints, suitable for CNNs m.as_pixels() # text format for autoreregressive transformers from maze_dataset.tokenization import MazeTokenizerModular, TokenizationMode m.as_tokens(maze_tokenizer=MazeTokenizerModular( tokenization_mode=TokenizationMode.AOTP_UT_rasterized, max_grid_size=100, )) # advanced visualization with many features from maze_dataset.plotting import MazePlot MazePlot(maze).plot() ``` textual and visual output formats # Development we use this [makefile template](https://github.com/mivanit/python-project-makefile-template) with slight modifications for our development workflow. - clone with `git clone https://github.com/understanding-search/maze-dataset` - `make dep` to install all dependencies - `make help` will print all available commands - `make test` will run basic tests to ensure the package is working - `make format` will run ruff to format and check the code ``````{ end_of_file="README.md" } ``````{ path="makefile" processed_with="makefile_recipes" } # first/default target is help .PHONY: default default: help ... # this recipe is weird. we need it because: # - a one liner for getting the version with toml is unwieldy, and using regex is fragile # - using $$SCRIPT_GET_VERSION within $(shell ...) doesn't work because of escaping issues # - trying to write to the file inside the `gen-version-info` recipe doesn't work, # shell eval happens before our `python -c ...` gets run and `cat` doesn't see the new file .PHONY: write-proj-version write-proj-version: ... # gets version info from $(PYPROJECT), last version from $(LAST_VERSION_FILE), and python version # uses just `python` for everything except getting the python version. no echo here, because this is "private" .PHONY: gen-version-info gen-version-info: write-proj-version ... # getting commit log since the tag specified in $(LAST_VERSION_FILE) # will write to $(COMMIT_LOG_FILE) # when publishing, the contents of $(COMMIT_LOG_FILE) will be used as the tag description (but can be edited during the process) # no echo here, because this is "private" .PHONY: gen-commit-log gen-commit-log: gen-version-info ... # force the version info to be read, printing it out # also force the commit log to be generated, and cat it out .PHONY: version version: gen-commit-log @echo "Current version is $(PROJ_VERSION), last auto-uploaded version is $(LAST_VERSION)" ... .PHONY: setup setup: dep-check @echo "install and update via uv" ... .PHONY: dep-check-torch dep-check-torch: @echo "see if torch is installed, and which CUDA version and devices it sees" ... .PHONY: dep dep: @echo "Exporting dependencies as per $(PYPROJECT) section 'tool.uv-exports.exports'" ... .PHONY: dep-check dep-check: @echo "Checking that exported requirements are up to date" ... .PHONY: dep-clean dep-clean: @echo "clean up lock files, .venv, and requirements files" ... # runs ruff and pycln to format the code .PHONY: format format: @echo "format the source code" ... # runs ruff and pycln to check if the code is formatted correctly .PHONY: format-check format-check: @echo "check if the source code is formatted correctly" ... # runs type checks with mypy .PHONY: typing typing: clean @echo "running type checks" ... # generates a report of the mypy output .PHONY: typing-report typing-report: @echo "generate a report of the type check output -- errors per file" ... .PHONY: test-unit test-unit: @echo "run unit tests" ... .PHONY: tokenizer-hashes-save tokenizer-hashes-save: @echo "generate and save tokenizer hashes" ... .PHONY: tokenizer-hashes-test tokenizer-hashes-test: @echo "re-run tokenizer hashes and compare" ... .PHONY: tokenizer-test-long tokenizer-test-long: @echo "run tests on all tokenizers. can pass NUM_TOKENIZERS_TO_TEST. doesn't check fst" ... .PHONY: tokenizer-fst-gen tokenizer-fst-gen: @echo "generate and save tokenizer FSTs" ... .PHONY: tokenizer-fst-check tokenizer-fst-check: @echo "regen all tokenizers, check their names are in the fst" ... .PHONY: tokenizer-fst-check-small tokenizer-fst-check-small: @echo "regen all tokenizers, check 1000 random ones" ... .PHONY: test-notebooks-muutils-convert test-notebooks-muutils-convert: @echo "convert notebooks in $(NOTEBOOKS_DIR) using muutils.nbutils.convert_ipynb_to_script.py" ... .PHONY: test-notebooks-muutils test-notebooks-muutils: test-notebooks-muutils-convert @echo "run tests on converted notebooks in $(CONVERTED_NOTEBOOKS_TEMP_DIR) using muutils.nbutils.run_notebook_tests.py" ... .PHONY: test-notebooks-nbmake test-notebooks-nbmake: @echo "run tests on notebooks in $(NOTEBOOKS_DIR) using nbmake" ... .PHONY: test-notebooks test-notebooks: test-notebooks-muutils test-notebooks-nbmake @echo "run tests on notebooks in $(NOTEBOOKS_DIR) using both muutils and nbmake" ... .PHONY: test test: clean test-unit test-notebooks-muutils tokenizer-fst-check-small @echo "run all usual tests: unit, notebooks, and fst check (but not tokenizer-test-long)" ... .PHONY: test-cov test-cov: clean @echo "run all pytest tests in one for coverage, including tokenizers" ... .PHONY: test-all test-all: clean test-unit test-notebooks tokenizer-fst-check tokenizer-test-long @echo "run literally all tests: unit, notebooks both ways, tokenizers fst check, long tokenizer test" ... .PHONY: check check: clean format-check test typing @echo "run format check and test" ... .PHONY: check-all check-all: clean format-check test-all typing @echo "run format check and test-all (includes tokenizers)" ... # generates a whole tree of documentation in html format. # see `$(MAKE_DOCS_SCRIPT_PATH)` and the templates in `$(DOCS_RESOURCES_DIR)/templates/html/` for more info .PHONY: docs-html docs-html: @echo "generate html docs" ... # instead of a whole website, generates a single markdown file with all docs using the templates in `$(DOCS_RESOURCES_DIR)/templates/markdown/`. # this is useful if you want to have a copy that you can grep/search, but those docs are much messier. # docs-combined will use pandoc to convert them to other formats. .PHONY: docs-md docs-md: @echo "generate combined (single-file) docs in markdown" ... # after running docs-md, this will convert the combined markdown file to other formats: # gfm (github-flavored markdown), plain text, and html # requires pandoc in path, pointed to by $(PANDOC) # pdf output would be nice but requires other deps .PHONY: docs-combined docs-combined: docs-md @echo "generate combined (single-file) docs in markdown and convert to other formats" ... # generates coverage reports as html and text with `pytest-cov`, and a badge with `coverage-badge` # if `.coverage` is not found, will run tests first # also removes the `.gitignore` file that `coverage html` creates, since we count that as part of the docs .PHONY: cov cov: @echo "generate coverage reports" ... # runs the coverage report, then the docs, then the combined docs .PHONY: docs docs: cov docs-html docs-combined todo lmcat @echo "generate all documentation and coverage reports" ... # removed all generated documentation files, but leaves everything in `$DOCS_RESOURCES_DIR` # and leaves things defined in `pyproject.toml:tool.makefile.docs.no_clean` # (templates, svg, css, make_docs.py script) # distinct from `make clean` .PHONY: docs-clean docs-clean: @echo "remove generated docs except resources" ... .PHONY: todo todo: @echo "get all TODO's from the code" ... .PHONY: lmcat-tree lmcat-tree: @echo "show in console the lmcat tree view" ... .PHONY: lmcat lmcat: @echo "write the lmcat full output to pyproject.toml:[tool.lmcat.output]" ... .PHONY: benchmark-speed-test benchmark-speed-test: @echo "test speed benchmarks" ... .PHONY: benchmark-speed benchmark-speed: @echo "run speed benchmarks" ... .PHONY: benchmark-success-test benchmark-success-test: @echo "test success benchmarks" ... .PHONY: benchmark-success benchmark-success: benchmark-success-test @echo "run success benchmarks" ... .PHONY: benchmark-test benchmark-test: benchmark-speed-test benchmark-success-test @echo "run all benchmarks tests" ... .PHONY: example-clean example-clean: @echo "clean up generated examples" ... .PHONY: example-gen example-gen: @echo "generate examples" ... .PHONY: regenerate-when-cfg-hashes-changed regenerate-when-cfg-hashes-changed: example-clean example-gen benchmark-success @echo "regenerate everything we need to when the process by which we hash configs might have changed -- like if you add a new attribute" ... # verifies that the current branch is $(PUBLISH_BRANCH) and that git is clean # used before publishing .PHONY: verify-git verify-git: @echo "checking git status" ... .PHONY: build build: @echo "build the package" ... # gets the commit log, checks everything, builds, and then publishes with twine # will ask the user to confirm the new version number (and this allows for editing the tag info) # will also print the contents of $(PYPI_TOKEN_FILE) to the console for the user to copy and paste in when prompted by twine .PHONY: publish publish: gen-commit-log check build verify-git version gen-version-info @echo "run all checks, build, and then publish" ... # cleans up temp files from formatter, type checking, tests, coverage # removes all built files # removes $(TESTS_TEMP_DIR) to remove temporary test files # recursively removes all `__pycache__` directories and `*.pyc` or `*.pyo` files # distinct from `make docs-clean`, which only removes generated documentation files .PHONY: clean clean: @echo "clean up temporary files" ... # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .PHONY: clean-all clean-all: clean docs-clean dep-clean example-clean @echo "clean up all temporary files, dep files, venv, and generated docs" ... .PHONY: info info: gen-version-info @echo "# makefile variables" ... .PHONY: info-long info-long: info @echo "# other variables" ... # immediately print out the help targets, and then local variables (but those take a bit longer) .PHONY: help help: help-targets info @echo -n "" ... .PHONY: paper-setup paper-setup: $(PAPER_PATH_INARA)/README.md @echo "set up inara and packages for compiling the paper" ... .PHONY: paper-clean-dev paper-clean-dev: @echo "clean up ephemeral files related to the paper" ... .PHONY: paper-clean paper-clean: @echo "clean up paper files" ... .PHONY: paper-tex paper-tex: @echo "compile the paper to tex" ... .PHONY: paper-pdf paper-pdf: paper-tex @echo "compile the paper to pdf" ... .PHONY: paper-html paper-html: @echo "compile the paper to html" ... .PHONY: paper paper: paper-clean paper-pdf paper-html @echo "paper compiled to tex, pdf and html" ... ``````{ end_of_file="makefile" } ``````{ path="pyproject.toml" } [project] name = "maze-dataset" version = "1.3.2" description = "generating and working with datasets of mazes" authors = [ { name = "Michael Ivanitskiy", email = "mivanits@umich.edu" }, { name = "Aaron Sandoval", email = "aaron.sandoval10@gmail.com" }, { name = "Rusheb Shah", email = "rusheb.shah@gmail.com" }, { name = "Dan Valentine", email = "danvalentine256@gmail.com" }, { name = "Lucia Quirke", email = "luciaq@canva.com" }, { name = "Can Rager", email = "can.rager@posteo.de" }, { name = "Alex Spies", email = "alexfspies@gmail.com" }, { name = "Chris Mathwin", email = "cwmathwin@gmail.com" }, { name = "Tilman Rauker", email = "traeuker@googlemail.com" }, { name = "Guillaume Corlouer", email = "guillaume.corlouer@gmail.com" }, ] readme = "README.md" requires-python = ">=3.10" # source info # packages = [{include = "maze_dataset"}] # exclude = ["maze_dataset/tokenization/MazeTokenizerModular_hashes.npz"] # don't ship the hashes # informational metadata keywords = ["maze", "mazes", "labyrinth", "dataset", "procedural", "pathfinding", "tokenization"] dependencies = [ # custom packages "muutils>=0.8.3", "zanj>=0.5.0", # arrays and type hints "numpy", "jaxtyping>=0.2.19", # standard numerical "matplotlib>=3.7.0", # notebooks "jupyter>=1.0.0", "ipykernel>=6.22.0", # misc "tqdm>=4.65.0", ] [project.optional-dependencies] tokenization = [ "frozendict>=2.4.4", # storing valid tokenizers # doesn't appear to work on macos "rust_fst>=0.1.2; platform_system != 'darwin'", ] [dependency-groups] dev = [ # for benchmarking "pandas>=2.2.2", # test "pytest>=8.2.2", "pytest-xdist>=3.6.1", # for parallel all tokenizers tests "pytest-mock>=3.10.0", "nbmake>=1.5.5", # coverage "pytest-cov>=4.1.0", "coverage-badge>=1.1.0", # type checking "mypy>=1.0.1", "types-tqdm", "pandas-stubs", "types-psutil", # see https://github.com/understanding-search/maze-dataset/actions/runs/14327419830/job/40155509863 # docs 'pdoc>=14.6.0', "nbconvert>=7.16.4", # for notebooks # lmcat -- a custom library. not exactly docs, but lets an LLM see all the code "lmcat>=0.2.0; python_version >= '3.11'", # tomli since no tomlib in python < 3.11 "tomli>=2.1.0; python_version < '3.11'", # uploading "twine", ] lint = [ # lint "ruff>=0.4.8", ] benchmark = [ # only used in `estimate_dataset_fractions.ipynb` "pysr>=1.4.0", "seaborn", ] [project.urls] Homepage = "https://github.com/understanding-search/maze-dataset" Documentation = "https://understanding-search.github.io/maze-dataset/" Repository = "https://github.com/understanding-search/maze-dataset" Issues = "https://github.com/understanding-search/maze-dataset/issues" [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] exclude = ["maze_dataset/tokenization/MazeTokenizerModular_hashes.npz"] [tool.pytest.ini_options] # Ignore numpy deprecation warnings triggered by muutils filterwarnings = [ # Warning from muutils: https://github.com/mivanit/muutils/issues/1 "ignore:`np\\.\\w*` is a deprecated alias for:DeprecationWarning", # Warning from matplotlib. Issue: https://github.com/matplotlib/matplotlib/issues/25244 "ignore:Deprecated call to `pkg_resources.declare_namespace:DeprecationWarning", # temporary fix for lots of deprecation warnings for old tokenizers "ignore::maze_dataset.token_utils.TokenizerPendingDeprecationWarning", ] testpaths = "tests" norecursedirs="maze_dataset/utils/test_helpers" [tool.mypy] # generate this exclude with `make typing-report` exclude = [ # high priority "tests/unit/processing/test_collect_gen_metadata.py", # 3 "tests/unit/generation/test_latticemaze.py", # 7 "maze_dataset/constants.py", # 9 # tokenization "maze_dataset/tokenization/modular/all_tokenizers.py", # 8 "maze_dataset/tokenization/maze_tokenizer_legacy.py", # 19 "maze_dataset/token_utils.py", # 21 "maze_dataset/tokenization/modular/maze_tokenizer_modular.py", # 12 "maze_dataset/tokenization/modular/elements.py", # 97 "maze_dataset/tokenization/modular/all_instances.py", # 17 # low priority "maze_dataset/plotting/plot_maze.py", # 11 "tests/all_tokenizers/test_all_tokenizers.py", # 12 "tests/unit/generation/test_maze_dataset.py", # 16 "tests/unit/tokenization/test_token_utils.py", # 16 "tests/unit/tokenization/test_tokenizer.py", # 45 "tests/unit/processing/test_get_forking_path_points.py", # 58 # extra low priority (test temp, generated from notebooks) "tests/_temp/*", ] check_untyped_defs = true [[tool.mypy.overrides]] module = "fire" ignore_missing_imports = true # ruff config [tool.ruff] exclude = ["__pycache__"] [tool.ruff.format] indent-style = "tab" skip-magic-trailing-comma = false [tool.ruff.lint] ignore = [ "TC002", # fine to normally import jaxtyping and others not in a TYPE_CHECKING block "F722", # doesn't like jaxtyping "W191", # we like tabs "D400", # missing-trailing-period "D415", # missing-terminal-punctuation "E501", # line-too-long "S101", # assert is fine "D403", # first-word-uncapitalized "D206", # docstring-tab-indentation "ERA001", # commented-out-code "T201", # print is fine lmao "C408", # calling dict() is fine "UP015", # we like specifying the mode even if it's the default "D300", # we like docstrings # boolean positional arguments are fine "FBT001", "FBT002", "FBT003", "PTH123", # opening files is fine "RET505", # else return is fine "FIX001", # FIXME comments are ok since `make todo` handles them "FIX002", # `make todo` will give us the TODO comments "FIX004", # same for `HACK` "PIE790", # be explicit about when we pass "EM101", # fine to have string literal exceptions "FURB129", # .readlines() is fine "SIM108", # ternary operators can be hard to read, choose on a case-by-case basis "PLR5501", # nested if else is fine, for readability "D203", # docstring right after the class "D213", # docstring on first line "NPY002", # legacy numpy generator is fine "D401", # dont care about imperative mood "RUF022", # don't want to sort __all__ lexicographically, sort by meaning "PLR0913", # sometimes you have to have a lot of args "B028", # fine to omit stacklevel on warnings "SLF001", # fine to access private vars "N802", # uppercase in func names is fine # warning: The following rule may cause conflicts when used with the formatter: `COM812`. To avoid unexpected behavior, we recommend disabling this rule, either by removing it from the `select` or `extend-select` configuration, or adding it to the `ignore` configuration. "COM812", "TC001", # don't force us to import things in type checking blocks # todos: "TD001", # we allow tags besides "TODO" "TD002", # dont care about author "TD003", # `make todo` will give us a table where we can create issues # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # TODO: no type hints on *args or **kwargs for now "ANN002", "ANN003", # TODO: more fine-grained exception classes "TRY003", # TODO: use extend instead of append? "PERF401", # HACK: need to be more specific about mypy ignores "PGH003", # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # only for old version compatibility "UP007", # `Optional` is ok, we might not want to use `|` for compatibility # old style hints `Tuple`, `List`, etc. are fine "UP006", "UP035", ] select = ["ALL"] # select = ["ICN001"] [tool.ruff.lint.per-file-ignores] "maze_dataset/generation/generators.py" = [ # none of this is for security "S311", # yes the generation functions are complicated "C901", ] "maze_dataset/tokenization/modular/*.py" = [ # TODO: lots of unused args in the tokenizer code "ARG002" ] "tests/*" = [ # dont need docstrings in test functions or modules "D100", "D101", "D102", "D103", "INP001", # dont need __init__ either # dont need type annotations in test functions "ANN001", "ANN201", "ANN202", "TRY003", # long exception messages in tests are fine "PLR2004", # magic variables are fine "C419", # unnecessary comprehensions are fine # uppercase is fine in tests (we write UT and AOTP a lot) "N802", "N806", # using os for path in tests is fine (not in main lib tho) "PTH100", "PTH118", ] "docs/resources/make_docs.py" = ["ALL"] # not our problem "docs/*" = [ "INP001", # scripts, not modules ] "test.ipynb" = ["ALL"] # this is just a test notebook "**/*.ipynb" = [ "D103", # dont need docstrings "PLR2004", # magic variables are fine "N806", # uppercase vars are fine ] [tool.lmcat] output = "docs/other/lmcat.txt" # changing this might mean it wont be accessible from the docs ignore_patterns = [ "docs/**", ".venv/**", ".git/**", ".meta/**", "uv.lock", ".ruff_cache/**", ".github/ISSUE_TEMPLATE/**", "_wip/**", "sweep.yaml", # there are... a lot of tests. we usually dont need to put these in lmcat "tests/**", "maze_dataset/tokenization/modular/MazeTokenizerModular_tested.fst", ] [tool.lmcat.glob_process] "[mM]akefile" = "makefile_recipes" "*.ipynb" = "ipynb_to_md" # ============================================================ [tool.makefile] # documentation configuration, for `make docs` and `make docs-clean` [tool.makefile.docs] # Output directory for generated documentation # MUST match DOCS_DIR in makefile output_dir = "docs" # List of files/directories in docs/ that should not be cleaned by `make docs-clean` # These paths are relative to output_dir no_clean = [ ".nojekyll", "assets", "benchmarks", "paper", "resources", "examples", # "resources/", # Templates, CSS, etc. this, or whatever is specified as DOCS_RESOURCES_DIR in makefile will always be preserved ] markdown_headings_increment = 2 warnings_ignore = [ "Error parsing type annotation FilterBy for maze_dataset", "Found 'coord_str_to_tuple' in maze_dataset.tokenization.__all__, but it does not resolve: Error importing maze_dataset.tokenization.coord_str_to_tuple", ] [tool.makefile.docs.notebooks] enabled = true source_path = "notebooks" output_path_relative = "notebooks" [tool.makefile.docs.notebooks.descriptions] "demo_dataset" = "Creating and filtering a dataset, and various output formats" "demo_generator" = "Exploring different maze generation algorithms and parameters" "demo_latticemaze" = "Working with LatticeMaze class, visualization and solving mazes" "demo_mazetokenizermodular" = "Using the modern MazeTokenizerModular system for tokenization" "demo_tokenization" = "Legacy tokenization with MazeTokenizer and TokenizationMode" "estimate_dataset_fractions" = "Estimating and predicting maze generation success rates" "forking_points" = "Identifying and working with decision points in maze solutions" "iterated_backfilling" = "Implementation of iterated backfilling as an algorithm for solving visual mazes" "profile_dataset_save_read" = "Profiling and optimizing dataset serialization performance" # Custom export configurations # affects `make dep` and related commands [tool.makefile.uv-exports] args = [ "--no-hashes" ] exports = [ # no groups, no extras, just the base dependencies { name = "base", groups = false, extras = false }, # all groups { name = "groups", groups = true, extras = false }, # only the lint group -- custom options for this { name = "lint", options = ["--only-group", "lint"] }, # # all groups and extras { name = "all", filename="requirements.txt", groups = true, extras=true }, # # all groups and extras, a different way { name = "all", groups = true, options = ["--all-extras"] }, ] # configures `make todo` [tool.makefile.inline-todo] search_dir = "." out_file_base = "docs/other/todo-inline" context_lines = 2 extensions = ["py", "md"] tags = ["CRIT", "TODO", "FIXME", "HACK", "BUG", "DOC", "DOCS", "TYPING"] exclude = [ "docs/**", ".venv/**", "scripts/get_todos.py", ] # branch to put in the url branch = "main" # Mapping of tags to GitHub issue labels [tool.makefile.inline-todo.tag_label_map] "BUG" = "bug" "TODO" = "enhancement" "DOC" = "documentation" ``````{ end_of_file="pyproject.toml" }