Coverage for src/lsqfitgp/copula/_base.py: 96%
59 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/copula/_base.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" define DistrBase """
22import abc 1feabcd
23import functools 1feabcd
24import collections 1feabcd
26import gvar 1feabcd
27import numpy 1feabcd
29class DistrBase(metaclass=abc.ABCMeta): 1feabcd
30 r"""
32 Abstract base class to represent (trees of) probability distributions.
34 Attributes
35 ----------
36 in_shape : tuple of int
37 The core shape of the input array to `partial_invfcn`.
38 shape : (tree of) tuple of int
39 The core shape of the output array of `partial_invfcn`.
40 dtype : (tree of) dtype
41 The dtype of the output array of `partial_invfcn`.
42 distrshape : (tree of) tuple of int
43 The sub-core shape of the output array of `partial_invfcn` that
44 represents the atomic shape of the distribution.
46 Methods
47 -------
48 partial_invfcn :
49 Transformation function from a (multivariate) Normal variable to the
50 target random variable.
51 add_distribution :
52 Register the distribution for usage with `gvar.BufferDict`.
53 gvars :
54 Return an array of gvars with the appropriate shape for usage with
55 `gvar.BufferDict`.
57 See also
58 --------
59 Distr, Copula
61 """
63 def __init_subclass__(cls, **kw): 1feabcd
64 super().__init_subclass__(**kw) 1feabcd
65 cls._named = {} 1feabcd
67 in_shape = NotImplemented 1feabcd
68 shape = NotImplemented 1feabcd
69 dtype = NotImplemented 1feabcd
70 distrshape = NotImplemented 1feabcd
72 def partial_invfcn(self, x): 1feabcd
73 """
75 Map independent Normal variables to the desired distribution.
77 This function is a generalized ufunc. It is jax traceable and
78 differentiable one time. It supports arrays of gvars as input. The
79 attributes `in_shape` and `shape` give the core shapes.
81 Parameters
82 ----------
83 x : ``(..., *in_shape)`` array
84 An array of values representing draws of i.i.d. Normal variates.
86 Returns
87 -------
88 y : (tree of) ``(..., *shape)`` array
89 An array of values representing draws of the desired distribution.
91 """
92 return self._partial_invfcn(x) 1feabcd
94 @abc.abstractmethod 1feabcd
95 def _partial_invfcn(self, x): 1feabcd
96 pass
98 def _is_same_family(self, invfcn): 1feabcd
99 return getattr(invfcn, '__self__', None).__class__ is self.__class__ 1eabcd
101 def add_distribution(self, name): 1feabcd
102 """
104 Register the distribution for usage with `gvar.BufferDict`.
106 Parameters
107 ----------
108 name : str
109 The name to use for the distribution. It must be globally unique,
110 and it should not contain parentheses. To redefine a distribution
111 with the same name, use `gvar.BufferDict.del_distribution` first.
112 However, it is allowed to reuse the name if the distribution family,
113 shape and parameters are identical to those used for the existing
114 definition.
116 See also
117 --------
118 gvar.BufferDict.add_distribution, gvar.BufferDict.del_distribution
120 """
122 if gvar.BufferDict.has_distribution(name): 1feabcd
123 invfcn = gvar.BufferDict.invfcn[name] 1eabcd
124 if not self._is_same_family(invfcn): 1eabcd
125 raise ValueError(f'distribution {name} already defined') 1abcd
126 existing = self._named[name] 1eabcd
127 if existing != self._staticdescr: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true1eabcd
128 raise ValueError('Attempt to overwrite existing'
129 f' {self.__class__.__name__} distribution with name {name}')
130 # cls._named is not updated by
131 # gvar.BufferDict.del_distribution, but it is not a problem
133 else:
134 gvar.BufferDict.add_distribution(name, self.partial_invfcn) 1feabcd
135 self._named[name] = self._staticdescr 1feabcd
137 def gvars(self): 1feabcd
138 """
140 Return an array of gvars intended as value in a `gvar.BufferDict`.
142 Returns
143 -------
144 gvars : array of gvars
145 An array of i.i.d. standard Normal primary gvars with shape
146 `in_shape`.
148 """
150 return gvar.gvar(numpy.zeros(self.in_shape), numpy.ones(self.in_shape)) 1feabcd
152 @abc.abstractmethod 1feabcd
153 def __repr__(self, path='', cache=None): 1feabcd
154 """ produce a representation where no object appears more than once,
155 later appearances are replaced by a user-friendly identifier """
156 if cache is None: 1feabcd
157 cache = {} 1abcd
158 if self in cache: 1feabcd
159 return cache[self] 1abcd
160 cache[self] = f'<{path}>' 1feabcd
161 return cache 1feabcd
163 class _Path(collections.namedtuple('Path', ['path'])): pass 1feabcd
165 @abc.abstractmethod 1feabcd
166 def _compute_staticdescr(self, path, cache): 1feabcd
167 """ compute static description of self, can be compared """
168 if self in cache: 1feabcd
169 return cache[self] 1abcd
170 cache[self] = self._Path(path) 1feabcd
172 @functools.cached_property 1feabcd
173 def _staticdescr(self): 1feabcd
174 return self._compute_staticdescr([], {}) 1feabcd
176 @abc.abstractmethod 1feabcd
177 def _compute_in_size(self, cache): 1feabcd
178 """ compute input size to partial_invfcn, without double counting """
179 if self in cache: 1eabcd
180 return 0 1abcd
181 cache.add(self) 1eabcd
183 @abc.abstractmethod 1feabcd
184 def _partial_invfcn_internal(self, x, i, cache): 1feabcd
185 assert x.ndim == 1 1feabcd
186 if self in cache: 1feabcd
187 return cache[self], i 1abcd