Coverage for src/lsqfitgp/_Kernel/_stationary.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_Kernel/_stationary.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20from jax import numpy as jnp 1feabcd

21 

22from .. import _jaxext 1feabcd

23 

24from . import _util 1feabcd

25from . import _crosskernel 1feabcd

26from . import _kernel 1feabcd

27 

28class CrossStationaryKernel(_crosskernel.CrossKernel): 1feabcd

29 """ 

30  

31 Subclass of `CrossKernel` for stationary kernels. 

32 

33 A stationary kernel depends only on the difference, dimension by dimension, 

34 between its two arguments. 

35 

36 Parameters 

37 ---------- 

38 core : callable 

39 A function taking one positional argument ``delta = x - y`` and optional 

40 keyword arguments. 

41 input : {'signed', 'posabs', 'abs'}, default 'signed' 

42 If ``'signed'``, `kernel` is passed the bare difference. If 

43 ``'posabs'``, `kernel` is passed the absolute value of the difference, 

44 and the difference of equal points is a small number instead of zero. If 

45 ``'abs'``, the absolute value. 

46 **kw 

47 Additional keyword arguments are passed to the `CrossKernel` 

48 constructor. 

49  

50 """ 

51 

52 def __new__(cls, core, *, input='signed', **kw): 1feabcd

53 

54 if input == 'posabs': 1feabcd

55 dist = lambda x, y: _softabs(x - y) 1abcd

56 elif input == 'signed': 1feabcd

57 dist = lambda x, y: x - y 1feabcd

58 elif input == 'abs': 1eabcd

59 dist = lambda x, y: jnp.abs(x - y) 1eabcd

60 else: 

61 raise KeyError(input) 1abcd

62 

63 def newcore(x, y, **kw): 1feabcd

64 q = _util.ufunc_recurse_dtype(dist, x, y) 1feabcd

65 return core(q, **kw) 1feabcd

66 

67 return super().__new__(cls, newcore, **kw) 1feabcd

68 

69 # TODO this class requires that, on both inputs, the thing depends only on 

70 # the distance. However, when transforming a stationary kernel, I often have 

71 # the property only on either side. Making a left/right class interferes 

72 # with _swap, because it is not clean to switch classes in the ancestors, 

73 # so there would be a class for a single side with a property. 

74 # 

75 # Maybe the elegant way to keep left/right properties separated is having 

76 # separate classes for the two processes. 

77 # 

78 # Drop the whole generic hierarchy. Keep a single class Kernel. Have two 

79 # attributes left and right which are instances of Process. Linops are 

80 # applied as methods on left and right, which operate on a partial 

81 # evaluation of the core. Kernel.__getattr__ sees if the missing attribute 

82 # is a method defined on both left and right and calls both of them; this 

83 # allows potential overrides for the joint transformation. 

84 # 

85 # The class logic operates on left/right. Linops looks one process at a 

86 # time. Algops are defined on the whole kernel but need somehow to indicate 

87 # how to change the classes of processes. Maybe a list of class-preserving 

88 # algops in the process. 

89 

90class StationaryKernel(CrossStationaryKernel, _kernel.Kernel): 1feabcd

91 pass 1feabcd

92 

93# make these transformations preserve the class StationaryKernel and upwards, 

94# other transformations are added by IsotropicKernel 

95StationaryKernel.inherit_transf('dim', intermediates=True) 1feabcd

96 

97def _eps(x): 1feabcd

98 if jnp.issubdtype(x.dtype, jnp.inexact): 1abcd

99 return jnp.finfo(x.dtype).eps 1abcd

100 # finfo(x) does not work in numpy 1.20 

101 else: 

102 return jnp.finfo(jnp.empty(())).eps 1abcd

103 

104def _softabs(x): 1feabcd

105 return jnp.abs(x) + _eps(x) 1abcd