Coverage for src/lsqfitgp/_GP/_base.py: 100%
53 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_GP/_base.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1efabcd
22import jax 1efabcd
23from jax import numpy as jnp 1efabcd
25from .. import _jaxext 1efabcd
26from .. import _utils 1efabcd
28class GPBase: 1efabcd
30 def __init__(self, *, checkfinite=True, checklin=True): 1efabcd
31 self._checkfinite = bool(checkfinite) 1efabcd
32 self._checklin = bool(checklin) 1efabcd
34 def _clone(self): 1efabcd
35 newself = object.__new__(self.__class__) 1efabcd
36 newself._checkfinite = self._checkfinite 1efabcd
37 newself._checklin = self._checklin 1efabcd
38 return newself 1efabcd
40 class _SingletonMeta(type): 1efabcd
42 def __repr__(cls): 1efabcd
43 return cls.__name__ 1abcd
45 class _Singleton(metaclass=_SingletonMeta): 1efabcd
47 def __new__(cls): 1efabcd
48 raise NotImplementedError(f"{cls.__name__} can not be instantiated") 1abcd
50 class DefaultProcess(_Singleton): 1efabcd
51 """ Key of the default process. """
52 pass 1efabcd
54 def _checklinear(self, func, inshapes, elementwise=False): 1efabcd
56 # Make input arrays.
57 rkey = jax.random.PRNGKey(202206091600) 1efabcd
58 inp = [] 1efabcd
59 for shape in inshapes: 1efabcd
60 rkey, subkey = jax.random.split(rkey) 1efabcd
61 inp.append(jax.random.normal(subkey, shape)) 1efabcd
63 # Put zeros into the arrays to check they are preserved.
64 if elementwise: 1efabcd
65 shape = jnp.broadcast_shapes(*inshapes) 1abcd
66 rkey, subkey = jax.random.split(rkey) 1abcd
67 zeros = jax.random.bernoulli(subkey, 0.5, shape) 1abcd
68 for i, a in enumerate(inp): 1abcd
69 inp[i] = a.at[zeros].set(0) 1abcd
71 # Compute JVP and check it is identical to the function itself.
72 with _jaxext.skipifabstract(): 1efabcd
73 out0, out1 = jax.jvp(func, inp, inp) 1efabcd
74 if out1.dtype == jax.float0: 1efabcd
75 cond = jnp.allclose(out0, 0) 1abcd
76 else:
77 cond = jnp.allclose(out0, out1) 1efabcd
78 if not cond: 1efabcd
79 raise RuntimeError('the transformation is not linear') 1abcd
81 # Check that the function is elementwise.
82 if elementwise: 1efabcd
83 if out0.shape != shape or not (jnp.allclose(out0[zeros], 0) and jnp.allclose(out1[zeros], 0)): 1abcd
84 raise RuntimeError('the transformation is not elementwise') 1abcd
86def newself(meth): 1efabcd
87 """ Decorator to create a new GP object and pass it to the method. """
89 @functools.wraps(meth) 1efabcd
90 def newmeth(self, *args, **kw): 1efabcd
91 self = self._clone() 1efabcd
92 meth(self, *args, **kw) 1efabcd
93 return self 1efabcd
95 # append return value description to docstring
96 doctail = """\ 1efabcd
97 Returns
98 -------
99 gp : GP
100 A new GP object with the applied modifications.
101 """
102 newmeth.__doc__ = _utils.append_to_docstring(meth.__doc__, doctail) 1efabcd
104 return newmeth 1efabcd