Coverage for maze_dataset/generation/registration.py: 93%

43 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-08-03 22:20 -0700

1"""Registration system for custom maze generators""" 

2 

3import inspect 

4from typing import TypeVar, Union, get_args, get_origin 

5 

6from maze_dataset.constants import Coord, CoordTup 

7from maze_dataset.generation.generators import ( 

8 GENERATORS_MAP, 

9 LatticeMazeGenerators, 

10 MazeGeneratorFunc, 

11) 

12from maze_dataset.maze import LatticeMaze 

13 

14F_MazeGeneratorFunc = TypeVar("F_MazeGeneratorFunc", bound=MazeGeneratorFunc) 

15 

16 

17class MazeGeneratorRegistrationError(TypeError): 

18 """error for maze generator registration issues""" 

19 

20 pass 

21 

22 

23def _check_grid_shape_annotation( 

24 annotation: Union[type, str], 

25 func_name: str, 

26) -> None: 

27 """Check if the annotation for grid_shape is valid 

28 

29 # Parameters: 

30 - `annotation` : Union[type, str] the type annotation of the grid_shape parameter 

31 

32 # Raises: 

33 - `MazeGeneratorRegistrationError` if the annotation is not compatible 

34 """ 

35 annotation_str: str = str(annotation) 

36 

37 if not any( 

38 [ 

39 annotation == Coord, 

40 annotation == CoordTup, 

41 annotation_str.lower() == "tuple[int, int]", 

42 # TODO: these are pretty loose checks, would be better to make this more robust 

43 "ndarray" in annotation_str, 

44 "Int8" in annotation_str, 

45 ] 

46 ): 

47 err_msg = ( 

48 f"Maze generator function '{func_name}' first parameter 'grid_shape' " 

49 f"must be typed as 'Coord | CoordTup' or compatible type, " 

50 f"got {annotation = }, {annotation_str = }" 

51 ) 

52 raise MazeGeneratorRegistrationError(err_msg) 

53 

54 

55def validate_MazeGeneratorFunc( 

56 func: F_MazeGeneratorFunc, 

57) -> None: 

58 """validate the signature of a maze generator function 

59 

60 return `None` if valid, otherwise raises `MazeGeneratorRegistrationError` 

61 (which is a subclass of `TypeError`) 

62 

63 # Parameters: 

64 - `func : MazeGeneratorFunc` function to validate 

65 

66 # Raises: 

67 - `MazeGeneratorRegistrationError` : type error describing the issue with the function signature 

68 """ 

69 func_name: str = func.__name__ 

70 sig: inspect.Signature = inspect.signature(func) 

71 params: list[str] = list(sig.parameters.keys()) 

72 

73 if not params or params[0] != "grid_shape": 

74 err_msg = ( 

75 f"Maze generator function '{func_name}' must have 'grid_shape' " 

76 "as its first parameter. Please ensure the function signature starts with 'grid_shape: Coord | CoordTup'." 

77 f"{params = }" 

78 ) 

79 raise MazeGeneratorRegistrationError(err_msg) 

80 

81 # Check first parameter type annotation if present 

82 first_param_annotation = sig.parameters["grid_shape"].annotation 

83 if first_param_annotation == inspect.Parameter.empty: 

84 err_msg = ( 

85 f"Maze generator function '{func_name}' must have a type annotation for 'grid_shape'. " 

86 "Please add `grid_shape: Coord | CoordTup` to the function signature." 

87 ) 

88 

89 # Check if it's a Union type 

90 if get_origin(first_param_annotation) in (Union, type(Union[int, str])): 

91 args: tuple = get_args(first_param_annotation) 

92 # Check all of the args look like Coord or CoordTup 

93 for arg in args: 

94 _check_grid_shape_annotation(arg, func_name) 

95 else: 

96 # Check if the annotation is a single type 

97 _check_grid_shape_annotation(first_param_annotation, func_name) 

98 

99 # Check return type annotation - must be present and correct 

100 if sig.return_annotation == inspect.Signature.empty: 

101 err_msg = ( 

102 f"Maze generator function '{func_name}' must have a return type annotation " 

103 "of LatticeMaze. Please add `-> LatticeMaze` to the function signature." 

104 ) 

105 raise MazeGeneratorRegistrationError(err_msg) 

106 

107 if sig.return_annotation != LatticeMaze: 

108 err_msg = ( 

109 f"Maze generator function '{func_name}' must return LatticeMaze, " 

110 f"got return type annotation: {sig.return_annotation}. " 

111 "Please ensure the function returns a valid LatticeMaze instance." 

112 ) 

113 raise MazeGeneratorRegistrationError(err_msg) 

114 

115 

116def register_maze_generator(func: F_MazeGeneratorFunc) -> F_MazeGeneratorFunc: 

117 """Decorator to register a custom maze generator function. 

118 

119 This decorator allows users to register their own maze generation functions 

120 without modifying the core library code. The registered function will be: 

121 1. Added to `GENERATORS_MAP` 

122 2. Added as a static method to `LatticeMazeGenerators` (for compatibility) 

123 

124 # NOTE: 

125 In general, you should avoid using this decorator! instead, just add your function 

126 to the GENERATORS_MAP dictionary directly and to `LatticeMazeGenerators` as a static method. 

127 If you add a new function, please make a pull request! 

128 https://github.com/understanding-search/maze-dataset/pulls 

129 

130 # Usage: 

131 ```python 

132 @register_maze_generator 

133 def gen_my_custom( 

134 grid_shape: Coord | CoordTup, 

135 **kwargs, # this can be anything you like 

136 ) -> LatticeMaze: 

137 # Your custom maze generation logic here 

138 connection_list = ... # Create your maze structure 

139 

140 # Important: If your maze is not fully connected, you must include 

141 # `visited_cells` in `generation_meta`, or mark it as `fully_connected=True` 

142 return LatticeMaze( 

143 connection_list=connection_list, 

144 generation_meta=dict( 

145 func_name="gen_my_custom", 

146 grid_shape=np.array(grid_shape), 

147 fully_connected=True, # or provide visited_cells 

148 ), 

149 ) 

150 ``` 

151 

152 note that to properly create **or load** a maze dataset with your custom generator, 

153 you must be importing the file in which you register the generator 

154 

155 # Returns: 

156 The decorated function, unchanged. 

157 """ 

158 # Validate function signature 

159 validate_MazeGeneratorFunc(func) 

160 

161 # Check if name already exists 

162 func_name: str = func.__name__ 

163 if func_name in GENERATORS_MAP: 

164 err_msg = ( 

165 f"Generator with name '{func_name}' already exists in GENERATORS_MAP. " 

166 "Please choose a different name." 

167 ) 

168 raise ValueError(err_msg) 

169 

170 # Register the function 

171 GENERATORS_MAP[func_name] = func 

172 

173 # Also add as a static method to LatticeMazeGenerators 

174 setattr(LatticeMazeGenerators, func_name, staticmethod(func)) 

175 

176 return func