Coverage for src/lsqfitgp/copula/_copulas.py: 100%
89 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/copula/_copulas.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" predefined distributions """
22import functools 1feabcd
23import collections 1feabcd
25from jax.scipy import special as jspecial 1feabcd
26import jax 1feabcd
27from jax import numpy as jnp 1feabcd
29from .. import _jaxext 1feabcd
30from .. import _array 1feabcd
31from . import _beta, _gamma 1feabcd
32from . import _distr 1feabcd
34def _normcdf(x): 1feabcd
35 x = jnp.asarray(x) 1feabcd
36 x = x.astype(_jaxext.float_type(x)) 1feabcd
37 return jspecial.ndtr(x) 1feabcd
39 # In jax < 0.?.?, jax.scipy.stats.norm.sf is implemented as 1 - cdf(x)
40 # instead of cdf(-x), defeating the purpose of numerical accuracy. Use
41 # _normcdf(-x) instead. See https://github.com/google/jax/issues/17199
43class beta(_distr.Distr): 1feabcd
44 """
45 https://en.wikipedia.org/wiki/Beta_distribution
46 """
48 @staticmethod 1feabcd
49 def invfcn(x, alpha, beta): 1feabcd
50 return _beta.beta.ppf(_normcdf(x), a=alpha, b=beta) 1eabcd
52class dirichlet(_distr.Distr): 1feabcd
53 """
54 https://en.wikipedia.org/wiki/Dirichlet_distribution
55 """
57 signature = '(n),(n)->(n)' 1feabcd
59 @classmethod 1feabcd
60 def invfcn(cls, x, alpha): 1feabcd
61 lny = loggamma.invfcn(x, alpha) 1abcd
62 norm = jspecial.logsumexp(lny, axis=-1, keepdims=True) 1abcd
63 return jnp.exp(lny - norm) 1abcd
65 # @classmethod
66 # def _invfcn_tiny_alpha(cls, x, alpha):
67 # q = _normcdf(x)
68 # lnq = jnp.log(q)
69 # lny = lnq / alpha
70 # lnnorm = jspecial.logsumexp(lny, axis=-1, keepdims=True)
71 # return jnp.exp(lny - lnnorm)
73 # For a -> 0:
74 #
75 # gamma.cdf(x, a) = P(a, x)
76 # = gamma(a, x) / Gamma(a)
77 # = int_0^x dt e^-t t^(a - 1) / (1 / a)
78 # = a [t^a / a]_0^x
79 # = a x^a / a
80 # = x^a
81 #
82 # gamma.ppf(q, a) = P^-1(a, q)
83 # = q^1/a
85class gamma(_distr.Distr): 1feabcd
86 """
87 https://en.wikipedia.org/wiki/Gamma_distribution
88 """
90 @staticmethod 1feabcd
91 def _boundary(x): 1feabcd
92 return { 1eabcd
93 jnp.dtype(jnp.float32): 12,
94 jnp.dtype(jnp.float64): 37,
95 }[x.dtype]
97 @classmethod 1feabcd
98 def invfcn(cls, x, alpha, beta): 1feabcd
99 x = jnp.asarray(x) 1abcd
100 x = x.astype(_jaxext.float_type(x)) 1abcd
101 boundary = cls._boundary(x) 1abcd
102 return _piecewise_multiarg( 1abcd
103 [x < 0, x < boundary, x >= boundary],
104 # TODO the x < 0 case is probably never considered because
105 # piecewise evaluates from the right and x < boundary is
106 # satisfied too. Why are the tests not uncovering the
107 # inaccuracy? First find whether it's accurate the same or if
108 # the tests are lacking, then correct the conditionals.
109 [
110 lambda x, a: _gamma.gamma.ppf(_normcdf(x), a),
111 lambda x, a: _gamma.gamma.isf(_normcdf(-x), a),
112 lambda x, a: _gamma._gammaisf_normcdf_large_neg_x(-x, a),
113 ],
114 x, alpha,
115 ) / beta
117class loggamma(_distr.Distr): 1feabcd
118 """
119 https://en.wikipedia.org/wiki/Gamma_distribution, `scipy.stats.loggamma`
121 This is the distribution of the logarithm of a Gamma variable. The naming
122 convention is the opposite of lognorm, which is the distribution of the
123 exponential of a Normal variable.
124 """
126 @staticmethod 1feabcd
127 def _boundary(x): 1feabcd
128 return gamma._boundary(x) 1abcd
130 @classmethod 1feabcd
131 def invfcn(cls, x, alpha): 1feabcd
132 x = jnp.asarray(x) 1abcd
133 x = x.astype(_jaxext.float_type(x)) 1abcd
134 boundary = cls._boundary(x) 1abcd
135 return _piecewise_multiarg( 1abcd
136 [x < 0, x < boundary, x >= boundary],
137 [
138 lambda x, alpha: _gamma.loggamma.ppf(_normcdf(x), alpha),
139 lambda x, alpha: _gamma.loggamma.isf(_normcdf(-x), alpha),
140 lambda x, alpha: _gamma._loggammaisf_normcdf_large_neg_x(-x, alpha),
141 ],
142 x, alpha,
143 )
145 # TODO scipy.stats.gamma has inaccurate logsf instead of using loggamma.sf,
146 # open an issue
148class invgamma(_distr.Distr): 1feabcd
149 """
150 https://en.wikipedia.org/wiki/Inverse-gamma_distribution
151 """
153 @staticmethod 1feabcd
154 def _boundary(x): 1feabcd
155 return -gamma._boundary(x) 1eabcd
157 @classmethod 1feabcd
158 def invfcn(cls, x, alpha, beta): 1feabcd
159 x = jnp.asarray(x) 1eabcd
160 x = x.astype(_jaxext.float_type(x)) 1eabcd
161 boundary = cls._boundary(x) 1eabcd
162 return beta * _piecewise_multiarg( 1eabcd
163 [x < boundary, x < 0, x >= 0],
164 [
165 lambda x, a: 1 / _gamma._gammaisf_normcdf_large_neg_x(x, a),
166 lambda x, a: _gamma.invgamma.ppf(_normcdf(x), a),
167 lambda x, a: _gamma.invgamma.isf(_normcdf(-x), a),
168 ],
169 x, alpha,
170 )
172def _piecewise_multiarg(conds, functions, *operands): 1feabcd
173 conds = jnp.stack(conds, axis=-1) 1eabcd
174 index = jnp.argmax(conds, axis=-1) 1eabcd
175 return _vectorized_switch(index, functions, *operands) 1eabcd
177@functools.partial(jnp.vectorize, excluded=(1,)) 1feabcd
178def _vectorized_switch(index, branches, *operands): 1feabcd
179 return jax.lax.switch(index, branches, *operands) 1eabcd
181class halfcauchy(_distr.Distr): 1feabcd
182 """
183 https://en.wikipedia.org/wiki/Cauchy_distribution, `scipy.stats.halfcauchy`
184 """
186 @staticmethod 1feabcd
187 def _ppf(p): 1feabcd
188 return jnp.tan(jnp.pi * p / 2) 1eabcd
190 @staticmethod 1feabcd
191 def _isf(p): 1feabcd
192 return 1 / jnp.tan(jnp.pi * p / 2) 1eabcd
194 @classmethod 1feabcd
195 def invfcn(cls, x, gamma): 1feabcd
196 return gamma * jnp.where(x < 0, 1eabcd
197 cls._ppf(_normcdf(x)),
198 cls._isf(_normcdf(-x)),
199 )
201class halfnorm(_distr.Distr): 1feabcd
202 """
203 https://en.wikipedia.org/wiki/Half-normal_distribution
204 """
206 @staticmethod 1feabcd
207 def _ppf(p): 1feabcd
208 # F(x) = 2 Φ(x) - 1
209 # --> F⁻¹(p) = Φ⁻¹((1 + p) / 2)
210 return jspecial.ndtri((1 + p) / 2) 1eabcd
212 @staticmethod 1feabcd
213 def _isf(p): 1feabcd
214 # Φ(-x) = 1 - Φ(x)
215 # --> Φ⁻¹(1 - p) = -Φ⁻¹(p)
216 # S(x) = 1 - F(x)
217 # --> S⁻¹(p) = F⁻¹(1 - p)
218 # = Φ⁻¹((2 - p) / 2)
219 # = Φ⁻¹(1 - p / 2)
220 # = -Φ⁻¹(p / 2)
221 return -jspecial.ndtri(p / 2) 1eabcd
223 @classmethod 1feabcd
224 def invfcn(cls, x, sigma): 1feabcd
225 return sigma * jnp.where(x < 0, 1eabcd
226 cls._ppf(_normcdf(x)),
227 cls._isf(_normcdf(-x)),
228 )
230class uniform(_distr.Distr): 1feabcd
231 """
232 https://en.wikipedia.org/wiki/Continuous_uniform_distribution
233 """
235 @staticmethod 1feabcd
236 def invfcn(x, a, b): 1feabcd
237 return a + (b - a) * _normcdf(x) 1feabcd
239class lognorm(_distr.Distr): 1feabcd
240 """
241 https://en.wikipedia.org/wiki/Log-normal_distribution
242 """
244 @staticmethod 1feabcd
245 def invfcn(x, mu, sigma): 1feabcd
246 return jnp.exp(mu + sigma * x) 1eabcd