Coverage for src/lsqfitgp/_special/_bernoulli.py: 64%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_special/_bernoulli.py 

2# 

3# Copyright (c) 2022, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1efabcd

21 

22import numpy 1efabcd

23from scipy import special 1efabcd

24import jax 1efabcd

25from jax import numpy as jnp 1efabcd

26from jax.scipy import special as jspecial 1efabcd

27 

28def periodic_bernoulli(n, x): 1efabcd

29 # TODO to make this jittable, hardcode size to 60 and truncate by writing 

30 # zeros 

31 n = int(n) 1abcd

32 bernoulli = special.bernoulli(n) 1abcd

33 k = numpy.arange(n + 1) 1abcd

34 binom = special.binom(n, k) 1abcd

35 coeffs = binom[::-1] * bernoulli 1abcd

36 x = x % 1 1abcd

37 cond = x < 0.5 1abcd

38 x = jnp.where(cond, x, 1 - x) 1abcd

39 out = jnp.polyval(coeffs, x) 1abcd

40 if n % 2 == 1: 1abcd

41 out = out * jnp.where(cond, 1, -1) 1abcd

42 return out 1abcd

43 

44@functools.partial(jax.custom_jvp, nondiff_argnums=(0,)) 1efabcd

45def scaled_periodic_bernoulli(n, x): 1efabcd

46 """ periodic Bernoulli polynomial scaled such that B_n(0) = ζ(n) """ 

47 tau = 2 * jnp.pi 

48 lognorm = n * jnp.log(tau) - jspecial.gammaln(n + 1) 

49 norm = jnp.exp(lognorm) / 2 

50 cond = n < 60 

51 smalls = norm * periodic_bernoulli(n if cond else 1, x) 

52 # n -> ∞: -Re e^2πix / i^n 

53 # don't use 1j ** n because it is very inaccurate 

54 arg = tau * x 

55 sign = jnp.where((n // 2) % 2, 1, -1) 

56 larges = sign * jnp.where(n % 2, jnp.sin(arg), jnp.cos(arg)) 

57 return jnp.where(cond, smalls, larges) 

58 

59@scaled_periodic_bernoulli.defjvp 1efabcd

60def _scaled_periodic_bernoulli_jvp(n, primals, tangents): 1efabcd

61 x, = primals 

62 xt, = tangents 

63 primal = scaled_periodic_bernoulli(n, x) 

64 tangent = 2 * jnp.pi * scaled_periodic_bernoulli(n - 1, x) * xt 

65 return primal, tangent