Coverage for src/lsqfitgp/_kernels/_zeta.py: 100%
47 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_kernels/_zeta.py
2#
3# Copyright (c) 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1fabcde
22from jax import numpy as jnp 1fabcde
24from .. import _special 1fabcde
25from .. import _jaxext 1fabcde
26from .. import _Kernel 1fabcde
28def check_nu(nu): 1fabcde
29 with _jaxext.skipifabstract(): 1abcde
30 assert 0 <= nu < jnp.inf, nu 1abcde
32def zeta_derivable(*, nu): 1fabcde
33 check_nu(nu) 1abcde
34 with _jaxext.skipifabstract(): 1abcde
35 return int(max(0, jnp.ceil(nu) - 1)) 1abcde
37@_Kernel.crosskernel(bases=(_Kernel.AffineSpan, _Kernel.StationaryKernel,), maxdim=1, derivable=zeta_derivable) 1fabcde
38def Zeta(delta, *, nu, **_): 1fabcde
39 r"""
41 Zeta kernel.
43 .. math::
44 k(\Delta)
45 &= \frac{\Re F(\Delta, s)}{\zeta(s)} =
46 \qquad (s = 1 + 2 \nu, \quad \nu \ge 0) \\
47 &= \frac1{\zeta(s)} \sum_{k=1}^\infty
48 \frac {\cos(2\pi k\Delta)} {k^s} = \\
49 &= -(-1)^{s/2}
50 \frac {(2\pi)^s} {2s!}
51 \frac {\tilde B_s(\Delta)} {\zeta(s)}
52 \quad \text{for even integer $s$.}
54 It is equivalent to fitting with a Fourier series of period 1 with
55 independent priors on the coefficients with mean zero and variance
56 :math:`1/(\zeta(s)k^s)` for the :math:`k`-th term. Analogously to
57 :class:`Matern`, the process is :math:`\lceil\nu\rceil - 1` times
58 derivable, and the highest derivative is continuous iff :math:`\nu\bmod 1
59 \ge 1/2`.
61 The :math:`k = 0` term is not included in the summation, so the mean of the
62 process over one period is forced to be zero.
64 Reference: Petrillo (2022).
66 """
67 check_nu(nu) 1abcde
68 s = 1 + 2 * nu 1abcde
69 nupos = _special.periodic_zeta(delta, s) / _special.zeta(s) 1abcde
70 nuzero = jnp.where(delta % 1, 0, 1) 1abcde
71 return jnp.where(s > 1, nupos, nuzero) 1abcde
73 # return -(-1) ** (s // 2) * _special.scaled_periodic_bernoulli(s, delta) / jspecial.zeta(s, 1)
75 # TODO use the bernoully version for integer even s, based on the type of
76 # the input such that it's static, because it is much more accurate
78 # TODO ND version. The separable product is not equivalent I think.
80 # TODO the derivative w.r.t. nu is probably broken
82@_Kernel.kernel(maxdim=1, derivable=False) 1fabcde
83def ZetaFourier(k, q, *, nu, lloc, rloc, lscale, rscale, offset, ampl): 1fabcde
84 check_nu(nu) 1abcde
85 s = 1 + 2 * nu 1abcde
86 lorder = jnp.ceil(k / 2) 1abcde
87 rorder = jnp.ceil(q / 2) 1abcde
88 lodd = k % 2 1abcde
89 rodd = q % 2 1abcde
90 var = ampl / (lorder ** s * _special.zeta(s)) 1abcde
91 arg = 2 * jnp.pi * lorder * (lloc / lscale - rloc / rscale) 1abcde
92 return jnp.where(lorder == rorder, 1abcde
93 jnp.where(lodd == rodd,
94 jnp.where(lorder, var * jnp.cos(arg), offset),
95 var * jnp.sin(arg) * jnp.where(lodd, 1, -1),
96 ),
97 0,
98 )
100def crosszeta_derivable(*, nu, **_): 1fabcde
101 return 0, zeta_derivable(nu=nu) 1abcde
103@_Kernel.crosskernel(bases=(_Kernel.PreservedBySwap, _Kernel.CrossKernel), maxdim=1, derivable=crosszeta_derivable) 1fabcde
104def CrossZetaFourier(k, y, *, nu, lloc, rloc, lscale, rscale, offset, ampl): 1fabcde
105 check_nu(nu) 1abcde
106 s = 1 + 2 * nu 1abcde
107 order = jnp.ceil(k / 2) 1abcde
108 odd = k % 2 1abcde
109 var = ampl / (order ** s * _special.zeta(s)) 1abcde
110 arg = 2 * jnp.pi * order * (lloc / lscale + (y - rloc) / rscale) 1abcde
111 return jnp.where(odd, 1abcde
112 var * jnp.sin(arg),
113 jnp.where(order, var * jnp.cos(arg), offset),
114 )
116fourier_doc = r""" 1fabcde
118Compute the Fourier series transform of the function.
120.. math::
122 T(f)(k) = \begin{cases}
123 \frac2T \int_0^T \mathrm dx\, f(x)
124 \cos\left(\frac{2\pi}T \frac k2 x\right)
125 & \text{if $k$ is even} \\
126 \frac2T \int_0^T \mathrm dx\, f(x)
127 \sin\left(\frac{2\pi}T \frac{k+1}2 x\right)
128 & \text{if $k$ is odd}
129 \end{cases}
131The period :math:`T` is 1.
133"""
135def fourier_argparser(do): 1fabcde
136 return do if do else None 1abcde
138def translkw(*, dynkw, **initkw): 1fabcde
139 return dict(**dynkw, **initkw) 1abcde
141Zeta.make_linop_family('fourier', ZetaFourier, CrossZetaFourier, translkw=translkw, doc=fourier_doc, argparser=fourier_argparser) 1fabcde
143# TODO
144# - test the transf with rescalings (what cross check can I do?)
145# - track affine transf in CrossZetaFourier too
146# - make Zeta support non-sym affine ops (I think I need to define CrossZeta
147# then subclass to Zeta(CrossZeta, Kernel)
148# - consider renaming fourier to fourier_series when I rewrite transf system