Coverage for src/lsqfitgp/_fit.py: 84%
550 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_fit.py
2#
3# Copyright (c) 2020, 2022, 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import re 1feabcd
21import warnings 1feabcd
22import functools 1feabcd
23import time 1feabcd
24import textwrap 1feabcd
25import datetime 1feabcd
27import gvar 1feabcd
28import jax 1feabcd
29from jax import numpy as jnp 1feabcd
30import numpy 1feabcd
31from scipy import optimize 1feabcd
32from jax import tree_util 1feabcd
34from . import _GP 1feabcd
35from . import _linalg 1feabcd
36from . import _jaxext 1feabcd
37from . import _gvarext 1feabcd
38from . import _array 1feabcd
40# TODO the following token_ thing functionality may be provided by jax in the
41# future, follow the developments
43@functools.singledispatch 1feabcd
44def token_getter(x): 1feabcd
45 return x
47@functools.singledispatch 1feabcd
48def token_setter(x, token): 1feabcd
49 return token
51@token_getter.register(jnp.ndarray) 1feabcd
52@token_getter.register(numpy.ndarray) 1feabcd
53def _(x): 1feabcd
54 return x[x.ndim * (0,)] if x.size else x 1feabcd
56@token_setter.register(jnp.ndarray) 1feabcd
57@token_setter.register(numpy.ndarray) 1feabcd
58def _(x, token): 1feabcd
59 x = jnp.asarray(x) 1feabcd
60 return x.at[x.ndim * (0,)].set(token) if x.size else token 1feabcd
62def token_map_leaf(func, x): 1feabcd
63 if isinstance(x, (jnp.ndarray, numpy.ndarray)): 63 ↛ 74line 63 didn't jump to line 74 because the condition on line 63 was always true1feabcd
64 token = token_getter(x) 1feabcd
65 @jax.custom_jvp 1feabcd
66 def jaxfunc(token): 1feabcd
67 return jax.pure_callback(func, token, token, vectorized=True) 1feabcd
68 @jaxfunc.defjvp 1feabcd
69 def _(p, t): 1feabcd
70 return (jaxfunc(*p), *t) 1feabcd
71 token = jaxfunc(token) 1feabcd
72 return token_setter(x, token) 1feabcd
73 else:
74 token = token_getter(x)
75 token = func(token)
76 return token_setter(x, token)
78def token_map(func, x): 1feabcd
79 return tree_util.tree_map(lambda x: token_map_leaf(func, x), x) 1feabcd
81class Logger: 1feabcd
82 """ Class to manage a log. Can be used as superclass. Each line of the log
83 has a verbosity level (an integer >= 0) and is printed only if this level is
84 below a threshold. All lines are saved and the log can be retrieved. """
86 def __init__(self, target_verbosity=0): 1feabcd
87 """ set the threshold used to exclude log lines """
88 self._verbosity = target_verbosity 1feabcd
89 self._loggedlines = [] 1feabcd
91 def _indent(self, text, level=0): 1feabcd
92 """ indent a text by provided level or by global current level """
93 level = max(0, level + self.loglevel._level) 1eabcd
94 prefix = 4 * level * ' ' 1eabcd
95 return textwrap.indent(text, prefix) 1eabcd
97 def _select(self, verbosity, target_verbosity=None): 1feabcd
98 if target_verbosity is None: 98 ↛ 100line 98 didn't jump to line 100 because the condition on line 98 was always true1feabcd
99 target_verbosity = self._verbosity 1feabcd
100 if isinstance(verbosity, int): 1feabcd
101 return target_verbosity >= verbosity 1feabcd
102 else:
103 return target_verbosity in verbosity 1feabcd
105 def log(self, message, verbosity=1, *, level=0): 1feabcd
106 """
107 Print and record a message.
109 Parameters
110 ----------
111 message : str
112 The message to print. A newline is added unconditionally.
113 verbosity : int or set, default 1
114 The verbosity level(s) at which the message is printed. If an
115 integer, it's printed at all levels >= that integer. If a set, at
116 the specified levels.
117 level : int, default 0
118 The indentation level of the message.
119 """
120 if self._select(verbosity): 1feabcd
121 print(self._indent(message, level)) 1eabcd
122 self._loggedlines.append((message, verbosity, level + self.loglevel._level)) 1feabcd
124 def getlog(self, target_verbosity=None, *, base_level=0): 1feabcd
125 """ return all logged line as a single string """
126 return '\n'.join(
127 self._indent(message, base_level + level)
128 for message, verbosity, level in self._loggedlines
129 if self._select(verbosity, target_verbosity)
130 )
132 class _LogLevel: 1feabcd
133 """ shared context manager to indent messages """
135 _level = 0 1feabcd
137 @classmethod 1feabcd
138 def __enter__(cls): 1feabcd
139 cls._level += 1 1feabcd
141 @classmethod 1feabcd
142 def __exit__(cls, *_): 1feabcd
143 cls._level -= 1 1feabcd
145 loglevel = _LogLevel() 1feabcd
147class empbayes_fit(Logger): 1feabcd
149 SEPARATE_JAC = False 1feabcd
151 def __init__( 1feabcd
152 self,
153 hyperprior,
154 gpfactory,
155 data,
156 *,
157 raises=True,
158 minkw={},
159 gpfactorykw={},
160 jit=True,
161 method='gradient',
162 initial='priormean',
163 verbosity=0,
164 covariance='auto',
165 fix=None,
166 mlkw={},
167 forward=False,
168 additional_loss=None,
169 ):
170 """
172 Maximum a posteriori fit.
174 Maximizes the marginal likelihood of the data with a Gaussian process
175 model that depends on hyperparameters, multiplied by a prior on the
176 hyperparameters.
178 Parameters
179 ----------
180 hyperprior : scalar, array or dictionary of scalars/arrays
181 A collection of gvars representing the prior for the
182 hyperparameters.
183 gpfactory : callable
184 A function with signature gpfactory(hyperparams) -> GP object. The
185 argument ``hyperparams`` has the same structure of the
186 `empbayes_fit` argument ``hyperprior``. gpfactory must be
187 JAX-friendly, i.e., use `jax.numpy` and `jax.scipy` instead of plain
188 `numpy`/`scipy` and avoid assignments to arrays.
189 data : dict, tuple or callable
190 Dictionary of data that is passed to `GP.marginal_likelihood` on
191 the GP object returned by ``gpfactory``. If a tuple, it contains the
192 first two arguments to `GP.marginal_likelihood`. If a callable, it
193 is called with the same arguments of ``gpfactory`` and must return
194 the argument(s) for `GP.marginal_likelihood`.
195 raises : bool, optional
196 If True (default), raise an error when the minimization fails.
197 Otherwise, use the last point of the minimization as result.
198 minkw : dict, optional
199 Keyword arguments passed to `scipy.optimize.minimize`, overwrites
200 values specified by `empbayes_fit`.
201 gpfactorykw : dict, optional
202 Keyword arguments passed to ``gpfactory``, and also to ``data`` if
203 it is a callable. If ``jit``, ``gpfactorykw`` crosses a `jax.jit`
204 boundary, so it must contain objects understandable by `jax`.
205 jit : bool
206 If True (default), use `jax.jit` to compile the minimization target.
207 method : str
208 Minimization strategy. Options:
210 'nograd'
211 Use a gradient-free method.
212 'gradient' (default)
213 Use a gradient-only method.
214 'fisher'
215 Use a Newton method with the Fisher information matrix plus
216 the hyperprior precision matrix.
217 initial : str, scalar, array, dictionary of scalars/arrays
218 Starting point for the minimization, matching the format of
219 ``hyperprior``, or one of the following options:
221 'priormean' (default)
222 Start from the hyperprior mean.
223 'priorsample'
224 Take a random sample from the hyperprior.
225 verbosity : int
226 An integer indicating how much information is printed on the
227 terminal:
229 0 (default)
230 No logging.
231 1
232 Minimal report.
233 2
234 Detailed report.
235 3
236 Log each iteration.
237 4
238 More detailed iteration log.
239 5
240 Print the current parameter values at each iteration.
241 covariance : str
242 Method to estimate the posterior covariance matrix of the
243 hyperparameters:
245 'fisher'
246 Use the Fisher information in the MAP, plus the prior precision,
247 as precision matrix.
248 'minhess'
249 Use the hessian estimate of the minimizer as precision matrix.
250 'none'
251 Do not estimate the covariance matrix.
252 'auto' (default)
253 ``'minhess'`` if applicable, ``'none'`` otherwise.
254 fix : scalar, array or dictionary of scalars/arrays
255 A set of booleans, with the same format as ``hyperprior``,
256 indicating which hyperparameters are kept fixed to their initial
257 value. Scalars and arrays are broadcasted to the shape of
258 ``hyperprior``. If a dictionary, missing keys are treated as False.
259 mlkw : dict
260 Additional arguments passed to `GP.marginal_likelihood`.
261 forward : bool, default False
262 Use forward instead of backward derivatives. Typically, forward is
263 faster with a small number of parameters.
264 additional_loss : callable, optional
265 A function with signature ``additional_loss(hyperparams) -> float``
266 which is added to the minus log marginal posterior of the
267 hyperparameters.
269 Attributes
270 ----------
271 p : scalar, array or dictionary of scalars/arrays
272 A collection of gvars representing the hyperparameters that
273 maximize their posterior. These gvars do not track correlations
274 with the hyperprior or the data.
275 prior : scalar, array or dictionary of scalars/arrays
276 A copy of the hyperprior.
277 initial : scalar, array or dictionary of scalars/arrays
278 Starting point of the minimization, with the same format as ``p``.
279 fix : scalar, array or dictionary of scalars/arrays
280 A set of booleans, with the same format as ``p``, indicating which
281 parameters were kept fixed to the values in ``initial``.
282 pmean : scalar, array or dictionary of scalars/arrays
283 Mean of ``p``.
284 pcov : scalar, array or dictionary of scalars/arrays
285 Covariance matrix of ``p``.
286 minresult : scipy.optimize.OptimizeResult
287 The result object returned by `scipy.optimize.minimize`.
288 minargs : dict
289 The arguments passed to `scipy.optimize.minimize`.
290 gpfactory : callable
291 The ``gpfactory`` argument.
292 gpfactorykw : dict
293 The ``gpfactorykw`` argument.
294 data : dict, tuple or callable
295 The ``data`` argument.
297 Raises
298 ------
299 RuntimeError
300 The minimization failed and ``raises`` is True.
302 """
304 Logger.__init__(self, verbosity) 1feabcd
305 del verbosity 1feabcd
306 self.log('**** call lsqfitgp.empbayes_fit ****') 1feabcd
308 assert callable(gpfactory) 1feabcd
310 # analyze the hyperprior
311 hpinitial, hpunflat = self._parse_hyperprior(hyperprior, initial, fix) 1feabcd
312 del hyperprior, initial, fix 1feabcd
314 # analyze data
315 data, cachedargs = self._parse_data(data) 1feabcd
317 # define functions
318 timer, functions = self._prepare_functions( 1feabcd
319 gpfactory=gpfactory, gpfactorykw=gpfactorykw, data=data,
320 cachedargs=cachedargs, hpunflat=hpunflat, mlkw=mlkw, jit=jit,
321 forward=forward, additional_loss=additional_loss,
322 )
323 del gpfactory, gpfactorykw, data, cachedargs, mlkw, forward, additional_loss 1feabcd
325 # prepare minimizer arguments
326 minargs = self._prepare_minargs(method, functions, hpinitial) 1feabcd
328 # set up callback to time and log iterations
329 callback = self._Callback(self, functions, timer, hpunflat) 1feabcd
330 minargs.update(callback=callback) 1feabcd
332 # check invalid argument before running minimizer
333 if not covariance in ('auto', 'fisher', 'minhess', 'none'): 333 ↛ 334line 333 didn't jump to line 334 because the condition on line 333 was never true1feabcd
334 raise KeyError(covariance)
336 # add user arguments and minimize
337 minargs.update(minkw) 1feabcd
338 self.log(f'minimizer method {minargs["method"]!r}', 2) 1feabcd
339 total = time.perf_counter() 1feabcd
340 result = optimize.minimize(**minargs) 1feabcd
342 # check the minimization was successful
343 self._check_success(result, raises) 1feabcd
345 # compute posterior covariance of the hyperparameters
346 cov = self._posterior_covariance(method, covariance, result, functions['fisher']) 1feabcd
348 # log total timings and function calls
349 total = time.perf_counter() - total 1feabcd
350 self._log_totals(total, timer, callback, jit, functions) 1feabcd
352 # join posterior mean and covariance matrix
353 uresult = gvar.gvar(result.x, cov) 1feabcd
355 # set attributes
356 self.p = gvar.gvar(hpunflat(uresult)) 1feabcd
357 self.pmean = gvar.mean(self.p) 1feabcd
358 self.pcov = gvar.evalcov(self.p) 1feabcd
359 self.minresult = result 1feabcd
360 self.minargs = minargs 1feabcd
362 # tabulate hyperparameter prior and posterior
363 if self._verbosity >= 2: 1feabcd
364 self.log(_gvarext.tabulate_together( 1eabcd
365 self.prior, self.p,
366 headers=['param', 'prior', 'posterior'],
367 )) # TODO replace tabulate_toegether with something more flexible I
368 # can use for the callback as well. Maybe import TextMatrix from
369 # miscpy.
370 # TODO print the transformed parameters
372 self.log('**** exit lsqfitgp.empbayes_fit ****') 1feabcd
374 class _CountCalls: 1feabcd
375 """ wrap a callable to count calls """
377 def __init__(self, func): 1feabcd
378 self._func = func 1feabcd
379 self._total = 0 1feabcd
380 self._partial = 0 1feabcd
381 functools.update_wrapper(self, func) 1feabcd
383 def __call__(self, *args, **kw): 1feabcd
384 self._total += 1 1feabcd
385 self._partial += 1 1feabcd
386 return self._func(*args, **kw) 1feabcd
388 def partial(self): 1feabcd
389 """ return the partial counter and reset it """
390 result = self._partial 1feabcd
391 self._partial = 0 1feabcd
392 return result 1feabcd
394 def total(self): 1feabcd
395 """ return the total number of calls """
396 return self._total 1feabcd
398 @staticmethod 1feabcd
399 def fmtcalls(method, functions): 1feabcd
400 """
401 format summary of number of calls
402 method : str
403 functions: dict[str, _CountCalls]
404 """
405 def counts(): 1feabcd
406 for name, func in functions.items(): 1feabcd
407 if count := getattr(func, method)(): 1feabcd
408 yield f'{name} {count}' 1feabcd
409 return ', '.join(counts()) 1feabcd
411 class _Timer: 1feabcd
412 """ object to time likelihood computations """
414 def __init__(self): 1feabcd
415 self.totals = {} 1feabcd
416 self.partials = {} 1feabcd
417 self._last_start = False 1feabcd
419 def start(self, token): 1feabcd
420 return token_map(self._start, token) 1feabcd
422 def _start(self, token): 1feabcd
423 self.stamp = time.perf_counter() 1abcd
424 self.counter = 0 1abcd
425 assert not self._last_start # forbid consecutive start() calls 1abcd
426 self._last_start = True 1abcd
427 return token 1abcd
429 def reset(self): 1feabcd
430 self.partials = {} 1feabcd
432 def partial(self, token): 1feabcd
433 return token_map(self._partial, token) 1feabcd
435 def _partial(self, token): 1feabcd
436 now = time.perf_counter() 1abcd
437 delta = now - self.stamp 1abcd
438 self.partials[self.counter] = self.partials.get(self.counter, 0) + delta 1abcd
439 self.totals[self.counter] = self.totals.get(self.counter, 0) + delta 1abcd
440 self.stamp = now 1abcd
441 self.counter += 1 1abcd
442 self._last_start = False 1abcd
443 return token 1abcd
445 def _parse_hyperprior(self, hyperprior, initial, fix): 1feabcd
447 # check fix against hyperprior and fill missing values
448 hyperprior = self._copyasarrayorbufferdict(hyperprior) 1feabcd
449 self._check_no_redundant_keys(hyperprior) 1feabcd
450 fix = self._parse_fix(hyperprior, fix) 1feabcd
451 flatfix = self._flatview(fix) 1feabcd
453 # extract distribution of free hyperparameters
454 flathp = self._flatview(hyperprior) 1feabcd
455 freehp = flathp[~flatfix] 1feabcd
456 mean = gvar.mean(freehp) 1feabcd
457 cov = gvar.evalcov(freehp) # TODO use evalcov_blocks 1feabcd
458 dec = _linalg.Chol(cov) 1feabcd
459 assert dec.n == freehp.size 1feabcd
460 self.log(f'{freehp.size}/{flathp.size} free hyperparameters', 2) 1feabcd
462 # determine starting point for minimization
463 initial = self._parse_initial(hyperprior, initial, dec) 1feabcd
464 flatinitial = self._flatview(initial) 1feabcd
465 x0 = dec.pinv_correlate(flatinitial[~flatfix] - mean) 1feabcd
466 # TODO for initial = 'priormean', x0 is zero, skip decorrelate
467 # for initial = 'priorsample', x0 is iid normal, but I have to sync
468 # it with the user-exposed unflattened initial in _parse_initial
470 # make function to correlate, add fixed values, and reshape to original
471 # format
472 fixed_indices, = jnp.nonzero(flatfix) 1feabcd
473 unfixed_indices, = jnp.nonzero(~flatfix) 1feabcd
474 fixed_values = jnp.asarray(flatinitial[flatfix]) 1feabcd
475 def unflat(x): 1feabcd
476 assert x.ndim == 1 1feabcd
477 if x.dtype == object: 1feabcd
478 jac, indices = _gvarext.jacobian(x) 1feabcd
479 xmean = mean + dec.correlate(gvar.mean(x)) 1feabcd
480 xjac = dec.correlate(jac) 1feabcd
481 x = _gvarext.from_jacobian(xmean, xjac, indices) 1feabcd
482 y = numpy.empty(flatfix.size, x.dtype) 1feabcd
483 numpy.put(y, unfixed_indices, x) 1feabcd
484 numpy.put(y, fixed_indices, fixed_values) 1feabcd
485 else:
486 x = mean + dec.correlate(x) 1feabcd
487 y = jnp.empty(flatfix.size, x.dtype) 1feabcd
488 y = y.at[unfixed_indices].set(x) 1feabcd
489 y = y.at[fixed_indices].set(fixed_values) 1feabcd
490 return self._unflatview(y, hyperprior) 1feabcd
492 self.prior = hyperprior 1feabcd
493 return x0, unflat 1feabcd
495 @staticmethod 1feabcd
496 def _check_no_redundant_keys(hyperprior): 1feabcd
497 if not hasattr(hyperprior, 'keys'): 1feabcd
498 return 1abcd
499 for k in hyperprior: 1feabcd
500 m = hyperprior.extension_pattern.match(k) 1feabcd
501 if m and m.group(1) in hyperprior.invfcn: 1feabcd
502 altk = m.group(2) 1feabcd
503 if altk in hyperprior: 503 ↛ 504line 503 didn't jump to line 504 because the condition on line 503 was never true1feabcd
504 raise ValueError(f'duplicate keys {altk!r} and {k!r} in hyperprior')
506 def _parse_fix(self, hyperprior, fix): 1feabcd
508 if fix is None: 508 ↛ 514line 508 didn't jump to line 514 because the condition on line 508 was always true1feabcd
509 if hasattr(hyperprior, 'keys'): 1feabcd
510 fix = gvar.BufferDict(hyperprior, buf=numpy.zeros(hyperprior.size, bool)) 1feabcd
511 else:
512 fix = numpy.zeros(hyperprior.shape, bool) 1abcd
513 else:
514 fix = self._copyasarrayorbufferdict(fix)
515 if hasattr(fix, 'keys'):
516 assert hasattr(hyperprior, 'keys'), 'fix is dictionary but hyperprior is array'
517 assert all(hyperprior.has_dictkey(k) for k in fix), 'some keys in fix are missing in hyperprior'
518 newfix = {}
519 for k, v in hyperprior.items():
520 key = None
521 m = hyperprior.extension_pattern.match(k)
522 if m and m.group(1) in hyperprior.invfcn:
523 altk = m.group(2)
524 if altk in fix:
525 assert k not in fix, f'duplicate keys {k!r} and {altk!r} in fix'
526 key = altk
527 if key is None and k in fix:
528 key = k
529 if key is None:
530 elem = numpy.zeros(v.shape, bool)
531 else:
532 elem = numpy.broadcast_to(fix[key], v.shape)
533 newfix[k] = elem
534 fix = gvar.BufferDict(newfix, dtype=bool)
535 else:
536 assert not hasattr(hyperprior, 'keys'), 'fix is array but hyperprior is dictionary'
537 fix = numpy.broadcast_to(fix, hyperprior.shape).astype(bool)
539 self.fix = fix 1feabcd
540 return fix 1feabcd
542 def _parse_initial(self, hyperprior, initial, dec): 1feabcd
544 if not isinstance(initial, str): 544 ↛ 545line 544 didn't jump to line 545 because the condition on line 544 was never true1feabcd
545 self.log('start from provided point', 2)
546 initial = self._copyasarrayorbufferdict(initial)
547 if hasattr(hyperprior, 'keys'):
548 assert hasattr(initial, 'keys'), 'hyperprior is dictionary but initial is array'
549 assert set(hyperprior.keys()) == set(initial.keys())
550 assert all(hyperprior[k].shape == initial[k].shape for k in hyperprior)
551 else:
552 assert not hasattr(initial, 'keys'), 'hyperprior is array but initial is dictionary'
553 assert hyperprior.shape == initial.shape
555 elif initial == 'priormean': 555 ↛ 559line 555 didn't jump to line 559 because the condition on line 555 was always true1feabcd
556 self.log('start from prior mean', 2) 1feabcd
557 initial = gvar.mean(hyperprior) 1feabcd
559 elif initial == 'priorsample':
560 self.log('start from a random sample from the prior', 2)
561 if dec.n < hyperprior.size:
562 flathp = self._flatview(hyperprior)
563 cov = gvar.evalcov(flathp) # TODO use evalcov_blocks
564 fulldec = _linalg.Chol(cov)
565 else:
566 fulldec = dec
567 iid = numpy.random.randn(fulldec.m)
568 flatinitial = numpy.asarray(fulldec.correlate(iid))
569 initial = self._unflatview(flatinitial, hyperprior)
571 else:
572 raise KeyError(initial)
574 self.initial = initial 1feabcd
575 return initial 1feabcd
577 def _parse_data(self, data): 1feabcd
579 self.data = data 1feabcd
580 if isinstance(data, tuple) and len(data) == 1: 1feabcd
581 data, = data 1abcd
583 if callable(data): 1feabcd
584 self.log('data is callable', 2) 1eabcd
585 cachedargs = None 1eabcd
586 elif isinstance(data, tuple): 1feabcd
587 self.log('data errors provided separately', 2) 1abcd
588 assert len(data) == 2 1abcd
589 cachedargs = data 1abcd
590 elif (gdata := self._copyasarrayorbufferdict(data)).dtype == object: 1feabcd
591 self.log('data has errors as gvars', 2) 1eabcd
592 data = gvar.gvar(gdata) 1eabcd
593 # convert to gvar because non-gvars in the array would upset
594 # gvar.mean and gvar.evalcov
595 cachedargs = (gvar.mean(data), gvar.evalcov(data)) 1eabcd
596 else:
597 self.log('data has no errors', 2) 1feabcd
598 cachedargs = (data,) 1feabcd
600 return data, cachedargs 1feabcd
602 def _prepare_functions(self, *, gpfactory, gpfactorykw, data, cachedargs, 1feabcd
603 hpunflat, mlkw, jit, forward, additional_loss):
605 timer = self._Timer() 1feabcd
606 firstcall = [None] 1feabcd
608 def make_decomp(p, **kw): 1feabcd
609 """ decomposition of the prior covariance and data """
611 # start timer and convert hypers to user format
612 p = timer.start(p) 1feabcd
613 hp = hpunflat(p) 1feabcd
615 # create GP object
616 gp = gpfactory(hp, **kw) 1feabcd
617 assert isinstance(gp, _GP.GP) 1feabcd
619 # extract data
620 if cachedargs: 1feabcd
621 args = cachedargs 1feabcd
622 else:
623 args = data(hp, **kw) 1eabcd
624 if not isinstance(args, tuple): 1eabcd
625 args = (args,) 1eabcd
627 # decompose covariance matrix and flatten data
628 decomp, r = gp._prior_decomp(*args, covtransf=timer.partial, **mlkw) 1feabcd
629 r = r.astype(float) # int data upsets jax 1feabcd
631 # log number of datapoints
632 if firstcall: 1feabcd
633 # it is convenient to do here because the data is flattened.
634 # works under jit since the first call is tracing
635 firstcall.pop() 1feabcd
636 xdtype = gp._get_x_dtype() 1feabcd
637 nd = '?' if xdtype is None else _array._nd(xdtype) 1feabcd
638 self.log(f'{r.size} datapoints, {nd} covariates') 1feabcd
640 # compute user loss
641 if additional_loss is None: 1feabcd
642 loss = 0. 1feabcd
643 else:
644 loss = additional_loss(hp) 1eabcd
646 # split timer and return decomposition
647 return timer.partial(decomp), r, loss 1feabcd
648 # TODO what's the correct way of checkpointing r?
650 # define wrapper to collect call stats, pass user args, compile
651 def wrap(func): 1feabcd
652 if jit: 652 ↛ 654line 652 didn't jump to line 654 because the condition on line 652 was always true1feabcd
653 func = jax.jit(func) 1feabcd
654 func = functools.partial(func, **gpfactorykw) 1feabcd
655 return self._CountCalls(func) 1feabcd
656 if jit: 656 ↛ 660line 656 didn't jump to line 660 because the condition on line 656 was always true1feabcd
657 self.log('compile functions with jax jit', 2) 1feabcd
659 # log derivation method
660 modename = 'forward' if forward else 'reverse' 1feabcd
661 self.log(f'{modename}-mode autodiff (if used)', 2) 1feabcd
663 # TODO time the derivatives separately => maybe I need a custom
664 # derivative rule for timer token acknoledgement?
666 def prior(p): 1feabcd
667 # the marginal prior of the hyperparameters is a Normal with
668 # identity covariance matrix because p is transformed to make it so
669 return 1/2 * (len(p) * jnp.log(2 * jnp.pi) + p @ p) 1feabcd
671 def grad_prior(p): 1feabcd
672 return p 1feabcd
674 def fisher_prior(p): 1feabcd
675 return jnp.eye(len(p)) 1abcd
677 @wrap 1feabcd
678 def fun(p, **kw): 1feabcd
679 """ minus log marginal posterior of the hyperparameters (not
680 normalized) """
681 decomp, r, loss = make_decomp(p, **kw) 1abcd
682 cond, _, _, _, _ = decomp.minus_log_normal_density(r, value=True) 1abcd
683 post = cond + prior(p) + loss 1abcd
684 # TODO what's the correct way of checkpointing prior and loss?
685 return timer.partial(post) 1abcd
687 def make_gradfwd_fisher_args(p, **kw): 1feabcd
688 def make_decomp_tee(p): 1eabcd
689 decomp, r, loss = make_decomp(p, **kw) 1eabcd
690 return (decomp.matrix(), r, loss), (decomp, r, loss) 1eabcd
691 (dK, dr, grad_loss), (decomp, r, loss) = jax.jacfwd(make_decomp_tee, has_aux=True)(p) 1eabcd
692 lkw = dict(dK=dK, dr=dr) 1eabcd
693 return decomp, r, lkw, loss, grad_loss 1eabcd
695 def make_gradrev_args(p, **kw): 1feabcd
696 def make_decomp_loss(p): 1feabcd
697 def make_decomp_r(p): 1feabcd
698 def make_decomp_K(p): 1feabcd
699 decomp, r, loss = make_decomp(p, **kw) 1feabcd
700 return decomp.matrix(), (decomp, r, loss) 1feabcd
701 _, dK_vjp, (decomp, r, loss) = jax.vjp(make_decomp_K, p, has_aux=True) 1feabcd
702 return r, (decomp, r, dK_vjp, loss) 1feabcd
703 _, dr_vjp, (decomp, r, dK_vjp, loss) = jax.vjp(make_decomp_r, p, has_aux=True) 1feabcd
704 return loss, (decomp, r, dK_vjp, dr_vjp, loss) 1feabcd
705 grad_loss, (decomp, r, dK_vjp, dr_vjp, loss) = jax.grad(make_decomp_loss, has_aux=True)(p) 1feabcd
706 unpack = lambda f: lambda x: f(x)[0] 1feabcd
707 dK_vjp = unpack(dK_vjp) 1feabcd
708 dr_vjp = unpack(dr_vjp) 1feabcd
709 lkw = dict(dK_vjp=dK_vjp, dr_vjp=dr_vjp) 1feabcd
710 return decomp, r, lkw, loss, grad_loss 1feabcd
712 def make_jac_args(p, **kw): 1feabcd
713 if forward: 1feabcd
714 out = make_gradfwd_fisher_args(p, **kw) 1eabcd
715 out[2].update(gradfwd=True) # out[2] is lkw 1eabcd
716 else:
717 out = make_gradrev_args(p, **kw) 1feabcd
718 out[2].update(gradrev=True) 1feabcd
719 return out 1feabcd
721 @wrap 1feabcd
722 def fun_and_jac(p, **kw): 1feabcd
723 """ fun and its gradient """
724 decomp, r, lkw, loss, grad_loss = make_jac_args(p, **kw) 1feabcd
725 cond, gradrev, gradfwd, _, _ = decomp.minus_log_normal_density(r, value=True, **lkw) 1feabcd
726 post = cond + prior(p) + loss 1feabcd
727 grad_cond = gradfwd if forward else gradrev 1feabcd
728 grad_post = grad_cond + grad_prior(p) + grad_loss 1feabcd
729 return timer.partial((post, grad_post)) 1feabcd
731 @wrap 1feabcd
732 def jac(p, **kw): 1feabcd
733 """ gradient of fun """
734 decomp, r, lkw, _, grad_loss = make_jac_args(p, **kw)
735 _, gradrev, gradfwd, _, _ = decomp.minus_log_normal_density(r, **lkw)
736 grad_cond = gradfwd if forward else gradrev
737 grad_post = grad_cond + grad_prior(p) + grad_loss
738 return timer.partial(grad_post)
740 @wrap 1feabcd
741 def fisher(p, **kw): 1feabcd
742 """ fisher matrix """
743 if additional_loss is not None: 1abcd
744 raise NotImplementedError( 1abcd
745 'Fisher matrix not implemented with additional_loss. It '
746 'is possible but I did not prioritize it. If you need it, '
747 'open an issue on github.')
748 decomp, r, lkw, _, _ = make_gradfwd_fisher_args(p, **kw) 1abcd
749 _, _, _, fisher_cond, _ = decomp.minus_log_normal_density(r, fisher=True, **lkw) 1abcd
750 fisher_post = fisher_cond + fisher_prior(p) 1abcd
751 return timer.partial(fisher_post) 1abcd
753 # set attributes
754 self.gpfactory = gpfactory 1feabcd
755 self.gpfactorykw = gpfactorykw 1feabcd
757 return timer, { 1feabcd
758 'fun': fun,
759 'jac': jac,
760 'fun&jac': fun_and_jac,
761 'fisher': fisher,
762 }
764 def _prepare_minargs(self, method, functions, hpinitial): 1feabcd
765 minargs = dict(fun=functions['fun&jac'], jac=True, x0=hpinitial) 1feabcd
766 if self.SEPARATE_JAC: 766 ↛ 767line 766 didn't jump to line 767 because the condition on line 766 was never true1feabcd
767 minargs.update(fun=functions['fun'], jac=functions['jac'])
768 if method == 'nograd': 1feabcd
769 minargs.update(fun=functions['fun'], jac=None, method='nelder-mead') 1abcd
770 elif method == 'gradient': 1feabcd
771 minargs.update(method='bfgs') 1feabcd
772 elif method == 'fisher': 1abcd
773 minargs.update(hess=functions['fisher'], method='dogleg') 1abcd
774 # dogleg requires positive definiteness, fisher is p.s.d.
775 # trust-constr has more options, but it seems to be slower than
776 # dogleg, so I keep dogleg as default
777 else:
778 raise KeyError(method) 1abcd
779 self.log(f'method {method!r}', 2) 1feabcd
780 return minargs 1feabcd
782 # TODO add method with fisher matvec instead of fisher matrix
784 def _log_totals(self, total, timer, callback, jit, functions): 1feabcd
785 times = { 1feabcd
786 'gp&cov': timer.totals[0],
787 'decomp': timer.totals[1],
788 'likelihood': timer.totals[2],
789 'jit': None, # set now and delete later to keep it before 'other'
790 'other': total - sum(timer.totals.values()),
791 }
792 if jit: 792 ↛ 799line 792 didn't jump to line 799 because the condition on line 792 was always true1feabcd
793 overhead = callback.estimate_firstcall_overhead() 1feabcd
794 # TODO this estimation ignores the jit compilation of the function
795 # used to compute the precision matrix, to be precise I should
796 # manually split the jit into compilation + evaluation or hook into
797 # it somehow. Maybe the jit object keeps a compilation wall time
798 # stat?
799 if jit and overhead is not None: 1feabcd
800 times['jit'] = overhead 1feabcd
801 times['other'] -= overhead 1feabcd
802 else:
803 del times['jit'] 1abcd
804 self.log('', 4) 1feabcd
805 calls = self._CountCalls.fmtcalls('total', functions) 1feabcd
806 self.log(f'calls: {calls}') 1feabcd
807 self.log(f'total time: {callback.fmttime(total)}') 1feabcd
808 self.log(f'partials: {callback.fmttimes(times)}', 2) 1feabcd
810 def _check_success(self, result, raises): 1feabcd
811 if result.success: 1feabcd
812 self.log(f'minimization succeeded: {result.message}') 1feabcd
813 else:
814 msg = f'minimization failed: {result.message}' 1fabcd
815 if raises: 1fabcd
816 raise RuntimeError(msg) 1abcd
817 elif self._verbosity == 0: 817 ↛ 820line 817 didn't jump to line 820 because the condition on line 817 was always true1f
818 warnings.warn(msg) 1f
819 else:
820 self.log(msg)
822 def _posterior_covariance(self, method, covariance, minimizer_result, fisher_func): 1feabcd
824 if covariance == 'auto': 824 ↛ 830line 824 didn't jump to line 830 because the condition on line 824 was always true1feabcd
825 if hasattr(minimizer_result, 'hess_inv') or hasattr(minimizer_result, 'hess'): 1feabcd
826 covariance = 'minhess' 1feabcd
827 else:
828 covariance = 'none' 1abcd
830 if covariance == 'fisher': 830 ↛ 831line 830 didn't jump to line 831 because the condition on line 830 was never true1feabcd
831 self.log('use fisher plus prior precision as precision', 2)
832 if method == 'fisher':
833 prec = minimizer_result.hess
834 else:
835 prec = fisher_func(minimizer_result.x)
836 cov = _linalg.Chol(prec).ginv()
838 elif covariance == 'minhess': 1feabcd
839 if hasattr(minimizer_result, 'hess_inv'): 1feabcd
840 hessinv = minimizer_result.hess_inv 1feabcd
841 if isinstance(hessinv, optimize.LbfgsInvHessProduct): 1feabcd
842 self.log(f'convert LBFGS({hessinv.n_corrs}) hessian inverse to BFGS as covariance', 2) 1eabcd
843 cov = self._invhess_lbfgs_to_bfgs(hessinv) 1eabcd
844 # TODO this still gives a too wide cov when the minimization
845 # terminates due to bad linear search, is it because of
846 # dropped updates? This is currently keeping me from setting
847 # l-bfgs-b as default minimization method.
848 elif isinstance(hessinv, numpy.ndarray): 848 ↛ 863line 848 didn't jump to line 863 because the condition on line 848 was always true1feabcd
849 self.log('use minimizer estimate of inverse hessian as covariance', 2) 1feabcd
850 cov = hessinv 1feabcd
851 elif hasattr(minimizer_result, 'hess'): 851 ↛ 855line 851 didn't jump to line 855 because the condition on line 851 was always true1abcd
852 self.log('use minimizer hessian as precision', 2) 1abcd
853 cov = _linalg.Chol(minimizer_result.hess).ginv() 1abcd
854 else:
855 raise RuntimeError('the minimizer did not return an estimate of the hessian')
857 elif covariance == 'none': 857 ↛ 861line 857 didn't jump to line 861 because the condition on line 857 was always true1abcd
858 cov = numpy.full(minimizer_result.x.size, numpy.nan) 1abcd
860 else:
861 raise KeyError(covariance)
863 return cov 1feabcd
865 @staticmethod 1feabcd
866 def _invhess_lbfgs_to_bfgs(lbfgs): 1feabcd
867 bfgs = optimize.BFGS() 1eabcd
868 bfgs.initialize(lbfgs.shape[0], 'inv_hess') 1eabcd
869 for i in range(lbfgs.n_corrs): 1eabcd
870 bfgs.update(lbfgs.sk[i], lbfgs.yk[i]) 1eabcd
871 return bfgs.get_matrix() 1eabcd
873 class _Callback: 1feabcd
874 """ Iteration callback for scipy.optimize.minimize """
876 def __init__(self, this, functions, timer, unflat): 1feabcd
877 self.it = 0 1feabcd
878 self.stamp = time.perf_counter() 1feabcd
879 self.this = this 1feabcd
880 self.functions = functions 1feabcd
881 self.timer = timer 1feabcd
882 self.unflat = unflat 1feabcd
883 self.tail_overhead = 0 1feabcd
884 self.tail_overhead_iter = 0 1feabcd
886 def __call__(self, intermediate_result, arg2=None): 1feabcd
888 if isinstance(intermediate_result, optimize.OptimizeResult): 888 ↛ 889line 888 didn't jump to line 889 because the condition on line 888 was never true1feabcd
889 p = intermediate_result.x
890 elif isinstance(intermediate_result, numpy.ndarray): 890 ↛ 893line 890 didn't jump to line 893 because the condition on line 890 was always true1feabcd
891 p = intermediate_result 1feabcd
892 else:
893 raise TypeError(type(intermediate_result))
895 self.it += 1 1feabcd
896 now = time.perf_counter() 1feabcd
897 duration = now - self.stamp 1feabcd
899 worktime = sum(self.timer.partials.values()) 1feabcd
900 if worktime: 1feabcd
901 overhead = duration - worktime 1feabcd
902 assert overhead >= 0, (duration, worktime) 1feabcd
903 if self.it == 1: 1feabcd
904 self.first_overhead = overhead 1feabcd
905 else:
906 self.tail_overhead_iter += 1 1feabcd
907 self.tail_overhead += overhead 1feabcd
909 # level 3 log
910 calls = self.this._CountCalls.fmtcalls('partial', self.functions) 1feabcd
911 times = self.fmttime(duration) 1feabcd
912 self.this.log(f'iter {self.it}, time: {times}, calls: {calls}', {3}) 1feabcd
914 # level 4 log
915 tot = self.fmttime(duration) 1feabcd
916 if self.timer.partials: 1feabcd
917 times = { 1feabcd
918 'gp&cov': self.timer.partials[0],
919 'dec': self.timer.partials[1],
920 'like': self.timer.partials[2],
921 'other': duration - sum(self.timer.partials.values()),
922 }
923 times = self.fmttimes(times) 1feabcd
924 else:
925 times = 'n/d' 1abcd
926 self.this.log(f'\niteration {self.it}', 4) 1feabcd
927 with self.this.loglevel: 1feabcd
928 self.this.log(f'total time: {tot}', 4) 1feabcd
929 self.this.log(f'partial: {times}', 4) 1feabcd
930 self.this.log(f'calls: {calls}', 4) 1feabcd
932 # level 5 log
933 nicep = self.unflat(p) 1feabcd
934 nicep = self.this._copyasarrayorbufferdict(nicep) 1feabcd
935 with self.this.loglevel: 1feabcd
936 self.this.log(f'parameters = {nicep}', 5) 1feabcd
937 # TODO write a method to format the parameters nicely. => use
938 # gvar.tabulate? => nope, need actual gvars
939 # TODO does this logging add significant overhead?
941 self.stamp = now 1feabcd
942 self.timer.reset() 1feabcd
944 pattern = re.compile( 1feabcd
945 r'((\d+) days, )?(\d{1,2}):(\d\d):(\d\d(\.\d{6})?)')
947 @classmethod 1feabcd
948 def fmttime(cls, seconds): 1feabcd
949 if seconds < 0: 949 ↛ 950line 949 didn't jump to line 950 because the condition on line 949 was never true1feabcd
950 prefix = '-'
951 seconds = -seconds
952 else:
953 prefix = '' 1feabcd
954 return prefix + cls._fmttime_positive(seconds) 1feabcd
956 @classmethod 1feabcd
957 def _fmttime_positive(cls, seconds): 1feabcd
958 td = datetime.timedelta(seconds=seconds) 1feabcd
959 m = cls.pattern.fullmatch(str(td)) 1feabcd
960 _, day, hour, minute, second, _ = m.groups() 1feabcd
961 hour = int(hour) 1feabcd
962 minute = int(minute) 1feabcd
963 second = float(second) 1feabcd
964 if day: 964 ↛ 965line 964 didn't jump to line 965 because the condition on line 964 was never true1feabcd
965 return f'{day.lstrip("0")}d{hour:02d}h'
966 elif hour: 966 ↛ 967line 966 didn't jump to line 967 because the condition on line 966 was never true1feabcd
967 return f'{hour}h{minute:02d}m'
968 elif minute: 1feabcd
969 return f'{minute}m{second:02.0f}s' 1e
970 elif second >= 0.0995: 1feabcd
971 return f'{second:#.2g}'.rstrip('.') + 's' 1feabcd
972 elif second >= 0.0000995: 1feabcd
973 return f'{second * 1e3:#.2g}'.rstrip('.') + 'ms' 1feabcd
974 else:
975 return f'{second * 1e6:.0f}μs' 1feab
977 @classmethod 1feabcd
978 def fmttimes(cls, times): 1feabcd
979 """ times = dict label -> seconds """
980 return ', '.join(f'{k} {cls.fmttime(v)}' for k, v in times.items()) 1feabcd
982 def estimate_firstcall_overhead(self): 1feabcd
983 if self.tail_overhead_iter and hasattr(self, 'first_overhead'): 1feabcd
984 typical_overhead = self.tail_overhead / self.tail_overhead_iter 1feabcd
985 return self.first_overhead - typical_overhead 1feabcd
987 @staticmethod 1feabcd
988 def _copyasarrayorbufferdict(x): 1feabcd
989 if hasattr(x, 'keys'): 1feabcd
990 return gvar.BufferDict(x) 1feabcd
991 else:
992 return numpy.array(x) 1abcd
994 @staticmethod 1feabcd
995 def _flatview(x): 1feabcd
996 if hasattr(x, 'reshape'): 1feabcd
997 return x.reshape(-1) 1abcd
998 elif hasattr(x, 'buf'): 1feabcd
999 return x.buf 1feabcd
1000 else: # pragma: no cover
1001 raise NotImplementedError
1003 @staticmethod 1feabcd
1004 def _unflatview(x, original): 1feabcd
1005 if isinstance(original, numpy.ndarray): 1feabcd
1006 # TODO is this never applied to jax arrays?
1007 out = x.reshape(original.shape) 1abcd
1008 # if not out.shape:
1009 # try:
1010 # out = out.item()
1011 # except jax.errors.ConcretizationTypeError:
1012 # pass
1013 return out 1abcd
1014 elif isinstance(original, gvar.BufferDict): 1feabcd
1015 # normally I would do BufferDict(original, buf=x) but it does not
1016 # work with JAX tracers
1017 b = gvar.BufferDict(original) 1feabcd
1018 b._extension = {} 1feabcd
1019 b._buf = x 1feabcd
1020 # b.buf = x does not work because BufferDict checks that the
1021 # array is a numpy array
1022 # TODO maybe make a feature request to gvar to accept array_like
1023 # buf
1024 return b 1feabcd
1025 else: # pragma: no cover
1026 raise NotImplementedError
1029# TODO would it be meaningful to add correlation of the fit result with the data
1030# and hyperprior?
1032# TODO add the second order correction. It probably requires more than the
1033# gradient and inv_hess, but maybe by getting a little help from
1034# marginal_likelihood I can use the least-squares optimized second order
1035# correction on the residuals term and invent something for the logdet term.
1037# TODO it raises very often with "Desired error not necessarily achieved due to
1038# precision loss.". I tried doing a forward grad on the logdet but does not fix
1039# the problem. I still suspect it's the logdet, maybe the value itself and not
1040# the derivative, because as the matrix changes the regularization can change a
1041# lot the value of the logdet. How do I stabilize it? => scipy's l-bfgs-b seems
1042# to fail the linear search less often
1044# TODO compute the logGBF for the whole fit (see the gpbart code). In its doc,
1045# specify that 1) additional_loss may break the normalization if the user does
1046# not know what they are doing 2) the calculation of the log determinant term
1047# heavily depends on the regularization if the covariance matrix is singular;
1048# this won't happen if there are independent error terms in the model as usual.
1050# TODO empbayes_fit(autoeps=True) tries to double epsabs until the minimization
1051# succedes, with some maximum number of tries. autoeps=dict(maxrepeat=5,
1052# increasefactor=2, initial=1e-16, startfromzero=True) allows to configure the
1053# algorithm.
1055# TODO empbayes_fit(maxiter=100) sets the maximum number of minimization
1056# iterations. maxiter=dict(iter=100, calls=200, callsperiter=10) allows to
1057# configure it more finely. The calls limits are cumulative on all functions
1058# (need to make a class counter in _CountCalls), I can probably implement them
1059# by returning nan when the limit is surpassed, I hope the minimizer stops
1060# immediately on nan (test this). => Callback can raise StopIteration.
1062# TODO can I approximate the hessian with only function values and no gradient,
1063# i.e., when using nelder-mead? => See Hare (2022), although I would not know
1064# how to apply it properly to the optimization history. Somehow I need to keep
1065# only the "last" iterations.
1067# TODO is there a better algorithm than lbfgs for inaccurate functions? consider
1068# SC-BFGS (https://github.com/frankecurtis/SCBFGS). See Basak (2022). And NonOpt
1069# (https://github.com/frankecurtis/NonOpt).
1071# TODO can I estimate the error on the likelihood with the matrices? It requires
1072# the condition number. Basak (2022) gives wide bounds. I could try an upper
1073# bound and see how it compares to the true error, assuming that the matrix was
1074# as ill-conditioned as possible, i.e., use eps as the lowest eigenvalue, and
1075# gershgorin as the highest one.
1077# TODO look into jaxopt: it has improved a lot since the last time I saw it. In
1078# particular, it implements l-bfgs and has a "do not stop on failed line search"
1079# option. And it probably supports float32, although a skim of the docs suggests
1080# it does not work well. => See also optimistix.
1082# TODO reimplement the timing system with host_callback.id_tap. It should
1083# preserve the order because id_tap takes inputs and outputs. I must take care
1084# to make all callbacks happen at runtime instead of having some of them at
1085# compile time. I tried once but failed. Currently host_callback is
1086# experimental, maybe wait until it isn't. => I think it fails because it's
1087# asynchronous and there is only one device. Maybe host_callback.call would
1088# work? => I think they are developing something like my token machinery.
1090# TODO dictionary argument jitkw, arguments passed to jax.jit?
1092# TODO parameter float32: bool to use short float type. I think that scipy's
1093# optimize may break down with short floats with default options, I hope that
1094# changing termination tolerances does the trick.
1096# TODO make separate_jac a parameter
1098# TODO add options in _CountCalls to track inputs and/or outputs to some maximum
1099# buffer length, activate it if the method (after applying user options,
1100# lowercasing, and inferring minimize's default) is l-bfgs-b and the covariance
1101# is minhess or auto, to the order specified in the arguments to l-bfgs-b (after
1102# defaults inference if missing) (add tests in test_fit to check that the
1103# defaults stay as inferred), to be used if l-bfgs-b returns a crooked hessian.
1104# --- Alternative: if covariance = 'auto', it could be appropriate to use fisher
1105# per definition. --- Alternative: add option covariance = 'lbfgs(<order>)' that
1106# does this for any method, although this would require computing the gradients
1107# afterwards if the gradient was not used. These alternatives are not mutually
1108# exclusive.
1110# TODO make a helper function/class method that takes in data transf dependent
1111# on hypers and outputs additional loss (the log jacobian of the appropriate
1112# function with the appropriate sign)