Coverage for src/lsqfitgp/_special/_taylor.py: 100%
15 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_special/_taylor.py
2#
3# Copyright (c) 2022, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1efabcd
22import jax 1efabcd
23from jax import numpy as jnp 1efabcd
24from jax.scipy import special as jspecial 1efabcd
26@functools.partial(jax.custom_jvp, nondiff_argnums=(0, 1, 2, 3)) 1efabcd
27def taylor(coefgen, args, n, m, x): 1efabcd
28 """
29 coefgen : function = start, end -> taylor coefficients for powers start:end
30 args : tuple = additional arguments to coefgen
31 n : int = derivation order
32 m : int = number of coefficients used
33 x : argument
34 """
35 c = coefgen(n, n + m, *args) 1abcd
36 k = jnp.arange(n, n + m) 1abcd
37 c = c * jnp.exp(jspecial.gammaln(1 + k) - jspecial.gammaln(1 + k - n)) 1abcd
38 return jnp.polyval(c[::-1], x) 1abcd
40@taylor.defjvp 1efabcd
41def taylor_jvp(coefgen, args, n, m, primals, tangents): 1efabcd
42 x, = primals 1abcd
43 xt, = tangents 1abcd
44 return taylor(coefgen, args, n, m, x), taylor(coefgen, args, n + 1, m, x) * xt 1abcd