Coverage for src/lsqfitgp/_gvarext/_ufunc.py: 94%
46 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_gvarext/_ufunc.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1efabcd
21import string 1efabcd
23import jax 1efabcd
24import gvar 1efabcd
25from jax import numpy as jnp 1efabcd
27from .. import _signature 1efabcd
29from ._jacobian import jacobian, from_jacobian 1efabcd
30from ._tabulate import tabulate_together 1efabcd
32def gvar_gufunc(func, *, signature=None): 1efabcd
33 """
35 Wraps a jax-traceable generalized ufunc with one argument to support gvars.
37 Parameters
38 ----------
39 func : callable
40 A function from one array to one array. It must be a generalized ufunc,
41 and differentiable one time with `jax`.
42 signature : str, optional
43 The signature of the generalized ufunc. If not specified, it is assumed
44 to be scalar to scalar (normal ufunc).
46 Returns
47 -------
48 decorated_func : callable
49 A function that, in addition to numerical arrays, accepts gvars and
50 returns gvars.
52 See also
53 --------
54 numpy.vectorize
56 """
58 # parse signature
59 if signature is None: 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true1efabcd
60 signature = '()->()'
61 sig = _signature.Signature(signature) 1efabcd
62 inp, = sig.incores 1efabcd
63 out, = sig.outcores 1efabcd
64 jac_sig = _signature.Signature.from_tuples([inp], [out + inp]) 1efabcd
66 # make jacobian function
67 deriv = jnp.vectorize(jax.jacfwd(func), signature=jac_sig.signature) 1efabcd
69 # get indices for summation
70 ninp = len(inp) 1efabcd
71 nout = len(out) 1efabcd
72 head_indices = '...' 1efabcd
73 out_indices = string.ascii_letters[:nout] 1efabcd
74 in_indices = string.ascii_letters[nout:nout + ninp] 1efabcd
75 gvar_indices = string.ascii_letters[nout + ninp] 1efabcd
77 # make summation formula
78 jac_indices = head_indices + out_indices + in_indices 1efabcd
79 in_jac_indices = head_indices + in_indices + gvar_indices 1efabcd
80 out_indices = head_indices + out_indices + gvar_indices 1efabcd
81 formula = f'{jac_indices},{in_jac_indices}->{out_indices}' 1efabcd
83 def gvar_function(x): 1efabcd
85 # unpack the gvars
86 in_mean = gvar.mean(x) 1efabcd
87 in_jac, indices = jacobian(x) 1efabcd
89 # apply function
90 out_mean = func(in_mean) 1efabcd
91 jac = deriv(in_mean) 1efabcd
93 # check shapes match
94 head_ndim = jac.ndim - nout - ninp 1efabcd
95 assert jac.shape[:head_ndim] == in_jac.shape[:in_jac.ndim - 1 - ninp] 1efabcd
97 # contract
98 out_jac = jnp.einsum(formula, jac, in_jac) 1efabcd
100 # pack output
101 return from_jacobian(out_mean, out_jac, indices) 1efabcd
103 @functools.wraps(func) 1efabcd
104 def decorated_func(x): 1efabcd
105 if isinstance(x, gvar.GVar): 1efabcd
106 out = gvar_function(x) 1efabcd
107 if not out.ndim: 107 ↛ 109line 107 didn't jump to line 109 because the condition on line 107 was always true1efabcd
108 out = out.item() 1efabcd
109 return out 1efabcd
110 elif getattr(x, 'dtype', None) == object: 1efabcd
111 return gvar_function(x) 1abcd
112 else:
113 return func(x) 1efabcd
115 return decorated_func 1efabcd
117 # TODO add more than one argument or output. Possibly without taking
118 # derivatives when it's not a gvar, i.e., merge the wrappers and cycle over
119 # args. Also implement excluded => note that jnp.vectorize only supports
120 # positional arguments, excluded takes in only indices, not names