8. Fitting¶
- class lsqfitgp.empbayes_fit(hyperprior, gpfactory, data, *, raises=True, minkw={}, gpfactorykw={}, jit=True, method='gradient', initial='priormean', verbosity=0, covariance='auto', fix=None, mlkw={}, forward=False)[source]¶
Maximum a posteriori fit.
Maximizes the marginal likelihood of the data with a Gaussian process model that depends on hyperparameters, multiplied by a prior on the hyperparameters.
- Parameters:
- hyperpriorscalar, array or dictionary of scalars/arrays
A collection of gvars representing the prior for the hyperparameters.
- gpfactorycallable
A function with signature gpfactory(hyperparams) -> GP object. The argument
hyperparams
has the same structure of the empbayes_fit argumenthyperprior
. gpfactory must be JAX-friendly, i.e., use jax.numpy and jax.scipy instead of plain numpy/scipy and avoid assignments to arrays.- datadict, tuple or callable
Dictionary of data that is passed to
GP.marginal_likelihood
on the GP object returned bygpfactory
. If a tuple, it contains the first two arguments toGP.marginal_likelihood
. If a callable, it is called with the same arguments ofgpfactory
and must return the argument(s) forGP.marginal_likelihood
.- raisesbool, optional
If True (default), raise an error when the minimization fails. Otherwise, use the last point of the minimization as result.
- minkwdict, optional
Keyword arguments passed to
scipy.optimize.minimize
, overwrites values specified byempbayes_fit
.- gpfactorykwdict, optional
Keyword arguments passed to
gpfactory
, and also todata
if it is a callable. Ifjit
,gpfactorykw
crosses a JAX jit boundary, so it must contain objects understandable by JAX.- jitbool
If True (default), use jax’s jit to compile the minimization target.
- methodstr
Minimization strategy. Options:
- ‘nograd’
Use a gradient-free method.
- ‘gradient’
Use a gradient-only method (default).
- ‘fisher’
Use a Newton method with the Fisher information matrix plus the hyperprior precision matrix.
- initialstr, scalar, array, dictionary of scalars/arrays
Starting point for the minimization, matching the format of
hyperprior
, or one of the following options:- ‘priormean’
Start from the hyperprior mean (default).
- ‘priorsample’
Take a random sample from the hyperprior.
- verbosityint
An integer indicating how much information is printed on the terminal:
- 0
No logging (default).
- 1
Minimal report.
- 2
Detailed report.
- 3
Log each iteration.
- 4
More detailed iteration log.
- 5
Print the current parameter values at each iteration.
- covariancestr
Method to estimate the posterior covariance matrix of the hyperparameters:
- ‘fisher’
Use the Fisher information in the MAP, plus the prior precision, as precision matrix.
- ‘minhess’
Use the hessian estimate of the minimizer as precision matrix.
- ‘none’
Do not estimate the covariance matrix.
- ‘auto’ (default)
‘minhess’ if applicable, ‘none’ otherwise.
- fixscalar, array or dictionary of scalars/arrays
A set of booleans, with the same format as
hyperprior
, indicating which hyperparameters are kept fixed to their initial value. Scalars and arrays are broadcasted to the shape ofhyperprior
. If a dictionary, missing keys are treated as False.- mlkwdict
Additional arguments passed to
GP.marginal_likelihood
.- forwardbool
Use forward instead of backward (default) derivatives. Typically, forward is faster with a small number of parameters.
- Raises:
- RuntimeError
The minimization failed and
raises
is True.
- Attributes:
- pscalar, array or dictionary of scalars/arrays
A collection of gvars representing the hyperparameters that maximize their posterior. These gvars do not track correlations with the hyperprior or the data.
- priorscalar, array or dictionary of scalars/arrays
A copy of the hyperprior.
- initialscalar, array or dictionary of scalars/arrays
Starting point of the minimization, with the same format as
p
.- fixscalar, array or dictionary of scalars/arrays
A set of booleans, with the same format as
p
, indicating which parameters were kept fixed to the values ininitial
.- pmeanscalar, array or dictionary of scalars/arrays
Mean of
p
.- pcovscalar, array or dictionary of scalars/arrays
Covariance matrix of
p
.- minresultscipy.optimize.OptimizeResult
The result object returned by
scipy.optimize.minimize
.- minargsdict
The arguments passed to
scipy.optimize.minimize
.- gpfactorycallable
The
gpfactory
argument.- gpfactorykwdict
The
gpfactorykw
argument.- datadict, tuple or callable
The
data
argument.