Coverage for src/lsqfitgp/_fit.py: 84%
551 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 13:39 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 13:39 +0000
1# lsqfitgp/_fit.py
2#
3# Copyright (c) 2020, 2022, 2023, 2024, 2025, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import re 1fedabc
21import warnings 1fedabc
22import functools 1fedabc
23import time 1fedabc
24import textwrap 1fedabc
25import datetime 1fedabc
27import gvar 1fedabc
28import jax 1fedabc
29from jax import numpy as jnp 1fedabc
30import numpy 1fedabc
31from scipy import optimize 1fedabc
32from jax import tree_util 1fedabc
34from . import _GP 1fedabc
35from . import _linalg 1fedabc
36from . import _jaxext 1fedabc
37from . import _gvarext 1fedabc
38from . import _array 1fedabc
40# TODO the following token_ thing functionality may be provided by jax in the
41# future, follow the developments
43@functools.singledispatch 1fedabc
44def token_getter(x): 1fedabc
45 return x
47@functools.singledispatch 1fedabc
48def token_setter(x, token): 1fedabc
49 return token
51@token_getter.register(jnp.ndarray) 1fedabc
52@token_getter.register(numpy.ndarray) 1fedabc
53def _(x): 1fedabc
54 return x[x.ndim * (0,)] if x.size else x 1fedabc
56@token_setter.register(jnp.ndarray) 1fedabc
57@token_setter.register(numpy.ndarray) 1fedabc
58def _(x, token): 1fedabc
59 x = jnp.asarray(x) 1fedabc
60 return x.at[x.ndim * (0,)].set(token) if x.size else token 1fedabc
62def token_map_leaf(func, x): 1fedabc
63 if isinstance(x, (jnp.ndarray, numpy.ndarray)): 63 ↛ 74line 63 didn't jump to line 74 because the condition on line 63 was always true1fedabc
64 token = token_getter(x) 1fedabc
65 @jax.custom_jvp 1fedabc
66 def jaxfunc(token): 1fedabc
67 return jax.pure_callback(func, token, token, vmap_method='expand_dims') 1fedabc
68 @jaxfunc.defjvp 1fedabc
69 def _(p, t): 1fedabc
70 return (jaxfunc(*p), *t) 1fedabc
71 token = jaxfunc(token) 1fedabc
72 return token_setter(x, token) 1fedabc
73 else:
74 token = token_getter(x)
75 token = func(token)
76 return token_setter(x, token)
78def token_map(func, x): 1fedabc
79 return tree_util.tree_map(lambda x: token_map_leaf(func, x), x) 1fedabc
81class Logger: 1fedabc
82 """ Class to manage a log. Can be used as superclass. Each line of the log
83 has a verbosity level (an integer >= 0) and is printed only if this level is
84 below a threshold. All lines are saved and the log can be retrieved. """
86 def __init__(self, target_verbosity=0): 1fedabc
87 """ set the threshold used to exclude log lines """
88 self._verbosity = target_verbosity 1fedabc
89 self._loggedlines = [] 1fedabc
91 def _indent(self, text, level=0): 1fedabc
92 """ indent a text by provided level or by global current level """
93 level = max(0, level + self.loglevel._level) 1edabc
94 prefix = 4 * level * ' ' 1edabc
95 return textwrap.indent(text, prefix) 1edabc
97 def _select(self, verbosity, target_verbosity=None): 1fedabc
98 if target_verbosity is None: 98 ↛ 100line 98 didn't jump to line 100 because the condition on line 98 was always true1fedabc
99 target_verbosity = self._verbosity 1fedabc
100 if isinstance(verbosity, int): 1fedabc
101 return target_verbosity >= verbosity 1fedabc
102 else:
103 return target_verbosity in verbosity 1fedabc
105 def log(self, message, verbosity=1, *, level=0): 1fedabc
106 """
107 Print and record a message.
109 Parameters
110 ----------
111 message : str
112 The message to print. A newline is added unconditionally.
113 verbosity : int or set, default 1
114 The verbosity level(s) at which the message is printed. If an
115 integer, it's printed at all levels >= that integer. If a set, at
116 the specified levels.
117 level : int, default 0
118 The indentation level of the message.
119 """
120 if self._select(verbosity): 1fedabc
121 print(self._indent(message, level)) 1edabc
122 self._loggedlines.append((message, verbosity, level + self.loglevel._level)) 1fedabc
124 def getlog(self, target_verbosity=None, *, base_level=0): 1fedabc
125 """ return all logged line as a single string """
126 return '\n'.join(
127 self._indent(message, base_level + level)
128 for message, verbosity, level in self._loggedlines
129 if self._select(verbosity, target_verbosity)
130 )
132 class _LogLevel: 1fedabc
133 """ shared context manager to indent messages """
135 _level = 0 1fedabc
137 @classmethod 1fedabc
138 def __enter__(cls): 1fedabc
139 cls._level += 1 1fedabc
141 @classmethod 1fedabc
142 def __exit__(cls, *_): 1fedabc
143 cls._level -= 1 1fedabc
145 loglevel = _LogLevel() 1fedabc
147class empbayes_fit(Logger): 1fedabc
149 SEPARATE_JAC = False 1fedabc
151 def __init__( 1fedabc
152 self,
153 hyperprior,
154 gpfactory,
155 data,
156 *,
157 raises=True,
158 minkw={},
159 gpfactorykw={},
160 jit=True,
161 method='gradient',
162 initial='priormean',
163 verbosity=0,
164 covariance='auto',
165 fix=None,
166 mlkw={},
167 forward=False,
168 additional_loss=None,
169 ):
170 """
172 Maximum a posteriori fit.
174 Maximizes the marginal likelihood of the data with a Gaussian process
175 model that depends on hyperparameters, multiplied by a prior on the
176 hyperparameters.
178 Parameters
179 ----------
180 hyperprior : scalar, array or dictionary of scalars/arrays
181 A collection of gvars representing the prior for the
182 hyperparameters.
183 gpfactory : callable
184 A function with signature gpfactory(hyperparams) -> GP object. The
185 argument ``hyperparams`` has the same structure of the
186 `empbayes_fit` argument ``hyperprior``. gpfactory must be
187 JAX-friendly, i.e., use `jax.numpy` and `jax.scipy` instead of plain
188 `numpy`/`scipy` and avoid assignments to arrays.
189 data : dict, tuple or callable
190 Dictionary of data that is passed to `GP.marginal_likelihood` on
191 the GP object returned by ``gpfactory``. If a tuple, it contains the
192 first two arguments to `GP.marginal_likelihood`. If a callable, it
193 is called with the same arguments of ``gpfactory`` and must return
194 the argument(s) for `GP.marginal_likelihood`.
195 raises : bool, optional
196 If True (default), raise an error when the minimization fails.
197 Otherwise, use the last point of the minimization as result.
198 minkw : dict, optional
199 Keyword arguments passed to `scipy.optimize.minimize`, overwrites
200 values specified by `empbayes_fit`.
201 gpfactorykw : dict, optional
202 Keyword arguments passed to ``gpfactory``, and also to ``data`` if
203 it is a callable. If ``jit``, ``gpfactorykw`` crosses a `jax.jit`
204 boundary, so it must contain objects understandable by `jax`.
205 jit : bool
206 If True (default), use `jax.jit` to compile the minimization target.
207 method : str
208 Minimization strategy. Options:
210 'nograd'
211 Use a gradient-free method.
212 'gradient' (default)
213 Use a gradient-only method.
214 'fisher'
215 Use a Newton method with the Fisher information matrix plus
216 the hyperprior precision matrix.
217 initial : str, scalar, array, dictionary of scalars/arrays
218 Starting point for the minimization, matching the format of
219 ``hyperprior``, or one of the following options:
221 'priormean' (default)
222 Start from the hyperprior mean.
223 'priorsample'
224 Take a random sample from the hyperprior.
225 verbosity : int
226 An integer indicating how much information is printed on the
227 terminal:
229 0 (default)
230 No logging.
231 1
232 Minimal report.
233 2
234 Detailed report.
235 3
236 Log each iteration.
237 4
238 More detailed iteration log.
239 5
240 Print the current parameter values at each iteration.
241 covariance : str
242 Method to estimate the posterior covariance matrix of the
243 hyperparameters:
245 'fisher'
246 Use the Fisher information in the MAP, plus the prior precision,
247 as precision matrix.
248 'minhess'
249 Use the hessian estimate of the minimizer as precision matrix.
250 'none'
251 Do not estimate the covariance matrix.
252 'auto' (default)
253 ``'minhess'`` if applicable, ``'none'`` otherwise.
254 fix : scalar, array or dictionary of scalars/arrays
255 A set of booleans, with the same format as ``hyperprior``,
256 indicating which hyperparameters are kept fixed to their initial
257 value. Scalars and arrays are broadcasted to the shape of
258 ``hyperprior``. If a dictionary, missing keys are treated as False.
259 mlkw : dict
260 Additional arguments passed to `GP.marginal_likelihood`.
261 forward : bool, default False
262 Use forward instead of backward derivatives. Typically, forward is
263 faster with a small number of parameters.
264 additional_loss : callable, optional
265 A function with signature ``additional_loss(hyperparams) -> float``
266 which is added to the minus log marginal posterior of the
267 hyperparameters.
269 Attributes
270 ----------
271 p : scalar, array or dictionary of scalars/arrays
272 A collection of gvars representing the hyperparameters that
273 maximize their posterior. These gvars do not track correlations
274 with the hyperprior or the data.
275 prior : scalar, array or dictionary of scalars/arrays
276 A copy of the hyperprior.
277 initial : scalar, array or dictionary of scalars/arrays
278 Starting point of the minimization, with the same format as ``p``.
279 fix : scalar, array or dictionary of scalars/arrays
280 A set of booleans, with the same format as ``p``, indicating which
281 parameters were kept fixed to the values in ``initial``.
282 pmean : scalar, array or dictionary of scalars/arrays
283 Mean of ``p``.
284 pcov : scalar, array or dictionary of scalars/arrays
285 Covariance matrix of ``p``.
286 minresult : scipy.optimize.OptimizeResult
287 The result object returned by `scipy.optimize.minimize`.
288 minargs : dict
289 The arguments passed to `scipy.optimize.minimize`.
290 gpfactory : callable
291 The ``gpfactory`` argument.
292 gpfactorykw : dict
293 The ``gpfactorykw`` argument.
294 data : dict, tuple or callable
295 The ``data`` argument.
297 Raises
298 ------
299 RuntimeError
300 The minimization failed and ``raises`` is True.
302 """
304 Logger.__init__(self, verbosity) 1fedabc
305 del verbosity 1fedabc
306 self.log('**** call lsqfitgp.empbayes_fit ****') 1fedabc
308 assert callable(gpfactory) 1fedabc
310 # analyze the hyperprior
311 hpinitial, hpunflat = self._parse_hyperprior(hyperprior, initial, fix) 1fedabc
312 del hyperprior, initial, fix 1fedabc
314 # analyze data
315 data, cachedargs = self._parse_data(data) 1fedabc
317 # define functions
318 timer, functions = self._prepare_functions( 1fedabc
319 gpfactory=gpfactory, gpfactorykw=gpfactorykw, data=data,
320 cachedargs=cachedargs, hpunflat=hpunflat, mlkw=mlkw, jit=jit,
321 forward=forward, additional_loss=additional_loss,
322 )
323 del gpfactory, gpfactorykw, data, cachedargs, mlkw, forward, additional_loss 1fedabc
325 # prepare minimizer arguments
326 minargs = self._prepare_minargs(method, functions, hpinitial) 1fedabc
328 # set up callback to time and log iterations
329 callback = self._Callback(self, functions, timer, hpunflat) 1fedabc
330 minargs.update(callback=callback) 1fedabc
332 # check invalid argument before running minimizer
333 if not covariance in ('auto', 'fisher', 'minhess', 'none'): 333 ↛ 334line 333 didn't jump to line 334 because the condition on line 333 was never true1fedabc
334 raise KeyError(covariance)
336 # add user arguments and minimize
337 minargs.update(minkw) 1fedabc
338 self.log(f'minimizer method {minargs["method"]!r}', 2) 1fedabc
339 total = time.perf_counter() 1fedabc
340 result = optimize.minimize(**minargs) 1fedabc
342 # check the minimization was successful
343 self._check_success(result, raises) 1fedabc
345 # compute posterior covariance of the hyperparameters
346 cov = self._posterior_covariance(method, covariance, result, functions['fisher']) 1fedabc
348 # log total timings and function calls
349 total = time.perf_counter() - total 1fedabc
350 self._log_totals(total, timer, callback, jit, functions) 1fedabc
352 ##### temporary fix for gplepage/gvar#50 #####
353 cov = numpy.array(cov, order='C') 1fedabc
354 ##############################################
356 # join posterior mean and covariance matrix
357 uresult = gvar.gvar(result.x, cov) 1fedabc
359 # set attributes
360 self.p = gvar.gvar(hpunflat(uresult)) 1fedabc
361 self.pmean = gvar.mean(self.p) 1fedabc
362 self.pcov = gvar.evalcov(self.p) 1fedabc
363 self.minresult = result 1fedabc
364 self.minargs = minargs 1fedabc
366 # tabulate hyperparameter prior and posterior
367 if self._verbosity >= 2: 1fedabc
368 self.log(_gvarext.tabulate_together( 1edabc
369 self.prior, self.p,
370 headers=['param', 'prior', 'posterior'],
371 )) # TODO replace tabulate_toegether with something more flexible I
372 # can use for the callback as well. Maybe import TextMatrix from
373 # miscpy.
374 # TODO print the transformed parameters
376 self.log('**** exit lsqfitgp.empbayes_fit ****') 1fedabc
378 class _CountCalls: 1fedabc
379 """ wrap a callable to count calls """
381 def __init__(self, func): 1fedabc
382 self._func = func 1fedabc
383 self._total = 0 1fedabc
384 self._partial = 0 1fedabc
385 functools.update_wrapper(self, func) 1fedabc
387 def __call__(self, *args, **kw): 1fedabc
388 self._total += 1 1fedabc
389 self._partial += 1 1fedabc
390 return self._func(*args, **kw) 1fedabc
392 def partial(self): 1fedabc
393 """ return the partial counter and reset it """
394 result = self._partial 1fedabc
395 self._partial = 0 1fedabc
396 return result 1fedabc
398 def total(self): 1fedabc
399 """ return the total number of calls """
400 return self._total 1fedabc
402 @staticmethod 1fedabc
403 def fmtcalls(method, functions): 1fedabc
404 """
405 format summary of number of calls
406 method : str
407 functions: dict[str, _CountCalls]
408 """
409 def counts(): 1fedabc
410 for name, func in functions.items(): 1fedabc
411 if count := getattr(func, method)(): 1fedabc
412 yield f'{name} {count}' 1fedabc
413 return ', '.join(counts()) 1fedabc
415 class _Timer: 1fedabc
416 """ object to time likelihood computations """
418 def __init__(self): 1fedabc
419 self.totals = {} 1fedabc
420 self.partials = {} 1fedabc
421 self._last_start = False 1fedabc
423 def start(self, token): 1fedabc
424 return token_map(self._start, token) 1fedabc
426 def _start(self, token): 1fedabc
427 self.stamp = time.perf_counter() 1fedabc
428 self.counter = 0 1fedabc
429 assert not self._last_start # forbid consecutive start() calls 1fedabc
430 self._last_start = True 1fedabc
431 return token 1fedabc
433 def reset(self): 1fedabc
434 self.partials = {} 1fedabc
436 def partial(self, token): 1fedabc
437 return token_map(self._partial, token) 1fedabc
439 def _partial(self, token): 1fedabc
440 now = time.perf_counter() 1fedabc
441 delta = now - self.stamp 1fedabc
442 self.partials[self.counter] = self.partials.get(self.counter, 0) + delta 1fedabc
443 self.totals[self.counter] = self.totals.get(self.counter, 0) + delta 1fedabc
444 self.stamp = now 1fedabc
445 self.counter += 1 1fedabc
446 self._last_start = False 1fedabc
447 return token 1fedabc
449 def _parse_hyperprior(self, hyperprior, initial, fix): 1fedabc
451 # check fix against hyperprior and fill missing values
452 hyperprior = self._copyasarrayorbufferdict(hyperprior) 1fedabc
453 self._check_no_redundant_keys(hyperprior) 1fedabc
454 fix = self._parse_fix(hyperprior, fix) 1fedabc
455 flatfix = self._flatview(fix) 1fedabc
457 # extract distribution of free hyperparameters
458 flathp = self._flatview(hyperprior) 1fedabc
459 freehp = flathp[~flatfix] 1fedabc
460 mean = gvar.mean(freehp) 1fedabc
461 cov = gvar.evalcov(freehp) # TODO use evalcov_blocks 1fedabc
462 dec = _linalg.Chol(cov) 1fedabc
463 assert dec.n == freehp.size 1fedabc
464 self.log(f'{freehp.size}/{flathp.size} free hyperparameters', 2) 1fedabc
466 # determine starting point for minimization
467 initial = self._parse_initial(hyperprior, initial, dec) 1fedabc
468 flatinitial = self._flatview(initial) 1fedabc
469 x0 = dec.pinv_correlate(flatinitial[~flatfix] - mean) 1fedabc
470 # TODO for initial = 'priormean', x0 is zero, skip decorrelate
471 # for initial = 'priorsample', x0 is iid normal, but I have to sync
472 # it with the user-exposed unflattened initial in _parse_initial
474 # make function to correlate, add fixed values, and reshape to original
475 # format
476 fixed_indices, = jnp.nonzero(flatfix) 1fedabc
477 unfixed_indices, = jnp.nonzero(~flatfix) 1fedabc
478 fixed_values = jnp.asarray(flatinitial[flatfix]) 1fedabc
479 def unflat(x): 1fedabc
480 assert x.ndim == 1 1fedabc
481 if x.dtype == object: 1fedabc
482 jac, indices = _gvarext.jacobian(x) 1fedabc
483 xmean = mean + dec.correlate(gvar.mean(x)) 1fedabc
484 xjac = dec.correlate(jac) 1fedabc
485 x = _gvarext.from_jacobian(xmean, xjac, indices) 1fedabc
486 y = numpy.empty(flatfix.size, x.dtype) 1fedabc
487 numpy.put(y, unfixed_indices, x) 1fedabc
488 numpy.put(y, fixed_indices, fixed_values) 1fedabc
489 else:
490 x = mean + dec.correlate(x) 1fedabc
491 y = jnp.empty(flatfix.size, x.dtype) 1fedabc
492 y = y.at[unfixed_indices].set(x) 1fedabc
493 y = y.at[fixed_indices].set(fixed_values) 1fedabc
494 return self._unflatview(y, hyperprior) 1fedabc
496 self.prior = hyperprior 1fedabc
497 return x0, unflat 1fedabc
499 @staticmethod 1fedabc
500 def _check_no_redundant_keys(hyperprior): 1fedabc
501 if not hasattr(hyperprior, 'keys'): 1fedabc
502 return 1dabc
503 for k in hyperprior: 1fedabc
504 m = hyperprior.extension_pattern.match(k) 1fedabc
505 if m and m.group(1) in hyperprior.invfcn: 1fedabc
506 altk = m.group(2) 1fedabc
507 if altk in hyperprior: 507 ↛ 508line 507 didn't jump to line 508 because the condition on line 507 was never true1fedabc
508 raise ValueError(f'duplicate keys {altk!r} and {k!r} in hyperprior')
510 def _parse_fix(self, hyperprior, fix): 1fedabc
512 if fix is None: 512 ↛ 518line 512 didn't jump to line 518 because the condition on line 512 was always true1fedabc
513 if hasattr(hyperprior, 'keys'): 1fedabc
514 fix = gvar.BufferDict(hyperprior, buf=numpy.zeros(hyperprior.size, bool)) 1fedabc
515 else:
516 fix = numpy.zeros(hyperprior.shape, bool) 1dabc
517 else:
518 fix = self._copyasarrayorbufferdict(fix)
519 if hasattr(fix, 'keys'):
520 assert hasattr(hyperprior, 'keys'), 'fix is dictionary but hyperprior is array'
521 assert all(hyperprior.has_dictkey(k) for k in fix), 'some keys in fix are missing in hyperprior'
522 newfix = {}
523 for k, v in hyperprior.items():
524 key = None
525 m = hyperprior.extension_pattern.match(k)
526 if m and m.group(1) in hyperprior.invfcn:
527 altk = m.group(2)
528 if altk in fix:
529 assert k not in fix, f'duplicate keys {k!r} and {altk!r} in fix'
530 key = altk
531 if key is None and k in fix:
532 key = k
533 if key is None:
534 elem = numpy.zeros(v.shape, bool)
535 else:
536 elem = numpy.broadcast_to(fix[key], v.shape)
537 newfix[k] = elem
538 fix = gvar.BufferDict(newfix, dtype=bool)
539 else:
540 assert not hasattr(hyperprior, 'keys'), 'fix is array but hyperprior is dictionary'
541 fix = numpy.broadcast_to(fix, hyperprior.shape).astype(bool)
543 self.fix = fix 1fedabc
544 return fix 1fedabc
546 def _parse_initial(self, hyperprior, initial, dec): 1fedabc
548 if not isinstance(initial, str): 548 ↛ 549line 548 didn't jump to line 549 because the condition on line 548 was never true1fedabc
549 self.log('start from provided point', 2)
550 initial = self._copyasarrayorbufferdict(initial)
551 if hasattr(hyperprior, 'keys'):
552 assert hasattr(initial, 'keys'), 'hyperprior is dictionary but initial is array'
553 assert set(hyperprior.keys()) == set(initial.keys())
554 assert all(hyperprior[k].shape == initial[k].shape for k in hyperprior)
555 else:
556 assert not hasattr(initial, 'keys'), 'hyperprior is array but initial is dictionary'
557 assert hyperprior.shape == initial.shape
559 elif initial == 'priormean': 559 ↛ 563line 559 didn't jump to line 563 because the condition on line 559 was always true1fedabc
560 self.log('start from prior mean', 2) 1fedabc
561 initial = gvar.mean(hyperprior) 1fedabc
563 elif initial == 'priorsample':
564 self.log('start from a random sample from the prior', 2)
565 if dec.n < hyperprior.size:
566 flathp = self._flatview(hyperprior)
567 cov = gvar.evalcov(flathp) # TODO use evalcov_blocks
568 fulldec = _linalg.Chol(cov)
569 else:
570 fulldec = dec
571 iid = numpy.random.randn(fulldec.m)
572 flatinitial = numpy.asarray(fulldec.correlate(iid))
573 initial = self._unflatview(flatinitial, hyperprior)
575 else:
576 raise KeyError(initial)
578 self.initial = initial 1fedabc
579 return initial 1fedabc
581 def _parse_data(self, data): 1fedabc
583 self.data = data 1fedabc
584 if isinstance(data, tuple) and len(data) == 1: 1fedabc
585 data, = data 1dabc
587 if callable(data): 1fedabc
588 self.log('data is callable', 2) 1edabc
589 cachedargs = None 1edabc
590 elif isinstance(data, tuple): 1fedabc
591 self.log('data errors provided separately', 2) 1dabc
592 assert len(data) == 2 1dabc
593 cachedargs = data 1dabc
594 elif (gdata := self._copyasarrayorbufferdict(data)).dtype == object: 1fedabc
595 self.log('data has errors as gvars', 2) 1edabc
596 data = gvar.gvar(gdata) 1edabc
597 # convert to gvar because non-gvars in the array would upset
598 # gvar.mean and gvar.evalcov
599 cachedargs = (gvar.mean(data), gvar.evalcov(data)) 1edabc
600 else:
601 self.log('data has no errors', 2) 1fedabc
602 cachedargs = (data,) 1fedabc
604 return data, cachedargs 1fedabc
606 def _prepare_functions(self, *, gpfactory, gpfactorykw, data, cachedargs, 1fedabc
607 hpunflat, mlkw, jit, forward, additional_loss):
609 timer = self._Timer() 1fedabc
610 firstcall = [None] 1fedabc
612 def make_decomp(p, **kw): 1fedabc
613 """ decomposition of the prior covariance and data """
615 # start timer and convert hypers to user format
616 p = timer.start(p) 1fedabc
617 hp = hpunflat(p) 1fedabc
619 # create GP object
620 gp = gpfactory(hp, **kw) 1fedabc
621 assert isinstance(gp, _GP.GP) 1fedabc
623 # extract data
624 if cachedargs: 1fedabc
625 args = cachedargs 1fedabc
626 else:
627 args = data(hp, **kw) 1edabc
628 if not isinstance(args, tuple): 1edabc
629 args = (args,) 1edabc
631 # decompose covariance matrix and flatten data
632 decomp, r = gp._prior_decomp(*args, covtransf=timer.partial, **mlkw) 1fedabc
633 r = r.astype(float) # int data upsets jax 1fedabc
635 # log number of datapoints
636 if firstcall: 1fedabc
637 # it is convenient to do here because the data is flattened.
638 # works under jit since the first call is tracing
639 firstcall.pop() 1fedabc
640 xdtype = gp._get_x_dtype() 1fedabc
641 nd = '?' if xdtype is None else _array._nd(xdtype) 1fedabc
642 self.log(f'{r.size} datapoints, {nd} covariates') 1fedabc
644 # compute user loss
645 if additional_loss is None: 1fedabc
646 loss = 0. 1fedabc
647 else:
648 loss = additional_loss(hp) 1edabc
650 # split timer and return decomposition
651 return timer.partial(decomp), r, loss 1fedabc
652 # TODO what's the correct way of checkpointing r?
654 # define wrapper to collect call stats, pass user args, compile
655 def wrap(func): 1fedabc
656 if jit: 656 ↛ 658line 656 didn't jump to line 658 because the condition on line 656 was always true1fedabc
657 func = jax.jit(func) 1fedabc
658 func = functools.partial(func, **gpfactorykw) 1fedabc
659 return self._CountCalls(func) 1fedabc
660 if jit: 660 ↛ 664line 660 didn't jump to line 664 because the condition on line 660 was always true1fedabc
661 self.log('compile functions with jax jit', 2) 1fedabc
663 # log derivation method
664 modename = 'forward' if forward else 'reverse' 1fedabc
665 self.log(f'{modename}-mode autodiff (if used)', 2) 1fedabc
667 # TODO time the derivatives separately => maybe I need a custom
668 # derivative rule for timer token acknoledgement?
670 def prior(p): 1fedabc
671 # the marginal prior of the hyperparameters is a Normal with
672 # identity covariance matrix because p is transformed to make it so
673 return 1/2 * (len(p) * jnp.log(2 * jnp.pi) + p @ p) 1fedabc
675 def grad_prior(p): 1fedabc
676 return p 1fedabc
678 def fisher_prior(p): 1fedabc
679 return jnp.eye(len(p)) 1dabc
681 @wrap 1fedabc
682 def fun(p, **kw): 1fedabc
683 """ minus log marginal posterior of the hyperparameters (not
684 normalized) """
685 decomp, r, loss = make_decomp(p, **kw) 1dabc
686 cond, _, _, _, _ = decomp.minus_log_normal_density(r, value=True) 1dabc
687 post = cond + prior(p) + loss 1dabc
688 # TODO what's the correct way of checkpointing prior and loss?
689 return timer.partial(post) 1dabc
691 def make_gradfwd_fisher_args(p, **kw): 1fedabc
692 def make_decomp_tee(p): 1edabc
693 decomp, r, loss = make_decomp(p, **kw) 1edabc
694 return (decomp.matrix(), r, loss), (decomp, r, loss) 1edabc
695 (dK, dr, grad_loss), (decomp, r, loss) = jax.jacfwd(make_decomp_tee, has_aux=True)(p) 1edabc
696 lkw = dict(dK=dK, dr=dr) 1edabc
697 return decomp, r, lkw, loss, grad_loss 1edabc
699 def make_gradrev_args(p, **kw): 1fedabc
700 def make_decomp_loss(p): 1fedabc
701 def make_decomp_r(p): 1fedabc
702 def make_decomp_K(p): 1fedabc
703 decomp, r, loss = make_decomp(p, **kw) 1fedabc
704 return decomp.matrix(), (decomp, r, loss) 1fedabc
705 _, dK_vjp, (decomp, r, loss) = jax.vjp(make_decomp_K, p, has_aux=True) 1fedabc
706 return r, (decomp, r, dK_vjp, loss) 1fedabc
707 _, dr_vjp, (decomp, r, dK_vjp, loss) = jax.vjp(make_decomp_r, p, has_aux=True) 1fedabc
708 return loss, (decomp, r, dK_vjp, dr_vjp, loss) 1fedabc
709 grad_loss, (decomp, r, dK_vjp, dr_vjp, loss) = jax.grad(make_decomp_loss, has_aux=True)(p) 1fedabc
710 unpack = lambda f: lambda x: f(x)[0] 1fedabc
711 dK_vjp = unpack(dK_vjp) 1fedabc
712 dr_vjp = unpack(dr_vjp) 1fedabc
713 lkw = dict(dK_vjp=dK_vjp, dr_vjp=dr_vjp) 1fedabc
714 return decomp, r, lkw, loss, grad_loss 1fedabc
716 def make_jac_args(p, **kw): 1fedabc
717 if forward: 1fedabc
718 out = make_gradfwd_fisher_args(p, **kw) 1edabc
719 out[2].update(gradfwd=True) # out[2] is lkw 1edabc
720 else:
721 out = make_gradrev_args(p, **kw) 1fedabc
722 out[2].update(gradrev=True) 1fedabc
723 return out 1fedabc
725 @wrap 1fedabc
726 def fun_and_jac(p, **kw): 1fedabc
727 """ fun and its gradient """
728 decomp, r, lkw, loss, grad_loss = make_jac_args(p, **kw) 1fedabc
729 cond, gradrev, gradfwd, _, _ = decomp.minus_log_normal_density(r, value=True, **lkw) 1fedabc
730 post = cond + prior(p) + loss 1fedabc
731 grad_cond = gradfwd if forward else gradrev 1fedabc
732 grad_post = grad_cond + grad_prior(p) + grad_loss 1fedabc
733 return timer.partial((post, grad_post)) 1fedabc
735 @wrap 1fedabc
736 def jac(p, **kw): 1fedabc
737 """ gradient of fun """
738 decomp, r, lkw, _, grad_loss = make_jac_args(p, **kw)
739 _, gradrev, gradfwd, _, _ = decomp.minus_log_normal_density(r, **lkw)
740 grad_cond = gradfwd if forward else gradrev
741 grad_post = grad_cond + grad_prior(p) + grad_loss
742 return timer.partial(grad_post)
744 @wrap 1fedabc
745 def fisher(p, **kw): 1fedabc
746 """ fisher matrix """
747 if additional_loss is not None: 1dabc
748 raise NotImplementedError( 1dabc
749 'Fisher matrix not implemented with additional_loss. It '
750 'is possible but I did not prioritize it. If you need it, '
751 'open an issue on github.')
752 decomp, r, lkw, _, _ = make_gradfwd_fisher_args(p, **kw) 1dabc
753 _, _, _, fisher_cond, _ = decomp.minus_log_normal_density(r, fisher=True, **lkw) 1dabc
754 fisher_post = fisher_cond + fisher_prior(p) 1dabc
755 return timer.partial(fisher_post) 1dabc
757 # set attributes
758 self.gpfactory = gpfactory 1fedabc
759 self.gpfactorykw = gpfactorykw 1fedabc
761 return timer, { 1fedabc
762 'fun': fun,
763 'jac': jac,
764 'fun&jac': fun_and_jac,
765 'fisher': fisher,
766 }
768 def _prepare_minargs(self, method, functions, hpinitial): 1fedabc
769 minargs = dict(fun=functions['fun&jac'], jac=True, x0=hpinitial) 1fedabc
770 if self.SEPARATE_JAC: 770 ↛ 771line 770 didn't jump to line 771 because the condition on line 770 was never true1fedabc
771 minargs.update(fun=functions['fun'], jac=functions['jac'])
772 if method == 'nograd': 1fedabc
773 minargs.update(fun=functions['fun'], jac=None, method='nelder-mead') 1dabc
774 elif method == 'gradient': 1fedabc
775 minargs.update(method='bfgs') 1fedabc
776 elif method == 'fisher': 1dabc
777 minargs.update(hess=functions['fisher'], method='dogleg') 1dabc
778 # dogleg requires positive definiteness, fisher is p.s.d.
779 # trust-constr has more options, but it seems to be slower than
780 # dogleg, so I keep dogleg as default
781 else:
782 raise KeyError(method) 1dabc
783 self.log(f'method {method!r}', 2) 1fedabc
784 return minargs 1fedabc
786 # TODO add method with fisher matvec instead of fisher matrix
788 def _log_totals(self, total, timer, callback, jit, functions): 1fedabc
789 times = { 1fedabc
790 'gp&cov': timer.totals[0],
791 'decomp': timer.totals[1],
792 'likelihood': timer.totals[2],
793 'jit': None, # set now and delete later to keep it before 'other'
794 'other': total - sum(timer.totals.values()),
795 }
796 if jit: 796 ↛ 803line 796 didn't jump to line 803 because the condition on line 796 was always true1fedabc
797 overhead = callback.estimate_firstcall_overhead() 1fedabc
798 # TODO this estimation ignores the jit compilation of the function
799 # used to compute the precision matrix, to be precise I should
800 # manually split the jit into compilation + evaluation or hook into
801 # it somehow. Maybe the jit object keeps a compilation wall time
802 # stat?
803 if jit and overhead is not None: 1fedabc
804 times['jit'] = overhead 1fedabc
805 times['other'] -= overhead 1fedabc
806 else:
807 del times['jit'] 1dabc
808 self.log('', 4) 1fedabc
809 calls = self._CountCalls.fmtcalls('total', functions) 1fedabc
810 self.log(f'calls: {calls}') 1fedabc
811 self.log(f'total time: {callback.fmttime(total)}') 1fedabc
812 self.log(f'partials: {callback.fmttimes(times)}', 2) 1fedabc
814 def _check_success(self, result, raises): 1fedabc
815 if result.success: 1fedabc
816 self.log(f'minimization succeeded: {result.message}') 1fedabc
817 else:
818 msg = f'minimization failed: {result.message}' 1fdabc
819 if raises: 1fdabc
820 raise RuntimeError(msg) 1dabc
821 elif self._verbosity == 0: 1fabc
822 warnings.warn(msg) 1fb
823 else:
824 self.log(msg) 1abc
826 def _posterior_covariance(self, method, covariance, minimizer_result, fisher_func): 1fedabc
828 if covariance == 'auto': 828 ↛ 834line 828 didn't jump to line 834 because the condition on line 828 was always true1fedabc
829 if hasattr(minimizer_result, 'hess_inv') or hasattr(minimizer_result, 'hess'): 1fedabc
830 covariance = 'minhess' 1fedabc
831 else:
832 covariance = 'none' 1dabc
834 if covariance == 'fisher': 834 ↛ 835line 834 didn't jump to line 835 because the condition on line 834 was never true1fedabc
835 self.log('use fisher plus prior precision as precision', 2)
836 if method == 'fisher':
837 prec = minimizer_result.hess
838 else:
839 prec = fisher_func(minimizer_result.x)
840 cov = _linalg.Chol(prec).ginv()
842 elif covariance == 'minhess': 1fedabc
843 if hasattr(minimizer_result, 'hess_inv'): 1fedabc
844 hessinv = minimizer_result.hess_inv 1fedabc
845 if isinstance(hessinv, optimize.LbfgsInvHessProduct): 1fedabc
846 self.log(f'convert LBFGS({hessinv.n_corrs}) hessian inverse to BFGS as covariance', 2) 1edabc
847 cov = self._invhess_lbfgs_to_bfgs(hessinv) 1edabc
848 # TODO this still gives a too wide cov when the minimization
849 # terminates due to bad linear search, is it because of
850 # dropped updates? This is currently keeping me from setting
851 # l-bfgs-b as default minimization method.
852 elif isinstance(hessinv, numpy.ndarray): 852 ↛ 867line 852 didn't jump to line 867 because the condition on line 852 was always true1fedabc
853 self.log('use minimizer estimate of inverse hessian as covariance', 2) 1fedabc
854 cov = hessinv 1fedabc
855 elif hasattr(minimizer_result, 'hess'): 855 ↛ 859line 855 didn't jump to line 859 because the condition on line 855 was always true1dabc
856 self.log('use minimizer hessian as precision', 2) 1dabc
857 cov = _linalg.Chol(minimizer_result.hess).ginv() 1dabc
858 else:
859 raise RuntimeError('the minimizer did not return an estimate of the hessian')
861 elif covariance == 'none': 861 ↛ 865line 861 didn't jump to line 865 because the condition on line 861 was always true1dabc
862 cov = numpy.full(minimizer_result.x.size, numpy.nan) 1dabc
864 else:
865 raise KeyError(covariance)
867 return cov 1fedabc
869 @staticmethod 1fedabc
870 def _invhess_lbfgs_to_bfgs(lbfgs): 1fedabc
871 bfgs = optimize.BFGS() 1edabc
872 bfgs.initialize(lbfgs.shape[0], 'inv_hess') 1edabc
873 for i in range(lbfgs.n_corrs): 1edabc
874 bfgs.update(lbfgs.sk[i], lbfgs.yk[i]) 1edabc
875 return bfgs.get_matrix() 1edabc
877 class _Callback: 1fedabc
878 """ Iteration callback for scipy.optimize.minimize """
880 def __init__(self, this, functions, timer, unflat): 1fedabc
881 self.it = 0 1fedabc
882 self.stamp = time.perf_counter() 1fedabc
883 self.this = this 1fedabc
884 self.functions = functions 1fedabc
885 self.timer = timer 1fedabc
886 self.unflat = unflat 1fedabc
887 self.tail_overhead = 0 1fedabc
888 self.tail_overhead_iter = 0 1fedabc
890 def __call__(self, intermediate_result, arg2=None): 1fedabc
892 if isinstance(intermediate_result, optimize.OptimizeResult): 892 ↛ 893line 892 didn't jump to line 893 because the condition on line 892 was never true1fedabc
893 p = intermediate_result.x
894 elif isinstance(intermediate_result, numpy.ndarray): 894 ↛ 897line 894 didn't jump to line 897 because the condition on line 894 was always true1fedabc
895 p = intermediate_result 1fedabc
896 else:
897 raise TypeError(type(intermediate_result))
899 self.it += 1 1fedabc
900 now = time.perf_counter() 1fedabc
901 duration = now - self.stamp 1fedabc
903 worktime = sum(self.timer.partials.values()) 1fedabc
904 if worktime: 1fedabc
905 overhead = duration - worktime 1fedabc
906 assert overhead >= 0, (duration, worktime) 1fedabc
907 if self.it == 1: 1fedabc
908 self.first_overhead = overhead 1fedabc
909 else:
910 self.tail_overhead_iter += 1 1fedabc
911 self.tail_overhead += overhead 1fedabc
913 # level 3 log
914 calls = self.this._CountCalls.fmtcalls('partial', self.functions) 1fedabc
915 times = self.fmttime(duration) 1fedabc
916 self.this.log(f'iter {self.it}, time: {times}, calls: {calls}', {3}) 1fedabc
918 # level 4 log
919 tot = self.fmttime(duration) 1fedabc
920 if self.timer.partials: 1fedabc
921 times = { 1fedabc
922 'gp&cov': self.timer.partials[0],
923 'dec': self.timer.partials[1],
924 'like': self.timer.partials[2],
925 'other': duration - sum(self.timer.partials.values()),
926 }
927 times = self.fmttimes(times) 1fedabc
928 else:
929 times = 'n/d' 1dabc
930 self.this.log(f'\niteration {self.it}', 4) 1fedabc
931 with self.this.loglevel: 1fedabc
932 self.this.log(f'total time: {tot}', 4) 1fedabc
933 self.this.log(f'partial: {times}', 4) 1fedabc
934 self.this.log(f'calls: {calls}', 4) 1fedabc
936 # level 5 log
937 nicep = self.unflat(p) 1fedabc
938 nicep = self.this._copyasarrayorbufferdict(nicep) 1fedabc
939 with self.this.loglevel: 1fedabc
940 self.this.log(f'parameters = {nicep}', 5) 1fedabc
941 # TODO write a method to format the parameters nicely. => use
942 # gvar.tabulate? => nope, need actual gvars
943 # TODO does this logging add significant overhead?
945 self.stamp = now 1fedabc
946 self.timer.reset() 1fedabc
948 pattern = re.compile( 1fedabc
949 r'((\d+) days, )?(\d{1,2}):(\d\d):(\d\d(\.\d{6})?)')
951 @classmethod 1fedabc
952 def fmttime(cls, seconds): 1fedabc
953 if seconds < 0: 953 ↛ 954line 953 didn't jump to line 954 because the condition on line 953 was never true1fedabc
954 prefix = '-'
955 seconds = -seconds
956 else:
957 prefix = '' 1fedabc
958 return prefix + cls._fmttime_positive(seconds) 1fedabc
960 @classmethod 1fedabc
961 def _fmttime_positive(cls, seconds): 1fedabc
962 td = datetime.timedelta(seconds=seconds) 1fedabc
963 m = cls.pattern.fullmatch(str(td)) 1fedabc
964 _, day, hour, minute, second, _ = m.groups() 1fedabc
965 hour = int(hour) 1fedabc
966 minute = int(minute) 1fedabc
967 second = float(second) 1fedabc
968 if day: 968 ↛ 969line 968 didn't jump to line 969 because the condition on line 968 was never true1fedabc
969 return f'{day.lstrip("0")}d{hour:02d}h'
970 elif hour: 970 ↛ 971line 970 didn't jump to line 971 because the condition on line 970 was never true1fedabc
971 return f'{hour}h{minute:02d}m'
972 elif minute: 1fedabc
973 return f'{minute}m{second:02.0f}s' 1e
974 elif second >= 0.0995: 1fedabc
975 return f'{second:#.2g}'.rstrip('.') + 's' 1fedabc
976 elif second >= 0.0000995: 1fedabc
977 return f'{second * 1e3:#.2g}'.rstrip('.') + 'ms' 1fedabc
978 else:
979 return f'{second * 1e6:.0f}μs' 1feda
981 @classmethod 1fedabc
982 def fmttimes(cls, times): 1fedabc
983 """ times = dict label -> seconds """
984 return ', '.join(f'{k} {cls.fmttime(v)}' for k, v in times.items()) 1fedabc
986 def estimate_firstcall_overhead(self): 1fedabc
987 if self.tail_overhead_iter and hasattr(self, 'first_overhead'): 1fedabc
988 typical_overhead = self.tail_overhead / self.tail_overhead_iter 1fedabc
989 return self.first_overhead - typical_overhead 1fedabc
991 @staticmethod 1fedabc
992 def _copyasarrayorbufferdict(x): 1fedabc
993 if hasattr(x, 'keys'): 1fedabc
994 return gvar.BufferDict(x) 1fedabc
995 else:
996 return numpy.array(x) 1dabc
998 @staticmethod 1fedabc
999 def _flatview(x): 1fedabc
1000 if hasattr(x, 'reshape'): 1fedabc
1001 return x.reshape(-1) 1dabc
1002 elif hasattr(x, 'buf'): 1fedabc
1003 return x.buf 1fedabc
1004 else: # pragma: no cover
1005 raise NotImplementedError
1007 @staticmethod 1fedabc
1008 def _unflatview(x, original): 1fedabc
1009 if isinstance(original, numpy.ndarray): 1fedabc
1010 # TODO is this never applied to jax arrays?
1011 out = x.reshape(original.shape) 1dabc
1012 # if not out.shape:
1013 # try:
1014 # out = out.item()
1015 # except jax.errors.ConcretizationTypeError:
1016 # pass
1017 return out 1dabc
1018 elif isinstance(original, gvar.BufferDict): 1fedabc
1019 # normally I would do BufferDict(original, buf=x) but it does not
1020 # work with JAX tracers
1021 b = gvar.BufferDict(original) 1fedabc
1022 b._extension = {} 1fedabc
1023 b._buf = x 1fedabc
1024 # b.buf = x does not work because BufferDict checks that the
1025 # array is a numpy array
1026 # TODO maybe make a feature request to gvar to accept array_like
1027 # buf
1028 return b 1fedabc
1029 else: # pragma: no cover
1030 raise NotImplementedError
1033# TODO would it be meaningful to add correlation of the fit result with the data
1034# and hyperprior?
1036# TODO add the second order correction. It probably requires more than the
1037# gradient and inv_hess, but maybe by getting a little help from
1038# marginal_likelihood I can use the least-squares optimized second order
1039# correction on the residuals term and invent something for the logdet term.
1041# TODO it raises very often with "Desired error not necessarily achieved due to
1042# precision loss.". I tried doing a forward grad on the logdet but does not fix
1043# the problem. I still suspect it's the logdet, maybe the value itself and not
1044# the derivative, because as the matrix changes the regularization can change a
1045# lot the value of the logdet. How do I stabilize it? => scipy's l-bfgs-b seems
1046# to fail the linear search less often
1048# TODO compute the logGBF for the whole fit (see the gpbart code). In its doc,
1049# specify that 1) additional_loss may break the normalization if the user does
1050# not know what they are doing 2) the calculation of the log determinant term
1051# heavily depends on the regularization if the covariance matrix is singular;
1052# this won't happen if there are independent error terms in the model as usual.
1054# TODO empbayes_fit(autoeps=True) tries to double epsabs until the minimization
1055# succedes, with some maximum number of tries. autoeps=dict(maxrepeat=5,
1056# increasefactor=2, initial=1e-16, startfromzero=True) allows to configure the
1057# algorithm.
1059# TODO empbayes_fit(maxiter=100) sets the maximum number of minimization
1060# iterations. maxiter=dict(iter=100, calls=200, callsperiter=10) allows to
1061# configure it more finely. The calls limits are cumulative on all functions
1062# (need to make a class counter in _CountCalls), I can probably implement them
1063# by returning nan when the limit is surpassed, I hope the minimizer stops
1064# immediately on nan (test this). => Callback can raise StopIteration.
1066# TODO can I approximate the hessian with only function values and no gradient,
1067# i.e., when using nelder-mead? => See Hare (2022), although I would not know
1068# how to apply it properly to the optimization history. Somehow I need to keep
1069# only the "last" iterations.
1071# TODO is there a better algorithm than lbfgs for inaccurate functions? consider
1072# SC-BFGS (https://github.com/frankecurtis/SCBFGS). See Basak (2022). And NonOpt
1073# (https://github.com/frankecurtis/NonOpt).
1075# TODO can I estimate the error on the likelihood with the matrices? It requires
1076# the condition number. Basak (2022) gives wide bounds. I could try an upper
1077# bound and see how it compares to the true error, assuming that the matrix was
1078# as ill-conditioned as possible, i.e., use eps as the lowest eigenvalue, and
1079# gershgorin as the highest one.
1081# TODO look into jaxopt: it has improved a lot since the last time I saw it. In
1082# particular, it implements l-bfgs and has a "do not stop on failed line search"
1083# option. And it probably supports float32, although a skim of the docs suggests
1084# it does not work well. => See also optimistix.
1086# TODO reimplement the timing system with host_callback.id_tap. It should
1087# preserve the order because id_tap takes inputs and outputs. I must take care
1088# to make all callbacks happen at runtime instead of having some of them at
1089# compile time. I tried once but failed. Currently host_callback is
1090# experimental, maybe wait until it isn't. => I think it fails because it's
1091# asynchronous and there is only one device. Maybe host_callback.call would
1092# work? => I think they are developing something like my token machinery.
1094# TODO dictionary argument jitkw, arguments passed to jax.jit?
1096# TODO parameter float32: bool to use short float type. I think that scipy's
1097# optimize may break down with short floats with default options, I hope that
1098# changing termination tolerances does the trick.
1100# TODO make separate_jac a parameter
1102# TODO add options in _CountCalls to track inputs and/or outputs to some maximum
1103# buffer length, activate it if the method (after applying user options,
1104# lowercasing, and inferring minimize's default) is l-bfgs-b and the covariance
1105# is minhess or auto, to the order specified in the arguments to l-bfgs-b (after
1106# defaults inference if missing) (add tests in test_fit to check that the
1107# defaults stay as inferred), to be used if l-bfgs-b returns a crooked hessian.
1108# --- Alternative: if covariance = 'auto', it could be appropriate to use fisher
1109# per definition. --- Alternative: add option covariance = 'lbfgs(<order>)' that
1110# does this for any method, although this would require computing the gradients
1111# afterwards if the gradient was not used. These alternatives are not mutually
1112# exclusive.
1114# TODO make a helper function/class method that takes in data transf dependent
1115# on hypers and outputs additional loss (the log jacobian of the appropriate
1116# function with the appropriate sign)