Coverage for src/lsqfitgp/_special/_exp.py: 100%
16 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_special/_exp.py
2#
3# Copyright (c) 2022, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import jax 1efabcd
21from jax import numpy as jnp 1efabcd
23@jax.custom_jvp 1efabcd
24@jax.jit 1efabcd
25def expm1x(x): 1efabcd
26 r"""
27 Compute accurately :math:`e^x - 1 - x = x^2/2 {}_1F_1(1, 3, x)`.
28 """
29 n = 10 if x.dtype == jnp.float32 else 17 1abcd
30 k = jnp.arange(2, n + 1) 1abcd
31 f = jnp.cumprod(k) 1abcd
32 coef = jnp.array(1, x.dtype) / f[::-1] 1abcd
33 smallx = x * x * jnp.polyval(coef, x, unroll=n) 1abcd
34 return jnp.where(jnp.abs(x) < 1, smallx, jnp.expm1(x) - x) 1abcd
36 # see also the GSL
37 # https://www.gnu.org/software/gsl/doc/html/specfunc.html#relative-exponential-functions
39@expm1x.defjvp 1efabcd
40def _expm1x_jvp(p, t): 1efabcd
41 x, = p 1abcd
42 xt, = t 1abcd
43 return expm1x(x), jnp.expm1(x) * xt 1abcd