docs for maze-dataset v1.3.2
View Source on GitHub

maze_dataset.tokenization.modular.element_base

provides the base _TokenizerElement class and related functionality for modular maze tokenization

see the code in maze_dataset.tokenization.modular.elements for examples of subclasses of _TokenizerElement


  1"""provides the base `_TokenizerElement` class and related functionality for modular maze tokenization
  2
  3see the code in `maze_dataset.tokenization.modular.elements` for examples of subclasses of `_TokenizerElement`
  4"""
  5
  6import abc
  7from typing import (
  8	Any,
  9	Callable,
 10	Literal,
 11	TypeVar,
 12)
 13
 14from muutils.json_serialize import (
 15	SerializableDataclass,
 16	serializable_dataclass,
 17	serializable_field,
 18)
 19from muutils.json_serialize.util import _FORMAT_KEY
 20from muutils.misc import flatten
 21from zanj.loading import load_item_recursive
 22
 23from maze_dataset.tokenization.modular.hashing import _hash_tokenizer_name
 24
 25# from maze_dataset import SolvedMaze
 26
 27
 28@serializable_dataclass(frozen=True, kw_only=True)
 29class _TokenizerElement(SerializableDataclass, abc.ABC):
 30	"""Superclass for tokenizer elements.
 31
 32	Subclasses contain modular functionality for maze tokenization.
 33
 34	# Development
 35	> [!TIP]
 36	> Due to the functionality of `get_all_tokenizers()`, `_TokenizerElement` subclasses
 37	> may only contain fields of type `utils.FiniteValued`.
 38	> Implementing a subclass with an `int` or `float`-typed field, for example, is not supported.
 39	> In the event that adding such fields is deemed necessary, `get_all_tokenizers()` must be updated.
 40
 41	"""
 42
 43	# TYPING: type hint `v` more specifically
 44	@staticmethod
 45	def _stringify(k: str, v: Any) -> str:  # noqa: ANN401
 46		if isinstance(v, bool):
 47			return f"{k}={str(v)[0]}"
 48		if isinstance(v, _TokenizerElement):
 49			return v.name
 50		if isinstance(v, tuple):
 51			return f"{k}={''.join(['(', *[str(x) + ', ' for x in v], ')'])}"
 52		else:
 53			return f"{k}={v}"
 54
 55	@property
 56	def name(self) -> str:
 57		members_str: str = ", ".join(
 58			[self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"],
 59		)
 60		output: str = f"{type(self).__name__}({members_str})"
 61		if "." in output and output.index("(") > output.index("."):
 62			return "".join(output.split(".")[1:])
 63		else:
 64			return output
 65
 66	def __str__(self) -> str:
 67		return self.name
 68
 69	# TYPING: type hints for `__init_subclass__`?
 70	def __init_subclass__(cls, **kwargs):  # noqa: ANN204
 71		"""Hack: dataclass hashes don't include the class itself in the hash function inputs.
 72
 73		This causes dataclasses with identical fields but different types to hash identically.
 74		This hack circumvents this by adding a slightly hidden field to every subclass with a value of `repr(cls)`.
 75		To maintain compatibility with `all_instances`, the static type of the new field can only have 1 possible value.
 76		So we type it as a singleton `Literal` type.
 77		muutils 0.6.1 doesn't support `Literal` type validation, so `assert_type=False`.
 78		Ignore Pylance complaining about the arg to `Literal` being an expression.
 79		"""
 80		super().__init_subclass__(**kwargs)
 81		# we are adding a new attr here intentionally
 82		cls._type_ = serializable_field(  # type: ignore[attr-defined]
 83			init=True,
 84			repr=False,
 85			default=repr(cls),
 86			assert_type=False,
 87		)
 88		cls.__annotations__["_type_"] = Literal[repr(cls)]
 89
 90	def __hash__(self) -> int:
 91		"Stable hash to identify unique `MazeTokenizerModular` instances. uses name"
 92		return _hash_tokenizer_name(self.name)
 93
 94	@classmethod
 95	def _level_one_subclass(cls) -> type["_TokenizerElement"]:
 96		"""Returns the immediate subclass of `_TokenizerElement` of which `cls` is an instance."""
 97		return (
 98			set(cls.__mro__).intersection(set(_TokenizerElement.__subclasses__())).pop()
 99		)
100
101	def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]:
102		"""Returns a list of all `_TokenizerElement` instances contained in the subtree.
103
104		Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or
105		which sit inside a `tuple` without further nesting.
106
107		# Parameters
108		- `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer.
109		"""
110		if not any(type(el) == tuple for el in self.__dict__.values()):  # noqa: E721
111			return list(
112				flatten(
113					[
114						[el, *el.tokenizer_elements()]
115						for el in self.__dict__.values()
116						if isinstance(el, _TokenizerElement)
117					],
118				)
119				if deep
120				else filter(
121					lambda x: isinstance(x, _TokenizerElement),
122					self.__dict__.values(),
123				),
124			)
125		else:
126			non_tuple_elems: list[_TokenizerElement] = list(
127				flatten(
128					[
129						[el, *el.tokenizer_elements()]
130						for el in self.__dict__.values()
131						if isinstance(el, _TokenizerElement)
132					]
133					if deep
134					else filter(
135						lambda x: isinstance(x, _TokenizerElement),
136						self.__dict__.values(),
137					),
138				),
139			)
140			tuple_elems: list[_TokenizerElement] = list(
141				flatten(
142					[
143						(
144							[
145								[tup_el, *tup_el.tokenizer_elements()]
146								for tup_el in el
147								if isinstance(tup_el, _TokenizerElement)
148							]
149							if deep
150							else filter(lambda x: isinstance(x, _TokenizerElement), el)
151						)
152						for el in self.__dict__.values()
153						if isinstance(el, tuple)
154					],
155				),
156			)
157			non_tuple_elems.extend(tuple_elems)
158			return non_tuple_elems
159
160	def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str:
161		"""Returns a string representation of the tree of tokenizer elements contained in `self`.
162
163		# Parameters
164		- `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify.
165		- `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance.
166		"""
167		name: str = "\t" * depth + (
168			type(self).__name__
169			if not abstract
170			else type(self)._level_one_subclass().__name__
171		)
172		return (
173			name
174			+ "\n"
175			+ "".join(
176				el.tokenizer_element_tree(depth + 1, abstract)
177				for el in self.tokenizer_elements(deep=False)
178			)
179		)
180
181	def tokenizer_element_dict(self) -> dict:
182		"""Returns a dictionary representation of the tree of tokenizer elements contained in `self`."""
183		return {
184			type(self).__name__: {
185				key: (
186					val.tokenizer_element_dict()
187					if isinstance(val, _TokenizerElement)
188					else (
189						val
190						if not isinstance(val, tuple)
191						else [
192							(
193								el.tokenizer_element_dict()
194								if isinstance(el, _TokenizerElement)
195								else el
196							)
197							for el in val
198						]
199					)
200				)
201				for key, val in self.__dict__.items()
202				if key != "_type_"
203			},
204		}
205
206	@classmethod
207	@abc.abstractmethod
208	def attribute_key(cls) -> str:
209		"""Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`."""
210		raise NotImplementedError
211
212	def to_tokens(self, *args, **kwargs) -> list[str]:
213		"""Converts a maze element into a list of tokens.
214
215		Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method.
216		Those subclasses which do produce tokens should override this method.
217		"""
218		raise NotImplementedError
219
220	@abc.abstractmethod
221	def is_valid(self, do_except: bool = False) -> bool:
222		"""Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`.
223
224		Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints.
225		`is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone.
226		If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass.
227
228		# Types of Invalidity
229		In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below.
230		Invalidity types, in ascending order of invalidity:
231		- Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study.
232		E.g., `_TokenizerElement`s which are strictly worse than some alternative.
233		- Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers.
234		- Untrainable: Training functional models using these tokenizers would be (nearly) impossible.
235		- Erroneous: These tokenizers might raise exceptions during use.
236
237		# Development
238		`is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid.
239		When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary.
240
241		## Nesting
242		In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class.
243		In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree.
244		`MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected.
245
246		## Types of Invalidity
247		If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments.
248		This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances.
249		"""
250		raise NotImplementedError
251
252
253T = TypeVar("T", bound=_TokenizerElement)
254
255
256def _unsupported_is_invalid(self, do_except: bool = False) -> bool:  # noqa: ANN001
257	"""Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes"""
258	if do_except:
259		err_msg: str = (
260			f"Class `{type(self).__name__ = }, marked as unsupported, is not valid."
261			f"{type(self) = }, {self = }"
262		)
263		raise ValueError(err_msg)
264
265	return False
266
267
268# TYPING: better type hints for this function
269def mark_as_unsupported(is_valid: Callable[[T, bool], bool]) -> Callable[[T], T]:
270	"""mark a _TokenizerElement as unsupported.
271
272	Classes marked with this decorator won't show up in `get_all_tokenizers()` and thus wont be tested.
273	The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method.
274	The decorator exists to prune the space of tokenizers returned by `all_instances` both for testing and usage.
275	Previously, the space was too large, resulting in impractical runtimes.
276	These decorators could be removed in future releases to expand the space of possible tokenizers.
277	"""
278
279	def wrapper(cls: T) -> T:
280		# intentionally modifying method here
281		# idk why it things `T`/`self` should not be an argument
282		cls.is_valid = is_valid  # type: ignore[assignment, method-assign]
283		return cls
284
285	return wrapper
286
287
288# TODO: why noqa here? `B024 `__TokenizerElementNamespace` is an abstract base class, but it has no abstract methods or properties`
289class __TokenizerElementNamespace(abc.ABC):  # noqa: B024
290	"""ABC for namespaces
291
292	# Properties
293	- key: The binding used in `MazeTokenizerModular` for instances of the classes contained within that `__TokenizerElementNamespace`.
294	"""
295
296	# HACK: this is not the right way of doing this lol
297	key: str = NotImplementedError  # type: ignore[assignment]
298
299
300def _load_tokenizer_element(
301	data: dict[str, Any],
302	namespace: type[__TokenizerElementNamespace],
303) -> _TokenizerElement:
304	"""Loads a `TokenizerElement` stored via zanj."""
305	key: str = namespace.key
306	format_: str = data[key][_FORMAT_KEY]
307	cls_name: str = format_.split("(")[0]
308	cls: type[_TokenizerElement] = getattr(namespace, cls_name)
309	kwargs: dict[str, Any] = {
310		k: load_item_recursive(data[key][k], tuple()) for k, v in data[key].items()
311	}
312	if _FORMAT_KEY in kwargs:
313		kwargs.pop(_FORMAT_KEY)
314	return cls(**kwargs)

def mark_as_unsupported(is_valid: Callable[[~T, bool], bool]) -> Callable[[~T], ~T]:
270def mark_as_unsupported(is_valid: Callable[[T, bool], bool]) -> Callable[[T], T]:
271	"""mark a _TokenizerElement as unsupported.
272
273	Classes marked with this decorator won't show up in `get_all_tokenizers()` and thus wont be tested.
274	The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method.
275	The decorator exists to prune the space of tokenizers returned by `all_instances` both for testing and usage.
276	Previously, the space was too large, resulting in impractical runtimes.
277	These decorators could be removed in future releases to expand the space of possible tokenizers.
278	"""
279
280	def wrapper(cls: T) -> T:
281		# intentionally modifying method here
282		# idk why it things `T`/`self` should not be an argument
283		cls.is_valid = is_valid  # type: ignore[assignment, method-assign]
284		return cls
285
286	return wrapper

mark a _TokenizerElement as unsupported.

Classes marked with this decorator won't show up in get_all_tokenizers() and thus wont be tested. The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method. The decorator exists to prune the space of tokenizers returned by all_instances both for testing and usage. Previously, the space was too large, resulting in impractical runtimes. These decorators could be removed in future releases to expand the space of possible tokenizers.