Coverage for src / bartz / jaxext / scipy / special.py: 95%
59 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-18 15:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-18 15:24 +0000
1# bartz/src/bartz/jaxext/scipy/special.py
2#
3# Copyright (c) 2025, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Mockup of the :external:py:mod:`scipy.special` module."""
27from functools import wraps
29from jax import ShapeDtypeStruct, jit, pure_callback
30from jax import numpy as jnp
31from scipy.special import gammainccinv as scipy_gammainccinv
34def _float_type(*args):
35 """Determine the jax floating point result type given operands/types."""
36 t = jnp.result_type(*args) 1zAB
37 return jnp.sin(jnp.empty(0, t)).dtype 1zAB
40def _castto(func, dtype):
41 @wraps(func) 1zAB
42 def newfunc(*args, **kw): 1zAB
43 return func(*args, **kw).astype(dtype) 1CDEFzAGHIJKLMNOPQRSTUVWXYZ0123456789!#$%'()*+,-./:;=B
45 return newfunc 1zAB
48@jit
49def gammainccinv(a, y):
50 """Survival function inverse of the Gamma(a, 1) distribution."""
51 shape = jnp.broadcast_shapes(a.shape, y.shape) 1zAB
52 dtype = _float_type(a.dtype, y.dtype) 1zAB
53 dummy = ShapeDtypeStruct(shape, dtype) 1zAB
54 ufunc = _castto(scipy_gammainccinv, dtype) 1zAB
55 return pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims') 1zAB
58################# COPIED AND ADAPTED FROM JAX ##################
59# Copyright 2018 The JAX Authors.
60#
61# Licensed under the Apache License, Version 2.0 (the "License");
62# you may not use this file except in compliance with the License.
63# You may obtain a copy of the License at
64#
65# https://www.apache.org/licenses/LICENSE-2.0
66#
67# Unless required by applicable law or agreed to in writing, software
68# distributed under the License is distributed on an "AS IS" BASIS,
69# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70# See the License for the specific language governing permissions and
71# limitations under the License.
73import numpy as np
74from jax import debug_infs, lax
77def ndtri(p):
78 """Compute the inverse of the CDF of the Normal distribution function.
80 This is a patch of `jax.scipy.special.ndtri`.
81 """
82 dtype = lax.dtype(p) 1abcdefghijklmnopqrstuvwx
83 if dtype not in (jnp.float32, jnp.float64): 83 ↛ 84line 83 didn't jump to line 84 because the condition on line 83 was never true1abcdefghijklmnopqrstuvwx
84 msg = f'x.dtype={dtype} is not supported, see docstring for supported types.'
85 raise TypeError(msg)
86 return _ndtri(p) 1abcdefghijklmnopqrstuvwx
89def _ndtri(p):
90 # Constants used in piece-wise rational approximations. Taken from the cephes
91 # library:
92 # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
93 p0 = list( 1abcdefghijklmnopqrstuvwx
94 reversed(
95 [
96 -5.99633501014107895267e1,
97 9.80010754185999661536e1,
98 -5.66762857469070293439e1,
99 1.39312609387279679503e1,
100 -1.23916583867381258016e0,
101 ]
102 )
103 )
104 q0 = list( 1abcdefghijklmnopqrstuvwx
105 reversed(
106 [
107 1.0,
108 1.95448858338141759834e0,
109 4.67627912898881538453e0,
110 8.63602421390890590575e1,
111 -2.25462687854119370527e2,
112 2.00260212380060660359e2,
113 -8.20372256168333339912e1,
114 1.59056225126211695515e1,
115 -1.18331621121330003142e0,
116 ]
117 )
118 )
119 p1 = list( 1abcdefghijklmnopqrstuvwx
120 reversed(
121 [
122 4.05544892305962419923e0,
123 3.15251094599893866154e1,
124 5.71628192246421288162e1,
125 4.40805073893200834700e1,
126 1.46849561928858024014e1,
127 2.18663306850790267539e0,
128 -1.40256079171354495875e-1,
129 -3.50424626827848203418e-2,
130 -8.57456785154685413611e-4,
131 ]
132 )
133 )
134 q1 = list( 1abcdefghijklmnopqrstuvwx
135 reversed(
136 [
137 1.0,
138 1.57799883256466749731e1,
139 4.53907635128879210584e1,
140 4.13172038254672030440e1,
141 1.50425385692907503408e1,
142 2.50464946208309415979e0,
143 -1.42182922854787788574e-1,
144 -3.80806407691578277194e-2,
145 -9.33259480895457427372e-4,
146 ]
147 )
148 )
149 p2 = list( 1abcdefghijklmnopqrstuvwx
150 reversed(
151 [
152 3.23774891776946035970e0,
153 6.91522889068984211695e0,
154 3.93881025292474443415e0,
155 1.33303460815807542389e0,
156 2.01485389549179081538e-1,
157 1.23716634817820021358e-2,
158 3.01581553508235416007e-4,
159 2.65806974686737550832e-6,
160 6.23974539184983293730e-9,
161 ]
162 )
163 )
164 q2 = list( 1abcdefghijklmnopqrstuvwx
165 reversed(
166 [
167 1.0,
168 6.02427039364742014255e0,
169 3.67983563856160859403e0,
170 1.37702099489081330271e0,
171 2.16236993594496635890e-1,
172 1.34204006088543189037e-2,
173 3.28014464682127739104e-4,
174 2.89247864745380683936e-6,
175 6.79019408009981274425e-9,
176 ]
177 )
178 )
180 dtype = lax.dtype(p).type 1abcdefghijklmnopqrstuvwx
181 shape = jnp.shape(p) 1abcdefghijklmnopqrstuvwx
183 def _create_polynomial(var, coeffs): 1abcdefghijklmnopqrstuvwx
184 """Compute n_th order polynomial via Horner's method."""
185 coeffs = np.array(coeffs, dtype) 1abcdefghijklmnopqrstuvwx
186 if not coeffs.size: 1abcdefghijklmnopqrstuvwx
187 return jnp.zeros_like(var) 1abcdefghijklmnopqrstuvwx
188 return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var 1abcdefghijklmnopqrstuvwx
190 maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.0)), dtype(1.0) - p, p) 1abcdefghijklmnopqrstuvwx
191 # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
192 # later on. The result from the computation when p == 0 is not used so any
193 # number that doesn't result in NaNs is fine.
194 sanitized_mcp = jnp.where( 1abcdefghijklmnopqrstuvwx
195 maybe_complement_p == dtype(0.0),
196 jnp.full(shape, dtype(0.5)),
197 maybe_complement_p,
198 )
200 # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
201 w = sanitized_mcp - dtype(0.5) 1abcdefghijklmnopqrstuvwx
202 ww = lax.square(w) 1abcdefghijklmnopqrstuvwx
203 x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) / _create_polynomial(ww, q0)) 1abcdefghijklmnopqrstuvwx
204 x_for_big_p *= -dtype(np.sqrt(2.0 * np.pi)) 1abcdefghijklmnopqrstuvwx
206 # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
207 # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
208 # arrays based on whether p < exp(-32).
209 z = lax.sqrt(dtype(-2.0) * lax.log(sanitized_mcp)) 1abcdefghijklmnopqrstuvwx
210 first_term = z - lax.log(z) / z 1abcdefghijklmnopqrstuvwx
211 second_term_small_p = ( 1abcdefghijklmnopqrstuvwx
212 _create_polynomial(dtype(1.0) / z, p2)
213 / _create_polynomial(dtype(1.0) / z, q2)
214 / z
215 )
216 second_term_otherwise = ( 1abcdefghijklmnopqrstuvwx
217 _create_polynomial(dtype(1.0) / z, p1)
218 / _create_polynomial(dtype(1.0) / z, q1)
219 / z
220 )
221 x_for_small_p = first_term - second_term_small_p 1abcdefghijklmnopqrstuvwx
222 x_otherwise = first_term - second_term_otherwise 1abcdefghijklmnopqrstuvwx
224 x = jnp.where( 1abcdefghijklmnopqrstuvwx
225 sanitized_mcp > dtype(np.exp(-2.0)),
226 x_for_big_p,
227 jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise),
228 )
230 x = jnp.where(p > dtype(1.0 - np.exp(-2.0)), x, -x) 1abcdefghijklmnopqrstuvwx
231 with debug_infs(False): 1abcdefghijklmnopqrstuvwx
232 infinity = jnp.full(shape, dtype(np.inf)) 1abcdefghijklmnopqrstuvwx
233 neg_infinity = -infinity 1abcdefghijklmnopqrstuvwx
234 return jnp.where( 1abcdefghijklmnopqrstuvwx
235 p == dtype(0.0), neg_infinity, jnp.where(p == dtype(1.0), infinity, x)
236 )
239################################################################