Coverage for src/lsqfitgp/copula/_copula.py: 90%
134 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/copula/_copula.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" defines Copula """
22import functools 1efabcd
23import pprint 1efabcd
25import jax 1efabcd
26from jax import tree_util 1efabcd
27from jax import numpy as jnp 1efabcd
28import gvar 1efabcd
30from .. import _array 1efabcd
31from .. import _gvarext 1efabcd
32from . import _base 1efabcd
34class Copula(_base.DistrBase): 1efabcd
35 """
37 Represents a tree of probability distributions.
39 By "tree" it is intended an arbitrarily nested structure of containers
40 (e.g., `dict` or `list`) where each leaf node is either a `Distr` object or
41 another `Copula`.
43 The function of a `Copula` is to keep into account the relationships
44 between all the `Distr` objects when defining the function `partial_invfcn`
45 that maps Normal variates to the desired random variable. The same `Distr`
46 object appearing in multiple places will not be sampled more than once.
48 This class inherits the functionality defined in `DistrBase` without
49 additions. The attributes `shape`, `dtype`, `distrshape`, that represents
50 properties of the output of `partial_invfcn`, are trees matching the
51 structure of the tree of variables. The same holds for the output of
52 `partial_invfcn`. `in_shape` instead is an ordinary array shape, indicating
53 that the input Normal variables are a raveled single array.
55 Parameters
56 ----------
57 variables : tree of Distr or Copula
58 The tree of distributions to wrap. The containers are copied, so
59 successive modifications will not be reflected on the `Copula`.
61 See also
62 --------
63 DistrBase, Distr, jax.tree_util
65 Examples
66 --------
68 Define a model into a dictionary one piece at a time, then wrap it as a
69 `Copula`:
71 >>> m = {}
72 >>> m['a'] = lgp.copula.halfnorm(1.5)
73 >>> m['b'] = lgp.copula.halfcauchy(2)
74 >>> m['c'] = [
75 ... lgp.copula.uniform(m['a'], m['b']),
76 ... lgp.copula.uniform(m['b'], m['a']),
77 ... ]
78 >>> cop = lgp.copula.Copula(m)
79 >>> cop
80 Copula({'a': halfnorm(1.5),
81 'b': halfcauchy(2),
82 'c': [uniform(<a>, <b>), uniform(<b>, <a>)]})
84 Notice how, when showing the object on the REPL, multiple appearances of
85 the same variables are replaced by identifiers derived from the dictionary
86 keys.
88 The model may then be extended to create a variant. This does not affect
89 `cop`:
91 >>> m['d'] = lgp.copula.invgamma(m['c'][0], m['c'][1])
92 >>> cop2 = lgp.copula.Copula(m)
93 >>> cop2
94 Copula({'a': halfnorm(1.5),
95 'b': halfcauchy(2),
96 'c': [uniform(<a>, <b>), uniform(<b>, <a>)],
97 'd': invgamma(<c.0>, <c.1>)})
98 >>> cop
99 Copula({'a': halfnorm(1.5),
100 'b': halfcauchy(2),
101 'c': [uniform(<a>, <b>), uniform(<b>, <a>)]})
103 """
105 @staticmethod 1efabcd
106 def _tree_path_str(path): 1efabcd
107 """ format a jax pytree key path as a compact, readable string """
108 def parsekey(key): 1abcd
109 if hasattr(key, 'key'): 109 ↛ 111line 109 didn't jump to line 111 because the condition on line 109 was always true1abcd
110 return key.key 1abcd
111 elif hasattr(key, 'idx'):
112 return key.idx
113 else:
114 return key
115 def keystr(key): 1abcd
116 key = parsekey(key) 1abcd
117 return str(key).replace('.', r'\.') 1abcd
118 return '.'.join(map(keystr, path)) 1abcd
120 @classmethod 1efabcd
121 def _jaxext_dict_sorting(cls, pytree): 1efabcd
122 """ replace dicts in pytree with a custom dict subclass such their
123 insertion order is maintained, see
124 https://github.com/google/jax/issues/4085 """
126 def is_dict(obj): 1abcd
127 return obj.__class__ is dict 1abcd
129 def patch_dict(obj): 1abcd
130 if is_dict(obj): 1abcd
131 return tree_util.tree_map(patch_dict, cls._Dict(obj)) 1abcd
132 else:
133 return obj 1abcd
135 return tree_util.tree_map(patch_dict, pytree, is_leaf=is_dict) 1abcd
137 @tree_util.register_pytree_with_keys_class 1efabcd
138 class _Dict(dict): 1efabcd
140 def tree_flatten_with_keys(self): 1efabcd
141 treedef = dict.fromkeys(self) 1abcd
142 keys_values = [(tree_util.DictKey(k), v) for k, v in self.items()] 1abcd
143 return keys_values, treedef 1abcd
145 @classmethod 1efabcd
146 def tree_unflatten(cls, treedef, values): 1efabcd
147 return cls(zip(treedef, values)) 1abcd
149 def __init__(self, variables): 1efabcd
150 variables = self._jaxext_dict_sorting(variables) 1abcd
151 def check_type(path, obj): 1abcd
152 if not isinstance(obj, _base.DistrBase): 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true1abcd
153 raise TypeError(f'only Distr or Copula objects can be '
154 f'contained in a Copula, found {obj!r} at '
155 f'<{self._tree_path_str(path)}>')
156 return obj 1abcd
157 self._variables = tree_util.tree_map_with_path(check_type, variables) 1abcd
158 cache = set() 1abcd
159 self.in_shape = self._compute_in_size(cache), 1abcd
160 self._ancestor_count = len(cache) - 1 1abcd
161 self.shape = self._map_getattr('shape') 1abcd
162 self.distrshape = self._map_getattr('distrshape') 1abcd
163 self.dtype = self._map_getattr('dtype') 1abcd
165 def _compute_in_size(self, cache): 1efabcd
166 if (out := super()._compute_in_size(cache)) is not None: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true1abcd
167 return out
168 def accumulate(in_size, obj): 1abcd
169 return in_size + obj._compute_in_size(cache) 1abcd
170 return tree_util.tree_reduce(accumulate, self._variables, 0) 1abcd
172 def _map_getattr(self, attr): 1efabcd
173 def get_attr(obj): 1abcd
174 if isinstance(obj, __class__): 1abcd
175 return obj._map_getattr(attr) 1abcd
176 else:
177 return getattr(obj, attr) 1abcd
178 return tree_util.tree_map(get_attr, self._variables) 1abcd
180 def _partial_invfcn_internal(self, x, i, cache): 1efabcd
181 if (out := super()._partial_invfcn_internal(x, i, cache)) is not None: 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true1abcd
182 return out
184 distributions, treedef = tree_util.tree_flatten(self._variables) 1abcd
185 outputs = [] 1abcd
186 for distr in distributions: 1abcd
187 out, i = distr._partial_invfcn_internal(x, i, cache) 1abcd
188 outputs.append(out) 1abcd
189 out = tree_util.tree_unflatten(treedef, outputs) 1abcd
191 cache[self] = out 1abcd
192 return out, i 1abcd
194 @functools.cached_property 1efabcd
195 def _partial_invfcn(self): 1efabcd
197 # non vectorized version, check core shapes and call recursive impl
198 # @jax.jit
199 def partial_invfcn_0(x): 1abcd
200 assert x.shape == self.in_shape 1abcd
201 cache = {} 1abcd
202 y, i = self._partial_invfcn_internal(x, 0, cache) 1abcd
203 assert i == x.size 1abcd
204 assert len(cache) == 1 + self._ancestor_count 1abcd
205 return y 1abcd
206 partial_invfcn_0_deriv = jax.jacfwd(partial_invfcn_0) 1abcd
208 # add 1-axis vectorization
209 partial_invfcn_1 = jax.vmap(partial_invfcn_0) 1abcd
210 partial_invfcn_1_deriv = jax.vmap(partial_invfcn_0_deriv) 1abcd
212 # add gvar support
213 def partial_invfcn_2(x): 1abcd
215 if x.dtype == object: 1abcd
217 # unpack the gvars
218 in_mean = gvar.mean(x) 1abcd
219 in_jac, indices = _gvarext.jacobian(x) 1abcd
221 # apply function
222 out_mean = partial_invfcn_1(in_mean) 1abcd
223 jac = partial_invfcn_1_deriv(in_mean) 1abcd
225 # concatenate derivatives and repack as gvars
226 def contract_and_pack(out_mean, jac): 1abcd
227 # indices:
228 # b = broadcast
229 # i = input
230 # ... = output
231 # g = gvar indices
232 out_jac = jnp.einsum('b...i,big->b...g', jac, in_jac) 1abcd
233 return _gvarext.from_jacobian(out_mean, out_jac, indices) 1abcd
235 return tree_util.tree_map(contract_and_pack, out_mean, jac) 1abcd
237 else:
238 return partial_invfcn_1(x) 1abcd
240 # add full vectorization
241 def partial_invfcn_3(x): 1abcd
242 x = _array.asarray(x) 1abcd
243 assert x.shape[-1:] == self.in_shape 1abcd
244 head = x.shape[:-1] 1abcd
245 x = x.reshape((-1,) + self.in_shape) 1abcd
246 y = partial_invfcn_2(x) 1abcd
247 def reshape_y(y, shape): 1abcd
248 assert y.shape[1:] == shape 1abcd
249 y = y.reshape(head + shape) 1abcd
250 if y.dtype == object and not y.ndim: 1abcd
251 y = y.item() 1abcd
252 return y 1abcd
253 return tree_util.tree_map(reshape_y, y, self.shape) 1abcd
255 return partial_invfcn_3 1abcd
257 def __repr__(self, path='', cache=None): 1efabcd
259 if isinstance(cache := super().__repr__(path, cache), str): 259 ↛ 260line 259 didn't jump to line 260 because the condition on line 259 was never true1abcd
260 return cache
262 def subrepr(k, obj): 1abcd
263 if isinstance(obj, _base.DistrBase): 263 ↛ 267line 263 didn't jump to line 267 because the condition on line 263 was always true1abcd
264 k = self._tree_path_str(k) 1abcd
265 return obj.__repr__('.'.join((path, k)).lstrip('.'), cache) 1abcd
266 else:
267 return repr(obj)
269 class NoQuotesRepr: 1abcd
270 def __init__(self, s): 1abcd
271 self.s = s 1abcd
272 def __repr__(self): 1abcd
273 return self.s 1abcd
275 out = tree_util.tree_map_with_path(subrepr, self._variables) 1abcd
276 out = tree_util.tree_map(NoQuotesRepr, out) 1abcd
277 out = pprint.pformat(out, sort_dicts=False) 1abcd
278 return f'{self.__class__.__name__}({out})' 1abcd
280 def _compute_staticdescr(self, path, cache): 1efabcd
281 def compute(key, x): 1abcd
282 return x._compute_staticdescr(path + [key], cache) 1abcd
283 return tree_util.tree_map_with_path(compute, self._variables) 1abcd