Coverage for src/lsqfitgp/copula/_gamma.py: 100%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/copula/_gamma.py 

2# 

3# Copyright (c) 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" 

21JAX-compatible implementation of the gamma and related distributions 

22""" 

23 

24import functools 1feabcd

25 

26from scipy import special 1feabcd

27import jax 1feabcd

28from jax.scipy import special as jspecial 1feabcd

29from jax import numpy as jnp 1feabcd

30import numpy 1feabcd

31 

32from .. import _jaxext 1feabcd

33 

34def _castto(func, type): 1feabcd

35 @functools.wraps(func) 1eabcd

36 def newfunc(*args, **kw): 1eabcd

37 return func(*args, **kw).astype(type) 1eabcd

38 return newfunc 1eabcd

39 

40@jax.custom_jvp 1feabcd

41def gammainccinv(a, y): 1feabcd

42 a = jnp.asarray(a) 1eabcd

43 y = jnp.asarray(y) 1eabcd

44 dtype = _jaxext.float_type(a.dtype, y.dtype) 1eabcd

45 ufunc = _castto(special.gammainccinv, dtype) 1eabcd

46 return _jaxext.pure_callback_ufunc(ufunc, dtype, a, y) 1eabcd

47 

48dQ_da = _jaxext.elementwise_grad(jspecial.gammaincc, 0) 1feabcd

49dQ_dx = _jaxext.elementwise_grad(jspecial.gammaincc, 1) 1feabcd

50 

51@gammainccinv.defjvp 1feabcd

52def gammainccinv_jvp(primals, tangents): 1feabcd

53 a, y = primals 1eabcd

54 at, yt = tangents 1eabcd

55 

56 x = gammainccinv(a, y) 1eabcd

57 

58 dQ_dx_a_x = dQ_dx(a, x) 1eabcd

59 dQinv_dy_a_y = 1 / dQ_dx_a_x 1eabcd

60 xt = dQinv_dy_a_y * yt 1eabcd

61 

62 if jnp.issubdtype(jnp.asarray(a).dtype, jnp.floating): # modern jax would be: getattr(at, 'dtype', jnp.float64) != jax.float0 1eabcd

63 dQ_da_a_x = dQ_da(a, x) 1abcd

64 dQinv_da_a_y = -dQinv_dy_a_y * dQ_da_a_x 1abcd

65 xt += dQinv_da_a_y * at 1abcd

66 

67 return x, xt 1eabcd

68 

69@jax.custom_jvp 1feabcd

70def gammaincinv(a, y): 1feabcd

71 a = jnp.asarray(a) 1eabcd

72 y = jnp.asarray(y) 1eabcd

73 dtype = _jaxext.float_type(a.dtype, y.dtype) 1eabcd

74 ufunc = _castto(special.gammaincinv, dtype) 1eabcd

75 return _jaxext.pure_callback_ufunc(ufunc, dtype, a, y) 1eabcd

76 

77dP_da = _jaxext.elementwise_grad(jspecial.gammainc, 0) 1feabcd

78dP_dx = _jaxext.elementwise_grad(jspecial.gammainc, 1) 1feabcd

79 

80@gammaincinv.defjvp 1feabcd

81def gammaincinv_jvp(primals, tangents): 1feabcd

82 a, y = primals 1eabcd

83 at, yt = tangents 1eabcd

84 

85 x = gammaincinv(a, y) 1eabcd

86 

87 dP_dx_a_x = dP_dx(a, x) 1eabcd

88 dPinv_dy_a_y = 1 / dP_dx_a_x 1eabcd

89 xt = dPinv_dy_a_y * yt 1eabcd

90 

91 if jnp.issubdtype(jnp.asarray(a).dtype, jnp.floating): # modern jax would be: getattr(at, 'dtype', jnp.float64) != jax.float0 1eabcd

92 dP_da_a_x = dP_da(a, x) 1abcd

93 dPinv_da_a_y = -dPinv_dy_a_y * dP_da_a_x 1abcd

94 xt += dPinv_da_a_y * at 1abcd

95 

96 return x, xt 1eabcd

97 

98def _gammaisf_normcdf_large_neg_x(x, a): 1feabcd

99 logphi = lambda x: -1/2 * jnp.log(2 * jnp.pi) - 1/2 * jnp.square(x) - jnp.log(-x) 1eabcd

100 logq = logphi(x) 1eabcd

101 loggammaa = jspecial.gammaln(a) 1eabcd

102 f = lambda y: (a - 1) * jnp.log(y) - y - loggammaa - logq 1eabcd

103 f1 = lambda y: (a - 1) / y - 1 1eabcd

104 y0 = -logq 1eabcd

105 y1 = y0 - ((a - 1) * jnp.log(y0) - loggammaa) / ((a - 1) / y0 - 1) 1eabcd

106 return y1 1eabcd

107 

108 # TODO Improve the accuracy. I tried adding one Newton step more, but it 

109 # does not improve the accuracy. I probably have to add terms to the 

110 # approximations of Phi and Q. I could try first special.erfcx for Phi. 

111 

112 # x -> -∞, q -> 0+, y -> ∞ 

113 # q = Φ(x) ≈ -1/√2π exp(-x²/2)/x 

114 # q = Q(a, y) ≈ y^(a-1) e^-y / Γ(a) 

115 # gamma.isf(q, a) = Q⁻¹(a, q) 

116 # log q = -1/2 log 2π - x²/2 - log(-x) (1) 

117 # log q = (a - 1) log y - y - log Γ(a) (2) 

118 # f(y) = (a - 1) log y - y - log Γ(a) - log(q) 

119 # = 0 by (2) 

120 # f'(y) = (a - 1) / y - 1 

121 # y_0 = -log q by considering y -> ∞ 

122 # y_1 = y_0 - f(y_0) / f'(y_0) Newton step 

123 

124def _loggammaisf_normcdf_large_neg_x(x, a): 1feabcd

125 logphi = lambda x: -1/2 * jnp.log(2 * jnp.pi) - 1/2 * jnp.square(x) - jnp.log(-x) 1abcd

126 logq = logphi(x) 1abcd

127 loggammaa = jspecial.gammaln(a) 1abcd

128 g = lambda logy: (a - 1) * logy - jnp.exp(logy) - loggammaa - logq 1abcd

129 g1 = lambda logy: (a - 1) - jnp.exp(logy) 1abcd

130 logy0 = jnp.log(-logq) 1abcd

131 logy1 = logy0 - ((a - 1) * logy0 - loggammaa) / ((a - 1) + logq) 1abcd

132 return logy1 1abcd

133 

134class gamma: 1feabcd

135 

136 @staticmethod 1feabcd

137 def ppf(q, a): 1feabcd

138 return gammaincinv(a, q) 1abcd

139 

140 @staticmethod 1feabcd

141 def isf(q, a): 1feabcd

142 return gammainccinv(a, q) 1abcd

143 

144class invgamma: 1feabcd

145 

146 @staticmethod 1feabcd

147 def ppf(q, a): 1feabcd

148 return 1 / gammainccinv(a, q) 1eabcd

149 

150 @staticmethod 1feabcd

151 def isf(q, a): 1feabcd

152 return 1 / gammaincinv(a, q) 1eabcd

153 

154 @staticmethod 1feabcd

155 def logpdf(x, a): 1feabcd

156 return -(a + 1) * jnp.log(x) - 1 / x - jspecial.gammaln(a) 1abcd

157 

158 @staticmethod 1feabcd

159 def cdf(x, a): 1feabcd

160 return jspecial.gammaincc(a, 1 / x) 1abcd

161 

162class loggamma: 1feabcd

163 

164 @staticmethod 1feabcd

165 def ppf(q, c): 1feabcd

166 # scipy code: 

167 # g = sc.gammaincinv(c, q) 

168 # return _lazywhere(g < _XMIN, (g, q, c), 

169 # lambda g, q, c: (np.log(q) + sc.gammaln(c+1))/c, 

170 # f2=lambda g, q, c: np.log(g)) 

171 g = gammaincinv(c, q) 1abcd

172 return jnp.where(g < jnp.finfo(g.dtype).tiny, 1abcd

173 (jnp.log(q) + jspecial.gammaln(c + 1)) / c, 

174 jnp.log(g), 

175 ) 

176 

177 @staticmethod 1feabcd

178 def isf(q, c): 1feabcd

179 # scipy code: 

180 # g = sc.gammainccinv(c, q) 

181 # return _lazywhere(g < _XMIN, (g, q, c), 

182 # lambda g, q, c: (np.log1p(-q) + sc.gammaln(c+1))/c, 

183 # f2=lambda g, q, c: np.log(g)) 

184 g = gammainccinv(c, q) 1abcd

185 return jnp.where(g < jnp.finfo(g.dtype).tiny, 1abcd

186 (jnp.log1p(-q) + jspecial.gammaln(c + 1)) / c, 

187 jnp.log(g), 

188 )