Coverage for src/lsqfitgp/_kernels/_zeta.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_kernels/_zeta.py 

2# 

3# Copyright (c) 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1fabcde

21 

22from jax import numpy as jnp 1fabcde

23 

24from .. import _special 1fabcde

25from .. import _jaxext 1fabcde

26from .. import _Kernel 1fabcde

27 

28def check_nu(nu): 1fabcde

29 with _jaxext.skipifabstract(): 1abcde

30 assert 0 <= nu < jnp.inf, nu 1abcde

31 

32def zeta_derivable(*, nu): 1fabcde

33 check_nu(nu) 1abcde

34 with _jaxext.skipifabstract(): 1abcde

35 return int(max(0, jnp.ceil(nu) - 1)) 1abcde

36 

37@_Kernel.crosskernel(bases=(_Kernel.AffineSpan, _Kernel.StationaryKernel,), maxdim=1, derivable=zeta_derivable) 1fabcde

38def Zeta(delta, *, nu, **_): 1fabcde

39 r""" 

40  

41 Zeta kernel. 

42  

43 .. math:: 

44 k(\Delta) 

45 &= \frac{\Re F(\Delta, s)}{\zeta(s)} = 

46 \qquad (s = 1 + 2 \nu, \quad \nu \ge 0) \\ 

47 &= \frac1{\zeta(s)} \sum_{k=1}^\infty 

48 \frac {\cos(2\pi k\Delta)} {k^s} = \\ 

49 &= -(-1)^{s/2} 

50 \frac {(2\pi)^s} {2s!} 

51 \frac {\tilde B_s(\Delta)} {\zeta(s)} 

52 \quad \text{for even integer $s$.} 

53  

54 It is equivalent to fitting with a Fourier series of period 1 with 

55 independent priors on the coefficients with mean zero and variance 

56 :math:`1/(\zeta(s)k^s)` for the :math:`k`-th term. Analogously to 

57 :class:`Matern`, the process is :math:`\lceil\nu\rceil - 1` times 

58 derivable, and the highest derivative is continuous iff :math:`\nu\bmod 1 

59 \ge 1/2`. 

60  

61 The :math:`k = 0` term is not included in the summation, so the mean of the 

62 process over one period is forced to be zero. 

63 

64 Reference: Petrillo (2022). 

65  

66 """ 

67 check_nu(nu) 1abcde

68 s = 1 + 2 * nu 1abcde

69 nupos = _special.periodic_zeta(delta, s) / _special.zeta(s) 1abcde

70 nuzero = jnp.where(delta % 1, 0, 1) 1abcde

71 return jnp.where(s > 1, nupos, nuzero) 1abcde

72 

73 # return -(-1) ** (s // 2) * _special.scaled_periodic_bernoulli(s, delta) / jspecial.zeta(s, 1) 

74 

75 # TODO use the bernoully version for integer even s, based on the type of 

76 # the input such that it's static, because it is much more accurate 

77 

78 # TODO ND version. The separable product is not equivalent I think. 

79 

80 # TODO the derivative w.r.t. nu is probably broken 

81 

82@_Kernel.kernel(maxdim=1, derivable=False) 1fabcde

83def ZetaFourier(k, q, *, nu, lloc, rloc, lscale, rscale, offset, ampl): 1fabcde

84 check_nu(nu) 1abcde

85 s = 1 + 2 * nu 1abcde

86 lorder = jnp.ceil(k / 2) 1abcde

87 rorder = jnp.ceil(q / 2) 1abcde

88 lodd = k % 2 1abcde

89 rodd = q % 2 1abcde

90 var = ampl / (lorder ** s * _special.zeta(s)) 1abcde

91 arg = 2 * jnp.pi * lorder * (lloc / lscale - rloc / rscale) 1abcde

92 return jnp.where(lorder == rorder, 1abcde

93 jnp.where(lodd == rodd, 

94 jnp.where(lorder, var * jnp.cos(arg), offset), 

95 var * jnp.sin(arg) * jnp.where(lodd, 1, -1), 

96 ), 

97 0, 

98 ) 

99 

100def crosszeta_derivable(*, nu, **_): 1fabcde

101 return 0, zeta_derivable(nu=nu) 1abcde

102 

103@_Kernel.crosskernel(bases=(_Kernel.PreservedBySwap, _Kernel.CrossKernel), maxdim=1, derivable=crosszeta_derivable) 1fabcde

104def CrossZetaFourier(k, y, *, nu, lloc, rloc, lscale, rscale, offset, ampl): 1fabcde

105 check_nu(nu) 1abcde

106 s = 1 + 2 * nu 1abcde

107 order = jnp.ceil(k / 2) 1abcde

108 odd = k % 2 1abcde

109 var = ampl / (order ** s * _special.zeta(s)) 1abcde

110 arg = 2 * jnp.pi * order * (lloc / lscale + (y - rloc) / rscale) 1abcde

111 return jnp.where(odd, 1abcde

112 var * jnp.sin(arg), 

113 jnp.where(order, var * jnp.cos(arg), offset), 

114 ) 

115 

116fourier_doc = r""" 1fabcde

117 

118Compute the Fourier series transform of the function. 

119 

120.. math:: 

121 

122 T(f)(k) = \begin{cases} 

123 \frac2T \int_0^T \mathrm dx\, f(x) 

124 \cos\left(\frac{2\pi}T \frac k2 x\right) 

125 & \text{if $k$ is even} \\ 

126 \frac2T \int_0^T \mathrm dx\, f(x) 

127 \sin\left(\frac{2\pi}T \frac{k+1}2 x\right) 

128 & \text{if $k$ is odd} 

129 \end{cases} 

130  

131The period :math:`T` is 1. 

132 

133""" 

134 

135def fourier_argparser(do): 1fabcde

136 return do if do else None 1abcde

137 

138def translkw(*, dynkw, **initkw): 1fabcde

139 return dict(**dynkw, **initkw) 1abcde

140 

141Zeta.make_linop_family('fourier', ZetaFourier, CrossZetaFourier, translkw=translkw, doc=fourier_doc, argparser=fourier_argparser) 1fabcde

142 

143# TODO 

144# - test the transf with rescalings (what cross check can I do?) 

145# - track affine transf in CrossZetaFourier too 

146# - make Zeta support non-sym affine ops (I think I need to define CrossZeta 

147# then subclass to Zeta(CrossZeta, Kernel) 

148# - consider renaming fourier to fourier_series when I rewrite transf system