Coverage for src/lsqfitgp/_GP/_elements.py: 100%
424 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_GP/_elements.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import abc 1feabcd
21import functools 1feabcd
22import warnings 1feabcd
23import math 1feabcd
25import gvar 1feabcd
26import numpy 1feabcd
27from scipy import sparse 1feabcd
28import jax 1feabcd
29from jax import numpy as jnp 1feabcd
31from .. import _Deriv 1feabcd
32from .. import _array 1feabcd
33from .. import _jaxext 1feabcd
34from .. import _gvarext 1feabcd
35from .. import _linalg 1feabcd
37from . import _base 1feabcd
39class GPElements(_base.GPBase): 1feabcd
41 def __init__(self, *, checkpos, checksym, posepsfac, halfmatrix): 1feabcd
42 self._elements = dict() # key -> _Element 1feabcd
43 self._covblocks = dict() # (key, key) -> matrix (2d flattened) 1feabcd
44 self._priordict = {} # key -> gvar array (shaped) 1feabcd
45 self._checkpositive = bool(checkpos) 1feabcd
46 self._posepsfac = float(posepsfac) 1feabcd
47 self._checksym = bool(checksym) 1feabcd
48 self._halfmatrix = bool(halfmatrix) 1feabcd
49 self._dtype = None 1feabcd
50 assert not (halfmatrix and checksym) 1feabcd
52 def _clone(self): 1feabcd
53 newself = super()._clone() 1feabcd
54 newself._elements = self._elements.copy() 1feabcd
55 newself._covblocks = self._covblocks.copy() 1feabcd
56 newself._priordict = self._priordict.copy() 1feabcd
57 newself._checkpositive = self._checkpositive 1feabcd
58 newself._posepsfac = self._posepsfac 1feabcd
59 newself._checksym = self._checksym 1feabcd
60 newself._halfmatrix = self._halfmatrix 1feabcd
61 newself._dtype = self._dtype 1feabcd
62 return newself 1feabcd
64 @staticmethod 1feabcd
65 def _concatenate(alist): 1feabcd
66 """
67 Decides to use numpy.concatenate or jnp.concatenate depending on the
68 input to support gvars.
69 """
70 if any(a.dtype == object for a in alist): 1feabcd
71 return numpy.concatenate(alist) 1feabcd
72 else:
73 return jnp.concatenate(alist) 1feabcd
75 @staticmethod 1feabcd
76 def _triu_indices_and_back(n): 1feabcd
77 """
78 Return indices to get the upper triangular part of a matrix, and indices
79 to convert a flat array of upper triangular elements to a symmetric
80 matrix.
81 """
82 ix, iy = jnp.triu_indices(n) 1abcd
83 q = jnp.empty((n, n), ix.dtype) 1abcd
84 a = jnp.arange(ix.size) 1abcd
85 q = q.at[ix, iy].set(a) 1abcd
86 q = q.at[iy, ix].set(a) 1abcd
87 return ix, iy, q 1abcd
89 class _Element(abc.ABC): 1feabcd
90 """
91 Abstract class for an object holding information associated to a key in
92 a GP object.
93 """
95 @property 1feabcd
96 @abc.abstractmethod 1feabcd
97 def shape(self): # pragma: no cover 1feabcd
98 """Output shape"""
99 pass
101 @property 1feabcd
102 def size(self): 1feabcd
103 return math.prod(self.shape) 1feabcd
105 class _Points(_Element): 1feabcd
106 """Points where the process is evaluated"""
108 def __init__(self, x, deriv, proc): 1feabcd
109 assert isinstance(x, (numpy.ndarray, jnp.ndarray, _array.StructuredArray)) 1feabcd
110 assert isinstance(deriv, _Deriv.Deriv) 1feabcd
111 self.x = x 1feabcd
112 self.deriv = deriv 1feabcd
113 self.proc = proc 1feabcd
115 @property 1feabcd
116 def shape(self): 1feabcd
117 return self.x.shape 1feabcd
119 class _LinTransf(_Element): 1feabcd
120 """Linear transformation of other _Element objects"""
122 shape = None 1feabcd
124 def __init__(self, transf, keys, shape): 1feabcd
125 self.transf = transf 1feabcd
126 self.keys = keys 1feabcd
127 self.shape = shape 1feabcd
129 def matrices(self, gp): 1feabcd
130 """
131 Matrix coefficients of the transformation (with flattened inputs
132 and output)
133 """
134 elems = [gp._elements[key] for key in self.keys] 1abcd
135 matrices = [] 1abcd
136 transf = jax.vmap(self.transf, 0, 0) 1abcd
137 for i, elem in enumerate(elems): 1abcd
138 inputs = [ 1abcd
139 jnp.eye(elem.size).reshape((elem.size,) + elem.shape)
140 if j == i else
141 jnp.zeros((elem.size,) + ej.shape)
142 for j, ej in enumerate(elems)
143 ]
144 output = transf(*inputs).reshape(elem.size, self.size).T 1abcd
145 matrices.append(output) 1abcd
146 return matrices 1abcd
148 class _Cov(_Element): 1feabcd
149 """User-provided covariance matrix block(s)"""
151 shape = None 1feabcd
153 def __init__(self, blocks, shape): 1feabcd
154 """ blocks = dict (key, key) -> matrix """
155 self.blocks = blocks 1eabcd
156 self.shape = shape 1eabcd
158 @_base.newself 1feabcd
159 def addx(self, x, key=None, *, deriv=0, proc=_base.GPBase.DefaultProcess): 1feabcd
160 """
162 Add points where the Gaussian process is evaluated.
164 The GP object keeps the various x arrays in a dictionary. If ``x`` is an
165 array, you have to specify its dictionary key with the ``key`` parameter.
166 Otherwise, you can directly pass a dictionary for ``x``.
168 To specify that on the given ``x`` a derivative of the process instead of
169 the process itself should be evaluated, use the parameter ``deriv``.
171 `addx` may or may not copy the input arrays.
173 Parameters
174 ----------
175 x : array or dictionary of arrays
176 The points to be added.
177 key : hashable
178 If ``x`` is an array, the dictionary key under which ``x`` is added.
179 Can not be specified if ``x`` is a dictionary.
180 deriv : Deriv-like
181 Derivative specification. A `Deriv` object or something that
182 can be converted to `Deriv`.
183 proc : hashable
184 The process to be evaluated on the points. If not specified, use
185 the default process.
187 """
189 # TODO after I implement block solving, add per-key covariance matrix
190 # flags.
192 # TODO add `copy` parameter, default False, to copy the input arrays
193 # if they are numpy arrays.
195 # this interface does not allow adding a single dictionary as x element
196 # unless it's wrapped as a 0d numpy array, but this is for the best
198 deriv = _Deriv.Deriv(deriv) 1feabcd
200 if proc not in self._procs: 1feabcd
201 raise KeyError(f'process named {proc!r} not found') 1abcd
203 if hasattr(x, 'keys'): 1feabcd
204 if key is not None: 1abcd
205 raise ValueError('can not specify key if x is a dictionary') 1abcd
206 if None in x: 1abcd
207 raise ValueError('None key in x not allowed') 1abcd
208 else:
209 if key is None: 1feabcd
210 raise ValueError('x is not dictionary but key is None') 1abcd
211 x = {key: x} 1feabcd
213 for key in x: 1feabcd
214 if key in self._elements: 1feabcd
215 raise KeyError('key {!r} already in GP'.format(key)) 1abcd
217 gx = x[key] 1feabcd
219 # Convert to JAX array, numpy array or StructuredArray.
220 # convert eagerly to jax to avoid problems with tracing.
221 gx = _array._asarray_jaxifpossible(gx) 1feabcd
223 # Check dtype is compatible with previous arrays.
224 # TODO since we never concatenate arrays we could allow a less
225 # strict compatibility. In principle we could allow really anything
226 # as long as the kernel eats it, but this probably would let bugs
227 # through without being really ever useful. What would make sense
228 # is checking the dtype structure matches recursively and check
229 # concrete dtypes of fields can be casted.
230 # TODO result_type is too lax. Examples: str, float -> str,
231 # object, float -> object. I should use something like the
232 # ordering function in updowncast.py.
233 if self._dtype is not None: 1feabcd
234 try: 1feabcd
235 self._dtype = numpy.result_type(self._dtype, gx.dtype) 1feabcd
236 # do not use jnp.result_type, it does not support
237 # structured types
238 except TypeError: 1abcd
239 msg = 'x[{!r}].dtype = {!r} not compatible with {!r}' 1abcd
240 msg = msg.format(key, gx.dtype, self._dtype) 1abcd
241 raise TypeError(msg) 1abcd
242 else:
243 self._dtype = gx.dtype 1feabcd
245 # Check that the derivative specifications are compatible with the
246 # array data type.
247 if gx.dtype.names is None: 1feabcd
248 if not deriv.implicit: 1feabcd
249 raise ValueError('x has no fields but derivative has') 1abcd
250 else:
251 for dim in deriv: 1feabcd
252 if dim not in gx.dtype.names: 1feabcd
253 raise ValueError(f'deriv field {dim!r} not in x') 1abcd
255 self._elements[key] = self._Points(gx, deriv, proc) 1feabcd
257 def _get_x_dtype(self): 1feabcd
258 """ Get the data type of x points """
259 return self._dtype 1feabcd
261 def addtransf(self, tensors, key, *, axes=1): 1feabcd
262 """
264 Apply a linear transformation to already specified process points. The
265 result of the transformation is represented by a new key.
267 Parameters
268 ----------
269 tensors : dict
270 Dictionary mapping keys of the GP to arrays/scalars. Each array is
271 matrix-multiplied with the process array represented by its key,
272 while scalars are just multiplied. Finally, the keys are summed
273 over.
274 key : hashable
275 A new key under which the transformation is placed.
276 axes : int
277 Number of axes to be summed over for matrix multiplication,
278 referring to trailing axes for tensors in ` tensors``, and to
279 heading axes for process points. Default 1.
281 Returns
282 -------
283 gp : GP
284 A new GP object with the applied modifications.
286 Notes
287 -----
288 The multiplication between the tensors and the process is done with
289 np.tensordot with, by default, 1-axis contraction. For >2d arrays this
290 is different from numpy's matrix multiplication, which would act on the
291 second-to-last dimension of the second array.
293 """
294 # Note: it may seem nice that when an array has less axes than `axes`,
295 # the summation would be restricted only on the existing axes. However
296 # this brings about the ambiguous case where only one of the factors has
297 # not enough axes. How many axes do you sum over on the other?
299 # Check axes.
300 assert isinstance(axes, int) and axes >= 0, axes 1eabcd
302 # Check key.
303 if key is None: 1eabcd
304 raise ValueError('key can not be None') 1abcd
305 if key in self._elements: 1eabcd
306 raise KeyError(f'key {key!r} already in GP') 1abcd
308 # Check keys.
309 for k in tensors: 1eabcd
310 if k not in self._elements: 1eabcd
311 raise KeyError(k) 1abcd
313 # Check tensors and convert them to jax arrays.
314 if len(tensors) == 0: 1eabcd
315 raise ValueError('empty tensors, undetermined output shape') 1abcd
316 tens = {} 1eabcd
317 for k, t in tensors.items(): 1eabcd
318 t = jnp.asarray(t) 1eabcd
319 # no need to check dtype since jax supports only numerical arrays
320 with _jaxext.skipifabstract(): 1eabcd
321 if self._checkfinite and not jnp.all(jnp.isfinite(t)): 1eabcd
322 raise ValueError(f'tensors[{k!r}] contains infs/nans') 1abcd
323 rshape = self._elements[k].shape 1eabcd
324 if t.shape and t.shape[t.ndim - axes:] != rshape[:axes]: 1eabcd
325 raise ValueError(f'tensors[{k!r}].shape = {t.shape!r} can not be multiplied with shape {rshape!r} with {axes}-axes contraction') 1abcd
326 tens[k] = t 1eabcd
328 # Check shapes broadcast correctly.
329 arrays = tens.values() 1eabcd
330 elements = (self._elements[k] for k in tens) 1eabcd
331 shapes = ( 1eabcd
332 t.shape[:t.ndim - axes] + e.shape[axes:] if t.shape else e.shape
333 for t, e in zip(arrays, elements)
334 )
335 try: 1eabcd
336 shape = jnp.broadcast_shapes(*shapes) 1eabcd
337 except ValueError: 1abcd
338 msg = 'can not broadcast tensors with shapes [' 1abcd
339 msg += ', '.join(repr(t.shape) for t in arrays) 1abcd
340 msg += '] contracted with arrays with shapes [' 1abcd
341 msg += ', '.join(repr(e.shape) for e in elements) + ']' 1abcd
342 raise ValueError(msg) 1abcd
344 # Define linear transformation.
345 def equiv_lintransf(*args): 1eabcd
346 assert len(args) == len(tens) 1eabcd
347 out = None 1eabcd
348 for a, (k, t) in zip(args, tens.items()): 1eabcd
349 if t.shape: 1eabcd
350 b = jnp.tensordot(t, a, axes) 1eabcd
351 else:
352 b = t * a 1eabcd
353 if out is None: 1eabcd
354 out = b 1eabcd
355 else:
356 out = out + b 1eabcd
357 return out 1eabcd
358 keys = list(tens.keys()) 1eabcd
359 return self.addlintransf(equiv_lintransf, keys, key, checklin=False) 1eabcd
361 @_base.newself 1feabcd
362 def addlintransf(self, transf, keys, key, *, checklin=None): 1feabcd
363 """
365 Define a finite linear transformation of the evaluated process.
367 Parameters
368 ----------
369 transf : callable
370 A function with signature ``f(array1, array2, ...) -> array`` which
371 computes the linear transformation. The function must be
372 jax-traceable, i.e., use jax.numpy instead of numpy.
373 keys : sequence
374 Keys of parts of the process to be passed as inputs to the
375 transformation.
376 key : hashable
377 The key of the newly defined points.
378 checklin : bool
379 If True (default), check that the given function is linear in its
380 inputs. The default can be overridden at initialization of the GP
381 object. Note that an affine function (x -> a + bx) is not linear.
383 Raises
384 ------
385 RuntimeError :
386 The transformation seems not to be linear. To disable the linearity
387 check, initialize the GP with ``checklin=False``.
389 """
391 # TODO elementwise operations can be applied more efficiently to
392 # primary gvars (tipical case), so the method could use an option
393 # `elementwise`. What is the reliable way to check it is indeed
394 # elementwise with a single random vector? Zero items of the tangent
395 # at random with p=0.5 and check they stay zero? (And of course check
396 # the shape is preserved.)
398 # Check key.
399 if key is None: 1feabcd
400 raise ValueError('key can not be None') 1abcd
401 if key in self._elements: 1feabcd
402 raise KeyError(f'key {key!r} already in GP') 1abcd
404 # Check keys.
405 for k in keys: 1feabcd
406 if k not in self._elements: 1feabcd
407 raise KeyError(k) 1abcd
409 # Determine shape.
410 class ArrayMockup: 1feabcd
411 def __init__(self, elem): 1feabcd
412 self.shape = elem.shape 1feabcd
413 self.dtype = float 1feabcd
414 inp = [ArrayMockup(self._elements[k]) for k in keys] 1feabcd
415 out = jax.eval_shape(transf, *inp) 1feabcd
416 shape = out.shape 1feabcd
418 # Check that the transformation is linear.
419 if checklin is None: 1feabcd
420 checklin = self._checklin 1feabcd
421 if checklin: 1feabcd
422 shapes = [self._elements[k].shape for k in keys] 1feabcd
423 self._checklinear(transf, shapes) 1feabcd
425 self._elements[key] = self._LinTransf(transf, keys, shape) 1feabcd
427 @_base.newself 1feabcd
428 def addcov(self, covblocks, key=None, *, decomps=None): 1feabcd
429 """
431 Add user-defined prior covariance matrix blocks.
433 Covariance matrices defined with `addcov` represent arbitrary
434 finite-dimensional zero-mean Gaussian variables, assumed independent
435 from all other variables in the GP object.
437 Parameters
438 ----------
439 covblocks : array or dictionary of arrays
440 If an array: a covariance matrix (or tensor) to be added under key
441 ``key``. If a dictionary: a mapping from pairs of keys to the
442 corresponding covariance matrix blocks. A missing off-diagonal
443 block in the dictionary is interpreted as a matrix of zeros,
444 unless the corresponding transposed block is specified.
445 key : hashable
446 If ``covblocks`` is an array, the dictionary key under which
447 ``covblocks`` is added. Can not be specified if ``covblocks`` is a
448 dictionary.
449 decomps : Decomposition or dict of Decompositions
450 Pre-computed decompositions of (not necessarily all) diagonal
451 blocks, as produced by `decompose`. The keys are single
452 GP keys and not pairs like in ``covblocks``.
454 Raises
455 ------
456 KeyError :
457 A key is already used in the GP.
458 ValueError :
459 ``covblocks`` and/or ``key`` and ``decomps`` are malformed or
460 inconsistent.
461 TypeError :
462 Wrong type of ``covblocks`` or ``decomps``.
464 """
466 # TODO maybe allow passing only the lower/upper triangular part for
467 # the diagonal blocks, like I meta-allow for out of diagonal blocks?
469 # TODO with multiple blocks and a single decomp, the decomp could be
470 # interpreted as the decomposition of the whole block matrix.
472 # Check type of `covblocks` and standardize it to dictionary.
473 if hasattr(covblocks, 'keys'): 1eabcd
474 if key is not None: 1abcd
475 raise ValueError('can not specify key if covblocks is a dictionary') 1abcd
476 if None in covblocks: 1abcd
477 raise ValueError('None key in covblocks not allowed') 1abcd
478 if decomps is not None and not hasattr(decomps, 'keys'): 1abcd
479 raise TypeError('covblocks is dictionary but decomps is not') 1abcd
480 else:
481 if key is None: 1eabcd
482 raise ValueError('covblocks is not dictionary but key is None') 1abcd
483 covblocks = {(key, key): covblocks} 1eabcd
484 if decomps is not None: 1eabcd
485 decomps = {key: decomps} 1abcd
487 if decomps is None: 1eabcd
488 decomps = {} 1eabcd
490 # Convert blocks to jax arrays and determine shapes from diagonal
491 # blocks.
492 shapes = {} 1eabcd
493 preblocks = {} 1eabcd
494 for keys, block in covblocks.items(): 1eabcd
495 # TODO maybe check that keys is a 2-tuple
496 for key in keys: 1eabcd
497 if key in self._elements: 1eabcd
498 raise KeyError(f'key {key!r} already in GP') 1abcd
499 xkey, ykey = keys 1eabcd
500 if block is None: 1eabcd
501 raise TypeError(f'block {keys!r} is None') 1abcd
502 # because jnp.asarray(None) interprets None as nan
503 # (see jax issue #14506)
504 block = jnp.asarray(block) 1eabcd
506 if xkey == ykey: 1eabcd
508 if block.ndim % 2 == 1: 1eabcd
509 raise ValueError(f'diagonal block {key!r} has odd number of axes') 1abcd
511 half = block.ndim // 2 1eabcd
512 head = block.shape[:half] 1eabcd
513 tail = block.shape[half:] 1eabcd
514 if head != tail: 1eabcd
515 raise ValueError(f'shape {block.shape!r} of diagonal block {key!r} is not symmetric') 1abcd
516 shapes[xkey] = head 1eabcd
518 with _jaxext.skipifabstract(): 1eabcd
519 if self._checksym and not jnp.allclose(block, block.T): 1eabcd
520 raise ValueError(f'diagonal block {key!r} is not symmetric') 1abcd
522 preblocks[keys] = block 1eabcd
524 # Check decomps is consistent with covblocks.
525 for key, dec in decomps.items(): 1eabcd
526 if key not in shapes: 1abcd
527 raise KeyError(f'key {key!r} in decomps not found in diagonal blocks') 1abcd
528 if not isinstance(dec, _linalg.Decomposition): 1abcd
529 raise TypeError(f'decomps[{key!r}] = {dec!r} is not a decomposition') 1abcd
530 n = math.prod(shapes[key]) 1abcd
531 if dec.n != n: 1abcd
532 raise ValueError(f'decomposition matrix size {dec.n} != diagonal block size {n} for key {key!r}') 1abcd
534 # Reshape blocks to square matrices and check that the shapes of out of
535 # diagonal blocks match those of diagonal ones.
536 blocks = {} 1eabcd
537 for keys, block in preblocks.items(): 1eabcd
538 with _jaxext.skipifabstract(): 1eabcd
539 if self._checkfinite and not jnp.all(jnp.isfinite(block)): 1eabcd
540 raise ValueError(f'block {keys!r} not finite') 1abcd
541 xkey, ykey = keys 1eabcd
542 if xkey == ykey: 1eabcd
543 size = math.prod(shapes[xkey]) 1eabcd
544 blocks[keys] = block.reshape((size, size)) 1eabcd
545 else:
546 for key in keys: 1abcd
547 if key not in shapes: 1abcd
548 raise KeyError(f'key {key!r} from off-diagonal block {keys!r} not found in diagonal blocks') 1abcd
549 eshape = shapes[xkey] + shapes[ykey] 1abcd
550 if block.shape != eshape: 1abcd
551 raise ValueError(f'shape {block.shape!r} of block {keys!r} is not {eshape!r} as expected from diagonal blocks') 1abcd
552 xsize = math.prod(shapes[xkey]) 1abcd
553 ysize = math.prod(shapes[ykey]) 1abcd
554 block = block.reshape((xsize, ysize)) 1abcd
555 blocks[keys] = block 1abcd
556 revkeys = keys[::-1] 1abcd
557 blockT = preblocks.get(revkeys) 1abcd
558 if blockT is None: 1abcd
559 blocks[revkeys] = block.T 1abcd
561 # Check symmetry of out of diagonal blocks.
562 if self._checksym: 1eabcd
563 with _jaxext.skipifabstract(): 1abcd
564 for keys, block in blocks.items(): 1abcd
565 xkey, ykey = keys 1abcd
566 if xkey != ykey: 1abcd
567 blockT = blocks[ykey, xkey] 1abcd
568 if not jnp.allclose(block.T, blockT): 1abcd
569 raise ValueError(f'block {keys!r} is not the transpose of block {revkeys!r}') 1abcd
571 # Create _Cov objects.
572 for key, shape in shapes.items(): 1eabcd
573 self._elements[key] = self._Cov(blocks, shape) 1eabcd
574 decomp = decomps.get(key) 1eabcd
575 if decomp is not None: 1eabcd
576 self._decompcache[key,] = decomp 1abcd
578 def _makecovblock_points(self, xkey, ykey): 1feabcd
579 x = self._elements[xkey] 1feabcd
580 y = self._elements[ykey] 1feabcd
582 assert isinstance(x, self._Points) 1feabcd
583 assert isinstance(y, self._Points) 1feabcd
585 kernel = self._crosskernel(x.proc, y.proc) 1feabcd
586 if kernel is self._zerokernel: 1feabcd
587 # TODO handle zero cov block efficiently
588 return jnp.zeros((x.size, y.size)) 1feabcd
590 kernel = kernel.linop('diff', x.deriv, y.deriv) 1feabcd
592 if x is y and not self._checksym and self._halfmatrix: 1feabcd
593 ix, iy, back = self._triu_indices_and_back(x.size) 1abcd
594 flat = x.x.reshape(-1) 1abcd
595 ax = flat[ix] 1abcd
596 ay = flat[iy] 1abcd
597 halfcov = kernel(ax, ay) 1abcd
598 cov = halfcov[back] 1abcd
599 # TODO to avoid inefficiencies like in BART, maybe _Kernel should
600 # have a method outer(x) that by default simply does self(x[None,
601 # :], x[:, None]) but can be overwritten. This halfmatrix impl could
602 # be moved there with an option outer(x, *, half=False). To carry
603 # over custom implementations of outer, there should be a callable
604 # attribute _outer, optionally set at initialization, that is
605 # transformed by kernel operations.
606 else:
607 ax = x.x.reshape(-1)[:, None] 1feabcd
608 ay = y.x.reshape(-1)[None, :] 1feabcd
609 cov = kernel(ax, ay) 1feabcd
611 return cov 1feabcd
613 def _makecovblock_lintransf_any(self, xkey, ykey): 1feabcd
614 x = self._elements[xkey] 1feabcd
615 y = self._elements[ykey] 1feabcd
616 assert isinstance(x, self._LinTransf) 1feabcd
618 # Gather covariance matrices to be transformed.
619 covs = [] 1feabcd
620 for k in x.keys: 1feabcd
621 elem = self._elements[k] 1feabcd
622 cov = self._covblock(k, ykey) 1feabcd
623 assert cov.shape == (elem.size, y.size) 1feabcd
624 cov = cov.reshape(elem.shape + (y.size,)) 1feabcd
625 covs.append(cov) 1feabcd
627 # Apply transformation.
628 t = jax.vmap(x.transf, -1, -1) 1feabcd
629 cov = t(*covs) 1feabcd
630 assert cov.shape == x.shape + (y.size,) 1feabcd
631 return cov.reshape((x.size, y.size)) # don't leave out the ()! 1feabcd
632 # the () probably was an obscure autograd bug, I don't think it will
633 # be a problem again with jax
635 def _makecovblock(self, xkey, ykey): 1feabcd
636 x = self._elements[xkey] 1feabcd
637 y = self._elements[ykey] 1feabcd
638 if isinstance(x, self._Points) and isinstance(y, self._Points): 1feabcd
639 cov = self._makecovblock_points(xkey, ykey) 1feabcd
640 elif isinstance(x, self._LinTransf): 1feabcd
641 cov = self._makecovblock_lintransf_any(xkey, ykey) 1feabcd
642 elif isinstance(y, self._LinTransf): 1feabcd
643 cov = self._makecovblock_lintransf_any(ykey, xkey) 1feabcd
644 cov = cov.T 1feabcd
645 elif isinstance(x, self._Cov) and isinstance(y, self._Cov) and x.blocks is y.blocks and (xkey, ykey) in x.blocks: 1eabcd
646 cov = x.blocks[xkey, ykey] 1eabcd
647 else:
648 # TODO handle zero cov block efficiently
649 cov = jnp.zeros((x.size, y.size)) 1eabcd
651 with _jaxext.skipifabstract(): 1feabcd
652 if self._checkfinite and not jnp.all(jnp.isfinite(cov)): 1feabcd
653 raise RuntimeError(f'covariance block {(xkey, ykey)!r} is not finite') 1abcd
654 if self._checksym and xkey == ykey and not jnp.allclose(cov, cov.T): 1feabcd
655 raise RuntimeError(f'covariance block {(xkey, ykey)!r} is not symmetric') 1abcd
657 return cov 1feabcd
659 def _covblock(self, row, col): 1feabcd
661 if (row, col) not in self._covblocks: 1feabcd
662 block = self._makecovblock(row, col) 1feabcd
663 if row != col: 1feabcd
664 if self._checksym: 1feabcd
665 with _jaxext.skipifabstract(): 1feabcd
666 blockT = self._makecovblock(col, row) 1feabcd
667 if not jnp.allclose(block.T, blockT): 1feabcd
668 msg = 'covariance block {!r} is not symmetric' 1abcd
669 raise RuntimeError(msg.format((row, col))) 1abcd
670 self._covblocks[col, row] = block.T 1feabcd
671 self._covblocks[row, col] = block 1feabcd
673 return self._covblocks[row, col] 1feabcd
675 def _assemblecovblocks(self, rowkeys, colkeys=None): 1feabcd
676 if colkeys is None: 1feabcd
677 colkeys = rowkeys 1feabcd
678 blocks = [ 1feabcd
679 [self._covblock(row, col) for col in colkeys]
680 for row in rowkeys
681 ]
682 return jnp.block(blocks) 1feabcd
684 def _checkpos(self, cov): 1feabcd
685 with _jaxext.skipifabstract(): 1feabcd
686 # eigv = jnp.linalg.eigvalsh(cov)
687 # mineigv, maxeigv = jnp.min(eigv), jnp.max(eigv)
688 with warnings.catch_warnings(): 1feabcd
689 warnings.filterwarnings('ignore', r'Exited at iteration .+? with accuracies') 1feabcd
690 warnings.filterwarnings('ignore', r'Exited postprocessing with accuracies') 1feabcd
691 X = numpy.random.randn(len(cov), 1) 1feabcd
692 A = numpy.asarray(cov) 1feabcd
693 (mineigv,), _ = sparse.linalg.lobpcg(A, X, largest=False) 1feabcd
694 (maxeigv,), _ = sparse.linalg.lobpcg(A, X, largest=True) 1feabcd
695 assert mineigv <= maxeigv 1feabcd
696 if mineigv < 0: 1feabcd
697 bound = -len(cov) * jnp.finfo(cov.dtype).eps * maxeigv * self._posepsfac 1abcd
698 if mineigv < bound: 1abcd
699 msg = 'covariance matrix is not positive definite: ' 1abcd
700 msg += 'mineigv = {:.4g} < {:.4g}'.format(mineigv, bound) 1abcd
701 raise numpy.linalg.LinAlgError(msg) 1abcd
703 _checkpos_cache = functools.cached_property(lambda self: []) 1feabcd
704 def _checkpos_keys(self, keys): 1feabcd
705 # TODO go back to ancestors of _LinTransf?
706 if not self._checkpositive: 1feabcd
707 return 1feabcd
708 keys = set(keys) 1feabcd
709 for prev_keys in self._checkpos_cache: 1feabcd
710 if keys.issubset(prev_keys): 1feabcd
711 return 1feabcd
712 cov = self._assemblecovblocks(list(keys)) 1feabcd
713 self._checkpos(cov) 1feabcd
714 self._checkpos_cache.append(keys) 1feabcd
716 def _priorpointscov(self, key): 1feabcd
718 x = self._elements[key] 1feabcd
719 classes = (self._Points, self._Cov) 1feabcd
720 assert isinstance(x, classes) 1feabcd
721 mean = numpy.zeros(x.size) 1feabcd
722 cov = self._covblock(key, key).astype(float) 1feabcd
723 assert cov.shape == 2 * mean.shape, cov.shape 1feabcd
725 # get preexisting primary gvars to be correlated with the new ones
726 preitems = [ 1feabcd
727 k
728 for k, px in self._elements.items()
729 if isinstance(px, classes)
730 and k in self._priordict
731 ]
732 if preitems: 1feabcd
733 prex = numpy.concatenate([ 1feabcd
734 numpy.reshape(self._priordict[k], -1)
735 for k in preitems
736 ])
737 precov = numpy.concatenate([ 1feabcd
738 self._covblock(k, key).astype(float)
739 for k in preitems
740 ])
741 g = gvar.gvar(mean, cov, prex, precov, fast=True) 1feabcd
742 else:
743 g = gvar.gvar(mean, cov, fast=True) 1feabcd
745 return g.reshape(x.shape) 1feabcd
747 def _priorlintransf(self, key): 1feabcd
748 x = self._elements[key] 1feabcd
749 assert isinstance(x, self._LinTransf) 1feabcd
751 # Gather all gvars to be transformed.
752 elems = [ 1feabcd
753 self._prior(k).reshape(-1)
754 for k in x.keys
755 ]
756 g = numpy.concatenate(elems) 1feabcd
758 # Extract jacobian and split it.
759 slices = self._slices(x.keys) 1feabcd
760 jac, indices = _gvarext.jacobian(g) 1feabcd
761 jacs = [ 1feabcd
762 jac[s].reshape(self._elements[k].shape + indices.shape)
763 for s, k in zip(slices, x.keys)
764 ]
765 # TODO the jacobian can be extracted much more efficiently when the
766 # elements are _Points or _Cov, since in that case the gvars are primary
767 # and contiguous within each block, so each jacobian is the identity + a
768 # range. Then write a function _gvarext.merge_jacobians to combine
769 # them, which also can be optimized knowing the indices are
770 # non-overlapping ranges.
772 # Apply transformation.
773 t = jax.vmap(x.transf, -1, -1) 1feabcd
774 outjac = t(*jacs) 1feabcd
775 assert outjac.shape == x.shape + indices.shape 1feabcd
777 # Rebuild gvars.
778 outg = _gvarext.from_jacobian(numpy.zeros(x.shape), outjac, indices) 1feabcd
779 return outg 1feabcd
781 def _prior(self, key): 1feabcd
782 prior = self._priordict.get(key, None) 1feabcd
783 if prior is None: 1feabcd
784 x = self._elements[key] 1feabcd
785 if isinstance(x, (self._Points, self._Cov)): 1feabcd
786 prior = self._priorpointscov(key) 1feabcd
787 elif isinstance(x, self._LinTransf): 1feabcd
788 prior = self._priorlintransf(key) 1feabcd
789 else: # pragma: no cover
790 raise TypeError(type(x))
791 self._priordict[key] = prior 1feabcd
792 return prior 1feabcd
794 def prior(self, key=None, *, raw=False): 1feabcd
795 """
797 Return an array or a dictionary of arrays of gvars representing the
798 prior for the Gaussian process. The returned object is not unique but
799 the gvars stored inside are, so all the correlations are kept between
800 objects returned by different calls to `prior`.
802 Calling without arguments returns the complete prior as a dictionary.
803 If you specify ``key``, only the array for the requested key is returned.
805 Parameters
806 ----------
807 key : None, key or list of keys
808 Key(s) corresponding to one passed to `addx` or `addtransf`. None
809 for all keys.
810 raw : bool
811 If True, instead of returning a collection of gvars return
812 their covariance matrix as would be returned by `gvar.evalcov`.
813 Default False.
815 Returns
816 -------
817 If raw=False (default):
819 prior : np.ndarray or dict
820 A collection of gvars representing the prior.
822 If raw=True:
824 cov : np.ndarray or dict
825 The covariance matrix of the prior.
826 """
827 raw = bool(raw) 1feabcd
829 if key is None: 1feabcd
830 outkeys = list(self._elements) 1feabcd
831 elif isinstance(key, list): 1feabcd
832 outkeys = key 1eabcd
833 else:
834 outkeys = None 1feabcd
836 self._checkpos_keys([key] if outkeys is None else outkeys) 1feabcd
838 if raw and outkeys is not None: 1feabcd
839 return { 1eabcd
840 (row, col):
841 self._covblock(row, col).reshape(
842 self._elements[row].shape +
843 self._elements[col].shape
844 )
845 for row in outkeys
846 for col in outkeys
847 }
848 elif raw: 1feabcd
849 return self._covblock(key, key).reshape(2 * self._elements[key].shape) 1abcd
850 elif outkeys is not None: 1feabcd
851 return {key: self._prior(key) for key in outkeys} 1feabcd
852 else:
853 return self._prior(key) 1feabcd
855 def _slices(self, keylist): 1feabcd
856 """
857 Return list of slices for the positions of flattened arrays
858 corresponding to keys in ``keylist`` into their concatenation.
859 """
860 sizes = [self._elements[key].size for key in keylist] 1feabcd
861 stops = numpy.pad(numpy.cumsum(sizes), (1, 0)) 1feabcd
862 return [slice(stops[i - 1], stops[i]) for i in range(1, len(stops))] 1feabcd