Coverage for src/lsqfitgp/_jaxext/__init__.py: 90%

97 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_jaxext/__init__.py 

2# 

3# Copyright (c) 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import traceback 1feabcd

21import functools 1feabcd

22 

23import jax 1feabcd

24from jax import numpy as jnp 1feabcd

25 

26from ._batcher import batchufunc 1feabcd

27from ._fasthash import fasthash64, fasthash32 1feabcd

28 

29def makejaxufunc(ufunc, *derivs, excluded=None, floatcast=False): 1feabcd

30 """ 

31  

32 Wrap a numpy ufunc to add jax support. 

33  

34 Parameters 

35 ---------- 

36 ufunc : callable 

37 Elementwise function following numpy broadcasting and type promotion 

38 rules. Keyword arguments not supported. 

39 derivs : sequence of callable 

40 Derivatives of the function w.r.t. each positional argument, with the 

41 same signature as `ufunc`. Pass None to indicate a missing derivative. 

42 There must be as many derivatives as the arguments to `ufunc`. 

43 excluded : sequence of int, optional 

44 The indices of arguments that are not broadcasted. 

45 floatcast : bool, default False 

46 If True, cast all arguments to float before calling the ufunc. 

47  

48 Return 

49 ------ 

50 func : callable 

51 Wrapped `ufunc`. Supports jit, but the calculation is performed on cpu. 

52  

53 """ 

54 

55 nondiff_argnums = [i for i, d in enumerate(derivs) if d is None] 1feabcd

56 

57 @functools.wraps(ufunc) 1feabcd

58 @functools.partial(jax.custom_jvp, nondiff_argnums=nondiff_argnums) 1feabcd

59 def func(*args): 1feabcd

60 args = tuple(map(jnp.asarray, args)) 1eabcd

61 if floatcast: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true1eabcd

62 flt = float_type(*args) 

63 args = tuple(a.astype(flt) for a in args) 

64 return pure_callback_ufunc(ufunc, jnp.result_type(*args), *args, excluded=excluded) 1eabcd

65 

66 @func.defjvp 1feabcd

67 def func_jvp(*allargs): 1feabcd

68 ndargs = allargs[:-2] 1abcd

69 dargs = allargs[-2] 1abcd

70 dargst = allargs[-1] 1abcd

71 

72 itnd = iter(ndargs) 1abcd

73 itd = iter(dargs) 1abcd

74 args = [next(itnd) if d is None else next(itd) for d in derivs] 1abcd

75 

76 result = func(*args) 1abcd

77 tangent = sum([ 1abcd

78 d(*args) * t for t, d in 

79 zip(dargst, (d for d in derivs if d is not None)) 

80 ]) 

81 return result, tangent 1abcd

82 

83 return func 1feabcd

84 

85def elementwise_grad(fun, argnum=0): 1feabcd

86 assert int(argnum) == argnum and argnum >= 0, argnum 1feabcd

87 @functools.wraps(fun) 1feabcd

88 def funderiv(*args, **kw): 1feabcd

89 preargs = args[:argnum] 1feabcd

90 postargs = args[argnum + 1:] 1feabcd

91 def oneargfun(arg): 1feabcd

92 args = preargs + (arg,) + postargs 1feabcd

93 return fun(*args, **kw) 1feabcd

94 primal = args[argnum] 1feabcd

95 shape = getattr(primal, 'shape', ()) 1feabcd

96 dtype = getattr(primal, 'dtype', type(primal)) 1feabcd

97 tangent = jnp.ones(shape, dtype) 1feabcd

98 primal_out, tangent_out = jax.jvp(oneargfun, (primal,), (tangent,)) 1feabcd

99 return tangent_out 1feabcd

100 return funderiv 1feabcd

101 

102class skipifabstract: 1feabcd

103 """ 

104 Context manager to try to do all operations eagerly even during jit, and 

105 skip entirely if it is not possible. 

106 """ 

107 # I feared this would be slow because of the slow jax exception handling, 

108 # but %timeit suggests it isn't 

109 

110 ENSURE_COMPILE_TIME_EVAL = True 1feabcd

111 ENABLED = True 1feabcd

112 

113 def __enter__(self): 1feabcd

114 if self.ENSURE_COMPILE_TIME_EVAL and self.ENABLED: 114 ↛ exitline 114 didn't return from function '__enter__' because the condition on line 114 was always true1feabcd

115 self.mgr = jax.ensure_compile_time_eval() 1feabcd

116 self.mgr.__enter__() 1feabcd

117 

118 def __exit__(self, exc_type, exc_value, tb): 1feabcd

119 if not self.ENABLED: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true1feabcd

120 return 

121 exit = None 1feabcd

122 if self.ENSURE_COMPILE_TIME_EVAL: 122 ↛ 124line 122 didn't jump to line 1241feabcd

123 exit = self.mgr.__exit__(exc_type, exc_value, tb) 1feabcd

124 ignorable_error = ( 1feabcd

125 exc_type is not None 

126 and issubclass(exc_type, ( 

127 jax.errors.ConcretizationTypeError, 

128 jax.errors.TracerArrayConversionError, 

129 # TODO why isn't this a subclass of the former like 

130 # TracerBoolConversionError? Open an issue 

131 )) 

132 ) 

133 if exit or ignorable_error: 1feabcd

134 return True 1feabcd

135 

136 weird_cond = exc_type is IndexError and ( 1feabcd

137 traceback.extract_tb(tb)[-1].name in ('arg_info_pytree', '_origin_msg'), 

138 ) 

139 if weird_cond: # pragma: no cover 1feabcd

140 # TODO this ignores a jax internal bug I don't understand, appears 

141 # in examples/pdf4.py 

142 return True 

143 

144def float_type(*args): 1feabcd

145 t = jnp.result_type(*args) 1feabcd

146 return jnp.sin(jnp.empty(0, t)).dtype 1feabcd

147 # numpy does this with common_type, but that supports only arrays, not 

148 # dtypes in the input. jnp.common_type is not defined. 

149 

150def is_jax_type(dtype): 1feabcd

151 dtype = jnp.dtype(dtype) 1feabcd

152 try: 1feabcd

153 jnp.empty(0, dtype) 1feabcd

154 return True 1feabcd

155 except TypeError as e: 1abcd

156 if 'JAX only supports number and bool dtypes' in str(e): 156 ↛ 158line 156 didn't jump to line 158 because the condition on line 156 was always true1abcd

157 return False 1abcd

158 raise 

159 

160def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs): 1feabcd

161 """ version of jax.pure_callback that deals correctly with ufuncs, 

162 see https://github.com/google/jax/issues/17187 """ 

163 if excluded is None: 1eabcd

164 excluded = () 1eabcd

165 shape = jnp.broadcast_shapes(*( 1eabcd

166 a.shape 

167 for i, a in enumerate(args) 

168 if i not in excluded 

169 )) 

170 ndim = len(shape) 1eabcd

171 padded_args = [ 1eabcd

172 a if i in excluded 

173 else jnp.expand_dims(a, tuple(range(ndim - a.ndim))) 

174 for i, a in enumerate(args) 

175 ] 

176 result = jax.ShapeDtypeStruct(shape, dtype) 1eabcd

177 return jax.pure_callback(callback, result, *padded_args, vectorized=True, **kwargs) 1eabcd

178 

179 # TODO when jax solves this, check version and piggyback on original if new 

180 

181def limit_derivatives(x, n, error_func=None): 1feabcd

182 """ 

183 Limit the number of derivatives that goes through a value. 

184 

185 Parameters 

186 ---------- 

187 x : array_like 

188 The value. 

189 n : int 

190 The maximum number of derivatives allowed. Must be an integer. 

191 error_func : callable, optional 

192 A function that takes (derivatives taken, n) and returns an exception. 

193 

194 Return 

195 ------ 

196 x : array_like 

197 The value, unchanged. 

198 """ 

199 assert n == int(n) 1feabcd

200 if error_func is None: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true1feabcd

201 def error_func(current, n): 

202 return ValueError(f'took {current} derivatives > limit {n}') 

203 return _limit_derivatives_impl(0, n, error_func, x) 1feabcd

204 

205@functools.partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2)) 1feabcd

206def _limit_derivatives_impl(current, limit, func, x): 1feabcd

207 if current > limit: 1feabcd

208 raise func(current, limit) 1abcd

209 return x 1feabcd

210 

211@_limit_derivatives_impl.defjvp 1feabcd

212def _limit_derivatives_impl_jvp(current, limit, func, primals, tangents): 1feabcd

213 x, = primals 1feabcd

214 xdot, = tangents 1feabcd

215 return _limit_derivatives_impl(current + 1, limit, func, x), xdot 1feabcd