Coverage for maze_dataset/tokenization/modular/all_instances.py: 97%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 12:48 -0600

1"`all_instances`, `FiniteValued`, and related code for tokenizers" 

2 

3import enum 

4import itertools 

5import typing 

6from dataclasses import Field # noqa: TC003 

7from functools import cache, wraps 

8from types import UnionType 

9from typing import ( 

10 Callable, 

11 Generator, 

12 Iterable, 

13 Literal, 

14 TypeVar, 

15 get_args, 

16 get_origin, 

17) 

18 

19try: 

20 import frozendict 

21except ImportError as e: 

22 raise ImportError( 

23 "You need to install the `frozendict` package to use `all_instances` -- try installing `maze_dataset[tokenization]`" 

24 ) from e 

25from muutils.misc import IsDataclass, flatten, is_abstract 

26 

27FiniteValued = TypeVar("FiniteValued", bound=bool | IsDataclass | enum.Enum) 

28""" 

29# `FiniteValued` 

30The details of this type are not possible to fully define via the Python 3.10 typing library. 

31This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space. 

32`FiniteValued` defines the domain of supported types for the `all_instances` function, since that function relies heavily on static typing. 

33These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below). 

34The leaves of the tree must always be Primitive Types. 

35 

36# `FiniteValued` Subtypes 

37*: Indicates that this subtype is not yet supported by `all_instances` 

38 

39## Non-`FiniteValued` (Unbounded) Types 

40These are NOT valid subtypes, and are listed for illustrative purposes only. 

41This list is not comprehensive. 

42While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite, 

43they are considered unbounded types in this context. 

44- No Container subtype may contain any of these unbounded subtypes. 

45- `int` 

46- `float` 

47- `str` 

48- `list` 

49- `set`: Set types without a `FiniteValued` argument are unbounded 

50- `tuple`: Tuple types without a fixed length are unbounded 

51 

52## Primitive Types 

53Primitive types are non-nested types which resolve directly to a concrete range of values 

54- `bool`: has 2 possible values 

55- *`enum.Enum`: The range of a concrete `Enum` subclass is its set of enum members 

56- `typing.Literal`: Every type constructed using `Literal` has a finite set of possible literal values in its definition. 

57This is the preferred way to include limited ranges of non-`FiniteValued` types such as `int` or `str` in a `FiniteValued` hierarchy. 

58 

59## Container Types 

60Container types are types which contain zero or more fields of `FiniteValued` type. 

61The range of a container type is the cartesian product of their field types, except for `set[FiniteValued]`. 

62- `tuple[FiniteValued]`: Tuples of fixed length whose elements are each `FiniteValued`. 

63- `IsDataclass`: Concrete dataclasses whose fields are `FiniteValued`. 

64- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are `FiniteValued`-typed. 

65- *`set[FiniteValued]`: Sets of fixed length of a `FiniteValued` type. 

66 

67## Superclass Types 

68Superclass types don't directly contain data members like container types. 

69Their range is the union of the ranges of their subtypes. 

70- Abstract dataclasses: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types 

71- *`IsDataclass`: Concrete dataclasses which also have their own subclasses. 

72- *Standard abstract classes: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types 

73- `UnionType`: Any union of `FiniteValued` types, e.g., bool | Literal[2, 3] 

74""" 

75 

76 

77def _apply_validation_func( 

78 type_: FiniteValued, 

79 vals: Generator[FiniteValued, None, None], 

80 validation_funcs: ( 

81 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] | None 

82 ) = None, 

83) -> Generator[FiniteValued, None, None]: 

84 """Helper function for `all_instances`. 

85 

86 Filters `vals` according to `validation_funcs`. 

87 If `type_` is a regular type, searches in MRO order in `validation_funcs` and applies the first match, if any. 

88 Handles generic types supported by `all_instances` with special `if` clauses. 

89 

90 # Parameters 

91 - `type_: FiniteValued`: A type 

92 - `vals: Generator[FiniteValued, None, None]`: Instances of `type_` 

93 - `validation_funcs: dict`: Collection of types mapped to filtering validation functions 

94 """ 

95 if validation_funcs is None: 

96 return vals 

97 if type_ in validation_funcs: # Only possible catch of UnionTypes 

98 # TYPING: Incompatible return value type (got "filter[FiniteValued]", expected "Generator[FiniteValued, None, None]") [return-value] 

99 return filter(validation_funcs[type_], vals) 

100 elif hasattr( 

101 type_, 

102 "__mro__", 

103 ): # Generic types like UnionType, Literal don't have `__mro__` 

104 for superclass in type_.__mro__: 

105 if superclass not in validation_funcs: 

106 continue 

107 # TYPING: error: Incompatible types in assignment (expression has type "filter[FiniteValued]", variable has type "Generator[FiniteValued, None, None]") [assignment] 

108 vals = filter(validation_funcs[superclass], vals) 

109 break # Only the first validation function hit in the mro is applied 

110 elif get_origin(type_) == Literal: 

111 return flatten( 

112 ( 

113 _apply_validation_func(type(v), [v], validation_funcs) 

114 for v in get_args(type_) 

115 ), 

116 levels_to_flatten=1, 

117 ) 

118 return vals 

119 

120 

121# TYPING: some better type hints would be nice here 

122def _all_instances_wrapper(f: Callable) -> Callable: 

123 """Converts dicts to frozendicts to allow caching and applies `_apply_validation_func`.""" 

124 

125 @wraps(f) 

126 def wrapper(*args, **kwargs): # noqa: ANN202 

127 @cache 

128 def cached_wrapper( # noqa: ANN202 

129 type_: type, 

130 all_instances_func: Callable, 

131 validation_funcs: ( 

132 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] 

133 | None 

134 ), 

135 ): 

136 return _apply_validation_func( 

137 type_, 

138 all_instances_func(type_, validation_funcs), 

139 validation_funcs, 

140 ) 

141 

142 validation_funcs: frozendict.frozendict 

143 # TODO: what is this magic value here exactly? 

144 if len(args) >= 2 and args[1] is not None: # noqa: PLR2004 

145 validation_funcs = frozendict.frozendict(args[1]) 

146 elif "validation_funcs" in kwargs and kwargs["validation_funcs"] is not None: 

147 validation_funcs = frozendict.frozendict(kwargs["validation_funcs"]) 

148 else: 

149 validation_funcs = None 

150 return cached_wrapper(args[0], f, validation_funcs) 

151 

152 return wrapper 

153 

154 

155class UnsupportedAllInstancesError(TypeError): 

156 """Raised when `all_instances` is called on an unsupported type 

157 

158 either has unbounded possible values or is not supported (Enum is not supported) 

159 """ 

160 

161 def __init__(self, type_: type) -> None: 

162 "constructs an error message with the type and mro of the type" 

163 msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }" 

164 super().__init__(msg) 

165 

166 

167@_all_instances_wrapper 

168def all_instances( 

169 type_: FiniteValued, 

170 validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None, 

171) -> Generator[FiniteValued, None, None]: 

172 """Returns all possible values of an instance of `type_` if finite instances exist. 

173 

174 Uses type hinting to construct the possible values. 

175 All nested elements of `type_` must themselves be typed. 

176 Do not use with types whose members contain circular references. 

177 Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`. 

178 

179 # Parameters 

180 - `type_: FiniteValued` 

181 A finite-valued type. See docstring on `FiniteValued` for full details. 

182 - `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None` 

183 A mapping of types to auxiliary functions to validate instances of that type. 

184 This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide. 

185 See `validation_funcs` Details section below. 

186 (default: `None`) 

187 

188 ## Supported `type_` Values 

189 See docstring on `FiniteValued` for full details. 

190 `type_` may be: 

191 - `FiniteValued` 

192 - A finite-valued, fixed-length Generic tuple type. 

193 E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK. 

194 `tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed. 

195 - Nested versions of any of the types in this list 

196 - A `UnionType` of any of the types in this list 

197 

198 ## `validation_funcs` Details 

199 - `validation_funcs` is applied after all instances have been generated according to type hints. 

200 - If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`. 

201 - `validation_funcs` is passed down for all recursive calls of `all_instances`. 

202 - This allows for improved performance through maximal pruning of the exponential tree. 

203 - `validation_funcs` supports subclass checking. 

204 - If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order. 

205 - If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned. 

206 - If no superclass of `type_` is found, then no filter is applied. 

207 

208 # Raises: 

209 - `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`. 

210 """ 

211 if type_ == bool: # noqa: E721 

212 yield from [True, False] 

213 elif hasattr(type_, "__dataclass_fields__"): 

214 if is_abstract(type_): 

215 # Abstract dataclass: call `all_instances` on each subclass 

216 yield from flatten( 

217 ( 

218 all_instances(sub, validation_funcs) 

219 for sub in type_.__subclasses__() 

220 ), 

221 levels_to_flatten=1, 

222 ) 

223 else: 

224 # Concrete dataclass: construct dataclass instances with all possible combinations of fields 

225 fields: list[Field] = type_.__dataclass_fields__ 

226 fields_to_types: dict[str, type] = {f: fields[f].type for f in fields} 

227 all_arg_sequences: Iterable = itertools.product( 

228 *[ 

229 all_instances(arg_type, validation_funcs) 

230 for arg_type in fields_to_types.values() 

231 ], 

232 ) 

233 yield from ( 

234 type_( 

235 **dict(zip(fields_to_types.keys(), args, strict=False)), 

236 ) 

237 for args in all_arg_sequences 

238 ) 

239 else: 

240 type_origin = get_origin(type_) 

241 if type_origin == tuple: # noqa: E721 

242 # Only matches Generic type tuple since regular tuple is not finite-valued 

243 # Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields. 

244 yield from ( 

245 tuple(combo) 

246 for combo in itertools.product( 

247 *( 

248 all_instances(tup_item, validation_funcs) 

249 for tup_item in get_args(type_) 

250 ), 

251 ) 

252 ) 

253 elif type_origin in (UnionType, typing.Union): 

254 # Union: call `all_instances` for each type in the Union 

255 yield from flatten( 

256 [all_instances(sub, validation_funcs) for sub in get_args(type_)], 

257 levels_to_flatten=1, 

258 ) 

259 elif type_origin is Literal: 

260 # Literal: return all Literal arguments 

261 yield from get_args(type_) 

262 else: 

263 raise UnsupportedAllInstancesError(type_)