maze_dataset.dataset.filters
filtering MazeDataset
s
1"filtering `MazeDataset`s" 2 3import copy 4import functools 5import typing 6from collections import Counter, defaultdict 7 8import numpy as np 9 10from maze_dataset.constants import CoordTup 11from maze_dataset.dataset.dataset import ( 12 DatasetFilterFunc, 13 register_dataset_filter, 14 register_filter_namespace_for_dataset, 15) 16from maze_dataset.dataset.maze_dataset import MazeDataset 17from maze_dataset.maze import SolvedMaze 18 19 20def register_maze_filter( 21 method: typing.Callable[[SolvedMaze, typing.Any], bool], 22) -> DatasetFilterFunc: 23 """register a maze filter, casting it to operate over the whole list of mazes 24 25 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 26 27 this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays 28 """ 29 30 @functools.wraps(method) 31 def wrapper(dataset: MazeDataset, *args, **kwargs) -> MazeDataset: 32 # copy and filter 33 new_dataset: MazeDataset = copy.deepcopy( 34 MazeDataset( 35 cfg=dataset.cfg, 36 mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)], 37 ), 38 ) 39 # update the config 40 new_dataset.cfg.applied_filters.append( 41 dict(name=method.__name__, args=args, kwargs=kwargs), 42 ) 43 new_dataset.update_self_config() 44 return new_dataset 45 46 return wrapper 47 48 49@register_filter_namespace_for_dataset(MazeDataset) 50class MazeDatasetFilters: 51 "namespace for filters for `MazeDataset`s" 52 53 @register_maze_filter 54 @staticmethod 55 def path_length(maze: SolvedMaze, min_length: int) -> bool: 56 """filter out mazes with a solution length less than `min_length`""" 57 return len(maze.solution) >= min_length 58 59 @register_maze_filter 60 @staticmethod 61 def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool: 62 """filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)""" 63 return bool( 64 (np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance).all() 65 ) 66 67 @register_dataset_filter 68 @staticmethod 69 def cut_percentile_shortest( 70 dataset: MazeDataset, 71 percentile: float = 10.0, 72 ) -> MazeDataset: 73 """cut the shortest `percentile` of mazes from the dataset 74 75 `percentile` is 1-100, not 0-1, as this is what `np.percentile` expects 76 """ 77 lengths: np.ndarray = np.array([len(m.solution) for m in dataset]) 78 cutoff: int = int(np.percentile(lengths, percentile)) 79 80 filtered_mazes: list[SolvedMaze] = [ 81 m for m in dataset if len(m.solution) > cutoff 82 ] 83 new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes) 84 85 return copy.deepcopy(new_dataset) 86 87 @register_dataset_filter 88 @staticmethod 89 def truncate_count( 90 dataset: MazeDataset, 91 max_count: int, 92 ) -> MazeDataset: 93 """truncate the dataset to be at most `max_count` mazes""" 94 new_dataset: MazeDataset = MazeDataset( 95 cfg=dataset.cfg, 96 mazes=dataset.mazes[:max_count], 97 ) 98 return copy.deepcopy(new_dataset) 99 100 @register_dataset_filter 101 @staticmethod 102 def remove_duplicates( 103 dataset: MazeDataset, 104 minimum_difference_connection_list: int | None = 1, 105 minimum_difference_solution: int | None = 1, 106 _max_dataset_len_threshold: int = 1000, 107 ) -> MazeDataset: 108 """remove duplicates from a dataset, keeping the **LAST** unique maze 109 110 set minimum either minimum difference to `None` to disable checking 111 112 if you want to avoid mazes which have more overlap, set the minimum difference to be greater 113 114 Gotchas: 115 - if two mazes are of different sizes, they will never be considered duplicates 116 - if two solutions are of different lengths, they will never be considered duplicates 117 118 TODO: check for overlap? 119 """ 120 if len(dataset) > _max_dataset_len_threshold: 121 raise ValueError( 122 "this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n", 123 "if you know what you're doing, change `_max_dataset_len_threshold`", 124 ) 125 126 unique_mazes: list[SolvedMaze] = list() 127 128 maze_a: SolvedMaze 129 maze_b: SolvedMaze 130 for i, maze_a in enumerate(dataset.mazes): 131 a_unique: bool = True 132 for maze_b in dataset.mazes[i + 1 :]: 133 # after all that nesting, more nesting to perform checks 134 if (minimum_difference_connection_list is not None) and ( # noqa: SIM102 135 maze_a.connection_list.shape == maze_b.connection_list.shape 136 ): 137 if ( 138 np.sum(maze_a.connection_list != maze_b.connection_list) 139 <= minimum_difference_connection_list 140 ): 141 a_unique = False 142 break 143 144 if (minimum_difference_solution is not None) and ( # noqa: SIM102 145 maze_a.solution.shape == maze_b.solution.shape 146 ): 147 if ( 148 np.sum(maze_a.solution != maze_b.solution) 149 <= minimum_difference_solution 150 ): 151 a_unique = False 152 break 153 154 if a_unique: 155 unique_mazes.append(maze_a) 156 157 return copy.deepcopy( 158 MazeDataset( 159 cfg=dataset.cfg, 160 mazes=unique_mazes, 161 generation_metadata_collected=dataset.generation_metadata_collected, 162 ), 163 ) 164 165 @register_dataset_filter 166 @staticmethod 167 def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset: 168 """remove duplicates from a dataset""" 169 unique_mazes = list(dict.fromkeys(dataset.mazes)) 170 return copy.deepcopy( 171 MazeDataset( 172 cfg=dataset.cfg, 173 mazes=unique_mazes, 174 generation_metadata_collected=dataset.generation_metadata_collected, 175 ), 176 ) 177 178 @register_dataset_filter 179 @staticmethod 180 def strip_generation_meta(dataset: MazeDataset) -> MazeDataset: 181 """strip the generation meta from the dataset""" 182 new_dataset: MazeDataset = copy.deepcopy(dataset) 183 for maze in new_dataset: 184 # hacky because it's a frozen dataclass 185 maze.__dict__["generation_meta"] = None 186 return new_dataset 187 188 @register_dataset_filter 189 @staticmethod 190 # yes, this function is complicated hence the noqa 191 def collect_generation_meta( # noqa: C901, PLR0912 192 dataset: MazeDataset, 193 clear_in_mazes: bool = True, 194 inplace: bool = True, 195 allow_fail: bool = False, 196 ) -> MazeDataset: 197 """collect the generation metadata from each maze into a dataset-level metadata (saves space) 198 199 # Parameters: 200 - `dataset : MazeDataset` 201 - `clear_in_mazes : bool` 202 whether to clear the generation meta in the mazes after collecting it, keep it there if `False` 203 (defaults to `True`) 204 - `inplace : bool` 205 whether to modify the dataset in place or return a new one 206 (defaults to `True`) 207 - `allow_fail : bool` 208 whether to allow the collection to fail if the generation meta is not present in a maze 209 (defaults to `False`) 210 211 # Returns: 212 - `MazeDataset` 213 the dataset with the generation metadata collected 214 215 # Raises: 216 - `ValueError` : if the generation meta is not present in a maze and `allow_fail` is `False` 217 - `ValueError` : if we have other problems converting the generation metadata 218 - `TypeError` : if the generation meta on a maze is of an unexpected type 219 """ 220 if dataset.generation_metadata_collected is not None: 221 return dataset 222 else: 223 assert dataset[0].generation_meta is not None, ( 224 "generation meta is not collected and original is not present" 225 ) 226 # if the generation meta is already collected, don't collect it again, do nothing 227 228 new_dataset: MazeDataset 229 if inplace: 230 new_dataset = dataset 231 else: 232 new_dataset = copy.deepcopy(dataset) 233 234 gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = ( 235 defaultdict(Counter) 236 ) 237 for maze in new_dataset: 238 if maze.generation_meta is None: 239 if allow_fail: 240 break 241 raise ValueError( 242 "generation meta is not present in a maze, cannot collect generation meta", 243 ) 244 for key, value in maze.generation_meta.items(): 245 if isinstance(value, (bool, int, float, str)): # noqa: UP038 246 gen_meta_lists[key][value] += 1 247 248 elif isinstance(value, set): 249 # special case for visited_cells 250 gen_meta_lists[key].update(value) 251 252 elif isinstance(value, (list, np.ndarray)): # noqa: UP038 253 if isinstance(value, list): 254 # TODO: `for` loop variable `value` overwritten by assignment target (Ruff PLW2901) 255 try: 256 value = np.array(value) # noqa: PLW2901 257 except ValueError as convert_to_np_err: 258 err_msg = ( 259 f"Cannot collect generation meta for {key} as it is a list of type '{type(value[0]) = !s}'" 260 "\nexpected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords" 261 ) 262 raise ValueError(err_msg) from convert_to_np_err 263 264 if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim): 265 # assume its a single coordinate 266 gen_meta_lists[key][tuple(value)] += 1 267 # magic value is fine here 268 elif (len(value.shape) == 2) and ( # noqa: PLR2004 269 value.shape[1] == maze.lattice_dim 270 ): 271 # assume its a list of coordinates 272 gen_meta_lists[key].update([tuple(v) for v in value]) 273 else: 274 err_msg = ( 275 f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}\n" 276 "expected either a coord of shape (2,) or a list of coords of shape (n, 2)" 277 ) 278 raise ValueError(err_msg) 279 else: 280 err_msg = ( 281 f"Cannot collect generation meta for {key} as it is of type '{type(value)!s}'\n" 282 "expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords" 283 ) 284 raise TypeError(err_msg) 285 286 # clear the data 287 if clear_in_mazes: 288 # hacky because it's a frozen dataclass 289 maze.__dict__["generation_meta"] = None 290 291 new_dataset.generation_metadata_collected = { 292 key: dict(value) for key, value in gen_meta_lists.items() 293 } 294 295 return new_dataset
21def register_maze_filter( 22 method: typing.Callable[[SolvedMaze, typing.Any], bool], 23) -> DatasetFilterFunc: 24 """register a maze filter, casting it to operate over the whole list of mazes 25 26 method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset` 27 28 this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays 29 """ 30 31 @functools.wraps(method) 32 def wrapper(dataset: MazeDataset, *args, **kwargs) -> MazeDataset: 33 # copy and filter 34 new_dataset: MazeDataset = copy.deepcopy( 35 MazeDataset( 36 cfg=dataset.cfg, 37 mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)], 38 ), 39 ) 40 # update the config 41 new_dataset.cfg.applied_filters.append( 42 dict(name=method.__name__, args=args, kwargs=kwargs), 43 ) 44 new_dataset.update_self_config() 45 return new_dataset 46 47 return wrapper
register a maze filter, casting it to operate over the whole list of mazes
method should be a staticmethod of a namespace class registered with register_filter_namespace_for_dataset
this is a more restricted version of register_dataset_filter
that removes the need for boilerplate for operating over the arrays
50@register_filter_namespace_for_dataset(MazeDataset) 51class MazeDatasetFilters: 52 "namespace for filters for `MazeDataset`s" 53 54 @register_maze_filter 55 @staticmethod 56 def path_length(maze: SolvedMaze, min_length: int) -> bool: 57 """filter out mazes with a solution length less than `min_length`""" 58 return len(maze.solution) >= min_length 59 60 @register_maze_filter 61 @staticmethod 62 def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool: 63 """filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)""" 64 return bool( 65 (np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance).all() 66 ) 67 68 @register_dataset_filter 69 @staticmethod 70 def cut_percentile_shortest( 71 dataset: MazeDataset, 72 percentile: float = 10.0, 73 ) -> MazeDataset: 74 """cut the shortest `percentile` of mazes from the dataset 75 76 `percentile` is 1-100, not 0-1, as this is what `np.percentile` expects 77 """ 78 lengths: np.ndarray = np.array([len(m.solution) for m in dataset]) 79 cutoff: int = int(np.percentile(lengths, percentile)) 80 81 filtered_mazes: list[SolvedMaze] = [ 82 m for m in dataset if len(m.solution) > cutoff 83 ] 84 new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes) 85 86 return copy.deepcopy(new_dataset) 87 88 @register_dataset_filter 89 @staticmethod 90 def truncate_count( 91 dataset: MazeDataset, 92 max_count: int, 93 ) -> MazeDataset: 94 """truncate the dataset to be at most `max_count` mazes""" 95 new_dataset: MazeDataset = MazeDataset( 96 cfg=dataset.cfg, 97 mazes=dataset.mazes[:max_count], 98 ) 99 return copy.deepcopy(new_dataset) 100 101 @register_dataset_filter 102 @staticmethod 103 def remove_duplicates( 104 dataset: MazeDataset, 105 minimum_difference_connection_list: int | None = 1, 106 minimum_difference_solution: int | None = 1, 107 _max_dataset_len_threshold: int = 1000, 108 ) -> MazeDataset: 109 """remove duplicates from a dataset, keeping the **LAST** unique maze 110 111 set minimum either minimum difference to `None` to disable checking 112 113 if you want to avoid mazes which have more overlap, set the minimum difference to be greater 114 115 Gotchas: 116 - if two mazes are of different sizes, they will never be considered duplicates 117 - if two solutions are of different lengths, they will never be considered duplicates 118 119 TODO: check for overlap? 120 """ 121 if len(dataset) > _max_dataset_len_threshold: 122 raise ValueError( 123 "this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n", 124 "if you know what you're doing, change `_max_dataset_len_threshold`", 125 ) 126 127 unique_mazes: list[SolvedMaze] = list() 128 129 maze_a: SolvedMaze 130 maze_b: SolvedMaze 131 for i, maze_a in enumerate(dataset.mazes): 132 a_unique: bool = True 133 for maze_b in dataset.mazes[i + 1 :]: 134 # after all that nesting, more nesting to perform checks 135 if (minimum_difference_connection_list is not None) and ( # noqa: SIM102 136 maze_a.connection_list.shape == maze_b.connection_list.shape 137 ): 138 if ( 139 np.sum(maze_a.connection_list != maze_b.connection_list) 140 <= minimum_difference_connection_list 141 ): 142 a_unique = False 143 break 144 145 if (minimum_difference_solution is not None) and ( # noqa: SIM102 146 maze_a.solution.shape == maze_b.solution.shape 147 ): 148 if ( 149 np.sum(maze_a.solution != maze_b.solution) 150 <= minimum_difference_solution 151 ): 152 a_unique = False 153 break 154 155 if a_unique: 156 unique_mazes.append(maze_a) 157 158 return copy.deepcopy( 159 MazeDataset( 160 cfg=dataset.cfg, 161 mazes=unique_mazes, 162 generation_metadata_collected=dataset.generation_metadata_collected, 163 ), 164 ) 165 166 @register_dataset_filter 167 @staticmethod 168 def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset: 169 """remove duplicates from a dataset""" 170 unique_mazes = list(dict.fromkeys(dataset.mazes)) 171 return copy.deepcopy( 172 MazeDataset( 173 cfg=dataset.cfg, 174 mazes=unique_mazes, 175 generation_metadata_collected=dataset.generation_metadata_collected, 176 ), 177 ) 178 179 @register_dataset_filter 180 @staticmethod 181 def strip_generation_meta(dataset: MazeDataset) -> MazeDataset: 182 """strip the generation meta from the dataset""" 183 new_dataset: MazeDataset = copy.deepcopy(dataset) 184 for maze in new_dataset: 185 # hacky because it's a frozen dataclass 186 maze.__dict__["generation_meta"] = None 187 return new_dataset 188 189 @register_dataset_filter 190 @staticmethod 191 # yes, this function is complicated hence the noqa 192 def collect_generation_meta( # noqa: C901, PLR0912 193 dataset: MazeDataset, 194 clear_in_mazes: bool = True, 195 inplace: bool = True, 196 allow_fail: bool = False, 197 ) -> MazeDataset: 198 """collect the generation metadata from each maze into a dataset-level metadata (saves space) 199 200 # Parameters: 201 - `dataset : MazeDataset` 202 - `clear_in_mazes : bool` 203 whether to clear the generation meta in the mazes after collecting it, keep it there if `False` 204 (defaults to `True`) 205 - `inplace : bool` 206 whether to modify the dataset in place or return a new one 207 (defaults to `True`) 208 - `allow_fail : bool` 209 whether to allow the collection to fail if the generation meta is not present in a maze 210 (defaults to `False`) 211 212 # Returns: 213 - `MazeDataset` 214 the dataset with the generation metadata collected 215 216 # Raises: 217 - `ValueError` : if the generation meta is not present in a maze and `allow_fail` is `False` 218 - `ValueError` : if we have other problems converting the generation metadata 219 - `TypeError` : if the generation meta on a maze is of an unexpected type 220 """ 221 if dataset.generation_metadata_collected is not None: 222 return dataset 223 else: 224 assert dataset[0].generation_meta is not None, ( 225 "generation meta is not collected and original is not present" 226 ) 227 # if the generation meta is already collected, don't collect it again, do nothing 228 229 new_dataset: MazeDataset 230 if inplace: 231 new_dataset = dataset 232 else: 233 new_dataset = copy.deepcopy(dataset) 234 235 gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = ( 236 defaultdict(Counter) 237 ) 238 for maze in new_dataset: 239 if maze.generation_meta is None: 240 if allow_fail: 241 break 242 raise ValueError( 243 "generation meta is not present in a maze, cannot collect generation meta", 244 ) 245 for key, value in maze.generation_meta.items(): 246 if isinstance(value, (bool, int, float, str)): # noqa: UP038 247 gen_meta_lists[key][value] += 1 248 249 elif isinstance(value, set): 250 # special case for visited_cells 251 gen_meta_lists[key].update(value) 252 253 elif isinstance(value, (list, np.ndarray)): # noqa: UP038 254 if isinstance(value, list): 255 # TODO: `for` loop variable `value` overwritten by assignment target (Ruff PLW2901) 256 try: 257 value = np.array(value) # noqa: PLW2901 258 except ValueError as convert_to_np_err: 259 err_msg = ( 260 f"Cannot collect generation meta for {key} as it is a list of type '{type(value[0]) = !s}'" 261 "\nexpected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords" 262 ) 263 raise ValueError(err_msg) from convert_to_np_err 264 265 if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim): 266 # assume its a single coordinate 267 gen_meta_lists[key][tuple(value)] += 1 268 # magic value is fine here 269 elif (len(value.shape) == 2) and ( # noqa: PLR2004 270 value.shape[1] == maze.lattice_dim 271 ): 272 # assume its a list of coordinates 273 gen_meta_lists[key].update([tuple(v) for v in value]) 274 else: 275 err_msg = ( 276 f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}\n" 277 "expected either a coord of shape (2,) or a list of coords of shape (n, 2)" 278 ) 279 raise ValueError(err_msg) 280 else: 281 err_msg = ( 282 f"Cannot collect generation meta for {key} as it is of type '{type(value)!s}'\n" 283 "expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords" 284 ) 285 raise TypeError(err_msg) 286 287 # clear the data 288 if clear_in_mazes: 289 # hacky because it's a frozen dataclass 290 maze.__dict__["generation_meta"] = None 291 292 new_dataset.generation_metadata_collected = { 293 key: dict(value) for key, value in gen_meta_lists.items() 294 } 295 296 return new_dataset
namespace for filters for MazeDataset
s
54 @register_maze_filter 55 @staticmethod 56 def path_length(maze: SolvedMaze, min_length: int) -> bool: 57 """filter out mazes with a solution length less than `min_length`""" 58 return len(maze.solution) >= min_length
filter out mazes with a solution length less than min_length
60 @register_maze_filter 61 @staticmethod 62 def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool: 63 """filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)""" 64 return bool( 65 (np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance).all() 66 )
filter out datasets where the start and end pos are less than min_distance
apart on the manhattan distance (ignoring walls)
68 @register_dataset_filter 69 @staticmethod 70 def cut_percentile_shortest( 71 dataset: MazeDataset, 72 percentile: float = 10.0, 73 ) -> MazeDataset: 74 """cut the shortest `percentile` of mazes from the dataset 75 76 `percentile` is 1-100, not 0-1, as this is what `np.percentile` expects 77 """ 78 lengths: np.ndarray = np.array([len(m.solution) for m in dataset]) 79 cutoff: int = int(np.percentile(lengths, percentile)) 80 81 filtered_mazes: list[SolvedMaze] = [ 82 m for m in dataset if len(m.solution) > cutoff 83 ] 84 new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes) 85 86 return copy.deepcopy(new_dataset)
cut the shortest percentile
of mazes from the dataset
percentile
is 1-100, not 0-1, as this is what np.percentile
expects
88 @register_dataset_filter 89 @staticmethod 90 def truncate_count( 91 dataset: MazeDataset, 92 max_count: int, 93 ) -> MazeDataset: 94 """truncate the dataset to be at most `max_count` mazes""" 95 new_dataset: MazeDataset = MazeDataset( 96 cfg=dataset.cfg, 97 mazes=dataset.mazes[:max_count], 98 ) 99 return copy.deepcopy(new_dataset)
truncate the dataset to be at most max_count
mazes
101 @register_dataset_filter 102 @staticmethod 103 def remove_duplicates( 104 dataset: MazeDataset, 105 minimum_difference_connection_list: int | None = 1, 106 minimum_difference_solution: int | None = 1, 107 _max_dataset_len_threshold: int = 1000, 108 ) -> MazeDataset: 109 """remove duplicates from a dataset, keeping the **LAST** unique maze 110 111 set minimum either minimum difference to `None` to disable checking 112 113 if you want to avoid mazes which have more overlap, set the minimum difference to be greater 114 115 Gotchas: 116 - if two mazes are of different sizes, they will never be considered duplicates 117 - if two solutions are of different lengths, they will never be considered duplicates 118 119 TODO: check for overlap? 120 """ 121 if len(dataset) > _max_dataset_len_threshold: 122 raise ValueError( 123 "this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n", 124 "if you know what you're doing, change `_max_dataset_len_threshold`", 125 ) 126 127 unique_mazes: list[SolvedMaze] = list() 128 129 maze_a: SolvedMaze 130 maze_b: SolvedMaze 131 for i, maze_a in enumerate(dataset.mazes): 132 a_unique: bool = True 133 for maze_b in dataset.mazes[i + 1 :]: 134 # after all that nesting, more nesting to perform checks 135 if (minimum_difference_connection_list is not None) and ( # noqa: SIM102 136 maze_a.connection_list.shape == maze_b.connection_list.shape 137 ): 138 if ( 139 np.sum(maze_a.connection_list != maze_b.connection_list) 140 <= minimum_difference_connection_list 141 ): 142 a_unique = False 143 break 144 145 if (minimum_difference_solution is not None) and ( # noqa: SIM102 146 maze_a.solution.shape == maze_b.solution.shape 147 ): 148 if ( 149 np.sum(maze_a.solution != maze_b.solution) 150 <= minimum_difference_solution 151 ): 152 a_unique = False 153 break 154 155 if a_unique: 156 unique_mazes.append(maze_a) 157 158 return copy.deepcopy( 159 MazeDataset( 160 cfg=dataset.cfg, 161 mazes=unique_mazes, 162 generation_metadata_collected=dataset.generation_metadata_collected, 163 ), 164 )
remove duplicates from a dataset, keeping the LAST unique maze
set minimum either minimum difference to None
to disable checking
if you want to avoid mazes which have more overlap, set the minimum difference to be greater
Gotchas:
- if two mazes are of different sizes, they will never be considered duplicates
- if two solutions are of different lengths, they will never be considered duplicates
TODO: check for overlap?
166 @register_dataset_filter 167 @staticmethod 168 def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset: 169 """remove duplicates from a dataset""" 170 unique_mazes = list(dict.fromkeys(dataset.mazes)) 171 return copy.deepcopy( 172 MazeDataset( 173 cfg=dataset.cfg, 174 mazes=unique_mazes, 175 generation_metadata_collected=dataset.generation_metadata_collected, 176 ), 177 )
remove duplicates from a dataset
179 @register_dataset_filter 180 @staticmethod 181 def strip_generation_meta(dataset: MazeDataset) -> MazeDataset: 182 """strip the generation meta from the dataset""" 183 new_dataset: MazeDataset = copy.deepcopy(dataset) 184 for maze in new_dataset: 185 # hacky because it's a frozen dataclass 186 maze.__dict__["generation_meta"] = None 187 return new_dataset
strip the generation meta from the dataset
189 @register_dataset_filter 190 @staticmethod 191 # yes, this function is complicated hence the noqa 192 def collect_generation_meta( # noqa: C901, PLR0912 193 dataset: MazeDataset, 194 clear_in_mazes: bool = True, 195 inplace: bool = True, 196 allow_fail: bool = False, 197 ) -> MazeDataset: 198 """collect the generation metadata from each maze into a dataset-level metadata (saves space) 199 200 # Parameters: 201 - `dataset : MazeDataset` 202 - `clear_in_mazes : bool` 203 whether to clear the generation meta in the mazes after collecting it, keep it there if `False` 204 (defaults to `True`) 205 - `inplace : bool` 206 whether to modify the dataset in place or return a new one 207 (defaults to `True`) 208 - `allow_fail : bool` 209 whether to allow the collection to fail if the generation meta is not present in a maze 210 (defaults to `False`) 211 212 # Returns: 213 - `MazeDataset` 214 the dataset with the generation metadata collected 215 216 # Raises: 217 - `ValueError` : if the generation meta is not present in a maze and `allow_fail` is `False` 218 - `ValueError` : if we have other problems converting the generation metadata 219 - `TypeError` : if the generation meta on a maze is of an unexpected type 220 """ 221 if dataset.generation_metadata_collected is not None: 222 return dataset 223 else: 224 assert dataset[0].generation_meta is not None, ( 225 "generation meta is not collected and original is not present" 226 ) 227 # if the generation meta is already collected, don't collect it again, do nothing 228 229 new_dataset: MazeDataset 230 if inplace: 231 new_dataset = dataset 232 else: 233 new_dataset = copy.deepcopy(dataset) 234 235 gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = ( 236 defaultdict(Counter) 237 ) 238 for maze in new_dataset: 239 if maze.generation_meta is None: 240 if allow_fail: 241 break 242 raise ValueError( 243 "generation meta is not present in a maze, cannot collect generation meta", 244 ) 245 for key, value in maze.generation_meta.items(): 246 if isinstance(value, (bool, int, float, str)): # noqa: UP038 247 gen_meta_lists[key][value] += 1 248 249 elif isinstance(value, set): 250 # special case for visited_cells 251 gen_meta_lists[key].update(value) 252 253 elif isinstance(value, (list, np.ndarray)): # noqa: UP038 254 if isinstance(value, list): 255 # TODO: `for` loop variable `value` overwritten by assignment target (Ruff PLW2901) 256 try: 257 value = np.array(value) # noqa: PLW2901 258 except ValueError as convert_to_np_err: 259 err_msg = ( 260 f"Cannot collect generation meta for {key} as it is a list of type '{type(value[0]) = !s}'" 261 "\nexpected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords" 262 ) 263 raise ValueError(err_msg) from convert_to_np_err 264 265 if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim): 266 # assume its a single coordinate 267 gen_meta_lists[key][tuple(value)] += 1 268 # magic value is fine here 269 elif (len(value.shape) == 2) and ( # noqa: PLR2004 270 value.shape[1] == maze.lattice_dim 271 ): 272 # assume its a list of coordinates 273 gen_meta_lists[key].update([tuple(v) for v in value]) 274 else: 275 err_msg = ( 276 f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}\n" 277 "expected either a coord of shape (2,) or a list of coords of shape (n, 2)" 278 ) 279 raise ValueError(err_msg) 280 else: 281 err_msg = ( 282 f"Cannot collect generation meta for {key} as it is of type '{type(value)!s}'\n" 283 "expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords" 284 ) 285 raise TypeError(err_msg) 286 287 # clear the data 288 if clear_in_mazes: 289 # hacky because it's a frozen dataclass 290 maze.__dict__["generation_meta"] = None 291 292 new_dataset.generation_metadata_collected = { 293 key: dict(value) for key, value in gen_meta_lists.items() 294 } 295 296 return new_dataset
collect the generation metadata from each maze into a dataset-level metadata (saves space)
Parameters:
dataset : MazeDataset
clear_in_mazes : bool
whether to clear the generation meta in the mazes after collecting it, keep it there ifFalse
(defaults toTrue
)inplace : bool
whether to modify the dataset in place or return a new one (defaults toTrue
)allow_fail : bool
whether to allow the collection to fail if the generation meta is not present in a maze (defaults toFalse
)
Returns:
MazeDataset
the dataset with the generation metadata collected
Raises:
ValueError
: if the generation meta is not present in a maze andallow_fail
isFalse
ValueError
: if we have other problems converting the generation metadataTypeError
: if the generation meta on a maze is of an unexpected type