Coverage for tests/unit/tokenization/test_all_instances.py: 94%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-27 23:43 -0600

1import abc 

2from dataclasses import dataclass, field 

3from typing import Callable, Iterable, Literal 

4 

5import pytest 

6from muutils.misc import IsDataclass, dataclass_set_equals 

7 

8from maze_dataset.tokenization.modular.all_instances import FiniteValued, all_instances 

9 

10 

11# Test classes 

12@dataclass 

13class DC1: 

14 x: bool 

15 y: bool = False 

16 

17 

18@dataclass(frozen=True) 

19class DC2: 

20 x: bool 

21 y: bool = False 

22 

23 

24@dataclass(frozen=True) 

25class DC3: 

26 x: DC2 = field(default_factory=lambda: DC2(False, False)) 

27 

28 

29@dataclass(frozen=True) 

30class DC4: 

31 x: DC2 

32 y: bool = False 

33 

34 

35@dataclass(frozen=True) 

36class DC5: 

37 x: int 

38 

39 

40@dataclass(frozen=True) 

41class DC6: 

42 x: DC5 

43 y: bool = False 

44 

45 

46@dataclass(frozen=True) 

47class DC7(abc.ABC): 

48 x: bool 

49 

50 @abc.abstractmethod 

51 def foo(self): 

52 pass 

53 

54 

55@dataclass(frozen=True) 

56class DC8(DC7): 

57 x: bool = False 

58 

59 def foo(self): 

60 pass 

61 

62 

63@dataclass(frozen=True) 

64class DC9(DC7): 

65 y: bool = True 

66 

67 def foo(self): 

68 pass 

69 

70 

71@pytest.mark.parametrize( 

72 ("type_", "validation_funcs", "result"), 

73 [ 

74 pytest.param( 

75 type_, 

76 vfs, 

77 result, 

78 id=f"{type_}-vfs[{len(vfs) if vfs is not None else 'None'}]", 

79 ) 

80 for type_, vfs, result in ( 

81 [ 

82 ( 

83 DC1, 

84 None, 

85 [ 

86 DC1(False, False), 

87 DC1(False, True), 

88 DC1(True, False), 

89 DC1(True, True), 

90 ], 

91 ), 

92 ( 

93 DC2, 

94 None, 

95 [ 

96 DC2(False, False), 

97 DC2(False, True), 

98 DC2(True, False), 

99 DC2(True, True), 

100 ], 

101 ), 

102 ( 

103 DC2, 

104 {DC2: lambda dc: dc.x ^ dc.y}, 

105 [ 

106 DC2(False, True), 

107 DC2(True, False), 

108 ], 

109 ), 

110 ( 

111 DC1 | DC2, 

112 {DC2: lambda dc: dc.x ^ dc.y}, 

113 [ 

114 DC2(False, True), 

115 DC2(True, False), 

116 DC1(False, False), 

117 DC1(False, True), 

118 DC1(True, False), 

119 DC1(True, True), 

120 ], 

121 ), 

122 ( 

123 DC1 | DC2, 

124 { 

125 DC1: lambda dc: dc.x == dc.y, 

126 DC2: lambda dc: dc.x ^ dc.y, 

127 }, 

128 [ 

129 DC2(False, True), 

130 DC2(True, False), 

131 DC1(False, False), 

132 DC1(True, True), 

133 ], 

134 ), 

135 ( 

136 DC3, 

137 None, 

138 [ 

139 DC3(DC2(False, False)), 

140 DC3(DC2(False, True)), 

141 DC3(DC2(True, False)), 

142 DC3(DC2(True, True)), 

143 ], 

144 ), 

145 ( 

146 DC4, 

147 None, 

148 [ 

149 DC4(DC2(False, False), True), 

150 DC4(DC2(False, True), True), 

151 DC4(DC2(True, False), True), 

152 DC4(DC2(True, True), True), 

153 DC4(DC2(False, False), False), 

154 DC4(DC2(False, True), False), 

155 DC4(DC2(True, False), False), 

156 DC4(DC2(True, True), False), 

157 ], 

158 ), 

159 ( 

160 DC4, 

161 {DC2: lambda dc: dc.x ^ dc.y}, 

162 [ 

163 DC4(DC2(False, True), True), 

164 DC4(DC2(True, False), True), 

165 DC4(DC2(False, True), False), 

166 DC4(DC2(True, False), False), 

167 ], 

168 ), 

169 (DC5, None, TypeError), 

170 (DC6, None, TypeError), 

171 (bool, None, [True, False]), 

172 (bool, {bool: lambda x: x}, [True]), 

173 (bool, {bool: lambda x: not x}, [False]), 

174 (int, None, TypeError), 

175 (str, None, TypeError), 

176 (Literal[0, 1, 2], None, [0, 1, 2]), 

177 (Literal[0, 1, 2], {int: lambda x: x % 2 == 0}, [0, 2]), 

178 (bool | Literal[0, 1, 2], dict(), [0, 1, 2, True, False]), 

179 (bool | Literal[0, 1, 2], {bool: lambda x: x}, [0, 1, 2, True]), 

180 (bool | Literal[0, 1, 2], {int: lambda x: x % 2}, [1, True]), 

181 ( 

182 tuple[bool], 

183 None, 

184 [ 

185 (True,), 

186 (False,), 

187 ], 

188 ), 

189 ( 

190 tuple[bool, bool], 

191 None, 

192 [ 

193 (True, True), 

194 (True, False), 

195 (False, True), 

196 (False, False), 

197 ], 

198 ), 

199 ( 

200 tuple[bool, bool], 

201 {bool: lambda x: x}, 

202 [ 

203 (True, True), 

204 ], 

205 ), 

206 ( 

207 DC8, 

208 None, 

209 [ 

210 DC8(False), 

211 DC8(True), 

212 ], 

213 ), 

214 ( 

215 DC7, 

216 None, 

217 [ 

218 DC8(False), 

219 DC8(True), 

220 DC9(False, False), 

221 DC9(False, True), 

222 DC9(True, False), 

223 DC9(True, True), 

224 ], 

225 ), 

226 ( 

227 tuple[DC7], 

228 None, 

229 [ 

230 (DC8(False),), 

231 (DC8(True),), 

232 (DC9(False, False),), 

233 (DC9(False, True),), 

234 (DC9(True, False),), 

235 (DC9(True, True),), 

236 ], 

237 ), 

238 ( 

239 tuple[DC7], 

240 {DC9: lambda dc: dc.x == dc.y}, 

241 [ 

242 (DC8(False),), 

243 (DC8(True),), 

244 (DC9(False, False),), 

245 (DC9(True, True),), 

246 ], 

247 ), 

248 ( 

249 tuple[DC8, DC8], 

250 None, 

251 [ 

252 (DC8(False), DC8(False)), 

253 (DC8(False), DC8(True)), 

254 (DC8(True), DC8(False)), 

255 (DC8(True), DC8(True)), 

256 ], 

257 ), 

258 ( 

259 tuple[DC7, bool], 

260 None, 

261 [ 

262 (DC8(False), True), 

263 (DC8(True), True), 

264 (DC9(False, False), True), 

265 (DC9(False, True), True), 

266 (DC9(True, False), True), 

267 (DC9(True, True), True), 

268 (DC8(False), False), 

269 (DC8(True), False), 

270 (DC9(False, False), False), 

271 (DC9(False, True), False), 

272 (DC9(True, False), False), 

273 (DC9(True, True), False), 

274 ], 

275 ), 

276 ] 

277 ) 

278 ], 

279) 

280def test_all_instances( 

281 type_: FiniteValued, 

282 validation_funcs: dict[FiniteValued, Callable[[FiniteValued], bool]] | None, 

283 result: type[Exception] | Iterable[FiniteValued], 

284): 

285 if isinstance(result, type) and issubclass(result, Exception): 

286 with pytest.raises(result): 

287 list(all_instances(type_, validation_funcs)) 

288 elif hasattr(type_, "__dataclass_fields__"): 

289 # TYPING: error: Argument 2 to "dataclass_set_equals" has incompatible type "Iterable[FiniteValued]"; expected "Iterable[IsDataclass]" [arg-type] 

290 assert dataclass_set_equals(all_instances(type_, validation_funcs), result) # type: ignore[arg-type] 

291 else: # General case, due to nesting, results might contain some dataclasses and some other types 

292 out = list(all_instances(type_, validation_funcs)) 

293 assert dataclass_set_equals( 

294 # TYPING: error: Argument 1 to "filter" has incompatible type "Callable[[Any], bool]"; expected "Callable[[FiniteValued], TypeGuard[IsDataclass]]" [arg-type] 

295 filter(lambda x: isinstance(x, IsDataclass), out), # type: ignore[arg-type] 

296 filter(lambda x: isinstance(x, IsDataclass), result), # type: ignore[arg-type] 

297 ) 

298 assert set(filter(lambda x: not isinstance(x, IsDataclass), out)) == set( 

299 filter(lambda x: not isinstance(x, IsDataclass), result), 

300 )