Coverage for src/lsqfitgp/_special/_expint.py: 100%
101 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_special/_expint.py
2#
3# Copyright (c) 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1efabcd
22from scipy import special 1efabcd
23import jax 1efabcd
24from jax import numpy as jnp 1efabcd
25from jax.scipy import special as jspecial 1efabcd
27from . import _gamma 1efabcd
28from . import _taylor 1efabcd
29from .. import _jaxext 1efabcd
31@functools.partial(jax.custom_jvp, nondiff_argnums=(0,)) 1efabcd
32def expn_imag(n, x): 1efabcd
33 """
34 Compute E_n(-ix), n integer >= 2, x real >= 0
35 """
37 # expn_imag_smallx loses accuracy due to cancellation between two terms
38 # ~ x^n-2, while the result ~ x^-1, thus the relative error ~ x^-1/x^n-2 =
39 # = x^-(n-1)
40 #
41 # error of expn_imag_smallx: eps z^n-1 E_1(z) / Gamma(n) ~
42 # ~ eps z^n-2 / Gamma(n)
43 #
44 # error of expn_asymp: e^-z/z (n)_nt e^z/z^nt-1 E_n+nt(z) =
45 # = (n)_nt / z^nt E_n+nt(z) ~
46 # ~ (n)_nt / z^nt+1
47 #
48 # set the errors equal:
49 # eps z^n-2 / Gamma(n) = (n)_nt / z^nt+1 -->
50 # --> z = (Gamma(n + nt) / eps)^1/(n+nt-1)
52 # TODO improve accuracy at large n, it is probably sufficient to use
53 # something like softmin(1/(n-1), 1/x) e^-ix, where the softmin scale
54 # increases with n (how?)
56 x = jnp.asarray(x) 1abcd
57 with jax.ensure_compile_time_eval(): 1abcd
58 n = jnp.asarray(n) 1abcd
59 dt = _jaxext.float_type(n, x) 1abcd
60 if dt == jnp.float32: 1abcd
61 nt = jnp.array(10, 'i4') # TODO optimize to raise maximum n 1abcd
62 else:
63 nt = 20 # TODO optimize to raise maximum n 1abcd
64 eps = jnp.finfo(dt).eps 1abcd
65 knee = (special.gamma(n + nt) / eps) ** (1 / (n + nt - 1)) 1abcd
66 small = expn_imag_smallx(n, x) 1abcd
67 large = expn_asymp(n, -1j * x, nt) 1abcd
68 return jnp.where(x < knee, small, large) 1abcd
70@expn_imag.defjvp 1efabcd
71def expn_imag_jvp(n, primals, tangents): 1efabcd
73 # DLMF 8.19.13
75 x, = primals 1abcd
76 xt, = tangents 1abcd
77 return expn_imag(n, x), xt * 1j * expn_imag(n - 1, x) 1abcd
79def expn_imag_smallx(n, x): 1efabcd
81 # DLMF 8.19.7
83 n, x = jnp.asarray(n), jnp.asarray(x) 1abcd
84 k = jnp.arange(n) 1abcd
85 fact = jnp.cumprod(k.at[0].set(1), dtype=_jaxext.float_type(n, x)) 1abcd
86 n_1fact = fact[-1] 1abcd
87 ix = 1j * x 1abcd
88 E_1 = exp1_imag(x) # E_1(-ix) 1abcd
89 E_1 = jnp.where(x, E_1, 0) # Re E_1(-ix) ~ log(x) for x -> 0 1abcd
90 part1 = ix ** (n - 1) * E_1 1abcd
91 coefs = fact[:-1][(...,) + (None,) * ix.ndim] 1abcd
92 part2 = jnp.exp(ix) * jnp.polyval(coefs, ix) 1abcd
93 return (part1 + part2) / n_1fact 1abcd
95 # TODO to make this work with jit n, since the maximum n is something
96 # like 30, I can always compute all the terms and set some of them to zero
98def expn_asymp_coefgen(s, e, n): 1efabcd
99 k = jnp.arange(s, e, dtype=n.dtype) 1abcd
100 return (-1) ** k * _gamma.poch(n, k) 1abcd
102def expn_asymp(n, z, nt): 1efabcd
103 """
104 Compute E_n(z) for large |z|, |arg z| < 3/2 π. ``nt`` is the number of terms
105 used in the asymptotic series.
106 """
108 # DLMF 8.20.2
110 invz = 1 / z 1abcd
111 return jnp.exp(-z) * invz * _taylor.taylor(expn_asymp_coefgen, (n,), 0, nt, invz) 1abcd
113_si_num = [ 1efabcd
114 1,
115 -4.54393409816329991e-2, # x^2
116 1.15457225751016682e-3, # x^4
117 -1.41018536821330254e-5, # x^6
118 9.43280809438713025e-8, # x^8
119 -3.53201978997168357e-10, # x^10
120 7.08240282274875911e-13, # x^12
121 -6.05338212010422477e-16, # x^14
122]
124_si_denom = [ 1efabcd
125 1,
126 1.01162145739225565e-2, # x^2
127 4.99175116169755106e-5, # x^4
128 1.55654986308745614e-7, # x^6
129 3.28067571055789734e-10, # x^8
130 4.5049097575386581e-13, # x^10
131 3.21107051193712168e-16, # x^12
132]
134_ci_num = [ 1efabcd
135 -0.25,
136 7.51851524438898291e-3, # x^2
137 -1.27528342240267686e-4, # x^4
138 1.05297363846239184e-6, # x^6
139 -4.68889508144848019e-9, # x^8
140 1.06480802891189243e-11, # x^10
141 -9.93728488857585407e-15, # x^12
142]
144_ci_denom = [ 1efabcd
145 1,
146 1.1592605689110735e-2, # x^2
147 6.72126800814254432e-5, # x^4
148 2.55533277086129636e-7, # x^6
149 6.97071295760958946e-10, # x^8
150 1.38536352772778619e-12, # x^10
151 1.89106054713059759e-15, # x^12
152 1.39759616731376855e-18, # x^14
153]
155_f_num = [ 1efabcd
156 1,
157 7.44437068161936700618e2, # x^-2
158 1.96396372895146869801e5, # x^-4
159 2.37750310125431834034e7, # x^-6
160 1.43073403821274636888e9, # x^-8
161 4.33736238870432522765e10, # x^-10
162 6.40533830574022022911e11, # x^-12
163 4.20968180571076940208e12, # x^-14
164 1.00795182980368574617e13, # x^-16
165 4.94816688199951963482e12, # x^-18
166 -4.94701168645415959931e11, # x^-20
167]
169_f_denom = [ 1efabcd
170 1,
171 7.46437068161927678031e2, # x^-2
172 1.97865247031583951450e5, # x^-4
173 2.41535670165126845144e7, # x^-6
174 1.47478952192985464958e9, # x^-8
175 4.58595115847765779830e10, # x^-10
176 7.08501308149515401563e11, # x^-12
177 5.06084464593475076774e12, # x^-14
178 1.43468549171581016479e13, # x^-16
179 1.11535493509914254097e13, # x^-18
180]
182_g_num = [ 1efabcd
183 1,
184 8.1359520115168615e2, # x^-2
185 2.35239181626478200e5, # x^-4
186 3.12557570795778731e7, # x^-6
187 2.06297595146763354e9, # x^-8
188 6.83052205423625007e10, # x^-10
189 1.09049528450362786e12, # x^-12
190 7.57664583257834349e12, # x^-14
191 1.81004487464664575e13, # x^-16
192 6.43291613143049485e12, # x^-18
193 -1.36517137670871689e12, # x^-20
194]
196_g_denom = [ 1efabcd
197 1,
198 8.19595201151451564e2, # x^-2
199 2.40036752835578777e5, # x^-4
200 3.26026661647090822e7, # x^-6
201 2.23355543278099360e9, # x^-8
202 7.87465017341829930e10, # x^-10
203 1.39866710696414565e12, # x^-12
204 1.17164723371736605e13, # x^-14
205 4.01839087307656620e13, # x^-16
206 3.99653257887490811e13, # x^-18
207]
209def _si_smallx(x): 1efabcd
210 """ Compute Si(x) = int_0^x dt sin t / t, for x < 4"""
211 x2 = jnp.square(x) 1abcd
212 dtype = _jaxext.float_type(x) 1abcd
213 num = jnp.polyval(jnp.array(_si_num[::-1], dtype), x2) 1abcd
214 denom = jnp.polyval(jnp.array(_si_denom[::-1], dtype), x2) 1abcd
215 return x * num / denom 1abcd
217def _minus_cin_smallx(x): 1efabcd
218 """ Compute -Cin(x) = int_0^x dt (cos t - 1) / t, for x < 4 """
219 x2 = jnp.square(x) 1abcd
220 dtype = _jaxext.float_type(x) 1abcd
221 num = jnp.polyval(jnp.array(_ci_num[::-1], dtype), x2) 1abcd
222 denom = jnp.polyval(jnp.array(_ci_denom[::-1], dtype), x2) 1abcd
223 return x2 * num / denom 1abcd
225def _ci_smallx(x): 1efabcd
226 """ Compute Ci(x) = -int_x^oo dt cos t / t, for x < 4 """
227 gamma = 0.57721566490153286060 1abcd
228 return gamma + jnp.log(x) + _minus_cin_smallx(x) 1abcd
230def _f_largex(x): 1efabcd
231 """ Compute f(x) = int_0^oo dt sin t / (x + t), for x > 4 """
232 x2 = 1 / jnp.square(x) 1abcd
233 dtype = _jaxext.float_type(x) 1abcd
234 num = jnp.polyval(jnp.array(_f_num[::-1], dtype), x2) 1abcd
235 denom = jnp.polyval(jnp.array(_f_denom[::-1], dtype), x2) 1abcd
236 return num / denom / x 1abcd
238def _g_largex(x): 1efabcd
239 """ Compute g(x) = int_0^oo dt cos t / (x + t), for x > 4 """
240 x2 = 1 / jnp.square(x) 1abcd
241 dtype = _jaxext.float_type(x) 1abcd
242 num = jnp.polyval(jnp.array(_g_num[::-1], dtype), x2) 1abcd
243 denom = jnp.polyval(jnp.array(_g_denom[::-1], dtype), x2) 1abcd
244 return x2 * num / denom 1abcd
246def _exp1_imag_smallx(x): 1efabcd
247 """ Compute E_1(-ix), for x < 4 """
248 return -_ci_smallx(x) + 1j * (jnp.pi / 2 - _si_smallx(x)) 1abcd
250def _exp1_imag_largex(x): 1efabcd
251 """ Compute E_1(-ix), for x > 4 """
252 s = jnp.sin(x) 1abcd
253 c = jnp.cos(x) 1abcd
254 f = _f_largex(x) 1abcd
255 g = _g_largex(x) 1abcd
256 real = -f * s + g * c 1abcd
257 imag = f * c + g * s 1abcd
258 return real + 1j * imag # e^ix (g + if) 1abcd
260@jax.jit 1efabcd
261def exp1_imag(x): 1efabcd
262 """
263 Compute E_1(-ix) = int_1^oo dt e^ixt / t, for x > 0
264 Reference: Rowe et al. (2015, app. B)
265 """
266 return jnp.where(x < 4, _exp1_imag_smallx(x), _exp1_imag_largex(x)) 1abcd
268 # TODO This is 40x faster than special.exp1(-1j * x) and 2x than
269 # special.sici(x), and since the jit has to run (I'm guessing) through both
270 # branches of jnp.where, a C/Cython implementation would be 4x faster. Maybe
271 # PR it to scipy for sici, after checking the accuracy against mpmath and
272 # the actual C performance.
274 # Do Padé approximants work for complex functions?
276@jax.custom_jvp 1efabcd
277def ci(x): 1efabcd
278 return -exp1_imag(x).real 1abcd
280@ci.defjvp 1efabcd
281def _ci_jvp(primals, tangents): 1efabcd
282 x, = primals 1abcd
283 xt, = tangents 1abcd
284 return ci(x), xt * jnp.cos(x) / x 1abcd