Coverage for src/lsqfitgp/copula/_distr.py: 95%
213 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/copula/_distr.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" define Distr and distribution """
22import abc 1feabcd
23import functools 1feabcd
24import collections 1feabcd
25import numbers 1feabcd
26import inspect 1feabcd
27import types 1feabcd
28import math 1feabcd
30import gvar 1feabcd
31import numpy 1feabcd
32import jax 1feabcd
33from jax import numpy as jnp 1feabcd
35from .. import _gvarext 1feabcd
36from .. import _array 1feabcd
37from .. import _signature 1feabcd
38from . import _base 1feabcd
40######### The following 5 functions are adapted from numpy.lib.mixins #########
42def _disables_array_ufunc(obj): 1feabcd
43 """True when __array_ufunc__ is set to None."""
44 return getattr(obj, '__array_ufunc__', NotImplemented) is None 1eabcd
46def _binary_method(ufunc, name): 1feabcd
47 """Implement a forward binary method with a ufunc, e.g., __add__."""
48 def func(self, other): 1feabcd
49 if _disables_array_ufunc(other): 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true1abcd
50 return NotImplemented
51 return ufunc(self, other) 1abcd
52 func.__name__ = '__{}__'.format(name) 1feabcd
53 return func 1feabcd
55def _reflected_binary_method(ufunc, name): 1feabcd
56 """Implement a reflected binary method with a ufunc, e.g., __radd__."""
57 def func(self, other): 1feabcd
58 if _disables_array_ufunc(other): 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true1eabcd
59 return NotImplemented
60 return ufunc(other, self) 1eabcd
61 func.__name__ = '__r{}__'.format(name) 1feabcd
62 return func 1feabcd
64def _numeric_methods(ufunc, name): 1feabcd
65 """Implement forward and reflected binary methods with a ufunc."""
66 return (_binary_method(ufunc, name), 1feabcd
67 _reflected_binary_method(ufunc, name))
69def _unary_method(ufunc, name): 1feabcd
70 """Implement a unary special method with a ufunc."""
71 def func(self): 1feabcd
72 return ufunc(self) 1abcd
73 func.__name__ = '__{}__'.format(name) 1feabcd
74 return func 1feabcd
76###############################################################################
78class Distr(_base.DistrBase): 1feabcd
79 r"""
81 Abstract base class to represent probability distributions.
83 A `Distr` object represents a probability distribution of a variable in
84 :math:`\mathbb R^n`, and provides a transformation function from a
85 (multivariate) Normal variable to the target random variable.
87 The main functionality is defined in `DistrBase`. The additional attributes
88 and methods `params`, `signature`, and `invfcn` are not intended for common
89 usage.
91 Parameters
92 ----------
93 *params : tuple of scalar, array or Distr
94 The parameters of the distribution. If the parameters have leading axes
95 other than those required, the distribution is repeated i.i.d.
96 over those axes. If a parameter is an instance of `Distr` itself, it
97 is a random parameter and its distribution is accounted for.
98 shape : int or tuple of int
99 The shape of the array of i.i.d. variables to be represented, scalar by
100 default. If the variable is multivariate, this shape adds as leading
101 axes in the array. This shape broadcasts with the non-core shapes of the
102 parameters.
103 name : str, optional
104 If specified, the distribution is defined for usage with
105 `gvar.BufferDict` using `gvar.BufferDict.add_distribution`, and for
106 convenience the constructor returns an array of gvars with the
107 appropriate shape instead of the `Distr` object. See `add_distribution`.
109 Returns
110 -------
111 If `name` is None (default):
113 distr : Distr
114 An object representing the distribution.
116 Else:
118 gvars : array of gvars
119 An array of primary gvars that can be set as value in a
120 `gvar.BufferDict` under a key that uses the just defined name.
122 Attributes
123 ----------
124 params : tuple
125 The parameters as passed to the constructor.
126 signature : Signature
127 An object representing the signature of `invfcn`. This is a class
128 attribute.
130 Methods
131 -------
132 invfcn : classmethod
133 Transformation function from a (multivariate) Normal variable to the
134 target random variable.
136 Examples
137 --------
139 Use directly with `gvar.BufferDict` by setting `name`:
141 >>> copula = gvar.BufferDict({
142 ... 'A(x)': lgp.copula.beta(1, 1, name='A'),
143 ... 'B(y)': lgp.copula.beta(3, 5, name='B'),
144 ... })
145 >>> copula['x']
146 0.50(40)
147 >>> copula['y']
148 0.36(18)
150 Corresponding "unrolled" usage:
152 >>> A = lgp.copula.beta(1, 1)
153 >>> B = lgp.copula.beta(3, 5)
154 >>> A.add_distribution('A')
155 >>> B.add_distribution('B')
156 >>> copula = gvar.BufferDict({
157 ... 'A(x)': A.gvars(),
158 ... 'B(y)': B.gvars(),
159 ... })
161 Notice that, although the name used for `add_distribution` must be globally
162 unique, for convenience it is permitted to redefine the same distribution
163 family with the same parameters, even from another `Distr` instance.
165 To generate automatically sensible names and avoid repeating them twice, use
166 `makedict`:
168 >>> lgp.copula.makedict({
169 ... 'x': lgp.copula.beta(1, 1),
170 ... 'y': lgp.copula.beta(3, 5),
171 ... })
172 BufferDict({'__copula_beta{1, 1}(x)': 0.0(1.0), '__copula_beta{3, 5}(y)': 0.0(1.0)})
174 Define a distribution with a random parameter:
176 >>> X = lgp.copula.halfnorm(np.sqrt(lgp.copula.invgamma(1, 1)))
177 >>> X
178 halfnorm(sqrt(invgamma(1, 1)))
180 Now `X` represents the model
182 .. math::
183 \sigma^2 &\sim \mathrm{InvGamma}(1, 1), \\
184 X \mid \sigma &\sim \mathrm{HalfNorm}(\sigma).
186 In general it is possible to transform a `Distr` with `numpy` ufuncs and
187 continuous arithmetic operations.
189 Repeated usage of `Distr` instances for random parameters will share
190 those parameters in the distributions. The following code:
192 >>> sigma2 = lgp.copula.invgamma(1, 1)
193 >>> X = lgp.copula.halfnorm(np.sqrt(sigma2))
194 >>> Y = lgp.copula.halfcauchy(np.sqrt(sigma2))
196 Corresponds to the model
198 .. math::
199 \sigma^2 &\sim \mathrm{InvGamma}(1, 1), \\
200 X \mid \sigma &\sim \mathrm{HalfNorm}(\sigma), \\
201 Y \mid \sigma &\sim \mathrm{HalfCauchy}(\sigma),
203 with the same parameter :math:`\sigma^2` shared between the two
204 distributions. However, if the distributions are now put into a
205 `gvar.BufferDict`, with
207 >>> sigma2.add_distribution('distr_sigma2')
208 >>> X.add_distribution('distr_X')
209 >>> Y.add_distribution('distr_Y')
210 >>> bd = gvar.BufferDict({
211 ... 'distr_sigma2(sigma2)': sigma2.gvars(),
212 ... 'distr_X(X)': X.gvars(),
213 ... 'distr_Y(Y)': Y.gvars(),
214 ... })
216 then this relationship breaks down; the model represented by the dictionary
217 `bd` is
219 .. math::
220 \sigma^2 &\sim \mathrm{InvGamma}(1, 1), \\
221 X \mid \sigma_X &\sim \mathrm{HalfNorm}(\sigma_X), \quad
222 & \sigma_X^2 &\sim \mathrm{InvGamma}(1, 1), \\
223 Y \mid \sigma_Y &\sim \mathrm{HalfCauchy}(\sigma_Y), \quad
224 & \sigma_Y^2 &\sim \mathrm{InvGamma}(1, 1),
226 with separate, independent parameters :math:`\sigma,\sigma_X,\sigma_Y`,
227 because each dictionary entry is evaluated separately. Indeed, trying to do
228 this with `makedict` will raise an error:
230 >>> bd = lgp.copula.makedict({'sigma2': sigma2, 'X': X, 'Y': Y})
231 ValueError: cross-key occurrences of object(s):
232 invgamma with id 6201535248: <sigma2>, <X.0.0>, <Y.0.0>
234 To use all the distributions at once while preserving the relationships,
235 put them into a container of choice and wrap it as a `Copula` object:
237 >>> sigmaXY = lgp.copula.Copula({'sigma2': sigma2, 'X': X, 'Y': Y})
239 The `Copula` provides a `partial_invfcn` function to map Normal variables
240 to a structure, with the same layout as the input one, of desired variates.
241 The whole `Copula` can be used in `gvar.BufferDict`:
243 >>> bd = lgp.copula.makedict({'sigmaXY': sigmaXY})
244 >>> bd
245 BufferDict({"__copula_{'sigma2': invgamma{1, 1}, 'X': halfnorm{sqrt{_Path{path=[{DictKey{key='sigma2'},}]}}}, 'Y': halfcauchy{sqrt{_Path{path=[{DictKey{key='sigma2'},}]}}}}(sigmaXY)": array([0.0(1.0), 0.0(1.0), 0.0(1.0)], dtype=object)})
246 >>> bd['sigmaXY']
247 {'sigma2': 1.4(1.7), 'X': 0.81(89), 'Y': 1.2(1.7)}
248 >>> gvar.corr(bd['sigmaXY']['X'], bd['sigmaXY']['Y'])
249 0.21950577757757836
251 Although the actual dictionary value is a flat array, getting the unwrapped
252 key reproduces the original structure.
254 To apply arbitrary transformations, use manually `invfcn`:
256 >>> @functools.partial(lgp.gvar_gufunc, signature='(n)->(n)')
257 >>> @functools.partial(jnp.vectorize, signature='(n)->(n)')
258 >>> def model_invfcn(normal_params):
259 ... sigma2 = lgp.copula.invgamma.invfcn(normal_params[0], 1, 1)
260 ... sigma = jnp.sqrt(sigma2)
261 ... X = lgp.copula.halfnorm.invfcn(normal_params[1], sigma)
262 ... Y = lgp.copula.halfcauchy.invfcn(normal_params[2], sigma)
263 ... return jnp.stack([sigma, X, Y])
265 The `jax.numpy.vectorize` decorator makes `model_invfcn` support
266 broadcasting on additional input axes, while `gvar_gufunc` makes it accept
267 gvars as input.
269 See also
270 --------
271 DistrBase, Copula, gvar.BufferDict.uniform
273 Notes
274 -----
275 Concrete subclasses must define `invfcn`, and define the class attribute
276 `signature` to the numpy signature string of `invfcn`, unless `invfcn` is an
277 ufunc and its number of parameters can be inferred. `invfcn` must be
278 vectorized.
280 """
282 @classmethod 1feabcd
283 @abc.abstractmethod 1feabcd
284 def invfcn(cls, x, *params): 1feabcd
285 r"""
287 Normal to desired distribution transformation.
289 Maps a (multivariate) Normal variable to a variable with the desired
290 marginal distribution. In symbols: :math:`y = F^{-1}(\Phi(x))`. This
291 function is a generalized ufunc, jax traceable, vmappable one time, and
292 differentiable one time. The signature is accessible through the
293 class attribute `signature`.
295 Parameters
296 ----------
297 x : array_like
298 The input Normal variable.
299 *params : array_like
300 The parameters of the distribution.
302 Returns
303 -------
304 y : array_like
305 The output variable with the desired marginal distribution.
307 """
308 pass
310 def _get_x_core_shape(self, *preprocessed_params): 1feabcd
311 sig = self.signature.eval(None, *preprocessed_params) 1feabcd
312 return sig.core_in_shapes[0] 1feabcd
314 def _eval_shapes(self, shape): 1feabcd
316 # check number of parameters
317 if self.signature.nin != 1 + len(self.params): 1feabcd
318 raise TypeError(f'{self.__class__.__name__} distribution has ' 1abcd
319 f'{self.signature.nin - 1} parameters, but {len(self.params)} '
320 'parameters were passed to the constructor')
322 # convert shape to tuple
323 if isinstance(shape, numbers.Integral): 1feabcd
324 shape = (shape,) 1abcd
325 else:
326 shape = tuple(shape) 1feabcd
328 # make sure parameters have a shape
329 array_params = [ 1feabcd
330 p if hasattr(p, 'shape') else jnp.asarray(p)
331 for p in self.params
332 ]
334 # parse signature of cls.invfcn
335 x_core_shape = self._get_x_core_shape(*array_params) 1feabcd
336 x = jax.ShapeDtypeStruct(shape + x_core_shape, 'd') 1feabcd
337 sig = self.signature.eval(x, *array_params) 1feabcd
338 self._in_shape_1 = sig.in_shapes[0] 1feabcd
339 self.distrshape, = sig.core_out_shapes 1feabcd
340 self.shape, = sig.out_shapes 1feabcd
342 self._compute_in_shape() 1feabcd
344 def _compute_in_shape(self): 1feabcd
345 in_size = math.prod(self._in_shape_1) 1feabcd
346 cache = set() 1feabcd
347 for p in self.params: 1feabcd
348 if isinstance(p, __class__): 1feabcd
349 in_size += p._compute_in_size(cache) 1eabcd
350 if in_size == 1: 1feabcd
351 self.in_shape = () 1feabcd
352 else:
353 self.in_shape = in_size, 1abcd
354 self._ancestor_count = len(cache) 1feabcd
356 def _compute_in_size(self, cache): 1feabcd
357 if (out := super()._compute_in_size(cache)) is not None: 1eabcd
358 return out 1abcd
359 in_size = math.prod(self._in_shape_1) 1eabcd
360 for p in self.params: 1eabcd
361 if isinstance(p, __class__): 1eabcd
362 in_size += p._compute_in_size(cache) 1abcd
363 return in_size 1eabcd
365 def _partial_invfcn_internal(self, x, i, cache): 1feabcd
366 if (out := super()._partial_invfcn_internal(x, i, cache)) is not None: 1feabcd
367 return out 1abcd
369 concrete_params = [] 1feabcd
370 for p in self.params: 1feabcd
372 if isinstance(p, __class__): 1feabcd
373 p, i = p._partial_invfcn_internal(x, i, cache) 1eabcd
374 else:
375 p = jnp.asarray(p) 1feabcd
377 concrete_params.append(p) 1feabcd
379 in_size = math.prod(self._in_shape_1) 1feabcd
380 assert i + in_size <= x.size 1feabcd
381 last = x[i:i + in_size].reshape(self._in_shape_1) 1feabcd
383 y = self.invfcn(last, *concrete_params) 1feabcd
384 if y.shape != self.shape or y.dtype != self.dtype: 384 ↛ 385line 384 didn't jump to line 385 because the condition on line 384 was never true1feabcd
385 raise ValueError(f'{self.__class__.__name__}.invfcn returned '
386 f'array with shape {y.shape} and dtype {y.dtype}, while '
387 f'{self.shape} and {self.dtype} were expected')
389 cache[self] = y 1feabcd
390 return y, i + in_size 1feabcd
392 @functools.cached_property 1feabcd
393 def _partial_invfcn(self): 1feabcd
395 # determine signature
396 shapestr = lambda shape: ','.join(map(str, shape)) 1feabcd
397 signature = f'({shapestr(self.in_shape)})->({shapestr(self.shape)})' 1feabcd
399 # wrap to support gvars
400 @functools.partial(_gvarext.gvar_gufunc, signature=signature) 1feabcd
401 # @jax.jit
402 @functools.partial(jnp.vectorize, signature=signature) 1feabcd
403 def _partial_invfcn(x): 1feabcd
404 assert x.shape == self.in_shape 1feabcd
405 if not self.in_shape: 1feabcd
406 x = x[None] 1feabcd
407 cache = {} 1feabcd
408 y, i = self._partial_invfcn_internal(x, 0, cache) 1feabcd
409 assert i == x.size 1feabcd
410 assert len(cache) == 1 + self._ancestor_count 1feabcd
411 return y 1feabcd
413 return _partial_invfcn 1feabcd
415 def __init_subclass__(cls, **kw): 1feabcd
416 super().__init_subclass__(**kw) 1feabcd
418 # check and/or set signature attribute (the gufunc signature of invfcn)
419 if not hasattr(cls, 'signature'): 1feabcd
420 sig = inspect.signature(cls.invfcn) 1feabcd
421 if not all( 421 ↛ 425line 421 didn't jump to line 425 because the condition on line 421 was never true1feabcd
422 p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
423 for p in sig.parameters.values()
424 ):
425 raise ValueError('can not automatically infer signature of '
426 f'{cls.__qualname__}.invfcn')
427 cls.signature = ','.join(['()'] * len(sig.parameters)) + '->()' 1feabcd
428 if not isinstance(cls.signature, _signature.Signature): 428 ↛ 430line 428 didn't jump to line 430 because the condition on line 428 was always true1feabcd
429 cls.signature = _signature.Signature(cls.signature) 1feabcd
430 cls.signature.check_nargs(cls.invfcn) 1feabcd
432 # set dtype to float if not specified
433 if getattr(cls, 'dtype', NotImplemented) is NotImplemented: 1feabcd
434 cls.dtype = jax.dtypes.canonicalize_dtype(jnp.float64) 1feabcd
436 # set __signature__ to take positional parameters from invfcn
437 sig = inspect.signature(cls.invfcn) 1feabcd
438 pos_params = list(sig.parameters.values())[1:] 1feabcd
439 sig = inspect.signature(cls.__new__) 1feabcd
440 key_params = [ 1feabcd
441 p for i, p in enumerate(sig.parameters.values())
442 if p.kind in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
443 and i > 0
444 ]
445 cls.__signature__ = inspect.Signature(pos_params + key_params) 1feabcd
447 def __new__(cls, *params, shape=(), name=None): 1feabcd
449 self = super().__new__(cls) 1feabcd
450 self.params = params 1feabcd
451 self._eval_shapes(shape) 1feabcd
453 if name is None: 1feabcd
454 return self 1feabcd
455 else:
456 self.add_distribution(name) 1abcd
457 return self.gvars() 1abcd
459 class _Descr(collections.namedtuple('Distr', 'family shape params')): 1feabcd
460 """ static representation of a Distr object """
462 def __repr__(self): 1feabcd
463 args = list(map(repr, self.params)) 1feabcd
464 if len(self.shape) == 1: 1feabcd
465 args += [f'shape={self.shape[0]}'] 1abcd
466 elif self.shape: 1feabcd
467 args += [f'shape={self.shape}'] 1abcd
468 arglist = ', '.join(args) 1feabcd
469 return f'{self.family.__name__}({arglist})' 1feabcd
471 def _compute_staticdescr(self, path, cache): 1feabcd
472 if (obj := super()._compute_staticdescr(path, cache)) is not None: 1feabcd
473 return obj 1abcd
475 params = [] 1feabcd
476 for i, p in enumerate(self.params): 1feabcd
477 if isinstance(p, __class__): 1feabcd
478 p = p._compute_staticdescr(path + [i], cache) 1eabcd
479 else:
480 p = numpy.asarray(p).tolist() 1feabcd
481 params.append(p) 1feabcd
483 return self._Descr(self.__class__, self.shape, tuple(params)) 1feabcd
485 def _shapestr(self, shape): 1feabcd
486 if shape: 1eabcd
487 return (str(shape) 1abcd
488 .replace(',)', ')')
489 .replace('(' , '[')
490 .replace(')' , ']')
491 .replace(' ', '')
492 )
493 else:
494 return '' 1e
496 def __repr__(self, path='', cache=None): 1feabcd
498 if isinstance(cache := super().__repr__(path, cache), str): 1feabcd
499 return cache 1abcd
501 args = [] 1feabcd
502 for i, p in enumerate(self.params): 1feabcd
504 if isinstance(p, __class__): 1feabcd
505 p = p.__repr__('.'.join((path, str(i))).lstrip('.'), cache) 1eabcd
506 elif hasattr(p, 'shape'): 1feabcd
507 p = f'Array{self._shapestr(p.shape)}' 1eabcd
508 else:
509 p = repr(p) 1feabcd
510 args.append(p) 1feabcd
512 if len(self.shape) == 1: 1feabcd
513 args += [f'shape={self.shape[0]}'] 1abcd
514 elif self.shape: 1feabcd
515 args += [f'shape={self.shape}'] 1abcd
517 return f'{self.__class__.__name__}({", ".join(args)})' 1feabcd
519 def __array_ufunc__(self, ufunc, method, *inputs, **kw): 1feabcd
520 if method != '__call__' or kw or ufunc.signature: 520 ↛ 522line 520 didn't jump to line 522 because the condition on line 520 was never true1eabcd
521 # TODO jax 0.4.15 should introduce ufunc methods
522 return NotImplemented
523 ufunc_class = UFunc.make_subclass(ufunc) 1eabcd
524 return ufunc_class(*inputs) 1eabcd
526 # TODO make this work with gufuncs. See comment in _signature.py.
527 # matmul in particular.
529 # continuous binary operations
530 __add__, __radd__ = _numeric_methods(numpy.add, 'add') 1feabcd
531 __sub__, __rsub__ = _numeric_methods(numpy.subtract, 'sub') 1feabcd
532 __mul__, __rmul__ = _numeric_methods(numpy.multiply, 'mul') 1feabcd
533 # __matmul__, __rmatmul__ = _numeric_methods(numpy.matmul, 'matmul')
534 __truediv__, __rtruediv__ = _numeric_methods(numpy.divide, 'truediv') 1feabcd
535 __mod__, __rmod__ = _numeric_methods(numpy.remainder, 'mod') 1feabcd
536 __divmod__, __rdivmod__ = _numeric_methods(numpy.divmod, 'divmod') 1feabcd
537 __pow__, __rpow__ = _numeric_methods(numpy.power, 'pow') 1feabcd
539 # continuous unary operations
540 __neg__ = _unary_method(numpy.negative, 'neg') 1feabcd
541 __pos__ = _unary_method(numpy.positive, 'pos') 1feabcd
542 __abs__ = _unary_method(numpy.absolute, 'abs') 1feabcd
544 # TODO add __getitem__ and __array_function__
546class UFunc: 1feabcd
547 """ base class of objects representing ufuncs applied to Distr instances """
549 def __new__(cls, *args): 1feabcd
550 return super().__new__(cls, *args) 1eabcd
551 # this __new__ serves to forbid keyword arguments
553 @classmethod 1feabcd
554 def invfcn(cls, x, *args): 1feabcd
555 return cls._ufunc(*args) 1eabcd
557 def _get_x_core_shape(self, *_): 1feabcd
558 return (0,) 1eabcd
560 @classmethod 1feabcd
561 @functools.lru_cache(maxsize=None) # functools.cache not available in 3.8 1feabcd
562 def make_subclass(cls, ufunc): 1feabcd
563 def exec_body(ns): 1eabcd
564 ns['_ufunc'] = getattr(jnp, ufunc.__name__) 1eabcd
565 ns['signature'] = ','.join(['(0)'] + ufunc.nin * ['()']) + '->()' 1eabcd
566 return types.new_class(ufunc.__name__, (__class__, Distr), exec_body=exec_body) 1eabcd
568def distribution(invfcn, signature=None, dtype=None): 1feabcd
569 r"""
571 Decorator to define a distribution from a transformation function.
573 Parameters
574 ----------
575 invfcn : function
576 The transformation function from a (multivariate) standard Normal
577 variable to the target random variable. The signature must be
578 ``invfcn(x, *params)``. It must be jax-traceable. It does not need to
579 be vectorized.
580 signature : str, optional
581 The signature of `invfcn`, as a numpy signature string. If not
582 specified, `invfcn` is assumed to take and output scalars.
583 dtype : dtype, optional
584 The dtype of the output of `invfcn`. If not specified, it is assumed to
585 be floating point.
587 Returns
588 -------
589 cls : Distr
590 The new distribution class.
592 Examples
593 --------
595 >>> @lgp.copula.distribution
596 ... def uniform(x, a, b):
597 ... return a + (b - a) * jax.scipy.stats.norm.cdf(x)
599 >>> @functools.partial(lgp.copula.distribution, signature='(n,m)->(n)')
600 ... def wishart(x):
601 ... " this parametrization is terrible, do not use "
602 ... return x @ x.T
604 """
606 def exec_body(ns): 1abcd
607 if signature is not None: 607 ↛ 609line 607 didn't jump to line 609 because the condition on line 607 was always true1abcd
608 ns['signature'] = signature 1abcd
609 if dtype is not None: 609 ↛ 611line 609 didn't jump to line 611 because the condition on line 609 was always true1abcd
610 ns['dtype'] = dtype 1abcd
611 ns['invfcn'] = staticmethod(jnp.vectorize(invfcn, signature=signature)) 1abcd
613 return types.new_class(invfcn.__name__, (Distr,), exec_body=exec_body) 1abcd