maze_dataset.tokenization.modular.all_instances
all_instances
, FiniteValued
, and related code for tokenizers
1"`all_instances`, `FiniteValued`, and related code for tokenizers" 2 3import enum 4import itertools 5import typing 6from dataclasses import Field # noqa: TC003 7from functools import cache, wraps 8from types import UnionType 9from typing import ( 10 Callable, 11 Generator, 12 Iterable, 13 Literal, 14 TypeVar, 15 get_args, 16 get_origin, 17) 18 19try: 20 import frozendict 21except ImportError as e: 22 raise ImportError( 23 "You need to install the `frozendict` package to use `all_instances` -- try installing `maze_dataset[tokenization]`" 24 ) from e 25from muutils.misc import IsDataclass, flatten, is_abstract 26 27FiniteValued = TypeVar("FiniteValued", bound=bool | IsDataclass | enum.Enum) 28""" 29# `FiniteValued` 30The details of this type are not possible to fully define via the Python 3.10 typing library. 31This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space. 32`FiniteValued` defines the domain of supported types for the `all_instances` function, since that function relies heavily on static typing. 33These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below). 34The leaves of the tree must always be Primitive Types. 35 36# `FiniteValued` Subtypes 37*: Indicates that this subtype is not yet supported by `all_instances` 38 39## Non-`FiniteValued` (Unbounded) Types 40These are NOT valid subtypes, and are listed for illustrative purposes only. 41This list is not comprehensive. 42While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite, 43they are considered unbounded types in this context. 44- No Container subtype may contain any of these unbounded subtypes. 45- `int` 46- `float` 47- `str` 48- `list` 49- `set`: Set types without a `FiniteValued` argument are unbounded 50- `tuple`: Tuple types without a fixed length are unbounded 51 52## Primitive Types 53Primitive types are non-nested types which resolve directly to a concrete range of values 54- `bool`: has 2 possible values 55- *`enum.Enum`: The range of a concrete `Enum` subclass is its set of enum members 56- `typing.Literal`: Every type constructed using `Literal` has a finite set of possible literal values in its definition. 57This is the preferred way to include limited ranges of non-`FiniteValued` types such as `int` or `str` in a `FiniteValued` hierarchy. 58 59## Container Types 60Container types are types which contain zero or more fields of `FiniteValued` type. 61The range of a container type is the cartesian product of their field types, except for `set[FiniteValued]`. 62- `tuple[FiniteValued]`: Tuples of fixed length whose elements are each `FiniteValued`. 63- `IsDataclass`: Concrete dataclasses whose fields are `FiniteValued`. 64- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are `FiniteValued`-typed. 65- *`set[FiniteValued]`: Sets of fixed length of a `FiniteValued` type. 66 67## Superclass Types 68Superclass types don't directly contain data members like container types. 69Their range is the union of the ranges of their subtypes. 70- Abstract dataclasses: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types 71- *`IsDataclass`: Concrete dataclasses which also have their own subclasses. 72- *Standard abstract classes: Abstract dataclasses whose subclasses are all `FiniteValued` superclass or container types 73- `UnionType`: Any union of `FiniteValued` types, e.g., bool | Literal[2, 3] 74""" 75 76 77def _apply_validation_func( 78 type_: FiniteValued, 79 vals: Generator[FiniteValued, None, None], 80 validation_funcs: ( 81 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] | None 82 ) = None, 83) -> Generator[FiniteValued, None, None]: 84 """Helper function for `all_instances`. 85 86 Filters `vals` according to `validation_funcs`. 87 If `type_` is a regular type, searches in MRO order in `validation_funcs` and applies the first match, if any. 88 Handles generic types supported by `all_instances` with special `if` clauses. 89 90 # Parameters 91 - `type_: FiniteValued`: A type 92 - `vals: Generator[FiniteValued, None, None]`: Instances of `type_` 93 - `validation_funcs: dict`: Collection of types mapped to filtering validation functions 94 """ 95 if validation_funcs is None: 96 return vals 97 if type_ in validation_funcs: # Only possible catch of UnionTypes 98 # TYPING: Incompatible return value type (got "filter[FiniteValued]", expected "Generator[FiniteValued, None, None]") [return-value] 99 return filter(validation_funcs[type_], vals) 100 elif hasattr( 101 type_, 102 "__mro__", 103 ): # Generic types like UnionType, Literal don't have `__mro__` 104 for superclass in type_.__mro__: 105 if superclass not in validation_funcs: 106 continue 107 # TYPING: error: Incompatible types in assignment (expression has type "filter[FiniteValued]", variable has type "Generator[FiniteValued, None, None]") [assignment] 108 vals = filter(validation_funcs[superclass], vals) 109 break # Only the first validation function hit in the mro is applied 110 elif get_origin(type_) == Literal: 111 return flatten( 112 ( 113 _apply_validation_func(type(v), [v], validation_funcs) 114 for v in get_args(type_) 115 ), 116 levels_to_flatten=1, 117 ) 118 return vals 119 120 121# TYPING: some better type hints would be nice here 122def _all_instances_wrapper(f: Callable) -> Callable: 123 """Converts dicts to frozendicts to allow caching and applies `_apply_validation_func`.""" 124 125 @wraps(f) 126 def wrapper(*args, **kwargs): # noqa: ANN202 127 @cache 128 def cached_wrapper( # noqa: ANN202 129 type_: type, 130 all_instances_func: Callable, 131 validation_funcs: ( 132 frozendict.frozendict[FiniteValued, Callable[[FiniteValued], bool]] 133 | None 134 ), 135 ): 136 return _apply_validation_func( 137 type_, 138 all_instances_func(type_, validation_funcs), 139 validation_funcs, 140 ) 141 142 validation_funcs: frozendict.frozendict 143 # TODO: what is this magic value here exactly? 144 if len(args) >= 2 and args[1] is not None: # noqa: PLR2004 145 validation_funcs = frozendict.frozendict(args[1]) 146 elif "validation_funcs" in kwargs and kwargs["validation_funcs"] is not None: 147 validation_funcs = frozendict.frozendict(kwargs["validation_funcs"]) 148 else: 149 validation_funcs = None 150 return cached_wrapper(args[0], f, validation_funcs) 151 152 return wrapper 153 154 155class UnsupportedAllInstancesError(TypeError): 156 """Raised when `all_instances` is called on an unsupported type 157 158 either has unbounded possible values or is not supported (Enum is not supported) 159 """ 160 161 def __init__(self, type_: type) -> None: 162 "constructs an error message with the type and mro of the type" 163 msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }" 164 super().__init__(msg) 165 166 167@_all_instances_wrapper 168def all_instances( 169 type_: FiniteValued, 170 validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None, 171) -> Generator[FiniteValued, None, None]: 172 """Returns all possible values of an instance of `type_` if finite instances exist. 173 174 Uses type hinting to construct the possible values. 175 All nested elements of `type_` must themselves be typed. 176 Do not use with types whose members contain circular references. 177 Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`. 178 179 # Parameters 180 - `type_: FiniteValued` 181 A finite-valued type. See docstring on `FiniteValued` for full details. 182 - `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None` 183 A mapping of types to auxiliary functions to validate instances of that type. 184 This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide. 185 See `validation_funcs` Details section below. 186 (default: `None`) 187 188 ## Supported `type_` Values 189 See docstring on `FiniteValued` for full details. 190 `type_` may be: 191 - `FiniteValued` 192 - A finite-valued, fixed-length Generic tuple type. 193 E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK. 194 `tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed. 195 - Nested versions of any of the types in this list 196 - A `UnionType` of any of the types in this list 197 198 ## `validation_funcs` Details 199 - `validation_funcs` is applied after all instances have been generated according to type hints. 200 - If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`. 201 - `validation_funcs` is passed down for all recursive calls of `all_instances`. 202 - This allows for improved performance through maximal pruning of the exponential tree. 203 - `validation_funcs` supports subclass checking. 204 - If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order. 205 - If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned. 206 - If no superclass of `type_` is found, then no filter is applied. 207 208 # Raises: 209 - `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`. 210 """ 211 if type_ == bool: # noqa: E721 212 yield from [True, False] 213 elif hasattr(type_, "__dataclass_fields__"): 214 if is_abstract(type_): 215 # Abstract dataclass: call `all_instances` on each subclass 216 yield from flatten( 217 ( 218 all_instances(sub, validation_funcs) 219 for sub in type_.__subclasses__() 220 ), 221 levels_to_flatten=1, 222 ) 223 else: 224 # Concrete dataclass: construct dataclass instances with all possible combinations of fields 225 fields: list[Field] = type_.__dataclass_fields__ 226 fields_to_types: dict[str, type] = {f: fields[f].type for f in fields} 227 all_arg_sequences: Iterable = itertools.product( 228 *[ 229 all_instances(arg_type, validation_funcs) 230 for arg_type in fields_to_types.values() 231 ], 232 ) 233 yield from ( 234 type_( 235 **dict(zip(fields_to_types.keys(), args, strict=False)), 236 ) 237 for args in all_arg_sequences 238 ) 239 else: 240 type_origin = get_origin(type_) 241 if type_origin == tuple: # noqa: E721 242 # Only matches Generic type tuple since regular tuple is not finite-valued 243 # Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields. 244 yield from ( 245 tuple(combo) 246 for combo in itertools.product( 247 *( 248 all_instances(tup_item, validation_funcs) 249 for tup_item in get_args(type_) 250 ), 251 ) 252 ) 253 elif type_origin in (UnionType, typing.Union): 254 # Union: call `all_instances` for each type in the Union 255 yield from flatten( 256 [all_instances(sub, validation_funcs) for sub in get_args(type_)], 257 levels_to_flatten=1, 258 ) 259 elif type_origin is Literal: 260 # Literal: return all Literal arguments 261 yield from get_args(type_) 262 else: 263 raise UnsupportedAllInstancesError(type_)
FiniteValued
The details of this type are not possible to fully define via the Python 3.10 typing library.
This custom generic type is a generic domain of many types which have a finite, discrete, and well-defined range space.
FiniteValued
defines the domain of supported types for the all_instances
function, since that function relies heavily on static typing.
These types may be nested in an arbitrarily deep tree via Container Types and Superclass Types (see below).
The leaves of the tree must always be Primitive Types.
FiniteValued
Subtypes
*: Indicates that this subtype is not yet supported by all_instances
Non-FiniteValued
(Unbounded) Types
These are NOT valid subtypes, and are listed for illustrative purposes only. This list is not comprehensive. While the finite and discrete nature of digital computers means that the cardinality of these types is technically finite, they are considered unbounded types in this context.
- No Container subtype may contain any of these unbounded subtypes.
int
float
str
list
set
: Set types without aFiniteValued
argument are unboundedtuple
: Tuple types without a fixed length are unbounded
Primitive Types
Primitive types are non-nested types which resolve directly to a concrete range of values
bool
: has 2 possible values- *
enum.Enum
: The range of a concreteEnum
subclass is its set of enum members typing.Literal
: Every type constructed usingLiteral
has a finite set of possible literal values in its definition. This is the preferred way to include limited ranges of non-FiniteValued
types such asint
orstr
in aFiniteValued
hierarchy.
Container Types
Container types are types which contain zero or more fields of FiniteValued
type.
The range of a container type is the cartesian product of their field types, except for set[FiniteValued]
.
tuple[FiniteValued]
: Tuples of fixed length whose elements are eachFiniteValued
.IsDataclass
: Concrete dataclasses whose fields areFiniteValued
.- *Standard concrete class: Regular classes could be supported just like dataclasses if all their data members are
FiniteValued
-typed. - *
set[FiniteValued]
: Sets of fixed length of aFiniteValued
type.
Superclass Types
Superclass types don't directly contain data members like container types. Their range is the union of the ranges of their subtypes.
- Abstract dataclasses: Abstract dataclasses whose subclasses are all
FiniteValued
superclass or container types - *
IsDataclass
: Concrete dataclasses which also have their own subclasses. - *Standard abstract classes: Abstract dataclasses whose subclasses are all
FiniteValued
superclass or container types UnionType
: Any union ofFiniteValued
types, e.g., bool | Literal[2, 3]
156class UnsupportedAllInstancesError(TypeError): 157 """Raised when `all_instances` is called on an unsupported type 158 159 either has unbounded possible values or is not supported (Enum is not supported) 160 """ 161 162 def __init__(self, type_: type) -> None: 163 "constructs an error message with the type and mro of the type" 164 msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }" 165 super().__init__(msg)
Raised when all_instances
is called on an unsupported type
either has unbounded possible values or is not supported (Enum is not supported)
162 def __init__(self, type_: type) -> None: 163 "constructs an error message with the type and mro of the type" 164 msg: str = f"Type {type_} is not supported by `all_instances`. See docstring for details. {type_.__mro__ = }" 165 super().__init__(msg)
constructs an error message with the type and mro of the type
Inherited Members
- builtins.BaseException
- with_traceback
- add_note
- args
168@_all_instances_wrapper 169def all_instances( 170 type_: FiniteValued, 171 validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None = None, 172) -> Generator[FiniteValued, None, None]: 173 """Returns all possible values of an instance of `type_` if finite instances exist. 174 175 Uses type hinting to construct the possible values. 176 All nested elements of `type_` must themselves be typed. 177 Do not use with types whose members contain circular references. 178 Function is susceptible to infinite recursion if `type_` is a dataclass whose member tree includes another instance of `type_`. 179 180 # Parameters 181 - `type_: FiniteValued` 182 A finite-valued type. See docstring on `FiniteValued` for full details. 183 - `validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None` 184 A mapping of types to auxiliary functions to validate instances of that type. 185 This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide. 186 See `validation_funcs` Details section below. 187 (default: `None`) 188 189 ## Supported `type_` Values 190 See docstring on `FiniteValued` for full details. 191 `type_` may be: 192 - `FiniteValued` 193 - A finite-valued, fixed-length Generic tuple type. 194 E.g., `tuple[bool]`, `tuple[bool, MyEnum]` are OK. 195 `tuple[bool, ...]` is NOT supported, since the length of the tuple is not fixed. 196 - Nested versions of any of the types in this list 197 - A `UnionType` of any of the types in this list 198 199 ## `validation_funcs` Details 200 - `validation_funcs` is applied after all instances have been generated according to type hints. 201 - If `type_` is in `validation_funcs`, then the list of instances is filtered by `validation_funcs[type_](instance)`. 202 - `validation_funcs` is passed down for all recursive calls of `all_instances`. 203 - This allows for improved performance through maximal pruning of the exponential tree. 204 - `validation_funcs` supports subclass checking. 205 - If `type_` is not found in `validation_funcs`, then the search is performed iteratively in mro order. 206 - If a superclass of `type_` is found while searching in mro order, that validation function is applied and the list is returned. 207 - If no superclass of `type_` is found, then no filter is applied. 208 209 # Raises: 210 - `UnsupportedAllInstancesError`: If `type_` is not supported by `all_instances`. 211 """ 212 if type_ == bool: # noqa: E721 213 yield from [True, False] 214 elif hasattr(type_, "__dataclass_fields__"): 215 if is_abstract(type_): 216 # Abstract dataclass: call `all_instances` on each subclass 217 yield from flatten( 218 ( 219 all_instances(sub, validation_funcs) 220 for sub in type_.__subclasses__() 221 ), 222 levels_to_flatten=1, 223 ) 224 else: 225 # Concrete dataclass: construct dataclass instances with all possible combinations of fields 226 fields: list[Field] = type_.__dataclass_fields__ 227 fields_to_types: dict[str, type] = {f: fields[f].type for f in fields} 228 all_arg_sequences: Iterable = itertools.product( 229 *[ 230 all_instances(arg_type, validation_funcs) 231 for arg_type in fields_to_types.values() 232 ], 233 ) 234 yield from ( 235 type_( 236 **dict(zip(fields_to_types.keys(), args, strict=False)), 237 ) 238 for args in all_arg_sequences 239 ) 240 else: 241 type_origin = get_origin(type_) 242 if type_origin == tuple: # noqa: E721 243 # Only matches Generic type tuple since regular tuple is not finite-valued 244 # Generic tuple: Similar to concrete dataclass. Construct all possible combinations of tuple fields. 245 yield from ( 246 tuple(combo) 247 for combo in itertools.product( 248 *( 249 all_instances(tup_item, validation_funcs) 250 for tup_item in get_args(type_) 251 ), 252 ) 253 ) 254 elif type_origin in (UnionType, typing.Union): 255 # Union: call `all_instances` for each type in the Union 256 yield from flatten( 257 [all_instances(sub, validation_funcs) for sub in get_args(type_)], 258 levels_to_flatten=1, 259 ) 260 elif type_origin is Literal: 261 # Literal: return all Literal arguments 262 yield from get_args(type_) 263 else: 264 raise UnsupportedAllInstancesError(type_)
Returns all possible values of an instance of type_
if finite instances exist.
Uses type hinting to construct the possible values.
All nested elements of type_
must themselves be typed.
Do not use with types whose members contain circular references.
Function is susceptible to infinite recursion if type_
is a dataclass whose member tree includes another instance of type_
.
Parameters
type_: FiniteValued
A finite-valued type. See docstring onFiniteValued
for full details.validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None
A mapping of types to auxiliary functions to validate instances of that type. This optional argument can provide an additional, more precise layer of validation for the instances generated beyond what type hinting alone can provide. Seevalidation_funcs
Details section below. (default:None
)
Supported type_
Values
See docstring on FiniteValued
for full details.
type_
may be:
FiniteValued
- A finite-valued, fixed-length Generic tuple type.
E.g.,
tuple[bool]
,tuple[bool, MyEnum]
are OK.tuple[bool, ...]
is NOT supported, since the length of the tuple is not fixed. - Nested versions of any of the types in this list
- A
UnionType
of any of the types in this list
validation_funcs
Details
validation_funcs
is applied after all instances have been generated according to type hints.- If
type_
is invalidation_funcs
, then the list of instances is filtered byvalidation_funcs[type_](instance)
. validation_funcs
is passed down for all recursive calls ofall_instances
.- This allows for improved performance through maximal pruning of the exponential tree.
validation_funcs
supports subclass checking.- If
type_
is not found invalidation_funcs
, then the search is performed iteratively in mro order. - If a superclass of
type_
is found while searching in mro order, that validation function is applied and the list is returned. - If no superclass of
type_
is found, then no filter is applied.
Raises:
UnsupportedAllInstancesError
: Iftype_
is not supported byall_instances
.