Coverage for src/lsqfitgp/_jaxext/__init__.py: 90%
97 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_jaxext/__init__.py
2#
3# Copyright (c) 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import traceback 1feabcd
21import functools 1feabcd
23import jax 1feabcd
24from jax import numpy as jnp 1feabcd
26from ._batcher import batchufunc 1feabcd
27from ._fasthash import fasthash64, fasthash32 1feabcd
29def makejaxufunc(ufunc, *derivs, excluded=None, floatcast=False): 1feabcd
30 """
32 Wrap a numpy ufunc to add jax support.
34 Parameters
35 ----------
36 ufunc : callable
37 Elementwise function following numpy broadcasting and type promotion
38 rules. Keyword arguments not supported.
39 derivs : sequence of callable
40 Derivatives of the function w.r.t. each positional argument, with the
41 same signature as `ufunc`. Pass None to indicate a missing derivative.
42 There must be as many derivatives as the arguments to `ufunc`.
43 excluded : sequence of int, optional
44 The indices of arguments that are not broadcasted.
45 floatcast : bool, default False
46 If True, cast all arguments to float before calling the ufunc.
48 Return
49 ------
50 func : callable
51 Wrapped `ufunc`. Supports jit, but the calculation is performed on cpu.
53 """
55 nondiff_argnums = [i for i, d in enumerate(derivs) if d is None] 1feabcd
57 @functools.wraps(ufunc) 1feabcd
58 @functools.partial(jax.custom_jvp, nondiff_argnums=nondiff_argnums) 1feabcd
59 def func(*args): 1feabcd
60 args = tuple(map(jnp.asarray, args)) 1eabcd
61 if floatcast: 61 ↛ 62line 61 didn't jump to line 62 because the condition on line 61 was never true1eabcd
62 flt = float_type(*args)
63 args = tuple(a.astype(flt) for a in args)
64 return pure_callback_ufunc(ufunc, jnp.result_type(*args), *args, excluded=excluded) 1eabcd
66 @func.defjvp 1feabcd
67 def func_jvp(*allargs): 1feabcd
68 ndargs = allargs[:-2] 1abcd
69 dargs = allargs[-2] 1abcd
70 dargst = allargs[-1] 1abcd
72 itnd = iter(ndargs) 1abcd
73 itd = iter(dargs) 1abcd
74 args = [next(itnd) if d is None else next(itd) for d in derivs] 1abcd
76 result = func(*args) 1abcd
77 tangent = sum([ 1abcd
78 d(*args) * t for t, d in
79 zip(dargst, (d for d in derivs if d is not None))
80 ])
81 return result, tangent 1abcd
83 return func 1feabcd
85def elementwise_grad(fun, argnum=0): 1feabcd
86 assert int(argnum) == argnum and argnum >= 0, argnum 1feabcd
87 @functools.wraps(fun) 1feabcd
88 def funderiv(*args, **kw): 1feabcd
89 preargs = args[:argnum] 1feabcd
90 postargs = args[argnum + 1:] 1feabcd
91 def oneargfun(arg): 1feabcd
92 args = preargs + (arg,) + postargs 1feabcd
93 return fun(*args, **kw) 1feabcd
94 primal = args[argnum] 1feabcd
95 shape = getattr(primal, 'shape', ()) 1feabcd
96 dtype = getattr(primal, 'dtype', type(primal)) 1feabcd
97 tangent = jnp.ones(shape, dtype) 1feabcd
98 primal_out, tangent_out = jax.jvp(oneargfun, (primal,), (tangent,)) 1feabcd
99 return tangent_out 1feabcd
100 return funderiv 1feabcd
102class skipifabstract: 1feabcd
103 """
104 Context manager to try to do all operations eagerly even during jit, and
105 skip entirely if it is not possible.
106 """
107 # I feared this would be slow because of the slow jax exception handling,
108 # but %timeit suggests it isn't
110 ENSURE_COMPILE_TIME_EVAL = True 1feabcd
111 ENABLED = True 1feabcd
113 def __enter__(self): 1feabcd
114 if self.ENSURE_COMPILE_TIME_EVAL and self.ENABLED: 114 ↛ exitline 114 didn't return from function '__enter__' because the condition on line 114 was always true1feabcd
115 self.mgr = jax.ensure_compile_time_eval() 1feabcd
116 self.mgr.__enter__() 1feabcd
118 def __exit__(self, exc_type, exc_value, tb): 1feabcd
119 if not self.ENABLED: 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true1feabcd
120 return
121 exit = None 1feabcd
122 if self.ENSURE_COMPILE_TIME_EVAL: 122 ↛ 124line 122 didn't jump to line 1241feabcd
123 exit = self.mgr.__exit__(exc_type, exc_value, tb) 1feabcd
124 ignorable_error = ( 1feabcd
125 exc_type is not None
126 and issubclass(exc_type, (
127 jax.errors.ConcretizationTypeError,
128 jax.errors.TracerArrayConversionError,
129 # TODO why isn't this a subclass of the former like
130 # TracerBoolConversionError? Open an issue
131 ))
132 )
133 if exit or ignorable_error: 1feabcd
134 return True 1feabcd
136 weird_cond = exc_type is IndexError and ( 1feabcd
137 traceback.extract_tb(tb)[-1].name in ('arg_info_pytree', '_origin_msg'),
138 )
139 if weird_cond: # pragma: no cover 1feabcd
140 # TODO this ignores a jax internal bug I don't understand, appears
141 # in examples/pdf4.py
142 return True
144def float_type(*args): 1feabcd
145 t = jnp.result_type(*args) 1feabcd
146 return jnp.sin(jnp.empty(0, t)).dtype 1feabcd
147 # numpy does this with common_type, but that supports only arrays, not
148 # dtypes in the input. jnp.common_type is not defined.
150def is_jax_type(dtype): 1feabcd
151 dtype = jnp.dtype(dtype) 1feabcd
152 try: 1feabcd
153 jnp.empty(0, dtype) 1feabcd
154 return True 1feabcd
155 except TypeError as e: 1abcd
156 if 'JAX only supports number and bool dtypes' in str(e): 156 ↛ 158line 156 didn't jump to line 158 because the condition on line 156 was always true1abcd
157 return False 1abcd
158 raise
160def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs): 1feabcd
161 """ version of jax.pure_callback that deals correctly with ufuncs,
162 see https://github.com/google/jax/issues/17187 """
163 if excluded is None: 1eabcd
164 excluded = () 1eabcd
165 shape = jnp.broadcast_shapes(*( 1eabcd
166 a.shape
167 for i, a in enumerate(args)
168 if i not in excluded
169 ))
170 ndim = len(shape) 1eabcd
171 padded_args = [ 1eabcd
172 a if i in excluded
173 else jnp.expand_dims(a, tuple(range(ndim - a.ndim)))
174 for i, a in enumerate(args)
175 ]
176 result = jax.ShapeDtypeStruct(shape, dtype) 1eabcd
177 return jax.pure_callback(callback, result, *padded_args, vectorized=True, **kwargs) 1eabcd
179 # TODO when jax solves this, check version and piggyback on original if new
181def limit_derivatives(x, n, error_func=None): 1feabcd
182 """
183 Limit the number of derivatives that goes through a value.
185 Parameters
186 ----------
187 x : array_like
188 The value.
189 n : int
190 The maximum number of derivatives allowed. Must be an integer.
191 error_func : callable, optional
192 A function that takes (derivatives taken, n) and returns an exception.
194 Return
195 ------
196 x : array_like
197 The value, unchanged.
198 """
199 assert n == int(n) 1feabcd
200 if error_func is None: 200 ↛ 201line 200 didn't jump to line 201 because the condition on line 200 was never true1feabcd
201 def error_func(current, n):
202 return ValueError(f'took {current} derivatives > limit {n}')
203 return _limit_derivatives_impl(0, n, error_func, x) 1feabcd
205@functools.partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2)) 1feabcd
206def _limit_derivatives_impl(current, limit, func, x): 1feabcd
207 if current > limit: 1feabcd
208 raise func(current, limit) 1abcd
209 return x 1feabcd
211@_limit_derivatives_impl.defjvp 1feabcd
212def _limit_derivatives_impl_jvp(current, limit, func, primals, tangents): 1feabcd
213 x, = primals 1feabcd
214 xdot, = tangents 1feabcd
215 return _limit_derivatives_impl(current + 1, limit, func, x), xdot 1feabcd