Coverage for src / bartz / jaxext / _autobatch.py: 97%

122 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-18 15:24 +0000

1# bartz/src/bartz/jaxext/_autobatch.py 

2# 

3# Copyright (c) 2025, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implementation of `autobatch`.""" 

26 

27import math 

28from collections.abc import Callable 

29from functools import wraps 

30from warnings import warn 

31 

32from jax import eval_shape, jit 

33from jax import numpy as jnp 

34from jax.lax import scan 

35from jax.tree import flatten as tree_flatten 

36from jax.tree import map as tree_map 

37from jax.tree import reduce as tree_reduce 

38from jaxtyping import PyTree 

39 

40 

41def expand_axes(axes, tree): 

42 """Expand `axes` such that they match the pytreedef of `tree`.""" 

43 

44 def expand_axis(axis, subtree): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

45 return tree_map(lambda _: axis, subtree) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

46 

47 return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

48 

49 

50def check_no_nones(axes, tree): 

51 def check_not_none(_, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

52 assert axis is not None 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

53 

54 tree_map(check_not_none, tree, axes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

55 

56 

57def extract_size(axes, tree): 

58 def get_size(x, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

59 if axis is None: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

60 return None 1cdefghijklmnopqrstuvwxyzab

61 else: 

62 return x.shape[axis] 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

63 

64 sizes = tree_map(get_size, tree, axes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

65 sizes, _ = tree_flatten(sizes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

66 assert all(s == sizes[0] for s in sizes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

67 return sizes[0] 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

68 

69 

70def sum_nbytes(tree): 

71 def nbytes(x): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

72 return math.prod(x.shape) * x.dtype.itemsize 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

73 

74 return tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

75 

76 

77def next_divisor_small(dividend, min_divisor): 

78 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 78 ↛ 81line 78 didn't jump to line 81 because the loop on line 78 didn't complete1cdefghijklmnopqrJKLstuMvwxyNzOPQRFGHISTUVWXYZ0123

79 if dividend % divisor == 0: 79 ↛ 78line 79 didn't jump to line 78 because the condition on line 79 was always true1cdefghijklmnopqrJKLstuMvwxyNzOPQRFGHISTUVWXYZ0123

80 return divisor 1cdefghijklmnopqrJKLstuMvwxyNzOPQRFGHISTUVWXYZ0123

81 return dividend 

82 

83 

84def next_divisor_large(dividend, min_divisor): 

85 max_inv_divisor = dividend // min_divisor 1aBCDEAb

86 for inv_divisor in range(max_inv_divisor, 0, -1): 1aBCDEAb

87 if dividend % inv_divisor == 0: 87 ↛ 86line 87 didn't jump to line 86 because the condition on line 87 was always true1aBCDEb

88 return dividend // inv_divisor 1aBCDEb

89 return dividend 1A

90 

91 

92def next_divisor(dividend, min_divisor): 

93 if dividend == 0: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

94 return min_divisor 14

95 if min_divisor * min_divisor <= dividend: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAbTUVWXYZ0123

96 return next_divisor_small(dividend, min_divisor) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRFGHISTUVWXYZ0123

97 return next_divisor_large(dividend, min_divisor) 1aBCDEAb

98 

99 

100def pull_nonbatched(axes, tree): 

101 def pull_nonbatched(x, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

102 if axis is None: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

103 return None 1cdefghijklmnopqrstuvwxyzab

104 else: 

105 return x 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

106 

107 return tree_map(pull_nonbatched, tree, axes), tree 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

108 

109 

110def push_nonbatched(axes, tree, original_tree): 

111 def push_nonbatched(original_x, x, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

112 if axis is None: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

113 return original_x 1cdefghijklmnopqrstuvwxyzab

114 else: 

115 return x 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

116 

117 return tree_map(push_nonbatched, original_tree, tree, axes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

118 

119 

120def move_axes_out(axes, tree): 

121 def move_axis_out(x, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

122 return jnp.moveaxis(x, axis, 0) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

123 

124 return tree_map(move_axis_out, tree, axes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

125 

126 

127def move_axes_in(axes, tree): 

128 def move_axis_in(x, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

129 return jnp.moveaxis(x, 0, axis) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

130 

131 return tree_map(move_axis_in, tree, axes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

132 

133 

134def batch(tree, nbatches): 

135 def batch(x): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

136 return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:]) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

137 

138 return tree_map(batch, tree) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

139 

140 

141def unbatch(tree): 

142 def unbatch(x): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

143 return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

144 

145 return tree_map(unbatch, tree) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

146 

147 

148def check_same(tree1, tree2): 

149 def check_same(x1, x2): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

150 assert x1.shape == x2.shape 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

151 assert x1.dtype == x2.dtype 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

152 

153 tree_map(check_same, tree1, tree2) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

154 

155 

156def autobatch( 

157 func: Callable, 

158 max_io_nbytes: int, 

159 in_axes: PyTree[int | None] = 0, 

160 out_axes: PyTree[int] = 0, 

161 return_nbatches: bool = False, 

162) -> Callable: 

163 """ 

164 Batch a function such that each batch is smaller than a threshold. 

165 

166 Parameters 

167 ---------- 

168 func 

169 A jittable function with positional arguments only, with inputs and 

170 outputs pytrees of arrays. 

171 max_io_nbytes 

172 The maximum number of input + output bytes in each batch (excluding 

173 unbatched arguments.) 

174 in_axes 

175 A tree matching (a prefix of) the structure of the function input, 

176 indicating along which axes each array should be batched. A `None` axis 

177 indicates to not batch an argument. 

178 out_axes 

179 The same for outputs (but non-batching is not allowed). 

180 return_nbatches 

181 If True, the number of batches is returned as a second output. 

182 

183 Returns 

184 ------- 

185 A function with the same signature as `func`, save for the return value if `return_nbatches`. 

186 """ 

187 initial_in_axes = in_axes 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

188 initial_out_axes = out_axes 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

189 

190 @jit 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

191 @wraps(func) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

192 def batched_func(*args): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

193 example_result = eval_shape(func, *args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

194 

195 in_axes = expand_axes(initial_in_axes, args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

196 out_axes = expand_axes(initial_out_axes, example_result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

197 check_no_nones(out_axes, example_result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

198 

199 size = extract_size((in_axes, out_axes), (args, example_result)) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

200 

201 args, nonbatched_args = pull_nonbatched(in_axes, args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

202 

203 total_nbytes = sum_nbytes((args, example_result)) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

204 min_nbatches = total_nbytes // max_io_nbytes + bool( 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

205 total_nbytes % max_io_nbytes 

206 ) 

207 min_nbatches = max(1, min_nbatches) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

208 nbatches = next_divisor(size, min_nbatches) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

209 assert 1 <= nbatches <= max(1, size) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

210 assert size % nbatches == 0 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

211 assert total_nbytes % nbatches == 0 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

212 

213 batch_nbytes = total_nbytes // nbatches 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

214 if batch_nbytes > max_io_nbytes: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

215 assert size == nbatches 1A

216 msg = f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}' 1A

217 warn(msg) 1A

218 

219 def loop(_, args): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

220 args = move_axes_in(in_axes, args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

221 args = push_nonbatched(in_axes, args, nonbatched_args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

222 result = func(*args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

223 result = move_axes_out(out_axes, result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

224 return None, result 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

225 

226 args = move_axes_out(in_axes, args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

227 args = batch(args, nbatches) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

228 _, result = scan(loop, None, args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

229 result = unbatch(result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

230 result = move_axes_in(out_axes, result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

231 

232 check_same(example_result, result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

233 

234 if return_nbatches: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123

235 return result, nbatches 1FBGCHDIES4

236 return result 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIEAbTUVWXYZ0123

237 

238 return batched_func 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123