Coverage for src/lsqfitgp/_special/_gamma.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_special/_gamma.py 

2# 

3# Copyright (c) 2022, 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20from jax import numpy as jnp 1feabcd

21from jax.scipy import special as jspecial 1feabcd

22 

23from .. import _jaxext 1feabcd

24 

25def sgngamma(x): 1feabcd

26 return jnp.where((x > 0) | (x % 2 < 1), 1, -1) 1eabcd

27 

28def gamma(x): 1feabcd

29 return sgngamma(x) * jnp.exp(jspecial.gammaln(x)) 1eabcd

30 

31def poch(x, k): 1feabcd

32 return jnp.exp(jspecial.gammaln(x + k) - jspecial.gammaln(x)) # DLMF 5.2.5 1abcd

33 # TODO does not handle properly special cases with x and/or k nonpositive 

34 # integers 

35 

36def gamma_incr(x, e): 1feabcd

37 """ 

38 Compute Γ(x+e) / (Γ(x)Γ(1+e)) - 1 accurately for x >= 2 and |e| < 1/2 

39 """ 

40 

41 # G(x + e) / G(x)G(1+e) - 1 = 

42 # = expm1(log(G(x + e) / G(x)G(1+e))) = 

43 # = expm1(log G(x + e) - log G(x) - log G(1 + e)) 

44 

45 t = _jaxext.float_type(x, e) 1eabcd

46 n = 23 if t == jnp.float64 else 10 1eabcd

47 # n such that 1/2^n 1/n! d^n/dx^n log G(x) |_x=2 < eps 

48 k = jnp.arange(n).reshape((n,) + (1,) * max(x.ndim, e.ndim)) 1eabcd

49 coef = jspecial.polygamma(k, x) 1eabcd

50 fact = jnp.cumprod(1 + k, 0, t) 1eabcd

51 coef /= fact 1eabcd

52 gammaln = e * jnp.polyval(coef[::-1], e) 1eabcd

53 return jnp.expm1(gammaln - gammaln1(e)) 1eabcd

54 

55 # Like gammaln1, I thought that writing as log Γ(1+x+e) - log Γ(1+x) + 

56 # - log1p(e/x) would increase accuracy, instead it deteriorates 

57 

58def gammaln1(x): 1feabcd

59 """ compute log Γ(1+x) accurately for |x| <= 1/2 """ 

60 

61 t = _jaxext.float_type(x) 1eabcd

62 coef = jnp.array(_gammaln1_coef_1[:48], t) # 48 found by trial and error 1eabcd

63 return x * jnp.polyval(coef[::-1], x) 1eabcd

64 

65 # I thought that writing this as log Γ(2+x) - log1p(x) would be more 

66 # accurate but it isn't, probably there's a cancellation for some values of 

67 # x 

68 

69_gammaln1_coef_1 = [ # = _gen_gammaln1_coef(53, 1) 1feabcd

70 -0.5772156649015329, 

71 0.8224670334241132, 

72 -0.40068563438653143, 

73 0.27058080842778454, 

74 -0.20738555102867398, 

75 0.1695571769974082, 

76 -0.1440498967688461, 

77 0.12550966952474304, 

78 -0.11133426586956469, 

79 0.1000994575127818, 

80 -0.09095401714582904, 

81 0.083353840546109, 

82 -0.0769325164113522, 

83 0.07143294629536133, 

84 -0.06666870588242046, 

85 0.06250095514121304, 

86 -0.058823978658684585, 

87 0.055555767627403614, 

88 -0.05263167937961666, 

89 0.05000004769810169, 

90 -0.047619070330142226, 

91 0.04545455629320467, 

92 -0.04347826605304026, 

93 0.04166666915034121, 

94 -0.04000000119214014, 

95 0.03846153903467518, 

96 -0.037037037312989324, 

97 0.035714285847333355, 

98 -0.034482758684919304, 

99 0.03333333336437758, 

100 -0.03225806453115042, 

101 0.03125000000727597, 

102 -0.030303030306558044, 

103 0.029411764707594344, 

104 -0.02857142857226011, 

105 0.027777777778181998, 

106 -0.027027027027223673, 

107 0.02631578947377995, 

108 -0.025641025641072283, 

109 0.025000000000022737, 

110 -0.024390243902450117, 

111 0.023809523809529224, 

112 -0.023255813953491015, 

113 0.02272727272727402, 

114 -0.022222222222222855, 

115 0.021739130434782917, 

116 -0.021276595744681003, 

117 0.02083333333333341, 

118 -0.02040816326530616, 

119 0.020000000000000018, 

120 -0.019607843137254912, 

121 0.019230769230769235, 

122 -0.01886792452830189, 

123] 

124 

125def _gen_gammaln1_coef(n, x): # pragma: no cover 1feabcd

126 """ compute Taylor coefficients of log Γ(x) """ 

127 import mpmath as mp 

128 with mp.workdps(32): 

129 return [float(mp.polygamma(k, x) / mp.fac(k + 1)) for k in range(n)]