Coverage for src/lsqfitgp/_special/_exp.py: 100%

16 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_special/_exp.py 

2# 

3# Copyright (c) 2022, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import jax 1efabcd

21from jax import numpy as jnp 1efabcd

22 

23@jax.custom_jvp 1efabcd

24@jax.jit 1efabcd

25def expm1x(x): 1efabcd

26 r""" 

27 Compute accurately :math:`e^x - 1 - x = x^2/2 {}_1F_1(1, 3, x)`. 

28 """ 

29 n = 10 if x.dtype == jnp.float32 else 17 1abcd

30 k = jnp.arange(2, n + 1) 1abcd

31 f = jnp.cumprod(k) 1abcd

32 coef = jnp.array(1, x.dtype) / f[::-1] 1abcd

33 smallx = x * x * jnp.polyval(coef, x, unroll=n) 1abcd

34 return jnp.where(jnp.abs(x) < 1, smallx, jnp.expm1(x) - x) 1abcd

35 

36 # see also the GSL 

37 # https://www.gnu.org/software/gsl/doc/html/specfunc.html#relative-exponential-functions 

38 

39@expm1x.defjvp 1efabcd

40def _expm1x_jvp(p, t): 1efabcd

41 x, = p 1abcd

42 xt, = t 1abcd

43 return expm1x(x), jnp.expm1(x) * xt 1abcd