Coverage for src/lsqfitgp/_GP/_compute.py: 100%
185 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-12 12:42 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-12 12:42 +0000
1# lsqfitgp/_GP/_compute.py
2#
3# Copyright (c) 2020, 2022, 2023, 2025, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import warnings 1feabcd
21import math 1feabcd
23from jax import numpy as jnp 1feabcd
24import numpy 1feabcd
25import gvar 1feabcd
27from .. import _linalg 1feabcd
28from .. import _jaxext 1feabcd
30from . import _base 1feabcd
32class GPCompute(_base.GPBase): 1feabcd
34 def __init__(self, *, solver, solverkw): 1feabcd
35 self._decompcache = {} # tuple of keys -> Decomposition 1feabcd
36 decomp = self._getdecomp(solver) 1feabcd
37 self._decompclass = lambda K, **kwargs: decomp(K, **kwargs, **solverkw) 1feabcd
39 def _clone(self): 1feabcd
40 newself = super()._clone() 1feabcd
41 newself._decompcache = self._decompcache.copy() 1feabcd
42 newself._decompclass = self._decompclass 1feabcd
43 return newself 1feabcd
45 def _solver(self, keys, ycov=None, *, covtransf=None, **kw): 1feabcd
46 """
47 Return a decomposition of the covariance matrix of the keys in ``keys``
48 plus the matrix ycov. Keyword arguments are passed to the decomposition.
49 """
51 # TODO cache ignores **kw.
53 keys = tuple(keys) 1feabcd
55 # Check if decomposition is in cache.
56 if ycov is None: 1feabcd
57 cache = self._decompcache.get(keys) 1feabcd
58 if cache is not None: 1feabcd
59 return cache 1feabcd
60 # TODO use frozenset(keys) instead of tuple(keys) to make cache
61 # work when order changes, but I have to permute the decomposition
62 # to make that work. Needs an ad-hoc class in _linalg. Make the
63 # decompcache a dict subclass that accepts tuples of keys but uses
64 # internally frozenset.
66 # Compute decomposition. # woodbury, currently un-implemented
67 # if isinstance(ycov, _linalg.Decomposition):
68 # ancestors = []
69 # transfs = []
70 # for key in keys:
71 # elem = self._elements[key]
72 # nest = False
73 # if isinstance(elem, self._LinTransf):
74 # size = sum(self._elements[k].size for k in elem.keys)
75 # if size < elem.size:
76 # nest = True
77 # ancestors += list(elem.keys)
78 # transfs.append(jnp.concatenate(elem.matrices(self), 1))
79 # if not nest:
80 # ancestors.append(key)
81 # transfs.append(jnp.eye(elem.size))
82 # transf = jlinalg.block_diag(*transfs)
83 # cov = self._assemblecovblocks(ancestors)
84 # if covtransf:
85 # ycov, transf, cov = covtransf((ycov, transf, cov))
86 # covdec = self._decompclass(cov, **kw)
87 # # TODO obtain covdec from _solver recursively, to use cache?
88 # decomp = _linalg.Woodbury2(ycov, transf, covdec, self._decompclass, sign=1, **kw)
89 # else:
90 Kxx = self._assemblecovblocks(keys) 1feabcd
91 if ycov is not None: 1feabcd
92 Kxx = Kxx + ycov 1feabcd
93 if covtransf: 1feabcd
94 Kxx = covtransf(Kxx) 1feabcd
95 decomp = self._decompclass(Kxx, **kw) 1feabcd
97 # Cache decomposition.
98 if ycov is None: 1feabcd
99 self._decompcache[keys] = decomp 1feabcd
101 return decomp 1feabcd
103 def _flatgiven(self, given, givencov): 1feabcd
105 if not hasattr(given, 'keys'): 1feabcd
106 raise TypeError('`given` must be dict') 1abcd
107 gcblack = givencov is None or isinstance(givencov, _linalg.Decomposition) 1feabcd
108 if not gcblack and not hasattr(givencov, 'keys'): 1feabcd
109 raise TypeError('`givenconv` must be None, dict or Decomposition') 1abcd
111 ylist = [] 1feabcd
112 keylist = [] 1feabcd
113 for key, l in given.items(): 1feabcd
114 if key not in self._elements: 1feabcd
115 raise KeyError(key) 1abcd
117 if not isinstance(l, jnp.ndarray): 1feabcd
118 # use numpy since there could be gvars
119 l = numpy.asarray(l) 1feabcd
120 shape = self._elements[key].shape 1feabcd
121 if l.shape != shape: 1feabcd
122 msg = 'given[{!r}] has shape {!r} different from shape {!r}' 1abcd
123 raise ValueError(msg.format(key, l.shape, shape)) 1abcd
124 if l.dtype != object and not jnp.issubdtype(l.dtype, jnp.number): 1feabcd
125 msg = 'given[{!r}] has non-numerical dtype {!r}' 1abcd
126 raise TypeError(msg.format(key, l.dtype)) 1abcd
128 ylist.append(l.reshape(-1)) 1feabcd
129 keylist.append(key) 1feabcd
131 # TODO error checking on the unpacking of givencov
133 if gcblack: 1feabcd
134 covblocks = givencov 1feabcd
135 else:
136 covblocks = [ 1eabcd
137 [
138 jnp.asarray(givencov[keylist[i], keylist[j]]).reshape(ylist[i].shape + ylist[j].shape)
139 for j in range(len(keylist))
140 ]
141 for i in range(len(keylist))
142 ]
144 return ylist, keylist, covblocks 1feabcd
146 def pred(self, given, key=None, givencov=None, *, fromdata=None, raw=False, keepcorr=None): 1feabcd
147 """
149 Compute the posterior.
151 The posterior can be computed either for all points or for a subset,
152 and either directly from data or from a posterior obtained with a fit.
153 The latter case is for when the Gaussian process was used in a fit with
154 other parameters.
156 The output is a collection of gvars, either an array or a dictionary
157 of arrays. They are properly correlated with gvars returned by
158 `prior` and with the input data/fit.
160 The input is a dictionary of arrays, ``given``, with keys corresponding
161 to the keys in the GP as added by `addx` or `addtransf`.
163 Parameters
164 ----------
165 given : dictionary of arrays
166 The data or fit result for some/all of the points in the GP.
167 The arrays can contain either gvars or normal numbers, the latter
168 being equivalent to zero-uncertainty gvars.
169 key : None, key or list of keys, optional
170 If None, compute the posterior for all points in the GP (also those
171 used in ``given``). Otherwise only those specified by key.
172 givencov : dictionary of arrays, optional
173 Covariance matrix of ``given``. If not specified, the covariance
174 is extracted from ``given`` with ``gvar.evalcov(given)``.
175 fromdata : bool
176 Mandatory. Specify if the contents of ``given`` are data or already
177 a posterior.
178 raw : bool, optional
179 If True, instead of returning a collection of gvars, return
180 the mean and the covariance. When the mean is a dictionary, the
181 covariance is a dictionary whose keys are pairs of keys of the
182 mean (the same format used by `gvar.evalcov`). Default False.
183 keepcorr : bool, optional
184 If True (default), the returned gvars are correlated with the
185 prior and the data/fit. If False, they have the correct covariance
186 between themselves, but are independent from all other preexisting
187 gvars.
189 Returns
190 -------
191 If raw=False (default):
193 posterior : array or dictionary of arrays
194 A collections of gvars representing the posterior.
196 If raw=True:
198 pmean : array or dictionary of arrays
199 The mean of the posterior. Equivalent to ``gvar.mean(posterior)``.
200 pcov : 2D array or dictionary of 2D arrays
201 The covariance matrix of the posterior. If ``pmean`` is a
202 dictionary, the keys of ``pcov`` are pairs of keys of ``pmean``.
203 Equivalent to ``gvar.evalcov(posterior)``.
205 """
207 # TODO GP.pred(..., raw=True, onlyvariance=True) computes only the
208 # variance (requires actually implementing diagquad at least in Chol and
209 # Diag).
212 if fromdata is None: 1feabcd
213 raise ValueError('you must specify if `given` is data or fit result') 1abcd
214 fromdata = bool(fromdata) 1feabcd
215 raw = bool(raw) 1feabcd
216 if keepcorr is None: 1feabcd
217 keepcorr = not raw 1feabcd
218 if keepcorr and raw: 1feabcd
219 raise ValueError('both keepcorr=True and raw=True') 1abcd
221 strip = False 1feabcd
222 if key is None: 1feabcd
223 outkeys = list(self._elements) 1feabcd
224 elif isinstance(key, list): 1feabcd
225 outkeys = key 1feabcd
226 else:
227 outkeys = [key] 1feabcd
228 strip = True 1feabcd
229 outslices = self._slices(outkeys) 1feabcd
231 ylist, inkeys, ycovblocks = self._flatgiven(given, givencov) 1feabcd
232 y = self._concatenate(ylist) 1feabcd
233 if y.dtype == object: 1feabcd
234 if ycovblocks is not None: 1feabcd
235 raise ValueError('given may contain gvars but a separate covariance matrix has been provided') 1abcd
237 self._checkpos_keys(inkeys + outkeys) 1feabcd
239 Kxxs = self._assemblecovblocks(inkeys, outkeys) 1feabcd
241 # if isinstance(ycovblocks, _linalg.Decomposition): # woodbury, currently un-implemented
242 # ycov = ycovblocks
243 # elif ...
244 if ycovblocks is not None: 1feabcd
245 ycov = jnp.block(ycovblocks) 1abcd
246 elif (fromdata or raw or not keepcorr) and y.dtype == object: 1feabcd
247 ycov = gvar.evalcov(gvar.gvar(y)) 1feabcd
248 # TODO use evalcov_blocks
249 # TODO I think this ignores the case in which we are using gvars
250 # and they are correlated with the GP. I guess the correct thing
251 # would be to sum the data gvars to the prior ones and use the
252 # resulting covariance matrix, and write a note about possible
253 # different results in this case when switching raw or keepcorr.
254 else:
255 ycov = None 1feabcd
256 self._check_ycov(ycov) 1feabcd
258 if raw or not keepcorr or self._checkfinite: 1feabcd
259 if y.dtype == object: 1feabcd
260 ymean = gvar.mean(y) 1feabcd
261 else:
262 ymean = y 1feabcd
263 self._check_ymean(ymean) 1feabcd
265 if raw or not keepcorr: 1feabcd
267 Kxsxs = self._assemblecovblocks(outkeys) 1feabcd
269 if fromdata: 1feabcd
270 solver = self._solver(inkeys, ycov) 1feabcd
271 else:
272 solver = self._solver(inkeys) 1abcd
274 mean = solver.pinv_bilinear(Kxxs, ymean) 1feabcd
275 cov = Kxsxs - solver.ginv_quad(Kxxs) 1feabcd
277 if not fromdata: 1feabcd
278 # cov = Kxsxs - Kxsx Kxx^-1 (Kxx - ycov) Kxx^-1 Kxxs =
279 # = Kxsxs - Kxsx Kxx^-1 Kxxs + Kxsx Kxx^-1 ycov Kxx^-1 Kxxs
280 if ycov is not None: 1abcd
281 # if isinstance(ycov, _linalg.Decomposition): # for woodbury, currently un-implemented
282 # ycov = ycov.matrix()
283 A = solver.ginv_linear(Kxxs) 1abcd
284 # TODO do I need K⁺ here or is K⁻ fine?
285 cov += A.T @ ycov @ A 1abcd
287 else: # (keepcorr and not raw)
288 yplist = [numpy.reshape(self._prior(key), -1) for key in inkeys] 1feabcd
289 ysplist = [numpy.reshape(self._prior(key), -1) for key in outkeys] 1feabcd
290 yp = self._concatenate(yplist) 1feabcd
291 ysp = self._concatenate(ysplist) 1feabcd
293 if y.dtype != object and ycov is not None: 1feabcd
294 # if isinstance(ycov, _linalg.Decomposition): # for woodbury, currently un-implemented
295 # ycov = ycov.matrix()
296 y = gvar.gvar(y, ycov) 1abcd
297 else:
298 y = numpy.asarray(y) # because y - yp fails if y is a jax array 1feabcd
299 mat = ycov if fromdata else None 1feabcd
300 flatout = ysp + self._solver(inkeys, mat).pinv_bilinear_robj(Kxxs, y - yp) 1feabcd
302 if raw and not strip: 1feabcd
303 meandict = { 1eabcd
304 key: mean[slic].reshape(self._elements[key].shape)
305 for key, slic in zip(outkeys, outslices)
306 }
308 covdict = { 1eabcd
309 (row, col):
310 cov[rowslice, colslice].reshape(self._elements[row].shape + self._elements[col].shape)
311 for row, rowslice in zip(outkeys, outslices)
312 for col, colslice in zip(outkeys, outslices)
313 }
315 return meandict, covdict 1eabcd
317 elif raw: 1feabcd
318 outkey, = outkeys 1feabcd
319 mean = mean.reshape(self._elements[outkey].shape) 1feabcd
320 cov = cov.reshape(2 * self._elements[outkey].shape) 1feabcd
321 return mean, cov 1feabcd
323 elif not keepcorr: 1feabcd
325 ##### temporary fix for gplepage/gvar#49 #####
326 cov = numpy.array(cov) 1eabcd
327 ##############################################
329 flatout = gvar.gvar(mean, cov, fast=True) 1eabcd
331 if not strip: 1feabcd
332 return { 1feabcd
333 key: flatout[slic].reshape(self._elements[key].shape)
334 for key, slic in zip(outkeys, outslices)
335 }
336 else:
337 outkey, = outkeys 1feabcd
338 return flatout.reshape(self._elements[outkey].shape) 1feabcd
340 def predfromfit(self, *args, **kw): 1feabcd
341 """
342 Like `pred` with ``fromdata=False``.
343 """
344 return self.pred(*args, fromdata=False, **kw) 1feabcd
346 def predfromdata(self, *args, **kw): 1feabcd
347 """
348 Like `pred` with ``fromdata=True``.
349 """
350 return self.pred(*args, fromdata=True, **kw) 1feabcd
352 def _prior_decomp(self, given, givencov=None, **kw): 1feabcd
353 """ Internal implementation of marginal_likelihood. Keyword arguments
354 are passed to _solver. """
355 ylist, inkeys, ycovblocks = self._flatgiven(given, givencov) 1feabcd
356 y = self._concatenate(ylist) 1feabcd
358 self._checkpos_keys(inkeys) 1feabcd
360 # Get mean.
361 if y.dtype == object: 1feabcd
362 ymean = gvar.mean(y) 1abcd
363 else:
364 ymean = y 1feabcd
365 self._check_ymean(ymean) 1feabcd
367 # Get covariance matrix.
368 # if isinstance(ycovblocks, _linalg.Decomposition): # for woodbury, currently un-implemented
369 # ycov = ycovblocks
370 # elif ...
371 if ycovblocks is not None: 1feabcd
372 ycov = jnp.block(ycovblocks) 1eabcd
373 if y.dtype == object: 1eabcd
374 warnings.warn(f'covariance matrix may have been specified both explicitly and with gvars; the explicit one will be used') 1abcd
375 elif y.dtype == object: 1feabcd
376 gvary = gvar.gvar(y) 1abcd
377 ycov = gvar.evalcov(gvary) 1abcd
378 else:
379 ycov = None 1feabcd
380 self._check_ycov(ycov) 1feabcd
382 decomp = self._solver(inkeys, ycov, **kw) 1feabcd
383 return decomp, ymean 1feabcd
385 def _check_ymean(self, ymean): 1feabcd
386 with _jaxext.skipifabstract(): 1feabcd
387 if self._checkfinite and not jnp.all(jnp.isfinite(ymean)): 1feabcd
388 raise ValueError('mean of `given` is not finite') 1abcd
390 def _check_ycov(self, ycov): 1feabcd
391 if ycov is None or isinstance(ycov, _linalg.Decomposition): 1feabcd
392 return 1feabcd
393 with _jaxext.skipifabstract(): 1feabcd
394 if self._checkfinite and not jnp.all(jnp.isfinite(ycov)): 1feabcd
395 raise ValueError('covariance matrix of `given` is not finite') 1abcd
396 if self._checksym and not jnp.allclose(ycov, ycov.T): 1feabcd
397 raise ValueError('covariance matrix of `given` is not symmetric') 1abcd
399 def marginal_likelihood(self, given, givencov=None, **kw): 1feabcd
400 """
402 Compute the logarithm of the probability of the data.
404 The probability is computed under the Gaussian prior and Gaussian error
405 model. It is also called marginal likelihood. If :math:`y` is the data
406 and :math:`g` is the Gaussian process, this is
408 .. math::
409 \\log \\int p(y|g) p(g) \\mathrm{d} g.
411 Unlike `pred`, you can't compute this with a fit result instead of
412 data. If you used the Gaussian process as latent variable in a fit,
413 use the whole fit to compute the marginal likelihood. E.g. `lsqfit`
414 always computes the logGBF (it's the same thing).
416 The input is an array or dictionary of arrays, ``given``. The contents
417 of ``given`` represent the input data.
419 Parameters
420 ----------
421 given : dictionary of arrays
422 The data for some/all of the points in the GP. The arrays can
423 contain either gvars or normal numbers, the latter being
424 equivalent to zero-uncertainty gvars.
425 givencov : dictionary of arrays, optional
426 Covariance matrix of ``given``. If not specified, the covariance
427 is extracted from ``given`` with ``gvar.evalcov(given)``.
428 **kw :
429 Additional keyword arguments are passed to the matrix decomposition.
431 Returns
432 -------
433 logp : scalar
434 The logarithm of the marginal likelihood.
435 """
436 decomp, ymean = self._prior_decomp(given, givencov, **kw) 1eabcd
437 mll, _, _, _, _ = decomp.minus_log_normal_density(ymean, value=True) 1eabcd
438 return -mll 1eabcd
440 @staticmethod 1feabcd
441 def _getdecomp(solver): 1feabcd
442 return { 1feabcd
443 'chol': _linalg.Chol,
444 }[solver]
446 @classmethod 1feabcd
447 def decompose(cls, posdefmatrix, solver='chol', **kw): 1feabcd
448 """
449 Decompose a nonnegative definite matrix.
451 The decomposition can be used to calculate linear algebra expressions
452 where the (pseudo)inverse of the matrix appears.
454 Parameters
455 ----------
456 posdefmatrix : array
457 A nonnegative definite nonempty symmetric square matrix. If the
458 array is not square, it must have a shape of the kind (k, n, m,
459 ..., k, n, m, ...) and is reshaped to (k * n * m * ..., k * n * m *
460 ...).
461 solver : str
462 Algorithm used to decompose the matrix.
464 'chol'
465 Cholesky decomposition after regularizing the matrix with a
466 Gershgorin estimate of the maximum eigenvalue.
467 **kw :
468 Additional options.
470 epsrel, epsabs : positive float or 'auto'
471 Specify the threshold for considering small the eigenvalues:
473 eps = epsrel * maximum_eigenvalue + epsabs
475 epsrel='auto' sets epsrel = matrix_size * float_epsilon,
476 while epsabs='auto' sets epsabs = float_epsilon. Default is
477 epsrel='auto', epsabs=0.
479 Returns
480 -------
481 decomp : Decomposition
482 An object representing the decomposition of the matrix. The
483 available methods and properties are (K being the matrix):
485 matrix():
486 Return K.
487 ginv():
488 Compute K⁻.
489 ginv_linear(X):
490 Compute K⁻X.
491 pinv_bilinear(A, r)
492 Compute A'K⁺r.
493 pinv_bilinear_robj(A, r)
494 Compute A'K⁺r, and r can be an array of arbitrary objects.
495 ginv_quad(A)
496 Compute A'K⁻A.
497 ginv_diagquad(A)
498 Compute diag(A'K⁻A).
499 correlate(x)
500 Compute Zx such that K = ZZ', Z can be rectangular.
501 back_correlate(X)
502 Compute Z'X.
503 pinv_correlate(x):
504 Compute Z⁺x.
505 minus_log_normal_density(r, ...)
506 Compute a Normal density and its derivatives.
507 eps
508 The threshold below which eigenvalues are not calculable.
509 n
510 Number of rows/columns of K.
511 m
512 Number of columns of Z.
514 Notes
515 -----
516 The decomposition operations are JAX-traceable, but they are not meant
517 to be differentiated. The method `minus_log_normal_density` provides
518 required derivatives with a custom implementation, given the derivatives
519 of the inputs.
521 """
522 m = jnp.asarray(posdefmatrix) 1eabcd
523 assert m.size > 0 1eabcd
524 assert m.ndim % 2 == 0 1eabcd
525 half = m.ndim // 2 1eabcd
526 head = m.shape[:half] 1eabcd
527 tail = m.shape[half:] 1eabcd
528 assert head == tail 1eabcd
529 n = math.prod(head) 1eabcd
530 m = m.reshape(n, n) 1eabcd
531 decompcls = cls._getdecomp(solver) 1eabcd
532 return decompcls(m, **kw) 1eabcd
534 # TODO extend the interface to use composite decompositions
535 # TODO accept a dict for covariance matrix