Coverage for src/lsqfitgp/_jaxext/_batcher.py: 100%
61 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_jaxext/_batcher.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1feabcd
21import math 1feabcd
23from jax import lax 1feabcd
24from jax import numpy as jnp 1feabcd
25import numpy 1feabcd
27def batchufunc(func, *, maxnbytes): 1feabcd
28 """
30 Make a batched version of an universal function.
32 The function is modified to process its inputs in chunks.
34 Parameters
35 ----------
36 func : callable
37 A jax-traceable universal function. All positional arguments are assumed
38 to be arrays which are broadcasted to determine the shape.
39 maxnbytes : number
40 The maximum number of bytes in each input chunck over all input arrays
41 after broadcasting.
43 Return
44 ------
45 batched_func : callable
46 The batched version of `func`. Keywords arguments are passed as-is to
47 the function.
49 """
51 maxnbytes = int(maxnbytes) 1eabcd
52 assert maxnbytes > 0 1eabcd
54 @functools.wraps(func) 1eabcd
55 def batched_func(*args, **kw): 1eabcd
57 shape = jnp.broadcast_shapes(*(arg.shape for arg in args)) 1eabcd
58 if not shape or any(size == 0 for size in shape): 1eabcd
59 return func(*args) 1abcd
61 rowsize = math.prod(shape[1:]) 1eabcd
62 rownbytes = rowsize * sum(arg.dtype.itemsize for arg in args) 1eabcd
63 totalnbytes = shape[0] * rownbytes 1eabcd
64 if totalnbytes <= maxnbytes: 1eabcd
65 return func(*args) 1abcd
67 args = [arg.reshape((1,) * (len(shape) - arg.ndim) + arg.shape) for arg in args] 1eabcd
68 short_args_idx = [i for i, arg in enumerate(args) if arg.shape[0] == 1] 1eabcd
69 long_args_idx = [i for i, arg in enumerate(args) if arg.shape[0] > 1] 1eabcd
70 short_args = [args[i].squeeze(0) for i in short_args_idx] 1eabcd
71 long_args = [args[i] for i in long_args_idx] 1eabcd
73 def combine_args(short_args, long_args): 1eabcd
74 args = [None] * (len(short_args) + len(long_args)) 1eabcd
75 for i, arg in zip(short_args_idx, short_args): 1eabcd
76 args[i] = arg 1eabcd
77 for i, arg in zip(long_args_idx, long_args): 1eabcd
78 args[i] = arg 1eabcd
79 return args 1eabcd
81 if rownbytes <= maxnbytes: 1eabcd
82 # batch over leading axis
84 batchsize = maxnbytes // rownbytes 1eabcd
85 nbatches = shape[0] // batchsize 1eabcd
86 batchedsize = nbatches * batchsize 1eabcd
88 sliced_args = [arg[:batchedsize] for arg in long_args] 1eabcd
89 batched_args = [ 1eabcd
90 arg.reshape((nbatches, batchsize) + arg.shape[1:])
91 for arg in sliced_args
92 ]
93 def scan_loop_body(short_args, batched_args): 1eabcd
94 assert all(arg.ndim == len(shape) - 1 for arg in short_args) 1eabcd
95 assert all(arg.ndim == len(shape) for arg in batched_args) 1eabcd
96 args = combine_args(short_args, batched_args) 1eabcd
97 out = func(*args, **kw) 1eabcd
98 assert out.shape == (batchsize,) + shape[1:] 1eabcd
99 return short_args, out 1eabcd
100 _, out = lax.scan(scan_loop_body, short_args, batched_args) 1eabcd
101 assert out.shape == (nbatches, batchsize) + shape[1:] 1eabcd
102 out = out.reshape((batchedsize,) + shape[1:]) 1eabcd
104 remainder_args = [arg[batchedsize:] for arg in long_args] 1eabcd
105 args = combine_args(short_args, remainder_args) 1eabcd
106 remainder = func(*args) 1eabcd
107 assert remainder.shape == (shape[0] - batchedsize,) + shape[1:] 1eabcd
109 out = jnp.concatenate([out, remainder]) 1eabcd
111 else:
112 # cycle over leading axis, recurse
114 def scan_loop_body(short_args, long_args): 1abcd
115 args = combine_args(short_args, long_args) 1abcd
116 assert all(arg.ndim == len(shape) - 1 for arg in args) 1abcd
117 out = batched_func(*args, **kw) 1abcd
118 assert out.shape == shape[1:] 1abcd
119 return short_args, out 1abcd
120 _, out = lax.scan(scan_loop_body, short_args, long_args) 1abcd
122 assert out.shape == shape 1eabcd
123 return out 1eabcd
125 return batched_func 1eabcd