docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.dataset.filters

filtering MazeDatasets


  1"filtering `MazeDataset`s"
  2
  3import copy
  4import functools
  5import typing
  6from collections import Counter, defaultdict
  7
  8import numpy as np
  9
 10from maze_dataset.constants import CoordTup
 11from maze_dataset.dataset.dataset import (
 12	DatasetFilterFunc,
 13	register_dataset_filter,
 14	register_filter_namespace_for_dataset,
 15)
 16from maze_dataset.dataset.maze_dataset import MazeDataset
 17from maze_dataset.maze import SolvedMaze
 18
 19
 20def register_maze_filter(
 21	method: typing.Callable[[SolvedMaze, typing.Any], bool],
 22) -> DatasetFilterFunc:
 23	"""register a maze filter, casting it to operate over the whole list of mazes
 24
 25	method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
 26
 27	this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays
 28	"""
 29
 30	@functools.wraps(method)
 31	def wrapper(dataset: MazeDataset, *args, **kwargs) -> MazeDataset:
 32		# copy and filter
 33		new_dataset: MazeDataset = copy.deepcopy(
 34			MazeDataset(
 35				cfg=dataset.cfg,
 36				mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)],
 37			),
 38		)
 39		# update the config
 40		new_dataset.cfg.applied_filters.append(
 41			dict(name=method.__name__, args=args, kwargs=kwargs),
 42		)
 43		new_dataset.update_self_config()
 44		return new_dataset
 45
 46	return wrapper
 47
 48
 49@register_filter_namespace_for_dataset(MazeDataset)
 50class MazeDatasetFilters:
 51	"namespace for filters for `MazeDataset`s"
 52
 53	@register_maze_filter
 54	@staticmethod
 55	def path_length(maze: SolvedMaze, min_length: int) -> bool:
 56		"""filter out mazes with a solution length less than `min_length`"""
 57		return len(maze.solution) >= min_length
 58
 59	@register_maze_filter
 60	@staticmethod
 61	def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool:
 62		"""filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)"""
 63		return bool(
 64			(np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance).all()
 65		)
 66
 67	@register_dataset_filter
 68	@staticmethod
 69	def cut_percentile_shortest(
 70		dataset: MazeDataset,
 71		percentile: float = 10.0,
 72	) -> MazeDataset:
 73		"""cut the shortest `percentile` of mazes from the dataset
 74
 75		`percentile` is 1-100, not 0-1, as this is what `np.percentile` expects
 76		"""
 77		lengths: np.ndarray = np.array([len(m.solution) for m in dataset])
 78		cutoff: int = int(np.percentile(lengths, percentile))
 79
 80		filtered_mazes: list[SolvedMaze] = [
 81			m for m in dataset if len(m.solution) > cutoff
 82		]
 83		new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes)
 84
 85		return copy.deepcopy(new_dataset)
 86
 87	@register_dataset_filter
 88	@staticmethod
 89	def truncate_count(
 90		dataset: MazeDataset,
 91		max_count: int,
 92	) -> MazeDataset:
 93		"""truncate the dataset to be at most `max_count` mazes"""
 94		new_dataset: MazeDataset = MazeDataset(
 95			cfg=dataset.cfg,
 96			mazes=dataset.mazes[:max_count],
 97		)
 98		return copy.deepcopy(new_dataset)
 99
100	@register_dataset_filter
101	@staticmethod
102	def remove_duplicates(
103		dataset: MazeDataset,
104		minimum_difference_connection_list: int | None = 1,
105		minimum_difference_solution: int | None = 1,
106		_max_dataset_len_threshold: int = 1000,
107	) -> MazeDataset:
108		"""remove duplicates from a dataset, keeping the **LAST** unique maze
109
110		set minimum either minimum difference to `None` to disable checking
111
112		if you want to avoid mazes which have more overlap, set the minimum difference to be greater
113
114		Gotchas:
115		- if two mazes are of different sizes, they will never be considered duplicates
116		- if two solutions are of different lengths, they will never be considered duplicates
117
118		TODO: check for overlap?
119		"""
120		if len(dataset) > _max_dataset_len_threshold:
121			raise ValueError(
122				"this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n",
123				"if you know what you're doing, change `_max_dataset_len_threshold`",
124			)
125
126		unique_mazes: list[SolvedMaze] = list()
127
128		maze_a: SolvedMaze
129		maze_b: SolvedMaze
130		for i, maze_a in enumerate(dataset.mazes):
131			a_unique: bool = True
132			for maze_b in dataset.mazes[i + 1 :]:
133				# after all that nesting, more nesting to perform checks
134				if (minimum_difference_connection_list is not None) and (  # noqa: SIM102
135					maze_a.connection_list.shape == maze_b.connection_list.shape
136				):
137					if (
138						np.sum(maze_a.connection_list != maze_b.connection_list)
139						<= minimum_difference_connection_list
140					):
141						a_unique = False
142						break
143
144				if (minimum_difference_solution is not None) and (  # noqa: SIM102
145					maze_a.solution.shape == maze_b.solution.shape
146				):
147					if (
148						np.sum(maze_a.solution != maze_b.solution)
149						<= minimum_difference_solution
150					):
151						a_unique = False
152						break
153
154			if a_unique:
155				unique_mazes.append(maze_a)
156
157		return copy.deepcopy(
158			MazeDataset(
159				cfg=dataset.cfg,
160				mazes=unique_mazes,
161				generation_metadata_collected=dataset.generation_metadata_collected,
162			),
163		)
164
165	@register_dataset_filter
166	@staticmethod
167	def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset:
168		"""remove duplicates from a dataset"""
169		unique_mazes = list(dict.fromkeys(dataset.mazes))
170		return copy.deepcopy(
171			MazeDataset(
172				cfg=dataset.cfg,
173				mazes=unique_mazes,
174				generation_metadata_collected=dataset.generation_metadata_collected,
175			),
176		)
177
178	@register_dataset_filter
179	@staticmethod
180	def strip_generation_meta(dataset: MazeDataset) -> MazeDataset:
181		"""strip the generation meta from the dataset"""
182		new_dataset: MazeDataset = copy.deepcopy(dataset)
183		for maze in new_dataset:
184			# hacky because it's a frozen dataclass
185			maze.__dict__["generation_meta"] = None
186		return new_dataset
187
188	@register_dataset_filter
189	@staticmethod
190	# yes, this function is complicated hence the noqa
191	def collect_generation_meta(  # noqa: C901, PLR0912
192		dataset: MazeDataset,
193		clear_in_mazes: bool = True,
194		inplace: bool = True,
195		allow_fail: bool = False,
196	) -> MazeDataset:
197		"""collect the generation metadata from each maze into a dataset-level metadata (saves space)
198
199		# Parameters:
200		- `dataset : MazeDataset`
201		- `clear_in_mazes : bool`
202			whether to clear the generation meta in the mazes after collecting it, keep it there if `False`
203			(defaults to `True`)
204		- `inplace : bool`
205			whether to modify the dataset in place or return a new one
206			(defaults to `True`)
207		- `allow_fail : bool`
208			whether to allow the collection to fail if the generation meta is not present in a maze
209			(defaults to `False`)
210
211		# Returns:
212		- `MazeDataset`
213			the dataset with the generation metadata collected
214
215		# Raises:
216		- `ValueError` : if the generation meta is not present in a maze and `allow_fail` is `False`
217		- `ValueError` : if we have other problems converting the generation metadata
218		- `TypeError` : if the generation meta on a maze is of an unexpected type
219		"""
220		if dataset.generation_metadata_collected is not None:
221			return dataset
222		else:
223			assert dataset[0].generation_meta is not None, (
224				"generation meta is not collected and original is not present"
225			)
226		# if the generation meta is already collected, don't collect it again, do nothing
227
228		new_dataset: MazeDataset
229		if inplace:
230			new_dataset = dataset
231		else:
232			new_dataset = copy.deepcopy(dataset)
233
234		gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = (
235			defaultdict(Counter)
236		)
237		for maze in new_dataset:
238			if maze.generation_meta is None:
239				if allow_fail:
240					break
241				raise ValueError(
242					"generation meta is not present in a maze, cannot collect generation meta",
243				)
244			for key, value in maze.generation_meta.items():
245				if isinstance(value, (bool, int, float, str)):  # noqa: UP038
246					gen_meta_lists[key][value] += 1
247
248				elif isinstance(value, set):
249					# special case for visited_cells
250					gen_meta_lists[key].update(value)
251
252				elif isinstance(value, (list, np.ndarray)):  # noqa: UP038
253					if isinstance(value, list):
254						# TODO: `for` loop variable `value` overwritten by assignment target (Ruff PLW2901)
255						try:
256							value = np.array(value)  # noqa: PLW2901
257						except ValueError as convert_to_np_err:
258							err_msg = (
259								f"Cannot collect generation meta for {key} as it is a list of type '{type(value[0]) = !s}'"
260								"\nexpected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords"
261							)
262							raise ValueError(err_msg) from convert_to_np_err
263
264					if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim):
265						# assume its a single coordinate
266						gen_meta_lists[key][tuple(value)] += 1
267					# magic value is fine here
268					elif (len(value.shape) == 2) and (  # noqa: PLR2004
269						value.shape[1] == maze.lattice_dim
270					):
271						# assume its a list of coordinates
272						gen_meta_lists[key].update([tuple(v) for v in value])
273					else:
274						err_msg = (
275							f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}\n"
276							"expected either a coord of shape (2,) or a list of coords of shape (n, 2)"
277						)
278						raise ValueError(err_msg)
279				else:
280					err_msg = (
281						f"Cannot collect generation meta for {key} as it is of type '{type(value)!s}'\n"
282						"expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords"
283					)
284					raise TypeError(err_msg)
285
286			# clear the data
287			if clear_in_mazes:
288				# hacky because it's a frozen dataclass
289				maze.__dict__["generation_meta"] = None
290
291		new_dataset.generation_metadata_collected = {
292			key: dict(value) for key, value in gen_meta_lists.items()
293		}
294
295		return new_dataset

def register_maze_filter( method: Callable[[maze_dataset.SolvedMaze, Any], bool]) -> Callable[Concatenate[~T_Dataset, ~P_FilterKwargs], ~T_Dataset]:
21def register_maze_filter(
22	method: typing.Callable[[SolvedMaze, typing.Any], bool],
23) -> DatasetFilterFunc:
24	"""register a maze filter, casting it to operate over the whole list of mazes
25
26	method should be a staticmethod of a namespace class registered with `register_filter_namespace_for_dataset`
27
28	this is a more restricted version of `register_dataset_filter` that removes the need for boilerplate for operating over the arrays
29	"""
30
31	@functools.wraps(method)
32	def wrapper(dataset: MazeDataset, *args, **kwargs) -> MazeDataset:
33		# copy and filter
34		new_dataset: MazeDataset = copy.deepcopy(
35			MazeDataset(
36				cfg=dataset.cfg,
37				mazes=[m for m in dataset.mazes if method(m, *args, **kwargs)],
38			),
39		)
40		# update the config
41		new_dataset.cfg.applied_filters.append(
42			dict(name=method.__name__, args=args, kwargs=kwargs),
43		)
44		new_dataset.update_self_config()
45		return new_dataset
46
47	return wrapper

register a maze filter, casting it to operate over the whole list of mazes

method should be a staticmethod of a namespace class registered with register_filter_namespace_for_dataset

this is a more restricted version of register_dataset_filter that removes the need for boilerplate for operating over the arrays

@register_filter_namespace_for_dataset(MazeDataset)
class MazeDatasetFilters:
 50@register_filter_namespace_for_dataset(MazeDataset)
 51class MazeDatasetFilters:
 52	"namespace for filters for `MazeDataset`s"
 53
 54	@register_maze_filter
 55	@staticmethod
 56	def path_length(maze: SolvedMaze, min_length: int) -> bool:
 57		"""filter out mazes with a solution length less than `min_length`"""
 58		return len(maze.solution) >= min_length
 59
 60	@register_maze_filter
 61	@staticmethod
 62	def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool:
 63		"""filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)"""
 64		return bool(
 65			(np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance).all()
 66		)
 67
 68	@register_dataset_filter
 69	@staticmethod
 70	def cut_percentile_shortest(
 71		dataset: MazeDataset,
 72		percentile: float = 10.0,
 73	) -> MazeDataset:
 74		"""cut the shortest `percentile` of mazes from the dataset
 75
 76		`percentile` is 1-100, not 0-1, as this is what `np.percentile` expects
 77		"""
 78		lengths: np.ndarray = np.array([len(m.solution) for m in dataset])
 79		cutoff: int = int(np.percentile(lengths, percentile))
 80
 81		filtered_mazes: list[SolvedMaze] = [
 82			m for m in dataset if len(m.solution) > cutoff
 83		]
 84		new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes)
 85
 86		return copy.deepcopy(new_dataset)
 87
 88	@register_dataset_filter
 89	@staticmethod
 90	def truncate_count(
 91		dataset: MazeDataset,
 92		max_count: int,
 93	) -> MazeDataset:
 94		"""truncate the dataset to be at most `max_count` mazes"""
 95		new_dataset: MazeDataset = MazeDataset(
 96			cfg=dataset.cfg,
 97			mazes=dataset.mazes[:max_count],
 98		)
 99		return copy.deepcopy(new_dataset)
100
101	@register_dataset_filter
102	@staticmethod
103	def remove_duplicates(
104		dataset: MazeDataset,
105		minimum_difference_connection_list: int | None = 1,
106		minimum_difference_solution: int | None = 1,
107		_max_dataset_len_threshold: int = 1000,
108	) -> MazeDataset:
109		"""remove duplicates from a dataset, keeping the **LAST** unique maze
110
111		set minimum either minimum difference to `None` to disable checking
112
113		if you want to avoid mazes which have more overlap, set the minimum difference to be greater
114
115		Gotchas:
116		- if two mazes are of different sizes, they will never be considered duplicates
117		- if two solutions are of different lengths, they will never be considered duplicates
118
119		TODO: check for overlap?
120		"""
121		if len(dataset) > _max_dataset_len_threshold:
122			raise ValueError(
123				"this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n",
124				"if you know what you're doing, change `_max_dataset_len_threshold`",
125			)
126
127		unique_mazes: list[SolvedMaze] = list()
128
129		maze_a: SolvedMaze
130		maze_b: SolvedMaze
131		for i, maze_a in enumerate(dataset.mazes):
132			a_unique: bool = True
133			for maze_b in dataset.mazes[i + 1 :]:
134				# after all that nesting, more nesting to perform checks
135				if (minimum_difference_connection_list is not None) and (  # noqa: SIM102
136					maze_a.connection_list.shape == maze_b.connection_list.shape
137				):
138					if (
139						np.sum(maze_a.connection_list != maze_b.connection_list)
140						<= minimum_difference_connection_list
141					):
142						a_unique = False
143						break
144
145				if (minimum_difference_solution is not None) and (  # noqa: SIM102
146					maze_a.solution.shape == maze_b.solution.shape
147				):
148					if (
149						np.sum(maze_a.solution != maze_b.solution)
150						<= minimum_difference_solution
151					):
152						a_unique = False
153						break
154
155			if a_unique:
156				unique_mazes.append(maze_a)
157
158		return copy.deepcopy(
159			MazeDataset(
160				cfg=dataset.cfg,
161				mazes=unique_mazes,
162				generation_metadata_collected=dataset.generation_metadata_collected,
163			),
164		)
165
166	@register_dataset_filter
167	@staticmethod
168	def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset:
169		"""remove duplicates from a dataset"""
170		unique_mazes = list(dict.fromkeys(dataset.mazes))
171		return copy.deepcopy(
172			MazeDataset(
173				cfg=dataset.cfg,
174				mazes=unique_mazes,
175				generation_metadata_collected=dataset.generation_metadata_collected,
176			),
177		)
178
179	@register_dataset_filter
180	@staticmethod
181	def strip_generation_meta(dataset: MazeDataset) -> MazeDataset:
182		"""strip the generation meta from the dataset"""
183		new_dataset: MazeDataset = copy.deepcopy(dataset)
184		for maze in new_dataset:
185			# hacky because it's a frozen dataclass
186			maze.__dict__["generation_meta"] = None
187		return new_dataset
188
189	@register_dataset_filter
190	@staticmethod
191	# yes, this function is complicated hence the noqa
192	def collect_generation_meta(  # noqa: C901, PLR0912
193		dataset: MazeDataset,
194		clear_in_mazes: bool = True,
195		inplace: bool = True,
196		allow_fail: bool = False,
197	) -> MazeDataset:
198		"""collect the generation metadata from each maze into a dataset-level metadata (saves space)
199
200		# Parameters:
201		- `dataset : MazeDataset`
202		- `clear_in_mazes : bool`
203			whether to clear the generation meta in the mazes after collecting it, keep it there if `False`
204			(defaults to `True`)
205		- `inplace : bool`
206			whether to modify the dataset in place or return a new one
207			(defaults to `True`)
208		- `allow_fail : bool`
209			whether to allow the collection to fail if the generation meta is not present in a maze
210			(defaults to `False`)
211
212		# Returns:
213		- `MazeDataset`
214			the dataset with the generation metadata collected
215
216		# Raises:
217		- `ValueError` : if the generation meta is not present in a maze and `allow_fail` is `False`
218		- `ValueError` : if we have other problems converting the generation metadata
219		- `TypeError` : if the generation meta on a maze is of an unexpected type
220		"""
221		if dataset.generation_metadata_collected is not None:
222			return dataset
223		else:
224			assert dataset[0].generation_meta is not None, (
225				"generation meta is not collected and original is not present"
226			)
227		# if the generation meta is already collected, don't collect it again, do nothing
228
229		new_dataset: MazeDataset
230		if inplace:
231			new_dataset = dataset
232		else:
233			new_dataset = copy.deepcopy(dataset)
234
235		gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = (
236			defaultdict(Counter)
237		)
238		for maze in new_dataset:
239			if maze.generation_meta is None:
240				if allow_fail:
241					break
242				raise ValueError(
243					"generation meta is not present in a maze, cannot collect generation meta",
244				)
245			for key, value in maze.generation_meta.items():
246				if isinstance(value, (bool, int, float, str)):  # noqa: UP038
247					gen_meta_lists[key][value] += 1
248
249				elif isinstance(value, set):
250					# special case for visited_cells
251					gen_meta_lists[key].update(value)
252
253				elif isinstance(value, (list, np.ndarray)):  # noqa: UP038
254					if isinstance(value, list):
255						# TODO: `for` loop variable `value` overwritten by assignment target (Ruff PLW2901)
256						try:
257							value = np.array(value)  # noqa: PLW2901
258						except ValueError as convert_to_np_err:
259							err_msg = (
260								f"Cannot collect generation meta for {key} as it is a list of type '{type(value[0]) = !s}'"
261								"\nexpected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords"
262							)
263							raise ValueError(err_msg) from convert_to_np_err
264
265					if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim):
266						# assume its a single coordinate
267						gen_meta_lists[key][tuple(value)] += 1
268					# magic value is fine here
269					elif (len(value.shape) == 2) and (  # noqa: PLR2004
270						value.shape[1] == maze.lattice_dim
271					):
272						# assume its a list of coordinates
273						gen_meta_lists[key].update([tuple(v) for v in value])
274					else:
275						err_msg = (
276							f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}\n"
277							"expected either a coord of shape (2,) or a list of coords of shape (n, 2)"
278						)
279						raise ValueError(err_msg)
280				else:
281					err_msg = (
282						f"Cannot collect generation meta for {key} as it is of type '{type(value)!s}'\n"
283						"expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords"
284					)
285					raise TypeError(err_msg)
286
287			# clear the data
288			if clear_in_mazes:
289				# hacky because it's a frozen dataclass
290				maze.__dict__["generation_meta"] = None
291
292		new_dataset.generation_metadata_collected = {
293			key: dict(value) for key, value in gen_meta_lists.items()
294		}
295
296		return new_dataset

namespace for filters for MazeDatasets

@register_maze_filter
@staticmethod
def path_length(maze: maze_dataset.SolvedMaze, min_length: int) -> bool:
54	@register_maze_filter
55	@staticmethod
56	def path_length(maze: SolvedMaze, min_length: int) -> bool:
57		"""filter out mazes with a solution length less than `min_length`"""
58		return len(maze.solution) >= min_length

filter out mazes with a solution length less than min_length

@register_maze_filter
@staticmethod
def start_end_distance( maze: maze_dataset.SolvedMaze, min_distance: int) -> bool:
60	@register_maze_filter
61	@staticmethod
62	def start_end_distance(maze: SolvedMaze, min_distance: int) -> bool:
63		"""filter out datasets where the start and end pos are less than `min_distance` apart on the manhattan distance (ignoring walls)"""
64		return bool(
65			(np.linalg.norm(maze.start_pos - maze.end_pos, 1) >= min_distance).all()
66		)

filter out datasets where the start and end pos are less than min_distance apart on the manhattan distance (ignoring walls)

@register_dataset_filter
@staticmethod
def cut_percentile_shortest( dataset: maze_dataset.MazeDataset, percentile: float = 10.0) -> maze_dataset.MazeDataset:
68	@register_dataset_filter
69	@staticmethod
70	def cut_percentile_shortest(
71		dataset: MazeDataset,
72		percentile: float = 10.0,
73	) -> MazeDataset:
74		"""cut the shortest `percentile` of mazes from the dataset
75
76		`percentile` is 1-100, not 0-1, as this is what `np.percentile` expects
77		"""
78		lengths: np.ndarray = np.array([len(m.solution) for m in dataset])
79		cutoff: int = int(np.percentile(lengths, percentile))
80
81		filtered_mazes: list[SolvedMaze] = [
82			m for m in dataset if len(m.solution) > cutoff
83		]
84		new_dataset: MazeDataset = MazeDataset(cfg=dataset.cfg, mazes=filtered_mazes)
85
86		return copy.deepcopy(new_dataset)

cut the shortest percentile of mazes from the dataset

percentile is 1-100, not 0-1, as this is what np.percentile expects

@register_dataset_filter
@staticmethod
def truncate_count( dataset: maze_dataset.MazeDataset, max_count: int) -> maze_dataset.MazeDataset:
88	@register_dataset_filter
89	@staticmethod
90	def truncate_count(
91		dataset: MazeDataset,
92		max_count: int,
93	) -> MazeDataset:
94		"""truncate the dataset to be at most `max_count` mazes"""
95		new_dataset: MazeDataset = MazeDataset(
96			cfg=dataset.cfg,
97			mazes=dataset.mazes[:max_count],
98		)
99		return copy.deepcopy(new_dataset)

truncate the dataset to be at most max_count mazes

@register_dataset_filter
@staticmethod
def remove_duplicates( dataset: maze_dataset.MazeDataset, minimum_difference_connection_list: int | None = 1, minimum_difference_solution: int | None = 1, _max_dataset_len_threshold: int = 1000) -> maze_dataset.MazeDataset:
101	@register_dataset_filter
102	@staticmethod
103	def remove_duplicates(
104		dataset: MazeDataset,
105		minimum_difference_connection_list: int | None = 1,
106		minimum_difference_solution: int | None = 1,
107		_max_dataset_len_threshold: int = 1000,
108	) -> MazeDataset:
109		"""remove duplicates from a dataset, keeping the **LAST** unique maze
110
111		set minimum either minimum difference to `None` to disable checking
112
113		if you want to avoid mazes which have more overlap, set the minimum difference to be greater
114
115		Gotchas:
116		- if two mazes are of different sizes, they will never be considered duplicates
117		- if two solutions are of different lengths, they will never be considered duplicates
118
119		TODO: check for overlap?
120		"""
121		if len(dataset) > _max_dataset_len_threshold:
122			raise ValueError(
123				"this method is currently very slow for large datasets, consider using `remove_duplicates_fast` instead\n",
124				"if you know what you're doing, change `_max_dataset_len_threshold`",
125			)
126
127		unique_mazes: list[SolvedMaze] = list()
128
129		maze_a: SolvedMaze
130		maze_b: SolvedMaze
131		for i, maze_a in enumerate(dataset.mazes):
132			a_unique: bool = True
133			for maze_b in dataset.mazes[i + 1 :]:
134				# after all that nesting, more nesting to perform checks
135				if (minimum_difference_connection_list is not None) and (  # noqa: SIM102
136					maze_a.connection_list.shape == maze_b.connection_list.shape
137				):
138					if (
139						np.sum(maze_a.connection_list != maze_b.connection_list)
140						<= minimum_difference_connection_list
141					):
142						a_unique = False
143						break
144
145				if (minimum_difference_solution is not None) and (  # noqa: SIM102
146					maze_a.solution.shape == maze_b.solution.shape
147				):
148					if (
149						np.sum(maze_a.solution != maze_b.solution)
150						<= minimum_difference_solution
151					):
152						a_unique = False
153						break
154
155			if a_unique:
156				unique_mazes.append(maze_a)
157
158		return copy.deepcopy(
159			MazeDataset(
160				cfg=dataset.cfg,
161				mazes=unique_mazes,
162				generation_metadata_collected=dataset.generation_metadata_collected,
163			),
164		)

remove duplicates from a dataset, keeping the LAST unique maze

set minimum either minimum difference to None to disable checking

if you want to avoid mazes which have more overlap, set the minimum difference to be greater

Gotchas:

  • if two mazes are of different sizes, they will never be considered duplicates
  • if two solutions are of different lengths, they will never be considered duplicates

TODO: check for overlap?

@register_dataset_filter
@staticmethod
def remove_duplicates_fast( dataset: maze_dataset.MazeDataset) -> maze_dataset.MazeDataset:
166	@register_dataset_filter
167	@staticmethod
168	def remove_duplicates_fast(dataset: MazeDataset) -> MazeDataset:
169		"""remove duplicates from a dataset"""
170		unique_mazes = list(dict.fromkeys(dataset.mazes))
171		return copy.deepcopy(
172			MazeDataset(
173				cfg=dataset.cfg,
174				mazes=unique_mazes,
175				generation_metadata_collected=dataset.generation_metadata_collected,
176			),
177		)

remove duplicates from a dataset

@register_dataset_filter
@staticmethod
def strip_generation_meta( dataset: maze_dataset.MazeDataset) -> maze_dataset.MazeDataset:
179	@register_dataset_filter
180	@staticmethod
181	def strip_generation_meta(dataset: MazeDataset) -> MazeDataset:
182		"""strip the generation meta from the dataset"""
183		new_dataset: MazeDataset = copy.deepcopy(dataset)
184		for maze in new_dataset:
185			# hacky because it's a frozen dataclass
186			maze.__dict__["generation_meta"] = None
187		return new_dataset

strip the generation meta from the dataset

@register_dataset_filter
@staticmethod
def collect_generation_meta( dataset: maze_dataset.MazeDataset, clear_in_mazes: bool = True, inplace: bool = True, allow_fail: bool = False) -> maze_dataset.MazeDataset:
189	@register_dataset_filter
190	@staticmethod
191	# yes, this function is complicated hence the noqa
192	def collect_generation_meta(  # noqa: C901, PLR0912
193		dataset: MazeDataset,
194		clear_in_mazes: bool = True,
195		inplace: bool = True,
196		allow_fail: bool = False,
197	) -> MazeDataset:
198		"""collect the generation metadata from each maze into a dataset-level metadata (saves space)
199
200		# Parameters:
201		- `dataset : MazeDataset`
202		- `clear_in_mazes : bool`
203			whether to clear the generation meta in the mazes after collecting it, keep it there if `False`
204			(defaults to `True`)
205		- `inplace : bool`
206			whether to modify the dataset in place or return a new one
207			(defaults to `True`)
208		- `allow_fail : bool`
209			whether to allow the collection to fail if the generation meta is not present in a maze
210			(defaults to `False`)
211
212		# Returns:
213		- `MazeDataset`
214			the dataset with the generation metadata collected
215
216		# Raises:
217		- `ValueError` : if the generation meta is not present in a maze and `allow_fail` is `False`
218		- `ValueError` : if we have other problems converting the generation metadata
219		- `TypeError` : if the generation meta on a maze is of an unexpected type
220		"""
221		if dataset.generation_metadata_collected is not None:
222			return dataset
223		else:
224			assert dataset[0].generation_meta is not None, (
225				"generation meta is not collected and original is not present"
226			)
227		# if the generation meta is already collected, don't collect it again, do nothing
228
229		new_dataset: MazeDataset
230		if inplace:
231			new_dataset = dataset
232		else:
233			new_dataset = copy.deepcopy(dataset)
234
235		gen_meta_lists: dict[bool | int | float | str | CoordTup, Counter] = (
236			defaultdict(Counter)
237		)
238		for maze in new_dataset:
239			if maze.generation_meta is None:
240				if allow_fail:
241					break
242				raise ValueError(
243					"generation meta is not present in a maze, cannot collect generation meta",
244				)
245			for key, value in maze.generation_meta.items():
246				if isinstance(value, (bool, int, float, str)):  # noqa: UP038
247					gen_meta_lists[key][value] += 1
248
249				elif isinstance(value, set):
250					# special case for visited_cells
251					gen_meta_lists[key].update(value)
252
253				elif isinstance(value, (list, np.ndarray)):  # noqa: UP038
254					if isinstance(value, list):
255						# TODO: `for` loop variable `value` overwritten by assignment target (Ruff PLW2901)
256						try:
257							value = np.array(value)  # noqa: PLW2901
258						except ValueError as convert_to_np_err:
259							err_msg = (
260								f"Cannot collect generation meta for {key} as it is a list of type '{type(value[0]) = !s}'"
261								"\nexpected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords"
262							)
263							raise ValueError(err_msg) from convert_to_np_err
264
265					if (len(value.shape) == 1) and (value.shape[0] == maze.lattice_dim):
266						# assume its a single coordinate
267						gen_meta_lists[key][tuple(value)] += 1
268					# magic value is fine here
269					elif (len(value.shape) == 2) and (  # noqa: PLR2004
270						value.shape[1] == maze.lattice_dim
271					):
272						# assume its a list of coordinates
273						gen_meta_lists[key].update([tuple(v) for v in value])
274					else:
275						err_msg = (
276							f"Cannot collect generation meta for {key} as it is an ndarray of shape {value.shape}\n"
277							"expected either a coord of shape (2,) or a list of coords of shape (n, 2)"
278						)
279						raise ValueError(err_msg)
280				else:
281					err_msg = (
282						f"Cannot collect generation meta for {key} as it is of type '{type(value)!s}'\n"
283						"expected either a basic type (bool, int, float, str), a numpy coord, or a numpy array of coords"
284					)
285					raise TypeError(err_msg)
286
287			# clear the data
288			if clear_in_mazes:
289				# hacky because it's a frozen dataclass
290				maze.__dict__["generation_meta"] = None
291
292		new_dataset.generation_metadata_collected = {
293			key: dict(value) for key, value in gen_meta_lists.items()
294		}
295
296		return new_dataset

collect the generation metadata from each maze into a dataset-level metadata (saves space)

Parameters:

  • dataset : MazeDataset
  • clear_in_mazes : bool whether to clear the generation meta in the mazes after collecting it, keep it there if False (defaults to True)
  • inplace : bool whether to modify the dataset in place or return a new one (defaults to True)
  • allow_fail : bool whether to allow the collection to fail if the generation meta is not present in a maze (defaults to False)

Returns:

  • MazeDataset the dataset with the generation metadata collected

Raises:

  • ValueError : if the generation meta is not present in a maze and allow_fail is False
  • ValueError : if we have other problems converting the generation metadata
  • TypeError : if the generation meta on a maze is of an unexpected type