Coverage for tests/unit/tokenization/test_all_instances.py: 94%
54 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-27 23:43 -0600
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-27 23:43 -0600
1import abc
2from dataclasses import dataclass, field
3from typing import Callable, Iterable, Literal
5import pytest
6from muutils.misc import IsDataclass, dataclass_set_equals
8from maze_dataset.tokenization.modular.all_instances import FiniteValued, all_instances
11# Test classes
12@dataclass
13class DC1:
14 x: bool
15 y: bool = False
18@dataclass(frozen=True)
19class DC2:
20 x: bool
21 y: bool = False
24@dataclass(frozen=True)
25class DC3:
26 x: DC2 = field(default_factory=lambda: DC2(False, False))
29@dataclass(frozen=True)
30class DC4:
31 x: DC2
32 y: bool = False
35@dataclass(frozen=True)
36class DC5:
37 x: int
40@dataclass(frozen=True)
41class DC6:
42 x: DC5
43 y: bool = False
46@dataclass(frozen=True)
47class DC7(abc.ABC):
48 x: bool
50 @abc.abstractmethod
51 def foo(self):
52 pass
55@dataclass(frozen=True)
56class DC8(DC7):
57 x: bool = False
59 def foo(self):
60 pass
63@dataclass(frozen=True)
64class DC9(DC7):
65 y: bool = True
67 def foo(self):
68 pass
71@pytest.mark.parametrize(
72 ("type_", "validation_funcs", "result"),
73 [
74 pytest.param(
75 type_,
76 vfs,
77 result,
78 id=f"{type_}-vfs[{len(vfs) if vfs is not None else 'None'}]",
79 )
80 for type_, vfs, result in (
81 [
82 (
83 DC1,
84 None,
85 [
86 DC1(False, False),
87 DC1(False, True),
88 DC1(True, False),
89 DC1(True, True),
90 ],
91 ),
92 (
93 DC2,
94 None,
95 [
96 DC2(False, False),
97 DC2(False, True),
98 DC2(True, False),
99 DC2(True, True),
100 ],
101 ),
102 (
103 DC2,
104 {DC2: lambda dc: dc.x ^ dc.y},
105 [
106 DC2(False, True),
107 DC2(True, False),
108 ],
109 ),
110 (
111 DC1 | DC2,
112 {DC2: lambda dc: dc.x ^ dc.y},
113 [
114 DC2(False, True),
115 DC2(True, False),
116 DC1(False, False),
117 DC1(False, True),
118 DC1(True, False),
119 DC1(True, True),
120 ],
121 ),
122 (
123 DC1 | DC2,
124 {
125 DC1: lambda dc: dc.x == dc.y,
126 DC2: lambda dc: dc.x ^ dc.y,
127 },
128 [
129 DC2(False, True),
130 DC2(True, False),
131 DC1(False, False),
132 DC1(True, True),
133 ],
134 ),
135 (
136 DC3,
137 None,
138 [
139 DC3(DC2(False, False)),
140 DC3(DC2(False, True)),
141 DC3(DC2(True, False)),
142 DC3(DC2(True, True)),
143 ],
144 ),
145 (
146 DC4,
147 None,
148 [
149 DC4(DC2(False, False), True),
150 DC4(DC2(False, True), True),
151 DC4(DC2(True, False), True),
152 DC4(DC2(True, True), True),
153 DC4(DC2(False, False), False),
154 DC4(DC2(False, True), False),
155 DC4(DC2(True, False), False),
156 DC4(DC2(True, True), False),
157 ],
158 ),
159 (
160 DC4,
161 {DC2: lambda dc: dc.x ^ dc.y},
162 [
163 DC4(DC2(False, True), True),
164 DC4(DC2(True, False), True),
165 DC4(DC2(False, True), False),
166 DC4(DC2(True, False), False),
167 ],
168 ),
169 (DC5, None, TypeError),
170 (DC6, None, TypeError),
171 (bool, None, [True, False]),
172 (bool, {bool: lambda x: x}, [True]),
173 (bool, {bool: lambda x: not x}, [False]),
174 (int, None, TypeError),
175 (str, None, TypeError),
176 (Literal[0, 1, 2], None, [0, 1, 2]),
177 (Literal[0, 1, 2], {int: lambda x: x % 2 == 0}, [0, 2]),
178 (bool | Literal[0, 1, 2], dict(), [0, 1, 2, True, False]),
179 (bool | Literal[0, 1, 2], {bool: lambda x: x}, [0, 1, 2, True]),
180 (bool | Literal[0, 1, 2], {int: lambda x: x % 2}, [1, True]),
181 (
182 tuple[bool],
183 None,
184 [
185 (True,),
186 (False,),
187 ],
188 ),
189 (
190 tuple[bool, bool],
191 None,
192 [
193 (True, True),
194 (True, False),
195 (False, True),
196 (False, False),
197 ],
198 ),
199 (
200 tuple[bool, bool],
201 {bool: lambda x: x},
202 [
203 (True, True),
204 ],
205 ),
206 (
207 DC8,
208 None,
209 [
210 DC8(False),
211 DC8(True),
212 ],
213 ),
214 (
215 DC7,
216 None,
217 [
218 DC8(False),
219 DC8(True),
220 DC9(False, False),
221 DC9(False, True),
222 DC9(True, False),
223 DC9(True, True),
224 ],
225 ),
226 (
227 tuple[DC7],
228 None,
229 [
230 (DC8(False),),
231 (DC8(True),),
232 (DC9(False, False),),
233 (DC9(False, True),),
234 (DC9(True, False),),
235 (DC9(True, True),),
236 ],
237 ),
238 (
239 tuple[DC7],
240 {DC9: lambda dc: dc.x == dc.y},
241 [
242 (DC8(False),),
243 (DC8(True),),
244 (DC9(False, False),),
245 (DC9(True, True),),
246 ],
247 ),
248 (
249 tuple[DC8, DC8],
250 None,
251 [
252 (DC8(False), DC8(False)),
253 (DC8(False), DC8(True)),
254 (DC8(True), DC8(False)),
255 (DC8(True), DC8(True)),
256 ],
257 ),
258 (
259 tuple[DC7, bool],
260 None,
261 [
262 (DC8(False), True),
263 (DC8(True), True),
264 (DC9(False, False), True),
265 (DC9(False, True), True),
266 (DC9(True, False), True),
267 (DC9(True, True), True),
268 (DC8(False), False),
269 (DC8(True), False),
270 (DC9(False, False), False),
271 (DC9(False, True), False),
272 (DC9(True, False), False),
273 (DC9(True, True), False),
274 ],
275 ),
276 ]
277 )
278 ],
279)
280def test_all_instances(
281 type_: FiniteValued,
282 validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None,
283 result: type[Exception] | Iterable[FiniteValued],
284):
285 if isinstance(result, type) and issubclass(result, Exception):
286 with pytest.raises(result):
287 list(all_instances(type_, validation_funcs))
288 elif hasattr(type_, "__dataclass_fields__"):
289 # TYPING: error: Argument 2 to "dataclass_set_equals" has incompatible type "Iterable[FiniteValued]"; expected "Iterable[IsDataclass]" [arg-type]
290 assert dataclass_set_equals(all_instances(type_, validation_funcs), result) # type: ignore[arg-type]
291 else: # General case, due to nesting, results might contain some dataclasses and some other types
292 out = list(all_instances(type_, validation_funcs))
293 assert dataclass_set_equals(
294 # TYPING: error: Argument 1 to "filter" has incompatible type "Callable[[Any], bool]"; expected "Callable[[FiniteValued], TypeGuard[IsDataclass]]" [arg-type]
295 filter(lambda x: isinstance(x, IsDataclass), out), # type: ignore[arg-type]
296 filter(lambda x: isinstance(x, IsDataclass), result), # type: ignore[arg-type]
297 )
298 assert set(filter(lambda x: not isinstance(x, IsDataclass), out)) == set(
299 filter(lambda x: not isinstance(x, IsDataclass), result),
300 )