Coverage for src/lsqfitgp/_special/_gamma.py: 100%
23 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_special/_gamma.py
2#
3# Copyright (c) 2022, 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20from jax import numpy as jnp 1feabcd
21from jax.scipy import special as jspecial 1feabcd
23from .. import _jaxext 1feabcd
25def sgngamma(x): 1feabcd
26 return jnp.where((x > 0) | (x % 2 < 1), 1, -1) 1eabcd
28def gamma(x): 1feabcd
29 return sgngamma(x) * jnp.exp(jspecial.gammaln(x)) 1eabcd
31def poch(x, k): 1feabcd
32 return jnp.exp(jspecial.gammaln(x + k) - jspecial.gammaln(x)) # DLMF 5.2.5 1abcd
33 # TODO does not handle properly special cases with x and/or k nonpositive
34 # integers
36def gamma_incr(x, e): 1feabcd
37 """
38 Compute Γ(x+e) / (Γ(x)Γ(1+e)) - 1 accurately for x >= 2 and |e| < 1/2
39 """
41 # G(x + e) / G(x)G(1+e) - 1 =
42 # = expm1(log(G(x + e) / G(x)G(1+e))) =
43 # = expm1(log G(x + e) - log G(x) - log G(1 + e))
45 t = _jaxext.float_type(x, e) 1eabcd
46 n = 23 if t == jnp.float64 else 10 1eabcd
47 # n such that 1/2^n 1/n! d^n/dx^n log G(x) |_x=2 < eps
48 k = jnp.arange(n).reshape((n,) + (1,) * max(x.ndim, e.ndim)) 1eabcd
49 coef = jspecial.polygamma(k, x) 1eabcd
50 fact = jnp.cumprod(1 + k, 0, t) 1eabcd
51 coef /= fact 1eabcd
52 gammaln = e * jnp.polyval(coef[::-1], e) 1eabcd
53 return jnp.expm1(gammaln - gammaln1(e)) 1eabcd
55 # Like gammaln1, I thought that writing as log Γ(1+x+e) - log Γ(1+x) +
56 # - log1p(e/x) would increase accuracy, instead it deteriorates
58def gammaln1(x): 1feabcd
59 """ compute log Γ(1+x) accurately for |x| <= 1/2 """
61 t = _jaxext.float_type(x) 1eabcd
62 coef = jnp.array(_gammaln1_coef_1[:48], t) # 48 found by trial and error 1eabcd
63 return x * jnp.polyval(coef[::-1], x) 1eabcd
65 # I thought that writing this as log Γ(2+x) - log1p(x) would be more
66 # accurate but it isn't, probably there's a cancellation for some values of
67 # x
69_gammaln1_coef_1 = [ # = _gen_gammaln1_coef(53, 1) 1feabcd
70 -0.5772156649015329,
71 0.8224670334241132,
72 -0.40068563438653143,
73 0.27058080842778454,
74 -0.20738555102867398,
75 0.1695571769974082,
76 -0.1440498967688461,
77 0.12550966952474304,
78 -0.11133426586956469,
79 0.1000994575127818,
80 -0.09095401714582904,
81 0.083353840546109,
82 -0.0769325164113522,
83 0.07143294629536133,
84 -0.06666870588242046,
85 0.06250095514121304,
86 -0.058823978658684585,
87 0.055555767627403614,
88 -0.05263167937961666,
89 0.05000004769810169,
90 -0.047619070330142226,
91 0.04545455629320467,
92 -0.04347826605304026,
93 0.04166666915034121,
94 -0.04000000119214014,
95 0.03846153903467518,
96 -0.037037037312989324,
97 0.035714285847333355,
98 -0.034482758684919304,
99 0.03333333336437758,
100 -0.03225806453115042,
101 0.03125000000727597,
102 -0.030303030306558044,
103 0.029411764707594344,
104 -0.02857142857226011,
105 0.027777777778181998,
106 -0.027027027027223673,
107 0.02631578947377995,
108 -0.025641025641072283,
109 0.025000000000022737,
110 -0.024390243902450117,
111 0.023809523809529224,
112 -0.023255813953491015,
113 0.02272727272727402,
114 -0.022222222222222855,
115 0.021739130434782917,
116 -0.021276595744681003,
117 0.02083333333333341,
118 -0.02040816326530616,
119 0.020000000000000018,
120 -0.019607843137254912,
121 0.019230769230769235,
122 -0.01886792452830189,
123]
125def _gen_gammaln1_coef(n, x): # pragma: no cover 1feabcd
126 """ compute Taylor coefficients of log Γ(x) """
127 import mpmath as mp
128 with mp.workdps(32):
129 return [float(mp.polygamma(k, x) / mp.fac(k + 1)) for k in range(n)]