Coverage for src/lsqfitgp/copula/_beta.py: 100%
25 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/copula/_beta.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" JAX-compatible implementation of the beta distribution """
22import functools 1fabcde
24from scipy import special 1fabcde
25import jax 1fabcde
26from jax.scipy import special as jspecial 1fabcde
27from jax import numpy as jnp 1fabcde
29from .. import _jaxext 1fabcde
31@functools.partial(jax.custom_jvp, nondiff_argnums=(0, 1)) 1fabcde
32def betaincinv(a, b, y): 1fabcde
33 a = jnp.asarray(a) 1abcde
34 b = jnp.asarray(b) 1abcde
35 y = jnp.asarray(y) 1abcde
36 dtype = _jaxext.float_type(a.dtype, b.dtype, y.dtype) 1abcde
37 return _jaxext.pure_callback_ufunc( 1abcde
38 lambda *args: special.betaincinv(*args).astype(dtype),
39 dtype, a, b, y,
40 )
42dIdx_ = _jaxext.elementwise_grad(jspecial.betainc, 2) 1fabcde
44@betaincinv.defjvp 1fabcde
45def betaincinv_jvp(a, b, primals, tangents): 1fabcde
46 y, = primals 1abcde
47 yt, = tangents 1abcde
48 x = betaincinv(a, b, y) 1abcde
49 dIdx = dIdx_(a, b, x) 1abcde
50 return x, yt / dIdx 1abcde
52class beta: 1fabcde
54 @staticmethod 1fabcde
55 def ppf(q, a, b): 1fabcde
56 return betaincinv(a, b, q) 1abcde