Coverage for src/lsqfitgp/_special/_zeta.py: 100%
164 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_special/_zeta.py
2#
3# Copyright (c) 2022, 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import collections 1feabcd
21import functools 1feabcd
22import math 1feabcd
24import jax 1feabcd
25from jax import lax 1feabcd
26from jax import numpy as jnp 1feabcd
27from jax.scipy import special as jspecial 1feabcd
29from .. import _jaxext 1feabcd
30from . import _gamma 1feabcd
32def hurwitz_zeta_series(m, x, a1, onlyeven=False, onlyodd=False, skipterm=None): 1feabcd
33 """
34 hurwitz zeta(s = m + x, a = 1 - a1) with integer m
35 meant to be used with |x| ≤ 1/2, but no actual restriction
36 assuming -S <= s <= 0 and |a1| <= 1/2 with S ~ some decade
37 https://dlmf.nist.gov/25.11.E10
38 """
40 # decide number of terms to sum
41 t = _jaxext.float_type(m, x, a1) 1eabcd
42 nmax = hze_nmax(t) 1eabcd
43 n = jnp.arange(nmax + 1) 1eabcd
45 # make arguments broadcastable with n
46 x = x[..., None] 1eabcd
47 m = m[..., None] 1eabcd
48 a1 = a1[..., None] 1eabcd
49 if skipterm is not None: 1eabcd
50 skipterm = skipterm[..., None] 1eabcd
52 # compute pochhammer symbol, factorial and power terms
53 nm = n + m 1eabcd
54 ns1 = nm - 1 + x # = n + s - 1 1eabcd
55 ns1_limit = jnp.where(ns1 == 0, 1, ns1) # pochhammer zero cancels zeta pole 1eabcd
56 ns1_limit = jnp.where(ns1 == 1, 0, ns1_limit) 1eabcd
57 # TODO this == 1 worries me, maybe sometimes it's violated, use shift
58 factor = jnp.cumprod((ns1_limit * a1 / n).at[..., 0].set(1), -1, t) 1eabcd
60 # handle tweaks to the series
61 if onlyeven: 1eabcd
62 sl = slice(None, None, 2) 1eabcd
63 elif onlyodd: 1abcd
64 sl = slice(1, None, 2) 1abcd
65 if onlyeven or onlyodd: 1eabcd
66 n = n[sl] 1eabcd
67 nm = nm[..., sl] 1eabcd
68 ns1 = ns1[..., sl] 1eabcd
69 factor = factor[..., sl] 1eabcd
70 if skipterm is not None: 1eabcd
71 factor = jnp.where(n == skipterm, 0, factor) 1eabcd
73 # compute zeta term
74 zet = zeta(x, nm) # = zeta(n + s) 1eabcd
75 zet_limit = jnp.where(ns1 == 0, 1, zet) # pole cancelled by pochhammer 1eabcd
77 # sum series
78 kw = dict(precision=lax.Precision.HIGHEST) 1eabcd
79 series = jnp.matmul(factor[..., None, :], zet_limit[..., :, None], **kw) 1eabcd
80 return series.squeeze((-2, -1)) 1eabcd
82def hze_nmax(t): 1feabcd
83 minz = 0.0037 # = min(2 gamma(s) / (2 pi)^s) for s <= 0 1eabcd
84 return int(math.ceil(-math.log2(jnp.finfo(t).eps * minz))) 1eabcd
86# @jax.jit
87def hurwitz_zeta(s, a): 1feabcd
88 """
89 For 0 <= a <= 1 and -S <= s <= 0 with S not too large
90 """
91 s = jnp.asarray(s) 1abcd
92 a = jnp.asarray(a) 1abcd
94 cond = a < 1/2 # do a + 1 to bring a closer to 1 1abcd
95 a1 = jnp.where(cond, -a, 1. - a) 1abcd
96 zero = jnp.array(0) 1abcd
97 zeta = hurwitz_zeta_series(zero, s, a1) # https://dlmf.nist.gov/25.11.E10 1abcd
98 zeta += jnp.where(cond, a ** -s, 0) # https://dlmf.nist.gov/25.11.E3 1abcd
99 return zeta 1abcd
101 # https://specialfunctions.juliamath.org/stable/functions_list/#SpecialFunctions.zeta
103@functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2)) 1feabcd
104# @functools.partial(jax.jit, static_argnums=(2,))
105def periodic_zeta(x, s, imag=False): 1feabcd
106 """
107 compute F(x,s) = Li_s(e^2πix) for real s > 1, real x
108 """
110 x = jnp.asarray(x) 1eabcd
111 s = jnp.asarray(s) 1eabcd
113 # decide boundary for large/small s implementation
114 t = _jaxext.float_type(x, s) 1eabcd
115 eps = jnp.finfo(t).eps 1eabcd
116 nmax = 50 1eabcd
117 larges = math.ceil(-math.log(eps) / math.log(nmax)) # 1/nmax^s < eps 1eabcd
119 z_smalls = periodic_zeta_smalls(x, s, imag) 1eabcd
120 z_larges = periodic_zeta_larges(x, s, nmax, imag) 1eabcd
122 return jnp.where(s < larges, z_smalls, z_larges) 1eabcd
124@periodic_zeta.defjvp 1feabcd
125def periodic_zeta_jvp(s, imag, p, t): 1feabcd
126 x, = p 1abcd
127 xt, = t 1abcd
128 primal = periodic_zeta(x, s, imag) 1abcd
129 sgn = 1 if imag else -1 1abcd
130 tangent = 2 * jnp.pi * sgn * periodic_zeta(x, s - 1, not imag) * xt 1abcd
131 return primal, tangent 1abcd
133def standard_x(x): 1feabcd
134 """ bring x in [0, 1/2] by modulus and reflection """
135 x %= 1 1eabcd
136 neg = x > 1/2 1eabcd
137 return neg, jnp.where(neg, 1 - x, x) 1eabcd
139def periodic_zeta_larges(x, s, nmax, imag): 1feabcd
140 """ https://dlmf.nist.gov/25.13.E1 """
142 t = _jaxext.float_type(x, s) 1eabcd
143 s = s.astype(t) # avoid n^s overflow with integer s 1eabcd
144 n = jnp.arange(1, nmax + 1) 1eabcd
145 neg, nx = standard_x(n * x[..., None]) 1eabcd
146 func = jnp.sin if imag else jnp.cos 1eabcd
147 terms = func(2 * jnp.pi * nx) / n ** s[..., None] 1eabcd
148 if imag: 1eabcd
149 terms *= jnp.where(neg, -1, 1) 1abcd
150 return jnp.sum(terms, -1) 1eabcd
152def periodic_zeta_smalls(x, s, imag): 1feabcd
153 """
154 https://dlmf.nist.gov/25.11.E10 and https://dlmf.nist.gov/25.11.E3 expanded
155 into https://dlmf.nist.gov/25.13.E2
156 """
157 neg, x = standard_x(x) # x in [0, 1/2] 1eabcd
159 eps = jnp.finfo(_jaxext.float_type(x, s)).eps 1eabcd
160 s = jnp.where(s % 1, s, s * (1 + eps)) # avoid integer s 1eabcd
162 s1 = 1 - s # < 0 1eabcd
163 q = -jnp.around(s1).astype(int) 1eabcd
164 a = s1 + q 1eabcd
165 # now s1 == -q + a with q integer >= 0 and |a| <= 1/2
167 pi = (2 * jnp.pi) ** -s1 1eabcd
168 gam = _gamma.gamma(s1) 1eabcd
169 func = sin_pi2 if imag else cos_pi2 1eabcd
170 pha = func(-q, a) # = sin or cos(π/2 s1), numerically accurate for small a 1eabcd
171 hzs = 2 * hurwitz_zeta_series(-q, a, -x, onlyeven=not imag, onlyodd=imag, skipterm=q) 1eabcd
172 # hzs = ζ(s1,1+x) -/+ ζ(s1,1-x) but without the x^q term in the series
173 pdiff = zeta_series_power_diff(x, q, a) 1eabcd
174 # pdiff accurately handles the sum of the external power x^-s1 due to
175 # 25.11.E3 with the q-th term (cancellation) with even q for the real part
176 # and odd q for the imaginary part
177 cancelcond = jnp.logical_and(imag, q % 2 == 1) 1eabcd
178 cancelcond |= jnp.logical_and(not imag, q % 2 == 0) 1eabcd
179 power = jnp.where(cancelcond, pdiff, x ** -s1) 1eabcd
180 hz = power + hzs # = ζ(s1,x) -/+ ζ(s1,1-x) 1eabcd
182 out = (pi * gam * pha) * hz 1eabcd
183 if imag: 1eabcd
184 out *= jnp.where(neg, -1, 1) 1abcd
185 return out 1eabcd
187def cos_pi2(n, x): 1feabcd
188 """ compute cos(π/2 (n + x)) for n integer, accurate for small x """
189 arg = -jnp.pi / 2 * x 1eabcd
190 cos = jnp.where(n % 2, jnp.sin(arg), jnp.cos(arg)) 1eabcd
191 return cos * jnp.where(n // 2 % 2, -1, 1) 1eabcd
193def sin_pi2(n, x): 1feabcd
194 return cos_pi2(n - 1, x) 1abcd
196def zeta_series_power_diff(x, q, a): 1feabcd
197 """
198 compute x^q-a + (-1)^q * [q-th term of 2 * hurwitz_zeta_series(-q, a, x)]
199 """
200 pint = x ** q 1eabcd
201 pz = jnp.where(q, 0, jnp.where(a, -1, 0)) # * 0^q = 0^q-a - 0^q 1eabcd
202 pdif = jnp.where(x, jnp.expm1(-a * jnp.log(x)), pz) # * x^q = x^q-a - x^q 1eabcd
203 gamincr = jnp.where(q, _gamma.gamma_incr(1 + q, -a), 0) 1eabcd
204 # gamincr = Γ(1+q-a) / Γ(1+q)Γ(1-a) - 1
205 zz = zeta_zero(a) # = ζ(a) - ζ(0) 1eabcd
206 qdif = 2 * (1 + gamincr) * zz - gamincr # = (q-th term) - (q-th term)|_a=0 1eabcd
207 return pint * (pdif + qdif) 1eabcd
209def zeta_zero(s): 1feabcd
210 """
211 Compute zeta(s) - zeta(0) for |s| < 1 accurately
212 """
214 # f(s) = zeta(s) - 1 / (s - 1)
215 # I have the Taylor series of f(s)
216 # zeta(s) - zeta(0) = f(s) + 1 / (s - 1) + 1/2 =
217 # = f(s) + 1/(s-1) + 1 - 1 + 1/2 =
218 # = f(s) - 1/2 + s/(s-1)
220 t = _jaxext.float_type(s) 1eabcd
221 coef = jnp.array(zeta_zero_coef, t).at[0].set(0) 1eabcd
222 fact = jnp.cumprod(jnp.arange(coef.size).at[0].set(1), dtype=t) 1eabcd
223 coef /= fact 1eabcd
224 f = jnp.polyval(coef[::-1], s) 1eabcd
225 return f + s / (s - 1) 1eabcd
227zeta_zero_coef = [ # = gen_zeta_zero_coef(17) 1feabcd
228 0.5,
229 0.08106146679532726,
230 -0.006356455908584851,
231 -0.004711166862254448,
232 0.002896811986292041,
233 -0.00023290755845472455,
234 -0.0009368251300509295,
235 0.0008498237650016692,
236 -0.00023243173551155957,
237 -0.00033058966361229646,
238 0.0005432341157797085,
239 -0.00037549317290726367,
240 -1.960353628101392e-05,
241 0.00040724123256303315,
242 -0.0005704920132817777,
243 0.0003939270789812044,
244 8.345880582550168e-05,
245]
247def gen_zeta_zero_coef(n): # pragma: no cover 1feabcd
248 """
249 Compute first n derivatives of zeta(s) - 1/(s-1) at s = 0
250 """
251 import mpmath as mp
252 with mp.workdps(32):
253 func = lambda s: mp.zeta(s) - 1 / (s - 1)
254 return [float(mp.diff(func, 0, k)) for k in range(n)]
256# @jax.jit
257def zeta(s, n=0): 1feabcd
258 """ compute ζ(n + s) with integer n, accurate for even n < 0 and small s """
259 s = jnp.asarray(s) 1eabcd
260 return jnp.where(n + s >= 0, 1eabcd
261 zeta_0_inf(n + s),
262 zeta_neg(s, n),
263 )
265def zeta_neg(s, n): 1feabcd
266 # reflection formula https://dlmf.nist.gov/25.4.E1
267 m = 1 - n 1eabcd
268 x = -s 1eabcd
269 # m + x = 1 - (n + s) = 1 - n - s
270 mx = m + x # > 1 1eabcd
271 logpi = -mx * jnp.log(2 * jnp.pi) 1eabcd
272 cos = cos_pi2(m, x) # = cos(π/2 (m + x)) but accurate for small x 1eabcd
273 loggam = jspecial.gammaln(mx) 1eabcd
274 zet = zeta_0_inf(mx) 1eabcd
276 # cancel zeta pole at 1
277 cos = jnp.where(mx == 1, -jnp.pi / 2, cos) 1eabcd
278 zet = jnp.where(mx == 1, 1, zet) 1eabcd
280 return 2 * jnp.exp(logpi + loggam) * cos * zet 1eabcd
282# Below I have my custom implementation of zeta. jax.scipy.special.zeta does not
283# work in (0, 1) (last checked v0.4.34), but I don't remember if I actually
284# need that interval, maybe I always use s > 1.
286##########################################################################
287# The following is adapted from gsl/specfunc/zeta.c (GPL license) #
288##########################################################################
290ChebSeries = collections.namedtuple('ChebSeries', 'c a b') 1feabcd
292def cheb_eval_e(cs, x): 1feabcd
293 d = 0.0 1eabcd
294 dd = 0.0 1eabcd
295 y = (2.0 * x - cs.a - cs.b) / (cs.b - cs.a) 1eabcd
296 y2 = 2.0 * y 1eabcd
298 for c in cs.c[:0:-1]: 1eabcd
299 d, dd = y2 * d - dd + c, d 1eabcd
301 return y * d - dd + 0.5 * cs.c[0] 1eabcd
303# chebyshev fit for (s(t)-1)Zeta[s(t)]
304# s(t)= (t+1)/2
305# -1 <= t <= 1
306zeta_xlt1_cs = ChebSeries(jnp.array([ 1feabcd
307 1.48018677156931561235192914649,
308 0.25012062539889426471999938167,
309 0.00991137502135360774243761467,
310 -0.00012084759656676410329833091,
311 -4.7585866367662556504652535281e-06,
312 2.2229946694466391855561441361e-07,
313 -2.2237496498030257121309056582e-09,
314 -1.0173226513229028319420799028e-10,
315 4.3756643450424558284466248449e-12,
316 -6.2229632593100551465504090814e-14,
317 -6.6116201003272207115277520305e-16,
318 4.9477279533373912324518463830e-17,
319 -1.0429819093456189719660003522e-18,
320 6.9925216166580021051464412040e-21,
321]), -1, 1)
323# chebyshev fit for (s(t)-1)Zeta[s(t)]
324# s(t)= (19t+21)/2
325# -1 <= t <= 1
326zeta_xgt1_cs = ChebSeries(jnp.array([ 1feabcd
327 19.3918515726724119415911269006,
328 9.1525329692510756181581271500,
329 0.2427897658867379985365270155,
330 -0.1339000688262027338316641329,
331 0.0577827064065028595578410202,
332 -0.0187625983754002298566409700,
333 0.0039403014258320354840823803,
334 -0.0000581508273158127963598882,
335 -0.0003756148907214820704594549,
336 0.0001892530548109214349092999,
337 -0.0000549032199695513496115090,
338 8.7086484008939038610413331863e-6,
339 6.4609477924811889068410083425e-7,
340 -9.6749773915059089205835337136e-7,
341 3.6585400766767257736982342461e-7,
342 -8.4592516427275164351876072573e-8,
343 9.9956786144497936572288988883e-9,
344 1.4260036420951118112457144842e-9,
345 -1.1761968823382879195380320948e-9,
346 3.7114575899785204664648987295e-10,
347 -7.4756855194210961661210215325e-11,
348 7.8536934209183700456512982968e-12,
349 9.9827182259685539619810406271e-13,
350 -7.5276687030192221587850302453e-13,
351 2.1955026393964279988917878654e-13,
352 -4.1934859852834647427576319246e-14,
353 4.6341149635933550715779074274e-15,
354 2.3742488509048340106830309402e-16,
355 -2.7276516388124786119323824391e-16,
356 7.8473570134636044722154797225e-17
357]), -1, 1)
359def zeta_0_1(s): 1feabcd
360 return cheb_eval_e(zeta_xlt1_cs, 2.0 * s - 1.0) / (s - 1.0) 1eabcd
362def zeta_1_20(s): 1feabcd
363 return cheb_eval_e(zeta_xgt1_cs, (2.0 * s - 21.0) / 19.0) / (s - 1.0) 1eabcd
365def zeta_20_inf(s): 1feabcd
366 f2 = 1.0 - 2.0 ** -s 1eabcd
367 f3 = 1.0 - 3.0 ** -s 1eabcd
368 f5 = 1.0 - 5.0 ** -s 1eabcd
369 f7 = 1.0 - 7.0 ** -s 1eabcd
370 return 1.0 / (f2 * f3 * f5 * f7); 1eabcd
372def zeta_0_inf(s): 1feabcd
373 return jnp.where(s >= 20, 1eabcd
374 zeta_20_inf(s),
375 jnp.where(s >= 1,
376 zeta_1_20(s),
377 zeta_0_1(s),
378 ),
379 )
381##########################################################################