Coverage for src/lsqfitgp/copula/_copulas.py: 100%

89 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/copula/_copulas.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" predefined distributions """ 

21 

22import functools 1feabcd

23import collections 1feabcd

24 

25from jax.scipy import special as jspecial 1feabcd

26import jax 1feabcd

27from jax import numpy as jnp 1feabcd

28 

29from .. import _jaxext 1feabcd

30from .. import _array 1feabcd

31from . import _beta, _gamma 1feabcd

32from . import _distr 1feabcd

33 

34def _normcdf(x): 1feabcd

35 x = jnp.asarray(x) 1feabcd

36 x = x.astype(_jaxext.float_type(x)) 1feabcd

37 return jspecial.ndtr(x) 1feabcd

38 

39 # In jax < 0.?.?, jax.scipy.stats.norm.sf is implemented as 1 - cdf(x) 

40 # instead of cdf(-x), defeating the purpose of numerical accuracy. Use 

41 # _normcdf(-x) instead. See https://github.com/google/jax/issues/17199 

42 

43class beta(_distr.Distr): 1feabcd

44 """ 

45 https://en.wikipedia.org/wiki/Beta_distribution 

46 """ 

47 

48 @staticmethod 1feabcd

49 def invfcn(x, alpha, beta): 1feabcd

50 return _beta.beta.ppf(_normcdf(x), a=alpha, b=beta) 1eabcd

51 

52class dirichlet(_distr.Distr): 1feabcd

53 """ 

54 https://en.wikipedia.org/wiki/Dirichlet_distribution 

55 """ 

56 

57 signature = '(n),(n)->(n)' 1feabcd

58 

59 @classmethod 1feabcd

60 def invfcn(cls, x, alpha): 1feabcd

61 lny = loggamma.invfcn(x, alpha) 1abcd

62 norm = jspecial.logsumexp(lny, axis=-1, keepdims=True) 1abcd

63 return jnp.exp(lny - norm) 1abcd

64 

65 # @classmethod 

66 # def _invfcn_tiny_alpha(cls, x, alpha): 

67 # q = _normcdf(x) 

68 # lnq = jnp.log(q) 

69 # lny = lnq / alpha 

70 # lnnorm = jspecial.logsumexp(lny, axis=-1, keepdims=True) 

71 # return jnp.exp(lny - lnnorm) 

72 

73 # For a -> 0: 

74 # 

75 # gamma.cdf(x, a) = P(a, x) 

76 # = gamma(a, x) / Gamma(a) 

77 # = int_0^x dt e^-t t^(a - 1) / (1 / a) 

78 # = a [t^a / a]_0^x 

79 # = a x^a / a 

80 # = x^a 

81 # 

82 # gamma.ppf(q, a) = P^-1(a, q) 

83 # = q^1/a 

84 

85class gamma(_distr.Distr): 1feabcd

86 """ 

87 https://en.wikipedia.org/wiki/Gamma_distribution 

88 """ 

89 

90 @staticmethod 1feabcd

91 def _boundary(x): 1feabcd

92 return { 1eabcd

93 jnp.dtype(jnp.float32): 12, 

94 jnp.dtype(jnp.float64): 37, 

95 }[x.dtype] 

96 

97 @classmethod 1feabcd

98 def invfcn(cls, x, alpha, beta): 1feabcd

99 x = jnp.asarray(x) 1abcd

100 x = x.astype(_jaxext.float_type(x)) 1abcd

101 boundary = cls._boundary(x) 1abcd

102 return _piecewise_multiarg( 1abcd

103 [x < 0, x < boundary, x >= boundary], 

104 # TODO the x < 0 case is probably never considered because 

105 # piecewise evaluates from the right and x < boundary is 

106 # satisfied too. Why are the tests not uncovering the 

107 # inaccuracy? First find whether it's accurate the same or if 

108 # the tests are lacking, then correct the conditionals. 

109 [ 

110 lambda x, a: _gamma.gamma.ppf(_normcdf(x), a), 

111 lambda x, a: _gamma.gamma.isf(_normcdf(-x), a), 

112 lambda x, a: _gamma._gammaisf_normcdf_large_neg_x(-x, a), 

113 ], 

114 x, alpha, 

115 ) / beta 

116 

117class loggamma(_distr.Distr): 1feabcd

118 """ 

119 https://en.wikipedia.org/wiki/Gamma_distribution, `scipy.stats.loggamma` 

120 

121 This is the distribution of the logarithm of a Gamma variable. The naming 

122 convention is the opposite of lognorm, which is the distribution of the 

123 exponential of a Normal variable. 

124 """ 

125 

126 @staticmethod 1feabcd

127 def _boundary(x): 1feabcd

128 return gamma._boundary(x) 1abcd

129 

130 @classmethod 1feabcd

131 def invfcn(cls, x, alpha): 1feabcd

132 x = jnp.asarray(x) 1abcd

133 x = x.astype(_jaxext.float_type(x)) 1abcd

134 boundary = cls._boundary(x) 1abcd

135 return _piecewise_multiarg( 1abcd

136 [x < 0, x < boundary, x >= boundary], 

137 [ 

138 lambda x, alpha: _gamma.loggamma.ppf(_normcdf(x), alpha), 

139 lambda x, alpha: _gamma.loggamma.isf(_normcdf(-x), alpha), 

140 lambda x, alpha: _gamma._loggammaisf_normcdf_large_neg_x(-x, alpha), 

141 ], 

142 x, alpha, 

143 ) 

144 

145 # TODO scipy.stats.gamma has inaccurate logsf instead of using loggamma.sf, 

146 # open an issue 

147 

148class invgamma(_distr.Distr): 1feabcd

149 """ 

150 https://en.wikipedia.org/wiki/Inverse-gamma_distribution 

151 """ 

152 

153 @staticmethod 1feabcd

154 def _boundary(x): 1feabcd

155 return -gamma._boundary(x) 1eabcd

156 

157 @classmethod 1feabcd

158 def invfcn(cls, x, alpha, beta): 1feabcd

159 x = jnp.asarray(x) 1eabcd

160 x = x.astype(_jaxext.float_type(x)) 1eabcd

161 boundary = cls._boundary(x) 1eabcd

162 return beta * _piecewise_multiarg( 1eabcd

163 [x < boundary, x < 0, x >= 0], 

164 [ 

165 lambda x, a: 1 / _gamma._gammaisf_normcdf_large_neg_x(x, a), 

166 lambda x, a: _gamma.invgamma.ppf(_normcdf(x), a), 

167 lambda x, a: _gamma.invgamma.isf(_normcdf(-x), a), 

168 ], 

169 x, alpha, 

170 ) 

171 

172def _piecewise_multiarg(conds, functions, *operands): 1feabcd

173 conds = jnp.stack(conds, axis=-1) 1eabcd

174 index = jnp.argmax(conds, axis=-1) 1eabcd

175 return _vectorized_switch(index, functions, *operands) 1eabcd

176 

177@functools.partial(jnp.vectorize, excluded=(1,)) 1feabcd

178def _vectorized_switch(index, branches, *operands): 1feabcd

179 return jax.lax.switch(index, branches, *operands) 1eabcd

180 

181class halfcauchy(_distr.Distr): 1feabcd

182 """ 

183 https://en.wikipedia.org/wiki/Cauchy_distribution, `scipy.stats.halfcauchy` 

184 """ 

185 

186 @staticmethod 1feabcd

187 def _ppf(p): 1feabcd

188 return jnp.tan(jnp.pi * p / 2) 1eabcd

189 

190 @staticmethod 1feabcd

191 def _isf(p): 1feabcd

192 return 1 / jnp.tan(jnp.pi * p / 2) 1eabcd

193 

194 @classmethod 1feabcd

195 def invfcn(cls, x, gamma): 1feabcd

196 return gamma * jnp.where(x < 0, 1eabcd

197 cls._ppf(_normcdf(x)), 

198 cls._isf(_normcdf(-x)), 

199 ) 

200 

201class halfnorm(_distr.Distr): 1feabcd

202 """ 

203 https://en.wikipedia.org/wiki/Half-normal_distribution 

204 """ 

205 

206 @staticmethod 1feabcd

207 def _ppf(p): 1feabcd

208 # F(x) = 2 Φ(x) - 1 

209 # --> F⁻¹(p) = Φ⁻¹((1 + p) / 2) 

210 return jspecial.ndtri((1 + p) / 2) 1eabcd

211 

212 @staticmethod 1feabcd

213 def _isf(p): 1feabcd

214 # Φ(-x) = 1 - Φ(x) 

215 # --> Φ⁻¹(1 - p) = -Φ⁻¹(p) 

216 # S(x) = 1 - F(x) 

217 # --> S⁻¹(p) = F⁻¹(1 - p) 

218 # = Φ⁻¹((2 - p) / 2) 

219 # = Φ⁻¹(1 - p / 2) 

220 # = -Φ⁻¹(p / 2) 

221 return -jspecial.ndtri(p / 2) 1eabcd

222 

223 @classmethod 1feabcd

224 def invfcn(cls, x, sigma): 1feabcd

225 return sigma * jnp.where(x < 0, 1eabcd

226 cls._ppf(_normcdf(x)), 

227 cls._isf(_normcdf(-x)), 

228 ) 

229 

230class uniform(_distr.Distr): 1feabcd

231 """ 

232 https://en.wikipedia.org/wiki/Continuous_uniform_distribution 

233 """ 

234 

235 @staticmethod 1feabcd

236 def invfcn(x, a, b): 1feabcd

237 return a + (b - a) * _normcdf(x) 1feabcd

238 

239class lognorm(_distr.Distr): 1feabcd

240 """ 

241 https://en.wikipedia.org/wiki/Log-normal_distribution 

242 """ 

243 

244 @staticmethod 1feabcd

245 def invfcn(x, mu, sigma): 1feabcd

246 return jnp.exp(mu + sigma * x) 1eabcd