Coverage for src / bartz / jaxext / _autobatch.py: 97%
122 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-18 15:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-18 15:24 +0000
1# bartz/src/bartz/jaxext/_autobatch.py
2#
3# Copyright (c) 2025, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Implementation of `autobatch`."""
27import math
28from collections.abc import Callable
29from functools import wraps
30from warnings import warn
32from jax import eval_shape, jit
33from jax import numpy as jnp
34from jax.lax import scan
35from jax.tree import flatten as tree_flatten
36from jax.tree import map as tree_map
37from jax.tree import reduce as tree_reduce
38from jaxtyping import PyTree
41def expand_axes(axes, tree):
42 """Expand `axes` such that they match the pytreedef of `tree`."""
44 def expand_axis(axis, subtree): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
45 return tree_map(lambda _: axis, subtree) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
47 return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
50def check_no_nones(axes, tree):
51 def check_not_none(_, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
52 assert axis is not None 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
54 tree_map(check_not_none, tree, axes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
57def extract_size(axes, tree):
58 def get_size(x, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
59 if axis is None: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
60 return None 1cdefghijklmnopqrstuvwxyzab
61 else:
62 return x.shape[axis] 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
64 sizes = tree_map(get_size, tree, axes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
65 sizes, _ = tree_flatten(sizes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
66 assert all(s == sizes[0] for s in sizes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
67 return sizes[0] 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
70def sum_nbytes(tree):
71 def nbytes(x): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
72 return math.prod(x.shape) * x.dtype.itemsize 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
74 return tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
77def next_divisor_small(dividend, min_divisor):
78 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 78 ↛ 81line 78 didn't jump to line 81 because the loop on line 78 didn't complete1cdefghijklmnopqrJKLstuMvwxyNzOPQRFGHISTUVWXYZ0123
79 if dividend % divisor == 0: 79 ↛ 78line 79 didn't jump to line 78 because the condition on line 79 was always true1cdefghijklmnopqrJKLstuMvwxyNzOPQRFGHISTUVWXYZ0123
80 return divisor 1cdefghijklmnopqrJKLstuMvwxyNzOPQRFGHISTUVWXYZ0123
81 return dividend
84def next_divisor_large(dividend, min_divisor):
85 max_inv_divisor = dividend // min_divisor 1aBCDEAb
86 for inv_divisor in range(max_inv_divisor, 0, -1): 1aBCDEAb
87 if dividend % inv_divisor == 0: 87 ↛ 86line 87 didn't jump to line 86 because the condition on line 87 was always true1aBCDEb
88 return dividend // inv_divisor 1aBCDEb
89 return dividend 1A
92def next_divisor(dividend, min_divisor):
93 if dividend == 0: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
94 return min_divisor 14
95 if min_divisor * min_divisor <= dividend: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAbTUVWXYZ0123
96 return next_divisor_small(dividend, min_divisor) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRFGHISTUVWXYZ0123
97 return next_divisor_large(dividend, min_divisor) 1aBCDEAb
100def pull_nonbatched(axes, tree):
101 def pull_nonbatched(x, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
102 if axis is None: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
103 return None 1cdefghijklmnopqrstuvwxyzab
104 else:
105 return x 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
107 return tree_map(pull_nonbatched, tree, axes), tree 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
110def push_nonbatched(axes, tree, original_tree):
111 def push_nonbatched(original_x, x, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
112 if axis is None: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
113 return original_x 1cdefghijklmnopqrstuvwxyzab
114 else:
115 return x 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
117 return tree_map(push_nonbatched, original_tree, tree, axes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
120def move_axes_out(axes, tree):
121 def move_axis_out(x, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
122 return jnp.moveaxis(x, axis, 0) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
124 return tree_map(move_axis_out, tree, axes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
127def move_axes_in(axes, tree):
128 def move_axis_in(x, axis): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
129 return jnp.moveaxis(x, 0, axis) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
131 return tree_map(move_axis_in, tree, axes) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
134def batch(tree, nbatches):
135 def batch(x): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
136 return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:]) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
138 return tree_map(batch, tree) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
141def unbatch(tree):
142 def unbatch(x): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
143 return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
145 return tree_map(unbatch, tree) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
148def check_same(tree1, tree2):
149 def check_same(x1, x2): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
150 assert x1.shape == x2.shape 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
151 assert x1.dtype == x2.dtype 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
153 tree_map(check_same, tree1, tree2) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
156def autobatch(
157 func: Callable,
158 max_io_nbytes: int,
159 in_axes: PyTree[int | None] = 0,
160 out_axes: PyTree[int] = 0,
161 return_nbatches: bool = False,
162) -> Callable:
163 """
164 Batch a function such that each batch is smaller than a threshold.
166 Parameters
167 ----------
168 func
169 A jittable function with positional arguments only, with inputs and
170 outputs pytrees of arrays.
171 max_io_nbytes
172 The maximum number of input + output bytes in each batch (excluding
173 unbatched arguments.)
174 in_axes
175 A tree matching (a prefix of) the structure of the function input,
176 indicating along which axes each array should be batched. A `None` axis
177 indicates to not batch an argument.
178 out_axes
179 The same for outputs (but non-batching is not allowed).
180 return_nbatches
181 If True, the number of batches is returned as a second output.
183 Returns
184 -------
185 A function with the same signature as `func`, save for the return value if `return_nbatches`.
186 """
187 initial_in_axes = in_axes 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
188 initial_out_axes = out_axes 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
190 @jit 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
191 @wraps(func) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
192 def batched_func(*args): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
193 example_result = eval_shape(func, *args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
195 in_axes = expand_axes(initial_in_axes, args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
196 out_axes = expand_axes(initial_out_axes, example_result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
197 check_no_nones(out_axes, example_result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
199 size = extract_size((in_axes, out_axes), (args, example_result)) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
201 args, nonbatched_args = pull_nonbatched(in_axes, args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
203 total_nbytes = sum_nbytes((args, example_result)) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
204 min_nbatches = total_nbytes // max_io_nbytes + bool( 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
205 total_nbytes % max_io_nbytes
206 )
207 min_nbatches = max(1, min_nbatches) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
208 nbatches = next_divisor(size, min_nbatches) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
209 assert 1 <= nbatches <= max(1, size) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
210 assert size % nbatches == 0 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
211 assert total_nbytes % nbatches == 0 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
213 batch_nbytes = total_nbytes // nbatches 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
214 if batch_nbytes > max_io_nbytes: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
215 assert size == nbatches 1A
216 msg = f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}' 1A
217 warn(msg) 1A
219 def loop(_, args): 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
220 args = move_axes_in(in_axes, args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
221 args = push_nonbatched(in_axes, args, nonbatched_args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
222 result = func(*args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
223 result = move_axes_out(out_axes, result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
224 return None, result 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
226 args = move_axes_out(in_axes, args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
227 args = batch(args, nbatches) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
228 _, result = scan(loop, None, args) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
229 result = unbatch(result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
230 result = move_axes_in(out_axes, result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
232 check_same(example_result, result) 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
234 if return_nbatches: 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123
235 return result, nbatches 1FBGCHDIES4
236 return result 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIEAbTUVWXYZ0123
238 return batched_func 1cdefghijklmnopqrJKLstuMvwxyNzOPQRaFBGCHDIESAb4TUVWXYZ0123