Coverage for maze_dataset/generation/registration.py: 93%
43 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 22:20 -0700
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-03 22:20 -0700
1"""Registration system for custom maze generators"""
3import inspect
4from typing import TypeVar, Union, get_args, get_origin
6from maze_dataset.constants import Coord, CoordTup
7from maze_dataset.generation.generators import (
8 GENERATORS_MAP,
9 LatticeMazeGenerators,
10 MazeGeneratorFunc,
11)
12from maze_dataset.maze import LatticeMaze
14F_MazeGeneratorFunc = TypeVar("F_MazeGeneratorFunc", bound=MazeGeneratorFunc)
17class MazeGeneratorRegistrationError(TypeError):
18 """error for maze generator registration issues"""
20 pass
23def _check_grid_shape_annotation(
24 annotation: Union[type, str],
25 func_name: str,
26) -> None:
27 """Check if the annotation for grid_shape is valid
29 # Parameters:
30 - `annotation` : Union[type, str] the type annotation of the grid_shape parameter
32 # Raises:
33 - `MazeGeneratorRegistrationError` if the annotation is not compatible
34 """
35 annotation_str: str = str(annotation)
37 if not any(
38 [
39 annotation == Coord,
40 annotation == CoordTup,
41 annotation_str.lower() == "tuple[int, int]",
42 # TODO: these are pretty loose checks, would be better to make this more robust
43 "ndarray" in annotation_str,
44 "Int8" in annotation_str,
45 ]
46 ):
47 err_msg = (
48 f"Maze generator function '{func_name}' first parameter 'grid_shape' "
49 f"must be typed as 'Coord | CoordTup' or compatible type, "
50 f"got {annotation = }, {annotation_str = }"
51 )
52 raise MazeGeneratorRegistrationError(err_msg)
55def validate_MazeGeneratorFunc(
56 func: F_MazeGeneratorFunc,
57) -> None:
58 """validate the signature of a maze generator function
60 return `None` if valid, otherwise raises `MazeGeneratorRegistrationError`
61 (which is a subclass of `TypeError`)
63 # Parameters:
64 - `func : MazeGeneratorFunc` function to validate
66 # Raises:
67 - `MazeGeneratorRegistrationError` : type error describing the issue with the function signature
68 """
69 func_name: str = func.__name__
70 sig: inspect.Signature = inspect.signature(func)
71 params: list[str] = list(sig.parameters.keys())
73 if not params or params[0] != "grid_shape":
74 err_msg = (
75 f"Maze generator function '{func_name}' must have 'grid_shape' "
76 "as its first parameter. Please ensure the function signature starts with 'grid_shape: Coord | CoordTup'."
77 f"{params = }"
78 )
79 raise MazeGeneratorRegistrationError(err_msg)
81 # Check first parameter type annotation if present
82 first_param_annotation = sig.parameters["grid_shape"].annotation
83 if first_param_annotation == inspect.Parameter.empty:
84 err_msg = (
85 f"Maze generator function '{func_name}' must have a type annotation for 'grid_shape'. "
86 "Please add `grid_shape: Coord | CoordTup` to the function signature."
87 )
89 # Check if it's a Union type
90 if get_origin(first_param_annotation) in (Union, type(Union[int, str])):
91 args: tuple = get_args(first_param_annotation)
92 # Check all of the args look like Coord or CoordTup
93 for arg in args:
94 _check_grid_shape_annotation(arg, func_name)
95 else:
96 # Check if the annotation is a single type
97 _check_grid_shape_annotation(first_param_annotation, func_name)
99 # Check return type annotation - must be present and correct
100 if sig.return_annotation == inspect.Signature.empty:
101 err_msg = (
102 f"Maze generator function '{func_name}' must have a return type annotation "
103 "of LatticeMaze. Please add `-> LatticeMaze` to the function signature."
104 )
105 raise MazeGeneratorRegistrationError(err_msg)
107 if sig.return_annotation != LatticeMaze:
108 err_msg = (
109 f"Maze generator function '{func_name}' must return LatticeMaze, "
110 f"got return type annotation: {sig.return_annotation}. "
111 "Please ensure the function returns a valid LatticeMaze instance."
112 )
113 raise MazeGeneratorRegistrationError(err_msg)
116def register_maze_generator(func: F_MazeGeneratorFunc) -> F_MazeGeneratorFunc:
117 """Decorator to register a custom maze generator function.
119 This decorator allows users to register their own maze generation functions
120 without modifying the core library code. The registered function will be:
121 1. Added to `GENERATORS_MAP`
122 2. Added as a static method to `LatticeMazeGenerators` (for compatibility)
124 # NOTE:
125 In general, you should avoid using this decorator! instead, just add your function
126 to the GENERATORS_MAP dictionary directly and to `LatticeMazeGenerators` as a static method.
127 If you add a new function, please make a pull request!
128 https://github.com/understanding-search/maze-dataset/pulls
130 # Usage:
131 ```python
132 @register_maze_generator
133 def gen_my_custom(
134 grid_shape: Coord | CoordTup,
135 **kwargs, # this can be anything you like
136 ) -> LatticeMaze:
137 # Your custom maze generation logic here
138 connection_list = ... # Create your maze structure
140 # Important: If your maze is not fully connected, you must include
141 # `visited_cells` in `generation_meta`, or mark it as `fully_connected=True`
142 return LatticeMaze(
143 connection_list=connection_list,
144 generation_meta=dict(
145 func_name="gen_my_custom",
146 grid_shape=np.array(grid_shape),
147 fully_connected=True, # or provide visited_cells
148 ),
149 )
150 ```
152 note that to properly create **or load** a maze dataset with your custom generator,
153 you must be importing the file in which you register the generator
155 # Returns:
156 The decorated function, unchanged.
157 """
158 # Validate function signature
159 validate_MazeGeneratorFunc(func)
161 # Check if name already exists
162 func_name: str = func.__name__
163 if func_name in GENERATORS_MAP:
164 err_msg = (
165 f"Generator with name '{func_name}' already exists in GENERATORS_MAP. "
166 "Please choose a different name."
167 )
168 raise ValueError(err_msg)
170 # Register the function
171 GENERATORS_MAP[func_name] = func
173 # Also add as a static method to LatticeMazeGenerators
174 setattr(LatticeMazeGenerators, func_name, staticmethod(func))
176 return func