Coverage for src/lsqfitgp/_linalg/_decomp.py: 85%

248 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_linalg/_decomp.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" 

21 

22Copy-pasted from the notes: 

23 

242023-02-14 

25========== 

26 

27My current decomposition system is a mess. I can't take reverse gradients. 

28I can't straightforwardly implement optimized algorithms that compute together 

29likelihood, gradient, and fisher. Jax patterns break down unpredictably. I 

30have to redesign it from scratch. 

31 

32Guidelines and requirements: 

33 

34 - Sufficient modularity to implement composite decompositions (Woodbury, 

35 Block) 

36 

37 - Does not mess up jax in any way 

38 

39 - Caches decompositions 

40 

41 - Favors optimizing together the likelihood and its derivatives 

42 

43Operations (in the following I indicate with lowercase inputs which are 

44typically vectors or tall matrices, and uppercase inputs which are typically 

45large matrices, since optimization requires taking it into account): 

46 

47 pinv_bilinear(A, r) => A'K⁺r (for the posterior mean) 

48 pinv_bilinear_robj(A, r) same but r can be gvars 

49 ginv_quad(A) => A'K⁻A (for the posterior covariance) 

50 I want the pseudoinverse for the mean because the data may not be 

51 in the span and I want to project it orthogonally, while for the 

52 covariance I expect A and K to come from a pos def matrix so they are 

53 coherent 

54 ginv_diagquad(A) => diag(A'K⁻A) (for the posterior variance) 

55 minus_log_normal_density( 

56 r: 1d array, # the residuals (data - prior mean) 

57 dr_vjp: callable, # x -> x_i ∂r_i/∂p_j, gradrev and fishvec 

58 dK_vjp: callable, # x -> x_ij ∂K_ij/∂p_k, gradrev and fishvec 

59 dr_jvp: callable, # x -> ∂r_i/∂r_j x_j, fishvec 

60 dK_jvp: callable, # x -> ∂K_ij/∂p_k x_k, fishvec 

61 dr: 2d array, # ∂r_i/∂p_j, gradfwd and fisher 

62 dK: 3d array, # ∂K_ij/∂p_k, gradfwd and fisher 

63 vec: 1d array, # input vector of fishvec, same size as params 

64 value: bool, 

65 gradrev: bool, 

66 gradfwd: bool, 

67 fisher: bool, 

68 fishvec: bool, 

69 ) 

70 This computes on request 

71 value: 1/2 tr(KK⁺) log 2π 

72 + 1/2 tr(I-KK⁺) log 2π 

73 + 1/2 log pdet K 

74 + 1/2 tr(I-KK⁺) log ε 

75 + 1/2 r'(K⁺+(I-KK⁺)/ε)r 

76 gradrev, 

77 gradfwd: 1/2 tr(K⁺dK) 

78 + r'(K⁺+(I-KK⁺)/ε) dr 

79 - 1/2 r'(K⁺+2(I-KK⁺)/ε)dKK⁺r 

80 fisher: 1/2 tr(K⁺dK(K⁺+2(I-KK⁺)/ε)d'K) 

81 - 2 tr(K⁺dK(I-KK⁺)d'KK⁺) 

82 + dr'(K⁺+(I-KK⁺)/ε)d'r 

83 fishvec: fisher matrix times vec 

84 There should be options for omitting the pieces with ε. I also need a 

85 way to make densities with different values of ε comparable with each 

86 other (may not be possible, if it is, it probably requires a history of 

87 ranks and ε). gradfwd/rev form K⁺ explicitly to compute tr(K⁺dK) for 

88 efficiency. 

89 correlate(x) 

90 Zx where K = ZZ'. 

91 back_correlate(X): 

92 Z'X, this is used by Sandwich and Woodbury. 

93 

94Since I also want to compute the Student density, I could split 

95minus_log_normal_density's return value into logdet and quad. The gradient 

96splits nicely between the two terms, but I have to redo the calculation of the 

97Fisher matrix for the Student distribution. Alternatively, I could use the the 

98Normal Fisher. => See Lange et al. (1989, app. B). => I think I can split the 

99gradient and Fisher matrix too. 

100 

1012023-03-07 

102========== 

103 

104To compute a Fisher-vector product when there are many parameters, do 

105 

106 tr(K+ dK K+ dK) v = 

107 = K_vjp(K+ K_jvp(v) K+) 

108 

109""" 

110 

111# TODO to automatize this further, I could take in a function that generates K 

112# (or its pieces) and the arguments to the function. But how would this play 

113# together with passing decomposition objects as pieces? 

114 

115# TODO split this file by class 

116 

117# TODO Consider using lineax for implementing non-materialized decomps 

118 

119import abc 1feabcd

120import functools 1feabcd

121 

122import numpy 1feabcd

123import jax 1feabcd

124from jax import numpy as jnp 1feabcd

125from jax.scipy import linalg as jlinalg 1feabcd

126from jax import lax 1feabcd

127 

128from .. import _jaxext 1feabcd

129from . import _pytree 1feabcd

130 

131class Decomposition(_pytree.AutoPyTree, abc.ABC): 1feabcd

132 """ 

133 Abstract base class for decompositions of positive semidefinite matrices. 

134 """ 

135 

136 @abc.abstractmethod 1feabcd

137 def __init__(self, *args, **kw): 1feabcd

138 """ Decompose the input matrix """ 

139 pass 

140 

141 @abc.abstractmethod 1feabcd

142 def matrix(self): 1feabcd

143 """ The input matrix """ 

144 pass 

145 

146 @abc.abstractmethod 1feabcd

147 def ginv_linear(self, X): 1feabcd

148 """ Compute K⁻X """ 

149 pass 

150 

151 @abc.abstractmethod 1feabcd

152 def pinv_bilinear(self, A, r): 1feabcd

153 """Compute A'K⁺r.""" 

154 pass 

155 

156 @abc.abstractmethod 1feabcd

157 def pinv_bilinear_robj(self, A, r): 1feabcd

158 """Compute A'K⁺r, where r can be an array of objects.""" 

159 pass 

160 

161 @abc.abstractmethod 1feabcd

162 def ginv_quad(self, A): 1feabcd

163 """Compute A'K⁻A.""" 

164 pass 

165 

166 @abc.abstractmethod 1feabcd

167 def ginv_diagquad(self, A): 1feabcd

168 """Compute diag(A'K⁻A).""" 

169 pass 

170 

171 @abc.abstractmethod 1feabcd

172 def correlate(self, x): 1feabcd

173 """ Compute Zx where K = ZZ' """ 

174 pass 

175 

176 @abc.abstractmethod 1feabcd

177 def back_correlate(self, X): 1feabcd

178 """ Compute Z'X """ 

179 pass 

180 

181 @abc.abstractmethod 1feabcd

182 def pinv_correlate(self, x): 1feabcd

183 """ Compute Z⁺x """ 

184 pass 

185 

186 @abc.abstractmethod 1feabcd

187 def minus_log_normal_density(self, 1feabcd

188 r, 

189 *, 

190 dr_vjp=None, 

191 dK_vjp=None, 

192 dr_jvp=None, 

193 dK_jvp=None, 

194 dr=None, 

195 dK=None, 

196 value=False, 

197 gradrev=False, 

198 gradfwd=False, 

199 fisher=False, 

200 fishvec=False, 

201 ): 

202 """ 

203 Compute minus log a Normal density and its derivatives, with covariance 

204 matrix K. 

205 

206 If an input derivative is not specified, it is assumed to be zero. 

207 

208 Parameters 

209 ---------- 

210 r: 1d array 

211 The residuals (value - mean) 

212 dr_vjp: callable 

213 x -> x_i ∂r_i/∂p_j, for gradrev and fishvec 

214 dK_vjp: callable 

215 x -> x_ij ∂K_ij/∂p_k, for gradrev and fishvec 

216 dr_jvp_vec: 1d array 

217 ∂r_i/∂r_j vec_j, for fishvec 

218 dK_jvp_vec: 2d array 

219 ∂K_ij/∂p_k vec_k, for fishvec 

220 dr: 2d array 

221 ∂r_i/∂p_j for gradfwd and fisher 

222 dK: 3d array 

223 ∂K_ij/∂p_k, for gradfwd and fisher 

224 value: bool 

225 gradrev: bool 

226 gradfwd: bool 

227 fisher: bool 

228 fishvec: bool 

229 These parameters indicate which of the return values to compute. 

230 Default all False. 

231 

232 Returns 

233 ------- 

234 value: 1/2 tr(KK⁺) log 2π 

235 + 1/2 tr(I-KK⁺) log 2π 

236 + 1/2 log pdet K 

237 + 1/2 tr(I-KK⁺) log ε 

238 + 1/2 r'(K⁺+(I-KK⁺)/ε)r 

239 gradrev, 

240 gradfwd: 1/2 tr(K⁺dK) 

241 + r'(K⁺+(I-KK⁺)/ε) dr 

242 - 1/2 r'(K⁺+2(I-KK⁺)/ε)dKK⁺r 

243 fisher: 1/2 tr(K⁺dK(K⁺+2(I-KK⁺)/ε)d'K) 

244 - 2 tr(K⁺dK(I-KK⁺)d'KK⁺) 

245 + dr'(K⁺+(I-KK⁺)/ε)d'r 

246 fishvec: fisher matrix @ vec 

247 """ 

248 pass 

249 

250 def _parseeps(self, K, epsrel, epsabs, maxeigv=None): 1feabcd

251 """ Determine eps from input arguments """ 

252 machine_eps = jnp.finfo(_jaxext.float_type(K)).eps 1feabcd

253 if epsrel == 'auto': 1feabcd

254 epsrel = len(K) * machine_eps 1feabcd

255 if epsabs == 'auto': 255 ↛ 256line 255 didn't jump to line 256 because the condition on line 255 was never true1feabcd

256 epsabs = machine_eps 

257 if maxeigv is None: 257 ↛ 259line 257 didn't jump to line 259 because the condition on line 257 was always true1feabcd

258 maxeigv = eigval_bound(K) 1feabcd

259 self._eps = epsrel * maxeigv + epsabs 1feabcd

260 return self._eps 1feabcd

261 

262 @property 1feabcd

263 def eps(self): 1feabcd

264 """ 

265 The threshold below which eigenvalues are too small to be determined. 

266 """ 

267 return self._eps 1abcd

268 

269 @property 1feabcd

270 @abc.abstractmethod 1feabcd

271 def n(self): 1feabcd

272 """ Number of rows/columns of the matrix """ 

273 pass 

274 

275 @property 1feabcd

276 @abc.abstractmethod 1feabcd

277 def m(self): 1feabcd

278 """ Number of columns of Z """ 

279 pass 

280 

281 def ginv(self): 1feabcd

282 """ Compute K⁻ """ 

283 return self.ginv_quad(jnp.eye(self.n)) 1abcd

284 

285def solve_triangular_python(a, b, *, lower=False): 1feabcd

286 """ 

287 Pure python implementation of scipy.linalg.solve_triangular for when 

288 a or b are object arrays. 

289 """ 

290 # TODO maybe commit this to gvar.linalg 

291 a = numpy.asarray(a) 1feabcd

292 x = numpy.copy(b) 1feabcd

293 

294 vec = x.ndim < 2 1feabcd

295 if vec: 295 ↛ 298line 295 didn't jump to line 298 because the condition on line 295 was always true1feabcd

296 x = x[:, None] 1feabcd

297 

298 n = a.shape[-1] 1feabcd

299 assert x.shape[-2] == n 1feabcd

300 

301 if not lower: 301 ↛ 302line 301 didn't jump to line 302 because the condition on line 301 was never true1feabcd

302 a = a[..., ::-1, ::-1] 

303 x = x[..., ::-1, :] 

304 

305 x[..., 0, :] /= a[..., 0, 0, None] 1feabcd

306 for i in range(1, n): 1feabcd

307 x[..., i:, :] -= x[..., None, i - 1, :] * a[..., i:, i - 1, None] 1feabcd

308 x[..., i, :] /= a[..., i, i, None] 1feabcd

309 

310 if not lower: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true1feabcd

311 x = x[..., ::-1, :] 

312 

313 if vec: 313 ↛ 315line 313 didn't jump to line 315 because the condition on line 313 was always true1feabcd

314 x = numpy.squeeze(x, -1) 1feabcd

315 return x 1feabcd

316 

317def solve_triangular_batched(a, b, *, lower=False): 1feabcd

318 """ Version of jax.scipy.linalg.solve_triangular that batches matmul-like """ 

319 a = jnp.asarray(a) 1abcd

320 b = jnp.asarray(b) 1abcd

321 vec = b.ndim < 2 1abcd

322 if vec: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true1abcd

323 b = b[:, None] 

324 

325 batch_shape = jnp.broadcast_shapes(a.shape[:-2], b.shape[:-2]) 1abcd

326 a_shape = batch_shape + a.shape[-2:] 1abcd

327 b_shape = batch_shape + b.shape[-2:] 1abcd

328 result = lax.linalg.triangular_solve( 1abcd

329 jnp.broadcast_to(a, a_shape), jnp.broadcast_to(b, b_shape), 

330 left_side=True, lower=lower, 

331 ) 

332 assert result.shape == b_shape 1abcd

333 

334 if vec: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true1abcd

335 result = result.squeeze(-1) 

336 return result 1abcd

337 

338def solve_batched(a, b, **kw): 1feabcd

339 """ Version of jax.scipy.linalg.solve that batches matmul-like """ 

340 a = jnp.asarray(a) 1abcd

341 b = jnp.asarray(b) 1abcd

342 vec = b.ndim < 2 1abcd

343 if vec: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true1abcd

344 b = b[:, None] 

345 

346 @functools.partial(jnp.vectorize, signature='(i,j),(j,k)->(i,k)') 1abcd

347 def solve_batched(a, b): 1abcd

348 return jlinalg.solve(a, b, **kw) 1abcd

349 result = solve_batched(a, b) 1abcd

350 

351 if vec: 351 ↛ 352line 351 didn't jump to line 352 because the condition on line 351 was never true1abcd

352 result = result.squeeze(-1) 

353 return result 1abcd

354 

355def eigval_bound(K): 1feabcd

356 """ 

357 Upper bound on the largest magnitude eigenvalue of the matrix, from 

358 Gershgorin's theorem. 

359 """ 

360 return jnp.max(jnp.sum(jnp.abs(K), axis=1)) 1feabcd

361 

362def diag_scale_pow2(K): 1feabcd

363 """ 

364 Compute a vector s of powers of 2 such that diag(K / outer(s, s)) ~ 1. 

365 """ 

366 d = jnp.diag(K) 1feabcd

367 return jnp.where(d, jnp.exp2(jnp.rint(0.5 * jnp.log2(d))), 1) 1feabcd

368 

369def transpose(x): 1feabcd

370 """ swap the last two axes of array x, corresponds to matrix tranposition 

371 with the broadcasting convention of matmul """ 

372 if x.ndim < 2: 

373 return x 

374 elif isinstance(x, jnp.ndarray): 

375 return jnp.swapaxes(x, -2, -1) 

376 else: 

377 # need to support numpy because this function is used with gvars 

378 return numpy.swapaxes(x, -2, -1) 

379 

380class Chol(Decomposition): 1feabcd

381 """Cholesky decomposition. The matrix is regularized adding a small multiple 

382 of the identity.""" 

383 

384 def __init__(self, K, *, epsrel='auto', epsabs=0): 1feabcd

385 # K <- K + Iε 

386 # K = LL' 

387 self._K = K 1feabcd

388 s = diag_scale_pow2(K) 1feabcd

389 K = K / s / s[:, None] 1feabcd

390 eps = self._parseeps(K, epsrel, epsabs) 1feabcd

391 K = K.at[jnp.diag_indices_from(K)].add(eps) 1feabcd

392 L = jlinalg.cholesky(K, lower=True) 1feabcd

393 with _jaxext.skipifabstract(): 1feabcd

394 if not jnp.all(jnp.isfinite(L)): 1feabcd

395 # TODO check that jax fills with nan after failed row, detect 

396 # and report minor index like scipy 

397 raise numpy.linalg.LinAlgError('cholesky decomposition not finite, probably matrix not pos def numerically') 1abcd

398 self._L = L * s[:, None] 1feabcd

399 self._eps = eps * jnp.min(s * s) 1feabcd

400 

401 def matrix(self): 1feabcd

402 return self._K 1feabcd

403 

404 def ginv_linear(self, X): 1feabcd

405 # = K⁻¹X 

406 # K⁻¹ = L'⁻¹L⁻¹ 

407 # K⁻¹X = L'⁻¹(L⁻¹X) 

408 invLX = jlinalg.solve_triangular(self._L, X, lower=True) 1abcd

409 return jlinalg.solve_triangular(self._L.T, invLX, lower=False) 1abcd

410 

411 def pinv_bilinear(self, A, r): 1feabcd

412 # = A'K⁻¹r = A'L'⁻¹L⁻¹r = (L⁻¹A)'(L⁻¹r) 

413 invLr = jlinalg.solve_triangular(self._L, r, lower=True) 1feabcd

414 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd

415 return invLA.T @ invLr 1feabcd

416 

417 def pinv_bilinear_robj(self, A, r): 1feabcd

418 # = A'K⁻¹r 

419 invLr = solve_triangular_python(self._L, r, lower=True) 1feabcd

420 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd

421 return numpy.asarray(invLA).T @ invLr 1feabcd

422 

423 def ginv_quad(self, A): 1feabcd

424 # = A'K⁻¹A = A'K⁻¹A = A'L'⁻¹L⁻¹A = (L⁻¹A)'(L⁻¹A) 

425 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd

426 return invLA.T @ invLA 1feabcd

427 

428 def ginv_diagquad(self, A): 1feabcd

429 # = diag(A'K⁻¹A) 

430 # X = L⁻¹A 

431 # diag(A'K⁻¹A)_i = diag(X'X)_i = ∑_j X'_ij X_ji = ∑_j X_ji X_ji 

432 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1abcd

433 return jnp.einsum('ji,ji->i', invLA, invLA) 1abcd

434 

435 def correlate(self, x): 1feabcd

436 # = Lx 

437 return self._L @ x 1feabcd

438 

439 def back_correlate(self, X): 1feabcd

440 # = L'X 

441 return self._L.T @ X 1abcd

442 

443 def pinv_correlate(self, x): 1feabcd

444 # = L⁻¹x 

445 return jlinalg.solve_triangular(self._L, x, lower=True) 1feabcd

446 

447 def minus_log_normal_density(self, 1feabcd

448 r, # 1d array, the residuals (data - prior mean) 

449 *, 

450 dr_vjp=None, # callable, x -> x_i ∂r_i/∂p_j, gradrev and fishvec 

451 dK_vjp=None, # callable, x -> x_ij ∂K_ij/∂p_k, gradrev and fishvec 

452 dr_jvp_vec=None, # 1d array, ∂r_i/∂r_j v_j, fishvec 

453 dK_jvp_vec=None, # 2d array, ∂K_ij/∂p_k v_k, fishvec 

454 dr=None, # 2d array, ∂r_i/∂p_j, gradfwd and fisher 

455 dK=None, # 3d array, ∂K_ij/∂p_k, gradfwd and fisher 

456 value=False, 

457 gradrev=False, 

458 gradfwd=False, 

459 fisher=False, 

460 fishvec=False, 

461 ): 

462 

463 L = self._L 1feabcd

464 

465 out = {} 1feabcd

466 

467 # compute shared factors 

468 grad = ( 1feabcd

469 (gradrev and (dK_vjp is not None or dr_vjp is not None)) 

470 or (gradfwd and (dK is not None or dr is not None)) 

471 ) 

472 if value or grad: 1feabcd

473 invLr = jlinalg.solve_triangular(L, r, lower=True) 1feabcd

474 if grad: 1feabcd

475 invKr = jlinalg.solve_triangular(L.T, invLr, lower=False) 1feabcd

476 if (gradrev and dK_vjp is not None) or (gradfwd and dK is not None): 1feabcd

477 invL = jlinalg.solve_triangular(L, jnp.eye(len(L)), lower=True) 1feabcd

478 invK = invL.T @ invL 1feabcd

479 

480 if value: 1feabcd

481 # = 1/2 n log 2π 

482 # + 1/2 log det K 

483 # + 1/2 r'K⁻¹r 

484 # K = LL' 

485 # K⁻¹ = L'⁻¹L⁻¹ 

486 # det K = (det L)² = 

487 # = (∏_i L_ii)² 

488 # r'K⁻¹r = r'L'⁻¹L⁻¹r = 

489 # = (L⁻¹r)'(L⁻¹r) 

490 out['value'] = 1/2 * ( 1feabcd

491 len(L) * jnp.log(2 * jnp.pi) + 

492 2 * jnp.sum(jnp.log(jnp.diag(L))) + 

493 invLr @ invLr 

494 ) 

495 else: 

496 out['value'] = None 1abcd

497 

498 if gradrev: 1feabcd

499 # = 1/2 tr(K⁻¹dK) 

500 # + r'K⁻¹dr 

501 # - 1/2 r'K⁻¹dKK⁻¹r 

502 # tr(K⁻¹dK) = K⁻¹_ij dK_ji = 

503 # = K⁻¹_ij dK_ij = 

504 # = dK_vjp(K⁻¹) 

505 # r'K⁻¹dr = r_i K⁻¹_ij dr_j = 

506 # = (K⁻¹r)_j dr_j = 

507 # = dr_vjp(K⁻¹r) 

508 # r'K⁻¹dKK⁻¹r = r_i K⁻¹_ij dK_jl K⁻¹_lm r_m = 

509 # = (K⁻¹r)_j dK_jl (K⁻¹r)_l = 

510 # = dK_vjp((K⁻¹r) ⊗ (K⁻¹r)) 

511 out['gradrev'] = 0 1feabcd

512 if dK_vjp is not None: 512 ↛ 516line 512 didn't jump to line 516 because the condition on line 512 was always true1feabcd

513 tr_invK_dK = dK_vjp(invK) 1feabcd

514 r_invK_dK_invK_r = dK_vjp(jnp.outer(invKr, invKr)) 1feabcd

515 out['gradrev'] += 1/2 * (tr_invK_dK - r_invK_dK_invK_r) 1feabcd

516 if dr_vjp is not None: 516 ↛ 522line 516 didn't jump to line 522 because the condition on line 516 was always true1feabcd

517 r_invK_dr = dr_vjp(invKr) 1feabcd

518 out['gradrev'] += r_invK_dr 1feabcd

519 else: 

520 out['gradrev'] = None 1eabcd

521 

522 if gradfwd: 1feabcd

523 # = 1/2 tr(K⁻¹dK) 

524 # + r'K⁻¹dr 

525 # - 1/2 r'K⁻¹dKK⁻¹r 

526 # tr(K⁻¹dK)_k = K⁻¹_ij dK_ijk 

527 # r'K⁻¹dr = (K⁻¹r)'dr 

528 # (r'K⁻¹dKK⁻¹r)_k = r_i K⁻¹_ij dK_jlk K⁻¹_lm r_m = 

529 # = (K⁻¹r)_j dK_jlk (K⁻¹r)_l 

530 out['gradfwd'] = 0 1eabcd

531 if dK is not None: 531 ↛ 535line 531 didn't jump to line 535 because the condition on line 531 was always true1eabcd

532 tr_invK_dK = jnp.einsum('ij,ijk->k', invK, dK) 1eabcd

533 r_invK_dK_invK_r = jnp.einsum('i,ijk,j->k', invKr, dK, invKr) 1eabcd

534 out['gradfwd'] += 1/2 * (tr_invK_dK - r_invK_dK_invK_r) 1eabcd

535 if dr is not None: 535 ↛ 541line 535 didn't jump to line 541 because the condition on line 535 was always true1eabcd

536 r_invK_dr = invKr @ dr 1eabcd

537 out['gradfwd'] += r_invK_dr 1eabcd

538 else: 

539 out['gradfwd'] = None 1feabcd

540 

541 if fisher: 1feabcd

542 # = 1/2 tr(K⁻¹dKK⁻¹d'K) 

543 # + dr'K⁻¹d'r 

544 # tr(K⁻¹dKK⁻¹d'K)_ij = tr(L'⁻¹L⁻¹dKL'⁻¹L⁻¹d'K)_ij = 

545 # = tr(L⁻¹dKL'⁻¹L⁻¹d'KL'⁻¹)_ij = 

546 # = (L⁻¹dKL'⁻¹)_kli (L⁻¹dKL'⁻¹)_klj 

547 # (L⁻¹dKL'⁻¹)_ijk = L⁻¹_il dK_lmk L'⁻¹_mj = 

548 # = L⁻¹_il L⁻¹_jm dK_lmk 

549 # (dr'K⁻¹d'r)_kq = dr'_k L'⁻¹L⁻¹dr_q = 

550 # = (L⁻¹dr_k)_i (L⁻¹dr_q)_i 

551 out['fisher'] = 0 1abcd

552 if dK is not None: 552 ↛ 561line 552 didn't jump to line 561 because the condition on line 552 was always true1abcd

553 invL_dK = solve_triangular_batched(L, 1abcd

554 jnp.moveaxis(dK, 2, 0), 

555 lower=True) # kim: L⁻¹_il dK_lmk 

556 invL_dK_invL = solve_triangular_batched(L, 1abcd

557 jnp.swapaxes(invL_dK, 1, 2), 

558 lower=True) # kji: L⁻¹_jm (L⁻¹_il dK_lmk) 

559 tr_invK_dK_invK_dK = jnp.einsum('kij,qij->kq', invL_dK_invL, invL_dK_invL) 1abcd

560 out['fisher'] += 1/2 * tr_invK_dK_invK_dK 1abcd

561 if dr is not None: 561 ↛ 568line 561 didn't jump to line 568 because the condition on line 561 was always true1abcd

562 invLdr = jlinalg.solve_triangular(L, dr, lower=True) 1abcd

563 dr_invK_dr = invLdr.T @ invLdr 1abcd

564 out['fisher'] += dr_invK_dr 1abcd

565 else: 

566 out['fisher'] = None 1feabcd

567 

568 if fishvec: 1feabcd

569 # = 1/2 tr(K⁻¹dKK⁻¹d'K) v 

570 # + dr'K⁻¹d'r v 

571 # tr(K⁻¹dKK⁻¹d'K) v = K_vjp(K⁻¹K_jvp(v)K⁻¹) = 

572 # = K_vjp(L'⁻¹L⁻¹ K_jvp(v) L'⁻¹L⁻¹) 

573 # dr'K⁻¹d'r v = dr'K⁻¹dr_jvp(v) = 

574 # = dr_vjp(K⁻¹dr_jvp(v)) = 

575 # = dr_vjp(L'⁻¹L⁻¹ dr_jvp(v)) 

576 out['fishvec'] = 0 1abcd

577 if not (dK_jvp_vec is None and dK_vjp is None): 577 ↛ 584line 577 didn't jump to line 584 because the condition on line 577 was always true1abcd

578 invL_dKv = jlinalg.solve_triangular(L, dK_jvp_vec, lower=True) 1abcd

579 invK_dKv = jlinalg.solve_triangular(L.T, invL_dKv, lower=False) 1abcd

580 invL_dKv_invK = jlinalg.solve_triangular(L, invK_dKv.T, lower=True) 1abcd

581 invK_dKv_invK = jlinalg.solve_triangular(L.T, invL_dKv_invK, lower=False) 1abcd

582 tr_invK_dK_invK_dK_v = dK_vjp(invK_dKv_invK) 1abcd

583 out['fishvec'] += 1/2 * tr_invK_dK_invK_dK_v 1abcd

584 if not (dr_jvp_vec is None and dr_vjp is None): 584 ↛ 592line 584 didn't jump to line 592 because the condition on line 584 was always true1abcd

585 invL_drv = jlinalg.solve_triangular(L, dr_jvp_vec, lower=True) 1abcd

586 invK_drv = jlinalg.solve_triangular(L.T, invL_drv, lower=False) 1abcd

587 dr_invK_drv_v = dr_vjp(invK_drv) 1abcd

588 out['fishvec'] += dr_invK_drv_v 1abcd

589 else: 

590 out['fishvec'] = None 1feabcd

591 

592 return tuple(out.values()) 1feabcd

593 

594 @classmethod 1feabcd

595 def make_derivs(cls, 1feabcd

596 K_fun, r_fun, primal, 

597 *, 

598 args=(), 

599 kw={}, 

600 vec=None, 

601 value=False, 

602 gradrev=False, 

603 gradfwd=False, 

604 fisher=False, 

605 fishvec=False, 

606 ): 

607 """ 

608 Prepares arguments for `minus_log_normal_density`. 

609 

610 Parameters 

611 ---------- 

612 K_fun, r_fun : callable 

613 Functions with signature ``f(primal, *args, **kw)`` that produce the 

614 `K` init argument and the `r` `minus_log_normal_density` argument. 

615 primal : 1d array 

616 The first argument to `K_fun` and `r_fun`. 

617 args : tuple 

618 Additional positional arguments to `K_fun` and `r_fun`. 

619 kw : dict 

620 Keyword arguments to `K_fun` and `r_fun`. 

621 vec : 1d array 

622 A tangent vector to compute the jacobian-vector products. 

623 value, gradrev, gradfwd, fisher, fishvec : bool 

624 Arguments to `minus_log_normal_density`, used to determine which 

625 derivatives are needed. 

626 

627 Returns 

628 ------- 

629 K : 2d array 

630 Output of `K_fun`. 

631 r : 1d array 

632 Output of `r_fun`. 

633 out : dict 

634 Dictionary with derivative arguments to `minus_log_normal_density`. 

635 """ 

636 

637 partial = lambda f: lambda x: f(x, *args, **kw) 1abcd

638 K_fun = partial(K_fun) 1abcd

639 r_fun = partial(r_fun) 1abcd

640 

641 out = {} 1abcd

642 

643 if gradrev or fishvec: 1abcd

644 K, dK_vjp = jax.vjp(K_fun, primal) 1abcd

645 r, dr_vjp = jax.vjp(r_fun, primal) 1abcd

646 out['dK_vjp'] = lambda x: dK_vjp(x)[0] 1abcd

647 out['dr_vjp'] = lambda x: dr_vjp(x)[0] 1abcd

648 else: 

649 K = K_fun(primal) 1abcd

650 r = r_fun(primal) 1abcd

651 if fishvec: 1abcd

652 _, out['dK_jvp_vec'] = jax.jvp(K_fun, (primal,), (vec,)) 1abcd

653 _, out['dr_jvp_vec'] = jax.jvp(r_fun, (primal,), (vec,)) 1abcd

654 if gradfwd or fisher: 1abcd

655 out['dK'] = jax.jacfwd(K_fun)(primal) 1abcd

656 out['dr'] = jax.jacfwd(r_fun)(primal) 1abcd

657 

658 return K, r, out 1abcd

659 

660 @property 1feabcd

661 def n(self): 1feabcd

662 return len(self._L) 1feabcd

663 

664 m = n 1feabcd