Coverage for src/lsqfitgp/_fastraniter.py: 96%

41 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_fastraniter.py 

2# 

3# Copyright (c) 2020, 2022, 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import itertools 1feabcd

21 

22import gvar 1feabcd

23import numpy 1feabcd

24 

25from . import _linalg 1feabcd

26 

27# TODO support jax tracing and jax random sampling 

28 

29# TODO I should have a vectorized sampler, that returns an array of samples 

30# right away. Copy the new gvar interface (from gvar 12.1). 

31 

32def _toslice(s): 1feabcd

33 if isinstance(s, slice): 1eabcd

34 return s 1eabcd

35 if isinstance(s, int): 35 ↛ 37line 35 didn't jump to line 37 because the condition on line 35 was always true1eabcd

36 return slice(s, s + 1) 1eabcd

37 raise TypeError(f'cannot convert {s!r} to slice') 

38 

39def raniter(mean, cov, n=None, eps=None, rng=None): 1feabcd

40 """ 

41  

42 Take random samples from a multivariate Gaussian. 

43  

44 This generator mimics the interface of `gvar.raniter`, but takes as input 

45 the mean and covariance separately instead of a collection of gvars. 

46  

47 Parameters 

48 ---------- 

49 mean : scalar, array, or dictionary of scalars/arrays 

50 The mean of the Gaussian distribution. 

51 cov : scalar, array, or dictionary of scalars/arrays 

52 The covariance matrix. If ``mean`` is a dictionary, ``cov`` must be a 

53 dictionary with pair of keys from ``mean`` as keys. 

54 n : int, optional 

55 The maximum number of iterations. Default unlimited. 

56 eps : float, optional 

57 Used to correct the eigenvalues of the covariance matrix to handle 

58 non-positivity due to roundoff, relative to the largest eigenvalue. 

59 Default is number of variables times floating point epsilon. 

60 rng : seed or random generator, optional 

61 ``rng`` is passed through `numpy.random.default_rng` to produce a random 

62 number generator. 

63  

64 Yields 

65 ------ 

66 samp : scalar, array, or dictionary of scalars/arrays 

67 The random sample in the same format of ``mean``. 

68  

69 Examples 

70 -------- 

71  

72 >>> mean = {'a': np.arange(3)} 

73 >>> cov = {('a', 'a'): np.eye(3)} 

74 >>> for sample in lgp.raniter(mean, cov, 3): 

75 >>> print(sample) 

76  

77 """ 

78 

79 # convert mean and cov to 1d and 2d arrays 

80 if hasattr(mean, 'keys'): # a dict or gvar.BufferDict 1eabcd

81 if not hasattr(mean, 'buf'): 1eabcd

82 mean = gvar.BufferDict(mean) 1eabcd

83 flatmean = mean.buf 1eabcd

84 squarecov = numpy.empty((len(flatmean), len(flatmean))) 1eabcd

85 for k1 in mean: 1eabcd

86 slic1 = _toslice(mean.slice(k1)) 1eabcd

87 for k2 in mean: 1eabcd

88 slic2 = _toslice(mean.slice(k2)) 1eabcd

89 sqshape = (slic1.stop - slic1.start, slic2.stop - slic2.start) 1eabcd

90 squarecov[slic1, slic2] = cov[k1, k2].reshape(sqshape) 1eabcd

91 else: # an array or scalar 

92 mean = numpy.array(mean, copy=False) 1eabcd

93 cov = numpy.array(cov, copy=False) 1eabcd

94 flatmean = mean.reshape(-1) 1eabcd

95 squarecov = cov.reshape(len(flatmean), len(flatmean)) 1eabcd

96 

97 # decompose the covariance matrix 

98 try: 1eabcd

99 covdec = _linalg.Chol(squarecov, epsrel='auto' if eps is None else eps) 1eabcd

100 except numpy.linalg.LinAlgError: 1abcd

101 raise numpy.linalg.LinAlgError('covariance matrix not positive definite with eps={}'.format(eps)) 1abcd

102 # TODO when I implement a pseudoinverse or something that does not fail 

103 # like diagonalization, issue a warning and use the other decomposition 

104 

105 # get random number generator 

106 rng = numpy.random.default_rng(rng) 1eabcd

107 

108 # take samples 

109 iterable = itertools.count() if n is None else range(n) 1eabcd

110 for _ in iterable: 1eabcd

111 iidsamp = rng.standard_normal(covdec.m) 1eabcd

112 samp = flatmean + covdec.correlate(iidsamp) 1eabcd

113 

114 # pack the samples with the original shape 

115 if hasattr(mean, 'keys'): 1eabcd

116 samp = gvar.BufferDict(mean, buf=samp) 1eabcd

117 else: 

118 samp = samp.reshape(mean.shape) if mean.shape else samp.item 1eabcd

119 

120 yield samp 1eabcd

121 

122def sample(*args, **kw): 1feabcd

123 """ 

124 Shortcut for ``next(raniter(..., n=1))``. 

125 """ 

126 return next(raniter(*args, n=1, **kw)) 1eabcd