Coverage for src/lsqfitgp/_patch_gvar.py: 100%
17 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_patch_gvar.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" modifications to the global state of gvar """
22import functools 1abcdef
24import gvar 1abcdef
25from jax import numpy as jnp 1abcdef
26from jax.scipy import special as jspecial 1abcdef
28gvar_ufuncs = [ 1abcdef
29 'sin',
30 'cos',
31 'tan',
32 'exp',
33 'log',
34 'sqrt',
35 'fabs',
36 'sinh',
37 'cosh',
38 'tanh',
39 'arcsin',
40 'arccos',
41 'arctan',
42 'arctan2',
43 'arcsinh',
44 'arccosh',
45 'arctanh',
46 'square',
47 'erf',
48]
50for fname in gvar_ufuncs: 1abcdef
51 fgvar = getattr(gvar, fname) 1abcdef
52 fjax = getattr(jnp, fname, getattr(jspecial, fname, NotImplemented)) 1abcdef
53 fboth = functools.singledispatch(fgvar) 1abcdef
54 fboth.register(jnp.ndarray, fjax) 1abcdef
55 setattr(gvar, fname, fboth) 1abcdef
57# reset transformations to support jax arrays
58gvar.BufferDict.del_distribution('log') 1abcdef
59gvar.BufferDict.del_distribution('sqrt') 1abcdef
60gvar.BufferDict.del_distribution('erfinv') 1abcdef
61gvar.BufferDict.add_distribution('log', gvar.exp) 1abcdef
62gvar.BufferDict.add_distribution('sqrt', gvar.square) 1abcdef
63gvar.BufferDict.add_distribution('erfinv', gvar.erf) 1abcdef