Coverage for src/lsqfitgp/_GP/_compute.py: 100%
184 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_GP/_compute.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import warnings 1feabcd
21import math 1feabcd
23from jax import numpy as jnp 1feabcd
24import numpy 1feabcd
25import gvar 1feabcd
27from .. import _linalg 1feabcd
28from .. import _jaxext 1feabcd
30from . import _base 1feabcd
32class GPCompute(_base.GPBase): 1feabcd
34 def __init__(self, *, solver, solverkw): 1feabcd
35 self._decompcache = {} # tuple of keys -> Decomposition 1feabcd
36 decomp = self._getdecomp(solver) 1feabcd
37 self._decompclass = lambda K, **kwargs: decomp(K, **kwargs, **solverkw) 1feabcd
39 def _clone(self): 1feabcd
40 newself = super()._clone() 1feabcd
41 newself._decompcache = self._decompcache.copy() 1feabcd
42 newself._decompclass = self._decompclass 1feabcd
43 return newself 1feabcd
45 def _solver(self, keys, ycov=None, *, covtransf=None, **kw): 1feabcd
46 """
47 Return a decomposition of the covariance matrix of the keys in ``keys``
48 plus the matrix ycov. Keyword arguments are passed to the decomposition.
49 """
51 # TODO cache ignores **kw.
53 keys = tuple(keys) 1feabcd
55 # Check if decomposition is in cache.
56 if ycov is None: 1feabcd
57 cache = self._decompcache.get(keys) 1feabcd
58 if cache is not None: 1feabcd
59 return cache 1feabcd
60 # TODO use frozenset(keys) instead of tuple(keys) to make cache
61 # work when order changes, but I have to permute the decomposition
62 # to make that work. Needs an ad-hoc class in _linalg. Make the
63 # decompcache a dict subclass that accepts tuples of keys but uses
64 # internally frozenset.
66 # Compute decomposition. # woodbury, currently un-implemented
67 # if isinstance(ycov, _linalg.Decomposition):
68 # ancestors = []
69 # transfs = []
70 # for key in keys:
71 # elem = self._elements[key]
72 # nest = False
73 # if isinstance(elem, self._LinTransf):
74 # size = sum(self._elements[k].size for k in elem.keys)
75 # if size < elem.size:
76 # nest = True
77 # ancestors += list(elem.keys)
78 # transfs.append(jnp.concatenate(elem.matrices(self), 1))
79 # if not nest:
80 # ancestors.append(key)
81 # transfs.append(jnp.eye(elem.size))
82 # transf = jlinalg.block_diag(*transfs)
83 # cov = self._assemblecovblocks(ancestors)
84 # if covtransf:
85 # ycov, transf, cov = covtransf((ycov, transf, cov))
86 # covdec = self._decompclass(cov, **kw)
87 # # TODO obtain covdec from _solver recursively, to use cache?
88 # decomp = _linalg.Woodbury2(ycov, transf, covdec, self._decompclass, sign=1, **kw)
89 # else:
90 Kxx = self._assemblecovblocks(keys) 1feabcd
91 if ycov is not None: 1feabcd
92 Kxx = Kxx + ycov 1feabcd
93 if covtransf: 1feabcd
94 Kxx = covtransf(Kxx) 1feabcd
95 decomp = self._decompclass(Kxx, **kw) 1feabcd
97 # Cache decomposition.
98 if ycov is None: 1feabcd
99 self._decompcache[keys] = decomp 1feabcd
101 return decomp 1feabcd
103 def _flatgiven(self, given, givencov): 1feabcd
105 if not hasattr(given, 'keys'): 1feabcd
106 raise TypeError('`given` must be dict') 1abcd
107 gcblack = givencov is None or isinstance(givencov, _linalg.Decomposition) 1feabcd
108 if not gcblack and not hasattr(givencov, 'keys'): 1feabcd
109 raise TypeError('`givenconv` must be None, dict or Decomposition') 1abcd
111 ylist = [] 1feabcd
112 keylist = [] 1feabcd
113 for key, l in given.items(): 1feabcd
114 if key not in self._elements: 1feabcd
115 raise KeyError(key) 1abcd
117 if not isinstance(l, jnp.ndarray): 1feabcd
118 # use numpy since there could be gvars
119 l = numpy.asarray(l) 1feabcd
120 shape = self._elements[key].shape 1feabcd
121 if l.shape != shape: 1feabcd
122 msg = 'given[{!r}] has shape {!r} different from shape {!r}' 1abcd
123 raise ValueError(msg.format(key, l.shape, shape)) 1abcd
124 if l.dtype != object and not jnp.issubdtype(l.dtype, jnp.number): 1feabcd
125 msg = 'given[{!r}] has non-numerical dtype {!r}' 1abcd
126 raise TypeError(msg.format(key, l.dtype)) 1abcd
128 ylist.append(l.reshape(-1)) 1feabcd
129 keylist.append(key) 1feabcd
131 # TODO error checking on the unpacking of givencov
133 if gcblack: 1feabcd
134 covblocks = givencov 1feabcd
135 else:
136 covblocks = [ 1eabcd
137 [
138 jnp.asarray(givencov[keylist[i], keylist[j]]).reshape(ylist[i].shape + ylist[j].shape)
139 for j in range(len(keylist))
140 ]
141 for i in range(len(keylist))
142 ]
144 return ylist, keylist, covblocks 1feabcd
146 def pred(self, given, key=None, givencov=None, *, fromdata=None, raw=False, keepcorr=None): 1feabcd
147 """
149 Compute the posterior.
151 The posterior can be computed either for all points or for a subset,
152 and either directly from data or from a posterior obtained with a fit.
153 The latter case is for when the Gaussian process was used in a fit with
154 other parameters.
156 The output is a collection of gvars, either an array or a dictionary
157 of arrays. They are properly correlated with gvars returned by
158 `prior` and with the input data/fit.
160 The input is a dictionary of arrays, ``given``, with keys corresponding
161 to the keys in the GP as added by `addx` or `addtransf`.
163 Parameters
164 ----------
165 given : dictionary of arrays
166 The data or fit result for some/all of the points in the GP.
167 The arrays can contain either gvars or normal numbers, the latter
168 being equivalent to zero-uncertainty gvars.
169 key : None, key or list of keys, optional
170 If None, compute the posterior for all points in the GP (also those
171 used in ``given``). Otherwise only those specified by key.
172 givencov : dictionary of arrays, optional
173 Covariance matrix of ``given``. If not specified, the covariance
174 is extracted from ``given`` with ``gvar.evalcov(given)``.
175 fromdata : bool
176 Mandatory. Specify if the contents of ``given`` are data or already
177 a posterior.
178 raw : bool, optional
179 If True, instead of returning a collection of gvars, return
180 the mean and the covariance. When the mean is a dictionary, the
181 covariance is a dictionary whose keys are pairs of keys of the
182 mean (the same format used by `gvar.evalcov`). Default False.
183 keepcorr : bool, optional
184 If True (default), the returned gvars are correlated with the
185 prior and the data/fit. If False, they have the correct covariance
186 between themselves, but are independent from all other preexisting
187 gvars.
189 Returns
190 -------
191 If raw=False (default):
193 posterior : array or dictionary of arrays
194 A collections of gvars representing the posterior.
196 If raw=True:
198 pmean : array or dictionary of arrays
199 The mean of the posterior. Equivalent to ``gvar.mean(posterior)``.
200 pcov : 2D array or dictionary of 2D arrays
201 The covariance matrix of the posterior. If ``pmean`` is a
202 dictionary, the keys of ``pcov`` are pairs of keys of ``pmean``.
203 Equivalent to ``gvar.evalcov(posterior)``.
205 """
207 # TODO GP.pred(..., raw=True, onlyvariance=True) computes only the
208 # variance (requires actually implementing diagquad at least in Chol and
209 # Diag).
212 if fromdata is None: 1feabcd
213 raise ValueError('you must specify if `given` is data or fit result') 1abcd
214 fromdata = bool(fromdata) 1feabcd
215 raw = bool(raw) 1feabcd
216 if keepcorr is None: 1feabcd
217 keepcorr = not raw 1feabcd
218 if keepcorr and raw: 1feabcd
219 raise ValueError('both keepcorr=True and raw=True') 1abcd
221 strip = False 1feabcd
222 if key is None: 1feabcd
223 outkeys = list(self._elements) 1feabcd
224 elif isinstance(key, list): 1feabcd
225 outkeys = key 1feabcd
226 else:
227 outkeys = [key] 1feabcd
228 strip = True 1feabcd
229 outslices = self._slices(outkeys) 1feabcd
231 ylist, inkeys, ycovblocks = self._flatgiven(given, givencov) 1feabcd
232 y = self._concatenate(ylist) 1feabcd
233 if y.dtype == object: 1feabcd
234 if ycovblocks is not None: 1feabcd
235 raise ValueError('given may contain gvars but a separate covariance matrix has been provided') 1abcd
237 self._checkpos_keys(inkeys + outkeys) 1feabcd
239 Kxxs = self._assemblecovblocks(inkeys, outkeys) 1feabcd
241 # if isinstance(ycovblocks, _linalg.Decomposition): # woodbury, currently un-implemented
242 # ycov = ycovblocks
243 # elif ...
244 if ycovblocks is not None: 1feabcd
245 ycov = jnp.block(ycovblocks) 1abcd
246 elif (fromdata or raw or not keepcorr) and y.dtype == object: 1feabcd
247 ycov = gvar.evalcov(gvar.gvar(y)) 1feabcd
248 # TODO use evalcov_blocks
249 # TODO I think this ignores the case in which we are using gvars
250 # and they are correlated with the GP. I guess the correct thing
251 # would be to sum the data gvars to the prior ones and use the
252 # resulting covariance matrix, and write a note about possible
253 # different results in this case when switching raw or keepcorr.
254 else:
255 ycov = None 1feabcd
256 self._check_ycov(ycov) 1feabcd
258 if raw or not keepcorr or self._checkfinite: 1feabcd
259 if y.dtype == object: 1feabcd
260 ymean = gvar.mean(y) 1feabcd
261 else:
262 ymean = y 1feabcd
263 self._check_ymean(ymean) 1feabcd
265 if raw or not keepcorr: 1feabcd
267 Kxsxs = self._assemblecovblocks(outkeys) 1feabcd
269 if fromdata: 1feabcd
270 solver = self._solver(inkeys, ycov) 1feabcd
271 else:
272 solver = self._solver(inkeys) 1abcd
274 mean = solver.pinv_bilinear(Kxxs, ymean) 1feabcd
275 cov = Kxsxs - solver.ginv_quad(Kxxs) 1feabcd
277 if not fromdata: 1feabcd
278 # cov = Kxsxs - Kxsx Kxx^-1 (Kxx - ycov) Kxx^-1 Kxxs =
279 # = Kxsxs - Kxsx Kxx^-1 Kxxs + Kxsx Kxx^-1 ycov Kxx^-1 Kxxs
280 if ycov is not None: 1abcd
281 # if isinstance(ycov, _linalg.Decomposition): # for woodbury, currently un-implemented
282 # ycov = ycov.matrix()
283 A = solver.ginv_linear(Kxxs) 1abcd
284 # TODO do I need K⁺ here or is K⁻ fine?
285 cov += A.T @ ycov @ A 1abcd
287 else: # (keepcorr and not raw)
288 yplist = [numpy.reshape(self._prior(key), -1) for key in inkeys] 1feabcd
289 ysplist = [numpy.reshape(self._prior(key), -1) for key in outkeys] 1feabcd
290 yp = self._concatenate(yplist) 1feabcd
291 ysp = self._concatenate(ysplist) 1feabcd
293 if y.dtype != object and ycov is not None: 1feabcd
294 # if isinstance(ycov, _linalg.Decomposition): # for woodbury, currently un-implemented
295 # ycov = ycov.matrix()
296 y = gvar.gvar(y, ycov) 1abcd
297 else:
298 y = numpy.asarray(y) # because y - yp fails if y is a jax array 1feabcd
299 mat = ycov if fromdata else None 1feabcd
300 flatout = ysp + self._solver(inkeys, mat).pinv_bilinear_robj(Kxxs, y - yp) 1feabcd
302 if raw and not strip: 1feabcd
303 meandict = { 1eabcd
304 key: mean[slic].reshape(self._elements[key].shape)
305 for key, slic in zip(outkeys, outslices)
306 }
308 covdict = { 1eabcd
309 (row, col):
310 cov[rowslice, colslice].reshape(self._elements[row].shape + self._elements[col].shape)
311 for row, rowslice in zip(outkeys, outslices)
312 for col, colslice in zip(outkeys, outslices)
313 }
315 return meandict, covdict 1eabcd
317 elif raw: 1feabcd
318 outkey, = outkeys 1feabcd
319 mean = mean.reshape(self._elements[outkey].shape) 1feabcd
320 cov = cov.reshape(2 * self._elements[outkey].shape) 1feabcd
321 return mean, cov 1feabcd
323 elif not keepcorr: 1feabcd
325 flatout = gvar.gvar(mean, cov, fast=True) 1eabcd
327 if not strip: 1feabcd
328 return { 1feabcd
329 key: flatout[slic].reshape(self._elements[key].shape)
330 for key, slic in zip(outkeys, outslices)
331 }
332 else:
333 outkey, = outkeys 1feabcd
334 return flatout.reshape(self._elements[outkey].shape) 1feabcd
336 def predfromfit(self, *args, **kw): 1feabcd
337 """
338 Like `pred` with ``fromdata=False``.
339 """
340 return self.pred(*args, fromdata=False, **kw) 1feabcd
342 def predfromdata(self, *args, **kw): 1feabcd
343 """
344 Like `pred` with ``fromdata=True``.
345 """
346 return self.pred(*args, fromdata=True, **kw) 1feabcd
348 def _prior_decomp(self, given, givencov=None, **kw): 1feabcd
349 """ Internal implementation of marginal_likelihood. Keyword arguments
350 are passed to _solver. """
351 ylist, inkeys, ycovblocks = self._flatgiven(given, givencov) 1feabcd
352 y = self._concatenate(ylist) 1feabcd
354 self._checkpos_keys(inkeys) 1feabcd
356 # Get mean.
357 if y.dtype == object: 1feabcd
358 ymean = gvar.mean(y) 1abcd
359 else:
360 ymean = y 1feabcd
361 self._check_ymean(ymean) 1feabcd
363 # Get covariance matrix.
364 # if isinstance(ycovblocks, _linalg.Decomposition): # for woodbury, currently un-implemented
365 # ycov = ycovblocks
366 # elif ...
367 if ycovblocks is not None: 1feabcd
368 ycov = jnp.block(ycovblocks) 1eabcd
369 if y.dtype == object: 1eabcd
370 warnings.warn(f'covariance matrix may have been specified both explicitly and with gvars; the explicit one will be used') 1abcd
371 elif y.dtype == object: 1feabcd
372 gvary = gvar.gvar(y) 1abcd
373 ycov = gvar.evalcov(gvary) 1abcd
374 else:
375 ycov = None 1feabcd
376 self._check_ycov(ycov) 1feabcd
378 decomp = self._solver(inkeys, ycov, **kw) 1feabcd
379 return decomp, ymean 1feabcd
381 def _check_ymean(self, ymean): 1feabcd
382 with _jaxext.skipifabstract(): 1feabcd
383 if self._checkfinite and not jnp.all(jnp.isfinite(ymean)): 1feabcd
384 raise ValueError('mean of `given` is not finite') 1abcd
386 def _check_ycov(self, ycov): 1feabcd
387 if ycov is None or isinstance(ycov, _linalg.Decomposition): 1feabcd
388 return 1feabcd
389 with _jaxext.skipifabstract(): 1feabcd
390 if self._checkfinite and not jnp.all(jnp.isfinite(ycov)): 1feabcd
391 raise ValueError('covariance matrix of `given` is not finite') 1abcd
392 if self._checksym and not jnp.allclose(ycov, ycov.T): 1feabcd
393 raise ValueError('covariance matrix of `given` is not symmetric') 1abcd
395 def marginal_likelihood(self, given, givencov=None, **kw): 1feabcd
396 """
398 Compute the logarithm of the probability of the data.
400 The probability is computed under the Gaussian prior and Gaussian error
401 model. It is also called marginal likelihood. If :math:`y` is the data
402 and :math:`g` is the Gaussian process, this is
404 .. math::
405 \\log \\int p(y|g) p(g) \\mathrm{d} g.
407 Unlike `pred`, you can't compute this with a fit result instead of
408 data. If you used the Gaussian process as latent variable in a fit,
409 use the whole fit to compute the marginal likelihood. E.g. `lsqfit`
410 always computes the logGBF (it's the same thing).
412 The input is an array or dictionary of arrays, ``given``. The contents
413 of ``given`` represent the input data.
415 Parameters
416 ----------
417 given : dictionary of arrays
418 The data for some/all of the points in the GP. The arrays can
419 contain either gvars or normal numbers, the latter being
420 equivalent to zero-uncertainty gvars.
421 givencov : dictionary of arrays, optional
422 Covariance matrix of ``given``. If not specified, the covariance
423 is extracted from ``given`` with ``gvar.evalcov(given)``.
424 **kw :
425 Additional keyword arguments are passed to the matrix decomposition.
427 Returns
428 -------
429 logp : scalar
430 The logarithm of the marginal likelihood.
431 """
432 decomp, ymean = self._prior_decomp(given, givencov, **kw) 1eabcd
433 mll, _, _, _, _ = decomp.minus_log_normal_density(ymean, value=True) 1eabcd
434 return -mll 1eabcd
436 @staticmethod 1feabcd
437 def _getdecomp(solver): 1feabcd
438 return { 1feabcd
439 'chol': _linalg.Chol,
440 }[solver]
442 @classmethod 1feabcd
443 def decompose(cls, posdefmatrix, solver='chol', **kw): 1feabcd
444 """
445 Decompose a nonnegative definite matrix.
447 The decomposition can be used to calculate linear algebra expressions
448 where the (pseudo)inverse of the matrix appears.
450 Parameters
451 ----------
452 posdefmatrix : array
453 A nonnegative definite nonempty symmetric square matrix. If the
454 array is not square, it must have a shape of the kind (k, n, m,
455 ..., k, n, m, ...) and is reshaped to (k * n * m * ..., k * n * m *
456 ...).
457 solver : str
458 Algorithm used to decompose the matrix.
460 'chol'
461 Cholesky decomposition after regularizing the matrix with a
462 Gershgorin estimate of the maximum eigenvalue.
463 **kw :
464 Additional options.
466 epsrel, epsabs : positive float or 'auto'
467 Specify the threshold for considering small the eigenvalues:
469 eps = epsrel * maximum_eigenvalue + epsabs
471 epsrel='auto' sets epsrel = matrix_size * float_epsilon,
472 while epsabs='auto' sets epsabs = float_epsilon. Default is
473 epsrel='auto', epsabs=0.
475 Returns
476 -------
477 decomp : Decomposition
478 An object representing the decomposition of the matrix. The
479 available methods and properties are (K being the matrix):
481 matrix():
482 Return K.
483 ginv():
484 Compute K⁻.
485 ginv_linear(X):
486 Compute K⁻X.
487 pinv_bilinear(A, r)
488 Compute A'K⁺r.
489 pinv_bilinear_robj(A, r)
490 Compute A'K⁺r, and r can be an array of arbitrary objects.
491 ginv_quad(A)
492 Compute A'K⁻A.
493 ginv_diagquad(A)
494 Compute diag(A'K⁻A).
495 correlate(x)
496 Compute Zx such that K = ZZ', Z can be rectangular.
497 back_correlate(X)
498 Compute Z'X.
499 pinv_correlate(x):
500 Compute Z⁺x.
501 minus_log_normal_density(r, ...)
502 Compute a Normal density and its derivatives.
503 eps
504 The threshold below which eigenvalues are not calculable.
505 n
506 Number of rows/columns of K.
507 m
508 Number of columns of Z.
510 Notes
511 -----
512 The decomposition operations are JAX-traceable, but they are not meant
513 to be differentiated. The method `minus_log_normal_density` provides
514 required derivatives with a custom implementation, given the derivatives
515 of the inputs.
517 """
518 m = jnp.asarray(posdefmatrix) 1eabcd
519 assert m.size > 0 1eabcd
520 assert m.ndim % 2 == 0 1eabcd
521 half = m.ndim // 2 1eabcd
522 head = m.shape[:half] 1eabcd
523 tail = m.shape[half:] 1eabcd
524 assert head == tail 1eabcd
525 n = math.prod(head) 1eabcd
526 m = m.reshape(n, n) 1eabcd
527 decompcls = cls._getdecomp(solver) 1eabcd
528 return decompcls(m, **kw) 1eabcd
530 # TODO extend the interface to use composite decompositions
531 # TODO accept a dict for covariance matrix