Coverage for src/lsqfitgp/_jaxext/_batcher.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_jaxext/_batcher.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1feabcd

21import math 1feabcd

22 

23from jax import lax 1feabcd

24from jax import numpy as jnp 1feabcd

25import numpy 1feabcd

26 

27def batchufunc(func, *, maxnbytes): 1feabcd

28 """ 

29 

30 Make a batched version of an universal function. 

31 

32 The function is modified to process its inputs in chunks. 

33 

34 Parameters 

35 ---------- 

36 func : callable 

37 A jax-traceable universal function. All positional arguments are assumed 

38 to be arrays which are broadcasted to determine the shape. 

39 maxnbytes : number 

40 The maximum number of bytes in each input chunck over all input arrays 

41 after broadcasting. 

42 

43 Return 

44 ------ 

45 batched_func : callable 

46 The batched version of `func`. Keywords arguments are passed as-is to 

47 the function. 

48 

49 """ 

50 

51 maxnbytes = int(maxnbytes) 1eabcd

52 assert maxnbytes > 0 1eabcd

53 

54 @functools.wraps(func) 1eabcd

55 def batched_func(*args, **kw): 1eabcd

56 

57 shape = jnp.broadcast_shapes(*(arg.shape for arg in args)) 1eabcd

58 if not shape or any(size == 0 for size in shape): 1eabcd

59 return func(*args) 1abcd

60 

61 rowsize = math.prod(shape[1:]) 1eabcd

62 rownbytes = rowsize * sum(arg.dtype.itemsize for arg in args) 1eabcd

63 totalnbytes = shape[0] * rownbytes 1eabcd

64 if totalnbytes <= maxnbytes: 1eabcd

65 return func(*args) 1abcd

66 

67 args = [arg.reshape((1,) * (len(shape) - arg.ndim) + arg.shape) for arg in args] 1eabcd

68 short_args_idx = [i for i, arg in enumerate(args) if arg.shape[0] == 1] 1eabcd

69 long_args_idx = [i for i, arg in enumerate(args) if arg.shape[0] > 1] 1eabcd

70 short_args = [args[i].squeeze(0) for i in short_args_idx] 1eabcd

71 long_args = [args[i] for i in long_args_idx] 1eabcd

72 

73 def combine_args(short_args, long_args): 1eabcd

74 args = [None] * (len(short_args) + len(long_args)) 1eabcd

75 for i, arg in zip(short_args_idx, short_args): 1eabcd

76 args[i] = arg 1eabcd

77 for i, arg in zip(long_args_idx, long_args): 1eabcd

78 args[i] = arg 1eabcd

79 return args 1eabcd

80 

81 if rownbytes <= maxnbytes: 1eabcd

82 # batch over leading axis 

83 

84 batchsize = maxnbytes // rownbytes 1eabcd

85 nbatches = shape[0] // batchsize 1eabcd

86 batchedsize = nbatches * batchsize 1eabcd

87 

88 sliced_args = [arg[:batchedsize] for arg in long_args] 1eabcd

89 batched_args = [ 1eabcd

90 arg.reshape((nbatches, batchsize) + arg.shape[1:]) 

91 for arg in sliced_args 

92 ] 

93 def scan_loop_body(short_args, batched_args): 1eabcd

94 assert all(arg.ndim == len(shape) - 1 for arg in short_args) 1eabcd

95 assert all(arg.ndim == len(shape) for arg in batched_args) 1eabcd

96 args = combine_args(short_args, batched_args) 1eabcd

97 out = func(*args, **kw) 1eabcd

98 assert out.shape == (batchsize,) + shape[1:] 1eabcd

99 return short_args, out 1eabcd

100 _, out = lax.scan(scan_loop_body, short_args, batched_args) 1eabcd

101 assert out.shape == (nbatches, batchsize) + shape[1:] 1eabcd

102 out = out.reshape((batchedsize,) + shape[1:]) 1eabcd

103 

104 remainder_args = [arg[batchedsize:] for arg in long_args] 1eabcd

105 args = combine_args(short_args, remainder_args) 1eabcd

106 remainder = func(*args) 1eabcd

107 assert remainder.shape == (shape[0] - batchedsize,) + shape[1:] 1eabcd

108 

109 out = jnp.concatenate([out, remainder]) 1eabcd

110 

111 else: 

112 # cycle over leading axis, recurse 

113 

114 def scan_loop_body(short_args, long_args): 1abcd

115 args = combine_args(short_args, long_args) 1abcd

116 assert all(arg.ndim == len(shape) - 1 for arg in args) 1abcd

117 out = batched_func(*args, **kw) 1abcd

118 assert out.shape == shape[1:] 1abcd

119 return short_args, out 1abcd

120 _, out = lax.scan(scan_loop_body, short_args, long_args) 1abcd

121 

122 assert out.shape == shape 1eabcd

123 return out 1eabcd

124 

125 return batched_func 1eabcd