Coverage for src/lsqfitgp/_GP/_compute.py: 100%

184 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_GP/_compute.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import warnings 1feabcd

21import math 1feabcd

22 

23from jax import numpy as jnp 1feabcd

24import numpy 1feabcd

25import gvar 1feabcd

26 

27from .. import _linalg 1feabcd

28from .. import _jaxext 1feabcd

29 

30from . import _base 1feabcd

31 

32class GPCompute(_base.GPBase): 1feabcd

33 

34 def __init__(self, *, solver, solverkw): 1feabcd

35 self._decompcache = {} # tuple of keys -> Decomposition 1feabcd

36 decomp = self._getdecomp(solver) 1feabcd

37 self._decompclass = lambda K, **kwargs: decomp(K, **kwargs, **solverkw) 1feabcd

38 

39 def _clone(self): 1feabcd

40 newself = super()._clone() 1feabcd

41 newself._decompcache = self._decompcache.copy() 1feabcd

42 newself._decompclass = self._decompclass 1feabcd

43 return newself 1feabcd

44 

45 def _solver(self, keys, ycov=None, *, covtransf=None, **kw): 1feabcd

46 """ 

47 Return a decomposition of the covariance matrix of the keys in ``keys`` 

48 plus the matrix ycov. Keyword arguments are passed to the decomposition. 

49 """ 

50 

51 # TODO cache ignores **kw. 

52 

53 keys = tuple(keys) 1feabcd

54 

55 # Check if decomposition is in cache. 

56 if ycov is None: 1feabcd

57 cache = self._decompcache.get(keys) 1feabcd

58 if cache is not None: 1feabcd

59 return cache 1feabcd

60 # TODO use frozenset(keys) instead of tuple(keys) to make cache 

61 # work when order changes, but I have to permute the decomposition 

62 # to make that work. Needs an ad-hoc class in _linalg. Make the 

63 # decompcache a dict subclass that accepts tuples of keys but uses 

64 # internally frozenset. 

65 

66 # Compute decomposition. # woodbury, currently un-implemented 

67 # if isinstance(ycov, _linalg.Decomposition): 

68 # ancestors = [] 

69 # transfs = [] 

70 # for key in keys: 

71 # elem = self._elements[key] 

72 # nest = False 

73 # if isinstance(elem, self._LinTransf): 

74 # size = sum(self._elements[k].size for k in elem.keys) 

75 # if size < elem.size: 

76 # nest = True 

77 # ancestors += list(elem.keys) 

78 # transfs.append(jnp.concatenate(elem.matrices(self), 1)) 

79 # if not nest: 

80 # ancestors.append(key) 

81 # transfs.append(jnp.eye(elem.size)) 

82 # transf = jlinalg.block_diag(*transfs) 

83 # cov = self._assemblecovblocks(ancestors) 

84 # if covtransf: 

85 # ycov, transf, cov = covtransf((ycov, transf, cov)) 

86 # covdec = self._decompclass(cov, **kw) 

87 # # TODO obtain covdec from _solver recursively, to use cache? 

88 # decomp = _linalg.Woodbury2(ycov, transf, covdec, self._decompclass, sign=1, **kw) 

89 # else: 

90 Kxx = self._assemblecovblocks(keys) 1feabcd

91 if ycov is not None: 1feabcd

92 Kxx = Kxx + ycov 1feabcd

93 if covtransf: 1feabcd

94 Kxx = covtransf(Kxx) 1feabcd

95 decomp = self._decompclass(Kxx, **kw) 1feabcd

96 

97 # Cache decomposition. 

98 if ycov is None: 1feabcd

99 self._decompcache[keys] = decomp 1feabcd

100 

101 return decomp 1feabcd

102 

103 def _flatgiven(self, given, givencov): 1feabcd

104 

105 if not hasattr(given, 'keys'): 1feabcd

106 raise TypeError('`given` must be dict') 1abcd

107 gcblack = givencov is None or isinstance(givencov, _linalg.Decomposition) 1feabcd

108 if not gcblack and not hasattr(givencov, 'keys'): 1feabcd

109 raise TypeError('`givenconv` must be None, dict or Decomposition') 1abcd

110 

111 ylist = [] 1feabcd

112 keylist = [] 1feabcd

113 for key, l in given.items(): 1feabcd

114 if key not in self._elements: 1feabcd

115 raise KeyError(key) 1abcd

116 

117 if not isinstance(l, jnp.ndarray): 1feabcd

118 # use numpy since there could be gvars 

119 l = numpy.asarray(l) 1feabcd

120 shape = self._elements[key].shape 1feabcd

121 if l.shape != shape: 1feabcd

122 msg = 'given[{!r}] has shape {!r} different from shape {!r}' 1abcd

123 raise ValueError(msg.format(key, l.shape, shape)) 1abcd

124 if l.dtype != object and not jnp.issubdtype(l.dtype, jnp.number): 1feabcd

125 msg = 'given[{!r}] has non-numerical dtype {!r}' 1abcd

126 raise TypeError(msg.format(key, l.dtype)) 1abcd

127 

128 ylist.append(l.reshape(-1)) 1feabcd

129 keylist.append(key) 1feabcd

130 

131 # TODO error checking on the unpacking of givencov 

132 

133 if gcblack: 1feabcd

134 covblocks = givencov 1feabcd

135 else: 

136 covblocks = [ 1eabcd

137 [ 

138 jnp.asarray(givencov[keylist[i], keylist[j]]).reshape(ylist[i].shape + ylist[j].shape) 

139 for j in range(len(keylist)) 

140 ] 

141 for i in range(len(keylist)) 

142 ] 

143 

144 return ylist, keylist, covblocks 1feabcd

145 

146 def pred(self, given, key=None, givencov=None, *, fromdata=None, raw=False, keepcorr=None): 1feabcd

147 """ 

148  

149 Compute the posterior. 

150  

151 The posterior can be computed either for all points or for a subset, 

152 and either directly from data or from a posterior obtained with a fit. 

153 The latter case is for when the Gaussian process was used in a fit with 

154 other parameters. 

155  

156 The output is a collection of gvars, either an array or a dictionary 

157 of arrays. They are properly correlated with gvars returned by 

158 `prior` and with the input data/fit. 

159  

160 The input is a dictionary of arrays, ``given``, with keys corresponding 

161 to the keys in the GP as added by `addx` or `addtransf`. 

162  

163 Parameters 

164 ---------- 

165 given : dictionary of arrays 

166 The data or fit result for some/all of the points in the GP. 

167 The arrays can contain either gvars or normal numbers, the latter 

168 being equivalent to zero-uncertainty gvars. 

169 key : None, key or list of keys, optional 

170 If None, compute the posterior for all points in the GP (also those 

171 used in ``given``). Otherwise only those specified by key. 

172 givencov : dictionary of arrays, optional 

173 Covariance matrix of ``given``. If not specified, the covariance 

174 is extracted from ``given`` with ``gvar.evalcov(given)``. 

175 fromdata : bool 

176 Mandatory. Specify if the contents of ``given`` are data or already 

177 a posterior. 

178 raw : bool, optional 

179 If True, instead of returning a collection of gvars, return 

180 the mean and the covariance. When the mean is a dictionary, the 

181 covariance is a dictionary whose keys are pairs of keys of the 

182 mean (the same format used by `gvar.evalcov`). Default False. 

183 keepcorr : bool, optional 

184 If True (default), the returned gvars are correlated with the 

185 prior and the data/fit. If False, they have the correct covariance 

186 between themselves, but are independent from all other preexisting 

187 gvars. 

188  

189 Returns 

190 ------- 

191 If raw=False (default): 

192  

193 posterior : array or dictionary of arrays 

194 A collections of gvars representing the posterior. 

195  

196 If raw=True: 

197  

198 pmean : array or dictionary of arrays 

199 The mean of the posterior. Equivalent to ``gvar.mean(posterior)``. 

200 pcov : 2D array or dictionary of 2D arrays 

201 The covariance matrix of the posterior. If ``pmean`` is a 

202 dictionary, the keys of ``pcov`` are pairs of keys of ``pmean``. 

203 Equivalent to ``gvar.evalcov(posterior)``. 

204  

205 """ 

206 

207 # TODO GP.pred(..., raw=True, onlyvariance=True) computes only the 

208 # variance (requires actually implementing diagquad at least in Chol and 

209 # Diag). 

210 

211 

212 if fromdata is None: 1feabcd

213 raise ValueError('you must specify if `given` is data or fit result') 1abcd

214 fromdata = bool(fromdata) 1feabcd

215 raw = bool(raw) 1feabcd

216 if keepcorr is None: 1feabcd

217 keepcorr = not raw 1feabcd

218 if keepcorr and raw: 1feabcd

219 raise ValueError('both keepcorr=True and raw=True') 1abcd

220 

221 strip = False 1feabcd

222 if key is None: 1feabcd

223 outkeys = list(self._elements) 1feabcd

224 elif isinstance(key, list): 1feabcd

225 outkeys = key 1feabcd

226 else: 

227 outkeys = [key] 1feabcd

228 strip = True 1feabcd

229 outslices = self._slices(outkeys) 1feabcd

230 

231 ylist, inkeys, ycovblocks = self._flatgiven(given, givencov) 1feabcd

232 y = self._concatenate(ylist) 1feabcd

233 if y.dtype == object: 1feabcd

234 if ycovblocks is not None: 1feabcd

235 raise ValueError('given may contain gvars but a separate covariance matrix has been provided') 1abcd

236 

237 self._checkpos_keys(inkeys + outkeys) 1feabcd

238 

239 Kxxs = self._assemblecovblocks(inkeys, outkeys) 1feabcd

240 

241 # if isinstance(ycovblocks, _linalg.Decomposition): # woodbury, currently un-implemented 

242 # ycov = ycovblocks 

243 # elif ... 

244 if ycovblocks is not None: 1feabcd

245 ycov = jnp.block(ycovblocks) 1abcd

246 elif (fromdata or raw or not keepcorr) and y.dtype == object: 1feabcd

247 ycov = gvar.evalcov(gvar.gvar(y)) 1feabcd

248 # TODO use evalcov_blocks 

249 # TODO I think this ignores the case in which we are using gvars 

250 # and they are correlated with the GP. I guess the correct thing 

251 # would be to sum the data gvars to the prior ones and use the 

252 # resulting covariance matrix, and write a note about possible 

253 # different results in this case when switching raw or keepcorr. 

254 else: 

255 ycov = None 1feabcd

256 self._check_ycov(ycov) 1feabcd

257 

258 if raw or not keepcorr or self._checkfinite: 1feabcd

259 if y.dtype == object: 1feabcd

260 ymean = gvar.mean(y) 1feabcd

261 else: 

262 ymean = y 1feabcd

263 self._check_ymean(ymean) 1feabcd

264 

265 if raw or not keepcorr: 1feabcd

266 

267 Kxsxs = self._assemblecovblocks(outkeys) 1feabcd

268 

269 if fromdata: 1feabcd

270 solver = self._solver(inkeys, ycov) 1feabcd

271 else: 

272 solver = self._solver(inkeys) 1abcd

273 

274 mean = solver.pinv_bilinear(Kxxs, ymean) 1feabcd

275 cov = Kxsxs - solver.ginv_quad(Kxxs) 1feabcd

276 

277 if not fromdata: 1feabcd

278 # cov = Kxsxs - Kxsx Kxx^-1 (Kxx - ycov) Kxx^-1 Kxxs = 

279 # = Kxsxs - Kxsx Kxx^-1 Kxxs + Kxsx Kxx^-1 ycov Kxx^-1 Kxxs 

280 if ycov is not None: 1abcd

281 # if isinstance(ycov, _linalg.Decomposition): # for woodbury, currently un-implemented 

282 # ycov = ycov.matrix() 

283 A = solver.ginv_linear(Kxxs) 1abcd

284 # TODO do I need K⁺ here or is K⁻ fine? 

285 cov += A.T @ ycov @ A 1abcd

286 

287 else: # (keepcorr and not raw)  

288 yplist = [numpy.reshape(self._prior(key), -1) for key in inkeys] 1feabcd

289 ysplist = [numpy.reshape(self._prior(key), -1) for key in outkeys] 1feabcd

290 yp = self._concatenate(yplist) 1feabcd

291 ysp = self._concatenate(ysplist) 1feabcd

292 

293 if y.dtype != object and ycov is not None: 1feabcd

294 # if isinstance(ycov, _linalg.Decomposition): # for woodbury, currently un-implemented 

295 # ycov = ycov.matrix() 

296 y = gvar.gvar(y, ycov) 1abcd

297 else: 

298 y = numpy.asarray(y) # because y - yp fails if y is a jax array 1feabcd

299 mat = ycov if fromdata else None 1feabcd

300 flatout = ysp + self._solver(inkeys, mat).pinv_bilinear_robj(Kxxs, y - yp) 1feabcd

301 

302 if raw and not strip: 1feabcd

303 meandict = { 1eabcd

304 key: mean[slic].reshape(self._elements[key].shape) 

305 for key, slic in zip(outkeys, outslices) 

306 } 

307 

308 covdict = { 1eabcd

309 (row, col): 

310 cov[rowslice, colslice].reshape(self._elements[row].shape + self._elements[col].shape) 

311 for row, rowslice in zip(outkeys, outslices) 

312 for col, colslice in zip(outkeys, outslices) 

313 } 

314 

315 return meandict, covdict 1eabcd

316 

317 elif raw: 1feabcd

318 outkey, = outkeys 1feabcd

319 mean = mean.reshape(self._elements[outkey].shape) 1feabcd

320 cov = cov.reshape(2 * self._elements[outkey].shape) 1feabcd

321 return mean, cov 1feabcd

322 

323 elif not keepcorr: 1feabcd

324 

325 flatout = gvar.gvar(mean, cov, fast=True) 1eabcd

326 

327 if not strip: 1feabcd

328 return { 1feabcd

329 key: flatout[slic].reshape(self._elements[key].shape) 

330 for key, slic in zip(outkeys, outslices) 

331 } 

332 else: 

333 outkey, = outkeys 1feabcd

334 return flatout.reshape(self._elements[outkey].shape) 1feabcd

335 

336 def predfromfit(self, *args, **kw): 1feabcd

337 """ 

338 Like `pred` with ``fromdata=False``. 

339 """ 

340 return self.pred(*args, fromdata=False, **kw) 1feabcd

341 

342 def predfromdata(self, *args, **kw): 1feabcd

343 """ 

344 Like `pred` with ``fromdata=True``. 

345 """ 

346 return self.pred(*args, fromdata=True, **kw) 1feabcd

347 

348 def _prior_decomp(self, given, givencov=None, **kw): 1feabcd

349 """ Internal implementation of marginal_likelihood. Keyword arguments 

350 are passed to _solver. """ 

351 ylist, inkeys, ycovblocks = self._flatgiven(given, givencov) 1feabcd

352 y = self._concatenate(ylist) 1feabcd

353 

354 self._checkpos_keys(inkeys) 1feabcd

355 

356 # Get mean. 

357 if y.dtype == object: 1feabcd

358 ymean = gvar.mean(y) 1abcd

359 else: 

360 ymean = y 1feabcd

361 self._check_ymean(ymean) 1feabcd

362 

363 # Get covariance matrix. 

364 # if isinstance(ycovblocks, _linalg.Decomposition): # for woodbury, currently un-implemented 

365 # ycov = ycovblocks 

366 # elif ... 

367 if ycovblocks is not None: 1feabcd

368 ycov = jnp.block(ycovblocks) 1eabcd

369 if y.dtype == object: 1eabcd

370 warnings.warn(f'covariance matrix may have been specified both explicitly and with gvars; the explicit one will be used') 1abcd

371 elif y.dtype == object: 1feabcd

372 gvary = gvar.gvar(y) 1abcd

373 ycov = gvar.evalcov(gvary) 1abcd

374 else: 

375 ycov = None 1feabcd

376 self._check_ycov(ycov) 1feabcd

377 

378 decomp = self._solver(inkeys, ycov, **kw) 1feabcd

379 return decomp, ymean 1feabcd

380 

381 def _check_ymean(self, ymean): 1feabcd

382 with _jaxext.skipifabstract(): 1feabcd

383 if self._checkfinite and not jnp.all(jnp.isfinite(ymean)): 1feabcd

384 raise ValueError('mean of `given` is not finite') 1abcd

385 

386 def _check_ycov(self, ycov): 1feabcd

387 if ycov is None or isinstance(ycov, _linalg.Decomposition): 1feabcd

388 return 1feabcd

389 with _jaxext.skipifabstract(): 1feabcd

390 if self._checkfinite and not jnp.all(jnp.isfinite(ycov)): 1feabcd

391 raise ValueError('covariance matrix of `given` is not finite') 1abcd

392 if self._checksym and not jnp.allclose(ycov, ycov.T): 1feabcd

393 raise ValueError('covariance matrix of `given` is not symmetric') 1abcd

394 

395 def marginal_likelihood(self, given, givencov=None, **kw): 1feabcd

396 """ 

397  

398 Compute the logarithm of the probability of the data. 

399  

400 The probability is computed under the Gaussian prior and Gaussian error 

401 model. It is also called marginal likelihood. If :math:`y` is the data 

402 and :math:`g` is the Gaussian process, this is 

403  

404 .. math:: 

405 \\log \\int p(y|g) p(g) \\mathrm{d} g. 

406  

407 Unlike `pred`, you can't compute this with a fit result instead of 

408 data. If you used the Gaussian process as latent variable in a fit, 

409 use the whole fit to compute the marginal likelihood. E.g. `lsqfit` 

410 always computes the logGBF (it's the same thing). 

411  

412 The input is an array or dictionary of arrays, ``given``. The contents 

413 of ``given`` represent the input data. 

414  

415 Parameters 

416 ---------- 

417 given : dictionary of arrays 

418 The data for some/all of the points in the GP. The arrays can 

419 contain either gvars or normal numbers, the latter being 

420 equivalent to zero-uncertainty gvars. 

421 givencov : dictionary of arrays, optional 

422 Covariance matrix of ``given``. If not specified, the covariance 

423 is extracted from ``given`` with ``gvar.evalcov(given)``. 

424 **kw : 

425 Additional keyword arguments are passed to the matrix decomposition. 

426  

427 Returns 

428 ------- 

429 logp : scalar 

430 The logarithm of the marginal likelihood. 

431 """ 

432 decomp, ymean = self._prior_decomp(given, givencov, **kw) 1eabcd

433 mll, _, _, _, _ = decomp.minus_log_normal_density(ymean, value=True) 1eabcd

434 return -mll 1eabcd

435 

436 @staticmethod 1feabcd

437 def _getdecomp(solver): 1feabcd

438 return { 1feabcd

439 'chol': _linalg.Chol, 

440 }[solver] 

441 

442 @classmethod 1feabcd

443 def decompose(cls, posdefmatrix, solver='chol', **kw): 1feabcd

444 """ 

445 Decompose a nonnegative definite matrix. 

446  

447 The decomposition can be used to calculate linear algebra expressions 

448 where the (pseudo)inverse of the matrix appears. 

449  

450 Parameters 

451 ---------- 

452 posdefmatrix : array 

453 A nonnegative definite nonempty symmetric square matrix. If the 

454 array is not square, it must have a shape of the kind (k, n, m, 

455 ..., k, n, m, ...) and is reshaped to (k * n * m * ..., k * n * m * 

456 ...). 

457 solver : str 

458 Algorithm used to decompose the matrix. 

459 

460 'chol' 

461 Cholesky decomposition after regularizing the matrix with a 

462 Gershgorin estimate of the maximum eigenvalue. 

463 **kw : 

464 Additional options. 

465 

466 epsrel, epsabs : positive float or 'auto' 

467 Specify the threshold for considering small the eigenvalues: 

468  

469 eps = epsrel * maximum_eigenvalue + epsabs 

470  

471 epsrel='auto' sets epsrel = matrix_size * float_epsilon,  

472 while epsabs='auto' sets epsabs = float_epsilon. Default is 

473 epsrel='auto', epsabs=0. 

474  

475 Returns 

476 ------- 

477 decomp : Decomposition 

478 An object representing the decomposition of the matrix. The 

479 available methods and properties are (K being the matrix): 

480  

481 matrix(): 

482 Return K. 

483 ginv(): 

484 Compute K⁻. 

485 ginv_linear(X): 

486 Compute K⁻X. 

487 pinv_bilinear(A, r) 

488 Compute A'K⁺r. 

489 pinv_bilinear_robj(A, r) 

490 Compute A'K⁺r, and r can be an array of arbitrary objects. 

491 ginv_quad(A) 

492 Compute A'K⁻A. 

493 ginv_diagquad(A) 

494 Compute diag(A'K⁻A). 

495 correlate(x) 

496 Compute Zx such that K = ZZ', Z can be rectangular. 

497 back_correlate(X) 

498 Compute Z'X. 

499 pinv_correlate(x): 

500 Compute Z⁺x. 

501 minus_log_normal_density(r, ...) 

502 Compute a Normal density and its derivatives. 

503 eps 

504 The threshold below which eigenvalues are not calculable. 

505 n 

506 Number of rows/columns of K. 

507 m 

508 Number of columns of Z. 

509  

510 Notes 

511 ----- 

512 The decomposition operations are JAX-traceable, but they are not meant 

513 to be differentiated. The method `minus_log_normal_density` provides 

514 required derivatives with a custom implementation, given the derivatives 

515 of the inputs. 

516  

517 """ 

518 m = jnp.asarray(posdefmatrix) 1eabcd

519 assert m.size > 0 1eabcd

520 assert m.ndim % 2 == 0 1eabcd

521 half = m.ndim // 2 1eabcd

522 head = m.shape[:half] 1eabcd

523 tail = m.shape[half:] 1eabcd

524 assert head == tail 1eabcd

525 n = math.prod(head) 1eabcd

526 m = m.reshape(n, n) 1eabcd

527 decompcls = cls._getdecomp(solver) 1eabcd

528 return decompcls(m, **kw) 1eabcd

529 

530 # TODO extend the interface to use composite decompositions 

531 # TODO accept a dict for covariance matrix