Coverage for src/lsqfitgp/copula/_copula.py: 90%

134 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/copula/_copula.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" defines Copula """ 

21 

22import functools 1efabcd

23import pprint 1efabcd

24 

25import jax 1efabcd

26from jax import tree_util 1efabcd

27from jax import numpy as jnp 1efabcd

28import gvar 1efabcd

29 

30from .. import _array 1efabcd

31from .. import _gvarext 1efabcd

32from . import _base 1efabcd

33 

34class Copula(_base.DistrBase): 1efabcd

35 """ 

36  

37 Represents a tree of probability distributions. 

38 

39 By "tree" it is intended an arbitrarily nested structure of containers 

40 (e.g., `dict` or `list`) where each leaf node is either a `Distr` object or 

41 another `Copula`. 

42 

43 The function of a `Copula` is to keep into account the relationships 

44 between all the `Distr` objects when defining the function `partial_invfcn` 

45 that maps Normal variates to the desired random variable. The same `Distr` 

46 object appearing in multiple places will not be sampled more than once. 

47 

48 This class inherits the functionality defined in `DistrBase` without 

49 additions. The attributes `shape`, `dtype`, `distrshape`, that represents 

50 properties of the output of `partial_invfcn`, are trees matching the 

51 structure of the tree of variables. The same holds for the output of 

52 `partial_invfcn`. `in_shape` instead is an ordinary array shape, indicating 

53 that the input Normal variables are a raveled single array. 

54 

55 Parameters 

56 ---------- 

57 variables : tree of Distr or Copula 

58 The tree of distributions to wrap. The containers are copied, so 

59 successive modifications will not be reflected on the `Copula`. 

60 

61 See also 

62 -------- 

63 DistrBase, Distr, jax.tree_util 

64 

65 Examples 

66 -------- 

67 

68 Define a model into a dictionary one piece at a time, then wrap it as a 

69 `Copula`: 

70 

71 >>> m = {} 

72 >>> m['a'] = lgp.copula.halfnorm(1.5) 

73 >>> m['b'] = lgp.copula.halfcauchy(2) 

74 >>> m['c'] = [ 

75 ... lgp.copula.uniform(m['a'], m['b']), 

76 ... lgp.copula.uniform(m['b'], m['a']), 

77 ... ] 

78 >>> cop = lgp.copula.Copula(m) 

79 >>> cop 

80 Copula({'a': halfnorm(1.5), 

81 'b': halfcauchy(2), 

82 'c': [uniform(<a>, <b>), uniform(<b>, <a>)]}) 

83 

84 Notice how, when showing the object on the REPL, multiple appearances of 

85 the same variables are replaced by identifiers derived from the dictionary 

86 keys. 

87 

88 The model may then be extended to create a variant. This does not affect 

89 `cop`: 

90 

91 >>> m['d'] = lgp.copula.invgamma(m['c'][0], m['c'][1]) 

92 >>> cop2 = lgp.copula.Copula(m) 

93 >>> cop2 

94 Copula({'a': halfnorm(1.5), 

95 'b': halfcauchy(2), 

96 'c': [uniform(<a>, <b>), uniform(<b>, <a>)], 

97 'd': invgamma(<c.0>, <c.1>)}) 

98 >>> cop 

99 Copula({'a': halfnorm(1.5), 

100 'b': halfcauchy(2), 

101 'c': [uniform(<a>, <b>), uniform(<b>, <a>)]}) 

102 

103 """ 

104 

105 @staticmethod 1efabcd

106 def _tree_path_str(path): 1efabcd

107 """ format a jax pytree key path as a compact, readable string """ 

108 def parsekey(key): 1abcd

109 if hasattr(key, 'key'): 109 ↛ 111line 109 didn't jump to line 111 because the condition on line 109 was always true1abcd

110 return key.key 1abcd

111 elif hasattr(key, 'idx'): 

112 return key.idx 

113 else: 

114 return key 

115 def keystr(key): 1abcd

116 key = parsekey(key) 1abcd

117 return str(key).replace('.', r'\.') 1abcd

118 return '.'.join(map(keystr, path)) 1abcd

119 

120 @classmethod 1efabcd

121 def _jaxext_dict_sorting(cls, pytree): 1efabcd

122 """ replace dicts in pytree with a custom dict subclass such their 

123 insertion order is maintained, see 

124 https://github.com/google/jax/issues/4085 """ 

125 

126 def is_dict(obj): 1abcd

127 return obj.__class__ is dict 1abcd

128 

129 def patch_dict(obj): 1abcd

130 if is_dict(obj): 1abcd

131 return tree_util.tree_map(patch_dict, cls._Dict(obj)) 1abcd

132 else: 

133 return obj 1abcd

134 

135 return tree_util.tree_map(patch_dict, pytree, is_leaf=is_dict) 1abcd

136 

137 @tree_util.register_pytree_with_keys_class 1efabcd

138 class _Dict(dict): 1efabcd

139 

140 def tree_flatten_with_keys(self): 1efabcd

141 treedef = dict.fromkeys(self) 1abcd

142 keys_values = [(tree_util.DictKey(k), v) for k, v in self.items()] 1abcd

143 return keys_values, treedef 1abcd

144 

145 @classmethod 1efabcd

146 def tree_unflatten(cls, treedef, values): 1efabcd

147 return cls(zip(treedef, values)) 1abcd

148 

149 def __init__(self, variables): 1efabcd

150 variables = self._jaxext_dict_sorting(variables) 1abcd

151 def check_type(path, obj): 1abcd

152 if not isinstance(obj, _base.DistrBase): 152 ↛ 153line 152 didn't jump to line 153 because the condition on line 152 was never true1abcd

153 raise TypeError(f'only Distr or Copula objects can be ' 

154 f'contained in a Copula, found {obj!r} at ' 

155 f'<{self._tree_path_str(path)}>') 

156 return obj 1abcd

157 self._variables = tree_util.tree_map_with_path(check_type, variables) 1abcd

158 cache = set() 1abcd

159 self.in_shape = self._compute_in_size(cache), 1abcd

160 self._ancestor_count = len(cache) - 1 1abcd

161 self.shape = self._map_getattr('shape') 1abcd

162 self.distrshape = self._map_getattr('distrshape') 1abcd

163 self.dtype = self._map_getattr('dtype') 1abcd

164 

165 def _compute_in_size(self, cache): 1efabcd

166 if (out := super()._compute_in_size(cache)) is not None: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true1abcd

167 return out 

168 def accumulate(in_size, obj): 1abcd

169 return in_size + obj._compute_in_size(cache) 1abcd

170 return tree_util.tree_reduce(accumulate, self._variables, 0) 1abcd

171 

172 def _map_getattr(self, attr): 1efabcd

173 def get_attr(obj): 1abcd

174 if isinstance(obj, __class__): 1abcd

175 return obj._map_getattr(attr) 1abcd

176 else: 

177 return getattr(obj, attr) 1abcd

178 return tree_util.tree_map(get_attr, self._variables) 1abcd

179 

180 def _partial_invfcn_internal(self, x, i, cache): 1efabcd

181 if (out := super()._partial_invfcn_internal(x, i, cache)) is not None: 181 ↛ 182line 181 didn't jump to line 182 because the condition on line 181 was never true1abcd

182 return out 

183 

184 distributions, treedef = tree_util.tree_flatten(self._variables) 1abcd

185 outputs = [] 1abcd

186 for distr in distributions: 1abcd

187 out, i = distr._partial_invfcn_internal(x, i, cache) 1abcd

188 outputs.append(out) 1abcd

189 out = tree_util.tree_unflatten(treedef, outputs) 1abcd

190 

191 cache[self] = out 1abcd

192 return out, i 1abcd

193 

194 @functools.cached_property 1efabcd

195 def _partial_invfcn(self): 1efabcd

196 

197 # non vectorized version, check core shapes and call recursive impl 

198 # @jax.jit 

199 def partial_invfcn_0(x): 1abcd

200 assert x.shape == self.in_shape 1abcd

201 cache = {} 1abcd

202 y, i = self._partial_invfcn_internal(x, 0, cache) 1abcd

203 assert i == x.size 1abcd

204 assert len(cache) == 1 + self._ancestor_count 1abcd

205 return y 1abcd

206 partial_invfcn_0_deriv = jax.jacfwd(partial_invfcn_0) 1abcd

207 

208 # add 1-axis vectorization 

209 partial_invfcn_1 = jax.vmap(partial_invfcn_0) 1abcd

210 partial_invfcn_1_deriv = jax.vmap(partial_invfcn_0_deriv) 1abcd

211 

212 # add gvar support 

213 def partial_invfcn_2(x): 1abcd

214 

215 if x.dtype == object: 1abcd

216 

217 # unpack the gvars 

218 in_mean = gvar.mean(x) 1abcd

219 in_jac, indices = _gvarext.jacobian(x) 1abcd

220 

221 # apply function 

222 out_mean = partial_invfcn_1(in_mean) 1abcd

223 jac = partial_invfcn_1_deriv(in_mean) 1abcd

224 

225 # concatenate derivatives and repack as gvars 

226 def contract_and_pack(out_mean, jac): 1abcd

227 # indices: 

228 # b = broadcast 

229 # i = input 

230 # ... = output 

231 # g = gvar indices 

232 out_jac = jnp.einsum('b...i,big->b...g', jac, in_jac) 1abcd

233 return _gvarext.from_jacobian(out_mean, out_jac, indices) 1abcd

234 

235 return tree_util.tree_map(contract_and_pack, out_mean, jac) 1abcd

236 

237 else: 

238 return partial_invfcn_1(x) 1abcd

239 

240 # add full vectorization 

241 def partial_invfcn_3(x): 1abcd

242 x = _array.asarray(x) 1abcd

243 assert x.shape[-1:] == self.in_shape 1abcd

244 head = x.shape[:-1] 1abcd

245 x = x.reshape((-1,) + self.in_shape) 1abcd

246 y = partial_invfcn_2(x) 1abcd

247 def reshape_y(y, shape): 1abcd

248 assert y.shape[1:] == shape 1abcd

249 y = y.reshape(head + shape) 1abcd

250 if y.dtype == object and not y.ndim: 1abcd

251 y = y.item() 1abcd

252 return y 1abcd

253 return tree_util.tree_map(reshape_y, y, self.shape) 1abcd

254 

255 return partial_invfcn_3 1abcd

256 

257 def __repr__(self, path='', cache=None): 1efabcd

258 

259 if isinstance(cache := super().__repr__(path, cache), str): 259 ↛ 260line 259 didn't jump to line 260 because the condition on line 259 was never true1abcd

260 return cache 

261 

262 def subrepr(k, obj): 1abcd

263 if isinstance(obj, _base.DistrBase): 263 ↛ 267line 263 didn't jump to line 267 because the condition on line 263 was always true1abcd

264 k = self._tree_path_str(k) 1abcd

265 return obj.__repr__('.'.join((path, k)).lstrip('.'), cache) 1abcd

266 else: 

267 return repr(obj) 

268 

269 class NoQuotesRepr: 1abcd

270 def __init__(self, s): 1abcd

271 self.s = s 1abcd

272 def __repr__(self): 1abcd

273 return self.s 1abcd

274 

275 out = tree_util.tree_map_with_path(subrepr, self._variables) 1abcd

276 out = tree_util.tree_map(NoQuotesRepr, out) 1abcd

277 out = pprint.pformat(out, sort_dicts=False) 1abcd

278 return f'{self.__class__.__name__}({out})' 1abcd

279 

280 def _compute_staticdescr(self, path, cache): 1efabcd

281 def compute(key, x): 1abcd

282 return x._compute_staticdescr(path + [key], cache) 1abcd

283 return tree_util.tree_map_with_path(compute, self._variables) 1abcd