Coverage for src/lsqfitgp/_fit.py: 84%

550 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_fit.py 

2# 

3# Copyright (c) 2020, 2022, 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import re 1feabcd

21import warnings 1feabcd

22import functools 1feabcd

23import time 1feabcd

24import textwrap 1feabcd

25import datetime 1feabcd

26 

27import gvar 1feabcd

28import jax 1feabcd

29from jax import numpy as jnp 1feabcd

30import numpy 1feabcd

31from scipy import optimize 1feabcd

32from jax import tree_util 1feabcd

33 

34from . import _GP 1feabcd

35from . import _linalg 1feabcd

36from . import _jaxext 1feabcd

37from . import _gvarext 1feabcd

38from . import _array 1feabcd

39 

40# TODO the following token_ thing functionality may be provided by jax in the 

41# future, follow the developments 

42 

43@functools.singledispatch 1feabcd

44def token_getter(x): 1feabcd

45 return x 

46 

47@functools.singledispatch 1feabcd

48def token_setter(x, token): 1feabcd

49 return token 

50 

51@token_getter.register(jnp.ndarray) 1feabcd

52@token_getter.register(numpy.ndarray) 1feabcd

53def _(x): 1feabcd

54 return x[x.ndim * (0,)] if x.size else x 1feabcd

55 

56@token_setter.register(jnp.ndarray) 1feabcd

57@token_setter.register(numpy.ndarray) 1feabcd

58def _(x, token): 1feabcd

59 x = jnp.asarray(x) 1feabcd

60 return x.at[x.ndim * (0,)].set(token) if x.size else token 1feabcd

61 

62def token_map_leaf(func, x): 1feabcd

63 if isinstance(x, (jnp.ndarray, numpy.ndarray)): 63 ↛ 74line 63 didn't jump to line 74 because the condition on line 63 was always true1feabcd

64 token = token_getter(x) 1feabcd

65 @jax.custom_jvp 1feabcd

66 def jaxfunc(token): 1feabcd

67 return jax.pure_callback(func, token, token, vectorized=True) 1feabcd

68 @jaxfunc.defjvp 1feabcd

69 def _(p, t): 1feabcd

70 return (jaxfunc(*p), *t) 1feabcd

71 token = jaxfunc(token) 1feabcd

72 return token_setter(x, token) 1feabcd

73 else: 

74 token = token_getter(x) 

75 token = func(token) 

76 return token_setter(x, token) 

77 

78def token_map(func, x): 1feabcd

79 return tree_util.tree_map(lambda x: token_map_leaf(func, x), x) 1feabcd

80 

81class Logger: 1feabcd

82 """ Class to manage a log. Can be used as superclass. Each line of the log 

83 has a verbosity level (an integer >= 0) and is printed only if this level is 

84 below a threshold. All lines are saved and the log can be retrieved. """ 

85 

86 def __init__(self, target_verbosity=0): 1feabcd

87 """ set the threshold used to exclude log lines """ 

88 self._verbosity = target_verbosity 1feabcd

89 self._loggedlines = [] 1feabcd

90 

91 def _indent(self, text, level=0): 1feabcd

92 """ indent a text by provided level or by global current level """ 

93 level = max(0, level + self.loglevel._level) 1eabcd

94 prefix = 4 * level * ' ' 1eabcd

95 return textwrap.indent(text, prefix) 1eabcd

96 

97 def _select(self, verbosity, target_verbosity=None): 1feabcd

98 if target_verbosity is None: 98 ↛ 100line 98 didn't jump to line 100 because the condition on line 98 was always true1feabcd

99 target_verbosity = self._verbosity 1feabcd

100 if isinstance(verbosity, int): 1feabcd

101 return target_verbosity >= verbosity 1feabcd

102 else: 

103 return target_verbosity in verbosity 1feabcd

104 

105 def log(self, message, verbosity=1, *, level=0): 1feabcd

106 """ 

107 Print and record a message. 

108 

109 Parameters 

110 ---------- 

111 message : str 

112 The message to print. A newline is added unconditionally. 

113 verbosity : int or set, default 1 

114 The verbosity level(s) at which the message is printed. If an 

115 integer, it's printed at all levels >= that integer. If a set, at 

116 the specified levels. 

117 level : int, default 0 

118 The indentation level of the message. 

119 """ 

120 if self._select(verbosity): 1feabcd

121 print(self._indent(message, level)) 1eabcd

122 self._loggedlines.append((message, verbosity, level + self.loglevel._level)) 1feabcd

123 

124 def getlog(self, target_verbosity=None, *, base_level=0): 1feabcd

125 """ return all logged line as a single string """ 

126 return '\n'.join( 

127 self._indent(message, base_level + level) 

128 for message, verbosity, level in self._loggedlines 

129 if self._select(verbosity, target_verbosity) 

130 ) 

131 

132 class _LogLevel: 1feabcd

133 """ shared context manager to indent messages """ 

134 

135 _level = 0 1feabcd

136 

137 @classmethod 1feabcd

138 def __enter__(cls): 1feabcd

139 cls._level += 1 1feabcd

140 

141 @classmethod 1feabcd

142 def __exit__(cls, *_): 1feabcd

143 cls._level -= 1 1feabcd

144 

145 loglevel = _LogLevel() 1feabcd

146 

147class empbayes_fit(Logger): 1feabcd

148 

149 SEPARATE_JAC = False 1feabcd

150 

151 def __init__( 1feabcd

152 self, 

153 hyperprior, 

154 gpfactory, 

155 data, 

156 *, 

157 raises=True, 

158 minkw={}, 

159 gpfactorykw={}, 

160 jit=True, 

161 method='gradient', 

162 initial='priormean', 

163 verbosity=0, 

164 covariance='auto', 

165 fix=None, 

166 mlkw={}, 

167 forward=False, 

168 additional_loss=None, 

169 ): 

170 """ 

171  

172 Maximum a posteriori fit. 

173  

174 Maximizes the marginal likelihood of the data with a Gaussian process 

175 model that depends on hyperparameters, multiplied by a prior on the 

176 hyperparameters. 

177  

178 Parameters 

179 ---------- 

180 hyperprior : scalar, array or dictionary of scalars/arrays 

181 A collection of gvars representing the prior for the 

182 hyperparameters. 

183 gpfactory : callable 

184 A function with signature gpfactory(hyperparams) -> GP object. The 

185 argument ``hyperparams`` has the same structure of the 

186 `empbayes_fit` argument ``hyperprior``. gpfactory must be 

187 JAX-friendly, i.e., use `jax.numpy` and `jax.scipy` instead of plain 

188 `numpy`/`scipy` and avoid assignments to arrays. 

189 data : dict, tuple or callable 

190 Dictionary of data that is passed to `GP.marginal_likelihood` on 

191 the GP object returned by ``gpfactory``. If a tuple, it contains the 

192 first two arguments to `GP.marginal_likelihood`. If a callable, it 

193 is called with the same arguments of ``gpfactory`` and must return 

194 the argument(s) for `GP.marginal_likelihood`. 

195 raises : bool, optional 

196 If True (default), raise an error when the minimization fails. 

197 Otherwise, use the last point of the minimization as result. 

198 minkw : dict, optional 

199 Keyword arguments passed to `scipy.optimize.minimize`, overwrites 

200 values specified by `empbayes_fit`. 

201 gpfactorykw : dict, optional 

202 Keyword arguments passed to ``gpfactory``, and also to ``data`` if 

203 it is a callable. If ``jit``, ``gpfactorykw`` crosses a `jax.jit` 

204 boundary, so it must contain objects understandable by `jax`. 

205 jit : bool 

206 If True (default), use `jax.jit` to compile the minimization target. 

207 method : str 

208 Minimization strategy. Options: 

209  

210 'nograd' 

211 Use a gradient-free method. 

212 'gradient' (default) 

213 Use a gradient-only method. 

214 'fisher' 

215 Use a Newton method with the Fisher information matrix plus 

216 the hyperprior precision matrix. 

217 initial : str, scalar, array, dictionary of scalars/arrays 

218 Starting point for the minimization, matching the format of 

219 ``hyperprior``, or one of the following options: 

220  

221 'priormean' (default) 

222 Start from the hyperprior mean. 

223 'priorsample' 

224 Take a random sample from the hyperprior. 

225 verbosity : int 

226 An integer indicating how much information is printed on the 

227 terminal: 

228  

229 0 (default) 

230 No logging. 

231 1 

232 Minimal report. 

233 2 

234 Detailed report. 

235 3 

236 Log each iteration. 

237 4 

238 More detailed iteration log. 

239 5 

240 Print the current parameter values at each iteration. 

241 covariance : str 

242 Method to estimate the posterior covariance matrix of the 

243 hyperparameters: 

244  

245 'fisher' 

246 Use the Fisher information in the MAP, plus the prior precision, 

247 as precision matrix. 

248 'minhess' 

249 Use the hessian estimate of the minimizer as precision matrix. 

250 'none' 

251 Do not estimate the covariance matrix. 

252 'auto' (default) 

253 ``'minhess'`` if applicable, ``'none'`` otherwise. 

254 fix : scalar, array or dictionary of scalars/arrays 

255 A set of booleans, with the same format as ``hyperprior``, 

256 indicating which hyperparameters are kept fixed to their initial 

257 value. Scalars and arrays are broadcasted to the shape of 

258 ``hyperprior``. If a dictionary, missing keys are treated as False. 

259 mlkw : dict 

260 Additional arguments passed to `GP.marginal_likelihood`. 

261 forward : bool, default False 

262 Use forward instead of backward derivatives. Typically, forward is 

263 faster with a small number of parameters. 

264 additional_loss : callable, optional 

265 A function with signature ``additional_loss(hyperparams) -> float`` 

266 which is added to the minus log marginal posterior of the 

267 hyperparameters. 

268  

269 Attributes 

270 ---------- 

271 p : scalar, array or dictionary of scalars/arrays 

272 A collection of gvars representing the hyperparameters that 

273 maximize their posterior. These gvars do not track correlations 

274 with the hyperprior or the data. 

275 prior : scalar, array or dictionary of scalars/arrays 

276 A copy of the hyperprior. 

277 initial : scalar, array or dictionary of scalars/arrays 

278 Starting point of the minimization, with the same format as ``p``. 

279 fix : scalar, array or dictionary of scalars/arrays 

280 A set of booleans, with the same format as ``p``, indicating which 

281 parameters were kept fixed to the values in ``initial``. 

282 pmean : scalar, array or dictionary of scalars/arrays 

283 Mean of ``p``. 

284 pcov : scalar, array or dictionary of scalars/arrays 

285 Covariance matrix of ``p``. 

286 minresult : scipy.optimize.OptimizeResult 

287 The result object returned by `scipy.optimize.minimize`. 

288 minargs : dict 

289 The arguments passed to `scipy.optimize.minimize`. 

290 gpfactory : callable 

291 The ``gpfactory`` argument. 

292 gpfactorykw : dict 

293 The ``gpfactorykw`` argument. 

294 data : dict, tuple or callable 

295 The ``data`` argument. 

296 

297 Raises 

298 ------ 

299 RuntimeError 

300 The minimization failed and ``raises`` is True. 

301  

302 """ 

303 

304 Logger.__init__(self, verbosity) 1feabcd

305 del verbosity 1feabcd

306 self.log('**** call lsqfitgp.empbayes_fit ****') 1feabcd

307 

308 assert callable(gpfactory) 1feabcd

309 

310 # analyze the hyperprior 

311 hpinitial, hpunflat = self._parse_hyperprior(hyperprior, initial, fix) 1feabcd

312 del hyperprior, initial, fix 1feabcd

313 

314 # analyze data 

315 data, cachedargs = self._parse_data(data) 1feabcd

316 

317 # define functions 

318 timer, functions = self._prepare_functions( 1feabcd

319 gpfactory=gpfactory, gpfactorykw=gpfactorykw, data=data, 

320 cachedargs=cachedargs, hpunflat=hpunflat, mlkw=mlkw, jit=jit, 

321 forward=forward, additional_loss=additional_loss, 

322 ) 

323 del gpfactory, gpfactorykw, data, cachedargs, mlkw, forward, additional_loss 1feabcd

324 

325 # prepare minimizer arguments 

326 minargs = self._prepare_minargs(method, functions, hpinitial) 1feabcd

327 

328 # set up callback to time and log iterations 

329 callback = self._Callback(self, functions, timer, hpunflat) 1feabcd

330 minargs.update(callback=callback) 1feabcd

331 

332 # check invalid argument before running minimizer 

333 if not covariance in ('auto', 'fisher', 'minhess', 'none'): 333 ↛ 334line 333 didn't jump to line 334 because the condition on line 333 was never true1feabcd

334 raise KeyError(covariance) 

335 

336 # add user arguments and minimize 

337 minargs.update(minkw) 1feabcd

338 self.log(f'minimizer method {minargs["method"]!r}', 2) 1feabcd

339 total = time.perf_counter() 1feabcd

340 result = optimize.minimize(**minargs) 1feabcd

341 

342 # check the minimization was successful 

343 self._check_success(result, raises) 1feabcd

344 

345 # compute posterior covariance of the hyperparameters 

346 cov = self._posterior_covariance(method, covariance, result, functions['fisher']) 1feabcd

347 

348 # log total timings and function calls 

349 total = time.perf_counter() - total 1feabcd

350 self._log_totals(total, timer, callback, jit, functions) 1feabcd

351 

352 # join posterior mean and covariance matrix 

353 uresult = gvar.gvar(result.x, cov) 1feabcd

354 

355 # set attributes 

356 self.p = gvar.gvar(hpunflat(uresult)) 1feabcd

357 self.pmean = gvar.mean(self.p) 1feabcd

358 self.pcov = gvar.evalcov(self.p) 1feabcd

359 self.minresult = result 1feabcd

360 self.minargs = minargs 1feabcd

361 

362 # tabulate hyperparameter prior and posterior 

363 if self._verbosity >= 2: 1feabcd

364 self.log(_gvarext.tabulate_together( 1eabcd

365 self.prior, self.p, 

366 headers=['param', 'prior', 'posterior'], 

367 )) # TODO replace tabulate_toegether with something more flexible I 

368 # can use for the callback as well. Maybe import TextMatrix from 

369 # miscpy. 

370 # TODO print the transformed parameters 

371 

372 self.log('**** exit lsqfitgp.empbayes_fit ****') 1feabcd

373 

374 class _CountCalls: 1feabcd

375 """ wrap a callable to count calls """ 

376 

377 def __init__(self, func): 1feabcd

378 self._func = func 1feabcd

379 self._total = 0 1feabcd

380 self._partial = 0 1feabcd

381 functools.update_wrapper(self, func) 1feabcd

382 

383 def __call__(self, *args, **kw): 1feabcd

384 self._total += 1 1feabcd

385 self._partial += 1 1feabcd

386 return self._func(*args, **kw) 1feabcd

387 

388 def partial(self): 1feabcd

389 """ return the partial counter and reset it """ 

390 result = self._partial 1feabcd

391 self._partial = 0 1feabcd

392 return result 1feabcd

393 

394 def total(self): 1feabcd

395 """ return the total number of calls """ 

396 return self._total 1feabcd

397 

398 @staticmethod 1feabcd

399 def fmtcalls(method, functions): 1feabcd

400 """ 

401 format summary of number of calls 

402 method : str 

403 functions: dict[str, _CountCalls] 

404 """ 

405 def counts(): 1feabcd

406 for name, func in functions.items(): 1feabcd

407 if count := getattr(func, method)(): 1feabcd

408 yield f'{name} {count}' 1feabcd

409 return ', '.join(counts()) 1feabcd

410 

411 class _Timer: 1feabcd

412 """ object to time likelihood computations """ 

413 

414 def __init__(self): 1feabcd

415 self.totals = {} 1feabcd

416 self.partials = {} 1feabcd

417 self._last_start = False 1feabcd

418 

419 def start(self, token): 1feabcd

420 return token_map(self._start, token) 1feabcd

421 

422 def _start(self, token): 1feabcd

423 self.stamp = time.perf_counter() 1abcd

424 self.counter = 0 1abcd

425 assert not self._last_start # forbid consecutive start() calls 1abcd

426 self._last_start = True 1abcd

427 return token 1abcd

428 

429 def reset(self): 1feabcd

430 self.partials = {} 1feabcd

431 

432 def partial(self, token): 1feabcd

433 return token_map(self._partial, token) 1feabcd

434 

435 def _partial(self, token): 1feabcd

436 now = time.perf_counter() 1abcd

437 delta = now - self.stamp 1abcd

438 self.partials[self.counter] = self.partials.get(self.counter, 0) + delta 1abcd

439 self.totals[self.counter] = self.totals.get(self.counter, 0) + delta 1abcd

440 self.stamp = now 1abcd

441 self.counter += 1 1abcd

442 self._last_start = False 1abcd

443 return token 1abcd

444 

445 def _parse_hyperprior(self, hyperprior, initial, fix): 1feabcd

446 

447 # check fix against hyperprior and fill missing values 

448 hyperprior = self._copyasarrayorbufferdict(hyperprior) 1feabcd

449 self._check_no_redundant_keys(hyperprior) 1feabcd

450 fix = self._parse_fix(hyperprior, fix) 1feabcd

451 flatfix = self._flatview(fix) 1feabcd

452 

453 # extract distribution of free hyperparameters 

454 flathp = self._flatview(hyperprior) 1feabcd

455 freehp = flathp[~flatfix] 1feabcd

456 mean = gvar.mean(freehp) 1feabcd

457 cov = gvar.evalcov(freehp) # TODO use evalcov_blocks 1feabcd

458 dec = _linalg.Chol(cov) 1feabcd

459 assert dec.n == freehp.size 1feabcd

460 self.log(f'{freehp.size}/{flathp.size} free hyperparameters', 2) 1feabcd

461 

462 # determine starting point for minimization 

463 initial = self._parse_initial(hyperprior, initial, dec) 1feabcd

464 flatinitial = self._flatview(initial) 1feabcd

465 x0 = dec.pinv_correlate(flatinitial[~flatfix] - mean) 1feabcd

466 # TODO for initial = 'priormean', x0 is zero, skip decorrelate 

467 # for initial = 'priorsample', x0 is iid normal, but I have to sync 

468 # it with the user-exposed unflattened initial in _parse_initial 

469 

470 # make function to correlate, add fixed values, and reshape to original 

471 # format 

472 fixed_indices, = jnp.nonzero(flatfix) 1feabcd

473 unfixed_indices, = jnp.nonzero(~flatfix) 1feabcd

474 fixed_values = jnp.asarray(flatinitial[flatfix]) 1feabcd

475 def unflat(x): 1feabcd

476 assert x.ndim == 1 1feabcd

477 if x.dtype == object: 1feabcd

478 jac, indices = _gvarext.jacobian(x) 1feabcd

479 xmean = mean + dec.correlate(gvar.mean(x)) 1feabcd

480 xjac = dec.correlate(jac) 1feabcd

481 x = _gvarext.from_jacobian(xmean, xjac, indices) 1feabcd

482 y = numpy.empty(flatfix.size, x.dtype) 1feabcd

483 numpy.put(y, unfixed_indices, x) 1feabcd

484 numpy.put(y, fixed_indices, fixed_values) 1feabcd

485 else: 

486 x = mean + dec.correlate(x) 1feabcd

487 y = jnp.empty(flatfix.size, x.dtype) 1feabcd

488 y = y.at[unfixed_indices].set(x) 1feabcd

489 y = y.at[fixed_indices].set(fixed_values) 1feabcd

490 return self._unflatview(y, hyperprior) 1feabcd

491 

492 self.prior = hyperprior 1feabcd

493 return x0, unflat 1feabcd

494 

495 @staticmethod 1feabcd

496 def _check_no_redundant_keys(hyperprior): 1feabcd

497 if not hasattr(hyperprior, 'keys'): 1feabcd

498 return 1abcd

499 for k in hyperprior: 1feabcd

500 m = hyperprior.extension_pattern.match(k) 1feabcd

501 if m and m.group(1) in hyperprior.invfcn: 1feabcd

502 altk = m.group(2) 1feabcd

503 if altk in hyperprior: 503 ↛ 504line 503 didn't jump to line 504 because the condition on line 503 was never true1feabcd

504 raise ValueError(f'duplicate keys {altk!r} and {k!r} in hyperprior') 

505 

506 def _parse_fix(self, hyperprior, fix): 1feabcd

507 

508 if fix is None: 508 ↛ 514line 508 didn't jump to line 514 because the condition on line 508 was always true1feabcd

509 if hasattr(hyperprior, 'keys'): 1feabcd

510 fix = gvar.BufferDict(hyperprior, buf=numpy.zeros(hyperprior.size, bool)) 1feabcd

511 else: 

512 fix = numpy.zeros(hyperprior.shape, bool) 1abcd

513 else: 

514 fix = self._copyasarrayorbufferdict(fix) 

515 if hasattr(fix, 'keys'): 

516 assert hasattr(hyperprior, 'keys'), 'fix is dictionary but hyperprior is array' 

517 assert all(hyperprior.has_dictkey(k) for k in fix), 'some keys in fix are missing in hyperprior' 

518 newfix = {} 

519 for k, v in hyperprior.items(): 

520 key = None 

521 m = hyperprior.extension_pattern.match(k) 

522 if m and m.group(1) in hyperprior.invfcn: 

523 altk = m.group(2) 

524 if altk in fix: 

525 assert k not in fix, f'duplicate keys {k!r} and {altk!r} in fix' 

526 key = altk 

527 if key is None and k in fix: 

528 key = k 

529 if key is None: 

530 elem = numpy.zeros(v.shape, bool) 

531 else: 

532 elem = numpy.broadcast_to(fix[key], v.shape) 

533 newfix[k] = elem 

534 fix = gvar.BufferDict(newfix, dtype=bool) 

535 else: 

536 assert not hasattr(hyperprior, 'keys'), 'fix is array but hyperprior is dictionary' 

537 fix = numpy.broadcast_to(fix, hyperprior.shape).astype(bool) 

538 

539 self.fix = fix 1feabcd

540 return fix 1feabcd

541 

542 def _parse_initial(self, hyperprior, initial, dec): 1feabcd

543 

544 if not isinstance(initial, str): 544 ↛ 545line 544 didn't jump to line 545 because the condition on line 544 was never true1feabcd

545 self.log('start from provided point', 2) 

546 initial = self._copyasarrayorbufferdict(initial) 

547 if hasattr(hyperprior, 'keys'): 

548 assert hasattr(initial, 'keys'), 'hyperprior is dictionary but initial is array' 

549 assert set(hyperprior.keys()) == set(initial.keys()) 

550 assert all(hyperprior[k].shape == initial[k].shape for k in hyperprior) 

551 else: 

552 assert not hasattr(initial, 'keys'), 'hyperprior is array but initial is dictionary' 

553 assert hyperprior.shape == initial.shape 

554 

555 elif initial == 'priormean': 555 ↛ 559line 555 didn't jump to line 559 because the condition on line 555 was always true1feabcd

556 self.log('start from prior mean', 2) 1feabcd

557 initial = gvar.mean(hyperprior) 1feabcd

558 

559 elif initial == 'priorsample': 

560 self.log('start from a random sample from the prior', 2) 

561 if dec.n < hyperprior.size: 

562 flathp = self._flatview(hyperprior) 

563 cov = gvar.evalcov(flathp) # TODO use evalcov_blocks 

564 fulldec = _linalg.Chol(cov) 

565 else: 

566 fulldec = dec 

567 iid = numpy.random.randn(fulldec.m) 

568 flatinitial = numpy.asarray(fulldec.correlate(iid)) 

569 initial = self._unflatview(flatinitial, hyperprior) 

570 

571 else: 

572 raise KeyError(initial) 

573 

574 self.initial = initial 1feabcd

575 return initial 1feabcd

576 

577 def _parse_data(self, data): 1feabcd

578 

579 self.data = data 1feabcd

580 if isinstance(data, tuple) and len(data) == 1: 1feabcd

581 data, = data 1abcd

582 

583 if callable(data): 1feabcd

584 self.log('data is callable', 2) 1eabcd

585 cachedargs = None 1eabcd

586 elif isinstance(data, tuple): 1feabcd

587 self.log('data errors provided separately', 2) 1abcd

588 assert len(data) == 2 1abcd

589 cachedargs = data 1abcd

590 elif (gdata := self._copyasarrayorbufferdict(data)).dtype == object: 1feabcd

591 self.log('data has errors as gvars', 2) 1eabcd

592 data = gvar.gvar(gdata) 1eabcd

593 # convert to gvar because non-gvars in the array would upset 

594 # gvar.mean and gvar.evalcov 

595 cachedargs = (gvar.mean(data), gvar.evalcov(data)) 1eabcd

596 else: 

597 self.log('data has no errors', 2) 1feabcd

598 cachedargs = (data,) 1feabcd

599 

600 return data, cachedargs 1feabcd

601 

602 def _prepare_functions(self, *, gpfactory, gpfactorykw, data, cachedargs, 1feabcd

603 hpunflat, mlkw, jit, forward, additional_loss): 

604 

605 timer = self._Timer() 1feabcd

606 firstcall = [None] 1feabcd

607 

608 def make_decomp(p, **kw): 1feabcd

609 """ decomposition of the prior covariance and data """ 

610 

611 # start timer and convert hypers to user format 

612 p = timer.start(p) 1feabcd

613 hp = hpunflat(p) 1feabcd

614 

615 # create GP object 

616 gp = gpfactory(hp, **kw) 1feabcd

617 assert isinstance(gp, _GP.GP) 1feabcd

618 

619 # extract data 

620 if cachedargs: 1feabcd

621 args = cachedargs 1feabcd

622 else: 

623 args = data(hp, **kw) 1eabcd

624 if not isinstance(args, tuple): 1eabcd

625 args = (args,) 1eabcd

626 

627 # decompose covariance matrix and flatten data 

628 decomp, r = gp._prior_decomp(*args, covtransf=timer.partial, **mlkw) 1feabcd

629 r = r.astype(float) # int data upsets jax 1feabcd

630 

631 # log number of datapoints 

632 if firstcall: 1feabcd

633 # it is convenient to do here because the data is flattened. 

634 # works under jit since the first call is tracing 

635 firstcall.pop() 1feabcd

636 xdtype = gp._get_x_dtype() 1feabcd

637 nd = '?' if xdtype is None else _array._nd(xdtype) 1feabcd

638 self.log(f'{r.size} datapoints, {nd} covariates') 1feabcd

639 

640 # compute user loss 

641 if additional_loss is None: 1feabcd

642 loss = 0. 1feabcd

643 else: 

644 loss = additional_loss(hp) 1eabcd

645 

646 # split timer and return decomposition 

647 return timer.partial(decomp), r, loss 1feabcd

648 # TODO what's the correct way of checkpointing r? 

649 

650 # define wrapper to collect call stats, pass user args, compile 

651 def wrap(func): 1feabcd

652 if jit: 652 ↛ 654line 652 didn't jump to line 654 because the condition on line 652 was always true1feabcd

653 func = jax.jit(func) 1feabcd

654 func = functools.partial(func, **gpfactorykw) 1feabcd

655 return self._CountCalls(func) 1feabcd

656 if jit: 656 ↛ 660line 656 didn't jump to line 660 because the condition on line 656 was always true1feabcd

657 self.log('compile functions with jax jit', 2) 1feabcd

658 

659 # log derivation method 

660 modename = 'forward' if forward else 'reverse' 1feabcd

661 self.log(f'{modename}-mode autodiff (if used)', 2) 1feabcd

662 

663 # TODO time the derivatives separately => maybe I need a custom 

664 # derivative rule for timer token acknoledgement? 

665 

666 def prior(p): 1feabcd

667 # the marginal prior of the hyperparameters is a Normal with 

668 # identity covariance matrix because p is transformed to make it so 

669 return 1/2 * (len(p) * jnp.log(2 * jnp.pi) + p @ p) 1feabcd

670 

671 def grad_prior(p): 1feabcd

672 return p 1feabcd

673 

674 def fisher_prior(p): 1feabcd

675 return jnp.eye(len(p)) 1abcd

676 

677 @wrap 1feabcd

678 def fun(p, **kw): 1feabcd

679 """ minus log marginal posterior of the hyperparameters (not 

680 normalized) """ 

681 decomp, r, loss = make_decomp(p, **kw) 1abcd

682 cond, _, _, _, _ = decomp.minus_log_normal_density(r, value=True) 1abcd

683 post = cond + prior(p) + loss 1abcd

684 # TODO what's the correct way of checkpointing prior and loss? 

685 return timer.partial(post) 1abcd

686 

687 def make_gradfwd_fisher_args(p, **kw): 1feabcd

688 def make_decomp_tee(p): 1eabcd

689 decomp, r, loss = make_decomp(p, **kw) 1eabcd

690 return (decomp.matrix(), r, loss), (decomp, r, loss) 1eabcd

691 (dK, dr, grad_loss), (decomp, r, loss) = jax.jacfwd(make_decomp_tee, has_aux=True)(p) 1eabcd

692 lkw = dict(dK=dK, dr=dr) 1eabcd

693 return decomp, r, lkw, loss, grad_loss 1eabcd

694 

695 def make_gradrev_args(p, **kw): 1feabcd

696 def make_decomp_loss(p): 1feabcd

697 def make_decomp_r(p): 1feabcd

698 def make_decomp_K(p): 1feabcd

699 decomp, r, loss = make_decomp(p, **kw) 1feabcd

700 return decomp.matrix(), (decomp, r, loss) 1feabcd

701 _, dK_vjp, (decomp, r, loss) = jax.vjp(make_decomp_K, p, has_aux=True) 1feabcd

702 return r, (decomp, r, dK_vjp, loss) 1feabcd

703 _, dr_vjp, (decomp, r, dK_vjp, loss) = jax.vjp(make_decomp_r, p, has_aux=True) 1feabcd

704 return loss, (decomp, r, dK_vjp, dr_vjp, loss) 1feabcd

705 grad_loss, (decomp, r, dK_vjp, dr_vjp, loss) = jax.grad(make_decomp_loss, has_aux=True)(p) 1feabcd

706 unpack = lambda f: lambda x: f(x)[0] 1feabcd

707 dK_vjp = unpack(dK_vjp) 1feabcd

708 dr_vjp = unpack(dr_vjp) 1feabcd

709 lkw = dict(dK_vjp=dK_vjp, dr_vjp=dr_vjp) 1feabcd

710 return decomp, r, lkw, loss, grad_loss 1feabcd

711 

712 def make_jac_args(p, **kw): 1feabcd

713 if forward: 1feabcd

714 out = make_gradfwd_fisher_args(p, **kw) 1eabcd

715 out[2].update(gradfwd=True) # out[2] is lkw 1eabcd

716 else: 

717 out = make_gradrev_args(p, **kw) 1feabcd

718 out[2].update(gradrev=True) 1feabcd

719 return out 1feabcd

720 

721 @wrap 1feabcd

722 def fun_and_jac(p, **kw): 1feabcd

723 """ fun and its gradient """ 

724 decomp, r, lkw, loss, grad_loss = make_jac_args(p, **kw) 1feabcd

725 cond, gradrev, gradfwd, _, _ = decomp.minus_log_normal_density(r, value=True, **lkw) 1feabcd

726 post = cond + prior(p) + loss 1feabcd

727 grad_cond = gradfwd if forward else gradrev 1feabcd

728 grad_post = grad_cond + grad_prior(p) + grad_loss 1feabcd

729 return timer.partial((post, grad_post)) 1feabcd

730 

731 @wrap 1feabcd

732 def jac(p, **kw): 1feabcd

733 """ gradient of fun """ 

734 decomp, r, lkw, _, grad_loss = make_jac_args(p, **kw) 

735 _, gradrev, gradfwd, _, _ = decomp.minus_log_normal_density(r, **lkw) 

736 grad_cond = gradfwd if forward else gradrev 

737 grad_post = grad_cond + grad_prior(p) + grad_loss 

738 return timer.partial(grad_post) 

739 

740 @wrap 1feabcd

741 def fisher(p, **kw): 1feabcd

742 """ fisher matrix """ 

743 if additional_loss is not None: 1abcd

744 raise NotImplementedError( 1abcd

745 'Fisher matrix not implemented with additional_loss. It ' 

746 'is possible but I did not prioritize it. If you need it, ' 

747 'open an issue on github.') 

748 decomp, r, lkw, _, _ = make_gradfwd_fisher_args(p, **kw) 1abcd

749 _, _, _, fisher_cond, _ = decomp.minus_log_normal_density(r, fisher=True, **lkw) 1abcd

750 fisher_post = fisher_cond + fisher_prior(p) 1abcd

751 return timer.partial(fisher_post) 1abcd

752 

753 # set attributes 

754 self.gpfactory = gpfactory 1feabcd

755 self.gpfactorykw = gpfactorykw 1feabcd

756 

757 return timer, { 1feabcd

758 'fun': fun, 

759 'jac': jac, 

760 'fun&jac': fun_and_jac, 

761 'fisher': fisher, 

762 } 

763 

764 def _prepare_minargs(self, method, functions, hpinitial): 1feabcd

765 minargs = dict(fun=functions['fun&jac'], jac=True, x0=hpinitial) 1feabcd

766 if self.SEPARATE_JAC: 766 ↛ 767line 766 didn't jump to line 767 because the condition on line 766 was never true1feabcd

767 minargs.update(fun=functions['fun'], jac=functions['jac']) 

768 if method == 'nograd': 1feabcd

769 minargs.update(fun=functions['fun'], jac=None, method='nelder-mead') 1abcd

770 elif method == 'gradient': 1feabcd

771 minargs.update(method='bfgs') 1feabcd

772 elif method == 'fisher': 1abcd

773 minargs.update(hess=functions['fisher'], method='dogleg') 1abcd

774 # dogleg requires positive definiteness, fisher is p.s.d. 

775 # trust-constr has more options, but it seems to be slower than 

776 # dogleg, so I keep dogleg as default 

777 else: 

778 raise KeyError(method) 1abcd

779 self.log(f'method {method!r}', 2) 1feabcd

780 return minargs 1feabcd

781 

782 # TODO add method with fisher matvec instead of fisher matrix 

783 

784 def _log_totals(self, total, timer, callback, jit, functions): 1feabcd

785 times = { 1feabcd

786 'gp&cov': timer.totals[0], 

787 'decomp': timer.totals[1], 

788 'likelihood': timer.totals[2], 

789 'jit': None, # set now and delete later to keep it before 'other' 

790 'other': total - sum(timer.totals.values()), 

791 } 

792 if jit: 792 ↛ 799line 792 didn't jump to line 799 because the condition on line 792 was always true1feabcd

793 overhead = callback.estimate_firstcall_overhead() 1feabcd

794 # TODO this estimation ignores the jit compilation of the function 

795 # used to compute the precision matrix, to be precise I should 

796 # manually split the jit into compilation + evaluation or hook into 

797 # it somehow. Maybe the jit object keeps a compilation wall time 

798 # stat? 

799 if jit and overhead is not None: 1feabcd

800 times['jit'] = overhead 1feabcd

801 times['other'] -= overhead 1feabcd

802 else: 

803 del times['jit'] 1abcd

804 self.log('', 4) 1feabcd

805 calls = self._CountCalls.fmtcalls('total', functions) 1feabcd

806 self.log(f'calls: {calls}') 1feabcd

807 self.log(f'total time: {callback.fmttime(total)}') 1feabcd

808 self.log(f'partials: {callback.fmttimes(times)}', 2) 1feabcd

809 

810 def _check_success(self, result, raises): 1feabcd

811 if result.success: 1feabcd

812 self.log(f'minimization succeeded: {result.message}') 1feabcd

813 else: 

814 msg = f'minimization failed: {result.message}' 1fabcd

815 if raises: 1fabcd

816 raise RuntimeError(msg) 1abcd

817 elif self._verbosity == 0: 817 ↛ 820line 817 didn't jump to line 820 because the condition on line 817 was always true1f

818 warnings.warn(msg) 1f

819 else: 

820 self.log(msg) 

821 

822 def _posterior_covariance(self, method, covariance, minimizer_result, fisher_func): 1feabcd

823 

824 if covariance == 'auto': 824 ↛ 830line 824 didn't jump to line 830 because the condition on line 824 was always true1feabcd

825 if hasattr(minimizer_result, 'hess_inv') or hasattr(minimizer_result, 'hess'): 1feabcd

826 covariance = 'minhess' 1feabcd

827 else: 

828 covariance = 'none' 1abcd

829 

830 if covariance == 'fisher': 830 ↛ 831line 830 didn't jump to line 831 because the condition on line 830 was never true1feabcd

831 self.log('use fisher plus prior precision as precision', 2) 

832 if method == 'fisher': 

833 prec = minimizer_result.hess 

834 else: 

835 prec = fisher_func(minimizer_result.x) 

836 cov = _linalg.Chol(prec).ginv() 

837 

838 elif covariance == 'minhess': 1feabcd

839 if hasattr(minimizer_result, 'hess_inv'): 1feabcd

840 hessinv = minimizer_result.hess_inv 1feabcd

841 if isinstance(hessinv, optimize.LbfgsInvHessProduct): 1feabcd

842 self.log(f'convert LBFGS({hessinv.n_corrs}) hessian inverse to BFGS as covariance', 2) 1eabcd

843 cov = self._invhess_lbfgs_to_bfgs(hessinv) 1eabcd

844 # TODO this still gives a too wide cov when the minimization 

845 # terminates due to bad linear search, is it because of 

846 # dropped updates? This is currently keeping me from setting 

847 # l-bfgs-b as default minimization method. 

848 elif isinstance(hessinv, numpy.ndarray): 848 ↛ 863line 848 didn't jump to line 863 because the condition on line 848 was always true1feabcd

849 self.log('use minimizer estimate of inverse hessian as covariance', 2) 1feabcd

850 cov = hessinv 1feabcd

851 elif hasattr(minimizer_result, 'hess'): 851 ↛ 855line 851 didn't jump to line 855 because the condition on line 851 was always true1abcd

852 self.log('use minimizer hessian as precision', 2) 1abcd

853 cov = _linalg.Chol(minimizer_result.hess).ginv() 1abcd

854 else: 

855 raise RuntimeError('the minimizer did not return an estimate of the hessian') 

856 

857 elif covariance == 'none': 857 ↛ 861line 857 didn't jump to line 861 because the condition on line 857 was always true1abcd

858 cov = numpy.full(minimizer_result.x.size, numpy.nan) 1abcd

859 

860 else: 

861 raise KeyError(covariance) 

862 

863 return cov 1feabcd

864 

865 @staticmethod 1feabcd

866 def _invhess_lbfgs_to_bfgs(lbfgs): 1feabcd

867 bfgs = optimize.BFGS() 1eabcd

868 bfgs.initialize(lbfgs.shape[0], 'inv_hess') 1eabcd

869 for i in range(lbfgs.n_corrs): 1eabcd

870 bfgs.update(lbfgs.sk[i], lbfgs.yk[i]) 1eabcd

871 return bfgs.get_matrix() 1eabcd

872 

873 class _Callback: 1feabcd

874 """ Iteration callback for scipy.optimize.minimize """ 

875 

876 def __init__(self, this, functions, timer, unflat): 1feabcd

877 self.it = 0 1feabcd

878 self.stamp = time.perf_counter() 1feabcd

879 self.this = this 1feabcd

880 self.functions = functions 1feabcd

881 self.timer = timer 1feabcd

882 self.unflat = unflat 1feabcd

883 self.tail_overhead = 0 1feabcd

884 self.tail_overhead_iter = 0 1feabcd

885 

886 def __call__(self, intermediate_result, arg2=None): 1feabcd

887 

888 if isinstance(intermediate_result, optimize.OptimizeResult): 888 ↛ 889line 888 didn't jump to line 889 because the condition on line 888 was never true1feabcd

889 p = intermediate_result.x 

890 elif isinstance(intermediate_result, numpy.ndarray): 890 ↛ 893line 890 didn't jump to line 893 because the condition on line 890 was always true1feabcd

891 p = intermediate_result 1feabcd

892 else: 

893 raise TypeError(type(intermediate_result)) 

894 

895 self.it += 1 1feabcd

896 now = time.perf_counter() 1feabcd

897 duration = now - self.stamp 1feabcd

898 

899 worktime = sum(self.timer.partials.values()) 1feabcd

900 if worktime: 1feabcd

901 overhead = duration - worktime 1feabcd

902 assert overhead >= 0, (duration, worktime) 1feabcd

903 if self.it == 1: 1feabcd

904 self.first_overhead = overhead 1feabcd

905 else: 

906 self.tail_overhead_iter += 1 1feabcd

907 self.tail_overhead += overhead 1feabcd

908 

909 # level 3 log 

910 calls = self.this._CountCalls.fmtcalls('partial', self.functions) 1feabcd

911 times = self.fmttime(duration) 1feabcd

912 self.this.log(f'iter {self.it}, time: {times}, calls: {calls}', {3}) 1feabcd

913 

914 # level 4 log 

915 tot = self.fmttime(duration) 1feabcd

916 if self.timer.partials: 1feabcd

917 times = { 1feabcd

918 'gp&cov': self.timer.partials[0], 

919 'dec': self.timer.partials[1], 

920 'like': self.timer.partials[2], 

921 'other': duration - sum(self.timer.partials.values()), 

922 } 

923 times = self.fmttimes(times) 1feabcd

924 else: 

925 times = 'n/d' 1abcd

926 self.this.log(f'\niteration {self.it}', 4) 1feabcd

927 with self.this.loglevel: 1feabcd

928 self.this.log(f'total time: {tot}', 4) 1feabcd

929 self.this.log(f'partial: {times}', 4) 1feabcd

930 self.this.log(f'calls: {calls}', 4) 1feabcd

931 

932 # level 5 log 

933 nicep = self.unflat(p) 1feabcd

934 nicep = self.this._copyasarrayorbufferdict(nicep) 1feabcd

935 with self.this.loglevel: 1feabcd

936 self.this.log(f'parameters = {nicep}', 5) 1feabcd

937 # TODO write a method to format the parameters nicely. => use 

938 # gvar.tabulate? => nope, need actual gvars 

939 # TODO does this logging add significant overhead? 

940 

941 self.stamp = now 1feabcd

942 self.timer.reset() 1feabcd

943 

944 pattern = re.compile( 1feabcd

945 r'((\d+) days, )?(\d{1,2}):(\d\d):(\d\d(\.\d{6})?)') 

946 

947 @classmethod 1feabcd

948 def fmttime(cls, seconds): 1feabcd

949 if seconds < 0: 949 ↛ 950line 949 didn't jump to line 950 because the condition on line 949 was never true1feabcd

950 prefix = '-' 

951 seconds = -seconds 

952 else: 

953 prefix = '' 1feabcd

954 return prefix + cls._fmttime_positive(seconds) 1feabcd

955 

956 @classmethod 1feabcd

957 def _fmttime_positive(cls, seconds): 1feabcd

958 td = datetime.timedelta(seconds=seconds) 1feabcd

959 m = cls.pattern.fullmatch(str(td)) 1feabcd

960 _, day, hour, minute, second, _ = m.groups() 1feabcd

961 hour = int(hour) 1feabcd

962 minute = int(minute) 1feabcd

963 second = float(second) 1feabcd

964 if day: 964 ↛ 965line 964 didn't jump to line 965 because the condition on line 964 was never true1feabcd

965 return f'{day.lstrip("0")}d{hour:02d}h' 

966 elif hour: 966 ↛ 967line 966 didn't jump to line 967 because the condition on line 966 was never true1feabcd

967 return f'{hour}h{minute:02d}m' 

968 elif minute: 1feabcd

969 return f'{minute}m{second:02.0f}s' 1e

970 elif second >= 0.0995: 1feabcd

971 return f'{second:#.2g}'.rstrip('.') + 's' 1feabcd

972 elif second >= 0.0000995: 1feabcd

973 return f'{second * 1e3:#.2g}'.rstrip('.') + 'ms' 1feabcd

974 else: 

975 return f'{second * 1e6:.0f}μs' 1feab

976 

977 @classmethod 1feabcd

978 def fmttimes(cls, times): 1feabcd

979 """ times = dict label -> seconds """ 

980 return ', '.join(f'{k} {cls.fmttime(v)}' for k, v in times.items()) 1feabcd

981 

982 def estimate_firstcall_overhead(self): 1feabcd

983 if self.tail_overhead_iter and hasattr(self, 'first_overhead'): 1feabcd

984 typical_overhead = self.tail_overhead / self.tail_overhead_iter 1feabcd

985 return self.first_overhead - typical_overhead 1feabcd

986 

987 @staticmethod 1feabcd

988 def _copyasarrayorbufferdict(x): 1feabcd

989 if hasattr(x, 'keys'): 1feabcd

990 return gvar.BufferDict(x) 1feabcd

991 else: 

992 return numpy.array(x) 1abcd

993 

994 @staticmethod 1feabcd

995 def _flatview(x): 1feabcd

996 if hasattr(x, 'reshape'): 1feabcd

997 return x.reshape(-1) 1abcd

998 elif hasattr(x, 'buf'): 1feabcd

999 return x.buf 1feabcd

1000 else: # pragma: no cover 

1001 raise NotImplementedError 

1002 

1003 @staticmethod 1feabcd

1004 def _unflatview(x, original): 1feabcd

1005 if isinstance(original, numpy.ndarray): 1feabcd

1006 # TODO is this never applied to jax arrays? 

1007 out = x.reshape(original.shape) 1abcd

1008 # if not out.shape: 

1009 # try: 

1010 # out = out.item() 

1011 # except jax.errors.ConcretizationTypeError: 

1012 # pass 

1013 return out 1abcd

1014 elif isinstance(original, gvar.BufferDict): 1feabcd

1015 # normally I would do BufferDict(original, buf=x) but it does not 

1016 # work with JAX tracers 

1017 b = gvar.BufferDict(original) 1feabcd

1018 b._extension = {} 1feabcd

1019 b._buf = x 1feabcd

1020 # b.buf = x does not work because BufferDict checks that the 

1021 # array is a numpy array 

1022 # TODO maybe make a feature request to gvar to accept array_like 

1023 # buf 

1024 return b 1feabcd

1025 else: # pragma: no cover 

1026 raise NotImplementedError 

1027 

1028 

1029# TODO would it be meaningful to add correlation of the fit result with the data 

1030# and hyperprior? 

1031 

1032# TODO add the second order correction. It probably requires more than the 

1033# gradient and inv_hess, but maybe by getting a little help from 

1034# marginal_likelihood I can use the least-squares optimized second order 

1035# correction on the residuals term and invent something for the logdet term. 

1036 

1037# TODO it raises very often with "Desired error not necessarily achieved due to 

1038# precision loss.". I tried doing a forward grad on the logdet but does not fix 

1039# the problem. I still suspect it's the logdet, maybe the value itself and not 

1040# the derivative, because as the matrix changes the regularization can change a 

1041# lot the value of the logdet. How do I stabilize it? => scipy's l-bfgs-b seems 

1042# to fail the linear search less often 

1043 

1044# TODO compute the logGBF for the whole fit (see the gpbart code). In its doc, 

1045# specify that 1) additional_loss may break the normalization if the user does 

1046# not know what they are doing 2) the calculation of the log determinant term 

1047# heavily depends on the regularization if the covariance matrix is singular; 

1048# this won't happen if there are independent error terms in the model as usual. 

1049 

1050# TODO empbayes_fit(autoeps=True) tries to double epsabs until the minimization 

1051# succedes, with some maximum number of tries. autoeps=dict(maxrepeat=5, 

1052# increasefactor=2, initial=1e-16, startfromzero=True) allows to configure the 

1053# algorithm. 

1054 

1055# TODO empbayes_fit(maxiter=100) sets the maximum number of minimization 

1056# iterations. maxiter=dict(iter=100, calls=200, callsperiter=10) allows to 

1057# configure it more finely. The calls limits are cumulative on all functions 

1058# (need to make a class counter in _CountCalls), I can probably implement them 

1059# by returning nan when the limit is surpassed, I hope the minimizer stops 

1060# immediately on nan (test this). => Callback can raise StopIteration. 

1061 

1062# TODO can I approximate the hessian with only function values and no gradient, 

1063# i.e., when using nelder-mead? => See Hare (2022), although I would not know 

1064# how to apply it properly to the optimization history. Somehow I need to keep 

1065# only the "last" iterations. 

1066 

1067# TODO is there a better algorithm than lbfgs for inaccurate functions? consider 

1068# SC-BFGS (https://github.com/frankecurtis/SCBFGS). See Basak (2022). And NonOpt 

1069# (https://github.com/frankecurtis/NonOpt). 

1070 

1071# TODO can I estimate the error on the likelihood with the matrices? It requires 

1072# the condition number. Basak (2022) gives wide bounds. I could try an upper 

1073# bound and see how it compares to the true error, assuming that the matrix was 

1074# as ill-conditioned as possible, i.e., use eps as the lowest eigenvalue, and 

1075# gershgorin as the highest one. 

1076 

1077# TODO look into jaxopt: it has improved a lot since the last time I saw it. In 

1078# particular, it implements l-bfgs and has a "do not stop on failed line search" 

1079# option. And it probably supports float32, although a skim of the docs suggests 

1080# it does not work well. => See also optimistix. 

1081 

1082# TODO reimplement the timing system with host_callback.id_tap. It should 

1083# preserve the order because id_tap takes inputs and outputs. I must take care 

1084# to make all callbacks happen at runtime instead of having some of them at 

1085# compile time. I tried once but failed. Currently host_callback is 

1086# experimental, maybe wait until it isn't. => I think it fails because it's 

1087# asynchronous and there is only one device. Maybe host_callback.call would 

1088# work? => I think they are developing something like my token machinery. 

1089 

1090# TODO dictionary argument jitkw, arguments passed to jax.jit? 

1091 

1092# TODO parameter float32: bool to use short float type. I think that scipy's 

1093# optimize may break down with short floats with default options, I hope that 

1094# changing termination tolerances does the trick. 

1095 

1096# TODO make separate_jac a parameter 

1097 

1098# TODO add options in _CountCalls to track inputs and/or outputs to some maximum 

1099# buffer length, activate it if the method (after applying user options, 

1100# lowercasing, and inferring minimize's default) is l-bfgs-b and the covariance 

1101# is minhess or auto, to the order specified in the arguments to l-bfgs-b (after 

1102# defaults inference if missing) (add tests in test_fit to check that the 

1103# defaults stay as inferred), to be used if l-bfgs-b returns a crooked hessian. 

1104# --- Alternative: if covariance = 'auto', it could be appropriate to use fisher 

1105# per definition. --- Alternative: add option covariance = 'lbfgs(<order>)' that 

1106# does this for any method, although this would require computing the gradients 

1107# afterwards if the gradient was not used. These alternatives are not mutually 

1108# exclusive. 

1109 

1110# TODO make a helper function/class method that takes in data transf dependent 

1111# on hypers and outputs additional loss (the log jacobian of the appropriate 

1112# function with the appropriate sign)