Coverage for src/lsqfitgp/_special/_expint.py: 100%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_special/_expint.py 

2# 

3# Copyright (c) 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1efabcd

21 

22from scipy import special 1efabcd

23import jax 1efabcd

24from jax import numpy as jnp 1efabcd

25from jax.scipy import special as jspecial 1efabcd

26 

27from . import _gamma 1efabcd

28from . import _taylor 1efabcd

29from .. import _jaxext 1efabcd

30 

31@functools.partial(jax.custom_jvp, nondiff_argnums=(0,)) 1efabcd

32def expn_imag(n, x): 1efabcd

33 """ 

34 Compute E_n(-ix), n integer >= 2, x real >= 0 

35 """ 

36 

37 # expn_imag_smallx loses accuracy due to cancellation between two terms 

38 # ~ x^n-2, while the result ~ x^-1, thus the relative error ~ x^-1/x^n-2 = 

39 # = x^-(n-1) 

40 # 

41 # error of expn_imag_smallx: eps z^n-1 E_1(z) / Gamma(n) ~ 

42 # ~ eps z^n-2 / Gamma(n) 

43 # 

44 # error of expn_asymp: e^-z/z (n)_nt e^z/z^nt-1 E_n+nt(z) = 

45 # = (n)_nt / z^nt E_n+nt(z) ~ 

46 # ~ (n)_nt / z^nt+1 

47 # 

48 # set the errors equal: 

49 # eps z^n-2 / Gamma(n) = (n)_nt / z^nt+1 --> 

50 # --> z = (Gamma(n + nt) / eps)^1/(n+nt-1) 

51 

52 # TODO improve accuracy at large n, it is probably sufficient to use 

53 # something like softmin(1/(n-1), 1/x) e^-ix, where the softmin scale 

54 # increases with n (how?) 

55 

56 x = jnp.asarray(x) 1abcd

57 with jax.ensure_compile_time_eval(): 1abcd

58 n = jnp.asarray(n) 1abcd

59 dt = _jaxext.float_type(n, x) 1abcd

60 if dt == jnp.float32: 1abcd

61 nt = jnp.array(10, 'i4') # TODO optimize to raise maximum n 1abcd

62 else: 

63 nt = 20 # TODO optimize to raise maximum n 1abcd

64 eps = jnp.finfo(dt).eps 1abcd

65 knee = (special.gamma(n + nt) / eps) ** (1 / (n + nt - 1)) 1abcd

66 small = expn_imag_smallx(n, x) 1abcd

67 large = expn_asymp(n, -1j * x, nt) 1abcd

68 return jnp.where(x < knee, small, large) 1abcd

69 

70@expn_imag.defjvp 1efabcd

71def expn_imag_jvp(n, primals, tangents): 1efabcd

72 

73 # DLMF 8.19.13 

74 

75 x, = primals 1abcd

76 xt, = tangents 1abcd

77 return expn_imag(n, x), xt * 1j * expn_imag(n - 1, x) 1abcd

78 

79def expn_imag_smallx(n, x): 1efabcd

80 

81 # DLMF 8.19.7 

82 

83 n, x = jnp.asarray(n), jnp.asarray(x) 1abcd

84 k = jnp.arange(n) 1abcd

85 fact = jnp.cumprod(k.at[0].set(1), dtype=_jaxext.float_type(n, x)) 1abcd

86 n_1fact = fact[-1] 1abcd

87 ix = 1j * x 1abcd

88 E_1 = exp1_imag(x) # E_1(-ix) 1abcd

89 E_1 = jnp.where(x, E_1, 0) # Re E_1(-ix) ~ log(x) for x -> 0 1abcd

90 part1 = ix ** (n - 1) * E_1 1abcd

91 coefs = fact[:-1][(...,) + (None,) * ix.ndim] 1abcd

92 part2 = jnp.exp(ix) * jnp.polyval(coefs, ix) 1abcd

93 return (part1 + part2) / n_1fact 1abcd

94 

95 # TODO to make this work with jit n, since the maximum n is something 

96 # like 30, I can always compute all the terms and set some of them to zero 

97 

98def expn_asymp_coefgen(s, e, n): 1efabcd

99 k = jnp.arange(s, e, dtype=n.dtype) 1abcd

100 return (-1) ** k * _gamma.poch(n, k) 1abcd

101 

102def expn_asymp(n, z, nt): 1efabcd

103 """ 

104 Compute E_n(z) for large |z|, |arg z| < 3/2 π. ``nt`` is the number of terms 

105 used in the asymptotic series. 

106 """ 

107 

108 # DLMF 8.20.2 

109 

110 invz = 1 / z 1abcd

111 return jnp.exp(-z) * invz * _taylor.taylor(expn_asymp_coefgen, (n,), 0, nt, invz) 1abcd

112 

113_si_num = [ 1efabcd

114 1, 

115 -4.54393409816329991e-2, # x^2 

116 1.15457225751016682e-3, # x^4 

117 -1.41018536821330254e-5, # x^6 

118 9.43280809438713025e-8, # x^8 

119 -3.53201978997168357e-10, # x^10 

120 7.08240282274875911e-13, # x^12 

121 -6.05338212010422477e-16, # x^14 

122] 

123 

124_si_denom = [ 1efabcd

125 1, 

126 1.01162145739225565e-2, # x^2 

127 4.99175116169755106e-5, # x^4 

128 1.55654986308745614e-7, # x^6 

129 3.28067571055789734e-10, # x^8 

130 4.5049097575386581e-13, # x^10 

131 3.21107051193712168e-16, # x^12 

132] 

133 

134_ci_num = [ 1efabcd

135 -0.25, 

136 7.51851524438898291e-3, # x^2 

137 -1.27528342240267686e-4, # x^4 

138 1.05297363846239184e-6, # x^6 

139 -4.68889508144848019e-9, # x^8 

140 1.06480802891189243e-11, # x^10 

141 -9.93728488857585407e-15, # x^12 

142] 

143 

144_ci_denom = [ 1efabcd

145 1, 

146 1.1592605689110735e-2, # x^2 

147 6.72126800814254432e-5, # x^4 

148 2.55533277086129636e-7, # x^6 

149 6.97071295760958946e-10, # x^8 

150 1.38536352772778619e-12, # x^10 

151 1.89106054713059759e-15, # x^12 

152 1.39759616731376855e-18, # x^14 

153] 

154 

155_f_num = [ 1efabcd

156 1, 

157 7.44437068161936700618e2, # x^-2 

158 1.96396372895146869801e5, # x^-4 

159 2.37750310125431834034e7, # x^-6 

160 1.43073403821274636888e9, # x^-8 

161 4.33736238870432522765e10, # x^-10 

162 6.40533830574022022911e11, # x^-12 

163 4.20968180571076940208e12, # x^-14 

164 1.00795182980368574617e13, # x^-16 

165 4.94816688199951963482e12, # x^-18 

166 -4.94701168645415959931e11, # x^-20 

167] 

168 

169_f_denom = [ 1efabcd

170 1, 

171 7.46437068161927678031e2, # x^-2 

172 1.97865247031583951450e5, # x^-4 

173 2.41535670165126845144e7, # x^-6 

174 1.47478952192985464958e9, # x^-8 

175 4.58595115847765779830e10, # x^-10 

176 7.08501308149515401563e11, # x^-12 

177 5.06084464593475076774e12, # x^-14 

178 1.43468549171581016479e13, # x^-16 

179 1.11535493509914254097e13, # x^-18 

180] 

181 

182_g_num = [ 1efabcd

183 1, 

184 8.1359520115168615e2, # x^-2 

185 2.35239181626478200e5, # x^-4 

186 3.12557570795778731e7, # x^-6 

187 2.06297595146763354e9, # x^-8 

188 6.83052205423625007e10, # x^-10 

189 1.09049528450362786e12, # x^-12 

190 7.57664583257834349e12, # x^-14 

191 1.81004487464664575e13, # x^-16 

192 6.43291613143049485e12, # x^-18 

193 -1.36517137670871689e12, # x^-20 

194] 

195 

196_g_denom = [ 1efabcd

197 1, 

198 8.19595201151451564e2, # x^-2 

199 2.40036752835578777e5, # x^-4 

200 3.26026661647090822e7, # x^-6 

201 2.23355543278099360e9, # x^-8 

202 7.87465017341829930e10, # x^-10 

203 1.39866710696414565e12, # x^-12 

204 1.17164723371736605e13, # x^-14 

205 4.01839087307656620e13, # x^-16 

206 3.99653257887490811e13, # x^-18 

207] 

208 

209def _si_smallx(x): 1efabcd

210 """ Compute Si(x) = int_0^x dt sin t / t, for x < 4""" 

211 x2 = jnp.square(x) 1abcd

212 dtype = _jaxext.float_type(x) 1abcd

213 num = jnp.polyval(jnp.array(_si_num[::-1], dtype), x2) 1abcd

214 denom = jnp.polyval(jnp.array(_si_denom[::-1], dtype), x2) 1abcd

215 return x * num / denom 1abcd

216 

217def _minus_cin_smallx(x): 1efabcd

218 """ Compute -Cin(x) = int_0^x dt (cos t - 1) / t, for x < 4 """ 

219 x2 = jnp.square(x) 1abcd

220 dtype = _jaxext.float_type(x) 1abcd

221 num = jnp.polyval(jnp.array(_ci_num[::-1], dtype), x2) 1abcd

222 denom = jnp.polyval(jnp.array(_ci_denom[::-1], dtype), x2) 1abcd

223 return x2 * num / denom 1abcd

224 

225def _ci_smallx(x): 1efabcd

226 """ Compute Ci(x) = -int_x^oo dt cos t / t, for x < 4 """ 

227 gamma = 0.57721566490153286060 1abcd

228 return gamma + jnp.log(x) + _minus_cin_smallx(x) 1abcd

229 

230def _f_largex(x): 1efabcd

231 """ Compute f(x) = int_0^oo dt sin t / (x + t), for x > 4 """ 

232 x2 = 1 / jnp.square(x) 1abcd

233 dtype = _jaxext.float_type(x) 1abcd

234 num = jnp.polyval(jnp.array(_f_num[::-1], dtype), x2) 1abcd

235 denom = jnp.polyval(jnp.array(_f_denom[::-1], dtype), x2) 1abcd

236 return num / denom / x 1abcd

237 

238def _g_largex(x): 1efabcd

239 """ Compute g(x) = int_0^oo dt cos t / (x + t), for x > 4 """ 

240 x2 = 1 / jnp.square(x) 1abcd

241 dtype = _jaxext.float_type(x) 1abcd

242 num = jnp.polyval(jnp.array(_g_num[::-1], dtype), x2) 1abcd

243 denom = jnp.polyval(jnp.array(_g_denom[::-1], dtype), x2) 1abcd

244 return x2 * num / denom 1abcd

245 

246def _exp1_imag_smallx(x): 1efabcd

247 """ Compute E_1(-ix), for x < 4 """ 

248 return -_ci_smallx(x) + 1j * (jnp.pi / 2 - _si_smallx(x)) 1abcd

249 

250def _exp1_imag_largex(x): 1efabcd

251 """ Compute E_1(-ix), for x > 4 """ 

252 s = jnp.sin(x) 1abcd

253 c = jnp.cos(x) 1abcd

254 f = _f_largex(x) 1abcd

255 g = _g_largex(x) 1abcd

256 real = -f * s + g * c 1abcd

257 imag = f * c + g * s 1abcd

258 return real + 1j * imag # e^ix (g + if) 1abcd

259 

260@jax.jit 1efabcd

261def exp1_imag(x): 1efabcd

262 """ 

263 Compute E_1(-ix) = int_1^oo dt e^ixt / t, for x > 0 

264 Reference: Rowe et al. (2015, app. B) 

265 """ 

266 return jnp.where(x < 4, _exp1_imag_smallx(x), _exp1_imag_largex(x)) 1abcd

267 

268 # TODO This is 40x faster than special.exp1(-1j * x) and 2x than 

269 # special.sici(x), and since the jit has to run (I'm guessing) through both 

270 # branches of jnp.where, a C/Cython implementation would be 4x faster. Maybe 

271 # PR it to scipy for sici, after checking the accuracy against mpmath and 

272 # the actual C performance. 

273 

274 # Do Padé approximants work for complex functions? 

275 

276@jax.custom_jvp 1efabcd

277def ci(x): 1efabcd

278 return -exp1_imag(x).real 1abcd

279 

280@ci.defjvp 1efabcd

281def _ci_jvp(primals, tangents): 1efabcd

282 x, = primals 1abcd

283 xt, = tangents 1abcd

284 return ci(x), xt * jnp.cos(x) / x 1abcd