Coverage for src/lsqfitgp/_special/_bernoulli.py: 64%
37 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_special/_bernoulli.py
2#
3# Copyright (c) 2022, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1efabcd
22import numpy 1efabcd
23from scipy import special 1efabcd
24import jax 1efabcd
25from jax import numpy as jnp 1efabcd
26from jax.scipy import special as jspecial 1efabcd
28def periodic_bernoulli(n, x): 1efabcd
29 # TODO to make this jittable, hardcode size to 60 and truncate by writing
30 # zeros
31 n = int(n) 1abcd
32 bernoulli = special.bernoulli(n) 1abcd
33 k = numpy.arange(n + 1) 1abcd
34 binom = special.binom(n, k) 1abcd
35 coeffs = binom[::-1] * bernoulli 1abcd
36 x = x % 1 1abcd
37 cond = x < 0.5 1abcd
38 x = jnp.where(cond, x, 1 - x) 1abcd
39 out = jnp.polyval(coeffs, x) 1abcd
40 if n % 2 == 1: 1abcd
41 out = out * jnp.where(cond, 1, -1) 1abcd
42 return out 1abcd
44@functools.partial(jax.custom_jvp, nondiff_argnums=(0,)) 1efabcd
45def scaled_periodic_bernoulli(n, x): 1efabcd
46 """ periodic Bernoulli polynomial scaled such that B_n(0) = ζ(n) """
47 tau = 2 * jnp.pi
48 lognorm = n * jnp.log(tau) - jspecial.gammaln(n + 1)
49 norm = jnp.exp(lognorm) / 2
50 cond = n < 60
51 smalls = norm * periodic_bernoulli(n if cond else 1, x)
52 # n -> ∞: -Re e^2πix / i^n
53 # don't use 1j ** n because it is very inaccurate
54 arg = tau * x
55 sign = jnp.where((n // 2) % 2, 1, -1)
56 larges = sign * jnp.where(n % 2, jnp.sin(arg), jnp.cos(arg))
57 return jnp.where(cond, smalls, larges)
59@scaled_periodic_bernoulli.defjvp 1efabcd
60def _scaled_periodic_bernoulli_jvp(n, primals, tangents): 1efabcd
61 x, = primals
62 xt, = tangents
63 primal = scaled_periodic_bernoulli(n, x)
64 tangent = 2 * jnp.pi * scaled_periodic_bernoulli(n - 1, x) * xt
65 return primal, tangent