Coverage for src/lsqfitgp/copula/_beta.py: 100%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/copula/_beta.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" JAX-compatible implementation of the beta distribution """ 

21 

22import functools 1fabcde

23 

24from scipy import special 1fabcde

25import jax 1fabcde

26from jax.scipy import special as jspecial 1fabcde

27from jax import numpy as jnp 1fabcde

28 

29from .. import _jaxext 1fabcde

30 

31@functools.partial(jax.custom_jvp, nondiff_argnums=(0, 1)) 1fabcde

32def betaincinv(a, b, y): 1fabcde

33 a = jnp.asarray(a) 1abcde

34 b = jnp.asarray(b) 1abcde

35 y = jnp.asarray(y) 1abcde

36 dtype = _jaxext.float_type(a.dtype, b.dtype, y.dtype) 1abcde

37 return _jaxext.pure_callback_ufunc( 1abcde

38 lambda *args: special.betaincinv(*args).astype(dtype), 

39 dtype, a, b, y, 

40 ) 

41 

42dIdx_ = _jaxext.elementwise_grad(jspecial.betainc, 2) 1fabcde

43 

44@betaincinv.defjvp 1fabcde

45def betaincinv_jvp(a, b, primals, tangents): 1fabcde

46 y, = primals 1abcde

47 yt, = tangents 1abcde

48 x = betaincinv(a, b, y) 1abcde

49 dIdx = dIdx_(a, b, x) 1abcde

50 return x, yt / dIdx 1abcde

51 

52class beta: 1fabcde

53 

54 @staticmethod 1fabcde

55 def ppf(q, a, b): 1fabcde

56 return betaincinv(a, b, q) 1abcde