Coverage for maze_dataset/tokenization/modular/element_base.py: 89%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-09 12:48 -0600

1"""provides the base `_TokenizerElement` class and related functionality for modular maze tokenization 

2 

3see the code in `maze_dataset.tokenization.modular.elements` for examples of subclasses of `_TokenizerElement` 

4""" 

5 

6import abc 

7from typing import ( 

8 Any, 

9 Callable, 

10 Literal, 

11 TypeVar, 

12) 

13 

14from muutils.json_serialize import ( 

15 SerializableDataclass, 

16 serializable_dataclass, 

17 serializable_field, 

18) 

19from muutils.json_serialize.util import _FORMAT_KEY 

20from muutils.misc import flatten 

21from zanj.loading import load_item_recursive 

22 

23from maze_dataset.tokenization.modular.hashing import _hash_tokenizer_name 

24 

25# from maze_dataset import SolvedMaze 

26 

27 

28@serializable_dataclass(frozen=True, kw_only=True) 

29class _TokenizerElement(SerializableDataclass, abc.ABC): 

30 """Superclass for tokenizer elements. 

31 

32 Subclasses contain modular functionality for maze tokenization. 

33 

34 # Development 

35 > [!TIP] 

36 > Due to the functionality of `get_all_tokenizers()`, `_TokenizerElement` subclasses 

37 > may only contain fields of type `utils.FiniteValued`. 

38 > Implementing a subclass with an `int` or `float`-typed field, for example, is not supported. 

39 > In the event that adding such fields is deemed necessary, `get_all_tokenizers()` must be updated. 

40 

41 """ 

42 

43 # TYPING: type hint `v` more specifically 

44 @staticmethod 

45 def _stringify(k: str, v: Any) -> str: # noqa: ANN401 

46 if isinstance(v, bool): 

47 return f"{k}={str(v)[0]}" 

48 if isinstance(v, _TokenizerElement): 

49 return v.name 

50 if isinstance(v, tuple): 

51 return f"{k}={''.join(['(', *[str(x) + ', ' for x in v], ')'])}" 

52 else: 

53 return f"{k}={v}" 

54 

55 @property 

56 def name(self) -> str: 

57 members_str: str = ", ".join( 

58 [self._stringify(k, v) for k, v in self.__dict__.items() if k != "_type_"], 

59 ) 

60 output: str = f"{type(self).__name__}({members_str})" 

61 if "." in output and output.index("(") > output.index("."): 

62 return "".join(output.split(".")[1:]) 

63 else: 

64 return output 

65 

66 def __str__(self) -> str: 

67 return self.name 

68 

69 # TYPING: type hints for `__init_subclass__`? 

70 def __init_subclass__(cls, **kwargs): # noqa: ANN204 

71 """Hack: dataclass hashes don't include the class itself in the hash function inputs. 

72 

73 This causes dataclasses with identical fields but different types to hash identically. 

74 This hack circumvents this by adding a slightly hidden field to every subclass with a value of `repr(cls)`. 

75 To maintain compatibility with `all_instances`, the static type of the new field can only have 1 possible value. 

76 So we type it as a singleton `Literal` type. 

77 muutils 0.6.1 doesn't support `Literal` type validation, so `assert_type=False`. 

78 Ignore Pylance complaining about the arg to `Literal` being an expression. 

79 """ 

80 super().__init_subclass__(**kwargs) 

81 # we are adding a new attr here intentionally 

82 cls._type_ = serializable_field( # type: ignore[attr-defined] 

83 init=True, 

84 repr=False, 

85 default=repr(cls), 

86 assert_type=False, 

87 ) 

88 cls.__annotations__["_type_"] = Literal[repr(cls)] 

89 

90 def __hash__(self) -> int: 

91 "Stable hash to identify unique `MazeTokenizerModular` instances. uses name" 

92 return _hash_tokenizer_name(self.name) 

93 

94 @classmethod 

95 def _level_one_subclass(cls) -> type["_TokenizerElement"]: 

96 """Returns the immediate subclass of `_TokenizerElement` of which `cls` is an instance.""" 

97 return ( 

98 set(cls.__mro__).intersection(set(_TokenizerElement.__subclasses__())).pop() 

99 ) 

100 

101 def tokenizer_elements(self, deep: bool = True) -> list["_TokenizerElement"]: 

102 """Returns a list of all `_TokenizerElement` instances contained in the subtree. 

103 

104 Currently only detects `_TokenizerElement` instances which are either direct attributes of another instance or 

105 which sit inside a `tuple` without further nesting. 

106 

107 # Parameters 

108 - `deep: bool`: Whether to return elements nested arbitrarily deeply or just a single layer. 

109 """ 

110 if not any(type(el) == tuple for el in self.__dict__.values()): # noqa: E721 

111 return list( 

112 flatten( 

113 [ 

114 [el, *el.tokenizer_elements()] 

115 for el in self.__dict__.values() 

116 if isinstance(el, _TokenizerElement) 

117 ], 

118 ) 

119 if deep 

120 else filter( 

121 lambda x: isinstance(x, _TokenizerElement), 

122 self.__dict__.values(), 

123 ), 

124 ) 

125 else: 

126 non_tuple_elems: list[_TokenizerElement] = list( 

127 flatten( 

128 [ 

129 [el, *el.tokenizer_elements()] 

130 for el in self.__dict__.values() 

131 if isinstance(el, _TokenizerElement) 

132 ] 

133 if deep 

134 else filter( 

135 lambda x: isinstance(x, _TokenizerElement), 

136 self.__dict__.values(), 

137 ), 

138 ), 

139 ) 

140 tuple_elems: list[_TokenizerElement] = list( 

141 flatten( 

142 [ 

143 ( 

144 [ 

145 [tup_el, *tup_el.tokenizer_elements()] 

146 for tup_el in el 

147 if isinstance(tup_el, _TokenizerElement) 

148 ] 

149 if deep 

150 else filter(lambda x: isinstance(x, _TokenizerElement), el) 

151 ) 

152 for el in self.__dict__.values() 

153 if isinstance(el, tuple) 

154 ], 

155 ), 

156 ) 

157 non_tuple_elems.extend(tuple_elems) 

158 return non_tuple_elems 

159 

160 def tokenizer_element_tree(self, depth: int = 0, abstract: bool = False) -> str: 

161 """Returns a string representation of the tree of tokenizer elements contained in `self`. 

162 

163 # Parameters 

164 - `depth: int`: Current depth in the tree. Used internally for recursion, no need to specify. 

165 - `abstract: bool`: Whether to print the name of the abstract base class or the concrete class for each `_TokenizerElement` instance. 

166 """ 

167 name: str = "\t" * depth + ( 

168 type(self).__name__ 

169 if not abstract 

170 else type(self)._level_one_subclass().__name__ 

171 ) 

172 return ( 

173 name 

174 + "\n" 

175 + "".join( 

176 el.tokenizer_element_tree(depth + 1, abstract) 

177 for el in self.tokenizer_elements(deep=False) 

178 ) 

179 ) 

180 

181 def tokenizer_element_dict(self) -> dict: 

182 """Returns a dictionary representation of the tree of tokenizer elements contained in `self`.""" 

183 return { 

184 type(self).__name__: { 

185 key: ( 

186 val.tokenizer_element_dict() 

187 if isinstance(val, _TokenizerElement) 

188 else ( 

189 val 

190 if not isinstance(val, tuple) 

191 else [ 

192 ( 

193 el.tokenizer_element_dict() 

194 if isinstance(el, _TokenizerElement) 

195 else el 

196 ) 

197 for el in val 

198 ] 

199 ) 

200 ) 

201 for key, val in self.__dict__.items() 

202 if key != "_type_" 

203 }, 

204 } 

205 

206 @classmethod 

207 @abc.abstractmethod 

208 def attribute_key(cls) -> str: 

209 """Returns the binding used in `MazeTokenizerModular` for that type of `_TokenizerElement`.""" 

210 raise NotImplementedError 

211 

212 def to_tokens(self, *args, **kwargs) -> list[str]: 

213 """Converts a maze element into a list of tokens. 

214 

215 Not all `_TokenizerElement` subclasses produce tokens, so this is not an abstract method. 

216 Those subclasses which do produce tokens should override this method. 

217 """ 

218 raise NotImplementedError 

219 

220 @abc.abstractmethod 

221 def is_valid(self, do_except: bool = False) -> bool: 

222 """Returns if `self` contains data members capable of producing an overall valid `MazeTokenizerModular`. 

223 

224 Some `_TokenizerElement` instances may be created which are not useful despite obeying data member type hints. 

225 `is_valid` allows for more precise detection of invalid `_TokenizerElement`s beyond type hinting alone. 

226 If type hints are sufficient to constrain the possible instances of some subclass, then this method may simply `return True` for that subclass. 

227 

228 # Types of Invalidity 

229 In nontrivial implementations of this method, each conditional clause should contain a comment classifying the reason for invalidity and one of the types below. 

230 Invalidity types, in ascending order of invalidity: 

231 - Uninteresting: These tokenizers might be used to train functional models, but the schemes are not interesting to study. 

232 E.g., `_TokenizerElement`s which are strictly worse than some alternative. 

233 - Duplicate: These tokenizers have identical tokenization behavior as some other valid tokenizers. 

234 - Untrainable: Training functional models using these tokenizers would be (nearly) impossible. 

235 - Erroneous: These tokenizers might raise exceptions during use. 

236 

237 # Development 

238 `is_invalid` is implemented to always return `True` in some abstract classes where all currently possible subclass instances are valid. 

239 When adding new subclasses or data members, the developer should check if any such blanket statement of validity still holds and update it as neccesary. 

240 

241 ## Nesting 

242 In general, when implementing this method, there is no need to recursively call `is_valid` on nested `_TokenizerElement`s contained in the class. 

243 In other words, failures of `is_valid` need not bubble up to the top of the nested `_TokenizerElement` tree. 

244 `MazeTokenizerModular.is_valid` calls `is_valid` on each of its `_TokenizerElement`s individually, so failure at any level will be detected. 

245 

246 ## Types of Invalidity 

247 If it's judged to be useful, the types of invalidity could be implemented with an Enum or similar rather than only living in comments. 

248 This could be used to create more or less stringent filters on the valid `_TokenizerElement` instances. 

249 """ 

250 raise NotImplementedError 

251 

252 

253T = TypeVar("T", bound=_TokenizerElement) 

254 

255 

256def _unsupported_is_invalid(self, do_except: bool = False) -> bool: # noqa: ANN001 

257 """Default implementation of `is_valid` for `mark_as_unsupported`-decorated classes""" 

258 if do_except: 

259 err_msg: str = ( 

260 f"Class `{type(self).__name__ = }, marked as unsupported, is not valid." 

261 f"{type(self) = }, {self = }" 

262 ) 

263 raise ValueError(err_msg) 

264 

265 return False 

266 

267 

268# TYPING: better type hints for this function 

269def mark_as_unsupported(is_valid: Callable[[T, bool], bool]) -> Callable[[T], T]: 

270 """mark a _TokenizerElement as unsupported. 

271 

272 Classes marked with this decorator won't show up in `get_all_tokenizers()` and thus wont be tested. 

273 The classes marked in release 1.0.0 did work reliably before being marked, but they can't be instantiated since the decorator adds an abstract method. 

274 The decorator exists to prune the space of tokenizers returned by `all_instances` both for testing and usage. 

275 Previously, the space was too large, resulting in impractical runtimes. 

276 These decorators could be removed in future releases to expand the space of possible tokenizers. 

277 """ 

278 

279 def wrapper(cls: T) -> T: 

280 # intentionally modifying method here 

281 # idk why it things `T`/`self` should not be an argument 

282 cls.is_valid = is_valid # type: ignore[assignment, method-assign] 

283 return cls 

284 

285 return wrapper 

286 

287 

288# TODO: why noqa here? `B024 `__TokenizerElementNamespace` is an abstract base class, but it has no abstract methods or properties` 

289class __TokenizerElementNamespace(abc.ABC): # noqa: B024 

290 """ABC for namespaces 

291 

292 # Properties 

293 - key: The binding used in `MazeTokenizerModular` for instances of the classes contained within that `__TokenizerElementNamespace`. 

294 """ 

295 

296 # HACK: this is not the right way of doing this lol 

297 key: str = NotImplementedError # type: ignore[assignment] 

298 

299 

300def _load_tokenizer_element( 

301 data: dict[str, Any], 

302 namespace: type[__TokenizerElementNamespace], 

303) -> _TokenizerElement: 

304 """Loads a `TokenizerElement` stored via zanj.""" 

305 key: str = namespace.key 

306 format_: str = data[key][_FORMAT_KEY] 

307 cls_name: str = format_.split("(")[0] 

308 cls: type[_TokenizerElement] = getattr(namespace, cls_name) 

309 kwargs: dict[str, Any] = { 

310 k: load_item_recursive(data[key][k], tuple()) for k, v in data[key].items() 

311 } 

312 if _FORMAT_KEY in kwargs: 

313 kwargs.pop(_FORMAT_KEY) 

314 return cls(**kwargs)