Coverage for maze_dataset/tokenization/modular/element_base.py: 89%
79 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-09 12:48 -0600
1"""provides the base `_TokenizerElement` class and related functionality for modular maze tokenization
3see the code in `maze_dataset.tokenization.modular.elements` for examples of subclasses of `_TokenizerElement`
4"""
6import abc
7from typing import (
8 Any,
9 Callable,
10 Literal,
11 TypeVar,
12)
14from muutils.json_serialize import (
15 SerializableDataclass,
16 serializable_dataclass,
17 serializable_field,
18)
19from muutils.json_serialize.util import _FORMAT_KEY
20from muutils.misc import flatten
21from zanj.loading import load_item_recursive
23from maze_dataset.tokenization.modular.hashing import _hash_tokenizer_name
25# from maze_dataset import SolvedMaze
28@serializable_dataclass(frozen=True, kw_only=True)
29class _TokenizerElement(SerializableDataclass, abc.ABC):
30 """Superclass for tokenizer elements.
32 Subclasses contain modular functionality for maze tokenization.
34 # Development
35 > [!TIP]
36 > Due to the functionality of `get_all_tokenizers()`, `_TokenizerElement` subclasses
37 > may only contain fields of type `utils.FiniteValued`.
38 > Implementing a subclass with an `int` or `float`-typed field, for example, is not supported.
39 > In the event that adding such fields is deemed necessary, `get_all_tokenizers()` must be updated.
41 """
43 # TYPING: type hint `v` more specifically
44 @staticmethod
45 def _stringify(k: str, v: Any) -> str: # noqa: ANN401
46 if isinstance(v, bool):
47 return f"{k}={str(v)[0]}"
48 if isinstance(v, _TokenizerElement):
49 return v.name
50 if isinstance(v, tuple):
51 return f"{k}={''.join(['(', *[str(x) + ', ' for x in v], ')'])}"
52 else:
53 return f"{k}={v}"
55 @property
56 def name(self) -> str:
57 members_str: str = ", ".join(
58 [self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"],
59 )
60 output: str = f"{type(self).__name__}({members_str})"
61 if "." in output and output.index("(") > output.index("."):
62 return "".join(output.split(".")[1:])
63 else:
64 return output
66 def __str__(self) -> str:
67 return self.name
69 # TYPING: type hints for `__init_subclass__`?
70 def __init_subclass__(cls, **kwargs): # noqa: ANN204
71 """Hack: dataclass hashes don't include the class itself in the hash function inputs.
73 This causes dataclasses with identical fields but different types to hash identically.
74 This hack circumvents this by adding a slightly hidden field to every subclass with a value of `repr(cls)`.
75 To maintain compatibility with `all_instances`, the static type of the new field can only have 1 possible value.
76 So we type it as a singleton `Literal` type.
77 muutils 0.6.1 doesn't support `Literal` type validation, so `assert_type=False`.
78 Ignore Pylance complaining about the arg to `Literal` being an expression.
79 """
80 super().__init_subclass__(**kwargs)
81 # we are adding a new attr here intentionally
82 cls._type_ = serializable_field( # type: ignore[attr-defined]
83 init=True,
84 repr=False,
85 default=repr(cls),
86 assert_type=False,
87 )
88 cls.__annotations__["_type_"] = Literal[repr(cls)]
90 def __hash__(self) -> int:
91 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
92 return _hash_tokenizer_name(self.name)
94 @classmethod
95 def _level_one_subclass(cls) -> type["_TokenizerElement"]:
96 """Returns the immediate subclass of `_TokenizerElement` of which `cls` is an instance."""
97 return (
98 set(cls.__mro__).intersection(set(_TokenizerElement.__subclasses__())).pop()
99 )
101 def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]:
102 """Returns a list of all `_TokenizerElement` instances contained in the subtree.
104 Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or
105 which sit inside a `tuple` without further nesting.
107 # Parameters
108 - `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer.
109 """
110 if not any(type(el) == tuple for el in self.__dict__.values()): # noqa: E721
111 return list(
112 flatten(
113 [
114 [el, *el.tokenizer_elements()]
115 for el in self.__dict__.values()
116 if isinstance(el, _TokenizerElement)
117 ],
118 )
119 if deep
120 else filter(
121 lambda x: isinstance(x, _TokenizerElement),
122 self.__dict__.values(),
123 ),
124 )
125 else:
126 non_tuple_elems: list[_TokenizerElement] = list(
127 flatten(
128 [
129 [el, *el.tokenizer_elements()]
130 for el in self.__dict__.values()
131 if isinstance(el, _TokenizerElement)
132 ]
133 if deep
134 else filter(
135 lambda x: isinstance(x, _TokenizerElement),
136 self.__dict__.values(),
137 ),
138 ),
139 )
140 tuple_elems: list[_TokenizerElement] = list(
141 flatten(
142 [
143 (
144 [
145 [tup_el, *tup_el.tokenizer_elements()]
146 for tup_el in el
147 if isinstance(tup_el, _TokenizerElement)
148 ]
149 if deep
150 else filter(lambda x: isinstance(x, _TokenizerElement), el)
151 )
152 for el in self.__dict__.values()
153 if isinstance(el, tuple)
154 ],
155 ),
156 )
157 non_tuple_elems.extend(tuple_elems)
158 return non_tuple_elems
160 def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str:
161 """Returns a string representation of the tree of tokenizer elements contained in `self`.
163 # Parameters
164 - `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify.
165 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
166 """
167 name: str = "\t" * depth + (
168 type(self).__name__
169 if not abstract
170 else type(self)._level_one_subclass().__name__
171 )
172 return (
173 name
174 + "\n"
175 + "".join(
176 el.tokenizer_element_tree(depth + 1, abstract)
177 for el in self.tokenizer_elements(deep=False)
178 )
179 )
181 def tokenizer_element_dict(self) -> dict:
182 """Returns a dictionary representation of the tree of tokenizer elements contained in `self`."""
183 return {
184 type(self).__name__: {
185 key: (
186 val.tokenizer_element_dict()
187 if isinstance(val, _TokenizerElement)
188 else (
189 val
190 if not isinstance(val, tuple)
191 else [
192 (
193 el.tokenizer_element_dict()
194 if isinstance(el, _TokenizerElement)
195 else el
196 )
197 for el in val
198 ]
199 )
200 )
201 for key, val in self.__dict__.items()
202 if key != "_type_"
203 },
204 }
206 @classmethod
207 @abc.abstractmethod
208 def attribute_key(cls) -> str:
209 """Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`."""
210 raise NotImplementedError
212 def to_tokens(self, *args, **kwargs) -> list[str]:
213 """Converts a maze element into a list of tokens.
215 Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method.
216 Those subclasses which do produce tokens should override this method.
217 """
218 raise NotImplementedError
220 @abc.abstractmethod
221 def is_valid(self, do_except: bool = False) -> bool:
222 """Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`.
224 Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints.
225 `is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone.
226 If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass.
228 # Types of Invalidity
229 In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below.
230 Invalidity types, in ascending order of invalidity:
231 - Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study.
232 E.g., `_TokenizerElement`s which are strictly worse than some alternative.
233 - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
234 - Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
235 - Erroneous: These tokenizers might raise exceptions during use.
237 # Development
238 `is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid.
239 When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.
241 ## Nesting
242 In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class.
243 In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree.
244 `MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected.
246 ## Types of Invalidity
247 If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments.
248 This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances.
249 """
250 raise NotImplementedError
253T = TypeVar("T", bound=_TokenizerElement)
256def _unsupported_is_invalid(self, do_except: bool = False) -> bool: # noqa: ANN001
257 """Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes"""
258 if do_except:
259 err_msg: str = (
260 f"Class `{type(self).__name__ = }, marked as unsupported, is not valid."
261 f"{type(self) = }, {self = }"
262 )
263 raise ValueError(err_msg)
265 return False
268# TYPING: better type hints for this function
269def mark_as_unsupported(is_valid: Callable[[T, bool], bool]) -> Callable[[T], T]:
270 """mark a _TokenizerElement as unsupported.
272 Classes marked with this decorator won't show up in `get_all_tokenizers()` and thus wont be tested.
273 The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method.
274 The decorator exists to prune the space of tokenizers returned by `all_instances` both for testing and usage.
275 Previously, the space was too large, resulting in impractical runtimes.
276 These decorators could be removed in future releases to expand the space of possible tokenizers.
277 """
279 def wrapper(cls: T) -> T:
280 # intentionally modifying method here
281 # idk why it things `T`/`self` should not be an argument
282 cls.is_valid = is_valid # type: ignore[assignment, method-assign]
283 return cls
285 return wrapper
288# TODO: why noqa here? `B024 `__TokenizerElementNamespace` is an abstract base class, but it has no abstract methods or properties`
289class __TokenizerElementNamespace(abc.ABC): # noqa: B024
290 """ABC for namespaces
292 # Properties
293 - key: The binding used in `MazeTokenizerModular` for instances of the classes contained within that `__TokenizerElementNamespace`.
294 """
296 # HACK: this is not the right way of doing this lol
297 key: str = NotImplementedError # type: ignore[assignment]
300def _load_tokenizer_element(
301 data: dict[str, Any],
302 namespace: type[__TokenizerElementNamespace],
303) -> _TokenizerElement:
304 """Loads a `TokenizerElement` stored via zanj."""
305 key: str = namespace.key
306 format_: str = data[key][_FORMAT_KEY]
307 cls_name: str = format_.split("(")[0]
308 cls: type[_TokenizerElement] = getattr(namespace, cls_name)
309 kwargs: dict[str, Any] = {
310 k: load_item_recursive(data[key][k], tuple()) for k, v in data[key].items()
311 }
312 if _FORMAT_KEY in kwargs:
313 kwargs.pop(_FORMAT_KEY)
314 return cls(**kwargs)