Coverage for src/bartz/jaxext/_autobatch.py: 97%
122 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
1# bartz/src/bartz/jaxext/_autobatch.py
2#
3# Copyright (c) 2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implementation of `autobatch`."""
27import math 1ab
28from collections.abc import Callable 1ab
29from functools import wraps 1ab
30from warnings import warn 1ab
32from jax import eval_shape, jit 1ab
33from jax import numpy as jnp 1ab
34from jax.lax import scan 1ab
35from jax.tree import flatten as tree_flatten 1ab
36from jax.tree import map as tree_map 1ab
37from jax.tree import reduce as tree_reduce 1ab
38from jaxtyping import PyTree 1ab
41def expand_axes(axes, tree): 1ab
42 """Expand `axes` such that they match the pytreedef of `tree`."""
44 def expand_axis(axis, subtree): 1ab
45 return tree_map(lambda _: axis, subtree) 1ab
47 return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None) 1ab
50def check_no_nones(axes, tree): 1ab
51 def check_not_none(_, axis): 1ab
52 assert axis is not None 1ab
54 tree_map(check_not_none, tree, axes) 1ab
57def extract_size(axes, tree): 1ab
58 def get_size(x, axis): 1ab
59 if axis is None: 1ab
60 return None 1ab
61 else:
62 return x.shape[axis] 1ab
64 sizes = tree_map(get_size, tree, axes) 1ab
65 sizes, _ = tree_flatten(sizes) 1ab
66 assert all(s == sizes[0] for s in sizes) 1ab
67 return sizes[0] 1ab
70def sum_nbytes(tree): 1ab
71 def nbytes(x): 1ab
72 return math.prod(x.shape) * x.dtype.itemsize 1ab
74 return tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 1ab
77def next_divisor_small(dividend, min_divisor): 1ab
78 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 78 ↛ 81line 78 didn't jump to line 81 because the loop on line 78 didn't complete1ab
79 if dividend % divisor == 0: 79 ↛ 78line 79 didn't jump to line 78 because the condition on line 79 was always true1ab
80 return divisor 1ab
81 return dividend
84def next_divisor_large(dividend, min_divisor): 1ab
85 max_inv_divisor = dividend // min_divisor 1ab
86 for inv_divisor in range(max_inv_divisor, 0, -1): 1ab
87 if dividend % inv_divisor == 0: 87 ↛ 86line 87 didn't jump to line 86 because the condition on line 87 was always true1ab
88 return dividend // inv_divisor 1ab
89 return dividend 1ab
92def next_divisor(dividend, min_divisor): 1ab
93 if dividend == 0: 1ab
94 return min_divisor 1ab
95 if min_divisor * min_divisor <= dividend: 1ab
96 return next_divisor_small(dividend, min_divisor) 1ab
97 return next_divisor_large(dividend, min_divisor) 1ab
100def pull_nonbatched(axes, tree): 1ab
101 def pull_nonbatched(x, axis): 1ab
102 if axis is None: 1ab
103 return None 1ab
104 else:
105 return x 1ab
107 return tree_map(pull_nonbatched, tree, axes), tree 1ab
110def push_nonbatched(axes, tree, original_tree): 1ab
111 def push_nonbatched(original_x, x, axis): 1ab
112 if axis is None: 1ab
113 return original_x 1ab
114 else:
115 return x 1ab
117 return tree_map(push_nonbatched, original_tree, tree, axes) 1ab
120def move_axes_out(axes, tree): 1ab
121 def move_axis_out(x, axis): 1ab
122 return jnp.moveaxis(x, axis, 0) 1ab
124 return tree_map(move_axis_out, tree, axes) 1ab
127def move_axes_in(axes, tree): 1ab
128 def move_axis_in(x, axis): 1ab
129 return jnp.moveaxis(x, 0, axis) 1ab
131 return tree_map(move_axis_in, tree, axes) 1ab
134def batch(tree, nbatches): 1ab
135 def batch(x): 1ab
136 return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:]) 1ab
138 return tree_map(batch, tree) 1ab
141def unbatch(tree): 1ab
142 def unbatch(x): 1ab
143 return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) 1ab
145 return tree_map(unbatch, tree) 1ab
148def check_same(tree1, tree2): 1ab
149 def check_same(x1, x2): 1ab
150 assert x1.shape == x2.shape 1ab
151 assert x1.dtype == x2.dtype 1ab
153 tree_map(check_same, tree1, tree2) 1ab
156def autobatch( 1ab
157 func: Callable,
158 max_io_nbytes: int,
159 in_axes: PyTree[int | None] = 0,
160 out_axes: PyTree[int] = 0,
161 return_nbatches: bool = False,
162) -> Callable:
163 """
164 Batch a function such that each batch is smaller than a threshold.
166 Parameters
167 ----------
168 func
169 A jittable function with positional arguments only, with inputs and
170 outputs pytrees of arrays.
171 max_io_nbytes
172 The maximum number of input + output bytes in each batch (excluding
173 unbatched arguments.)
174 in_axes
175 A tree matching (a prefix of) the structure of the function input,
176 indicating along which axes each array should be batched. A `None` axis
177 indicates to not batch an argument.
178 out_axes
179 The same for outputs (but non-batching is not allowed).
180 return_nbatches
181 If True, the number of batches is returned as a second output.
183 Returns
184 -------
185 A function with the same signature as `func`, save for the return value if `return_nbatches`.
186 """
187 initial_in_axes = in_axes 1ab
188 initial_out_axes = out_axes 1ab
190 @jit 1ab
191 @wraps(func) 1ab
192 def batched_func(*args): 1ab
193 example_result = eval_shape(func, *args) 1ab
195 in_axes = expand_axes(initial_in_axes, args) 1ab
196 out_axes = expand_axes(initial_out_axes, example_result) 1ab
197 check_no_nones(out_axes, example_result) 1ab
199 size = extract_size((in_axes, out_axes), (args, example_result)) 1ab
201 args, nonbatched_args = pull_nonbatched(in_axes, args) 1ab
203 total_nbytes = sum_nbytes((args, example_result)) 1ab
204 min_nbatches = total_nbytes // max_io_nbytes + bool( 1ab
205 total_nbytes % max_io_nbytes
206 )
207 min_nbatches = max(1, min_nbatches) 1ab
208 nbatches = next_divisor(size, min_nbatches) 1ab
209 assert 1 <= nbatches <= max(1, size) 1ab
210 assert size % nbatches == 0 1ab
211 assert total_nbytes % nbatches == 0 1ab
213 batch_nbytes = total_nbytes // nbatches 1ab
214 if batch_nbytes > max_io_nbytes: 1ab
215 assert size == nbatches 1ab
216 msg = f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}' 1ab
217 warn(msg) 1ab
219 def loop(_, args): 1ab
220 args = move_axes_in(in_axes, args) 1ab
221 args = push_nonbatched(in_axes, args, nonbatched_args) 1ab
222 result = func(*args) 1ab
223 result = move_axes_out(out_axes, result) 1ab
224 return None, result 1ab
226 args = move_axes_out(in_axes, args) 1ab
227 args = batch(args, nbatches) 1ab
228 _, result = scan(loop, None, args) 1ab
229 result = unbatch(result) 1ab
230 result = move_axes_in(out_axes, result) 1ab
232 check_same(example_result, result) 1ab
234 if return_nbatches: 1ab
235 return result, nbatches 1ab
236 return result 1ab
238 return batched_func 1ab