Coverage for maze_dataset/testing_utils.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 00:49 -0600

1"""Shared utilities for tests only. 

2 

3Do not import into any module outside of the tests directory 

4""" 

5 

6import itertools 

7from typing import Final, NamedTuple, Sequence 

8 

9import frozendict 

10import numpy as np 

11 

12from maze_dataset import ( 

13 CoordArray, 

14 LatticeMaze, 

15 LatticeMazeGenerators, 

16 MazeDataset, 

17 MazeDatasetConfig, 

18 SolvedMaze, 

19 TargetedLatticeMaze, 

20) 

21from maze_dataset.tokenization import ( 

22 MazeTokenizer, 

23 MazeTokenizerModular, 

24 TokenizationMode, 

25) 

26 

27GRID_N: Final[int] = 5 

28N_MAZES: Final[int] = 5 

29CFG: Final[MazeDatasetConfig] = MazeDatasetConfig( 

30 name="test", 

31 grid_n=GRID_N, 

32 n_mazes=N_MAZES, 

33 maze_ctor=LatticeMazeGenerators.gen_dfs, 

34) 

35MAZE_DATASET: Final[MazeDataset] = MazeDataset.from_config( 

36 CFG, 

37 do_download=False, 

38 load_local=False, 

39 do_generate=True, 

40 save_local=False, 

41 verbose=True, 

42 gen_parallel=False, 

43) 

44LATTICE_MAZES: Final[tuple[LatticeMaze, ...]] = tuple( 

45 LatticeMazeGenerators.gen_dfs(np.array([GRID_N, GRID_N])) for _ in range(N_MAZES) 

46) 

47_PATHS = tuple(maze.generate_random_path() for maze in LATTICE_MAZES) 

48TARGETED_MAZES: Final[tuple[TargetedLatticeMaze, ...]] = tuple( 

49 TargetedLatticeMaze.from_lattice_maze(maze, path[0], path[-1]) 

50 for maze, path in zip(LATTICE_MAZES, _PATHS, strict=False) 

51) 

52# MIXED_MAZES alternates the maze types, so you can slice a contiguous subset and still get all types 

53MIXED_MAZES: Final[tuple[LatticeMaze | TargetedLatticeMaze | SolvedMaze, ...]] = tuple( 

54 x 

55 for x in itertools.chain.from_iterable( 

56 itertools.zip_longest(MAZE_DATASET.mazes, TARGETED_MAZES, LATTICE_MAZES), 

57 ) 

58) 

59 

60 

61class MANUAL_MAZE(NamedTuple): # noqa: N801 

62 """A named tuple for manual maze definitions""" 

63 

64 tokens: str 

65 ascii: Sequence[str] 

66 straightaway_footprints: CoordArray 

67 

68 

69ASCII_MAZES: Final[frozendict.frozendict[str, MANUAL_MAZE]] = frozendict.frozendict( 

70 small_3x3=MANUAL_MAZE( 

71 tokens="<ADJLIST_START> (2,0) <--> (2,1) ; (0,0) <--> (0,1) ; (0,0) <--> (1,0) ; (0,2) <--> (1,2) ; (1,0) <--> (2,0) ; (0,2) <--> (0,1) ; (2,2) <--> (2,1) ; (1,1) <--> (2,1) ; <ADJLIST_END> <ORIGIN_START> (0,0) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (0,0) (1,0) (2,0) (2,1) <PATH_END>", 

72 ascii=( 

73 "#######", 

74 "#S #", 

75 "#X### #", 

76 "#X# # #", 

77 "#X# ###", 

78 "#XXE #", 

79 "#######", 

80 ), 

81 straightaway_footprints=np.array( 

82 [ 

83 [0, 0], 

84 [2, 0], 

85 [2, 1], 

86 ], 

87 ), 

88 ), 

89 big_10x10=MANUAL_MAZE( 

90 tokens="<ADJLIST_START> (8,2) <--> (8,3) ; (3,7) <--> (3,6) ; (6,7) <--> (6,8) ; (4,6) <--> (5,6) ; (9,5) <--> (9,4) ; (3,3) <--> (3,4) ; (5,1) <--> (4,1) ; (2,6) <--> (2,7) ; (8,5) <--> (8,4) ; (1,9) <--> (2,9) ; (4,1) <--> (4,2) ; (0,8) <--> (0,7) ; (5,4) <--> (5,3) ; (6,3) <--> (6,4) ; (5,0) <--> (4,0) ; (5,3) <--> (5,2) ; (3,1) <--> (2,1) ; (9,1) <--> (9,0) ; (3,5) <--> (3,6) ; (5,5) <--> (6,5) ; (7,1) <--> (7,2) ; (0,1) <--> (1,1) ; (7,8) <--> (8,8) ; (3,9) <--> (4,9) ; (4,6) <--> (4,7) ; (0,6) <--> (0,7) ; (3,4) <--> (3,5) ; (6,0) <--> (5,0) ; (7,7) <--> (7,6) ; (1,6) <--> (0,6) ; (6,1) <--> (6,0) ; (8,6) <--> (8,7) ; (9,9) <--> (9,8) ; (1,8) <--> (1,9) ; (2,1) <--> (2,2) ; (9,2) <--> (9,3) ; (5,9) <--> (6,9) ; (3,2) <--> (2,2) ; (0,8) <--> (0,9) ; (5,6) <--> (5,7) ; (2,3) <--> (2,4) ; (4,5) <--> (4,4) ; (8,9) <--> (8,8) ; (9,6) <--> (8,6) ; (3,7) <--> (3,8) ; (8,0) <--> (7,0) ; (6,1) <--> (6,2) ; (0,1) <--> (0,0) ; (7,3) <--> (7,4) ; (9,4) <--> (9,3) ; (9,6) <--> (9,5) ; (8,7) <--> (7,7) ; (5,2) <--> (5,1) ; (0,0) <--> (1,0) ; (7,2) <--> (7,3) ; (2,5) <--> (2,6) ; (4,9) <--> (5,9) ; (5,5) <--> (5,4) ; (5,6) <--> (6,6) ; (7,8) <--> (7,9) ; (1,7) <--> (2,7) ; (4,6) <--> (4,5) ; (1,1) <--> (1,2) ; (3,1) <--> (3,0) ; (1,5) <--> (1,6) ; (8,3) <--> (8,4) ; (9,9) <--> (8,9) ; (8,5) <--> (7,5) ; (1,4) <--> (2,4) ; (3,0) <--> (4,0) ; (3,3) <--> (4,3) ; (6,9) <--> (6,8) ; (1,0) <--> (2,0) ; (6,0) <--> (7,0) ; (8,0) <--> (9,0) ; (2,3) <--> (2,2) ; (2,8) <--> (3,8) ; (5,7) <--> (6,7) ; (1,3) <--> (0,3) ; (9,7) <--> (9,8) ; (7,5) <--> (7,4) ; (1,8) <--> (2,8) ; (6,5) <--> (6,4) ; (0,2) <--> (1,2) ; (0,7) <--> (1,7) ; (0,3) <--> (0,2) ; (4,3) <--> (4,2) ; (5,8) <--> (4,8) ; (9,1) <--> (8,1) ; (9,2) <--> (8,2) ; (1,3) <--> (1,4) ; (2,9) <--> (3,9) ; (4,8) <--> (4,7) ; (0,5) <--> (0,4) ; (8,1) <--> (7,1) ; (0,3) <--> (0,4) ; (9,7) <--> (9,6) ; (7,6) <--> (6,6) ; (1,5) <--> (0,5) ; <ADJLIST_END> <ORIGIN_START> (6,2) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (6,2) (6,1) (6,0) (5,0) (4,0) (3,0) (3,1) (2,1) <PATH_END>", 

91 ascii=( 

92 "#####################", 

93 "# # # #", 

94 "# # # # ### # # #####", 

95 "# # # # # # #", 

96 "# ####### ##### # # #", 

97 "# #E # # # #", 

98 "###X# ########### # #", 

99 "#XXX# # # #", 

100 "#X##### ########### #", 

101 "#X# # # #", 

102 "#X# ######### ### # #", 

103 "#X# # # # #", 

104 "#X######### # # ### #", 

105 "#XXXXS# # # #", 

106 "# ########### #######", 

107 "# # # # #", 

108 "# # ####### ### # ###", 

109 "# # # # # #", 

110 "# # # ####### ##### #", 

111 "# # #", 

112 "#####################", 

113 ), 

114 straightaway_footprints=np.array( 

115 [ 

116 [6, 2], 

117 [6, 0], 

118 [3, 0], 

119 [3, 1], 

120 [2, 1], 

121 ], 

122 ), 

123 ), 

124 longer_10x10=MANUAL_MAZE( 

125 tokens="<ADJLIST_START> (8,2) <--> (8,3) ; (3,7) <--> (3,6) ; (6,7) <--> (6,8) ; (4,6) <--> (5,6) ; (9,5) <--> (9,4) ; (3,3) <--> (3,4) ; (5,1) <--> (4,1) ; (2,6) <--> (2,7) ; (8,5) <--> (8,4) ; (1,9) <--> (2,9) ; (4,1) <--> (4,2) ; (0,8) <--> (0,7) ; (5,4) <--> (5,3) ; (6,3) <--> (6,4) ; (5,0) <--> (4,0) ; (5,3) <--> (5,2) ; (3,1) <--> (2,1) ; (9,1) <--> (9,0) ; (3,5) <--> (3,6) ; (5,5) <--> (6,5) ; (7,1) <--> (7,2) ; (0,1) <--> (1,1) ; (7,8) <--> (8,8) ; (3,9) <--> (4,9) ; (4,6) <--> (4,7) ; (0,6) <--> (0,7) ; (3,4) <--> (3,5) ; (6,0) <--> (5,0) ; (7,7) <--> (7,6) ; (1,6) <--> (0,6) ; (6,1) <--> (6,0) ; (8,6) <--> (8,7) ; (9,9) <--> (9,8) ; (1,8) <--> (1,9) ; (2,1) <--> (2,2) ; (9,2) <--> (9,3) ; (5,9) <--> (6,9) ; (3,2) <--> (2,2) ; (0,8) <--> (0,9) ; (5,6) <--> (5,7) ; (2,3) <--> (2,4) ; (4,5) <--> (4,4) ; (8,9) <--> (8,8) ; (9,6) <--> (8,6) ; (3,7) <--> (3,8) ; (8,0) <--> (7,0) ; (6,1) <--> (6,2) ; (0,1) <--> (0,0) ; (7,3) <--> (7,4) ; (9,4) <--> (9,3) ; (9,6) <--> (9,5) ; (8,7) <--> (7,7) ; (5,2) <--> (5,1) ; (0,0) <--> (1,0) ; (7,2) <--> (7,3) ; (2,5) <--> (2,6) ; (4,9) <--> (5,9) ; (5,5) <--> (5,4) ; (5,6) <--> (6,6) ; (7,8) <--> (7,9) ; (1,7) <--> (2,7) ; (4,6) <--> (4,5) ; (1,1) <--> (1,2) ; (3,1) <--> (3,0) ; (1,5) <--> (1,6) ; (8,3) <--> (8,4) ; (9,9) <--> (8,9) ; (8,5) <--> (7,5) ; (1,4) <--> (2,4) ; (3,0) <--> (4,0) ; (3,3) <--> (4,3) ; (6,9) <--> (6,8) ; (1,0) <--> (2,0) ; (6,0) <--> (7,0) ; (8,0) <--> (9,0) ; (2,3) <--> (2,2) ; (2,8) <--> (3,8) ; (5,7) <--> (6,7) ; (1,3) <--> (0,3) ; (9,7) <--> (9,8) ; (7,5) <--> (7,4) ; (1,8) <--> (2,8) ; (6,5) <--> (6,4) ; (0,2) <--> (1,2) ; (0,7) <--> (1,7) ; (0,3) <--> (0,2) ; (4,3) <--> (4,2) ; (5,8) <--> (4,8) ; (9,1) <--> (8,1) ; (9,2) <--> (8,2) ; (1,3) <--> (1,4) ; (2,9) <--> (3,9) ; (4,8) <--> (4,7) ; (0,5) <--> (0,4) ; (8,1) <--> (7,1) ; (0,3) <--> (0,4) ; (9,7) <--> (9,6) ; (7,6) <--> (6,6) ; (1,5) <--> (0,5) ; <ADJLIST_END> <ORIGIN_START> (6,2) <ORIGIN_END> <TARGET_START> (2,1) <TARGET_END> <PATH_START> (6,2) (6,1) (6,0) (5,0) (4,0) (3,0) (3,1) (2,1) (2,2) (2,3) (2,4) (1,4) (1,3) (0,3) (0,4) (0,5) (1,5) (1,6) (0,6) (0,7) (0,8) <PATH_END>", 

126 ascii=( 

127 "#####################", 

128 "# # XXXXX#XXXXE #", 

129 "# # # #X###X#X# #####", 

130 "# # #XXX#XXX# # #", 

131 "# #######X##### # # #", 

132 "# #XXXXXXX# # # #", 

133 "###X# ########### # #", 

134 "#XXX# # # #", 

135 "#X##### ########### #", 

136 "#X# # # #", 

137 "#X# ######### ### # #", 

138 "#X# # # # #", 

139 "#X######### # # ### #", 

140 "#XXXXS# # # #", 

141 "# ########### #######", 

142 "# # # # #", 

143 "# # ####### ### # ###", 

144 "# # # # # #", 

145 "# # # ####### ##### #", 

146 "# # #", 

147 "#####################", 

148 ), 

149 straightaway_footprints=np.array( 

150 [ 

151 [6, 2], 

152 [6, 0], 

153 [3, 0], 

154 [3, 1], 

155 [2, 1], 

156 [2, 4], 

157 [1, 4], 

158 [1, 3], 

159 [0, 3], 

160 [0, 5], 

161 [1, 5], 

162 [1, 6], 

163 [0, 6], 

164 [0, 8], 

165 ], 

166 ), 

167 ), 

168) 

169 

170# A list of legacy `MazeTokenizer`s and their `MazeTokenizerModular` equivalents. 

171# Used for unit tests where both versions are supported 

172LEGACY_AND_EQUIVALENT_TOKENIZERS: list[MazeTokenizer | MazeTokenizerModular] = [ 

173 *[ 

174 MazeTokenizer(tokenization_mode=tok_mode, max_grid_size=20) 

175 for tok_mode in TokenizationMode 

176 ], 

177 *[MazeTokenizerModular.from_legacy(tok_mode) for tok_mode in TokenizationMode], 

178]