Coverage for src/lsqfitgp/copula/_makedict.py: 100%
31 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/copula/_makedict.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" defines makedict """
22import gvar 1efabcd
24from . import _base 1efabcd
26def makedict(variables, prefix='__copula_'): 1efabcd
27 """
29 Expand distributions in a dictionary.
31 Parameters
32 ----------
33 variables : dict
34 A dictionary representing a collection of probability distribution. If a
35 value is an instance of `DistrBase`, the key is converted to mark a
36 transformation and the value is replaced with new primary gvars.
37 prefix : str
38 A prefix to make the transformation names unique.
40 Returns
41 -------
42 out : BufferDict
43 The transformed dictionary. Recognizes the same keys as `variables`,
44 but squashes the values through the transformation that sends a Normal
45 to the desired distribution.
47 Raises
48 ------
49 ValueError :
50 If any `DistrBase` object appears under different keys.
52 Examples
53 --------
55 Put a `Distr` into a `gvar.BufferDict`:
57 >>> bd = lgp.copula.makedict({'x': lgp.copula.beta(1, 1)})
58 >>> bd
59 BufferDict({'__copula_beta{1, 1}(x)': 0.0(1.0)})
60 >>> bd['x']
61 0.50(40)
62 >>> bd['__copula_beta{1, 1}(x)']
63 0.0(1.0)
65 You can also put an entire `Copula`:
67 >>> bd = lgp.copula.makedict({
68 ... 'x': lgp.copula.Copula({
69 ... 'y': lgp.copula.gamma(2, 1 / 2),
70 ... 'z': lgp.copula.invgamma(2, 2),
71 ... }),
72 ... })
73 >>> bd
74 BufferDict({"__copula_{'y': gamma{2, 0.5}, 'z': invgamma{2, 2}}(x)": array([0.0(1.0), 0.0(1.0)], dtype=object)})
75 >>> bd['x']
76 {'y': 3.4(2.5), 'z': 1.19(90)}
78 Other entries and transformations are left alone:
80 >>> bd = lgp.copula.makedict({
81 ... 'x': gvar.gvar(3, 0.2),
82 ... 'log(y)': gvar.gvar(0, 1),
83 ... 'z': lgp.copula.dirichlet(1.5, [1, 2, 3]),
84 ... })
85 >>> bd
86 BufferDict({'x': 3.00(20), 'log(y)': 0.0(1.0), '__copula_dirichlet{1.5, [1, 2, 3], shape=3}(z)': array([0.0(1.0), 0.0(1.0), 0.0(1.0)], dtype=object)})
87 >>> bd['z']
88 array([0.06(20), 0.31(49), 0.63(51)], dtype=object)
89 >>> bd['y']
90 1.0(1.0)
92 Since shared `DistrBase` objects represent statistical dependency, it is
93 forbidden to have the same object appear under different keys, as that
94 would make it impossible to take the dependencies into account:
96 >>> x = lgp.copula.beta(1, 1)
97 >>> y = lgp.copula.beta(1, x)
98 >>> lgp.copula.makedict({'x': x, 'y': y})
99 ValueError: cross-key occurrences of object(s):
100 beta with id 10952248976: <x>, <y.1>
102 """
104 # collect all objects and their representations in DistrBase instances
105 caches = {} 1efabcd
106 for k, v in variables.items(): 1efabcd
107 if isinstance(v, _base.DistrBase): 1efabcd
108 cache = {} 1efabcd
109 v.__repr__(k, cache) 1efabcd
110 caches[k] = cache 1efabcd
112 # put everything into a single multiple valued dict
113 allobjects = {} 1efabcd
114 for cache in caches.values(): 1efabcd
115 for obj, descr in cache.items(): 1efabcd
116 allobjects.setdefault(obj, []).append(descr) 1efabcd
118 # find objects that appear multiple times
119 multiple = '' 1efabcd
120 for obj, descrs in allobjects.items(): 1efabcd
121 if len(descrs) > 1: 1efabcd
122 multiple += f'{obj.__class__.__name__} with id {id(obj)}: {", ".join(descrs)}\n' 1abcd
124 # raise an error if there are
125 if multiple: 1efabcd
126 raise ValueError(f'cross-key occurrences of object(s):\n{multiple}') 1abcd
128 out = {} 1efabcd
129 for k, v in variables.items(): 1efabcd
130 if isinstance(v, _base.DistrBase): 1efabcd
131 name = str(v._staticdescr).replace('(', '{').replace(')', '}') 1efabcd
132 assert '(' not in prefix and ')' not in prefix 1efabcd
133 # gvar does not currently check presence of parentheses, see
134 # https://github.com/gplepage/gvar/issues/39
135 name = prefix + name 1efabcd
136 v.add_distribution(name) 1efabcd
137 v = v.gvars() 1efabcd
138 k = f'{name}({k})' 1efabcd
139 assert k not in out 1efabcd
140 out[k] = v 1efabcd
141 return gvar.BufferDict(out) 1efabcd