Coverage for maze_dataset/dataset/filters.py: 45%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1"filtering `MazeDataset`s" 

2 

3import copy 

4import functools 

5import typing 

6from collections import Counter, defaultdict 

7 

8import numpy as np 

9 

10from maze_dataset.constants import CoordTup 

11from maze_dataset.dataset.dataset import ( 

12 DatasetFilterFunc, 

13 register_dataset_filter, 

14 register_filter_namespace_for_dataset, 

15) 

16from maze_dataset.dataset.maze_dataset import MazeDataset 

17from maze_dataset.maze import SolvedMaze 

18 

19 

20def register_maze_filter( 

21 method: typing.Callable[[SolvedMaze, typing.Any], bool], 

22) -> DatasetFilterFunc: 

23 """register a maze filter, casting it to operate over the whole list of mazes 

24 

25 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 

26 

27 this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays 

28 """ 

29 

30 @functools.wraps(method) 

31 def wrapper(dataset: MazeDataset, *args, **kwargs) -> MazeDataset: 

32 # copy and filter 

33 new_dataset: MazeDataset = copy.deepcopy( 

34 MazeDataset( 

35 cfg=dataset.cfg, 

36 mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)], 

37 ), 

38 ) 

39 # update the config 

40 new_dataset.cfg.applied_filters.append( 

41 dict(name=method.__name__, args=args, kwargs=kwargs), 

42 ) 

43 new_dataset.update_self_config() 

44 return new_dataset 

45 

46 return wrapper 

47 

48 

49@register_filter_namespace_for_dataset(MazeDataset) 

50class MazeDatasetFilters: 

51 "namespace for filters for `MazeDataset`s" 

52 

53 @register_maze_filter 

54 @staticmethod 

55 def path_length(maze: SolvedMaze, min_length: int) -> bool: 

56 """filter out mazes with a solution length less than `min_length`""" 

57 return len(maze.solution) >= min_length 

58 

59 @register_maze_filter 

60 @staticmethod 

61 def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool: 

62 """filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)""" 

63 return bool( 

64 (np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance).all() 

65 ) 

66 

67 @register_dataset_filter 

68 @staticmethod 

69 def cut_percentile_shortest( 

70 dataset: MazeDataset, 

71 percentile: float = 10.0, 

72 ) -> MazeDataset: 

73 """cut the shortest `percentile` of mazes from the dataset 

74 

75 `percentile` is 1-100, not 0-1, as this is what `np.percentile` expects 

76 """ 

77 lengths: np.ndarray = np.array([len(m.solution) for m in dataset]) 

78 cutoff: int = int(np.percentile(lengths, percentile)) 

79 

80 filtered_mazes: list[SolvedMaze] = [ 

81 m for m in dataset if len(m.solution) > cutoff 

82 ] 

83 new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes) 

84 

85 return copy.deepcopy(new_dataset) 

86 

87 @register_dataset_filter 

88 @staticmethod 

89 def truncate_count( 

90 dataset: MazeDataset, 

91 max_count: int, 

92 ) -> MazeDataset: 

93 """truncate the dataset to be at most `max_count` mazes""" 

94 new_dataset: MazeDataset = MazeDataset( 

95 cfg=dataset.cfg, 

96 mazes=dataset.mazes[:max_count], 

97 ) 

98 return copy.deepcopy(new_dataset) 

99 

100 @register_dataset_filter 

101 @staticmethod 

102 def remove_duplicates( 

103 dataset: MazeDataset, 

104 minimum_difference_connection_list: int | None = 1, 

105 minimum_difference_solution: int | None = 1, 

106 _max_dataset_len_threshold: int = 1000, 

107 ) -> MazeDataset: 

108 """remove duplicates from a dataset, keeping the **LAST** unique maze 

109 

110 set minimum either minimum difference to `None` to disable checking 

111 

112 if you want to avoid mazes which have more overlap, set the minimum difference to be greater 

113 

114 Gotchas: 

115 - if two mazes are of different sizes, they will never be considered duplicates 

116 - if two solutions are of different lengths, they will never be considered duplicates 

117 

118 TODO: check for overlap? 

119 """ 

120 if len(dataset) > _max_dataset_len_threshold: 

121 raise ValueError( 

122 "this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n", 

123 "if you know what you're doing, change `_max_dataset_len_threshold`", 

124 ) 

125 

126 unique_mazes: list[SolvedMaze] = list() 

127 

128 maze_a: SolvedMaze 

129 maze_b: SolvedMaze 

130 for i, maze_a in enumerate(dataset.mazes): 

131 a_unique: bool = True 

132 for maze_b in dataset.mazes[i + 1 :]: 

133 # after all that nesting, more nesting to perform checks 

134 if (minimum_difference_connection_list is not None) and ( # noqa: SIM102 

135 maze_a.connection_list.shape == maze_b.connection_list.shape 

136 ): 

137 if ( 

138 np.sum(maze_a.connection_list != maze_b.connection_list) 

139 <= minimum_difference_connection_list 

140 ): 

141 a_unique = False 

142 break 

143 

144 if (minimum_difference_solution is not None) and ( # noqa: SIM102 

145 maze_a.solution.shape == maze_b.solution.shape 

146 ): 

147 if ( 

148 np.sum(maze_a.solution != maze_b.solution) 

149 <= minimum_difference_solution 

150 ): 

151 a_unique = False 

152 break 

153 

154 if a_unique: 

155 unique_mazes.append(maze_a) 

156 

157 return copy.deepcopy( 

158 MazeDataset( 

159 cfg=dataset.cfg, 

160 mazes=unique_mazes, 

161 generation_metadata_collected=dataset.generation_metadata_collected, 

162 ), 

163 ) 

164 

165 @register_dataset_filter 

166 @staticmethod 

167 def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset: 

168 """remove duplicates from a dataset""" 

169 unique_mazes = list(dict.fromkeys(dataset.mazes)) 

170 return copy.deepcopy( 

171 MazeDataset( 

172 cfg=dataset.cfg, 

173 mazes=unique_mazes, 

174 generation_metadata_collected=dataset.generation_metadata_collected, 

175 ), 

176 ) 

177 

178 @register_dataset_filter 

179 @staticmethod 

180 def strip_generation_meta(dataset: MazeDataset) -> MazeDataset: 

181 """strip the generation meta from the dataset""" 

182 new_dataset: MazeDataset = copy.deepcopy(dataset) 

183 for maze in new_dataset: 

184 # hacky because it's a frozen dataclass 

185 maze.__dict__["generation_meta"] = None 

186 return new_dataset 

187 

188 @register_dataset_filter 

189 @staticmethod 

190 # yes, this function is complicated hence the noqa 

191 def collect_generation_meta( # noqa: C901, PLR0912 

192 dataset: MazeDataset, 

193 clear_in_mazes: bool = True, 

194 inplace: bool = True, 

195 allow_fail: bool = False, 

196 ) -> MazeDataset: 

197 """collect the generation metadata from each maze into a dataset-level metadata (saves space) 

198 

199 # Parameters: 

200 - `dataset : MazeDataset` 

201 - `clear_in_mazes : bool` 

202 whether to clear the generation meta in the mazes after collecting it, keep it there if `False` 

203 (defaults to `True`) 

204 - `inplace : bool` 

205 whether to modify the dataset in place or return a new one 

206 (defaults to `True`) 

207 - `allow_fail : bool` 

208 whether to allow the collection to fail if the generation meta is not present in a maze 

209 (defaults to `False`) 

210 

211 # Returns: 

212 - `MazeDataset` 

213 the dataset with the generation metadata collected 

214 

215 # Raises: 

216 - `ValueError` : if the generation meta is not present in a maze and `allow_fail` is `False` 

217 - `ValueError` : if we have other problems converting the generation metadata 

218 - `TypeError` : if the generation meta on a maze is of an unexpected type 

219 """ 

220 if dataset.generation_metadata_collected is not None: 

221 return dataset 

222 else: 

223 assert dataset[0].generation_meta is not None, ( 

224 "generation meta is not collected and original is not present" 

225 ) 

226 # if the generation meta is already collected, don't collect it again, do nothing 

227 

228 new_dataset: MazeDataset 

229 if inplace: 

230 new_dataset = dataset 

231 else: 

232 new_dataset = copy.deepcopy(dataset) 

233 

234 gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = ( 

235 defaultdict(Counter) 

236 ) 

237 for maze in new_dataset: 

238 if maze.generation_meta is None: 

239 if allow_fail: 

240 break 

241 raise ValueError( 

242 "generation meta is not present in a maze, cannot collect generation meta", 

243 ) 

244 for key, value in maze.generation_meta.items(): 

245 if isinstance(value, (bool, int, float, str)): # noqa: UP038 

246 gen_meta_lists[key][value] += 1 

247 

248 elif isinstance(value, set): 

249 # special case for visited_cells 

250 gen_meta_lists[key].update(value) 

251 

252 elif isinstance(value, (list, np.ndarray)): # noqa: UP038 

253 if isinstance(value, list): 

254 # TODO: `for` loop variable `value` overwritten by assignment target (Ruff PLW2901) 

255 try: 

256 value = np.array(value) # noqa: PLW2901 

257 except ValueError as convert_to_np_err: 

258 err_msg = ( 

259 f"Cannot collect generation meta for {key} as it is a list of type '{type(value[0]) = !s}'" 

260 "\nexpected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords" 

261 ) 

262 raise ValueError(err_msg) from convert_to_np_err 

263 

264 if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim): 

265 # assume its a single coordinate 

266 gen_meta_lists[key][tuple(value)] += 1 

267 # magic value is fine here 

268 elif (len(value.shape) == 2) and ( # noqa: PLR2004 

269 value.shape[1] == maze.lattice_dim 

270 ): 

271 # assume its a list of coordinates 

272 gen_meta_lists[key].update([tuple(v) for v in value]) 

273 else: 

274 err_msg = ( 

275 f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}\n" 

276 "expected either a coord of shape (2,) or a list of coords of shape (n, 2)" 

277 ) 

278 raise ValueError(err_msg) 

279 else: 

280 err_msg = ( 

281 f"Cannot collect generation meta for {key} as it is of type '{type(value)!s}'\n" 

282 "expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords" 

283 ) 

284 raise TypeError(err_msg) 

285 

286 # clear the data 

287 if clear_in_mazes: 

288 # hacky because it's a frozen dataclass 

289 maze.__dict__["generation_meta"] = None 

290 

291 new_dataset.generation_metadata_collected = { 

292 key: dict(value) for key, value in gen_meta_lists.items() 

293 } 

294 

295 return new_dataset