Coverage for maze_dataset/dataset/maze_dataset.py: 44%

178 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-24 00:33 -0600

1"""`MazeDatasetConfig` is where you decide what your dataset should look like, then pass it to `MazeDataset.from_config` to generate or load the dataset. 

2 

3see [demo_dataset notebook](../../notebooks/demo_dataset) 

4 

5""" 

6 

7import copy 

8import json 

9import multiprocessing 

10import typing 

11import warnings 

12from pathlib import Path 

13from typing import Literal, Optional, cast, overload 

14 

15import numpy as np 

16import tqdm 

17from jaxtyping import Int 

18from muutils.json_serialize import ( 

19 json_serialize, 

20) 

21from muutils.json_serialize.util import ( 

22 _FORMAT_KEY, 

23 JSONdict, 

24) 

25from muutils.misc import stable_hash 

26from zanj import ZANJ 

27from zanj.loading import LoaderHandler, load_item_recursive, register_loader_handler 

28 

29from maze_dataset.constants import CoordArray 

30from maze_dataset.dataset.dataset import ( 

31 GPTDataset, 

32) 

33from maze_dataset.dataset.maze_dataset_config import ( 

34 SERIALIZE_MINIMAL_THRESHOLD, 

35 EndpointKwargsType, 

36 MazeDatasetConfig, 

37) 

38from maze_dataset.generation.seed import GLOBAL_SEED 

39from maze_dataset.maze import LatticeMaze, SolvedMaze 

40 

41_GLOBAL_WORKER_CONFIG: MazeDatasetConfig 

42 

43 

44def _generate_maze_helper(index: int) -> Optional[SolvedMaze]: # noqa: ARG001 

45 """Helper function for generating mazes in parallel. 

46 

47 > [!CAUTION] 

48 > don't use this unless generating in parallel! 

49 """ 

50 global _GLOBAL_WORKER_CONFIG # noqa: PLW0602 

51 # TODO: don't use this unless generating in parallel! 

52 maze: LatticeMaze = _GLOBAL_WORKER_CONFIG.maze_ctor( 

53 grid_shape=_GLOBAL_WORKER_CONFIG.grid_shape_np, 

54 **_GLOBAL_WORKER_CONFIG.maze_ctor_kwargs, 

55 ) 

56 

57 endpoint_kwargs: EndpointKwargsType = _GLOBAL_WORKER_CONFIG.endpoint_kwargs.copy() 

58 

59 # Generate the solution 

60 # mypy doesnt realize EndpointKwargsType has only string keys: `Keywords must be strings [misc]` 

61 # TYPING: error: No overload variant of "generate_random_path" of "LatticeMaze" matches argument type "dict[Literal['allowed_start', 'allowed_end', 'deadend_start', 'deadend_end', 'endpoints_not_equal', 'except_on_no_valid_endpoint'], bool | list[tuple[int, int]] | None]" [call-overload] 

62 solution: Optional[CoordArray] = maze.generate_random_path(**endpoint_kwargs) # type: ignore[misc, call-overload] 

63 

64 # Validate the solution 

65 if ( 

66 solution is None 

67 or len(solution) == 0 

68 or not isinstance(solution, np.ndarray) 

69 # magic value is fine here 

70 or len(solution.shape) != 2 # noqa: PLR2004 

71 ): 

72 return None # Return None if the solution is invalid 

73 

74 return SolvedMaze.from_lattice_maze( 

75 lattice_maze=maze, 

76 solution=solution, 

77 ) 

78 

79 

80def _maze_gen_init_worker(config: MazeDatasetConfig) -> None: 

81 """special worker helper 

82 

83 > [!CAUTION] 

84 > this makes the generation depend both on whether parallelism is used, and on the number of processes. this is bad! 

85 

86 """ 

87 # TODO: dont use globals here! 

88 global _GLOBAL_WORKER_CONFIG # noqa: PLW0603 

89 _GLOBAL_WORKER_CONFIG = config 

90 

91 process_id: tuple[int, ...] = multiprocessing.current_process()._identity 

92 if len(process_id) == 0: 

93 # no multiprocessing, seed was already set 

94 pass 

95 elif len(process_id) == 1: 

96 # multiprocessing, adjust seed based on process id 

97 # only set numpy seed, since we do not use other random gens 

98 np.random.seed( 

99 _GLOBAL_WORKER_CONFIG.seed 

100 or GLOBAL_SEED # if the seed is None, use the global seed 

101 + process_id[0] 

102 ) 

103 else: 

104 err_msg = ( 

105 f"unexpected process id: {process_id = }\n{multiprocessing.Process() = }" 

106 ) 

107 raise ValueError( 

108 err_msg, 

109 ) 

110 

111 

112class MazeDataset(GPTDataset[MazeDatasetConfig]): 

113 """a maze dataset class. This is a collection of solved mazes, and should be initialized via `MazeDataset.from_config`""" 

114 

115 def __init__( 

116 self, 

117 cfg: MazeDatasetConfig, 

118 mazes: typing.Sequence[SolvedMaze], 

119 generation_metadata_collected: dict | None = None, 

120 ) -> None: 

121 """initialize a maze dataset from a config and a list of solved mazes""" 

122 super().__init__() 

123 self.cfg: MazeDatasetConfig = cfg 

124 self.mazes: list[SolvedMaze] = list(mazes) 

125 self.generation_metadata_collected: dict | None = generation_metadata_collected 

126 

127 # TYPING: error: Return type "MazeDataset" of "from_config" incompatible with return type "T_Dataset" in supertype "GPTDataset" [override] 

128 @classmethod 

129 def from_config( # type: ignore[override] 

130 cls, 

131 # TYPING: error: Argument 1 of "from_config" is incompatible with supertype "GPTDataset"; supertype defines the argument type as "T_DatasetConfig" [override] 

132 cfg: MazeDatasetConfig, # type: ignore[override] 

133 do_generate: bool = True, 

134 load_local: bool = True, 

135 save_local: bool = True, 

136 zanj: ZANJ | None = None, 

137 do_download: bool = True, 

138 local_base_path: Path = Path("data/maze_dataset"), 

139 except_on_config_mismatch: bool = True, 

140 allow_generation_metadata_filter_mismatch: bool = True, 

141 verbose: bool = False, 

142 **kwargs, 

143 ) -> "MazeDataset": 

144 """create a maze dataset from a config 

145 

146 priority of loading: 

147 1. load from local 

148 2. download 

149 3. generate 

150 

151 """ 

152 return cast( 

153 MazeDataset, 

154 super().from_config( 

155 cfg=cfg, 

156 do_generate=do_generate, 

157 load_local=load_local, 

158 save_local=save_local, 

159 zanj=zanj, 

160 do_download=do_download, 

161 local_base_path=local_base_path, 

162 except_on_config_mismatch=except_on_config_mismatch, 

163 allow_generation_metadata_filter_mismatch=allow_generation_metadata_filter_mismatch, 

164 verbose=verbose, 

165 **kwargs, 

166 ), 

167 ) 

168 

169 def data_hash(self) -> int: 

170 """return a hash of the data""" 

171 return stable_hash(str(tuple([x.serialize() for x in self.mazes]))) 

172 

173 def __getitem__(self, i: int) -> SolvedMaze: 

174 """get a maze by index""" 

175 return self.mazes[i] 

176 

177 def __iter__(self) -> typing.Iterator[SolvedMaze]: 

178 """iterate over the mazes""" 

179 return iter(self.mazes) 

180 

181 def __deepcopy__(self, memo) -> "MazeDataset": # noqa: ANN001 

182 """deepcopy the dataset 

183 

184 FIX: this isnt actually a deepcopy I think? 

185 """ 

186 return MazeDataset.load(self._serialize_full()) 

187 

188 # TYPING: get type hints on the tokenizer here 

189 @overload 

190 def as_tokens( 

191 self, 

192 maze_tokenizer, # noqa: ANN001 

193 limit: int | None = None, 

194 join_tokens_individual_maze: Literal[False] = False, 

195 ) -> list[list[str]]: ... 

196 @overload 

197 def as_tokens( 

198 self, 

199 maze_tokenizer, # noqa: ANN001 

200 limit: int | None = None, 

201 join_tokens_individual_maze: Literal[True] = True, 

202 ) -> list[str]: ... 

203 def as_tokens( 

204 self, 

205 maze_tokenizer, # TODO: MazeTokenizer 

206 limit: int | None = None, 

207 join_tokens_individual_maze: bool = False, 

208 ) -> list[list[str]] | list[str]: 

209 """return the dataset as tokens according to the passed `maze_tokenizer` 

210 

211 the `maze_tokenizer` should be either a `MazeTokenizer` or a `MazeTokenizerModular` 

212 

213 if `join_tokens_individual_maze` is True, then the tokens of each maze are 

214 joined with a space, and the result is a list of strings. 

215 i.e.: 

216 

217 >>> dataset.as_tokens(join_tokens_individual_maze=False) 

218 [["a", "b", "c"], ["d", "e", "f"]] 

219 >>> dataset.as_tokens(join_tokens_individual_maze=True) 

220 ["a b c", "d e f"] 

221 """ 

222 output: list[list[str]] = [ 

223 maze.as_tokens(maze_tokenizer) for maze in self.mazes[:limit] 

224 ] 

225 if join_tokens_individual_maze: 

226 return [" ".join(tokens) for tokens in output] 

227 else: 

228 return output 

229 

230 def __len__(self) -> int: 

231 """return the number of mazes in the dataset""" 

232 return len(self.mazes) 

233 

234 def __eq__(self, other: object) -> bool: 

235 """compare two datasets""" 

236 if not isinstance(other, MazeDataset): 

237 raise NotImplementedError( 

238 "can only compare with other MazeDataset objects", 

239 ) 

240 # TODO: compare hashes of data instead of the data itself? 

241 return self.cfg == other.cfg and self.mazes == other.mazes 

242 

243 def assert_equal(self, other: "MazeDataset") -> None: 

244 """assert that two datasets are equal""" 

245 assert isinstance(other, MazeDataset) 

246 assert self.cfg == other.cfg, f"{self.cfg.diff(other.cfg) = }" 

247 assert self.mazes == other.mazes, f"{self.mazes = }, {other.mazes = }" 

248 

249 @classmethod 

250 def generate( 

251 cls, 

252 cfg: MazeDatasetConfig, 

253 gen_parallel: bool = False, 

254 pool_kwargs: dict | None = None, 

255 verbose: bool = False, 

256 # TODO: what to do when unexpected kwargs are passed? 

257 **kwargs, # noqa: ARG003 

258 ) -> "MazeDataset": 

259 """Generate a maze dataset given a config and some generation parameters""" 

260 # Copy the config to avoid modifying the original 

261 cfg_cpy: MazeDatasetConfig = MazeDatasetConfig.load( 

262 json.loads(json.dumps(cfg.serialize())), 

263 ) 

264 

265 if pool_kwargs is None: 

266 pool_kwargs = dict() 

267 maze_indexes: Int[np.ndarray, " maze_index"] = np.arange(cfg_cpy.n_mazes) # type: ignore[assignment] 

268 

269 solved_mazes: list[SolvedMaze | None] 

270 # Configure tqdm for progress bar 

271 tqdm_kwargs: dict = dict( 

272 total=cfg_cpy.n_mazes, 

273 unit="maze", 

274 desc="generating & solving mazes", 

275 disable=not verbose, 

276 ) 

277 # TODO: don't use the global unless generating in parallel! 

278 if gen_parallel: 

279 with multiprocessing.Pool( 

280 **pool_kwargs, 

281 initializer=_maze_gen_init_worker, 

282 initargs=(cfg_cpy,), 

283 ) as pool: 

284 solved_mazes = list( 

285 tqdm.tqdm( 

286 pool.imap(_generate_maze_helper, maze_indexes), 

287 **tqdm_kwargs, 

288 ), 

289 ) 

290 

291 else: 

292 _maze_gen_init_worker(cfg_cpy) 

293 solved_mazes = list( 

294 tqdm.tqdm( 

295 map( 

296 # TYPING: error: Argument 1 to "map" has incompatible type "Callable[[int], SolvedMaze | None]"; expected "Callable[[str], SolvedMaze | None]" [arg-type] 

297 # why does it think tolist() returns a string? 

298 _generate_maze_helper, # type: ignore[arg-type] 

299 maze_indexes.tolist(), 

300 ), 

301 **tqdm_kwargs, 

302 ), 

303 ) 

304 

305 # Filter out None values explicitly after ensuring all results are collected 

306 solved_mazes_: list[SolvedMaze] = [ 

307 maze for maze in solved_mazes if maze is not None 

308 ] 

309 # solved_mazes_ = list(filter(lambda x: x is not None, solved_mazes)) 

310 

311 # Update the config with the actual number of mazes 

312 cfg_cpy.n_mazes = len(solved_mazes_) 

313 

314 dataset: MazeDataset = cls( 

315 cfg=cfg_cpy, 

316 mazes=solved_mazes_, 

317 ) 

318 

319 dataset.update_self_config() # Call `update_self_config()` to ensure the dataset's config reflects changes 

320 

321 np.random.seed(cfg_cpy.seed) # Reset the seed to the value in the config copy 

322 

323 return dataset 

324 

325 @classmethod 

326 def download(cls, cfg: MazeDatasetConfig, **kwargs) -> "MazeDataset": 

327 "(not implemented yet!) download a maze dataset from the internet" 

328 raise NotImplementedError("not implemented yet") 

329 

330 @classmethod 

331 def load(cls: "type[MazeDataset]", data: JSONdict) -> "MazeDataset": 

332 """load from zanj/json""" 

333 if data[_FORMAT_KEY] == "MazeDataset:minimal": 

334 return cls._load_minimal(data) 

335 elif data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat": 

336 return cls._load_minimal_soln_cat(data) 

337 elif data[_FORMAT_KEY] == "MazeDataset": 

338 if ( 

339 SERIALIZE_MINIMAL_THRESHOLD == -1 

340 ): # Allow access to `_load_legacy` for profiling 

341 return cls._load_legacy(data) 

342 return cls._load_full(data) 

343 else: 

344 err_msg: str = f"`_FORMAT_KEY` string {data[_FORMAT_KEY] = } is not a recognized `MazeDataset` format. ({_FORMAT_KEY = })" 

345 raise KeyError( 

346 err_msg, 

347 ) 

348 

349 @classmethod 

350 def _load_full(cls, data: JSONdict) -> "MazeDataset": 

351 assert data[_FORMAT_KEY] == "MazeDataset" 

352 return cls( 

353 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 

354 mazes=load_item_recursive(data["mazes"], tuple()), 

355 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 

356 ) 

357 

358 @classmethod 

359 def _load_minimal(cls, data: JSONdict) -> "MazeDataset": 

360 assert data[_FORMAT_KEY] == "MazeDataset:minimal" 

361 return cls( 

362 cfg=MazeDatasetConfig.load(data["cfg"]), # type: ignore[arg-type] 

363 generation_metadata_collected=data["generation_metadata_collected"], # type: ignore[arg-type] 

364 mazes=[ 

365 SolvedMaze( 

366 clist, 

367 soln[:slen, ...], 

368 ) 

369 for clist, slen, soln in zip( 

370 load_item_recursive(data["maze_connection_lists"], tuple()), 

371 load_item_recursive(data["maze_solution_lengths"], tuple()), 

372 load_item_recursive(data["maze_solutions"], tuple()), 

373 strict=False, 

374 # load_item_recursive(data["maze_endpoints"], tuple()), 

375 ) 

376 ], 

377 ) 

378 

379 @classmethod 

380 def _load_minimal_soln_cat(cls, data: JSONdict) -> "MazeDataset": 

381 assert data[_FORMAT_KEY] == "MazeDataset:minimal_soln_cat" 

382 

383 maze_solution_lengths = load_item_recursive( 

384 data["maze_solution_lengths"], 

385 tuple(), 

386 ) 

387 maze_solutions_concat = load_item_recursive( 

388 data["maze_solutions_concat"], 

389 tuple(), 

390 ) 

391 maze_solutions = np.split( 

392 maze_solutions_concat, 

393 np.cumsum(maze_solution_lengths)[:-1], 

394 axis=0, 

395 ) 

396 

397 return cls( 

398 cfg=load_item_recursive(data["cfg"], tuple()), 

399 generation_metadata_collected=load_item_recursive( 

400 data["generation_metadata_collected"], 

401 tuple(), 

402 ), 

403 mazes=[ 

404 SolvedMaze( 

405 connection_list=clist, 

406 solution=soln, 

407 ) 

408 for clist, soln in zip( 

409 load_item_recursive(data["maze_connection_lists"], tuple()), 

410 # load_item_recursive(data["maze_endpoints"], tuple()), 

411 maze_solutions, 

412 strict=False, 

413 ) 

414 ], 

415 ) 

416 

417 @classmethod 

418 def _load_legacy(cls, data: JSONdict) -> "MazeDataset": 

419 """Legacy `load` method from <0.5.2. Used exclusively for profiling comparison.""" 

420 assert data[_FORMAT_KEY] == "MazeDataset" 

421 return cls( 

422 **{ 

423 key: load_item_recursive(data[key], tuple()) 

424 for key in ["cfg", "mazes", "generation_metadata_collected"] 

425 }, 

426 ) 

427 

428 def serialize(self) -> JSONdict: 

429 """serialize to zanj/json""" 

430 if ( 

431 SERIALIZE_MINIMAL_THRESHOLD is not None 

432 and len(self) >= SERIALIZE_MINIMAL_THRESHOLD 

433 ): 

434 return self._serialize_minimal() 

435 return self._serialize_full() 

436 

437 def _serialize_full(self) -> JSONdict: 

438 return { 

439 _FORMAT_KEY: "MazeDataset", 

440 "cfg": json_serialize(self.cfg), 

441 "fname": self.cfg.to_fname(), 

442 "mazes": json_serialize(self.mazes), 

443 "generation_metadata_collected": json_serialize( 

444 self.generation_metadata_collected, 

445 ), 

446 } 

447 

448 def _serialize_minimal(self) -> JSONdict: 

449 "alternate serialization where metadata is collected and mazes are stored in concatenated form" 

450 filtered_meta: MazeDataset 

451 if self.generation_metadata_collected is None: 

452 filtered_meta = self.filter_by.collect_generation_meta() 

453 else: 

454 filtered_meta = self 

455 

456 max_solution_len: int = max(m.solution.shape[0] for m in filtered_meta.mazes) 

457 n_mazes: int = len(filtered_meta.mazes) 

458 grid_n: int = filtered_meta.cfg.grid_n 

459 

460 maze_connection_lists: np.ndarray = np.empty( 

461 (n_mazes, 2, grid_n, grid_n), 

462 dtype=np.bool_, 

463 ) 

464 # maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 

465 maze_solution_lengths: np.ndarray = np.empty((n_mazes,), dtype=np.int32) 

466 maze_solutions: np.ndarray = np.empty( 

467 (n_mazes, max_solution_len, 2), 

468 dtype=np.int8, 

469 ) 

470 

471 for idx, maze in enumerate(filtered_meta.mazes): 

472 maze_connection_lists[idx] = maze.connection_list 

473 # maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 

474 maze_solution_lengths[idx] = maze.solution.shape[0] 

475 maze_solutions[idx, : maze.solution.shape[0]] = maze.solution 

476 

477 return { 

478 _FORMAT_KEY: "MazeDataset:minimal", 

479 "cfg": json_serialize(filtered_meta.cfg), 

480 "fname": filtered_meta.cfg.to_fname(), 

481 "generation_metadata_collected": json_serialize( 

482 filtered_meta.generation_metadata_collected, 

483 ), 

484 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 

485 # "maze_endpoints": maze_endpoints, 

486 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 

487 "maze_solutions": maze_solutions, # type: ignore[dict-item] 

488 } 

489 

490 def _serialize_minimal_soln_cat(self: "MazeDataset") -> JSONdict: 

491 "alternate serialization where metadata is collected, and mazes and their solutions are stored in concatenated form" 

492 filtered_meta: MazeDataset 

493 if self.generation_metadata_collected is None: 

494 filtered_meta = self.filter_by.collect_generation_meta() 

495 else: 

496 filtered_meta = self 

497 

498 maze_solution_lengths: np.ndarray = np.array( 

499 [m.solution.shape[0] for m in filtered_meta.mazes], 

500 dtype=np.int32, 

501 ) 

502 n_mazes: int = len(filtered_meta.mazes) 

503 grid_n: int = filtered_meta.cfg.grid_n 

504 total_solution_len: int = np.sum(maze_solution_lengths) 

505 

506 maze_connection_lists: np.ndarray = np.empty( 

507 (n_mazes, 2, grid_n, grid_n), 

508 dtype=np.bool_, 

509 ) 

510 maze_endpoints: np.ndarray = np.empty((n_mazes, 2, 2), dtype=np.int8) 

511 maze_solutions_concat: np.ndarray = np.empty( 

512 (total_solution_len, 2), 

513 dtype=np.int8, 

514 ) 

515 

516 solutions_running_idx: int = 0 

517 for idx, maze in enumerate(filtered_meta.mazes): 

518 maze_connection_lists[idx] = maze.connection_list 

519 maze_endpoints[idx] = np.array([maze.start_pos, maze.end_pos]) 

520 soln_len: int = maze.solution.shape[0] 

521 maze_solution_lengths[idx] = soln_len 

522 maze_solutions_concat[ 

523 solutions_running_idx : solutions_running_idx + soln_len 

524 ] = maze.solution 

525 solutions_running_idx += soln_len 

526 

527 return { 

528 _FORMAT_KEY: "MazeDataset:minimal_soln_cat", 

529 "cfg": json_serialize(filtered_meta.cfg), 

530 "fname": filtered_meta.cfg.to_fname(), 

531 "generation_metadata_collected": json_serialize( 

532 filtered_meta.generation_metadata_collected, 

533 ), 

534 "maze_connection_lists": maze_connection_lists, # type: ignore[dict-item] 

535 "maze_endpoints": maze_endpoints, # type: ignore[dict-item] 

536 "maze_solution_lengths": maze_solution_lengths, # type: ignore[dict-item] 

537 "maze_solutions_concat": maze_solutions_concat, # type: ignore[dict-item] 

538 } 

539 

540 def update_self_config(self) -> None: 

541 """update the config to match the current state of the dataset (number of mazes, such as after filtering)""" 

542 if self.cfg.n_mazes != len(self.mazes): 

543 warnings.warn( 

544 f"updating config n_mazes from {self.cfg.n_mazes} to {len(self.mazes)}", 

545 ) 

546 self.cfg.n_mazes = len(self.mazes) 

547 

548 def custom_maze_filter( 

549 self, 

550 method: typing.Callable[[SolvedMaze], bool], 

551 **kwargs, 

552 ) -> "MazeDataset": 

553 """filter the dataset using a custom method""" 

554 output: MazeDataset = MazeDataset( 

555 cfg=copy.deepcopy(self.cfg), 

556 mazes=[m for m in self.mazes if method(m, **kwargs)], 

557 ) 

558 output.cfg.applied_filters.append( 

559 { 

560 "name": f"__custom__:{method.__name__}", 

561 "kwargs": kwargs, 

562 }, 

563 ) 

564 output.update_self_config() 

565 return output 

566 

567 

568MazeDatasetConfig._dataset_class = property( # type: ignore[method-assign, assignment] 

569 lambda self: MazeDataset, # noqa: ARG005 

570) 

571 

572# register things with zanj 

573register_loader_handler( 

574 LoaderHandler( 

575 check=lambda json_item, path=None, z=None: ( # type: ignore[misc] # noqa: ARG005 

576 isinstance(json_item, typing.Mapping) 

577 and _FORMAT_KEY in json_item 

578 and json_item[_FORMAT_KEY].startswith("MazeDataset") 

579 ), 

580 load=lambda json_item, path=None, z=None: MazeDataset.load(json_item), # type: ignore[misc] # noqa: ARG005 

581 uid="MazeDataset", 

582 source_pckg="maze_dataset.generation.maze_dataset", 

583 desc="MazeDataset", 

584 ), 

585) 

586 

587 

588# TODO: the code below is for doing some smarter collecting and type checking. Probably will delete. 

589""" 

590collect either the type at the field, or the shape of the field if it is an array 

591metadata_types: dict[str, set[type, tuple]] = dict() 

592for maze in new_dataset: 

593 for key, value in maze.generation_meta.items(): 

594 if key not in metadata_types: 

595 metadata_types[key] = set() 

596 

597 if isinstance(value, np.ndarray): 

598 metadata_types[key].add(value.shape) 

599 else: 

600 metadata_types[key].add(type(value)) 

601 

602# figure out what to do for each field 

603metadata_actions: dict[str, typing.Callable] = dict() 

604for key, key_type in metadata_types.items(): 

605 if all(isinstance(kt, tuple) for kt in key_type): 

606 if all(kt == (2,) for kt in key_type): 

607 # its all coords, do a statcounter on those coords 

608 metadata_actions[key] = lambda vals: Counter(tuple(x) for x in vals) 

609 elif all( 

610 (len(kt) == 2) and (kt[1] == 2) 

611 for kt in key_type 

612 ): 

613 # its a list of coords, do a statcounter on those coords 

614 metadata_actions[key] = lambda vals: Counter( 

615 tuple(x) for x in np.concatenate(vals) 

616 ) 

617 else: 

618 # its a list of something else, do a counter on those 

619 # TODO: throw except here? 

620 metadata_actions[key] = Counter 

621 

622 elif all(kt in (bool, int, float) for kt in key_type): 

623 # statcounter for numeric types 

624 metadata_actions[key] = StatCounter 

625 elif all(kt == str for kt in key_type): 

626 # counter for string types 

627 metadata_actions[key] = Counter 

628 else: 

629 # counter for everything else 

630 # TODO: throw except here? 

631 metadata_actions[key] = Counter 

632"""