Coverage for maze_dataset/tokenization/modular/all_instances.py: 97%
64 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
1"`all_instances`, `FiniteValued`, and related code for tokenizers"
3import enum
4import itertools
5import typing
6from dataclasses import Field # noqa: TC003
7from functools import cache, wraps
8from types import UnionType
9from typing import (
10 Callable,
11 Generator,
12 Iterable,
13 Literal,
14 TypeVar,
15 get_args,
16 get_origin,
17)
19try:
20 import frozendict
21except ImportError as e:
22 raise ImportError(
23 "You need to install the `frozendict` package to use `all_instances` -- try installing `maze_dataset[tokenization]`"
24 ) from e
25from muutils.misc import IsDataclass, flatten, is_abstract
27FiniteValued = TypeVar("FiniteValued", bound=bool | IsDataclass | enum.Enum)
28"""
29# `FiniteValued`
30The details of this type are not possible to fully define via the Python 3.10 typing library.
31This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space.
32`FiniteValued` defines the domain of supported types for the `all_instances` function, since that function relies heavily on static typing.
33These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below).
34The leaves of the tree must always be Primitive Types.
36# `FiniteValued` Subtypes
37*: Indicates that this subtype is not yet supported by `all_instances`
39## Non-`FiniteValued` (Unbounded) Types
40These are NOT valid subtypes, and are listed for illustrative purposes only.
41This list is not comprehensive.
42While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite,
43they are considered unbounded types in this context.
44- No Container subtype may contain any of these unbounded subtypes.
45- `int`
46- `float`
47- `str`
48- `list`
49- `set`: Set types without a `FiniteValued` argument are unbounded
50- `tuple`: Tuple types without a fixed length are unbounded
52## Primitive Types
53Primitive types are non-nested types which resolve directly to a concrete range of values
54- `bool`: has 2 possible values
55- *`enum.Enum`: The range of a concrete `Enum` subclass is its set of enum members
56- `typing.Literal`: Every type constructed using `Literal` has a finite set of possible literal values in its definition.
57This is the preferred way to include limited ranges of non-`FiniteValued` types such as `int` or `str` in a `FiniteValued` hierarchy.
59## Container Types
60Container types are types which contain zero or more fields of `FiniteValued` type.
61The range of a container type is the cartesian product of their field types, except for `set[FiniteValued]`.
62- `tuple[FiniteValued]`: Tuples of fixed length whose elements are each `FiniteValued`.
63- `IsDataclass`: Concrete dataclasses whose fields are `FiniteValued`.
64- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are `FiniteValued`-typed.
65- *`set[FiniteValued]`: Sets of fixed length of a `FiniteValued` type.
67## Superclass Types
68Superclass types don't directly contain data members like container types.
69Their range is the union of the ranges of their subtypes.
70- Abstract dataclasses: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types
71- *`IsDataclass`: Concrete dataclasses which also have their own subclasses.
72- *Standard abstract classes: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types
73- `UnionType`: Any union of `FiniteValued` types, e.g., bool | Literal[2, 3]
74"""
77def _apply_validation_func(
78 type_: FiniteValued,
79 vals: Generator[FiniteValued, None, None],
80 validation_funcs: (
81 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] | None
82 ) = None,
83) -> Generator[FiniteValued, None, None]:
84 """Helper function for `all_instances`.
86 Filters `vals` according to `validation_funcs`.
87 If `type_` is a regular type, searches in MRO order in `validation_funcs` and applies the first match, if any.
88 Handles generic types supported by `all_instances` with special `if` clauses.
90 # Parameters
91 - `type_: FiniteValued`: A type
92 - `vals: Generator[FiniteValued, None, None]`: Instances of `type_`
93 - `validation_funcs: dict`: Collection of types mapped to filtering validation functions
94 """
95 if validation_funcs is None:
96 return vals
97 if type_ in validation_funcs: # Only possible catch of UnionTypes
98 # TYPING: Incompatible return value type (got "filter[FiniteValued]", expected "Generator[FiniteValued, None, None]") [return-value]
99 return filter(validation_funcs[type_], vals)
100 elif hasattr(
101 type_,
102 "__mro__",
103 ): # Generic types like UnionType, Literal don't have `__mro__`
104 for superclass in type_.__mro__:
105 if superclass not in validation_funcs:
106 continue
107 # TYPING: error: Incompatible types in assignment (expression has type "filter[FiniteValued]", variable has type "Generator[FiniteValued, None, None]") [assignment]
108 vals = filter(validation_funcs[superclass], vals)
109 break # Only the first validation function hit in the mro is applied
110 elif get_origin(type_) == Literal:
111 return flatten(
112 (
113 _apply_validation_func(type(v), [v], validation_funcs)
114 for v in get_args(type_)
115 ),
116 levels_to_flatten=1,
117 )
118 return vals
121# TYPING: some better type hints would be nice here
122def _all_instances_wrapper(f: Callable) -> Callable:
123 """Converts dicts to frozendicts to allow caching and applies `_apply_validation_func`."""
125 @wraps(f)
126 def wrapper(*args, **kwargs): # noqa: ANN202
127 @cache
128 def cached_wrapper( # noqa: ANN202
129 type_: type,
130 all_instances_func: Callable,
131 validation_funcs: (
132 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]]
133 | None
134 ),
135 ):
136 return _apply_validation_func(
137 type_,
138 all_instances_func(type_, validation_funcs),
139 validation_funcs,
140 )
142 validation_funcs: frozendict.frozendict
143 # TODO: what is this magic value here exactly?
144 if len(args) >= 2 and args[1] is not None: # noqa: PLR2004
145 validation_funcs = frozendict.frozendict(args[1])
146 elif "validation_funcs" in kwargs and kwargs["validation_funcs"] is not None:
147 validation_funcs = frozendict.frozendict(kwargs["validation_funcs"])
148 else:
149 validation_funcs = None
150 return cached_wrapper(args[0], f, validation_funcs)
152 return wrapper
155class UnsupportedAllInstancesError(TypeError):
156 """Raised when `all_instances` is called on an unsupported type
158 either has unbounded possible values or is not supported (Enum is not supported)
159 """
161 def __init__(self, type_: type) -> None:
162 "constructs an error message with the type and mro of the type"
163 msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }"
164 super().__init__(msg)
167@_all_instances_wrapper
168def all_instances(
169 type_: FiniteValued,
170 validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None,
171) -> Generator[FiniteValued, None, None]:
172 """Returns all possible values of an instance of `type_` if finite instances exist.
174 Uses type hinting to construct the possible values.
175 All nested elements of `type_` must themselves be typed.
176 Do not use with types whose members contain circular references.
177 Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`.
179 # Parameters
180 - `type_: FiniteValued`
181 A finite-valued type. See docstring on `FiniteValued` for full details.
182 - `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None`
183 A mapping of types to auxiliary functions to validate instances of that type.
184 This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide.
185 See `validation_funcs` Details section below.
186 (default: `None`)
188 ## Supported `type_` Values
189 See docstring on `FiniteValued` for full details.
190 `type_` may be:
191 - `FiniteValued`
192 - A finite-valued, fixed-length Generic tuple type.
193 E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK.
194 `tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed.
195 - Nested versions of any of the types in this list
196 - A `UnionType` of any of the types in this list
198 ## `validation_funcs` Details
199 - `validation_funcs` is applied after all instances have been generated according to type hints.
200 - If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`.
201 - `validation_funcs` is passed down for all recursive calls of `all_instances`.
202 - This allows for improved performance through maximal pruning of the exponential tree.
203 - `validation_funcs` supports subclass checking.
204 - If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order.
205 - If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned.
206 - If no superclass of `type_` is found, then no filter is applied.
208 # Raises:
209 - `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`.
210 """
211 if type_ == bool: # noqa: E721
212 yield from [True, False]
213 elif hasattr(type_, "__dataclass_fields__"):
214 if is_abstract(type_):
215 # Abstract dataclass: call `all_instances` on each subclass
216 yield from flatten(
217 (
218 all_instances(sub, validation_funcs)
219 for sub in type_.__subclasses__()
220 ),
221 levels_to_flatten=1,
222 )
223 else:
224 # Concrete dataclass: construct dataclass instances with all possible combinations of fields
225 fields: list[Field] = type_.__dataclass_fields__
226 fields_to_types: dict[str, type] = {f: fields[f].type for f in fields}
227 all_arg_sequences: Iterable = itertools.product(
228 *[
229 all_instances(arg_type, validation_funcs)
230 for arg_type in fields_to_types.values()
231 ],
232 )
233 yield from (
234 type_(
235 **dict(zip(fields_to_types.keys(), args, strict=False)),
236 )
237 for args in all_arg_sequences
238 )
239 else:
240 type_origin = get_origin(type_)
241 if type_origin == tuple: # noqa: E721
242 # Only matches Generic type tuple since regular tuple is not finite-valued
243 # Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields.
244 yield from (
245 tuple(combo)
246 for combo in itertools.product(
247 *(
248 all_instances(tup_item, validation_funcs)
249 for tup_item in get_args(type_)
250 ),
251 )
252 )
253 elif type_origin in (UnionType, typing.Union):
254 # Union: call `all_instances` for each type in the Union
255 yield from flatten(
256 [all_instances(sub, validation_funcs) for sub in get_args(type_)],
257 levels_to_flatten=1,
258 )
259 elif type_origin is Literal:
260 # Literal: return all Literal arguments
261 yield from get_args(type_)
262 else:
263 raise UnsupportedAllInstancesError(type_)