Coverage for src/lsqfitgp/_special/_zeta.py: 100%

164 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_special/_zeta.py 

2# 

3# Copyright (c) 2022, 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import collections 1feabcd

21import functools 1feabcd

22import math 1feabcd

23 

24import jax 1feabcd

25from jax import lax 1feabcd

26from jax import numpy as jnp 1feabcd

27from jax.scipy import special as jspecial 1feabcd

28 

29from .. import _jaxext 1feabcd

30from . import _gamma 1feabcd

31 

32def hurwitz_zeta_series(m, x, a1, onlyeven=False, onlyodd=False, skipterm=None): 1feabcd

33 """ 

34 hurwitz zeta(s = m + x, a = 1 - a1) with integer m 

35 meant to be used with |x| ≤ 1/2, but no actual restriction 

36 assuming -S <= s <= 0 and |a1| <= 1/2 with S ~ some decade 

37 https://dlmf.nist.gov/25.11.E10 

38 """ 

39 

40 # decide number of terms to sum 

41 t = _jaxext.float_type(m, x, a1) 1eabcd

42 nmax = hze_nmax(t) 1eabcd

43 n = jnp.arange(nmax + 1) 1eabcd

44 

45 # make arguments broadcastable with n 

46 x = x[..., None] 1eabcd

47 m = m[..., None] 1eabcd

48 a1 = a1[..., None] 1eabcd

49 if skipterm is not None: 1eabcd

50 skipterm = skipterm[..., None] 1eabcd

51 

52 # compute pochhammer symbol, factorial and power terms 

53 nm = n + m 1eabcd

54 ns1 = nm - 1 + x # = n + s - 1 1eabcd

55 ns1_limit = jnp.where(ns1 == 0, 1, ns1) # pochhammer zero cancels zeta pole 1eabcd

56 ns1_limit = jnp.where(ns1 == 1, 0, ns1_limit) 1eabcd

57 # TODO this == 1 worries me, maybe sometimes it's violated, use shift 

58 factor = jnp.cumprod((ns1_limit * a1 / n).at[..., 0].set(1), -1, t) 1eabcd

59 

60 # handle tweaks to the series 

61 if onlyeven: 1eabcd

62 sl = slice(None, None, 2) 1eabcd

63 elif onlyodd: 1abcd

64 sl = slice(1, None, 2) 1abcd

65 if onlyeven or onlyodd: 1eabcd

66 n = n[sl] 1eabcd

67 nm = nm[..., sl] 1eabcd

68 ns1 = ns1[..., sl] 1eabcd

69 factor = factor[..., sl] 1eabcd

70 if skipterm is not None: 1eabcd

71 factor = jnp.where(n == skipterm, 0, factor) 1eabcd

72 

73 # compute zeta term 

74 zet = zeta(x, nm) # = zeta(n + s) 1eabcd

75 zet_limit = jnp.where(ns1 == 0, 1, zet) # pole cancelled by pochhammer 1eabcd

76 

77 # sum series 

78 kw = dict(precision=lax.Precision.HIGHEST) 1eabcd

79 series = jnp.matmul(factor[..., None, :], zet_limit[..., :, None], **kw) 1eabcd

80 return series.squeeze((-2, -1)) 1eabcd

81 

82def hze_nmax(t): 1feabcd

83 minz = 0.0037 # = min(2 gamma(s) / (2 pi)^s) for s <= 0 1eabcd

84 return int(math.ceil(-math.log2(jnp.finfo(t).eps * minz))) 1eabcd

85 

86# @jax.jit 

87def hurwitz_zeta(s, a): 1feabcd

88 """ 

89 For 0 <= a <= 1 and -S <= s <= 0 with S not too large 

90 """ 

91 s = jnp.asarray(s) 1abcd

92 a = jnp.asarray(a) 1abcd

93 

94 cond = a < 1/2 # do a + 1 to bring a closer to 1 1abcd

95 a1 = jnp.where(cond, -a, 1. - a) 1abcd

96 zero = jnp.array(0) 1abcd

97 zeta = hurwitz_zeta_series(zero, s, a1) # https://dlmf.nist.gov/25.11.E10 1abcd

98 zeta += jnp.where(cond, a ** -s, 0) # https://dlmf.nist.gov/25.11.E3 1abcd

99 return zeta 1abcd

100 

101 # https://specialfunctions.juliamath.org/stable/functions_list/#SpecialFunctions.zeta 

102 

103@functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2)) 1feabcd

104# @functools.partial(jax.jit, static_argnums=(2,)) 

105def periodic_zeta(x, s, imag=False): 1feabcd

106 """ 

107 compute F(x,s) = Li_s(e^2πix) for real s > 1, real x 

108 """ 

109 

110 x = jnp.asarray(x) 1eabcd

111 s = jnp.asarray(s) 1eabcd

112 

113 # decide boundary for large/small s implementation 

114 t = _jaxext.float_type(x, s) 1eabcd

115 eps = jnp.finfo(t).eps 1eabcd

116 nmax = 50 1eabcd

117 larges = math.ceil(-math.log(eps) / math.log(nmax)) # 1/nmax^s < eps 1eabcd

118 

119 z_smalls = periodic_zeta_smalls(x, s, imag) 1eabcd

120 z_larges = periodic_zeta_larges(x, s, nmax, imag) 1eabcd

121 

122 return jnp.where(s < larges, z_smalls, z_larges) 1eabcd

123 

124@periodic_zeta.defjvp 1feabcd

125def periodic_zeta_jvp(s, imag, p, t): 1feabcd

126 x, = p 1abcd

127 xt, = t 1abcd

128 primal = periodic_zeta(x, s, imag) 1abcd

129 sgn = 1 if imag else -1 1abcd

130 tangent = 2 * jnp.pi * sgn * periodic_zeta(x, s - 1, not imag) * xt 1abcd

131 return primal, tangent 1abcd

132 

133def standard_x(x): 1feabcd

134 """ bring x in [0, 1/2] by modulus and reflection """ 

135 x %= 1 1eabcd

136 neg = x > 1/2 1eabcd

137 return neg, jnp.where(neg, 1 - x, x) 1eabcd

138 

139def periodic_zeta_larges(x, s, nmax, imag): 1feabcd

140 """ https://dlmf.nist.gov/25.13.E1 """ 

141 

142 t = _jaxext.float_type(x, s) 1eabcd

143 s = s.astype(t) # avoid n^s overflow with integer s 1eabcd

144 n = jnp.arange(1, nmax + 1) 1eabcd

145 neg, nx = standard_x(n * x[..., None]) 1eabcd

146 func = jnp.sin if imag else jnp.cos 1eabcd

147 terms = func(2 * jnp.pi * nx) / n ** s[..., None] 1eabcd

148 if imag: 1eabcd

149 terms *= jnp.where(neg, -1, 1) 1abcd

150 return jnp.sum(terms, -1) 1eabcd

151 

152def periodic_zeta_smalls(x, s, imag): 1feabcd

153 """ 

154 https://dlmf.nist.gov/25.11.E10 and https://dlmf.nist.gov/25.11.E3 expanded 

155 into https://dlmf.nist.gov/25.13.E2 

156 """ 

157 neg, x = standard_x(x) # x in [0, 1/2] 1eabcd

158 

159 eps = jnp.finfo(_jaxext.float_type(x, s)).eps 1eabcd

160 s = jnp.where(s % 1, s, s * (1 + eps)) # avoid integer s 1eabcd

161 

162 s1 = 1 - s # < 0 1eabcd

163 q = -jnp.around(s1).astype(int) 1eabcd

164 a = s1 + q 1eabcd

165 # now s1 == -q + a with q integer >= 0 and |a| <= 1/2 

166 

167 pi = (2 * jnp.pi) ** -s1 1eabcd

168 gam = _gamma.gamma(s1) 1eabcd

169 func = sin_pi2 if imag else cos_pi2 1eabcd

170 pha = func(-q, a) # = sin or cos(π/2 s1), numerically accurate for small a 1eabcd

171 hzs = 2 * hurwitz_zeta_series(-q, a, -x, onlyeven=not imag, onlyodd=imag, skipterm=q) 1eabcd

172 # hzs = ζ(s1,1+x) -/+ ζ(s1,1-x) but without the x^q term in the series 

173 pdiff = zeta_series_power_diff(x, q, a) 1eabcd

174 # pdiff accurately handles the sum of the external power x^-s1 due to 

175 # 25.11.E3 with the q-th term (cancellation) with even q for the real part 

176 # and odd q for the imaginary part 

177 cancelcond = jnp.logical_and(imag, q % 2 == 1) 1eabcd

178 cancelcond |= jnp.logical_and(not imag, q % 2 == 0) 1eabcd

179 power = jnp.where(cancelcond, pdiff, x ** -s1) 1eabcd

180 hz = power + hzs # = ζ(s1,x) -/+ ζ(s1,1-x) 1eabcd

181 

182 out = (pi * gam * pha) * hz 1eabcd

183 if imag: 1eabcd

184 out *= jnp.where(neg, -1, 1) 1abcd

185 return out 1eabcd

186 

187def cos_pi2(n, x): 1feabcd

188 """ compute cos(π/2 (n + x)) for n integer, accurate for small x """ 

189 arg = -jnp.pi / 2 * x 1eabcd

190 cos = jnp.where(n % 2, jnp.sin(arg), jnp.cos(arg)) 1eabcd

191 return cos * jnp.where(n // 2 % 2, -1, 1) 1eabcd

192 

193def sin_pi2(n, x): 1feabcd

194 return cos_pi2(n - 1, x) 1abcd

195 

196def zeta_series_power_diff(x, q, a): 1feabcd

197 """ 

198 compute x^q-a + (-1)^q * [q-th term of 2 * hurwitz_zeta_series(-q, a, x)] 

199 """ 

200 pint = x ** q 1eabcd

201 pz = jnp.where(q, 0, jnp.where(a, -1, 0)) # * 0^q = 0^q-a - 0^q 1eabcd

202 pdif = jnp.where(x, jnp.expm1(-a * jnp.log(x)), pz) # * x^q = x^q-a - x^q 1eabcd

203 gamincr = jnp.where(q, _gamma.gamma_incr(1 + q, -a), 0) 1eabcd

204 # gamincr = Γ(1+q-a) / Γ(1+q)Γ(1-a) - 1 

205 zz = zeta_zero(a) # = ζ(a) - ζ(0) 1eabcd

206 qdif = 2 * (1 + gamincr) * zz - gamincr # = (q-th term) - (q-th term)|_a=0 1eabcd

207 return pint * (pdif + qdif) 1eabcd

208 

209def zeta_zero(s): 1feabcd

210 """ 

211 Compute zeta(s) - zeta(0) for |s| < 1 accurately 

212 """ 

213 

214 # f(s) = zeta(s) - 1 / (s - 1) 

215 # I have the Taylor series of f(s) 

216 # zeta(s) - zeta(0) = f(s) + 1 / (s - 1) + 1/2 = 

217 # = f(s) + 1/(s-1) + 1 - 1 + 1/2 = 

218 # = f(s) - 1/2 + s/(s-1) 

219 

220 t = _jaxext.float_type(s) 1eabcd

221 coef = jnp.array(zeta_zero_coef, t).at[0].set(0) 1eabcd

222 fact = jnp.cumprod(jnp.arange(coef.size).at[0].set(1), dtype=t) 1eabcd

223 coef /= fact 1eabcd

224 f = jnp.polyval(coef[::-1], s) 1eabcd

225 return f + s / (s - 1) 1eabcd

226 

227zeta_zero_coef = [ # = gen_zeta_zero_coef(17) 1feabcd

228 0.5, 

229 0.08106146679532726, 

230 -0.006356455908584851, 

231 -0.004711166862254448, 

232 0.002896811986292041, 

233 -0.00023290755845472455, 

234 -0.0009368251300509295, 

235 0.0008498237650016692, 

236 -0.00023243173551155957, 

237 -0.00033058966361229646, 

238 0.0005432341157797085, 

239 -0.00037549317290726367, 

240 -1.960353628101392e-05, 

241 0.00040724123256303315, 

242 -0.0005704920132817777, 

243 0.0003939270789812044, 

244 8.345880582550168e-05, 

245] 

246 

247def gen_zeta_zero_coef(n): # pragma: no cover 1feabcd

248 """ 

249 Compute first n derivatives of zeta(s) - 1/(s-1) at s = 0 

250 """ 

251 import mpmath as mp 

252 with mp.workdps(32): 

253 func = lambda s: mp.zeta(s) - 1 / (s - 1) 

254 return [float(mp.diff(func, 0, k)) for k in range(n)] 

255 

256# @jax.jit 

257def zeta(s, n=0): 1feabcd

258 """ compute ζ(n + s) with integer n, accurate for even n < 0 and small s """ 

259 s = jnp.asarray(s) 1eabcd

260 return jnp.where(n + s >= 0, 1eabcd

261 zeta_0_inf(n + s), 

262 zeta_neg(s, n), 

263 ) 

264 

265def zeta_neg(s, n): 1feabcd

266 # reflection formula https://dlmf.nist.gov/25.4.E1 

267 m = 1 - n 1eabcd

268 x = -s 1eabcd

269 # m + x = 1 - (n + s) = 1 - n - s 

270 mx = m + x # > 1 1eabcd

271 logpi = -mx * jnp.log(2 * jnp.pi) 1eabcd

272 cos = cos_pi2(m, x) # = cos(π/2 (m + x)) but accurate for small x 1eabcd

273 loggam = jspecial.gammaln(mx) 1eabcd

274 zet = zeta_0_inf(mx) 1eabcd

275 

276 # cancel zeta pole at 1 

277 cos = jnp.where(mx == 1, -jnp.pi / 2, cos) 1eabcd

278 zet = jnp.where(mx == 1, 1, zet) 1eabcd

279 

280 return 2 * jnp.exp(logpi + loggam) * cos * zet 1eabcd

281 

282# Below I have my custom implementation of zeta. jax.scipy.special.zeta does not 

283# work in (0, 1) (last checked v0.4.34), but I don't remember if I actually 

284# need that interval, maybe I always use s > 1. 

285 

286########################################################################## 

287# The following is adapted from gsl/specfunc/zeta.c (GPL license) # 

288########################################################################## 

289 

290ChebSeries = collections.namedtuple('ChebSeries', 'c a b') 1feabcd

291 

292def cheb_eval_e(cs, x): 1feabcd

293 d = 0.0 1eabcd

294 dd = 0.0 1eabcd

295 y = (2.0 * x - cs.a - cs.b) / (cs.b - cs.a) 1eabcd

296 y2 = 2.0 * y 1eabcd

297 

298 for c in cs.c[:0:-1]: 1eabcd

299 d, dd = y2 * d - dd + c, d 1eabcd

300 

301 return y * d - dd + 0.5 * cs.c[0] 1eabcd

302 

303# chebyshev fit for (s(t)-1)Zeta[s(t)] 

304# s(t)= (t+1)/2 

305# -1 <= t <= 1 

306zeta_xlt1_cs = ChebSeries(jnp.array([ 1feabcd

307 1.48018677156931561235192914649, 

308 0.25012062539889426471999938167, 

309 0.00991137502135360774243761467, 

310 -0.00012084759656676410329833091, 

311 -4.7585866367662556504652535281e-06, 

312 2.2229946694466391855561441361e-07, 

313 -2.2237496498030257121309056582e-09, 

314 -1.0173226513229028319420799028e-10, 

315 4.3756643450424558284466248449e-12, 

316 -6.2229632593100551465504090814e-14, 

317 -6.6116201003272207115277520305e-16, 

318 4.9477279533373912324518463830e-17, 

319 -1.0429819093456189719660003522e-18, 

320 6.9925216166580021051464412040e-21, 

321]), -1, 1) 

322 

323# chebyshev fit for (s(t)-1)Zeta[s(t)] 

324# s(t)= (19t+21)/2 

325# -1 <= t <= 1 

326zeta_xgt1_cs = ChebSeries(jnp.array([ 1feabcd

327 19.3918515726724119415911269006, 

328 9.1525329692510756181581271500, 

329 0.2427897658867379985365270155, 

330 -0.1339000688262027338316641329, 

331 0.0577827064065028595578410202, 

332 -0.0187625983754002298566409700, 

333 0.0039403014258320354840823803, 

334 -0.0000581508273158127963598882, 

335 -0.0003756148907214820704594549, 

336 0.0001892530548109214349092999, 

337 -0.0000549032199695513496115090, 

338 8.7086484008939038610413331863e-6, 

339 6.4609477924811889068410083425e-7, 

340 -9.6749773915059089205835337136e-7, 

341 3.6585400766767257736982342461e-7, 

342 -8.4592516427275164351876072573e-8, 

343 9.9956786144497936572288988883e-9, 

344 1.4260036420951118112457144842e-9, 

345 -1.1761968823382879195380320948e-9, 

346 3.7114575899785204664648987295e-10, 

347 -7.4756855194210961661210215325e-11, 

348 7.8536934209183700456512982968e-12, 

349 9.9827182259685539619810406271e-13, 

350 -7.5276687030192221587850302453e-13, 

351 2.1955026393964279988917878654e-13, 

352 -4.1934859852834647427576319246e-14, 

353 4.6341149635933550715779074274e-15, 

354 2.3742488509048340106830309402e-16, 

355 -2.7276516388124786119323824391e-16, 

356 7.8473570134636044722154797225e-17 

357]), -1, 1) 

358 

359def zeta_0_1(s): 1feabcd

360 return cheb_eval_e(zeta_xlt1_cs, 2.0 * s - 1.0) / (s - 1.0) 1eabcd

361 

362def zeta_1_20(s): 1feabcd

363 return cheb_eval_e(zeta_xgt1_cs, (2.0 * s - 21.0) / 19.0) / (s - 1.0) 1eabcd

364 

365def zeta_20_inf(s): 1feabcd

366 f2 = 1.0 - 2.0 ** -s 1eabcd

367 f3 = 1.0 - 3.0 ** -s 1eabcd

368 f5 = 1.0 - 5.0 ** -s 1eabcd

369 f7 = 1.0 - 7.0 ** -s 1eabcd

370 return 1.0 / (f2 * f3 * f5 * f7); 1eabcd

371 

372def zeta_0_inf(s): 1feabcd

373 return jnp.where(s >= 20, 1eabcd

374 zeta_20_inf(s), 

375 jnp.where(s >= 1, 

376 zeta_1_20(s), 

377 zeta_0_1(s), 

378 ), 

379 ) 

380 

381##########################################################################