Coverage for maze_dataset/dataset/maze_dataset_config.py: 24%

118 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 12:48 -0600

1"implements `MazeDatasetConfig` which is used to generate or load a dataset" 

2 

3import hashlib 

4import importlib.metadata 

5import json 

6import typing 

7import warnings 

8from typing import Callable 

9 

10import numpy as np 

11from jaxtyping import Float 

12from muutils.json_serialize import ( 

13 serializable_dataclass, 

14 serializable_field, 

15) 

16from muutils.json_serialize.util import ( 

17 safe_getsource, 

18 string_as_lines, 

19) 

20from muutils.misc import sanitize_fname, shorten_numerical_to_str 

21 

22from maze_dataset.constants import Coord, CoordTup 

23from maze_dataset.dataset.dataset import ( 

24 GPTDatasetConfig, 

25) 

26from maze_dataset.dataset.success_predict_math import cfg_success_predict_fn 

27from maze_dataset.generation.generators import _GENERATORS_PERCOLATED, GENERATORS_MAP 

28 

29SERIALIZE_MINIMAL_THRESHOLD: int | None = 100 

30"""If `n_mazes>=SERIALIZE_MINIMAL_THRESHOLD`, then the MazeDataset will use `serialize_minimal`. 

31Setting to None means that `serialize_minimal` will never be used. 

32Set to -1 to make calls to `read` use `MazeDataset._load_legacy`. Used for profiling only.""" 

33 

34MAZEDATASETCONFIG_FNAME_HASH_LENGTH: int = 5 

35"length of the has, in characters, of the hash in the fname of a `MazeDatasetConfig`" 

36 

37_PercolationSuccessArray = Float[ 

38 np.ndarray, 

39 "p/grid_n/deadends/endpoints_not_equal/generator_func=5", 

40] 

41 

42 

43class NoPercolationInConfigError(ValueError): 

44 """raised when trying to predict the success fraction of a config that doesn't have percolation""" 

45 

46 pass 

47 

48 

49class SuccessChanceTooSmallError(ValueError): 

50 """raised when the success fraction is below the threshold in `MazeDatasetConfig.success_fraction_compensate`""" 

51 

52 pass 

53 

54 

55def set_serialize_minimal_threshold(threshold: int | None) -> None: 

56 "get the global SERIALIZE_MINIMAL_THRESHOLD" 

57 global SERIALIZE_MINIMAL_THRESHOLD # noqa: PLW0603 

58 SERIALIZE_MINIMAL_THRESHOLD = threshold 

59 

60 

61def _load_maze_ctor(maze_ctor_serialized: str | dict) -> Callable: 

62 "get the maze constructor from `GENERATORS_MAP`" 

63 if isinstance(maze_ctor_serialized, dict): 

64 # this is both the new and old version of the serialization 

65 return GENERATORS_MAP[maze_ctor_serialized["__name__"]] 

66 elif isinstance(maze_ctor_serialized, str): 

67 # this is a version I switched to for a while but now we are switching back 

68 warnings.warn( 

69 "you are loading an old model/config in `_load_maze_ctor()`!!! this should not be happening, please report: " 

70 "https://github.com/understanding-search/maze-dataset/issues/new", 

71 ) 

72 return GENERATORS_MAP[maze_ctor_serialized] 

73 else: 

74 err_msg: str = f"maze_ctor_serialized is of type {type(maze_ctor_serialized) = }, expected str or dict\n{maze_ctor_serialized = }" 

75 raise TypeError(err_msg) 

76 

77 

78EndpointKwargsType = dict[ 

79 typing.Literal[ 

80 "allowed_start", 

81 "allowed_end", 

82 "deadend_start", 

83 "deadend_end", 

84 "endpoints_not_equal", 

85 "except_on_no_valid_endpoint", 

86 ], 

87 bool | None | list[tuple[int, int]], 

88] 

89"""type hint for `MazeDatasetConfig.endpoint_kwargs` 

90 

91- `except_on_no_valid_endpoint : bool` (default: `True`) 

92 some of the conditions (dead ends if a maze is very open, no path between given start and end) can cause the maze generation to fail. 

93 if `except_on_no_valid_endpoint` is `True`, then the maze generation will raise an error if it fails to generate a valid maze. 

94 however, if `False`, then the maze generation will return a dataset with fewer mazes than requested. 

95 If you are generating large datasets, consider using `MazeDatasetConfig.success_fraction_compensate()` 

96 this uses a pysr-created function to roughly estimate the success fraction of the dataset. 

97- `allowed_start : list[tuple[int, int]]` (default: `None`) 

98 list of allowed starting position coordinates 

99- `allowed_end : list[tuple[int, int]]` (default: `None`) 

100 list of allowed ending position coordinates 

101- `deadend_start : bool` (default: `False`) 

102 if `True`, the starting position must be a dead end 

103- `deadend_end : bool` (default: `False`) 

104 if `True`, the ending position must be a dead end 

105- `endpoints_not_equal : bool` (default: `True`) 

106 if `True`, the starting and ending positions must be different 

107 

108 

109 

110""" 

111 

112 

113def _load_endpoint_kwargs(data: dict) -> EndpointKwargsType: 

114 if data.get("endpoint_kwargs") is None: 

115 return dict() 

116 

117 else: 

118 return { 

119 k: ( 

120 # bools and Nones are fine 

121 v 

122 if (isinstance(v, bool) or v is None) 

123 # assume its a CoordList 

124 else [tuple(x) for x in v] # muutils/zanj saves tuples as lists 

125 ) 

126 for k, v in data["endpoint_kwargs"].items() 

127 } 

128 

129 

130# not private because we need this to show up in docs 

131@serializable_dataclass(kw_only=True, properties_to_serialize=["grid_shape"]) 

132class MazeDatasetConfig_base(GPTDatasetConfig): # noqa: N801 

133 """base config -- we serialize, dump to json, and hash this to get the fname. all actual variables we want to be hashed are here""" 

134 

135 # NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only 

136 

137 grid_n: int = serializable_field() # type: ignore[misc] 

138 

139 # not comparing n_mazes is done primarily to avoid conflicts which happen during `from_config` when we have applied filters 

140 n_mazes: int = serializable_field(compare=False) # type: ignore[misc] 

141 

142 maze_ctor: Callable = serializable_field( 

143 default=GENERATORS_MAP["gen_dfs"], 

144 serialization_fn=lambda gen_func: { 

145 "__name__": gen_func.__name__, 

146 "__module__": gen_func.__module__, 

147 # NOTE: this was causing hashing issues on 3.13 vs older versions because somehow, 

148 # the `__doc__` variable is different across versions??????? WHY???????? IT TREATS WHITESPACE DIFFERENTLY 

149 # so we just uh. strip it all now. 

150 # see: 

151 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080746?pr=53 

152 # https://github.com/understanding-search/maze-dataset/actions/runs/14028046497/job/39270080742?pr=53 

153 # https://www.diffchecker.com/tqIMSevy/ 

154 # update: we also need to filter for empty lines. B) 

155 "__doc__": [ 

156 line.strip() 

157 for line in string_as_lines(gen_func.__doc__) 

158 if line.strip() 

159 ], 

160 "source_code": safe_getsource(gen_func), 

161 }, 

162 loading_fn=lambda data: _load_maze_ctor(data["maze_ctor"]), 

163 assert_type=False, # TODO: check the type here once muutils supports checking Callable signatures 

164 ) 

165 

166 maze_ctor_kwargs: dict = serializable_field( 

167 default_factory=dict, 

168 serialization_fn=lambda kwargs: kwargs, 

169 loading_fn=lambda data: ( 

170 dict() 

171 if data.get("maze_ctor_kwargs", None) 

172 is None # this should handle the backwards compatibility 

173 else data["maze_ctor_kwargs"] 

174 ), 

175 ) 

176 

177 endpoint_kwargs: EndpointKwargsType = serializable_field( 

178 default_factory=dict, 

179 serialization_fn=lambda kwargs: kwargs, 

180 loading_fn=_load_endpoint_kwargs, 

181 assert_type=False, 

182 ) 

183 

184 # NOTE: this part is very hacky. the way muutils works is that it iterates over the *keys in the serialized data*, 

185 # and so we need to save an `None` here or this wont load the `fname` field on load 

186 # this is a total mess, and very confusing, and entirely my fault 

187 _fname_loaded: str | None = serializable_field( 

188 default=None, 

189 compare=False, 

190 serialization_fn=lambda _: None, 

191 loading_fn=lambda data: data.get("fname", None), 

192 ) 

193 

194 @property 

195 def grid_shape(self) -> CoordTup: 

196 """return the shape of the grid as a tuple""" 

197 return (self.grid_n, self.grid_n) 

198 

199 @property 

200 def grid_shape_np(self) -> Coord: 

201 """return the shape of the grid as a numpy array""" 

202 return np.array(self.grid_shape) 

203 

204 @property 

205 def max_grid_n(self) -> int: 

206 """return the maximum of the grid shape""" 

207 return max(self.grid_shape) 

208 

209 def _serialize_base( 

210 self, applied_filters__skip__collect_generation_meta: bool = True 

211 ) -> dict: 

212 """serialize the base config for user in `stable_hash_cfg()` and `to_fname()` 

213 

214 - note that the _fname_loaded will always be `None` to avoid infinite recursion 

215 - note that we **do not** by default include information about metadata collection here, 

216 since otherwise loading a dataset that we minified by collecting the metadata would be impossible 

217 but for comparing things, we do store it when serializing properly by setting 

218 `applied_filters__skip__collect_generation_meta=False` 

219 """ 

220 serialized: dict = MazeDatasetConfig_base.serialize(self) 

221 if applied_filters__skip__collect_generation_meta: 

222 serialized["applied_filters"] = [ 

223 x 

224 for x in serialized["applied_filters"] 

225 if x.get("name", None) != "collect_generation_meta" 

226 ] 

227 return serialized 

228 

229 def _stable_str_dump(self) -> str: 

230 return json.dumps( 

231 self._serialize_base(), 

232 sort_keys=True, 

233 indent=None, 

234 ) 

235 

236 def stable_hash_cfg(self) -> int: 

237 """return a stable hash of the config""" 

238 return int.from_bytes( 

239 hashlib.md5( # noqa: S324 

240 bytes(self._stable_str_dump(), "ascii") 

241 ).digest(), 

242 "big", 

243 ) 

244 

245 def to_fname(self) -> str: 

246 """return a unique identifier (valid as a filename) for this config""" 

247 n_mazes_str: str = shorten_numerical_to_str(self.n_mazes) 

248 maze_ctor_name: str = self.maze_ctor.__name__.removeprefix("gen_") 

249 hash_id: int = self.stable_hash_cfg() % 10**MAZEDATASETCONFIG_FNAME_HASH_LENGTH 

250 return sanitize_fname( 

251 f"{self.name}-g{self.grid_n}-n{n_mazes_str}-a_{maze_ctor_name}-h{hash_id}", 

252 ) 

253 

254 

255# NOTE: type: ignore[misc] is because it tells us non-default attributes aren't allowed after ones with defaults, but everything is kw_only 

256@serializable_dataclass(kw_only=True, methods_no_override=["serialize"]) 

257class MazeDatasetConfig(MazeDatasetConfig_base): # type: ignore[misc] 

258 """config object which is passed to `MazeDataset.from_config` to generate or load a dataset 

259 

260 # Parameters: 

261 - `name : str` 

262 name of the dataset -- this can be anything, but should be filesystem safe since we use it in the `fname` 

263 - `grid_n : int` 

264 grid size of the maze (number of rows/columns) 

265 - `n_mazes : int` 

266 number of mazes to request. For some combinations of `endpoint_kwargs` and `maze_ctor`, not all mazes might successfully generate. 

267 see `EndpointKwargsType` for more details. 

268 - `maze_ctor : Callable` 

269 maze generator function. This should be a function that takes a grid size and returns a maze. 

270 This will usually be one of the functions in `LatticeMazeGenerators`. 

271 - `maze_ctor_kwargs : dict` 

272 keyword arguments to pass to the maze generator function. Specific to the `maze_ctor` you are using. 

273 - `endpoint_kwargs : EndpointKwargsType` 

274 keyword arguments passed to `LatticeMaze.generate_random_path()`. see `EndpointKwargsType` for more info. 

275 - `applied_filters : list[dict]` 

276 list of filters that have been applied to the dataset. We recommend applying filters to datasets directly, 

277 but these are stored with the config in case you want to re-generate the dataset with the same filters. 

278 

279 """ 

280 

281 @property 

282 def config_version(self) -> str: 

283 """return the version of the config. added in maze_dataset v1.3.0, previous versions had no dataset config""" 

284 return "1.0" 

285 

286 @property 

287 def versions(self) -> dict: 

288 """return the versions of the config and the maze_dataset""" 

289 return dict( 

290 config=self.config_version, 

291 maze_dataset=importlib.metadata.version("maze_dataset"), 

292 ) 

293 

294 def serialize(self) -> dict: 

295 "serialize the MazeDatasetConfig with all fields and fname" 

296 return { 

297 **self._serialize_base( 

298 applied_filters__skip__collect_generation_meta=False 

299 ), 

300 "fname": self.to_fname(), 

301 "versions": self.versions, 

302 } 

303 

304 def summary(self) -> dict: 

305 """return a summary of the config""" 

306 # do we run this to make sure it doesn't error? 

307 super_summary: dict = super().summary() 

308 assert super_summary 

309 self_ser: dict = self.serialize() 

310 return dict( 

311 name=self.name, 

312 fname=self.to_fname(), 

313 sdc_hash=self.stable_hash_cfg(), 

314 seed=self.seed, 

315 seq_len_min=self.seq_len_min, 

316 seq_len_max=self.seq_len_max, 

317 applied_filters=self.applied_filters, 

318 grid_n=self_ser["grid_n"], 

319 n_mazes=self_ser["n_mazes"], 

320 maze_ctor_name=self_ser["maze_ctor"]["__name__"], 

321 maze_ctor_kwargs=self_ser["maze_ctor_kwargs"], 

322 endpoint_kwargs=self_ser["endpoint_kwargs"], 

323 ) 

324 

325 def _to_ps_array(self) -> _PercolationSuccessArray: 

326 """Convert this config to a [p, grid_n, deadends, endpoints_not_equal, generator_func] vector. 

327 

328 used in predicting the success rate 

329 """ 

330 try: 

331 assert self.maze_ctor.__name__ in _GENERATORS_PERCOLATED, ( 

332 f"generator not supported, must be a percolation generator\n{self.maze_ctor.__name__ = }, {_GENERATORS_PERCOLATED = }" 

333 ) 

334 assert "p" in self.maze_ctor_kwargs, ( 

335 f"maze_ctor_kwargs must have a 'p' (percolation value) key: {self.maze_ctor_kwargs = }" 

336 ) 

337 assert not self.endpoint_kwargs.get("except_on_no_valid_endpoint", True), ( 

338 f"except_on_no_valid_endpoint must be False, or else if any maze fails to generate, the whole dataset will fail: {self.endpoint_kwargs = }" 

339 ) 

340 except AssertionError as e: 

341 err_msg: str = f"invalid config for percolation success prediction: {self.summary() = }" 

342 raise NoPercolationInConfigError( 

343 err_msg, 

344 ) from e 

345 

346 endpoints_unique_flag: int = int( 

347 # we are pretty sure it will be an int or bool here 

348 self.endpoint_kwargs.get("endpoints_not_equal", True), # type: ignore[arg-type] 

349 ) 

350 

351 # adjustment for bknutson0 

352 if not ( 

353 self.endpoint_kwargs.get("deadend_start", False) 

354 and self.endpoint_kwargs.get("deadend_end", False) 

355 ): 

356 # we didnt train on this, but if either endpoint is not required to be in a dead end 

357 # then requiring the endpoints to be unique does not really affect the success rate 

358 # (except for very small percolation values, pure percolation generation) 

359 endpoints_unique_flag = 0 

360 

361 return np.array( 

362 [ 

363 float(self.maze_ctor_kwargs["p"]), 

364 float(self.grid_n), 

365 float( 

366 int( 

367 self.endpoint_kwargs.get("deadend_start", False) # type: ignore[arg-type] 

368 or self.endpoint_kwargs.get("deadend_end", False), 

369 ), 

370 ), 

371 float(endpoints_unique_flag), 

372 float(_GENERATORS_PERCOLATED.index(self.maze_ctor.__name__)), 

373 ], 

374 dtype=np.float64, 

375 ) 

376 

377 @classmethod 

378 def _from_ps_array( 

379 cls, 

380 arr: _PercolationSuccessArray, 

381 name: str = "predict", 

382 n_mazes: int = 100, 

383 **kwargs, 

384 ) -> "MazeDatasetConfig": 

385 """Reconstruct a config from an array [p, grid_n, deadends, endpoints_not_equal, generator_func] and other config parameters. 

386 

387 # Returns: 

388 - `MazeDatasetConfig` 

389 Config corresponding to `arr` 

390 """ 

391 return cls( 

392 name=name, 

393 grid_n=int(arr[1]), 

394 n_mazes=n_mazes, 

395 maze_ctor=GENERATORS_MAP[_GENERATORS_PERCOLATED[int(arr[4])]], 

396 maze_ctor_kwargs={"p": float(arr[0])}, 

397 endpoint_kwargs=dict( 

398 deadend_start=bool(arr[2]), 

399 deadend_end=bool(arr[2]), 

400 endpoints_not_equal=bool(arr[3]), 

401 except_on_no_valid_endpoint=False, 

402 ), 

403 **kwargs, 

404 ) 

405 

406 def success_fraction_estimate( 

407 self, 

408 except_if_all_success_expected: bool = False, 

409 ) -> float: 

410 """Estimate the success fraction of this config. 

411 

412 only valid when the generator is a percolation generator, 

413 and endpoints are enforced to be dead ends 

414 

415 more information on where this comes from can be found in 

416 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 

417 - `estimate_dataset_fractions.ipynb` 

418 - `maze_dataset.benchmarks.sweep_fit` 

419 

420 # Parameters: 

421 - `except_if_all_success_expected : bool` 

422 if `True`, don't raise an error if the success fraction is below the threshold. 

423 will always return `1.0` if the config is not expected to fail 

424 

425 # Returns: 

426 - `float` 

427 estimated success fraction 

428 

429 # Raises: 

430 - `NoPercolationInConfigError` : if the config is not expected to fail, and `except_if_all_success_expected` is `False` 

431 """ 

432 try: 

433 return cfg_success_predict_fn(self) 

434 

435 except NoPercolationInConfigError as e: 

436 if except_if_all_success_expected: 

437 raise e # noqa: TRY201 

438 return 1.0 

439 

440 def success_fraction_compensate( 

441 self, 

442 safety_margin: float = 1.2, 

443 except_if_all_success_expected: bool = False, 

444 epsilon: float = 1e-2, 

445 ) -> "MazeDatasetConfig": 

446 """return a new `MazeDatasetConfig` like this one with `n_mazes` adjusted to compensate for the success fraction 

447 

448 calls `MazeDatasetConfig.success_fraction_estimate()` to get the success fraction, and then 

449 computes the new number of mazes as `n_mazes = n_mazes * safety_margin / success_fraction + 1` 

450 

451 more information on where this comes from can be found in 

452 - `cfg_success_predict_fn()` from `maze_dataset.dataset.success_predict_math` 

453 - `estimate_dataset_fractions.ipynb` 

454 - `maze_dataset.benchmarks.sweep_fit` 

455 

456 # Parameters: 

457 - `safety_margin : float` 

458 safety margin to apply to the success fraction estimate 

459 (defaults to `1.2`, or 20% more mazes than estimated) 

460 - `except_if_all_success_expected : bool` 

461 if `True`, don't raise an error if the success fraction is below the threshold. 

462 this is passed to `MazeDatasetConfig.success_fraction_estimate`. 

463 if your config isn't expected to fail, passing this might mean you generate more mazes than needed 

464 since `safety_margin` is still applied. 

465 (defaults to `False`) 

466 - `epsilon : float` 

467 raise `SuccessChanceTooSmallError` if the success fraction is below this threshold 

468 (defaults to `1e-2`) 

469 

470 # Returns: 

471 - `MazeDatasetConfig` 

472 new config with adjusted `n_mazes` 

473 

474 # Raises: 

475 - `SuccessChanceTooSmallError` : if the computed success fraction is below `epsilon` 

476 """ 

477 # compute and check the success fraction 

478 success_fraction: float = self.success_fraction_estimate( 

479 except_if_all_success_expected=except_if_all_success_expected, 

480 ) 

481 if success_fraction < epsilon: 

482 err_msg: str = ( 

483 f"{success_fraction = } is below the threshold of {epsilon = }" 

484 ) 

485 raise SuccessChanceTooSmallError( 

486 err_msg, 

487 ) 

488 

489 # compute the new number of mazes 

490 n_mazes: int = self.n_mazes 

491 new_n_mazes: int = int((n_mazes * safety_margin) / success_fraction) + 1 

492 

493 # put it in a new config and return 

494 cfg_dict: dict = self.serialize() 

495 cfg_dict["n_mazes"] = new_n_mazes 

496 return MazeDatasetConfig.load(cfg_dict)