Coverage for src/lsqfitgp/_fastraniter.py: 96%
41 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_fastraniter.py
2#
3# Copyright (c) 2020, 2022, 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import itertools 1feabcd
22import gvar 1feabcd
23import numpy 1feabcd
25from . import _linalg 1feabcd
27# TODO support jax tracing and jax random sampling
29# TODO I should have a vectorized sampler, that returns an array of samples
30# right away. Copy the new gvar interface (from gvar 12.1).
32def _toslice(s): 1feabcd
33 if isinstance(s, slice): 1eabcd
34 return s 1eabcd
35 if isinstance(s, int): 35 ↛ 37line 35 didn't jump to line 37 because the condition on line 35 was always true1eabcd
36 return slice(s, s + 1) 1eabcd
37 raise TypeError(f'cannot convert {s!r} to slice')
39def raniter(mean, cov, n=None, eps=None, rng=None): 1feabcd
40 """
42 Take random samples from a multivariate Gaussian.
44 This generator mimics the interface of `gvar.raniter`, but takes as input
45 the mean and covariance separately instead of a collection of gvars.
47 Parameters
48 ----------
49 mean : scalar, array, or dictionary of scalars/arrays
50 The mean of the Gaussian distribution.
51 cov : scalar, array, or dictionary of scalars/arrays
52 The covariance matrix. If ``mean`` is a dictionary, ``cov`` must be a
53 dictionary with pair of keys from ``mean`` as keys.
54 n : int, optional
55 The maximum number of iterations. Default unlimited.
56 eps : float, optional
57 Used to correct the eigenvalues of the covariance matrix to handle
58 non-positivity due to roundoff, relative to the largest eigenvalue.
59 Default is number of variables times floating point epsilon.
60 rng : seed or random generator, optional
61 ``rng`` is passed through `numpy.random.default_rng` to produce a random
62 number generator.
64 Yields
65 ------
66 samp : scalar, array, or dictionary of scalars/arrays
67 The random sample in the same format of ``mean``.
69 Examples
70 --------
72 >>> mean = {'a': np.arange(3)}
73 >>> cov = {('a', 'a'): np.eye(3)}
74 >>> for sample in lgp.raniter(mean, cov, 3):
75 >>> print(sample)
77 """
79 # convert mean and cov to 1d and 2d arrays
80 if hasattr(mean, 'keys'): # a dict or gvar.BufferDict 1eabcd
81 if not hasattr(mean, 'buf'): 1eabcd
82 mean = gvar.BufferDict(mean) 1eabcd
83 flatmean = mean.buf 1eabcd
84 squarecov = numpy.empty((len(flatmean), len(flatmean))) 1eabcd
85 for k1 in mean: 1eabcd
86 slic1 = _toslice(mean.slice(k1)) 1eabcd
87 for k2 in mean: 1eabcd
88 slic2 = _toslice(mean.slice(k2)) 1eabcd
89 sqshape = (slic1.stop - slic1.start, slic2.stop - slic2.start) 1eabcd
90 squarecov[slic1, slic2] = cov[k1, k2].reshape(sqshape) 1eabcd
91 else: # an array or scalar
92 mean = numpy.array(mean, copy=False) 1eabcd
93 cov = numpy.array(cov, copy=False) 1eabcd
94 flatmean = mean.reshape(-1) 1eabcd
95 squarecov = cov.reshape(len(flatmean), len(flatmean)) 1eabcd
97 # decompose the covariance matrix
98 try: 1eabcd
99 covdec = _linalg.Chol(squarecov, epsrel='auto' if eps is None else eps) 1eabcd
100 except numpy.linalg.LinAlgError: 1abcd
101 raise numpy.linalg.LinAlgError('covariance matrix not positive definite with eps={}'.format(eps)) 1abcd
102 # TODO when I implement a pseudoinverse or something that does not fail
103 # like diagonalization, issue a warning and use the other decomposition
105 # get random number generator
106 rng = numpy.random.default_rng(rng) 1eabcd
108 # take samples
109 iterable = itertools.count() if n is None else range(n) 1eabcd
110 for _ in iterable: 1eabcd
111 iidsamp = rng.standard_normal(covdec.m) 1eabcd
112 samp = flatmean + covdec.correlate(iidsamp) 1eabcd
114 # pack the samples with the original shape
115 if hasattr(mean, 'keys'): 1eabcd
116 samp = gvar.BufferDict(mean, buf=samp) 1eabcd
117 else:
118 samp = samp.reshape(mean.shape) if mean.shape else samp.item 1eabcd
120 yield samp 1eabcd
122def sample(*args, **kw): 1feabcd
123 """
124 Shortcut for ``next(raniter(..., n=1))``.
125 """
126 return next(raniter(*args, n=1, **kw)) 1eabcd