Coverage for src/bartz/jaxext/_autobatch.py: 97%

122 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 14:46 +0000

1# bartz/src/bartz/jaxext/_autobatch.py 

2# 

3# Copyright (c) 2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Implementation of `autobatch`.""" 

26 

27import math 1ab

28from collections.abc import Callable 1ab

29from functools import wraps 1ab

30from warnings import warn 1ab

31 

32from jax import eval_shape, jit 1ab

33from jax import numpy as jnp 1ab

34from jax.lax import scan 1ab

35from jax.tree import flatten as tree_flatten 1ab

36from jax.tree import map as tree_map 1ab

37from jax.tree import reduce as tree_reduce 1ab

38from jaxtyping import PyTree 1ab

39 

40 

41def expand_axes(axes, tree): 1ab

42 """Expand `axes` such that they match the pytreedef of `tree`.""" 

43 

44 def expand_axis(axis, subtree): 1ab

45 return tree_map(lambda _: axis, subtree) 1ab

46 

47 return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None) 1ab

48 

49 

50def check_no_nones(axes, tree): 1ab

51 def check_not_none(_, axis): 1ab

52 assert axis is not None 1ab

53 

54 tree_map(check_not_none, tree, axes) 1ab

55 

56 

57def extract_size(axes, tree): 1ab

58 def get_size(x, axis): 1ab

59 if axis is None: 1ab

60 return None 1ab

61 else: 

62 return x.shape[axis] 1ab

63 

64 sizes = tree_map(get_size, tree, axes) 1ab

65 sizes, _ = tree_flatten(sizes) 1ab

66 assert all(s == sizes[0] for s in sizes) 1ab

67 return sizes[0] 1ab

68 

69 

70def sum_nbytes(tree): 1ab

71 def nbytes(x): 1ab

72 return math.prod(x.shape) * x.dtype.itemsize 1ab

73 

74 return tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 1ab

75 

76 

77def next_divisor_small(dividend, min_divisor): 1ab

78 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 78 ↛ 81line 78 didn't jump to line 81 because the loop on line 78 didn't complete1ab

79 if dividend % divisor == 0: 79 ↛ 78line 79 didn't jump to line 78 because the condition on line 79 was always true1ab

80 return divisor 1ab

81 return dividend 

82 

83 

84def next_divisor_large(dividend, min_divisor): 1ab

85 max_inv_divisor = dividend // min_divisor 1ab

86 for inv_divisor in range(max_inv_divisor, 0, -1): 1ab

87 if dividend % inv_divisor == 0: 87 ↛ 86line 87 didn't jump to line 86 because the condition on line 87 was always true1ab

88 return dividend // inv_divisor 1ab

89 return dividend 1ab

90 

91 

92def next_divisor(dividend, min_divisor): 1ab

93 if dividend == 0: 1ab

94 return min_divisor 1ab

95 if min_divisor * min_divisor <= dividend: 1ab

96 return next_divisor_small(dividend, min_divisor) 1ab

97 return next_divisor_large(dividend, min_divisor) 1ab

98 

99 

100def pull_nonbatched(axes, tree): 1ab

101 def pull_nonbatched(x, axis): 1ab

102 if axis is None: 1ab

103 return None 1ab

104 else: 

105 return x 1ab

106 

107 return tree_map(pull_nonbatched, tree, axes), tree 1ab

108 

109 

110def push_nonbatched(axes, tree, original_tree): 1ab

111 def push_nonbatched(original_x, x, axis): 1ab

112 if axis is None: 1ab

113 return original_x 1ab

114 else: 

115 return x 1ab

116 

117 return tree_map(push_nonbatched, original_tree, tree, axes) 1ab

118 

119 

120def move_axes_out(axes, tree): 1ab

121 def move_axis_out(x, axis): 1ab

122 return jnp.moveaxis(x, axis, 0) 1ab

123 

124 return tree_map(move_axis_out, tree, axes) 1ab

125 

126 

127def move_axes_in(axes, tree): 1ab

128 def move_axis_in(x, axis): 1ab

129 return jnp.moveaxis(x, 0, axis) 1ab

130 

131 return tree_map(move_axis_in, tree, axes) 1ab

132 

133 

134def batch(tree, nbatches): 1ab

135 def batch(x): 1ab

136 return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:]) 1ab

137 

138 return tree_map(batch, tree) 1ab

139 

140 

141def unbatch(tree): 1ab

142 def unbatch(x): 1ab

143 return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) 1ab

144 

145 return tree_map(unbatch, tree) 1ab

146 

147 

148def check_same(tree1, tree2): 1ab

149 def check_same(x1, x2): 1ab

150 assert x1.shape == x2.shape 1ab

151 assert x1.dtype == x2.dtype 1ab

152 

153 tree_map(check_same, tree1, tree2) 1ab

154 

155 

156def autobatch( 1ab

157 func: Callable, 

158 max_io_nbytes: int, 

159 in_axes: PyTree[int | None] = 0, 

160 out_axes: PyTree[int] = 0, 

161 return_nbatches: bool = False, 

162) -> Callable: 

163 """ 

164 Batch a function such that each batch is smaller than a threshold. 

165 

166 Parameters 

167 ---------- 

168 func 

169 A jittable function with positional arguments only, with inputs and 

170 outputs pytrees of arrays. 

171 max_io_nbytes 

172 The maximum number of input + output bytes in each batch (excluding 

173 unbatched arguments.) 

174 in_axes 

175 A tree matching (a prefix of) the structure of the function input, 

176 indicating along which axes each array should be batched. A `None` axis 

177 indicates to not batch an argument. 

178 out_axes 

179 The same for outputs (but non-batching is not allowed). 

180 return_nbatches 

181 If True, the number of batches is returned as a second output. 

182 

183 Returns 

184 ------- 

185 A function with the same signature as `func`, save for the return value if `return_nbatches`. 

186 """ 

187 initial_in_axes = in_axes 1ab

188 initial_out_axes = out_axes 1ab

189 

190 @jit 1ab

191 @wraps(func) 1ab

192 def batched_func(*args): 1ab

193 example_result = eval_shape(func, *args) 1ab

194 

195 in_axes = expand_axes(initial_in_axes, args) 1ab

196 out_axes = expand_axes(initial_out_axes, example_result) 1ab

197 check_no_nones(out_axes, example_result) 1ab

198 

199 size = extract_size((in_axes, out_axes), (args, example_result)) 1ab

200 

201 args, nonbatched_args = pull_nonbatched(in_axes, args) 1ab

202 

203 total_nbytes = sum_nbytes((args, example_result)) 1ab

204 min_nbatches = total_nbytes // max_io_nbytes + bool( 1ab

205 total_nbytes % max_io_nbytes 

206 ) 

207 min_nbatches = max(1, min_nbatches) 1ab

208 nbatches = next_divisor(size, min_nbatches) 1ab

209 assert 1 <= nbatches <= max(1, size) 1ab

210 assert size % nbatches == 0 1ab

211 assert total_nbytes % nbatches == 0 1ab

212 

213 batch_nbytes = total_nbytes // nbatches 1ab

214 if batch_nbytes > max_io_nbytes: 1ab

215 assert size == nbatches 1ab

216 msg = f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}' 1ab

217 warn(msg) 1ab

218 

219 def loop(_, args): 1ab

220 args = move_axes_in(in_axes, args) 1ab

221 args = push_nonbatched(in_axes, args, nonbatched_args) 1ab

222 result = func(*args) 1ab

223 result = move_axes_out(out_axes, result) 1ab

224 return None, result 1ab

225 

226 args = move_axes_out(in_axes, args) 1ab

227 args = batch(args, nbatches) 1ab

228 _, result = scan(loop, None, args) 1ab

229 result = unbatch(result) 1ab

230 result = move_axes_in(out_axes, result) 1ab

231 

232 check_same(example_result, result) 1ab

233 

234 if return_nbatches: 1ab

235 return result, nbatches 1ab

236 return result 1ab

237 

238 return batched_func 1ab