Coverage for src/lsqfitgp/_special/_taylor.py: 100%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_special/_taylor.py 

2# 

3# Copyright (c) 2022, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1efabcd

21 

22import jax 1efabcd

23from jax import numpy as jnp 1efabcd

24from jax.scipy import special as jspecial 1efabcd

25 

26@functools.partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 3)) 1efabcd

27def taylor(coefgen, args, n, m, x): 1efabcd

28 """ 

29 coefgen : function = start, end -> taylor coefficients for powers start:end 

30 args : tuple = additional arguments to coefgen 

31 n : int = derivation order 

32 m : int = number of coefficients used 

33 x : argument 

34 """ 

35 c = coefgen(n, n + m, *args) 1abcd

36 k = jnp.arange(n, n + m) 1abcd

37 c = c * jnp.exp(jspecial.gammaln(1 + k) - jspecial.gammaln(1 + k - n)) 1abcd

38 return jnp.polyval(c[::-1], x) 1abcd

39 

40@taylor.defjvp 1efabcd

41def taylor_jvp(coefgen, args, n, m, primals, tangents): 1efabcd

42 x, = primals 1abcd

43 xt, = tangents 1abcd

44 return taylor(coefgen, args, n, m, x), taylor(coefgen, args, n + 1, m, x) * xt 1abcd