Coverage for src/lsqfitgp/_Kernel/_isotropic.py: 100%
46 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_Kernel/_isotropic.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import sys 1efabcd
22from jax import numpy as jnp 1efabcd
24from .. import _jaxext 1efabcd
26from . import _util 1efabcd
27from . import _crosskernel 1efabcd
28from . import _kernel 1efabcd
29from . import _stationary 1efabcd
31class CrossIsotropicKernel(_stationary.CrossStationaryKernel): 1efabcd
32 """
34 Subclass of `CrossStationaryKernel` for isotropic kernels.
36 An isotropic kernel depends only on the Euclidean distance between its two
37 arguments.
39 Parameters
40 ----------
41 core : callable
42 A function taking one argument ``r2`` which is the squared distance
43 between x and y, plus optionally keyword arguments.
44 input : {'squared', 'abs', 'posabs', 'raw'}, default 'squared'
45 If ``'squared'``, `core` is passed the squared distance. If ``'abs'``,
46 it is passed the distance (not squared). If ``'posabs'``, it is passed
47 the distance, and the distance of equal points is a small number instead
48 of zero. If ``'raw'``, `core` is passed both points separately like
49 non-stationary kernels.
50 **kw
51 Additional keyword arguments are passed to the `CrossKernel`
52 constructor.
54 Notes
55 -----
56 The ``input='posabs'`` option will cause problems with second derivatives in
57 more than one dimension.
59 """
61 def __new__(cls, core, *, input='squared', **kw): 1efabcd
62 if input == 'raw': 1efabcd
63 return _crosskernel.CrossKernel.__new__(cls, core, **kw) 1efabcd
65 if input in ('squared', 'abs'): 1efabcd
66 dist = lambda x, y: jnp.square(x - y) 1efabcd
67 elif input == 'posabs': 1abcd
68 dist = lambda x, y: jnp.square(_stationary._softabs(x - y)) 1abcd
69 else:
70 raise KeyError(input) 1abcd
72 if input in ('posabs', 'abs'): 1efabcd
73 transf = jnp.sqrt 1abcd
74 else:
75 transf = lambda ss: ss 1efabcd
77 def newcore(x, y, **kwargs): 1efabcd
78 ss = _util.sum_recurse_dtype(dist, x, y) 1efabcd
79 return core(transf(ss), **kwargs) 1efabcd
81 return _crosskernel.CrossKernel.__new__(cls, newcore, **kw) 1efabcd
83 # TODO add a `distance` parameter to pick arbitrary distances, but since the
84 # distance definition can not be changed arbitrarily, it may be better to
85 # keep this class for the 2-norm and eventually add another if needed.
87 # TODO it is not efficient that the distance is computed separately for
88 # each kernel in a kernel expression, but probably it would be difficult
89 # to support everything without bugs while also computing the distance once.
90 # A possible way is adding a keyword argument to the _kernel member
91 # that kernels use to memoize things, the first IsotropicKernel that gets
92 # called puts the distance there. Possible name: _cache.
94class IsotropicKernel(CrossIsotropicKernel, _stationary.StationaryKernel): 1efabcd
95 pass 1efabcd
97IsotropicKernel.inherit_all_algops(intermediates=True) 1efabcd
98IsotropicKernel.inherit_transf('rescale', intermediates=True) 1efabcd
99IsotropicKernel.inherit_transf('loc', intermediates=True) 1efabcd
100IsotropicKernel.inherit_transf('scale', intermediates=True) 1efabcd
101IsotropicKernel.inherit_transf('maxdim', intermediates=True) 1efabcd
102IsotropicKernel.inherit_transf('derivable', intermediates=True) 1efabcd
103IsotropicKernel.inherit_transf('normalize', intermediates=True) 1efabcd
104IsotropicKernel.inherit_transf('cond', intermediates=True) 1efabcd
106class CrossConstant(CrossIsotropicKernel): 1efabcd
107 pass 1efabcd
109class Constant(CrossConstant, IsotropicKernel): 1efabcd
110 pass 1efabcd
112def zero(x, y): 1efabcd
113 return jnp.broadcast_to(0., jnp.broadcast_shapes(x.shape, y.shape)) 1efabcd
115class Zero(IsotropicKernel): 1efabcd
116 """
117 Represents a kernel that unconditionally yields zero.
118 """
120 def __new__(cls): 1efabcd
121 return super().__new__(cls, zero, input='raw') 1efabcd
123_crosskernel.IsotropicKernel = IsotropicKernel 1efabcd
124_crosskernel.CrossIsotropicKernel = CrossIsotropicKernel 1efabcd
125_crosskernel.Constant = Constant 1efabcd
126_crosskernel.CrossConstant = CrossConstant 1efabcd