Coverage for src/lsqfitgp/_fit.py: 84%

551 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 13:39 +0000

1# lsqfitgp/_fit.py 

2# 

3# Copyright (c) 2020, 2022, 2023, 2024, 2025, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import re 1fedabc

21import warnings 1fedabc

22import functools 1fedabc

23import time 1fedabc

24import textwrap 1fedabc

25import datetime 1fedabc

26 

27import gvar 1fedabc

28import jax 1fedabc

29from jax import numpy as jnp 1fedabc

30import numpy 1fedabc

31from scipy import optimize 1fedabc

32from jax import tree_util 1fedabc

33 

34from . import _GP 1fedabc

35from . import _linalg 1fedabc

36from . import _jaxext 1fedabc

37from . import _gvarext 1fedabc

38from . import _array 1fedabc

39 

40# TODO the following token_ thing functionality may be provided by jax in the 

41# future, follow the developments 

42 

43@functools.singledispatch 1fedabc

44def token_getter(x): 1fedabc

45 return x 

46 

47@functools.singledispatch 1fedabc

48def token_setter(x, token): 1fedabc

49 return token 

50 

51@token_getter.register(jnp.ndarray) 1fedabc

52@token_getter.register(numpy.ndarray) 1fedabc

53def _(x): 1fedabc

54 return x[x.ndim * (0,)] if x.size else x 1fedabc

55 

56@token_setter.register(jnp.ndarray) 1fedabc

57@token_setter.register(numpy.ndarray) 1fedabc

58def _(x, token): 1fedabc

59 x = jnp.asarray(x) 1fedabc

60 return x.at[x.ndim * (0,)].set(token) if x.size else token 1fedabc

61 

62def token_map_leaf(func, x): 1fedabc

63 if isinstance(x, (jnp.ndarray, numpy.ndarray)): 63 ↛ 74line 63 didn't jump to line 74 because the condition on line 63 was always true1fedabc

64 token = token_getter(x) 1fedabc

65 @jax.custom_jvp 1fedabc

66 def jaxfunc(token): 1fedabc

67 return jax.pure_callback(func, token, token, vmap_method='expand_dims') 1fedabc

68 @jaxfunc.defjvp 1fedabc

69 def _(p, t): 1fedabc

70 return (jaxfunc(*p), *t) 1fedabc

71 token = jaxfunc(token) 1fedabc

72 return token_setter(x, token) 1fedabc

73 else: 

74 token = token_getter(x) 

75 token = func(token) 

76 return token_setter(x, token) 

77 

78def token_map(func, x): 1fedabc

79 return tree_util.tree_map(lambda x: token_map_leaf(func, x), x) 1fedabc

80 

81class Logger: 1fedabc

82 """ Class to manage a log. Can be used as superclass. Each line of the log 

83 has a verbosity level (an integer >= 0) and is printed only if this level is 

84 below a threshold. All lines are saved and the log can be retrieved. """ 

85 

86 def __init__(self, target_verbosity=0): 1fedabc

87 """ set the threshold used to exclude log lines """ 

88 self._verbosity = target_verbosity 1fedabc

89 self._loggedlines = [] 1fedabc

90 

91 def _indent(self, text, level=0): 1fedabc

92 """ indent a text by provided level or by global current level """ 

93 level = max(0, level + self.loglevel._level) 1edabc

94 prefix = 4 * level * ' ' 1edabc

95 return textwrap.indent(text, prefix) 1edabc

96 

97 def _select(self, verbosity, target_verbosity=None): 1fedabc

98 if target_verbosity is None: 98 ↛ 100line 98 didn't jump to line 100 because the condition on line 98 was always true1fedabc

99 target_verbosity = self._verbosity 1fedabc

100 if isinstance(verbosity, int): 1fedabc

101 return target_verbosity >= verbosity 1fedabc

102 else: 

103 return target_verbosity in verbosity 1fedabc

104 

105 def log(self, message, verbosity=1, *, level=0): 1fedabc

106 """ 

107 Print and record a message. 

108 

109 Parameters 

110 ---------- 

111 message : str 

112 The message to print. A newline is added unconditionally. 

113 verbosity : int or set, default 1 

114 The verbosity level(s) at which the message is printed. If an 

115 integer, it's printed at all levels >= that integer. If a set, at 

116 the specified levels. 

117 level : int, default 0 

118 The indentation level of the message. 

119 """ 

120 if self._select(verbosity): 1fedabc

121 print(self._indent(message, level)) 1edabc

122 self._loggedlines.append((message, verbosity, level + self.loglevel._level)) 1fedabc

123 

124 def getlog(self, target_verbosity=None, *, base_level=0): 1fedabc

125 """ return all logged line as a single string """ 

126 return '\n'.join( 

127 self._indent(message, base_level + level) 

128 for message, verbosity, level in self._loggedlines 

129 if self._select(verbosity, target_verbosity) 

130 ) 

131 

132 class _LogLevel: 1fedabc

133 """ shared context manager to indent messages """ 

134 

135 _level = 0 1fedabc

136 

137 @classmethod 1fedabc

138 def __enter__(cls): 1fedabc

139 cls._level += 1 1fedabc

140 

141 @classmethod 1fedabc

142 def __exit__(cls, *_): 1fedabc

143 cls._level -= 1 1fedabc

144 

145 loglevel = _LogLevel() 1fedabc

146 

147class empbayes_fit(Logger): 1fedabc

148 

149 SEPARATE_JAC = False 1fedabc

150 

151 def __init__( 1fedabc

152 self, 

153 hyperprior, 

154 gpfactory, 

155 data, 

156 *, 

157 raises=True, 

158 minkw={}, 

159 gpfactorykw={}, 

160 jit=True, 

161 method='gradient', 

162 initial='priormean', 

163 verbosity=0, 

164 covariance='auto', 

165 fix=None, 

166 mlkw={}, 

167 forward=False, 

168 additional_loss=None, 

169 ): 

170 """ 

171  

172 Maximum a posteriori fit. 

173  

174 Maximizes the marginal likelihood of the data with a Gaussian process 

175 model that depends on hyperparameters, multiplied by a prior on the 

176 hyperparameters. 

177  

178 Parameters 

179 ---------- 

180 hyperprior : scalar, array or dictionary of scalars/arrays 

181 A collection of gvars representing the prior for the 

182 hyperparameters. 

183 gpfactory : callable 

184 A function with signature gpfactory(hyperparams) -> GP object. The 

185 argument ``hyperparams`` has the same structure of the 

186 `empbayes_fit` argument ``hyperprior``. gpfactory must be 

187 JAX-friendly, i.e., use `jax.numpy` and `jax.scipy` instead of plain 

188 `numpy`/`scipy` and avoid assignments to arrays. 

189 data : dict, tuple or callable 

190 Dictionary of data that is passed to `GP.marginal_likelihood` on 

191 the GP object returned by ``gpfactory``. If a tuple, it contains the 

192 first two arguments to `GP.marginal_likelihood`. If a callable, it 

193 is called with the same arguments of ``gpfactory`` and must return 

194 the argument(s) for `GP.marginal_likelihood`. 

195 raises : bool, optional 

196 If True (default), raise an error when the minimization fails. 

197 Otherwise, use the last point of the minimization as result. 

198 minkw : dict, optional 

199 Keyword arguments passed to `scipy.optimize.minimize`, overwrites 

200 values specified by `empbayes_fit`. 

201 gpfactorykw : dict, optional 

202 Keyword arguments passed to ``gpfactory``, and also to ``data`` if 

203 it is a callable. If ``jit``, ``gpfactorykw`` crosses a `jax.jit` 

204 boundary, so it must contain objects understandable by `jax`. 

205 jit : bool 

206 If True (default), use `jax.jit` to compile the minimization target. 

207 method : str 

208 Minimization strategy. Options: 

209  

210 'nograd' 

211 Use a gradient-free method. 

212 'gradient' (default) 

213 Use a gradient-only method. 

214 'fisher' 

215 Use a Newton method with the Fisher information matrix plus 

216 the hyperprior precision matrix. 

217 initial : str, scalar, array, dictionary of scalars/arrays 

218 Starting point for the minimization, matching the format of 

219 ``hyperprior``, or one of the following options: 

220  

221 'priormean' (default) 

222 Start from the hyperprior mean. 

223 'priorsample' 

224 Take a random sample from the hyperprior. 

225 verbosity : int 

226 An integer indicating how much information is printed on the 

227 terminal: 

228  

229 0 (default) 

230 No logging. 

231 1 

232 Minimal report. 

233 2 

234 Detailed report. 

235 3 

236 Log each iteration. 

237 4 

238 More detailed iteration log. 

239 5 

240 Print the current parameter values at each iteration. 

241 covariance : str 

242 Method to estimate the posterior covariance matrix of the 

243 hyperparameters: 

244  

245 'fisher' 

246 Use the Fisher information in the MAP, plus the prior precision, 

247 as precision matrix. 

248 'minhess' 

249 Use the hessian estimate of the minimizer as precision matrix. 

250 'none' 

251 Do not estimate the covariance matrix. 

252 'auto' (default) 

253 ``'minhess'`` if applicable, ``'none'`` otherwise. 

254 fix : scalar, array or dictionary of scalars/arrays 

255 A set of booleans, with the same format as ``hyperprior``, 

256 indicating which hyperparameters are kept fixed to their initial 

257 value. Scalars and arrays are broadcasted to the shape of 

258 ``hyperprior``. If a dictionary, missing keys are treated as False. 

259 mlkw : dict 

260 Additional arguments passed to `GP.marginal_likelihood`. 

261 forward : bool, default False 

262 Use forward instead of backward derivatives. Typically, forward is 

263 faster with a small number of parameters. 

264 additional_loss : callable, optional 

265 A function with signature ``additional_loss(hyperparams) -> float`` 

266 which is added to the minus log marginal posterior of the 

267 hyperparameters. 

268  

269 Attributes 

270 ---------- 

271 p : scalar, array or dictionary of scalars/arrays 

272 A collection of gvars representing the hyperparameters that 

273 maximize their posterior. These gvars do not track correlations 

274 with the hyperprior or the data. 

275 prior : scalar, array or dictionary of scalars/arrays 

276 A copy of the hyperprior. 

277 initial : scalar, array or dictionary of scalars/arrays 

278 Starting point of the minimization, with the same format as ``p``. 

279 fix : scalar, array or dictionary of scalars/arrays 

280 A set of booleans, with the same format as ``p``, indicating which 

281 parameters were kept fixed to the values in ``initial``. 

282 pmean : scalar, array or dictionary of scalars/arrays 

283 Mean of ``p``. 

284 pcov : scalar, array or dictionary of scalars/arrays 

285 Covariance matrix of ``p``. 

286 minresult : scipy.optimize.OptimizeResult 

287 The result object returned by `scipy.optimize.minimize`. 

288 minargs : dict 

289 The arguments passed to `scipy.optimize.minimize`. 

290 gpfactory : callable 

291 The ``gpfactory`` argument. 

292 gpfactorykw : dict 

293 The ``gpfactorykw`` argument. 

294 data : dict, tuple or callable 

295 The ``data`` argument. 

296 

297 Raises 

298 ------ 

299 RuntimeError 

300 The minimization failed and ``raises`` is True. 

301  

302 """ 

303 

304 Logger.__init__(self, verbosity) 1fedabc

305 del verbosity 1fedabc

306 self.log('**** call lsqfitgp.empbayes_fit ****') 1fedabc

307 

308 assert callable(gpfactory) 1fedabc

309 

310 # analyze the hyperprior 

311 hpinitial, hpunflat = self._parse_hyperprior(hyperprior, initial, fix) 1fedabc

312 del hyperprior, initial, fix 1fedabc

313 

314 # analyze data 

315 data, cachedargs = self._parse_data(data) 1fedabc

316 

317 # define functions 

318 timer, functions = self._prepare_functions( 1fedabc

319 gpfactory=gpfactory, gpfactorykw=gpfactorykw, data=data, 

320 cachedargs=cachedargs, hpunflat=hpunflat, mlkw=mlkw, jit=jit, 

321 forward=forward, additional_loss=additional_loss, 

322 ) 

323 del gpfactory, gpfactorykw, data, cachedargs, mlkw, forward, additional_loss 1fedabc

324 

325 # prepare minimizer arguments 

326 minargs = self._prepare_minargs(method, functions, hpinitial) 1fedabc

327 

328 # set up callback to time and log iterations 

329 callback = self._Callback(self, functions, timer, hpunflat) 1fedabc

330 minargs.update(callback=callback) 1fedabc

331 

332 # check invalid argument before running minimizer 

333 if not covariance in ('auto', 'fisher', 'minhess', 'none'): 333 ↛ 334line 333 didn't jump to line 334 because the condition on line 333 was never true1fedabc

334 raise KeyError(covariance) 

335 

336 # add user arguments and minimize 

337 minargs.update(minkw) 1fedabc

338 self.log(f'minimizer method {minargs["method"]!r}', 2) 1fedabc

339 total = time.perf_counter() 1fedabc

340 result = optimize.minimize(**minargs) 1fedabc

341 

342 # check the minimization was successful 

343 self._check_success(result, raises) 1fedabc

344 

345 # compute posterior covariance of the hyperparameters 

346 cov = self._posterior_covariance(method, covariance, result, functions['fisher']) 1fedabc

347 

348 # log total timings and function calls 

349 total = time.perf_counter() - total 1fedabc

350 self._log_totals(total, timer, callback, jit, functions) 1fedabc

351 

352 ##### temporary fix for gplepage/gvar#50 ##### 

353 cov = numpy.array(cov, order='C') 1fedabc

354 ############################################## 

355 

356 # join posterior mean and covariance matrix 

357 uresult = gvar.gvar(result.x, cov) 1fedabc

358 

359 # set attributes 

360 self.p = gvar.gvar(hpunflat(uresult)) 1fedabc

361 self.pmean = gvar.mean(self.p) 1fedabc

362 self.pcov = gvar.evalcov(self.p) 1fedabc

363 self.minresult = result 1fedabc

364 self.minargs = minargs 1fedabc

365 

366 # tabulate hyperparameter prior and posterior 

367 if self._verbosity >= 2: 1fedabc

368 self.log(_gvarext.tabulate_together( 1edabc

369 self.prior, self.p, 

370 headers=['param', 'prior', 'posterior'], 

371 )) # TODO replace tabulate_toegether with something more flexible I 

372 # can use for the callback as well. Maybe import TextMatrix from 

373 # miscpy. 

374 # TODO print the transformed parameters 

375 

376 self.log('**** exit lsqfitgp.empbayes_fit ****') 1fedabc

377 

378 class _CountCalls: 1fedabc

379 """ wrap a callable to count calls """ 

380 

381 def __init__(self, func): 1fedabc

382 self._func = func 1fedabc

383 self._total = 0 1fedabc

384 self._partial = 0 1fedabc

385 functools.update_wrapper(self, func) 1fedabc

386 

387 def __call__(self, *args, **kw): 1fedabc

388 self._total += 1 1fedabc

389 self._partial += 1 1fedabc

390 return self._func(*args, **kw) 1fedabc

391 

392 def partial(self): 1fedabc

393 """ return the partial counter and reset it """ 

394 result = self._partial 1fedabc

395 self._partial = 0 1fedabc

396 return result 1fedabc

397 

398 def total(self): 1fedabc

399 """ return the total number of calls """ 

400 return self._total 1fedabc

401 

402 @staticmethod 1fedabc

403 def fmtcalls(method, functions): 1fedabc

404 """ 

405 format summary of number of calls 

406 method : str 

407 functions: dict[str, _CountCalls] 

408 """ 

409 def counts(): 1fedabc

410 for name, func in functions.items(): 1fedabc

411 if count := getattr(func, method)(): 1fedabc

412 yield f'{name} {count}' 1fedabc

413 return ', '.join(counts()) 1fedabc

414 

415 class _Timer: 1fedabc

416 """ object to time likelihood computations """ 

417 

418 def __init__(self): 1fedabc

419 self.totals = {} 1fedabc

420 self.partials = {} 1fedabc

421 self._last_start = False 1fedabc

422 

423 def start(self, token): 1fedabc

424 return token_map(self._start, token) 1fedabc

425 

426 def _start(self, token): 1fedabc

427 self.stamp = time.perf_counter() 1fedabc

428 self.counter = 0 1fedabc

429 assert not self._last_start # forbid consecutive start() calls 1fedabc

430 self._last_start = True 1fedabc

431 return token 1fedabc

432 

433 def reset(self): 1fedabc

434 self.partials = {} 1fedabc

435 

436 def partial(self, token): 1fedabc

437 return token_map(self._partial, token) 1fedabc

438 

439 def _partial(self, token): 1fedabc

440 now = time.perf_counter() 1fedabc

441 delta = now - self.stamp 1fedabc

442 self.partials[self.counter] = self.partials.get(self.counter, 0) + delta 1fedabc

443 self.totals[self.counter] = self.totals.get(self.counter, 0) + delta 1fedabc

444 self.stamp = now 1fedabc

445 self.counter += 1 1fedabc

446 self._last_start = False 1fedabc

447 return token 1fedabc

448 

449 def _parse_hyperprior(self, hyperprior, initial, fix): 1fedabc

450 

451 # check fix against hyperprior and fill missing values 

452 hyperprior = self._copyasarrayorbufferdict(hyperprior) 1fedabc

453 self._check_no_redundant_keys(hyperprior) 1fedabc

454 fix = self._parse_fix(hyperprior, fix) 1fedabc

455 flatfix = self._flatview(fix) 1fedabc

456 

457 # extract distribution of free hyperparameters 

458 flathp = self._flatview(hyperprior) 1fedabc

459 freehp = flathp[~flatfix] 1fedabc

460 mean = gvar.mean(freehp) 1fedabc

461 cov = gvar.evalcov(freehp) # TODO use evalcov_blocks 1fedabc

462 dec = _linalg.Chol(cov) 1fedabc

463 assert dec.n == freehp.size 1fedabc

464 self.log(f'{freehp.size}/{flathp.size} free hyperparameters', 2) 1fedabc

465 

466 # determine starting point for minimization 

467 initial = self._parse_initial(hyperprior, initial, dec) 1fedabc

468 flatinitial = self._flatview(initial) 1fedabc

469 x0 = dec.pinv_correlate(flatinitial[~flatfix] - mean) 1fedabc

470 # TODO for initial = 'priormean', x0 is zero, skip decorrelate 

471 # for initial = 'priorsample', x0 is iid normal, but I have to sync 

472 # it with the user-exposed unflattened initial in _parse_initial 

473 

474 # make function to correlate, add fixed values, and reshape to original 

475 # format 

476 fixed_indices, = jnp.nonzero(flatfix) 1fedabc

477 unfixed_indices, = jnp.nonzero(~flatfix) 1fedabc

478 fixed_values = jnp.asarray(flatinitial[flatfix]) 1fedabc

479 def unflat(x): 1fedabc

480 assert x.ndim == 1 1fedabc

481 if x.dtype == object: 1fedabc

482 jac, indices = _gvarext.jacobian(x) 1fedabc

483 xmean = mean + dec.correlate(gvar.mean(x)) 1fedabc

484 xjac = dec.correlate(jac) 1fedabc

485 x = _gvarext.from_jacobian(xmean, xjac, indices) 1fedabc

486 y = numpy.empty(flatfix.size, x.dtype) 1fedabc

487 numpy.put(y, unfixed_indices, x) 1fedabc

488 numpy.put(y, fixed_indices, fixed_values) 1fedabc

489 else: 

490 x = mean + dec.correlate(x) 1fedabc

491 y = jnp.empty(flatfix.size, x.dtype) 1fedabc

492 y = y.at[unfixed_indices].set(x) 1fedabc

493 y = y.at[fixed_indices].set(fixed_values) 1fedabc

494 return self._unflatview(y, hyperprior) 1fedabc

495 

496 self.prior = hyperprior 1fedabc

497 return x0, unflat 1fedabc

498 

499 @staticmethod 1fedabc

500 def _check_no_redundant_keys(hyperprior): 1fedabc

501 if not hasattr(hyperprior, 'keys'): 1fedabc

502 return 1dabc

503 for k in hyperprior: 1fedabc

504 m = hyperprior.extension_pattern.match(k) 1fedabc

505 if m and m.group(1) in hyperprior.invfcn: 1fedabc

506 altk = m.group(2) 1fedabc

507 if altk in hyperprior: 507 ↛ 508line 507 didn't jump to line 508 because the condition on line 507 was never true1fedabc

508 raise ValueError(f'duplicate keys {altk!r} and {k!r} in hyperprior') 

509 

510 def _parse_fix(self, hyperprior, fix): 1fedabc

511 

512 if fix is None: 512 ↛ 518line 512 didn't jump to line 518 because the condition on line 512 was always true1fedabc

513 if hasattr(hyperprior, 'keys'): 1fedabc

514 fix = gvar.BufferDict(hyperprior, buf=numpy.zeros(hyperprior.size, bool)) 1fedabc

515 else: 

516 fix = numpy.zeros(hyperprior.shape, bool) 1dabc

517 else: 

518 fix = self._copyasarrayorbufferdict(fix) 

519 if hasattr(fix, 'keys'): 

520 assert hasattr(hyperprior, 'keys'), 'fix is dictionary but hyperprior is array' 

521 assert all(hyperprior.has_dictkey(k) for k in fix), 'some keys in fix are missing in hyperprior' 

522 newfix = {} 

523 for k, v in hyperprior.items(): 

524 key = None 

525 m = hyperprior.extension_pattern.match(k) 

526 if m and m.group(1) in hyperprior.invfcn: 

527 altk = m.group(2) 

528 if altk in fix: 

529 assert k not in fix, f'duplicate keys {k!r} and {altk!r} in fix' 

530 key = altk 

531 if key is None and k in fix: 

532 key = k 

533 if key is None: 

534 elem = numpy.zeros(v.shape, bool) 

535 else: 

536 elem = numpy.broadcast_to(fix[key], v.shape) 

537 newfix[k] = elem 

538 fix = gvar.BufferDict(newfix, dtype=bool) 

539 else: 

540 assert not hasattr(hyperprior, 'keys'), 'fix is array but hyperprior is dictionary' 

541 fix = numpy.broadcast_to(fix, hyperprior.shape).astype(bool) 

542 

543 self.fix = fix 1fedabc

544 return fix 1fedabc

545 

546 def _parse_initial(self, hyperprior, initial, dec): 1fedabc

547 

548 if not isinstance(initial, str): 548 ↛ 549line 548 didn't jump to line 549 because the condition on line 548 was never true1fedabc

549 self.log('start from provided point', 2) 

550 initial = self._copyasarrayorbufferdict(initial) 

551 if hasattr(hyperprior, 'keys'): 

552 assert hasattr(initial, 'keys'), 'hyperprior is dictionary but initial is array' 

553 assert set(hyperprior.keys()) == set(initial.keys()) 

554 assert all(hyperprior[k].shape == initial[k].shape for k in hyperprior) 

555 else: 

556 assert not hasattr(initial, 'keys'), 'hyperprior is array but initial is dictionary' 

557 assert hyperprior.shape == initial.shape 

558 

559 elif initial == 'priormean': 559 ↛ 563line 559 didn't jump to line 563 because the condition on line 559 was always true1fedabc

560 self.log('start from prior mean', 2) 1fedabc

561 initial = gvar.mean(hyperprior) 1fedabc

562 

563 elif initial == 'priorsample': 

564 self.log('start from a random sample from the prior', 2) 

565 if dec.n < hyperprior.size: 

566 flathp = self._flatview(hyperprior) 

567 cov = gvar.evalcov(flathp) # TODO use evalcov_blocks 

568 fulldec = _linalg.Chol(cov) 

569 else: 

570 fulldec = dec 

571 iid = numpy.random.randn(fulldec.m) 

572 flatinitial = numpy.asarray(fulldec.correlate(iid)) 

573 initial = self._unflatview(flatinitial, hyperprior) 

574 

575 else: 

576 raise KeyError(initial) 

577 

578 self.initial = initial 1fedabc

579 return initial 1fedabc

580 

581 def _parse_data(self, data): 1fedabc

582 

583 self.data = data 1fedabc

584 if isinstance(data, tuple) and len(data) == 1: 1fedabc

585 data, = data 1dabc

586 

587 if callable(data): 1fedabc

588 self.log('data is callable', 2) 1edabc

589 cachedargs = None 1edabc

590 elif isinstance(data, tuple): 1fedabc

591 self.log('data errors provided separately', 2) 1dabc

592 assert len(data) == 2 1dabc

593 cachedargs = data 1dabc

594 elif (gdata := self._copyasarrayorbufferdict(data)).dtype == object: 1fedabc

595 self.log('data has errors as gvars', 2) 1edabc

596 data = gvar.gvar(gdata) 1edabc

597 # convert to gvar because non-gvars in the array would upset 

598 # gvar.mean and gvar.evalcov 

599 cachedargs = (gvar.mean(data), gvar.evalcov(data)) 1edabc

600 else: 

601 self.log('data has no errors', 2) 1fedabc

602 cachedargs = (data,) 1fedabc

603 

604 return data, cachedargs 1fedabc

605 

606 def _prepare_functions(self, *, gpfactory, gpfactorykw, data, cachedargs, 1fedabc

607 hpunflat, mlkw, jit, forward, additional_loss): 

608 

609 timer = self._Timer() 1fedabc

610 firstcall = [None] 1fedabc

611 

612 def make_decomp(p, **kw): 1fedabc

613 """ decomposition of the prior covariance and data """ 

614 

615 # start timer and convert hypers to user format 

616 p = timer.start(p) 1fedabc

617 hp = hpunflat(p) 1fedabc

618 

619 # create GP object 

620 gp = gpfactory(hp, **kw) 1fedabc

621 assert isinstance(gp, _GP.GP) 1fedabc

622 

623 # extract data 

624 if cachedargs: 1fedabc

625 args = cachedargs 1fedabc

626 else: 

627 args = data(hp, **kw) 1edabc

628 if not isinstance(args, tuple): 1edabc

629 args = (args,) 1edabc

630 

631 # decompose covariance matrix and flatten data 

632 decomp, r = gp._prior_decomp(*args, covtransf=timer.partial, **mlkw) 1fedabc

633 r = r.astype(float) # int data upsets jax 1fedabc

634 

635 # log number of datapoints 

636 if firstcall: 1fedabc

637 # it is convenient to do here because the data is flattened. 

638 # works under jit since the first call is tracing 

639 firstcall.pop() 1fedabc

640 xdtype = gp._get_x_dtype() 1fedabc

641 nd = '?' if xdtype is None else _array._nd(xdtype) 1fedabc

642 self.log(f'{r.size} datapoints, {nd} covariates') 1fedabc

643 

644 # compute user loss 

645 if additional_loss is None: 1fedabc

646 loss = 0. 1fedabc

647 else: 

648 loss = additional_loss(hp) 1edabc

649 

650 # split timer and return decomposition 

651 return timer.partial(decomp), r, loss 1fedabc

652 # TODO what's the correct way of checkpointing r? 

653 

654 # define wrapper to collect call stats, pass user args, compile 

655 def wrap(func): 1fedabc

656 if jit: 656 ↛ 658line 656 didn't jump to line 658 because the condition on line 656 was always true1fedabc

657 func = jax.jit(func) 1fedabc

658 func = functools.partial(func, **gpfactorykw) 1fedabc

659 return self._CountCalls(func) 1fedabc

660 if jit: 660 ↛ 664line 660 didn't jump to line 664 because the condition on line 660 was always true1fedabc

661 self.log('compile functions with jax jit', 2) 1fedabc

662 

663 # log derivation method 

664 modename = 'forward' if forward else 'reverse' 1fedabc

665 self.log(f'{modename}-mode autodiff (if used)', 2) 1fedabc

666 

667 # TODO time the derivatives separately => maybe I need a custom 

668 # derivative rule for timer token acknoledgement? 

669 

670 def prior(p): 1fedabc

671 # the marginal prior of the hyperparameters is a Normal with 

672 # identity covariance matrix because p is transformed to make it so 

673 return 1/2 * (len(p) * jnp.log(2 * jnp.pi) + p @ p) 1fedabc

674 

675 def grad_prior(p): 1fedabc

676 return p 1fedabc

677 

678 def fisher_prior(p): 1fedabc

679 return jnp.eye(len(p)) 1dabc

680 

681 @wrap 1fedabc

682 def fun(p, **kw): 1fedabc

683 """ minus log marginal posterior of the hyperparameters (not 

684 normalized) """ 

685 decomp, r, loss = make_decomp(p, **kw) 1dabc

686 cond, _, _, _, _ = decomp.minus_log_normal_density(r, value=True) 1dabc

687 post = cond + prior(p) + loss 1dabc

688 # TODO what's the correct way of checkpointing prior and loss? 

689 return timer.partial(post) 1dabc

690 

691 def make_gradfwd_fisher_args(p, **kw): 1fedabc

692 def make_decomp_tee(p): 1edabc

693 decomp, r, loss = make_decomp(p, **kw) 1edabc

694 return (decomp.matrix(), r, loss), (decomp, r, loss) 1edabc

695 (dK, dr, grad_loss), (decomp, r, loss) = jax.jacfwd(make_decomp_tee, has_aux=True)(p) 1edabc

696 lkw = dict(dK=dK, dr=dr) 1edabc

697 return decomp, r, lkw, loss, grad_loss 1edabc

698 

699 def make_gradrev_args(p, **kw): 1fedabc

700 def make_decomp_loss(p): 1fedabc

701 def make_decomp_r(p): 1fedabc

702 def make_decomp_K(p): 1fedabc

703 decomp, r, loss = make_decomp(p, **kw) 1fedabc

704 return decomp.matrix(), (decomp, r, loss) 1fedabc

705 _, dK_vjp, (decomp, r, loss) = jax.vjp(make_decomp_K, p, has_aux=True) 1fedabc

706 return r, (decomp, r, dK_vjp, loss) 1fedabc

707 _, dr_vjp, (decomp, r, dK_vjp, loss) = jax.vjp(make_decomp_r, p, has_aux=True) 1fedabc

708 return loss, (decomp, r, dK_vjp, dr_vjp, loss) 1fedabc

709 grad_loss, (decomp, r, dK_vjp, dr_vjp, loss) = jax.grad(make_decomp_loss, has_aux=True)(p) 1fedabc

710 unpack = lambda f: lambda x: f(x)[0] 1fedabc

711 dK_vjp = unpack(dK_vjp) 1fedabc

712 dr_vjp = unpack(dr_vjp) 1fedabc

713 lkw = dict(dK_vjp=dK_vjp, dr_vjp=dr_vjp) 1fedabc

714 return decomp, r, lkw, loss, grad_loss 1fedabc

715 

716 def make_jac_args(p, **kw): 1fedabc

717 if forward: 1fedabc

718 out = make_gradfwd_fisher_args(p, **kw) 1edabc

719 out[2].update(gradfwd=True) # out[2] is lkw 1edabc

720 else: 

721 out = make_gradrev_args(p, **kw) 1fedabc

722 out[2].update(gradrev=True) 1fedabc

723 return out 1fedabc

724 

725 @wrap 1fedabc

726 def fun_and_jac(p, **kw): 1fedabc

727 """ fun and its gradient """ 

728 decomp, r, lkw, loss, grad_loss = make_jac_args(p, **kw) 1fedabc

729 cond, gradrev, gradfwd, _, _ = decomp.minus_log_normal_density(r, value=True, **lkw) 1fedabc

730 post = cond + prior(p) + loss 1fedabc

731 grad_cond = gradfwd if forward else gradrev 1fedabc

732 grad_post = grad_cond + grad_prior(p) + grad_loss 1fedabc

733 return timer.partial((post, grad_post)) 1fedabc

734 

735 @wrap 1fedabc

736 def jac(p, **kw): 1fedabc

737 """ gradient of fun """ 

738 decomp, r, lkw, _, grad_loss = make_jac_args(p, **kw) 

739 _, gradrev, gradfwd, _, _ = decomp.minus_log_normal_density(r, **lkw) 

740 grad_cond = gradfwd if forward else gradrev 

741 grad_post = grad_cond + grad_prior(p) + grad_loss 

742 return timer.partial(grad_post) 

743 

744 @wrap 1fedabc

745 def fisher(p, **kw): 1fedabc

746 """ fisher matrix """ 

747 if additional_loss is not None: 1dabc

748 raise NotImplementedError( 1dabc

749 'Fisher matrix not implemented with additional_loss. It ' 

750 'is possible but I did not prioritize it. If you need it, ' 

751 'open an issue on github.') 

752 decomp, r, lkw, _, _ = make_gradfwd_fisher_args(p, **kw) 1dabc

753 _, _, _, fisher_cond, _ = decomp.minus_log_normal_density(r, fisher=True, **lkw) 1dabc

754 fisher_post = fisher_cond + fisher_prior(p) 1dabc

755 return timer.partial(fisher_post) 1dabc

756 

757 # set attributes 

758 self.gpfactory = gpfactory 1fedabc

759 self.gpfactorykw = gpfactorykw 1fedabc

760 

761 return timer, { 1fedabc

762 'fun': fun, 

763 'jac': jac, 

764 'fun&jac': fun_and_jac, 

765 'fisher': fisher, 

766 } 

767 

768 def _prepare_minargs(self, method, functions, hpinitial): 1fedabc

769 minargs = dict(fun=functions['fun&jac'], jac=True, x0=hpinitial) 1fedabc

770 if self.SEPARATE_JAC: 770 ↛ 771line 770 didn't jump to line 771 because the condition on line 770 was never true1fedabc

771 minargs.update(fun=functions['fun'], jac=functions['jac']) 

772 if method == 'nograd': 1fedabc

773 minargs.update(fun=functions['fun'], jac=None, method='nelder-mead') 1dabc

774 elif method == 'gradient': 1fedabc

775 minargs.update(method='bfgs') 1fedabc

776 elif method == 'fisher': 1dabc

777 minargs.update(hess=functions['fisher'], method='dogleg') 1dabc

778 # dogleg requires positive definiteness, fisher is p.s.d. 

779 # trust-constr has more options, but it seems to be slower than 

780 # dogleg, so I keep dogleg as default 

781 else: 

782 raise KeyError(method) 1dabc

783 self.log(f'method {method!r}', 2) 1fedabc

784 return minargs 1fedabc

785 

786 # TODO add method with fisher matvec instead of fisher matrix 

787 

788 def _log_totals(self, total, timer, callback, jit, functions): 1fedabc

789 times = { 1fedabc

790 'gp&cov': timer.totals[0], 

791 'decomp': timer.totals[1], 

792 'likelihood': timer.totals[2], 

793 'jit': None, # set now and delete later to keep it before 'other' 

794 'other': total - sum(timer.totals.values()), 

795 } 

796 if jit: 796 ↛ 803line 796 didn't jump to line 803 because the condition on line 796 was always true1fedabc

797 overhead = callback.estimate_firstcall_overhead() 1fedabc

798 # TODO this estimation ignores the jit compilation of the function 

799 # used to compute the precision matrix, to be precise I should 

800 # manually split the jit into compilation + evaluation or hook into 

801 # it somehow. Maybe the jit object keeps a compilation wall time 

802 # stat? 

803 if jit and overhead is not None: 1fedabc

804 times['jit'] = overhead 1fedabc

805 times['other'] -= overhead 1fedabc

806 else: 

807 del times['jit'] 1dabc

808 self.log('', 4) 1fedabc

809 calls = self._CountCalls.fmtcalls('total', functions) 1fedabc

810 self.log(f'calls: {calls}') 1fedabc

811 self.log(f'total time: {callback.fmttime(total)}') 1fedabc

812 self.log(f'partials: {callback.fmttimes(times)}', 2) 1fedabc

813 

814 def _check_success(self, result, raises): 1fedabc

815 if result.success: 1fedabc

816 self.log(f'minimization succeeded: {result.message}') 1fedabc

817 else: 

818 msg = f'minimization failed: {result.message}' 1fdabc

819 if raises: 1fdabc

820 raise RuntimeError(msg) 1dabc

821 elif self._verbosity == 0: 1fabc

822 warnings.warn(msg) 1fb

823 else: 

824 self.log(msg) 1abc

825 

826 def _posterior_covariance(self, method, covariance, minimizer_result, fisher_func): 1fedabc

827 

828 if covariance == 'auto': 828 ↛ 834line 828 didn't jump to line 834 because the condition on line 828 was always true1fedabc

829 if hasattr(minimizer_result, 'hess_inv') or hasattr(minimizer_result, 'hess'): 1fedabc

830 covariance = 'minhess' 1fedabc

831 else: 

832 covariance = 'none' 1dabc

833 

834 if covariance == 'fisher': 834 ↛ 835line 834 didn't jump to line 835 because the condition on line 834 was never true1fedabc

835 self.log('use fisher plus prior precision as precision', 2) 

836 if method == 'fisher': 

837 prec = minimizer_result.hess 

838 else: 

839 prec = fisher_func(minimizer_result.x) 

840 cov = _linalg.Chol(prec).ginv() 

841 

842 elif covariance == 'minhess': 1fedabc

843 if hasattr(minimizer_result, 'hess_inv'): 1fedabc

844 hessinv = minimizer_result.hess_inv 1fedabc

845 if isinstance(hessinv, optimize.LbfgsInvHessProduct): 1fedabc

846 self.log(f'convert LBFGS({hessinv.n_corrs}) hessian inverse to BFGS as covariance', 2) 1edabc

847 cov = self._invhess_lbfgs_to_bfgs(hessinv) 1edabc

848 # TODO this still gives a too wide cov when the minimization 

849 # terminates due to bad linear search, is it because of 

850 # dropped updates? This is currently keeping me from setting 

851 # l-bfgs-b as default minimization method. 

852 elif isinstance(hessinv, numpy.ndarray): 852 ↛ 867line 852 didn't jump to line 867 because the condition on line 852 was always true1fedabc

853 self.log('use minimizer estimate of inverse hessian as covariance', 2) 1fedabc

854 cov = hessinv 1fedabc

855 elif hasattr(minimizer_result, 'hess'): 855 ↛ 859line 855 didn't jump to line 859 because the condition on line 855 was always true1dabc

856 self.log('use minimizer hessian as precision', 2) 1dabc

857 cov = _linalg.Chol(minimizer_result.hess).ginv() 1dabc

858 else: 

859 raise RuntimeError('the minimizer did not return an estimate of the hessian') 

860 

861 elif covariance == 'none': 861 ↛ 865line 861 didn't jump to line 865 because the condition on line 861 was always true1dabc

862 cov = numpy.full(minimizer_result.x.size, numpy.nan) 1dabc

863 

864 else: 

865 raise KeyError(covariance) 

866 

867 return cov 1fedabc

868 

869 @staticmethod 1fedabc

870 def _invhess_lbfgs_to_bfgs(lbfgs): 1fedabc

871 bfgs = optimize.BFGS() 1edabc

872 bfgs.initialize(lbfgs.shape[0], 'inv_hess') 1edabc

873 for i in range(lbfgs.n_corrs): 1edabc

874 bfgs.update(lbfgs.sk[i], lbfgs.yk[i]) 1edabc

875 return bfgs.get_matrix() 1edabc

876 

877 class _Callback: 1fedabc

878 """ Iteration callback for scipy.optimize.minimize """ 

879 

880 def __init__(self, this, functions, timer, unflat): 1fedabc

881 self.it = 0 1fedabc

882 self.stamp = time.perf_counter() 1fedabc

883 self.this = this 1fedabc

884 self.functions = functions 1fedabc

885 self.timer = timer 1fedabc

886 self.unflat = unflat 1fedabc

887 self.tail_overhead = 0 1fedabc

888 self.tail_overhead_iter = 0 1fedabc

889 

890 def __call__(self, intermediate_result, arg2=None): 1fedabc

891 

892 if isinstance(intermediate_result, optimize.OptimizeResult): 892 ↛ 893line 892 didn't jump to line 893 because the condition on line 892 was never true1fedabc

893 p = intermediate_result.x 

894 elif isinstance(intermediate_result, numpy.ndarray): 894 ↛ 897line 894 didn't jump to line 897 because the condition on line 894 was always true1fedabc

895 p = intermediate_result 1fedabc

896 else: 

897 raise TypeError(type(intermediate_result)) 

898 

899 self.it += 1 1fedabc

900 now = time.perf_counter() 1fedabc

901 duration = now - self.stamp 1fedabc

902 

903 worktime = sum(self.timer.partials.values()) 1fedabc

904 if worktime: 1fedabc

905 overhead = duration - worktime 1fedabc

906 assert overhead >= 0, (duration, worktime) 1fedabc

907 if self.it == 1: 1fedabc

908 self.first_overhead = overhead 1fedabc

909 else: 

910 self.tail_overhead_iter += 1 1fedabc

911 self.tail_overhead += overhead 1fedabc

912 

913 # level 3 log 

914 calls = self.this._CountCalls.fmtcalls('partial', self.functions) 1fedabc

915 times = self.fmttime(duration) 1fedabc

916 self.this.log(f'iter {self.it}, time: {times}, calls: {calls}', {3}) 1fedabc

917 

918 # level 4 log 

919 tot = self.fmttime(duration) 1fedabc

920 if self.timer.partials: 1fedabc

921 times = { 1fedabc

922 'gp&cov': self.timer.partials[0], 

923 'dec': self.timer.partials[1], 

924 'like': self.timer.partials[2], 

925 'other': duration - sum(self.timer.partials.values()), 

926 } 

927 times = self.fmttimes(times) 1fedabc

928 else: 

929 times = 'n/d' 1dabc

930 self.this.log(f'\niteration {self.it}', 4) 1fedabc

931 with self.this.loglevel: 1fedabc

932 self.this.log(f'total time: {tot}', 4) 1fedabc

933 self.this.log(f'partial: {times}', 4) 1fedabc

934 self.this.log(f'calls: {calls}', 4) 1fedabc

935 

936 # level 5 log 

937 nicep = self.unflat(p) 1fedabc

938 nicep = self.this._copyasarrayorbufferdict(nicep) 1fedabc

939 with self.this.loglevel: 1fedabc

940 self.this.log(f'parameters = {nicep}', 5) 1fedabc

941 # TODO write a method to format the parameters nicely. => use 

942 # gvar.tabulate? => nope, need actual gvars 

943 # TODO does this logging add significant overhead? 

944 

945 self.stamp = now 1fedabc

946 self.timer.reset() 1fedabc

947 

948 pattern = re.compile( 1fedabc

949 r'((\d+) days, )?(\d{1,2}):(\d\d):(\d\d(\.\d{6})?)') 

950 

951 @classmethod 1fedabc

952 def fmttime(cls, seconds): 1fedabc

953 if seconds < 0: 953 ↛ 954line 953 didn't jump to line 954 because the condition on line 953 was never true1fedabc

954 prefix = '-' 

955 seconds = -seconds 

956 else: 

957 prefix = '' 1fedabc

958 return prefix + cls._fmttime_positive(seconds) 1fedabc

959 

960 @classmethod 1fedabc

961 def _fmttime_positive(cls, seconds): 1fedabc

962 td = datetime.timedelta(seconds=seconds) 1fedabc

963 m = cls.pattern.fullmatch(str(td)) 1fedabc

964 _, day, hour, minute, second, _ = m.groups() 1fedabc

965 hour = int(hour) 1fedabc

966 minute = int(minute) 1fedabc

967 second = float(second) 1fedabc

968 if day: 968 ↛ 969line 968 didn't jump to line 969 because the condition on line 968 was never true1fedabc

969 return f'{day.lstrip("0")}d{hour:02d}h' 

970 elif hour: 970 ↛ 971line 970 didn't jump to line 971 because the condition on line 970 was never true1fedabc

971 return f'{hour}h{minute:02d}m' 

972 elif minute: 1fedabc

973 return f'{minute}m{second:02.0f}s' 1e

974 elif second >= 0.0995: 1fedabc

975 return f'{second:#.2g}'.rstrip('.') + 's' 1fedabc

976 elif second >= 0.0000995: 1fedabc

977 return f'{second * 1e3:#.2g}'.rstrip('.') + 'ms' 1fedabc

978 else: 

979 return f'{second * 1e6:.0f}μs' 1feda

980 

981 @classmethod 1fedabc

982 def fmttimes(cls, times): 1fedabc

983 """ times = dict label -> seconds """ 

984 return ', '.join(f'{k} {cls.fmttime(v)}' for k, v in times.items()) 1fedabc

985 

986 def estimate_firstcall_overhead(self): 1fedabc

987 if self.tail_overhead_iter and hasattr(self, 'first_overhead'): 1fedabc

988 typical_overhead = self.tail_overhead / self.tail_overhead_iter 1fedabc

989 return self.first_overhead - typical_overhead 1fedabc

990 

991 @staticmethod 1fedabc

992 def _copyasarrayorbufferdict(x): 1fedabc

993 if hasattr(x, 'keys'): 1fedabc

994 return gvar.BufferDict(x) 1fedabc

995 else: 

996 return numpy.array(x) 1dabc

997 

998 @staticmethod 1fedabc

999 def _flatview(x): 1fedabc

1000 if hasattr(x, 'reshape'): 1fedabc

1001 return x.reshape(-1) 1dabc

1002 elif hasattr(x, 'buf'): 1fedabc

1003 return x.buf 1fedabc

1004 else: # pragma: no cover 

1005 raise NotImplementedError 

1006 

1007 @staticmethod 1fedabc

1008 def _unflatview(x, original): 1fedabc

1009 if isinstance(original, numpy.ndarray): 1fedabc

1010 # TODO is this never applied to jax arrays? 

1011 out = x.reshape(original.shape) 1dabc

1012 # if not out.shape: 

1013 # try: 

1014 # out = out.item() 

1015 # except jax.errors.ConcretizationTypeError: 

1016 # pass 

1017 return out 1dabc

1018 elif isinstance(original, gvar.BufferDict): 1fedabc

1019 # normally I would do BufferDict(original, buf=x) but it does not 

1020 # work with JAX tracers 

1021 b = gvar.BufferDict(original) 1fedabc

1022 b._extension = {} 1fedabc

1023 b._buf = x 1fedabc

1024 # b.buf = x does not work because BufferDict checks that the 

1025 # array is a numpy array 

1026 # TODO maybe make a feature request to gvar to accept array_like 

1027 # buf 

1028 return b 1fedabc

1029 else: # pragma: no cover 

1030 raise NotImplementedError 

1031 

1032 

1033# TODO would it be meaningful to add correlation of the fit result with the data 

1034# and hyperprior? 

1035 

1036# TODO add the second order correction. It probably requires more than the 

1037# gradient and inv_hess, but maybe by getting a little help from 

1038# marginal_likelihood I can use the least-squares optimized second order 

1039# correction on the residuals term and invent something for the logdet term. 

1040 

1041# TODO it raises very often with "Desired error not necessarily achieved due to 

1042# precision loss.". I tried doing a forward grad on the logdet but does not fix 

1043# the problem. I still suspect it's the logdet, maybe the value itself and not 

1044# the derivative, because as the matrix changes the regularization can change a 

1045# lot the value of the logdet. How do I stabilize it? => scipy's l-bfgs-b seems 

1046# to fail the linear search less often 

1047 

1048# TODO compute the logGBF for the whole fit (see the gpbart code). In its doc, 

1049# specify that 1) additional_loss may break the normalization if the user does 

1050# not know what they are doing 2) the calculation of the log determinant term 

1051# heavily depends on the regularization if the covariance matrix is singular; 

1052# this won't happen if there are independent error terms in the model as usual. 

1053 

1054# TODO empbayes_fit(autoeps=True) tries to double epsabs until the minimization 

1055# succedes, with some maximum number of tries. autoeps=dict(maxrepeat=5, 

1056# increasefactor=2, initial=1e-16, startfromzero=True) allows to configure the 

1057# algorithm. 

1058 

1059# TODO empbayes_fit(maxiter=100) sets the maximum number of minimization 

1060# iterations. maxiter=dict(iter=100, calls=200, callsperiter=10) allows to 

1061# configure it more finely. The calls limits are cumulative on all functions 

1062# (need to make a class counter in _CountCalls), I can probably implement them 

1063# by returning nan when the limit is surpassed, I hope the minimizer stops 

1064# immediately on nan (test this). => Callback can raise StopIteration. 

1065 

1066# TODO can I approximate the hessian with only function values and no gradient, 

1067# i.e., when using nelder-mead? => See Hare (2022), although I would not know 

1068# how to apply it properly to the optimization history. Somehow I need to keep 

1069# only the "last" iterations. 

1070 

1071# TODO is there a better algorithm than lbfgs for inaccurate functions? consider 

1072# SC-BFGS (https://github.com/frankecurtis/SCBFGS). See Basak (2022). And NonOpt 

1073# (https://github.com/frankecurtis/NonOpt). 

1074 

1075# TODO can I estimate the error on the likelihood with the matrices? It requires 

1076# the condition number. Basak (2022) gives wide bounds. I could try an upper 

1077# bound and see how it compares to the true error, assuming that the matrix was 

1078# as ill-conditioned as possible, i.e., use eps as the lowest eigenvalue, and 

1079# gershgorin as the highest one. 

1080 

1081# TODO look into jaxopt: it has improved a lot since the last time I saw it. In 

1082# particular, it implements l-bfgs and has a "do not stop on failed line search" 

1083# option. And it probably supports float32, although a skim of the docs suggests 

1084# it does not work well. => See also optimistix. 

1085 

1086# TODO reimplement the timing system with host_callback.id_tap. It should 

1087# preserve the order because id_tap takes inputs and outputs. I must take care 

1088# to make all callbacks happen at runtime instead of having some of them at 

1089# compile time. I tried once but failed. Currently host_callback is 

1090# experimental, maybe wait until it isn't. => I think it fails because it's 

1091# asynchronous and there is only one device. Maybe host_callback.call would 

1092# work? => I think they are developing something like my token machinery. 

1093 

1094# TODO dictionary argument jitkw, arguments passed to jax.jit? 

1095 

1096# TODO parameter float32: bool to use short float type. I think that scipy's 

1097# optimize may break down with short floats with default options, I hope that 

1098# changing termination tolerances does the trick. 

1099 

1100# TODO make separate_jac a parameter 

1101 

1102# TODO add options in _CountCalls to track inputs and/or outputs to some maximum 

1103# buffer length, activate it if the method (after applying user options, 

1104# lowercasing, and inferring minimize's default) is l-bfgs-b and the covariance 

1105# is minhess or auto, to the order specified in the arguments to l-bfgs-b (after 

1106# defaults inference if missing) (add tests in test_fit to check that the 

1107# defaults stay as inferred), to be used if l-bfgs-b returns a crooked hessian. 

1108# --- Alternative: if covariance = 'auto', it could be appropriate to use fisher 

1109# per definition. --- Alternative: add option covariance = 'lbfgs(<order>)' that 

1110# does this for any method, although this would require computing the gradients 

1111# afterwards if the gradient was not used. These alternatives are not mutually 

1112# exclusive. 

1113 

1114# TODO make a helper function/class method that takes in data transf dependent 

1115# on hypers and outputs additional loss (the log jacobian of the appropriate 

1116# function with the appropriate sign)