Coverage for src/lsqfitgp/_GP/_elements.py: 100%

424 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_GP/_elements.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import abc 1feabcd

21import functools 1feabcd

22import warnings 1feabcd

23import math 1feabcd

24 

25import gvar 1feabcd

26import numpy 1feabcd

27from scipy import sparse 1feabcd

28import jax 1feabcd

29from jax import numpy as jnp 1feabcd

30 

31from .. import _Deriv 1feabcd

32from .. import _array 1feabcd

33from .. import _jaxext 1feabcd

34from .. import _gvarext 1feabcd

35from .. import _linalg 1feabcd

36 

37from . import _base 1feabcd

38 

39class GPElements(_base.GPBase): 1feabcd

40 

41 def __init__(self, *, checkpos, checksym, posepsfac, halfmatrix): 1feabcd

42 self._elements = dict() # key -> _Element 1feabcd

43 self._covblocks = dict() # (key, key) -> matrix (2d flattened) 1feabcd

44 self._priordict = {} # key -> gvar array (shaped) 1feabcd

45 self._checkpositive = bool(checkpos) 1feabcd

46 self._posepsfac = float(posepsfac) 1feabcd

47 self._checksym = bool(checksym) 1feabcd

48 self._halfmatrix = bool(halfmatrix) 1feabcd

49 self._dtype = None 1feabcd

50 assert not (halfmatrix and checksym) 1feabcd

51 

52 def _clone(self): 1feabcd

53 newself = super()._clone() 1feabcd

54 newself._elements = self._elements.copy() 1feabcd

55 newself._covblocks = self._covblocks.copy() 1feabcd

56 newself._priordict = self._priordict.copy() 1feabcd

57 newself._checkpositive = self._checkpositive 1feabcd

58 newself._posepsfac = self._posepsfac 1feabcd

59 newself._checksym = self._checksym 1feabcd

60 newself._halfmatrix = self._halfmatrix 1feabcd

61 newself._dtype = self._dtype 1feabcd

62 return newself 1feabcd

63 

64 @staticmethod 1feabcd

65 def _concatenate(alist): 1feabcd

66 """ 

67 Decides to use numpy.concatenate or jnp.concatenate depending on the 

68 input to support gvars. 

69 """ 

70 if any(a.dtype == object for a in alist): 1feabcd

71 return numpy.concatenate(alist) 1feabcd

72 else: 

73 return jnp.concatenate(alist) 1feabcd

74 

75 @staticmethod 1feabcd

76 def _triu_indices_and_back(n): 1feabcd

77 """ 

78 Return indices to get the upper triangular part of a matrix, and indices 

79 to convert a flat array of upper triangular elements to a symmetric 

80 matrix. 

81 """ 

82 ix, iy = jnp.triu_indices(n) 1abcd

83 q = jnp.empty((n, n), ix.dtype) 1abcd

84 a = jnp.arange(ix.size) 1abcd

85 q = q.at[ix, iy].set(a) 1abcd

86 q = q.at[iy, ix].set(a) 1abcd

87 return ix, iy, q 1abcd

88 

89 class _Element(abc.ABC): 1feabcd

90 """ 

91 Abstract class for an object holding information associated to a key in 

92 a GP object. 

93 """ 

94 

95 @property 1feabcd

96 @abc.abstractmethod 1feabcd

97 def shape(self): # pragma: no cover 1feabcd

98 """Output shape""" 

99 pass 

100 

101 @property 1feabcd

102 def size(self): 1feabcd

103 return math.prod(self.shape) 1feabcd

104 

105 class _Points(_Element): 1feabcd

106 """Points where the process is evaluated""" 

107 

108 def __init__(self, x, deriv, proc): 1feabcd

109 assert isinstance(x, (numpy.ndarray, jnp.ndarray, _array.StructuredArray)) 1feabcd

110 assert isinstance(deriv, _Deriv.Deriv) 1feabcd

111 self.x = x 1feabcd

112 self.deriv = deriv 1feabcd

113 self.proc = proc 1feabcd

114 

115 @property 1feabcd

116 def shape(self): 1feabcd

117 return self.x.shape 1feabcd

118 

119 class _LinTransf(_Element): 1feabcd

120 """Linear transformation of other _Element objects""" 

121 

122 shape = None 1feabcd

123 

124 def __init__(self, transf, keys, shape): 1feabcd

125 self.transf = transf 1feabcd

126 self.keys = keys 1feabcd

127 self.shape = shape 1feabcd

128 

129 def matrices(self, gp): 1feabcd

130 """ 

131 Matrix coefficients of the transformation (with flattened inputs 

132 and output) 

133 """ 

134 elems = [gp._elements[key] for key in self.keys] 1abcd

135 matrices = [] 1abcd

136 transf = jax.vmap(self.transf, 0, 0) 1abcd

137 for i, elem in enumerate(elems): 1abcd

138 inputs = [ 1abcd

139 jnp.eye(elem.size).reshape((elem.size,) + elem.shape) 

140 if j == i else 

141 jnp.zeros((elem.size,) + ej.shape) 

142 for j, ej in enumerate(elems) 

143 ] 

144 output = transf(*inputs).reshape(elem.size, self.size).T 1abcd

145 matrices.append(output) 1abcd

146 return matrices 1abcd

147 

148 class _Cov(_Element): 1feabcd

149 """User-provided covariance matrix block(s)""" 

150 

151 shape = None 1feabcd

152 

153 def __init__(self, blocks, shape): 1feabcd

154 """ blocks = dict (key, key) -> matrix """ 

155 self.blocks = blocks 1eabcd

156 self.shape = shape 1eabcd

157 

158 @_base.newself 1feabcd

159 def addx(self, x, key=None, *, deriv=0, proc=_base.GPBase.DefaultProcess): 1feabcd

160 """ 

161  

162 Add points where the Gaussian process is evaluated. 

163  

164 The GP object keeps the various x arrays in a dictionary. If ``x`` is an 

165 array, you have to specify its dictionary key with the ``key`` parameter. 

166 Otherwise, you can directly pass a dictionary for ``x``. 

167  

168 To specify that on the given ``x`` a derivative of the process instead of 

169 the process itself should be evaluated, use the parameter ``deriv``. 

170  

171 `addx` may or may not copy the input arrays. 

172  

173 Parameters 

174 ---------- 

175 x : array or dictionary of arrays 

176 The points to be added. 

177 key : hashable 

178 If ``x`` is an array, the dictionary key under which ``x`` is added. 

179 Can not be specified if ``x`` is a dictionary. 

180 deriv : Deriv-like 

181 Derivative specification. A `Deriv` object or something that 

182 can be converted to `Deriv`. 

183 proc : hashable 

184 The process to be evaluated on the points. If not specified, use 

185 the default process. 

186  

187 """ 

188 

189 # TODO after I implement block solving, add per-key covariance matrix 

190 # flags. 

191 

192 # TODO add `copy` parameter, default False, to copy the input arrays 

193 # if they are numpy arrays. 

194 

195 # this interface does not allow adding a single dictionary as x element 

196 # unless it's wrapped as a 0d numpy array, but this is for the best 

197 

198 deriv = _Deriv.Deriv(deriv) 1feabcd

199 

200 if proc not in self._procs: 1feabcd

201 raise KeyError(f'process named {proc!r} not found') 1abcd

202 

203 if hasattr(x, 'keys'): 1feabcd

204 if key is not None: 1abcd

205 raise ValueError('can not specify key if x is a dictionary') 1abcd

206 if None in x: 1abcd

207 raise ValueError('None key in x not allowed') 1abcd

208 else: 

209 if key is None: 1feabcd

210 raise ValueError('x is not dictionary but key is None') 1abcd

211 x = {key: x} 1feabcd

212 

213 for key in x: 1feabcd

214 if key in self._elements: 1feabcd

215 raise KeyError('key {!r} already in GP'.format(key)) 1abcd

216 

217 gx = x[key] 1feabcd

218 

219 # Convert to JAX array, numpy array or StructuredArray. 

220 # convert eagerly to jax to avoid problems with tracing. 

221 gx = _array._asarray_jaxifpossible(gx) 1feabcd

222 

223 # Check dtype is compatible with previous arrays. 

224 # TODO since we never concatenate arrays we could allow a less 

225 # strict compatibility. In principle we could allow really anything 

226 # as long as the kernel eats it, but this probably would let bugs 

227 # through without being really ever useful. What would make sense 

228 # is checking the dtype structure matches recursively and check 

229 # concrete dtypes of fields can be casted. 

230 # TODO result_type is too lax. Examples: str, float -> str, 

231 # object, float -> object. I should use something like the 

232 # ordering function in updowncast.py. 

233 if self._dtype is not None: 1feabcd

234 try: 1feabcd

235 self._dtype = numpy.result_type(self._dtype, gx.dtype) 1feabcd

236 # do not use jnp.result_type, it does not support 

237 # structured types 

238 except TypeError: 1abcd

239 msg = 'x[{!r}].dtype = {!r} not compatible with {!r}' 1abcd

240 msg = msg.format(key, gx.dtype, self._dtype) 1abcd

241 raise TypeError(msg) 1abcd

242 else: 

243 self._dtype = gx.dtype 1feabcd

244 

245 # Check that the derivative specifications are compatible with the 

246 # array data type. 

247 if gx.dtype.names is None: 1feabcd

248 if not deriv.implicit: 1feabcd

249 raise ValueError('x has no fields but derivative has') 1abcd

250 else: 

251 for dim in deriv: 1feabcd

252 if dim not in gx.dtype.names: 1feabcd

253 raise ValueError(f'deriv field {dim!r} not in x') 1abcd

254 

255 self._elements[key] = self._Points(gx, deriv, proc) 1feabcd

256 

257 def _get_x_dtype(self): 1feabcd

258 """ Get the data type of x points """ 

259 return self._dtype 1feabcd

260 

261 def addtransf(self, tensors, key, *, axes=1): 1feabcd

262 """ 

263  

264 Apply a linear transformation to already specified process points. The 

265 result of the transformation is represented by a new key. 

266  

267 Parameters 

268 ---------- 

269 tensors : dict 

270 Dictionary mapping keys of the GP to arrays/scalars. Each array is 

271 matrix-multiplied with the process array represented by its key, 

272 while scalars are just multiplied. Finally, the keys are summed 

273 over. 

274 key : hashable 

275 A new key under which the transformation is placed. 

276 axes : int 

277 Number of axes to be summed over for matrix multiplication, 

278 referring to trailing axes for tensors in ` tensors``, and to 

279 heading axes for process points. Default 1. 

280  

281 Returns 

282 ------- 

283 gp : GP 

284 A new GP object with the applied modifications. 

285  

286 Notes 

287 ----- 

288 The multiplication between the tensors and the process is done with 

289 np.tensordot with, by default, 1-axis contraction. For >2d arrays this 

290 is different from numpy's matrix multiplication, which would act on the 

291 second-to-last dimension of the second array. 

292  

293 """ 

294 # Note: it may seem nice that when an array has less axes than `axes`, 

295 # the summation would be restricted only on the existing axes. However 

296 # this brings about the ambiguous case where only one of the factors has 

297 # not enough axes. How many axes do you sum over on the other? 

298 

299 # Check axes. 

300 assert isinstance(axes, int) and axes >= 0, axes 1eabcd

301 

302 # Check key. 

303 if key is None: 1eabcd

304 raise ValueError('key can not be None') 1abcd

305 if key in self._elements: 1eabcd

306 raise KeyError(f'key {key!r} already in GP') 1abcd

307 

308 # Check keys. 

309 for k in tensors: 1eabcd

310 if k not in self._elements: 1eabcd

311 raise KeyError(k) 1abcd

312 

313 # Check tensors and convert them to jax arrays. 

314 if len(tensors) == 0: 1eabcd

315 raise ValueError('empty tensors, undetermined output shape') 1abcd

316 tens = {} 1eabcd

317 for k, t in tensors.items(): 1eabcd

318 t = jnp.asarray(t) 1eabcd

319 # no need to check dtype since jax supports only numerical arrays 

320 with _jaxext.skipifabstract(): 1eabcd

321 if self._checkfinite and not jnp.all(jnp.isfinite(t)): 1eabcd

322 raise ValueError(f'tensors[{k!r}] contains infs/nans') 1abcd

323 rshape = self._elements[k].shape 1eabcd

324 if t.shape and t.shape[t.ndim - axes:] != rshape[:axes]: 1eabcd

325 raise ValueError(f'tensors[{k!r}].shape = {t.shape!r} can not be multiplied with shape {rshape!r} with {axes}-axes contraction') 1abcd

326 tens[k] = t 1eabcd

327 

328 # Check shapes broadcast correctly. 

329 arrays = tens.values() 1eabcd

330 elements = (self._elements[k] for k in tens) 1eabcd

331 shapes = ( 1eabcd

332 t.shape[:t.ndim - axes] + e.shape[axes:] if t.shape else e.shape 

333 for t, e in zip(arrays, elements) 

334 ) 

335 try: 1eabcd

336 shape = jnp.broadcast_shapes(*shapes) 1eabcd

337 except ValueError: 1abcd

338 msg = 'can not broadcast tensors with shapes [' 1abcd

339 msg += ', '.join(repr(t.shape) for t in arrays) 1abcd

340 msg += '] contracted with arrays with shapes [' 1abcd

341 msg += ', '.join(repr(e.shape) for e in elements) + ']' 1abcd

342 raise ValueError(msg) 1abcd

343 

344 # Define linear transformation. 

345 def equiv_lintransf(*args): 1eabcd

346 assert len(args) == len(tens) 1eabcd

347 out = None 1eabcd

348 for a, (k, t) in zip(args, tens.items()): 1eabcd

349 if t.shape: 1eabcd

350 b = jnp.tensordot(t, a, axes) 1eabcd

351 else: 

352 b = t * a 1eabcd

353 if out is None: 1eabcd

354 out = b 1eabcd

355 else: 

356 out = out + b 1eabcd

357 return out 1eabcd

358 keys = list(tens.keys()) 1eabcd

359 return self.addlintransf(equiv_lintransf, keys, key, checklin=False) 1eabcd

360 

361 @_base.newself 1feabcd

362 def addlintransf(self, transf, keys, key, *, checklin=None): 1feabcd

363 """ 

364  

365 Define a finite linear transformation of the evaluated process. 

366  

367 Parameters 

368 ---------- 

369 transf : callable 

370 A function with signature ``f(array1, array2, ...) -> array`` which 

371 computes the linear transformation. The function must be 

372 jax-traceable, i.e., use jax.numpy instead of numpy. 

373 keys : sequence 

374 Keys of parts of the process to be passed as inputs to the 

375 transformation. 

376 key : hashable 

377 The key of the newly defined points. 

378 checklin : bool 

379 If True (default), check that the given function is linear in its 

380 inputs. The default can be overridden at initialization of the GP 

381 object. Note that an affine function (x -> a + bx) is not linear. 

382  

383 Raises 

384 ------ 

385 RuntimeError : 

386 The transformation seems not to be linear. To disable the linearity 

387 check, initialize the GP with ``checklin=False``. 

388  

389 """ 

390 

391 # TODO elementwise operations can be applied more efficiently to 

392 # primary gvars (tipical case), so the method could use an option 

393 # `elementwise`. What is the reliable way to check it is indeed 

394 # elementwise with a single random vector? Zero items of the tangent 

395 # at random with p=0.5 and check they stay zero? (And of course check 

396 # the shape is preserved.) 

397 

398 # Check key. 

399 if key is None: 1feabcd

400 raise ValueError('key can not be None') 1abcd

401 if key in self._elements: 1feabcd

402 raise KeyError(f'key {key!r} already in GP') 1abcd

403 

404 # Check keys. 

405 for k in keys: 1feabcd

406 if k not in self._elements: 1feabcd

407 raise KeyError(k) 1abcd

408 

409 # Determine shape. 

410 class ArrayMockup: 1feabcd

411 def __init__(self, elem): 1feabcd

412 self.shape = elem.shape 1feabcd

413 self.dtype = float 1feabcd

414 inp = [ArrayMockup(self._elements[k]) for k in keys] 1feabcd

415 out = jax.eval_shape(transf, *inp) 1feabcd

416 shape = out.shape 1feabcd

417 

418 # Check that the transformation is linear. 

419 if checklin is None: 1feabcd

420 checklin = self._checklin 1feabcd

421 if checklin: 1feabcd

422 shapes = [self._elements[k].shape for k in keys] 1feabcd

423 self._checklinear(transf, shapes) 1feabcd

424 

425 self._elements[key] = self._LinTransf(transf, keys, shape) 1feabcd

426 

427 @_base.newself 1feabcd

428 def addcov(self, covblocks, key=None, *, decomps=None): 1feabcd

429 """ 

430  

431 Add user-defined prior covariance matrix blocks. 

432  

433 Covariance matrices defined with `addcov` represent arbitrary 

434 finite-dimensional zero-mean Gaussian variables, assumed independent 

435 from all other variables in the GP object. 

436  

437 Parameters 

438 ---------- 

439 covblocks : array or dictionary of arrays 

440 If an array: a covariance matrix (or tensor) to be added under key 

441 ``key``. If a dictionary: a mapping from pairs of keys to the 

442 corresponding covariance matrix blocks. A missing off-diagonal 

443 block in the dictionary is interpreted as a matrix of zeros, 

444 unless the corresponding transposed block is specified. 

445 key : hashable 

446 If ``covblocks`` is an array, the dictionary key under which 

447 ``covblocks`` is added. Can not be specified if ``covblocks`` is a 

448 dictionary. 

449 decomps : Decomposition or dict of Decompositions 

450 Pre-computed decompositions of (not necessarily all) diagonal 

451 blocks, as produced by `decompose`. The keys are single 

452 GP keys and not pairs like in ``covblocks``. 

453  

454 Raises 

455 ------ 

456 KeyError : 

457 A key is already used in the GP. 

458 ValueError : 

459 ``covblocks`` and/or ``key`` and ``decomps`` are malformed or 

460 inconsistent. 

461 TypeError : 

462 Wrong type of ``covblocks`` or ``decomps``. 

463  

464 """ 

465 

466 # TODO maybe allow passing only the lower/upper triangular part for 

467 # the diagonal blocks, like I meta-allow for out of diagonal blocks? 

468 

469 # TODO with multiple blocks and a single decomp, the decomp could be 

470 # interpreted as the decomposition of the whole block matrix. 

471 

472 # Check type of `covblocks` and standardize it to dictionary. 

473 if hasattr(covblocks, 'keys'): 1eabcd

474 if key is not None: 1abcd

475 raise ValueError('can not specify key if covblocks is a dictionary') 1abcd

476 if None in covblocks: 1abcd

477 raise ValueError('None key in covblocks not allowed') 1abcd

478 if decomps is not None and not hasattr(decomps, 'keys'): 1abcd

479 raise TypeError('covblocks is dictionary but decomps is not') 1abcd

480 else: 

481 if key is None: 1eabcd

482 raise ValueError('covblocks is not dictionary but key is None') 1abcd

483 covblocks = {(key, key): covblocks} 1eabcd

484 if decomps is not None: 1eabcd

485 decomps = {key: decomps} 1abcd

486 

487 if decomps is None: 1eabcd

488 decomps = {} 1eabcd

489 

490 # Convert blocks to jax arrays and determine shapes from diagonal 

491 # blocks. 

492 shapes = {} 1eabcd

493 preblocks = {} 1eabcd

494 for keys, block in covblocks.items(): 1eabcd

495 # TODO maybe check that keys is a 2-tuple 

496 for key in keys: 1eabcd

497 if key in self._elements: 1eabcd

498 raise KeyError(f'key {key!r} already in GP') 1abcd

499 xkey, ykey = keys 1eabcd

500 if block is None: 1eabcd

501 raise TypeError(f'block {keys!r} is None') 1abcd

502 # because jnp.asarray(None) interprets None as nan 

503 # (see jax issue #14506) 

504 block = jnp.asarray(block) 1eabcd

505 

506 if xkey == ykey: 1eabcd

507 

508 if block.ndim % 2 == 1: 1eabcd

509 raise ValueError(f'diagonal block {key!r} has odd number of axes') 1abcd

510 

511 half = block.ndim // 2 1eabcd

512 head = block.shape[:half] 1eabcd

513 tail = block.shape[half:] 1eabcd

514 if head != tail: 1eabcd

515 raise ValueError(f'shape {block.shape!r} of diagonal block {key!r} is not symmetric') 1abcd

516 shapes[xkey] = head 1eabcd

517 

518 with _jaxext.skipifabstract(): 1eabcd

519 if self._checksym and not jnp.allclose(block, block.T): 1eabcd

520 raise ValueError(f'diagonal block {key!r} is not symmetric') 1abcd

521 

522 preblocks[keys] = block 1eabcd

523 

524 # Check decomps is consistent with covblocks. 

525 for key, dec in decomps.items(): 1eabcd

526 if key not in shapes: 1abcd

527 raise KeyError(f'key {key!r} in decomps not found in diagonal blocks') 1abcd

528 if not isinstance(dec, _linalg.Decomposition): 1abcd

529 raise TypeError(f'decomps[{key!r}] = {dec!r} is not a decomposition') 1abcd

530 n = math.prod(shapes[key]) 1abcd

531 if dec.n != n: 1abcd

532 raise ValueError(f'decomposition matrix size {dec.n} != diagonal block size {n} for key {key!r}') 1abcd

533 

534 # Reshape blocks to square matrices and check that the shapes of out of 

535 # diagonal blocks match those of diagonal ones. 

536 blocks = {} 1eabcd

537 for keys, block in preblocks.items(): 1eabcd

538 with _jaxext.skipifabstract(): 1eabcd

539 if self._checkfinite and not jnp.all(jnp.isfinite(block)): 1eabcd

540 raise ValueError(f'block {keys!r} not finite') 1abcd

541 xkey, ykey = keys 1eabcd

542 if xkey == ykey: 1eabcd

543 size = math.prod(shapes[xkey]) 1eabcd

544 blocks[keys] = block.reshape((size, size)) 1eabcd

545 else: 

546 for key in keys: 1abcd

547 if key not in shapes: 1abcd

548 raise KeyError(f'key {key!r} from off-diagonal block {keys!r} not found in diagonal blocks') 1abcd

549 eshape = shapes[xkey] + shapes[ykey] 1abcd

550 if block.shape != eshape: 1abcd

551 raise ValueError(f'shape {block.shape!r} of block {keys!r} is not {eshape!r} as expected from diagonal blocks') 1abcd

552 xsize = math.prod(shapes[xkey]) 1abcd

553 ysize = math.prod(shapes[ykey]) 1abcd

554 block = block.reshape((xsize, ysize)) 1abcd

555 blocks[keys] = block 1abcd

556 revkeys = keys[::-1] 1abcd

557 blockT = preblocks.get(revkeys) 1abcd

558 if blockT is None: 1abcd

559 blocks[revkeys] = block.T 1abcd

560 

561 # Check symmetry of out of diagonal blocks. 

562 if self._checksym: 1eabcd

563 with _jaxext.skipifabstract(): 1abcd

564 for keys, block in blocks.items(): 1abcd

565 xkey, ykey = keys 1abcd

566 if xkey != ykey: 1abcd

567 blockT = blocks[ykey, xkey] 1abcd

568 if not jnp.allclose(block.T, blockT): 1abcd

569 raise ValueError(f'block {keys!r} is not the transpose of block {revkeys!r}') 1abcd

570 

571 # Create _Cov objects. 

572 for key, shape in shapes.items(): 1eabcd

573 self._elements[key] = self._Cov(blocks, shape) 1eabcd

574 decomp = decomps.get(key) 1eabcd

575 if decomp is not None: 1eabcd

576 self._decompcache[key,] = decomp 1abcd

577 

578 def _makecovblock_points(self, xkey, ykey): 1feabcd

579 x = self._elements[xkey] 1feabcd

580 y = self._elements[ykey] 1feabcd

581 

582 assert isinstance(x, self._Points) 1feabcd

583 assert isinstance(y, self._Points) 1feabcd

584 

585 kernel = self._crosskernel(x.proc, y.proc) 1feabcd

586 if kernel is self._zerokernel: 1feabcd

587 # TODO handle zero cov block efficiently 

588 return jnp.zeros((x.size, y.size)) 1feabcd

589 

590 kernel = kernel.linop('diff', x.deriv, y.deriv) 1feabcd

591 

592 if x is y and not self._checksym and self._halfmatrix: 1feabcd

593 ix, iy, back = self._triu_indices_and_back(x.size) 1abcd

594 flat = x.x.reshape(-1) 1abcd

595 ax = flat[ix] 1abcd

596 ay = flat[iy] 1abcd

597 halfcov = kernel(ax, ay) 1abcd

598 cov = halfcov[back] 1abcd

599 # TODO to avoid inefficiencies like in BART, maybe _Kernel should 

600 # have a method outer(x) that by default simply does self(x[None, 

601 # :], x[:, None]) but can be overwritten. This halfmatrix impl could 

602 # be moved there with an option outer(x, *, half=False). To carry 

603 # over custom implementations of outer, there should be a callable 

604 # attribute _outer, optionally set at initialization, that is 

605 # transformed by kernel operations. 

606 else: 

607 ax = x.x.reshape(-1)[:, None] 1feabcd

608 ay = y.x.reshape(-1)[None, :] 1feabcd

609 cov = kernel(ax, ay) 1feabcd

610 

611 return cov 1feabcd

612 

613 def _makecovblock_lintransf_any(self, xkey, ykey): 1feabcd

614 x = self._elements[xkey] 1feabcd

615 y = self._elements[ykey] 1feabcd

616 assert isinstance(x, self._LinTransf) 1feabcd

617 

618 # Gather covariance matrices to be transformed. 

619 covs = [] 1feabcd

620 for k in x.keys: 1feabcd

621 elem = self._elements[k] 1feabcd

622 cov = self._covblock(k, ykey) 1feabcd

623 assert cov.shape == (elem.size, y.size) 1feabcd

624 cov = cov.reshape(elem.shape + (y.size,)) 1feabcd

625 covs.append(cov) 1feabcd

626 

627 # Apply transformation. 

628 t = jax.vmap(x.transf, -1, -1) 1feabcd

629 cov = t(*covs) 1feabcd

630 assert cov.shape == x.shape + (y.size,) 1feabcd

631 return cov.reshape((x.size, y.size)) # don't leave out the ()! 1feabcd

632 # the () probably was an obscure autograd bug, I don't think it will 

633 # be a problem again with jax 

634 

635 def _makecovblock(self, xkey, ykey): 1feabcd

636 x = self._elements[xkey] 1feabcd

637 y = self._elements[ykey] 1feabcd

638 if isinstance(x, self._Points) and isinstance(y, self._Points): 1feabcd

639 cov = self._makecovblock_points(xkey, ykey) 1feabcd

640 elif isinstance(x, self._LinTransf): 1feabcd

641 cov = self._makecovblock_lintransf_any(xkey, ykey) 1feabcd

642 elif isinstance(y, self._LinTransf): 1feabcd

643 cov = self._makecovblock_lintransf_any(ykey, xkey) 1feabcd

644 cov = cov.T 1feabcd

645 elif isinstance(x, self._Cov) and isinstance(y, self._Cov) and x.blocks is y.blocks and (xkey, ykey) in x.blocks: 1eabcd

646 cov = x.blocks[xkey, ykey] 1eabcd

647 else: 

648 # TODO handle zero cov block efficiently 

649 cov = jnp.zeros((x.size, y.size)) 1eabcd

650 

651 with _jaxext.skipifabstract(): 1feabcd

652 if self._checkfinite and not jnp.all(jnp.isfinite(cov)): 1feabcd

653 raise RuntimeError(f'covariance block {(xkey, ykey)!r} is not finite') 1abcd

654 if self._checksym and xkey == ykey and not jnp.allclose(cov, cov.T): 1feabcd

655 raise RuntimeError(f'covariance block {(xkey, ykey)!r} is not symmetric') 1abcd

656 

657 return cov 1feabcd

658 

659 def _covblock(self, row, col): 1feabcd

660 

661 if (row, col) not in self._covblocks: 1feabcd

662 block = self._makecovblock(row, col) 1feabcd

663 if row != col: 1feabcd

664 if self._checksym: 1feabcd

665 with _jaxext.skipifabstract(): 1feabcd

666 blockT = self._makecovblock(col, row) 1feabcd

667 if not jnp.allclose(block.T, blockT): 1feabcd

668 msg = 'covariance block {!r} is not symmetric' 1abcd

669 raise RuntimeError(msg.format((row, col))) 1abcd

670 self._covblocks[col, row] = block.T 1feabcd

671 self._covblocks[row, col] = block 1feabcd

672 

673 return self._covblocks[row, col] 1feabcd

674 

675 def _assemblecovblocks(self, rowkeys, colkeys=None): 1feabcd

676 if colkeys is None: 1feabcd

677 colkeys = rowkeys 1feabcd

678 blocks = [ 1feabcd

679 [self._covblock(row, col) for col in colkeys] 

680 for row in rowkeys 

681 ] 

682 return jnp.block(blocks) 1feabcd

683 

684 def _checkpos(self, cov): 1feabcd

685 with _jaxext.skipifabstract(): 1feabcd

686 # eigv = jnp.linalg.eigvalsh(cov) 

687 # mineigv, maxeigv = jnp.min(eigv), jnp.max(eigv) 

688 with warnings.catch_warnings(): 1feabcd

689 warnings.filterwarnings('ignore', r'Exited at iteration .+? with accuracies') 1feabcd

690 warnings.filterwarnings('ignore', r'Exited postprocessing with accuracies') 1feabcd

691 X = numpy.random.randn(len(cov), 1) 1feabcd

692 A = numpy.asarray(cov) 1feabcd

693 (mineigv,), _ = sparse.linalg.lobpcg(A, X, largest=False) 1feabcd

694 (maxeigv,), _ = sparse.linalg.lobpcg(A, X, largest=True) 1feabcd

695 assert mineigv <= maxeigv 1feabcd

696 if mineigv < 0: 1feabcd

697 bound = -len(cov) * jnp.finfo(cov.dtype).eps * maxeigv * self._posepsfac 1abcd

698 if mineigv < bound: 1abcd

699 msg = 'covariance matrix is not positive definite: ' 1abcd

700 msg += 'mineigv = {:.4g} < {:.4g}'.format(mineigv, bound) 1abcd

701 raise numpy.linalg.LinAlgError(msg) 1abcd

702 

703 _checkpos_cache = functools.cached_property(lambda self: []) 1feabcd

704 def _checkpos_keys(self, keys): 1feabcd

705 # TODO go back to ancestors of _LinTransf? 

706 if not self._checkpositive: 1feabcd

707 return 1feabcd

708 keys = set(keys) 1feabcd

709 for prev_keys in self._checkpos_cache: 1feabcd

710 if keys.issubset(prev_keys): 1feabcd

711 return 1feabcd

712 cov = self._assemblecovblocks(list(keys)) 1feabcd

713 self._checkpos(cov) 1feabcd

714 self._checkpos_cache.append(keys) 1feabcd

715 

716 def _priorpointscov(self, key): 1feabcd

717 

718 x = self._elements[key] 1feabcd

719 classes = (self._Points, self._Cov) 1feabcd

720 assert isinstance(x, classes) 1feabcd

721 mean = numpy.zeros(x.size) 1feabcd

722 cov = self._covblock(key, key).astype(float) 1feabcd

723 assert cov.shape == 2 * mean.shape, cov.shape 1feabcd

724 

725 # get preexisting primary gvars to be correlated with the new ones 

726 preitems = [ 1feabcd

727 k 

728 for k, px in self._elements.items() 

729 if isinstance(px, classes) 

730 and k in self._priordict 

731 ] 

732 if preitems: 1feabcd

733 prex = numpy.concatenate([ 1feabcd

734 numpy.reshape(self._priordict[k], -1) 

735 for k in preitems 

736 ]) 

737 precov = numpy.concatenate([ 1feabcd

738 self._covblock(k, key).astype(float) 

739 for k in preitems 

740 ]) 

741 g = gvar.gvar(mean, cov, prex, precov, fast=True) 1feabcd

742 else: 

743 g = gvar.gvar(mean, cov, fast=True) 1feabcd

744 

745 return g.reshape(x.shape) 1feabcd

746 

747 def _priorlintransf(self, key): 1feabcd

748 x = self._elements[key] 1feabcd

749 assert isinstance(x, self._LinTransf) 1feabcd

750 

751 # Gather all gvars to be transformed. 

752 elems = [ 1feabcd

753 self._prior(k).reshape(-1) 

754 for k in x.keys 

755 ] 

756 g = numpy.concatenate(elems) 1feabcd

757 

758 # Extract jacobian and split it. 

759 slices = self._slices(x.keys) 1feabcd

760 jac, indices = _gvarext.jacobian(g) 1feabcd

761 jacs = [ 1feabcd

762 jac[s].reshape(self._elements[k].shape + indices.shape) 

763 for s, k in zip(slices, x.keys) 

764 ] 

765 # TODO the jacobian can be extracted much more efficiently when the 

766 # elements are _Points or _Cov, since in that case the gvars are primary 

767 # and contiguous within each block, so each jacobian is the identity + a 

768 # range. Then write a function _gvarext.merge_jacobians to combine 

769 # them, which also can be optimized knowing the indices are 

770 # non-overlapping ranges. 

771 

772 # Apply transformation. 

773 t = jax.vmap(x.transf, -1, -1) 1feabcd

774 outjac = t(*jacs) 1feabcd

775 assert outjac.shape == x.shape + indices.shape 1feabcd

776 

777 # Rebuild gvars. 

778 outg = _gvarext.from_jacobian(numpy.zeros(x.shape), outjac, indices) 1feabcd

779 return outg 1feabcd

780 

781 def _prior(self, key): 1feabcd

782 prior = self._priordict.get(key, None) 1feabcd

783 if prior is None: 1feabcd

784 x = self._elements[key] 1feabcd

785 if isinstance(x, (self._Points, self._Cov)): 1feabcd

786 prior = self._priorpointscov(key) 1feabcd

787 elif isinstance(x, self._LinTransf): 1feabcd

788 prior = self._priorlintransf(key) 1feabcd

789 else: # pragma: no cover 

790 raise TypeError(type(x)) 

791 self._priordict[key] = prior 1feabcd

792 return prior 1feabcd

793 

794 def prior(self, key=None, *, raw=False): 1feabcd

795 """ 

796  

797 Return an array or a dictionary of arrays of gvars representing the 

798 prior for the Gaussian process. The returned object is not unique but 

799 the gvars stored inside are, so all the correlations are kept between 

800 objects returned by different calls to `prior`. 

801  

802 Calling without arguments returns the complete prior as a dictionary. 

803 If you specify ``key``, only the array for the requested key is returned. 

804  

805 Parameters 

806 ---------- 

807 key : None, key or list of keys 

808 Key(s) corresponding to one passed to `addx` or `addtransf`. None 

809 for all keys. 

810 raw : bool 

811 If True, instead of returning a collection of gvars return 

812 their covariance matrix as would be returned by `gvar.evalcov`. 

813 Default False. 

814  

815 Returns 

816 ------- 

817 If raw=False (default): 

818  

819 prior : np.ndarray or dict 

820 A collection of gvars representing the prior. 

821  

822 If raw=True: 

823  

824 cov : np.ndarray or dict 

825 The covariance matrix of the prior. 

826 """ 

827 raw = bool(raw) 1feabcd

828 

829 if key is None: 1feabcd

830 outkeys = list(self._elements) 1feabcd

831 elif isinstance(key, list): 1feabcd

832 outkeys = key 1eabcd

833 else: 

834 outkeys = None 1feabcd

835 

836 self._checkpos_keys([key] if outkeys is None else outkeys) 1feabcd

837 

838 if raw and outkeys is not None: 1feabcd

839 return { 1eabcd

840 (row, col): 

841 self._covblock(row, col).reshape( 

842 self._elements[row].shape + 

843 self._elements[col].shape 

844 ) 

845 for row in outkeys 

846 for col in outkeys 

847 } 

848 elif raw: 1feabcd

849 return self._covblock(key, key).reshape(2 * self._elements[key].shape) 1abcd

850 elif outkeys is not None: 1feabcd

851 return {key: self._prior(key) for key in outkeys} 1feabcd

852 else: 

853 return self._prior(key) 1feabcd

854 

855 def _slices(self, keylist): 1feabcd

856 """ 

857 Return list of slices for the positions of flattened arrays 

858 corresponding to keys in ``keylist`` into their concatenation. 

859 """ 

860 sizes = [self._elements[key].size for key in keylist] 1feabcd

861 stops = numpy.pad(numpy.cumsum(sizes), (1, 0)) 1feabcd

862 return [slice(stops[i - 1], stops[i]) for i in range(1, len(stops))] 1feabcd