Coverage for src/lsqfitgp/_Kernel/_stationary.py: 100%
27 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_Kernel/_stationary.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20from jax import numpy as jnp 1feabcd
22from .. import _jaxext 1feabcd
24from . import _util 1feabcd
25from . import _crosskernel 1feabcd
26from . import _kernel 1feabcd
28class CrossStationaryKernel(_crosskernel.CrossKernel): 1feabcd
29 """
31 Subclass of `CrossKernel` for stationary kernels.
33 A stationary kernel depends only on the difference, dimension by dimension,
34 between its two arguments.
36 Parameters
37 ----------
38 core : callable
39 A function taking one positional argument ``delta = x - y`` and optional
40 keyword arguments.
41 input : {'signed', 'posabs', 'abs'}, default 'signed'
42 If ``'signed'``, `kernel` is passed the bare difference. If
43 ``'posabs'``, `kernel` is passed the absolute value of the difference,
44 and the difference of equal points is a small number instead of zero. If
45 ``'abs'``, the absolute value.
46 **kw
47 Additional keyword arguments are passed to the `CrossKernel`
48 constructor.
50 """
52 def __new__(cls, core, *, input='signed', **kw): 1feabcd
54 if input == 'posabs': 1feabcd
55 dist = lambda x, y: _softabs(x - y) 1abcd
56 elif input == 'signed': 1feabcd
57 dist = lambda x, y: x - y 1feabcd
58 elif input == 'abs': 1eabcd
59 dist = lambda x, y: jnp.abs(x - y) 1eabcd
60 else:
61 raise KeyError(input) 1abcd
63 def newcore(x, y, **kw): 1feabcd
64 q = _util.ufunc_recurse_dtype(dist, x, y) 1feabcd
65 return core(q, **kw) 1feabcd
67 return super().__new__(cls, newcore, **kw) 1feabcd
69 # TODO this class requires that, on both inputs, the thing depends only on
70 # the distance. However, when transforming a stationary kernel, I often have
71 # the property only on either side. Making a left/right class interferes
72 # with _swap, because it is not clean to switch classes in the ancestors,
73 # so there would be a class for a single side with a property.
74 #
75 # Maybe the elegant way to keep left/right properties separated is having
76 # separate classes for the two processes.
77 #
78 # Drop the whole generic hierarchy. Keep a single class Kernel. Have two
79 # attributes left and right which are instances of Process. Linops are
80 # applied as methods on left and right, which operate on a partial
81 # evaluation of the core. Kernel.__getattr__ sees if the missing attribute
82 # is a method defined on both left and right and calls both of them; this
83 # allows potential overrides for the joint transformation.
84 #
85 # The class logic operates on left/right. Linops looks one process at a
86 # time. Algops are defined on the whole kernel but need somehow to indicate
87 # how to change the classes of processes. Maybe a list of class-preserving
88 # algops in the process.
90class StationaryKernel(CrossStationaryKernel, _kernel.Kernel): 1feabcd
91 pass 1feabcd
93# make these transformations preserve the class StationaryKernel and upwards,
94# other transformations are added by IsotropicKernel
95StationaryKernel.inherit_transf('dim', intermediates=True) 1feabcd
97def _eps(x): 1feabcd
98 if jnp.issubdtype(x.dtype, jnp.inexact): 1abcd
99 return jnp.finfo(x.dtype).eps 1abcd
100 # finfo(x) does not work in numpy 1.20
101 else:
102 return jnp.finfo(jnp.empty(())).eps 1abcd
104def _softabs(x): 1feabcd
105 return jnp.abs(x) + _eps(x) 1abcd