maze_dataset.tokenization.modular.element_base
provides the base _TokenizerElement
class and related functionality for modular maze tokenization
see the code in maze_dataset.tokenization.modular.elements
for examples of subclasses of _TokenizerElement
1"""provides the base `_TokenizerElement` class and related functionality for modular maze tokenization 2 3see the code in `maze_dataset.tokenization.modular.elements` for examples of subclasses of `_TokenizerElement` 4""" 5 6import abc 7from typing import ( 8 Any, 9 Callable, 10 Literal, 11 TypeVar, 12) 13 14from muutils.json_serialize import ( 15 SerializableDataclass, 16 serializable_dataclass, 17 serializable_field, 18) 19from muutils.json_serialize.util import _FORMAT_KEY 20from muutils.misc import flatten 21from zanj.loading import load_item_recursive 22 23from maze_dataset.tokenization.modular.hashing import _hash_tokenizer_name 24 25# from maze_dataset import SolvedMaze 26 27 28@serializable_dataclass(frozen=True, kw_only=True) 29class _TokenizerElement(SerializableDataclass, abc.ABC): 30 """Superclass for tokenizer elements. 31 32 Subclasses contain modular functionality for maze tokenization. 33 34 # Development 35 > [!TIP] 36 > Due to the functionality of `get_all_tokenizers()`, `_TokenizerElement` subclasses 37 > may only contain fields of type `utils.FiniteValued`. 38 > Implementing a subclass with an `int` or `float`-typed field, for example, is not supported. 39 > In the event that adding such fields is deemed necessary, `get_all_tokenizers()` must be updated. 40 41 """ 42 43 # TYPING: type hint `v` more specifically 44 @staticmethod 45 def _stringify(k: str, v: Any) -> str: # noqa: ANN401 46 if isinstance(v, bool): 47 return f"{k}={str(v)[0]}" 48 if isinstance(v, _TokenizerElement): 49 return v.name 50 if isinstance(v, tuple): 51 return f"{k}={''.join(['(', *[str(x) + ', ' for x in v], ')'])}" 52 else: 53 return f"{k}={v}" 54 55 @property 56 def name(self) -> str: 57 members_str: str = ", ".join( 58 [self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"], 59 ) 60 output: str = f"{type(self).__name__}({members_str})" 61 if "." in output and output.index("(") > output.index("."): 62 return "".join(output.split(".")[1:]) 63 else: 64 return output 65 66 def __str__(self) -> str: 67 return self.name 68 69 # TYPING: type hints for `__init_subclass__`? 70 def __init_subclass__(cls, **kwargs): # noqa: ANN204 71 """Hack: dataclass hashes don't include the class itself in the hash function inputs. 72 73 This causes dataclasses with identical fields but different types to hash identically. 74 This hack circumvents this by adding a slightly hidden field to every subclass with a value of `repr(cls)`. 75 To maintain compatibility with `all_instances`, the static type of the new field can only have 1 possible value. 76 So we type it as a singleton `Literal` type. 77 muutils 0.6.1 doesn't support `Literal` type validation, so `assert_type=False`. 78 Ignore Pylance complaining about the arg to `Literal` being an expression. 79 """ 80 super().__init_subclass__(**kwargs) 81 # we are adding a new attr here intentionally 82 cls._type_ = serializable_field( # type: ignore[attr-defined] 83 init=True, 84 repr=False, 85 default=repr(cls), 86 assert_type=False, 87 ) 88 cls.__annotations__["_type_"] = Literal[repr(cls)] 89 90 def __hash__(self) -> int: 91 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" 92 return _hash_tokenizer_name(self.name) 93 94 @classmethod 95 def _level_one_subclass(cls) -> type["_TokenizerElement"]: 96 """Returns the immediate subclass of `_TokenizerElement` of which `cls` is an instance.""" 97 return ( 98 set(cls.__mro__).intersection(set(_TokenizerElement.__subclasses__())).pop() 99 ) 100 101 def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]: 102 """Returns a list of all `_TokenizerElement` instances contained in the subtree. 103 104 Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or 105 which sit inside a `tuple` without further nesting. 106 107 # Parameters 108 - `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer. 109 """ 110 if not any(type(el) == tuple for el in self.__dict__.values()): # noqa: E721 111 return list( 112 flatten( 113 [ 114 [el, *el.tokenizer_elements()] 115 for el in self.__dict__.values() 116 if isinstance(el, _TokenizerElement) 117 ], 118 ) 119 if deep 120 else filter( 121 lambda x: isinstance(x, _TokenizerElement), 122 self.__dict__.values(), 123 ), 124 ) 125 else: 126 non_tuple_elems: list[_TokenizerElement] = list( 127 flatten( 128 [ 129 [el, *el.tokenizer_elements()] 130 for el in self.__dict__.values() 131 if isinstance(el, _TokenizerElement) 132 ] 133 if deep 134 else filter( 135 lambda x: isinstance(x, _TokenizerElement), 136 self.__dict__.values(), 137 ), 138 ), 139 ) 140 tuple_elems: list[_TokenizerElement] = list( 141 flatten( 142 [ 143 ( 144 [ 145 [tup_el, *tup_el.tokenizer_elements()] 146 for tup_el in el 147 if isinstance(tup_el, _TokenizerElement) 148 ] 149 if deep 150 else filter(lambda x: isinstance(x, _TokenizerElement), el) 151 ) 152 for el in self.__dict__.values() 153 if isinstance(el, tuple) 154 ], 155 ), 156 ) 157 non_tuple_elems.extend(tuple_elems) 158 return non_tuple_elems 159 160 def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str: 161 """Returns a string representation of the tree of tokenizer elements contained in `self`. 162 163 # Parameters 164 - `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify. 165 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 166 """ 167 name: str = "\t" * depth + ( 168 type(self).__name__ 169 if not abstract 170 else type(self)._level_one_subclass().__name__ 171 ) 172 return ( 173 name 174 + "\n" 175 + "".join( 176 el.tokenizer_element_tree(depth + 1, abstract) 177 for el in self.tokenizer_elements(deep=False) 178 ) 179 ) 180 181 def tokenizer_element_dict(self) -> dict: 182 """Returns a dictionary representation of the tree of tokenizer elements contained in `self`.""" 183 return { 184 type(self).__name__: { 185 key: ( 186 val.tokenizer_element_dict() 187 if isinstance(val, _TokenizerElement) 188 else ( 189 val 190 if not isinstance(val, tuple) 191 else [ 192 ( 193 el.tokenizer_element_dict() 194 if isinstance(el, _TokenizerElement) 195 else el 196 ) 197 for el in val 198 ] 199 ) 200 ) 201 for key, val in self.__dict__.items() 202 if key != "_type_" 203 }, 204 } 205 206 @classmethod 207 @abc.abstractmethod 208 def attribute_key(cls) -> str: 209 """Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`.""" 210 raise NotImplementedError 211 212 def to_tokens(self, *args, **kwargs) -> list[str]: 213 """Converts a maze element into a list of tokens. 214 215 Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method. 216 Those subclasses which do produce tokens should override this method. 217 """ 218 raise NotImplementedError 219 220 @abc.abstractmethod 221 def is_valid(self, do_except: bool = False) -> bool: 222 """Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`. 223 224 Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints. 225 `is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone. 226 If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass. 227 228 # Types of Invalidity 229 In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. 230 Invalidity types, in ascending order of invalidity: 231 - Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. 232 E.g., `_TokenizerElement`s which are strictly worse than some alternative. 233 - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers. 234 - Untrainable: Training functional models using these tokenizers would be (nearly) impossible. 235 - Erroneous: These tokenizers might raise exceptions during use. 236 237 # Development 238 `is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid. 239 When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary. 240 241 ## Nesting 242 In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class. 243 In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree. 244 `MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected. 245 246 ## Types of Invalidity 247 If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. 248 This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances. 249 """ 250 raise NotImplementedError 251 252 253T = TypeVar("T", bound=_TokenizerElement) 254 255 256def _unsupported_is_invalid(self, do_except: bool = False) -> bool: # noqa: ANN001 257 """Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes""" 258 if do_except: 259 err_msg: str = ( 260 f"Class `{type(self).__name__ = }, marked as unsupported, is not valid." 261 f"{type(self) = }, {self = }" 262 ) 263 raise ValueError(err_msg) 264 265 return False 266 267 268# TYPING: better type hints for this function 269def mark_as_unsupported(is_valid: Callable[[T, bool], bool]) -> Callable[[T], T]: 270 """mark a _TokenizerElement as unsupported. 271 272 Classes marked with this decorator won't show up in `get_all_tokenizers()` and thus wont be tested. 273 The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method. 274 The decorator exists to prune the space of tokenizers returned by `all_instances` both for testing and usage. 275 Previously, the space was too large, resulting in impractical runtimes. 276 These decorators could be removed in future releases to expand the space of possible tokenizers. 277 """ 278 279 def wrapper(cls: T) -> T: 280 # intentionally modifying method here 281 # idk why it things `T`/`self` should not be an argument 282 cls.is_valid = is_valid # type: ignore[assignment, method-assign] 283 return cls 284 285 return wrapper 286 287 288# TODO: why noqa here? `B024 `__TokenizerElementNamespace` is an abstract base class, but it has no abstract methods or properties` 289class __TokenizerElementNamespace(abc.ABC): # noqa: B024 290 """ABC for namespaces 291 292 # Properties 293 - key: The binding used in `MazeTokenizerModular` for instances of the classes contained within that `__TokenizerElementNamespace`. 294 """ 295 296 # HACK: this is not the right way of doing this lol 297 key: str = NotImplementedError # type: ignore[assignment] 298 299 300def _load_tokenizer_element( 301 data: dict[str, Any], 302 namespace: type[__TokenizerElementNamespace], 303) -> _TokenizerElement: 304 """Loads a `TokenizerElement` stored via zanj.""" 305 key: str = namespace.key 306 format_: str = data[key][_FORMAT_KEY] 307 cls_name: str = format_.split("(")[0] 308 cls: type[_TokenizerElement] = getattr(namespace, cls_name) 309 kwargs: dict[str, Any] = { 310 k: load_item_recursive(data[key][k], tuple()) for k, v in data[key].items() 311 } 312 if _FORMAT_KEY in kwargs: 313 kwargs.pop(_FORMAT_KEY) 314 return cls(**kwargs)
270def mark_as_unsupported(is_valid: Callable[[T, bool], bool]) -> Callable[[T], T]: 271 """mark a _TokenizerElement as unsupported. 272 273 Classes marked with this decorator won't show up in `get_all_tokenizers()` and thus wont be tested. 274 The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method. 275 The decorator exists to prune the space of tokenizers returned by `all_instances` both for testing and usage. 276 Previously, the space was too large, resulting in impractical runtimes. 277 These decorators could be removed in future releases to expand the space of possible tokenizers. 278 """ 279 280 def wrapper(cls: T) -> T: 281 # intentionally modifying method here 282 # idk why it things `T`/`self` should not be an argument 283 cls.is_valid = is_valid # type: ignore[assignment, method-assign] 284 return cls 285 286 return wrapper
mark a _TokenizerElement as unsupported.
Classes marked with this decorator won't show up in get_all_tokenizers()
and thus wont be tested.
The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method.
The decorator exists to prune the space of tokenizers returned by all_instances
both for testing and usage.
Previously, the space was too large, resulting in impractical runtimes.
These decorators could be removed in future releases to expand the space of possible tokenizers.