Coverage for src/lsqfitgp/_Kernel/_isotropic.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_Kernel/_isotropic.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import sys 1efabcd

21 

22from jax import numpy as jnp 1efabcd

23 

24from .. import _jaxext 1efabcd

25 

26from . import _util 1efabcd

27from . import _crosskernel 1efabcd

28from . import _kernel 1efabcd

29from . import _stationary 1efabcd

30 

31class CrossIsotropicKernel(_stationary.CrossStationaryKernel): 1efabcd

32 """ 

33  

34 Subclass of `CrossStationaryKernel` for isotropic kernels. 

35 

36 An isotropic kernel depends only on the Euclidean distance between its two 

37 arguments. 

38 

39 Parameters 

40 ---------- 

41 core : callable 

42 A function taking one argument ``r2`` which is the squared distance 

43 between x and y, plus optionally keyword arguments. 

44 input : {'squared', 'abs', 'posabs', 'raw'}, default 'squared' 

45 If ``'squared'``, `core` is passed the squared distance. If ``'abs'``, 

46 it is passed the distance (not squared). If ``'posabs'``, it is passed 

47 the distance, and the distance of equal points is a small number instead 

48 of zero. If ``'raw'``, `core` is passed both points separately like 

49 non-stationary kernels. 

50 **kw 

51 Additional keyword arguments are passed to the `CrossKernel` 

52 constructor. 

53  

54 Notes 

55 ----- 

56 The ``input='posabs'`` option will cause problems with second derivatives in 

57 more than one dimension. 

58  

59 """ 

60 

61 def __new__(cls, core, *, input='squared', **kw): 1efabcd

62 if input == 'raw': 1efabcd

63 return _crosskernel.CrossKernel.__new__(cls, core, **kw) 1efabcd

64 

65 if input in ('squared', 'abs'): 1efabcd

66 dist = lambda x, y: jnp.square(x - y) 1efabcd

67 elif input == 'posabs': 1abcd

68 dist = lambda x, y: jnp.square(_stationary._softabs(x - y)) 1abcd

69 else: 

70 raise KeyError(input) 1abcd

71 

72 if input in ('posabs', 'abs'): 1efabcd

73 transf = jnp.sqrt 1abcd

74 else: 

75 transf = lambda ss: ss 1efabcd

76 

77 def newcore(x, y, **kwargs): 1efabcd

78 ss = _util.sum_recurse_dtype(dist, x, y) 1efabcd

79 return core(transf(ss), **kwargs) 1efabcd

80 

81 return _crosskernel.CrossKernel.__new__(cls, newcore, **kw) 1efabcd

82 

83 # TODO add a `distance` parameter to pick arbitrary distances, but since the 

84 # distance definition can not be changed arbitrarily, it may be better to 

85 # keep this class for the 2-norm and eventually add another if needed. 

86 

87 # TODO it is not efficient that the distance is computed separately for 

88 # each kernel in a kernel expression, but probably it would be difficult 

89 # to support everything without bugs while also computing the distance once. 

90 # A possible way is adding a keyword argument to the _kernel member 

91 # that kernels use to memoize things, the first IsotropicKernel that gets 

92 # called puts the distance there. Possible name: _cache. 

93 

94class IsotropicKernel(CrossIsotropicKernel, _stationary.StationaryKernel): 1efabcd

95 pass 1efabcd

96 

97IsotropicKernel.inherit_all_algops(intermediates=True) 1efabcd

98IsotropicKernel.inherit_transf('rescale', intermediates=True) 1efabcd

99IsotropicKernel.inherit_transf('loc', intermediates=True) 1efabcd

100IsotropicKernel.inherit_transf('scale', intermediates=True) 1efabcd

101IsotropicKernel.inherit_transf('maxdim', intermediates=True) 1efabcd

102IsotropicKernel.inherit_transf('derivable', intermediates=True) 1efabcd

103IsotropicKernel.inherit_transf('normalize', intermediates=True) 1efabcd

104IsotropicKernel.inherit_transf('cond', intermediates=True) 1efabcd

105 

106class CrossConstant(CrossIsotropicKernel): 1efabcd

107 pass 1efabcd

108 

109class Constant(CrossConstant, IsotropicKernel): 1efabcd

110 pass 1efabcd

111 

112def zero(x, y): 1efabcd

113 return jnp.broadcast_to(0., jnp.broadcast_shapes(x.shape, y.shape)) 1efabcd

114 

115class Zero(IsotropicKernel): 1efabcd

116 """ 

117 Represents a kernel that unconditionally yields zero. 

118 """ 

119 

120 def __new__(cls): 1efabcd

121 return super().__new__(cls, zero, input='raw') 1efabcd

122 

123_crosskernel.IsotropicKernel = IsotropicKernel 1efabcd

124_crosskernel.CrossIsotropicKernel = CrossIsotropicKernel 1efabcd

125_crosskernel.Constant = Constant 1efabcd

126_crosskernel.CrossConstant = CrossConstant 1efabcd