Coverage for src/lsqfitgp/copula/_gamma.py: 100%
103 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/copula/_gamma.py
2#
3# Copyright (c) 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20"""
21JAX-compatible implementation of the gamma and related distributions
22"""
24import functools 1feabcd
26from scipy import special 1feabcd
27import jax 1feabcd
28from jax.scipy import special as jspecial 1feabcd
29from jax import numpy as jnp 1feabcd
30import numpy 1feabcd
32from .. import _jaxext 1feabcd
34def _castto(func, type): 1feabcd
35 @functools.wraps(func) 1eabcd
36 def newfunc(*args, **kw): 1eabcd
37 return func(*args, **kw).astype(type) 1eabcd
38 return newfunc 1eabcd
40@jax.custom_jvp 1feabcd
41def gammainccinv(a, y): 1feabcd
42 a = jnp.asarray(a) 1eabcd
43 y = jnp.asarray(y) 1eabcd
44 dtype = _jaxext.float_type(a.dtype, y.dtype) 1eabcd
45 ufunc = _castto(special.gammainccinv, dtype) 1eabcd
46 return _jaxext.pure_callback_ufunc(ufunc, dtype, a, y) 1eabcd
48dQ_da = _jaxext.elementwise_grad(jspecial.gammaincc, 0) 1feabcd
49dQ_dx = _jaxext.elementwise_grad(jspecial.gammaincc, 1) 1feabcd
51@gammainccinv.defjvp 1feabcd
52def gammainccinv_jvp(primals, tangents): 1feabcd
53 a, y = primals 1eabcd
54 at, yt = tangents 1eabcd
56 x = gammainccinv(a, y) 1eabcd
58 dQ_dx_a_x = dQ_dx(a, x) 1eabcd
59 dQinv_dy_a_y = 1 / dQ_dx_a_x 1eabcd
60 xt = dQinv_dy_a_y * yt 1eabcd
62 if jnp.issubdtype(jnp.asarray(a).dtype, jnp.floating): # modern jax would be: getattr(at, 'dtype', jnp.float64) != jax.float0 1eabcd
63 dQ_da_a_x = dQ_da(a, x) 1abcd
64 dQinv_da_a_y = -dQinv_dy_a_y * dQ_da_a_x 1abcd
65 xt += dQinv_da_a_y * at 1abcd
67 return x, xt 1eabcd
69@jax.custom_jvp 1feabcd
70def gammaincinv(a, y): 1feabcd
71 a = jnp.asarray(a) 1eabcd
72 y = jnp.asarray(y) 1eabcd
73 dtype = _jaxext.float_type(a.dtype, y.dtype) 1eabcd
74 ufunc = _castto(special.gammaincinv, dtype) 1eabcd
75 return _jaxext.pure_callback_ufunc(ufunc, dtype, a, y) 1eabcd
77dP_da = _jaxext.elementwise_grad(jspecial.gammainc, 0) 1feabcd
78dP_dx = _jaxext.elementwise_grad(jspecial.gammainc, 1) 1feabcd
80@gammaincinv.defjvp 1feabcd
81def gammaincinv_jvp(primals, tangents): 1feabcd
82 a, y = primals 1eabcd
83 at, yt = tangents 1eabcd
85 x = gammaincinv(a, y) 1eabcd
87 dP_dx_a_x = dP_dx(a, x) 1eabcd
88 dPinv_dy_a_y = 1 / dP_dx_a_x 1eabcd
89 xt = dPinv_dy_a_y * yt 1eabcd
91 if jnp.issubdtype(jnp.asarray(a).dtype, jnp.floating): # modern jax would be: getattr(at, 'dtype', jnp.float64) != jax.float0 1eabcd
92 dP_da_a_x = dP_da(a, x) 1abcd
93 dPinv_da_a_y = -dPinv_dy_a_y * dP_da_a_x 1abcd
94 xt += dPinv_da_a_y * at 1abcd
96 return x, xt 1eabcd
98def _gammaisf_normcdf_large_neg_x(x, a): 1feabcd
99 logphi = lambda x: -1/2 * jnp.log(2 * jnp.pi) - 1/2 * jnp.square(x) - jnp.log(-x) 1eabcd
100 logq = logphi(x) 1eabcd
101 loggammaa = jspecial.gammaln(a) 1eabcd
102 f = lambda y: (a - 1) * jnp.log(y) - y - loggammaa - logq 1eabcd
103 f1 = lambda y: (a - 1) / y - 1 1eabcd
104 y0 = -logq 1eabcd
105 y1 = y0 - ((a - 1) * jnp.log(y0) - loggammaa) / ((a - 1) / y0 - 1) 1eabcd
106 return y1 1eabcd
108 # TODO Improve the accuracy. I tried adding one Newton step more, but it
109 # does not improve the accuracy. I probably have to add terms to the
110 # approximations of Phi and Q. I could try first special.erfcx for Phi.
112 # x -> -∞, q -> 0+, y -> ∞
113 # q = Φ(x) ≈ -1/√2π exp(-x²/2)/x
114 # q = Q(a, y) ≈ y^(a-1) e^-y / Γ(a)
115 # gamma.isf(q, a) = Q⁻¹(a, q)
116 # log q = -1/2 log 2π - x²/2 - log(-x) (1)
117 # log q = (a - 1) log y - y - log Γ(a) (2)
118 # f(y) = (a - 1) log y - y - log Γ(a) - log(q)
119 # = 0 by (2)
120 # f'(y) = (a - 1) / y - 1
121 # y_0 = -log q by considering y -> ∞
122 # y_1 = y_0 - f(y_0) / f'(y_0) Newton step
124def _loggammaisf_normcdf_large_neg_x(x, a): 1feabcd
125 logphi = lambda x: -1/2 * jnp.log(2 * jnp.pi) - 1/2 * jnp.square(x) - jnp.log(-x) 1abcd
126 logq = logphi(x) 1abcd
127 loggammaa = jspecial.gammaln(a) 1abcd
128 g = lambda logy: (a - 1) * logy - jnp.exp(logy) - loggammaa - logq 1abcd
129 g1 = lambda logy: (a - 1) - jnp.exp(logy) 1abcd
130 logy0 = jnp.log(-logq) 1abcd
131 logy1 = logy0 - ((a - 1) * logy0 - loggammaa) / ((a - 1) + logq) 1abcd
132 return logy1 1abcd
134class gamma: 1feabcd
136 @staticmethod 1feabcd
137 def ppf(q, a): 1feabcd
138 return gammaincinv(a, q) 1abcd
140 @staticmethod 1feabcd
141 def isf(q, a): 1feabcd
142 return gammainccinv(a, q) 1abcd
144class invgamma: 1feabcd
146 @staticmethod 1feabcd
147 def ppf(q, a): 1feabcd
148 return 1 / gammainccinv(a, q) 1eabcd
150 @staticmethod 1feabcd
151 def isf(q, a): 1feabcd
152 return 1 / gammaincinv(a, q) 1eabcd
154 @staticmethod 1feabcd
155 def logpdf(x, a): 1feabcd
156 return -(a + 1) * jnp.log(x) - 1 / x - jspecial.gammaln(a) 1abcd
158 @staticmethod 1feabcd
159 def cdf(x, a): 1feabcd
160 return jspecial.gammaincc(a, 1 / x) 1abcd
162class loggamma: 1feabcd
164 @staticmethod 1feabcd
165 def ppf(q, c): 1feabcd
166 # scipy code:
167 # g = sc.gammaincinv(c, q)
168 # return _lazywhere(g < _XMIN, (g, q, c),
169 # lambda g, q, c: (np.log(q) + sc.gammaln(c+1))/c,
170 # f2=lambda g, q, c: np.log(g))
171 g = gammaincinv(c, q) 1abcd
172 return jnp.where(g < jnp.finfo(g.dtype).tiny, 1abcd
173 (jnp.log(q) + jspecial.gammaln(c + 1)) / c,
174 jnp.log(g),
175 )
177 @staticmethod 1feabcd
178 def isf(q, c): 1feabcd
179 # scipy code:
180 # g = sc.gammainccinv(c, q)
181 # return _lazywhere(g < _XMIN, (g, q, c),
182 # lambda g, q, c: (np.log1p(-q) + sc.gammaln(c+1))/c,
183 # f2=lambda g, q, c: np.log(g))
184 g = gammainccinv(c, q) 1abcd
185 return jnp.where(g < jnp.finfo(g.dtype).tiny, 1abcd
186 (jnp.log1p(-q) + jspecial.gammaln(c + 1)) / c,
187 jnp.log(g),
188 )