Coverage for src/lsqfitgp/_linalg/_decomp.py: 85%

248 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-17 13:39 +0000

1# lsqfitgp/_linalg/_decomp.py 

2# 

3# Copyright (c) 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" 

21 

22Copy-pasted from the notes: 

23 

242023-02-14 

25========== 

26 

27My current decomposition system is a mess. I can't take reverse gradients. 

28I can't straightforwardly implement optimized algorithms that compute together 

29likelihood, gradient, and fisher. Jax patterns break down unpredictably. I 

30have to redesign it from scratch. 

31 

32Guidelines and requirements: 

33 

34 - Sufficient modularity to implement composite decompositions (Woodbury, 

35 Block) 

36 

37 - Does not mess up jax in any way 

38 

39 - Caches decompositions 

40 

41 - Favors optimizing together the likelihood and its derivatives 

42 

43Operations (in the following I indicate with lowercase inputs which are 

44typically vectors or tall matrices, and uppercase inputs which are typically 

45large matrices, since optimization requires taking it into account): 

46 

47 pinv_bilinear(A, r) => A'K⁺r (for the posterior mean) 

48 pinv_bilinear_robj(A, r) same but r can be gvars 

49 ginv_quad(A) => A'K⁻A (for the posterior covariance) 

50 I want the pseudoinverse for the mean because the data may not be 

51 in the span and I want to project it orthogonally, while for the 

52 covariance I expect A and K to come from a pos def matrix so they are 

53 coherent 

54 ginv_diagquad(A) => diag(A'K⁻A) (for the posterior variance) 

55 minus_log_normal_density( 

56 r: 1d array, # the residuals (data - prior mean) 

57 dr_vjp: callable, # x -> x_i ∂r_i/∂p_j, gradrev and fishvec 

58 dK_vjp: callable, # x -> x_ij ∂K_ij/∂p_k, gradrev and fishvec 

59 dr_jvp: callable, # x -> ∂r_i/∂r_j x_j, fishvec 

60 dK_jvp: callable, # x -> ∂K_ij/∂p_k x_k, fishvec 

61 dr: 2d array, # ∂r_i/∂p_j, gradfwd and fisher 

62 dK: 3d array, # ∂K_ij/∂p_k, gradfwd and fisher 

63 vec: 1d array, # input vector of fishvec, same size as params 

64 value: bool, 

65 gradrev: bool, 

66 gradfwd: bool, 

67 fisher: bool, 

68 fishvec: bool, 

69 ) 

70 This computes on request 

71 value: 1/2 tr(KK⁺) log 2π 

72 + 1/2 tr(I-KK⁺) log 2π 

73 + 1/2 log pdet K 

74 + 1/2 tr(I-KK⁺) log ε 

75 + 1/2 r'(K⁺+(I-KK⁺)/ε)r 

76 gradrev, 

77 gradfwd: 1/2 tr(K⁺dK) 

78 + r'(K⁺+(I-KK⁺)/ε) dr 

79 - 1/2 r'(K⁺+2(I-KK⁺)/ε)dKK⁺r 

80 fisher: 1/2 tr(K⁺dK(K⁺+2(I-KK⁺)/ε)d'K) 

81 - 2 tr(K⁺dK(I-KK⁺)d'KK⁺) 

82 + dr'(K⁺+(I-KK⁺)/ε)d'r 

83 fishvec: fisher matrix times vec 

84 There should be options for omitting the pieces with ε. I also need a 

85 way to make densities with different values of ε comparable with each 

86 other (may not be possible, if it is, it probably requires a history of 

87 ranks and ε). gradfwd/rev form K⁺ explicitly to compute tr(K⁺dK) for 

88 efficiency. 

89 correlate(x) 

90 Zx where K = ZZ'. 

91 back_correlate(X): 

92 Z'X, this is used by Sandwich and Woodbury. 

93 

94Since I also want to compute the Student density, I could split 

95minus_log_normal_density's return value into logdet and quad. The gradient 

96splits nicely between the two terms, but I have to redo the calculation of the 

97Fisher matrix for the Student distribution. Alternatively, I could use the the 

98Normal Fisher. => See Lange et al. (1989, app. B). => I think I can split the 

99gradient and Fisher matrix too. 

100 

1012023-03-07 

102========== 

103 

104To compute a Fisher-vector product when there are many parameters, do 

105 

106 tr(K+ dK K+ dK) v = 

107 = K_vjp(K+ K_jvp(v) K+) 

108 

109""" 

110 

111# TODO to automatize this further, I could take in a function that generates K 

112# (or its pieces) and the arguments to the function. But how would this play 

113# together with passing decomposition objects as pieces? 

114 

115# TODO split this file by class 

116 

117# TODO Consider using lineax for implementing non-materialized decomps 

118 

119import abc 1feabcd

120import functools 1feabcd

121 

122import numpy 1feabcd

123import jax 1feabcd

124from jax import numpy as jnp 1feabcd

125from jax.scipy import linalg as jlinalg 1feabcd

126from jax import lax 1feabcd

127 

128from .. import _jaxext 1feabcd

129from . import _pytree 1feabcd

130 

131class Decomposition(_pytree.AutoPyTree, abc.ABC): 1feabcd

132 """ 

133 Abstract base class for decompositions of positive semidefinite matrices. 

134 """ 

135 

136 @abc.abstractmethod 1feabcd

137 def __init__(self, *args, **kw): 1feabcd

138 """ Decompose the input matrix """ 

139 pass 

140 

141 @abc.abstractmethod 1feabcd

142 def matrix(self): 1feabcd

143 """ The input matrix """ 

144 pass 

145 

146 @abc.abstractmethod 1feabcd

147 def ginv_linear(self, X): 1feabcd

148 """ Compute K⁻X """ 

149 pass 

150 

151 @abc.abstractmethod 1feabcd

152 def pinv_bilinear(self, A, r): 1feabcd

153 """Compute A'K⁺r.""" 

154 pass 

155 

156 @abc.abstractmethod 1feabcd

157 def pinv_bilinear_robj(self, A, r): 1feabcd

158 """Compute A'K⁺r, where r can be an array of objects.""" 

159 pass 

160 

161 @abc.abstractmethod 1feabcd

162 def ginv_quad(self, A): 1feabcd

163 """Compute A'K⁻A.""" 

164 pass 

165 

166 @abc.abstractmethod 1feabcd

167 def ginv_diagquad(self, A): 1feabcd

168 """Compute diag(A'K⁻A).""" 

169 pass 

170 

171 @abc.abstractmethod 1feabcd

172 def correlate(self, x): 1feabcd

173 """ Compute Zx where K = ZZ' """ 

174 pass 

175 

176 @abc.abstractmethod 1feabcd

177 def back_correlate(self, X): 1feabcd

178 """ Compute Z'X """ 

179 pass 

180 

181 @abc.abstractmethod 1feabcd

182 def pinv_correlate(self, x): 1feabcd

183 """ Compute Z⁺x """ 

184 pass 

185 

186 @abc.abstractmethod 1feabcd

187 def minus_log_normal_density(self, 1feabcd

188 r, 

189 *, 

190 dr_vjp=None, 

191 dK_vjp=None, 

192 dr_jvp_vec=None, 

193 dK_jvp_vec=None, 

194 dr=None, 

195 dK=None, 

196 value=False, 

197 gradrev=False, 

198 gradfwd=False, 

199 fisher=False, 

200 fishvec=False, 

201 ): 

202 """ 

203 Compute minus log a Normal density and its derivatives, with covariance 

204 matrix K. 

205 

206 If an input derivative is not specified, it is assumed to be zero. 

207 

208 Parameters 

209 ---------- 

210 r: 1d array 

211 The residuals (value - mean) 

212 dr_vjp: callable 

213 x -> x_i ∂r_i/∂p_j, for gradrev and fishvec 

214 dK_vjp: callable 

215 x -> x_ij ∂K_ij/∂p_k, for gradrev and fishvec 

216 dr_jvp_vec: 1d array 

217 ∂r_i/∂r_j vec_j, for fishvec 

218 dK_jvp_vec: 2d array 

219 ∂K_ij/∂p_k vec_k, for fishvec 

220 dr: 2d array 

221 ∂r_i/∂p_j for gradfwd and fisher 

222 dK: 3d array 

223 ∂K_ij/∂p_k, for gradfwd and fisher 

224 value: bool 

225 gradrev: bool 

226 gradfwd: bool 

227 fisher: bool 

228 fishvec: bool 

229 These parameters indicate which of the return values to compute. 

230 Default all False. 

231 

232 Returns 

233 ------- 

234 value: 1/2 tr(KK⁺) log 2π 

235 + 1/2 tr(I-KK⁺) log 2π 

236 + 1/2 log pdet K 

237 + 1/2 tr(I-KK⁺) log ε 

238 + 1/2 r'(K⁺+(I-KK⁺)/ε)r 

239 gradrev, 

240 gradfwd: 1/2 tr(K⁺dK) 

241 + r'(K⁺+(I-KK⁺)/ε) dr 

242 - 1/2 r'(K⁺+2(I-KK⁺)/ε)dKK⁺r 

243 fisher: 1/2 tr(K⁺dK(K⁺+2(I-KK⁺)/ε)d'K) 

244 - 2 tr(K⁺dK(I-KK⁺)d'KK⁺) 

245 + dr'(K⁺+(I-KK⁺)/ε)d'r 

246 fishvec: fisher matrix @ vec 

247 """ 

248 pass 

249 

250 def _parseeps(self, K, epsrel, epsabs, maxeigv=None): 1feabcd

251 """ Determine eps from input arguments """ 

252 machine_eps = jnp.finfo(_jaxext.float_type(K)).eps 1feabcd

253 if epsrel == 'auto': 1feabcd

254 epsrel = len(K) * machine_eps 1feabcd

255 if epsabs == 'auto': 255 ↛ 256line 255 didn't jump to line 256 because the condition on line 255 was never true1feabcd

256 epsabs = machine_eps 

257 if maxeigv is None: 257 ↛ 259line 257 didn't jump to line 259 because the condition on line 257 was always true1feabcd

258 maxeigv = eigval_bound(K) 1feabcd

259 self._eps = epsrel * maxeigv + epsabs 1feabcd

260 return self._eps 1feabcd

261 

262 @property 1feabcd

263 def eps(self): 1feabcd

264 """ 

265 The threshold below which eigenvalues are too small to be determined. 

266 """ 

267 return self._eps 1abcd

268 

269 @property 1feabcd

270 @abc.abstractmethod 1feabcd

271 def n(self): 1feabcd

272 """ Number of rows/columns of the matrix """ 

273 pass 

274 

275 @property 1feabcd

276 @abc.abstractmethod 1feabcd

277 def m(self): 1feabcd

278 """ Number of columns of Z """ 

279 pass 

280 

281 def ginv(self): 1feabcd

282 """ Compute K⁻ """ 

283 return self.ginv_quad(jnp.eye(self.n)) 1abcd

284 

285def solve_triangular_python(a, b, *, lower=False): 1feabcd

286 """ 

287 Pure python implementation of scipy.linalg.solve_triangular for when 

288 a or b are object arrays. 

289 """ 

290 # TODO maybe commit this to gvar.linalg 

291 a = numpy.asarray(a) 1feabcd

292 x = numpy.copy(b) 1feabcd

293 

294 vec = x.ndim < 2 1feabcd

295 if vec: 295 ↛ 298line 295 didn't jump to line 298 because the condition on line 295 was always true1feabcd

296 x = x[:, None] 1feabcd

297 

298 n = a.shape[-1] 1feabcd

299 assert x.shape[-2] == n 1feabcd

300 

301 if not lower: 301 ↛ 302line 301 didn't jump to line 302 because the condition on line 301 was never true1feabcd

302 a = a[..., ::-1, ::-1] 

303 x = x[..., ::-1, :] 

304 

305 x[..., 0, :] /= a[..., 0, 0, None] 1feabcd

306 for i in range(1, n): 1feabcd

307 x[..., i:, :] -= x[..., None, i - 1, :] * a[..., i:, i - 1, None] 1feabcd

308 x[..., i, :] /= a[..., i, i, None] 1feabcd

309 

310 if not lower: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true1feabcd

311 x = x[..., ::-1, :] 

312 

313 if vec: 313 ↛ 315line 313 didn't jump to line 315 because the condition on line 313 was always true1feabcd

314 x = numpy.squeeze(x, -1) 1feabcd

315 return x 1feabcd

316 

317def solve_triangular_batched(a, b, *, lower=False): 1feabcd

318 """ Version of jax.scipy.linalg.solve_triangular that batches matmul-like """ 

319 a = jnp.asarray(a) 1abcd

320 b = jnp.asarray(b) 1abcd

321 vec = b.ndim < 2 1abcd

322 if vec: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true1abcd

323 b = b[:, None] 

324 

325 batch_shape = jnp.broadcast_shapes(a.shape[:-2], b.shape[:-2]) 1abcd

326 a_shape = batch_shape + a.shape[-2:] 1abcd

327 b_shape = batch_shape + b.shape[-2:] 1abcd

328 result = lax.linalg.triangular_solve( 1abcd

329 jnp.broadcast_to(a, a_shape), jnp.broadcast_to(b, b_shape), 

330 left_side=True, lower=lower, 

331 ) 

332 assert result.shape == b_shape 1abcd

333 

334 if vec: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true1abcd

335 result = result.squeeze(-1) 

336 return result 1abcd

337 

338def solve_batched(a, b, **kw): 1feabcd

339 """ Version of jax.scipy.linalg.solve that batches matmul-like """ 

340 a = jnp.asarray(a) 1abcd

341 b = jnp.asarray(b) 1abcd

342 vec = b.ndim < 2 1abcd

343 if vec: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true1abcd

344 b = b[:, None] 

345 

346 @functools.partial(jnp.vectorize, signature='(i,j),(j,k)->(i,k)') 1abcd

347 def solve_batched(a, b): 1abcd

348 return jlinalg.solve(a, b, **kw) 1abcd

349 result = solve_batched(a, b) 1abcd

350 

351 if vec: 351 ↛ 352line 351 didn't jump to line 352 because the condition on line 351 was never true1abcd

352 result = result.squeeze(-1) 

353 return result 1abcd

354 

355def eigval_bound(K): 1feabcd

356 """ 

357 Upper bound on the largest magnitude eigenvalue of the matrix, from 

358 Gershgorin's theorem. 

359 """ 

360 return jnp.max(jnp.sum(jnp.abs(K), axis=1)) 1feabcd

361 

362def diag_scale_pow2(K): 1feabcd

363 """ 

364 Compute a vector s of powers of 2 such that diag(K / outer(s, s)) ~ 1. 

365 """ 

366 d = jnp.diag(K) 1feabcd

367 return jnp.where(d, jnp.exp2(jnp.rint(0.5 * jnp.log2(d))), 1) 1feabcd

368 

369 # Golub and Van Loan (2013) say this is not a totally general heuristic 

370 

371def transpose(x): 1feabcd

372 """ swap the last two axes of array x, corresponds to matrix tranposition 

373 with the broadcasting convention of matmul """ 

374 if x.ndim < 2: 

375 return x 

376 elif isinstance(x, jnp.ndarray): 

377 return jnp.swapaxes(x, -2, -1) 

378 else: 

379 # need to support numpy because this function is used with gvars 

380 return numpy.swapaxes(x, -2, -1) 

381 

382class Chol(Decomposition): 1feabcd

383 """Cholesky decomposition. The matrix is regularized adding a small multiple 

384 of the identity.""" 

385 

386 def __init__(self, K, *, epsrel='auto', epsabs=0): 1feabcd

387 # K <- K + Iε 

388 # K = LL' 

389 self._K = K 1feabcd

390 s = diag_scale_pow2(K) 1feabcd

391 K = K / s / s[:, None] 1feabcd

392 eps = self._parseeps(K, epsrel, epsabs) 1feabcd

393 K = K.at[jnp.diag_indices_from(K)].add(eps) 1feabcd

394 L = jlinalg.cholesky(K, lower=True) 1feabcd

395 with _jaxext.skipifabstract(): 1feabcd

396 if not jnp.all(jnp.isfinite(L)): 1feabcd

397 # TODO check that jax fills with nan after failed row, detect 

398 # and report minor index like scipy 

399 raise numpy.linalg.LinAlgError('cholesky decomposition not finite, probably matrix not pos def numerically') 1abcd

400 self._L = L * s[:, None] 1feabcd

401 self._eps = eps * jnp.min(s * s) 1feabcd

402 

403 def matrix(self): 1feabcd

404 return self._K 1feabcd

405 

406 def ginv_linear(self, X): 1feabcd

407 # = K⁻¹X 

408 # K⁻¹ = L'⁻¹L⁻¹ 

409 # K⁻¹X = L'⁻¹(L⁻¹X) 

410 invLX = jlinalg.solve_triangular(self._L, X, lower=True) 1abcd

411 return jlinalg.solve_triangular(self._L.T, invLX, lower=False) 1abcd

412 

413 def pinv_bilinear(self, A, r): 1feabcd

414 # = A'K⁻¹r = A'L'⁻¹L⁻¹r = (L⁻¹A)'(L⁻¹r) 

415 invLr = jlinalg.solve_triangular(self._L, r, lower=True) 1feabcd

416 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd

417 return invLA.T @ invLr 1feabcd

418 

419 def pinv_bilinear_robj(self, A, r): 1feabcd

420 # = A'K⁻¹r 

421 invLr = solve_triangular_python(self._L, r, lower=True) 1feabcd

422 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd

423 return numpy.asarray(invLA).T @ invLr 1feabcd

424 

425 def ginv_quad(self, A): 1feabcd

426 # = A'K⁻¹A = A'K⁻¹A = A'L'⁻¹L⁻¹A = (L⁻¹A)'(L⁻¹A) 

427 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd

428 return invLA.T @ invLA 1feabcd

429 

430 def ginv_diagquad(self, A): 1feabcd

431 # = diag(A'K⁻¹A) 

432 # X = L⁻¹A 

433 # diag(A'K⁻¹A)_i = diag(X'X)_i = ∑_j X'_ij X_ji = ∑_j X_ji X_ji 

434 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1abcd

435 return jnp.einsum('ji,ji->i', invLA, invLA) 1abcd

436 

437 def correlate(self, x): 1feabcd

438 # = Lx 

439 return self._L @ x 1feabcd

440 

441 def back_correlate(self, X): 1feabcd

442 # = L'X 

443 return self._L.T @ X 1abcd

444 

445 def pinv_correlate(self, x): 1feabcd

446 # = L⁻¹x 

447 return jlinalg.solve_triangular(self._L, x, lower=True) 1feabcd

448 

449 def minus_log_normal_density(self, 1feabcd

450 r, # 1d array, the residuals (data - prior mean) 

451 *, 

452 dr_vjp=None, # callable, x -> x_i ∂r_i/∂p_j, gradrev and fishvec 

453 dK_vjp=None, # callable, x -> x_ij ∂K_ij/∂p_k, gradrev and fishvec 

454 dr_jvp_vec=None, # 1d array, ∂r_i/∂r_j v_j, fishvec 

455 dK_jvp_vec=None, # 2d array, ∂K_ij/∂p_k v_k, fishvec 

456 dr=None, # 2d array, ∂r_i/∂p_j, gradfwd and fisher 

457 dK=None, # 3d array, ∂K_ij/∂p_k, gradfwd and fisher 

458 value=False, 

459 gradrev=False, 

460 gradfwd=False, 

461 fisher=False, 

462 fishvec=False, 

463 ): 

464 

465 L = self._L 1feabcd

466 

467 out = {} 1feabcd

468 

469 # compute shared factors 

470 grad = ( 1feabcd

471 (gradrev and (dK_vjp is not None or dr_vjp is not None)) 

472 or (gradfwd and (dK is not None or dr is not None)) 

473 ) 

474 if value or grad: 1feabcd

475 invLr = jlinalg.solve_triangular(L, r, lower=True) 1feabcd

476 if grad: 1feabcd

477 invKr = jlinalg.solve_triangular(L.T, invLr, lower=False) 1feabcd

478 if (gradrev and dK_vjp is not None) or (gradfwd and dK is not None): 1feabcd

479 invL = jlinalg.solve_triangular(L, jnp.eye(len(L)), lower=True) 1feabcd

480 invK = invL.T @ invL 1feabcd

481 

482 if value: 1feabcd

483 # = 1/2 n log 2π 

484 # + 1/2 log det K 

485 # + 1/2 r'K⁻¹r 

486 # K = LL' 

487 # K⁻¹ = L'⁻¹L⁻¹ 

488 # det K = (det L)² = 

489 # = (∏_i L_ii)² 

490 # r'K⁻¹r = r'L'⁻¹L⁻¹r = 

491 # = (L⁻¹r)'(L⁻¹r) 

492 out['value'] = 1/2 * ( 1feabcd

493 len(L) * jnp.log(2 * jnp.pi) + 

494 2 * jnp.sum(jnp.log(jnp.diag(L))) + 

495 invLr @ invLr 

496 ) 

497 else: 

498 out['value'] = None 1abcd

499 

500 if gradrev: 1feabcd

501 # = 1/2 tr(K⁻¹dK) 

502 # + r'K⁻¹dr 

503 # - 1/2 r'K⁻¹dKK⁻¹r 

504 # tr(K⁻¹dK) = K⁻¹_ij dK_ji = 

505 # = K⁻¹_ij dK_ij = 

506 # = dK_vjp(K⁻¹) 

507 # r'K⁻¹dr = r_i K⁻¹_ij dr_j = 

508 # = (K⁻¹r)_j dr_j = 

509 # = dr_vjp(K⁻¹r) 

510 # r'K⁻¹dKK⁻¹r = r_i K⁻¹_ij dK_jl K⁻¹_lm r_m = 

511 # = (K⁻¹r)_j dK_jl (K⁻¹r)_l = 

512 # = dK_vjp((K⁻¹r) ⊗ (K⁻¹r)) 

513 out['gradrev'] = 0 1feabcd

514 if dK_vjp is not None: 514 ↛ 518line 514 didn't jump to line 518 because the condition on line 514 was always true1feabcd

515 tr_invK_dK = dK_vjp(invK) 1feabcd

516 r_invK_dK_invK_r = dK_vjp(jnp.outer(invKr, invKr)) 1feabcd

517 out['gradrev'] += 1/2 * (tr_invK_dK - r_invK_dK_invK_r) 1feabcd

518 if dr_vjp is not None: 518 ↛ 524line 518 didn't jump to line 524 because the condition on line 518 was always true1feabcd

519 r_invK_dr = dr_vjp(invKr) 1feabcd

520 out['gradrev'] += r_invK_dr 1feabcd

521 else: 

522 out['gradrev'] = None 1eabcd

523 

524 if gradfwd: 1feabcd

525 # = 1/2 tr(K⁻¹dK) 

526 # + r'K⁻¹dr 

527 # - 1/2 r'K⁻¹dKK⁻¹r 

528 # tr(K⁻¹dK)_k = K⁻¹_ij dK_ijk 

529 # r'K⁻¹dr = (K⁻¹r)'dr 

530 # (r'K⁻¹dKK⁻¹r)_k = r_i K⁻¹_ij dK_jlk K⁻¹_lm r_m = 

531 # = (K⁻¹r)_j dK_jlk (K⁻¹r)_l 

532 out['gradfwd'] = 0 1eabcd

533 if dK is not None: 533 ↛ 537line 533 didn't jump to line 537 because the condition on line 533 was always true1eabcd

534 tr_invK_dK = jnp.einsum('ij,ijk->k', invK, dK) 1eabcd

535 r_invK_dK_invK_r = jnp.einsum('i,ijk,j->k', invKr, dK, invKr) 1eabcd

536 out['gradfwd'] += 1/2 * (tr_invK_dK - r_invK_dK_invK_r) 1eabcd

537 if dr is not None: 537 ↛ 543line 537 didn't jump to line 543 because the condition on line 537 was always true1eabcd

538 r_invK_dr = invKr @ dr 1eabcd

539 out['gradfwd'] += r_invK_dr 1eabcd

540 else: 

541 out['gradfwd'] = None 1feabcd

542 

543 if fisher: 1feabcd

544 # = 1/2 tr(K⁻¹dKK⁻¹d'K) 

545 # + dr'K⁻¹d'r 

546 # tr(K⁻¹dKK⁻¹d'K)_ij = tr(L'⁻¹L⁻¹dKL'⁻¹L⁻¹d'K)_ij = 

547 # = tr(L⁻¹dKL'⁻¹L⁻¹d'KL'⁻¹)_ij = 

548 # = (L⁻¹dKL'⁻¹)_kli (L⁻¹dKL'⁻¹)_klj 

549 # (L⁻¹dKL'⁻¹)_ijk = L⁻¹_il dK_lmk L'⁻¹_mj = 

550 # = L⁻¹_il L⁻¹_jm dK_lmk 

551 # (dr'K⁻¹d'r)_kq = dr'_k L'⁻¹L⁻¹dr_q = 

552 # = (L⁻¹dr_k)_i (L⁻¹dr_q)_i 

553 out['fisher'] = 0 1abcd

554 if dK is not None: 554 ↛ 563line 554 didn't jump to line 563 because the condition on line 554 was always true1abcd

555 invL_dK = solve_triangular_batched(L, 1abcd

556 jnp.moveaxis(dK, 2, 0), 

557 lower=True) # kim: L⁻¹_il dK_lmk 

558 invL_dK_invL = solve_triangular_batched(L, 1abcd

559 jnp.swapaxes(invL_dK, 1, 2), 

560 lower=True) # kji: L⁻¹_jm (L⁻¹_il dK_lmk) 

561 tr_invK_dK_invK_dK = jnp.einsum('kij,qij->kq', invL_dK_invL, invL_dK_invL) 1abcd

562 out['fisher'] += 1/2 * tr_invK_dK_invK_dK 1abcd

563 if dr is not None: 563 ↛ 570line 563 didn't jump to line 570 because the condition on line 563 was always true1abcd

564 invLdr = jlinalg.solve_triangular(L, dr, lower=True) 1abcd

565 dr_invK_dr = invLdr.T @ invLdr 1abcd

566 out['fisher'] += dr_invK_dr 1abcd

567 else: 

568 out['fisher'] = None 1feabcd

569 

570 if fishvec: 1feabcd

571 # = 1/2 tr(K⁻¹dKK⁻¹d'K) v 

572 # + dr'K⁻¹d'r v 

573 # tr(K⁻¹dKK⁻¹d'K) v = K_vjp(K⁻¹K_jvp(v)K⁻¹) = 

574 # = K_vjp(L'⁻¹L⁻¹ K_jvp(v) L'⁻¹L⁻¹) 

575 # dr'K⁻¹d'r v = dr'K⁻¹dr_jvp(v) = 

576 # = dr_vjp(K⁻¹dr_jvp(v)) = 

577 # = dr_vjp(L'⁻¹L⁻¹ dr_jvp(v)) 

578 out['fishvec'] = 0 1abcd

579 if not (dK_jvp_vec is None and dK_vjp is None): 579 ↛ 586line 579 didn't jump to line 586 because the condition on line 579 was always true1abcd

580 invL_dKv = jlinalg.solve_triangular(L, dK_jvp_vec, lower=True) 1abcd

581 invK_dKv = jlinalg.solve_triangular(L.T, invL_dKv, lower=False) 1abcd

582 invL_dKv_invK = jlinalg.solve_triangular(L, invK_dKv.T, lower=True) 1abcd

583 invK_dKv_invK = jlinalg.solve_triangular(L.T, invL_dKv_invK, lower=False) 1abcd

584 tr_invK_dK_invK_dK_v = dK_vjp(invK_dKv_invK) 1abcd

585 out['fishvec'] += 1/2 * tr_invK_dK_invK_dK_v 1abcd

586 if not (dr_jvp_vec is None and dr_vjp is None): 586 ↛ 594line 586 didn't jump to line 594 because the condition on line 586 was always true1abcd

587 invL_drv = jlinalg.solve_triangular(L, dr_jvp_vec, lower=True) 1abcd

588 invK_drv = jlinalg.solve_triangular(L.T, invL_drv, lower=False) 1abcd

589 dr_invK_drv_v = dr_vjp(invK_drv) 1abcd

590 out['fishvec'] += dr_invK_drv_v 1abcd

591 else: 

592 out['fishvec'] = None 1feabcd

593 

594 return tuple(out.values()) # TODO a namedtuple 1feabcd

595 

596 @classmethod 1feabcd

597 def make_derivs(cls, 1feabcd

598 K_fun, r_fun, primal, 

599 *, 

600 args=(), 

601 kw={}, 

602 vec=None, 

603 value=False, 

604 gradrev=False, 

605 gradfwd=False, 

606 fisher=False, 

607 fishvec=False, 

608 ): 

609 """ 

610 Prepares arguments for `minus_log_normal_density`. 

611 

612 Parameters 

613 ---------- 

614 K_fun, r_fun : callable 

615 Functions with signature ``f(primal, *args, **kw)`` that produce the 

616 `K` init argument and the `r` `minus_log_normal_density` argument. 

617 primal : 1d array 

618 The first argument to `K_fun` and `r_fun`. 

619 args : tuple 

620 Additional positional arguments to `K_fun` and `r_fun`. 

621 kw : dict 

622 Keyword arguments to `K_fun` and `r_fun`. 

623 vec : 1d array 

624 A tangent vector to compute the jacobian-vector products. 

625 value, gradrev, gradfwd, fisher, fishvec : bool 

626 Arguments to `minus_log_normal_density`, used to determine which 

627 derivatives are needed. 

628 

629 Returns 

630 ------- 

631 K : 2d array 

632 Output of `K_fun`. 

633 r : 1d array 

634 Output of `r_fun`. 

635 out : dict 

636 Dictionary with derivative arguments to `minus_log_normal_density`. 

637 """ 

638 

639 partial = lambda f: lambda x: f(x, *args, **kw) 1abcd

640 K_fun = partial(K_fun) 1abcd

641 r_fun = partial(r_fun) 1abcd

642 

643 out = {} 1abcd

644 

645 if gradrev or fishvec: 1abcd

646 K, dK_vjp = jax.vjp(K_fun, primal) 1abcd

647 r, dr_vjp = jax.vjp(r_fun, primal) 1abcd

648 out['dK_vjp'] = lambda x: dK_vjp(x)[0] 1abcd

649 out['dr_vjp'] = lambda x: dr_vjp(x)[0] 1abcd

650 else: 

651 K = K_fun(primal) 1abcd

652 r = r_fun(primal) 1abcd

653 if fishvec: 1abcd

654 _, out['dK_jvp_vec'] = jax.jvp(K_fun, (primal,), (vec,)) 1abcd

655 _, out['dr_jvp_vec'] = jax.jvp(r_fun, (primal,), (vec,)) 1abcd

656 if gradfwd or fisher: 1abcd

657 out['dK'] = jax.jacfwd(K_fun)(primal) 1abcd

658 out['dr'] = jax.jacfwd(r_fun)(primal) 1abcd

659 

660 return K, r, out 1abcd

661 

662 @property 1feabcd

663 def n(self): 1feabcd

664 return len(self._L) 1feabcd

665 

666 m = n 1feabcd