Coverage for maze_dataset/dataset/filters.py: 45%
111 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-24 00:33 -0600
1"filtering `MazeDataset`s"
3import copy
4import functools
5import typing
6from collections import Counter, defaultdict
8import numpy as np
10from maze_dataset.constants import CoordTup
11from maze_dataset.dataset.dataset import (
12 DatasetFilterFunc,
13 register_dataset_filter,
14 register_filter_namespace_for_dataset,
15)
16from maze_dataset.dataset.maze_dataset import MazeDataset
17from maze_dataset.maze import SolvedMaze
20def register_maze_filter(
21 method: typing.Callable[[SolvedMaze, typing.Any], bool],
22) -> DatasetFilterFunc:
23 """register a maze filter, casting it to operate over the whole list of mazes
25 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
27 this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays
28 """
30 @functools.wraps(method)
31 def wrapper(dataset: MazeDataset, *args, **kwargs) -> MazeDataset:
32 # copy and filter
33 new_dataset: MazeDataset = copy.deepcopy(
34 MazeDataset(
35 cfg=dataset.cfg,
36 mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)],
37 ),
38 )
39 # update the config
40 new_dataset.cfg.applied_filters.append(
41 dict(name=method.__name__, args=args, kwargs=kwargs),
42 )
43 new_dataset.update_self_config()
44 return new_dataset
46 return wrapper
49@register_filter_namespace_for_dataset(MazeDataset)
50class MazeDatasetFilters:
51 "namespace for filters for `MazeDataset`s"
53 @register_maze_filter
54 @staticmethod
55 def path_length(maze: SolvedMaze, min_length: int) -> bool:
56 """filter out mazes with a solution length less than `min_length`"""
57 return len(maze.solution) >= min_length
59 @register_maze_filter
60 @staticmethod
61 def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool:
62 """filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)"""
63 return bool(
64 (np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance).all()
65 )
67 @register_dataset_filter
68 @staticmethod
69 def cut_percentile_shortest(
70 dataset: MazeDataset,
71 percentile: float = 10.0,
72 ) -> MazeDataset:
73 """cut the shortest `percentile` of mazes from the dataset
75 `percentile` is 1-100, not 0-1, as this is what `np.percentile` expects
76 """
77 lengths: np.ndarray = np.array([len(m.solution) for m in dataset])
78 cutoff: int = int(np.percentile(lengths, percentile))
80 filtered_mazes: list[SolvedMaze] = [
81 m for m in dataset if len(m.solution) > cutoff
82 ]
83 new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes)
85 return copy.deepcopy(new_dataset)
87 @register_dataset_filter
88 @staticmethod
89 def truncate_count(
90 dataset: MazeDataset,
91 max_count: int,
92 ) -> MazeDataset:
93 """truncate the dataset to be at most `max_count` mazes"""
94 new_dataset: MazeDataset = MazeDataset(
95 cfg=dataset.cfg,
96 mazes=dataset.mazes[:max_count],
97 )
98 return copy.deepcopy(new_dataset)
100 @register_dataset_filter
101 @staticmethod
102 def remove_duplicates(
103 dataset: MazeDataset,
104 minimum_difference_connection_list: int | None = 1,
105 minimum_difference_solution: int | None = 1,
106 _max_dataset_len_threshold: int = 1000,
107 ) -> MazeDataset:
108 """remove duplicates from a dataset, keeping the **LAST** unique maze
110 set minimum either minimum difference to `None` to disable checking
112 if you want to avoid mazes which have more overlap, set the minimum difference to be greater
114 Gotchas:
115 - if two mazes are of different sizes, they will never be considered duplicates
116 - if two solutions are of different lengths, they will never be considered duplicates
118 TODO: check for overlap?
119 """
120 if len(dataset) > _max_dataset_len_threshold:
121 raise ValueError(
122 "this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n",
123 "if you know what you're doing, change `_max_dataset_len_threshold`",
124 )
126 unique_mazes: list[SolvedMaze] = list()
128 maze_a: SolvedMaze
129 maze_b: SolvedMaze
130 for i, maze_a in enumerate(dataset.mazes):
131 a_unique: bool = True
132 for maze_b in dataset.mazes[i + 1 :]:
133 # after all that nesting, more nesting to perform checks
134 if (minimum_difference_connection_list is not None) and ( # noqa: SIM102
135 maze_a.connection_list.shape == maze_b.connection_list.shape
136 ):
137 if (
138 np.sum(maze_a.connection_list != maze_b.connection_list)
139 <= minimum_difference_connection_list
140 ):
141 a_unique = False
142 break
144 if (minimum_difference_solution is not None) and ( # noqa: SIM102
145 maze_a.solution.shape == maze_b.solution.shape
146 ):
147 if (
148 np.sum(maze_a.solution != maze_b.solution)
149 <= minimum_difference_solution
150 ):
151 a_unique = False
152 break
154 if a_unique:
155 unique_mazes.append(maze_a)
157 return copy.deepcopy(
158 MazeDataset(
159 cfg=dataset.cfg,
160 mazes=unique_mazes,
161 generation_metadata_collected=dataset.generation_metadata_collected,
162 ),
163 )
165 @register_dataset_filter
166 @staticmethod
167 def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset:
168 """remove duplicates from a dataset"""
169 unique_mazes = list(dict.fromkeys(dataset.mazes))
170 return copy.deepcopy(
171 MazeDataset(
172 cfg=dataset.cfg,
173 mazes=unique_mazes,
174 generation_metadata_collected=dataset.generation_metadata_collected,
175 ),
176 )
178 @register_dataset_filter
179 @staticmethod
180 def strip_generation_meta(dataset: MazeDataset) -> MazeDataset:
181 """strip the generation meta from the dataset"""
182 new_dataset: MazeDataset = copy.deepcopy(dataset)
183 for maze in new_dataset:
184 # hacky because it's a frozen dataclass
185 maze.__dict__["generation_meta"] = None
186 return new_dataset
188 @register_dataset_filter
189 @staticmethod
190 # yes, this function is complicated hence the noqa
191 def collect_generation_meta( # noqa: C901, PLR0912
192 dataset: MazeDataset,
193 clear_in_mazes: bool = True,
194 inplace: bool = True,
195 allow_fail: bool = False,
196 ) -> MazeDataset:
197 """collect the generation metadata from each maze into a dataset-level metadata (saves space)
199 # Parameters:
200 - `dataset : MazeDataset`
201 - `clear_in_mazes : bool`
202 whether to clear the generation meta in the mazes after collecting it, keep it there if `False`
203 (defaults to `True`)
204 - `inplace : bool`
205 whether to modify the dataset in place or return a new one
206 (defaults to `True`)
207 - `allow_fail : bool`
208 whether to allow the collection to fail if the generation meta is not present in a maze
209 (defaults to `False`)
211 # Returns:
212 - `MazeDataset`
213 the dataset with the generation metadata collected
215 # Raises:
216 - `ValueError` : if the generation meta is not present in a maze and `allow_fail` is `False`
217 - `ValueError` : if we have other problems converting the generation metadata
218 - `TypeError` : if the generation meta on a maze is of an unexpected type
219 """
220 if dataset.generation_metadata_collected is not None:
221 return dataset
222 else:
223 assert dataset[0].generation_meta is not None, (
224 "generation meta is not collected and original is not present"
225 )
226 # if the generation meta is already collected, don't collect it again, do nothing
228 new_dataset: MazeDataset
229 if inplace:
230 new_dataset = dataset
231 else:
232 new_dataset = copy.deepcopy(dataset)
234 gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = (
235 defaultdict(Counter)
236 )
237 for maze in new_dataset:
238 if maze.generation_meta is None:
239 if allow_fail:
240 break
241 raise ValueError(
242 "generation meta is not present in a maze, cannot collect generation meta",
243 )
244 for key, value in maze.generation_meta.items():
245 if isinstance(value, (bool, int, float, str)): # noqa: UP038
246 gen_meta_lists[key][value] += 1
248 elif isinstance(value, set):
249 # special case for visited_cells
250 gen_meta_lists[key].update(value)
252 elif isinstance(value, (list, np.ndarray)): # noqa: UP038
253 if isinstance(value, list):
254 # TODO: `for` loop variable `value` overwritten by assignment target (Ruff PLW2901)
255 try:
256 value = np.array(value) # noqa: PLW2901
257 except ValueError as convert_to_np_err:
258 err_msg = (
259 f"Cannot collect generation meta for {key} as it is a list of type '{type(value[0]) = !s}'"
260 "\nexpected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords"
261 )
262 raise ValueError(err_msg) from convert_to_np_err
264 if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim):
265 # assume its a single coordinate
266 gen_meta_lists[key][tuple(value)] += 1
267 # magic value is fine here
268 elif (len(value.shape) == 2) and ( # noqa: PLR2004
269 value.shape[1] == maze.lattice_dim
270 ):
271 # assume its a list of coordinates
272 gen_meta_lists[key].update([tuple(v) for v in value])
273 else:
274 err_msg = (
275 f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}\n"
276 "expected either a coord of shape (2,) or a list of coords of shape (n, 2)"
277 )
278 raise ValueError(err_msg)
279 else:
280 err_msg = (
281 f"Cannot collect generation meta for {key} as it is of type '{type(value)!s}'\n"
282 "expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords"
283 )
284 raise TypeError(err_msg)
286 # clear the data
287 if clear_in_mazes:
288 # hacky because it's a frozen dataclass
289 maze.__dict__["generation_meta"] = None
291 new_dataset.generation_metadata_collected = {
292 key: dict(value) for key, value in gen_meta_lists.items()
293 }
295 return new_dataset