Coverage for src/lsqfitgp/_linalg/_decomp.py: 85%
248 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 13:39 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-17 13:39 +0000
1# lsqfitgp/_linalg/_decomp.py
2#
3# Copyright (c) 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20"""
22Copy-pasted from the notes:
242023-02-14
25==========
27My current decomposition system is a mess. I can't take reverse gradients.
28I can't straightforwardly implement optimized algorithms that compute together
29likelihood, gradient, and fisher. Jax patterns break down unpredictably. I
30have to redesign it from scratch.
32Guidelines and requirements:
34 - Sufficient modularity to implement composite decompositions (Woodbury,
35 Block)
37 - Does not mess up jax in any way
39 - Caches decompositions
41 - Favors optimizing together the likelihood and its derivatives
43Operations (in the following I indicate with lowercase inputs which are
44typically vectors or tall matrices, and uppercase inputs which are typically
45large matrices, since optimization requires taking it into account):
47 pinv_bilinear(A, r) => A'K⁺r (for the posterior mean)
48 pinv_bilinear_robj(A, r) same but r can be gvars
49 ginv_quad(A) => A'K⁻A (for the posterior covariance)
50 I want the pseudoinverse for the mean because the data may not be
51 in the span and I want to project it orthogonally, while for the
52 covariance I expect A and K to come from a pos def matrix so they are
53 coherent
54 ginv_diagquad(A) => diag(A'K⁻A) (for the posterior variance)
55 minus_log_normal_density(
56 r: 1d array, # the residuals (data - prior mean)
57 dr_vjp: callable, # x -> x_i ∂r_i/∂p_j, gradrev and fishvec
58 dK_vjp: callable, # x -> x_ij ∂K_ij/∂p_k, gradrev and fishvec
59 dr_jvp: callable, # x -> ∂r_i/∂r_j x_j, fishvec
60 dK_jvp: callable, # x -> ∂K_ij/∂p_k x_k, fishvec
61 dr: 2d array, # ∂r_i/∂p_j, gradfwd and fisher
62 dK: 3d array, # ∂K_ij/∂p_k, gradfwd and fisher
63 vec: 1d array, # input vector of fishvec, same size as params
64 value: bool,
65 gradrev: bool,
66 gradfwd: bool,
67 fisher: bool,
68 fishvec: bool,
69 )
70 This computes on request
71 value: 1/2 tr(KK⁺) log 2π
72 + 1/2 tr(I-KK⁺) log 2π
73 + 1/2 log pdet K
74 + 1/2 tr(I-KK⁺) log ε
75 + 1/2 r'(K⁺+(I-KK⁺)/ε)r
76 gradrev,
77 gradfwd: 1/2 tr(K⁺dK)
78 + r'(K⁺+(I-KK⁺)/ε) dr
79 - 1/2 r'(K⁺+2(I-KK⁺)/ε)dKK⁺r
80 fisher: 1/2 tr(K⁺dK(K⁺+2(I-KK⁺)/ε)d'K)
81 - 2 tr(K⁺dK(I-KK⁺)d'KK⁺)
82 + dr'(K⁺+(I-KK⁺)/ε)d'r
83 fishvec: fisher matrix times vec
84 There should be options for omitting the pieces with ε. I also need a
85 way to make densities with different values of ε comparable with each
86 other (may not be possible, if it is, it probably requires a history of
87 ranks and ε). gradfwd/rev form K⁺ explicitly to compute tr(K⁺dK) for
88 efficiency.
89 correlate(x)
90 Zx where K = ZZ'.
91 back_correlate(X):
92 Z'X, this is used by Sandwich and Woodbury.
94Since I also want to compute the Student density, I could split
95minus_log_normal_density's return value into logdet and quad. The gradient
96splits nicely between the two terms, but I have to redo the calculation of the
97Fisher matrix for the Student distribution. Alternatively, I could use the the
98Normal Fisher. => See Lange et al. (1989, app. B). => I think I can split the
99gradient and Fisher matrix too.
1012023-03-07
102==========
104To compute a Fisher-vector product when there are many parameters, do
106 tr(K+ dK K+ dK) v =
107 = K_vjp(K+ K_jvp(v) K+)
109"""
111# TODO to automatize this further, I could take in a function that generates K
112# (or its pieces) and the arguments to the function. But how would this play
113# together with passing decomposition objects as pieces?
115# TODO split this file by class
117# TODO Consider using lineax for implementing non-materialized decomps
119import abc 1feabcd
120import functools 1feabcd
122import numpy 1feabcd
123import jax 1feabcd
124from jax import numpy as jnp 1feabcd
125from jax.scipy import linalg as jlinalg 1feabcd
126from jax import lax 1feabcd
128from .. import _jaxext 1feabcd
129from . import _pytree 1feabcd
131class Decomposition(_pytree.AutoPyTree, abc.ABC): 1feabcd
132 """
133 Abstract base class for decompositions of positive semidefinite matrices.
134 """
136 @abc.abstractmethod 1feabcd
137 def __init__(self, *args, **kw): 1feabcd
138 """ Decompose the input matrix """
139 pass
141 @abc.abstractmethod 1feabcd
142 def matrix(self): 1feabcd
143 """ The input matrix """
144 pass
146 @abc.abstractmethod 1feabcd
147 def ginv_linear(self, X): 1feabcd
148 """ Compute K⁻X """
149 pass
151 @abc.abstractmethod 1feabcd
152 def pinv_bilinear(self, A, r): 1feabcd
153 """Compute A'K⁺r."""
154 pass
156 @abc.abstractmethod 1feabcd
157 def pinv_bilinear_robj(self, A, r): 1feabcd
158 """Compute A'K⁺r, where r can be an array of objects."""
159 pass
161 @abc.abstractmethod 1feabcd
162 def ginv_quad(self, A): 1feabcd
163 """Compute A'K⁻A."""
164 pass
166 @abc.abstractmethod 1feabcd
167 def ginv_diagquad(self, A): 1feabcd
168 """Compute diag(A'K⁻A)."""
169 pass
171 @abc.abstractmethod 1feabcd
172 def correlate(self, x): 1feabcd
173 """ Compute Zx where K = ZZ' """
174 pass
176 @abc.abstractmethod 1feabcd
177 def back_correlate(self, X): 1feabcd
178 """ Compute Z'X """
179 pass
181 @abc.abstractmethod 1feabcd
182 def pinv_correlate(self, x): 1feabcd
183 """ Compute Z⁺x """
184 pass
186 @abc.abstractmethod 1feabcd
187 def minus_log_normal_density(self, 1feabcd
188 r,
189 *,
190 dr_vjp=None,
191 dK_vjp=None,
192 dr_jvp_vec=None,
193 dK_jvp_vec=None,
194 dr=None,
195 dK=None,
196 value=False,
197 gradrev=False,
198 gradfwd=False,
199 fisher=False,
200 fishvec=False,
201 ):
202 """
203 Compute minus log a Normal density and its derivatives, with covariance
204 matrix K.
206 If an input derivative is not specified, it is assumed to be zero.
208 Parameters
209 ----------
210 r: 1d array
211 The residuals (value - mean)
212 dr_vjp: callable
213 x -> x_i ∂r_i/∂p_j, for gradrev and fishvec
214 dK_vjp: callable
215 x -> x_ij ∂K_ij/∂p_k, for gradrev and fishvec
216 dr_jvp_vec: 1d array
217 ∂r_i/∂r_j vec_j, for fishvec
218 dK_jvp_vec: 2d array
219 ∂K_ij/∂p_k vec_k, for fishvec
220 dr: 2d array
221 ∂r_i/∂p_j for gradfwd and fisher
222 dK: 3d array
223 ∂K_ij/∂p_k, for gradfwd and fisher
224 value: bool
225 gradrev: bool
226 gradfwd: bool
227 fisher: bool
228 fishvec: bool
229 These parameters indicate which of the return values to compute.
230 Default all False.
232 Returns
233 -------
234 value: 1/2 tr(KK⁺) log 2π
235 + 1/2 tr(I-KK⁺) log 2π
236 + 1/2 log pdet K
237 + 1/2 tr(I-KK⁺) log ε
238 + 1/2 r'(K⁺+(I-KK⁺)/ε)r
239 gradrev,
240 gradfwd: 1/2 tr(K⁺dK)
241 + r'(K⁺+(I-KK⁺)/ε) dr
242 - 1/2 r'(K⁺+2(I-KK⁺)/ε)dKK⁺r
243 fisher: 1/2 tr(K⁺dK(K⁺+2(I-KK⁺)/ε)d'K)
244 - 2 tr(K⁺dK(I-KK⁺)d'KK⁺)
245 + dr'(K⁺+(I-KK⁺)/ε)d'r
246 fishvec: fisher matrix @ vec
247 """
248 pass
250 def _parseeps(self, K, epsrel, epsabs, maxeigv=None): 1feabcd
251 """ Determine eps from input arguments """
252 machine_eps = jnp.finfo(_jaxext.float_type(K)).eps 1feabcd
253 if epsrel == 'auto': 1feabcd
254 epsrel = len(K) * machine_eps 1feabcd
255 if epsabs == 'auto': 255 ↛ 256line 255 didn't jump to line 256 because the condition on line 255 was never true1feabcd
256 epsabs = machine_eps
257 if maxeigv is None: 257 ↛ 259line 257 didn't jump to line 259 because the condition on line 257 was always true1feabcd
258 maxeigv = eigval_bound(K) 1feabcd
259 self._eps = epsrel * maxeigv + epsabs 1feabcd
260 return self._eps 1feabcd
262 @property 1feabcd
263 def eps(self): 1feabcd
264 """
265 The threshold below which eigenvalues are too small to be determined.
266 """
267 return self._eps 1abcd
269 @property 1feabcd
270 @abc.abstractmethod 1feabcd
271 def n(self): 1feabcd
272 """ Number of rows/columns of the matrix """
273 pass
275 @property 1feabcd
276 @abc.abstractmethod 1feabcd
277 def m(self): 1feabcd
278 """ Number of columns of Z """
279 pass
281 def ginv(self): 1feabcd
282 """ Compute K⁻ """
283 return self.ginv_quad(jnp.eye(self.n)) 1abcd
285def solve_triangular_python(a, b, *, lower=False): 1feabcd
286 """
287 Pure python implementation of scipy.linalg.solve_triangular for when
288 a or b are object arrays.
289 """
290 # TODO maybe commit this to gvar.linalg
291 a = numpy.asarray(a) 1feabcd
292 x = numpy.copy(b) 1feabcd
294 vec = x.ndim < 2 1feabcd
295 if vec: 295 ↛ 298line 295 didn't jump to line 298 because the condition on line 295 was always true1feabcd
296 x = x[:, None] 1feabcd
298 n = a.shape[-1] 1feabcd
299 assert x.shape[-2] == n 1feabcd
301 if not lower: 301 ↛ 302line 301 didn't jump to line 302 because the condition on line 301 was never true1feabcd
302 a = a[..., ::-1, ::-1]
303 x = x[..., ::-1, :]
305 x[..., 0, :] /= a[..., 0, 0, None] 1feabcd
306 for i in range(1, n): 1feabcd
307 x[..., i:, :] -= x[..., None, i - 1, :] * a[..., i:, i - 1, None] 1feabcd
308 x[..., i, :] /= a[..., i, i, None] 1feabcd
310 if not lower: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true1feabcd
311 x = x[..., ::-1, :]
313 if vec: 313 ↛ 315line 313 didn't jump to line 315 because the condition on line 313 was always true1feabcd
314 x = numpy.squeeze(x, -1) 1feabcd
315 return x 1feabcd
317def solve_triangular_batched(a, b, *, lower=False): 1feabcd
318 """ Version of jax.scipy.linalg.solve_triangular that batches matmul-like """
319 a = jnp.asarray(a) 1abcd
320 b = jnp.asarray(b) 1abcd
321 vec = b.ndim < 2 1abcd
322 if vec: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true1abcd
323 b = b[:, None]
325 batch_shape = jnp.broadcast_shapes(a.shape[:-2], b.shape[:-2]) 1abcd
326 a_shape = batch_shape + a.shape[-2:] 1abcd
327 b_shape = batch_shape + b.shape[-2:] 1abcd
328 result = lax.linalg.triangular_solve( 1abcd
329 jnp.broadcast_to(a, a_shape), jnp.broadcast_to(b, b_shape),
330 left_side=True, lower=lower,
331 )
332 assert result.shape == b_shape 1abcd
334 if vec: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true1abcd
335 result = result.squeeze(-1)
336 return result 1abcd
338def solve_batched(a, b, **kw): 1feabcd
339 """ Version of jax.scipy.linalg.solve that batches matmul-like """
340 a = jnp.asarray(a) 1abcd
341 b = jnp.asarray(b) 1abcd
342 vec = b.ndim < 2 1abcd
343 if vec: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true1abcd
344 b = b[:, None]
346 @functools.partial(jnp.vectorize, signature='(i,j),(j,k)->(i,k)') 1abcd
347 def solve_batched(a, b): 1abcd
348 return jlinalg.solve(a, b, **kw) 1abcd
349 result = solve_batched(a, b) 1abcd
351 if vec: 351 ↛ 352line 351 didn't jump to line 352 because the condition on line 351 was never true1abcd
352 result = result.squeeze(-1)
353 return result 1abcd
355def eigval_bound(K): 1feabcd
356 """
357 Upper bound on the largest magnitude eigenvalue of the matrix, from
358 Gershgorin's theorem.
359 """
360 return jnp.max(jnp.sum(jnp.abs(K), axis=1)) 1feabcd
362def diag_scale_pow2(K): 1feabcd
363 """
364 Compute a vector s of powers of 2 such that diag(K / outer(s, s)) ~ 1.
365 """
366 d = jnp.diag(K) 1feabcd
367 return jnp.where(d, jnp.exp2(jnp.rint(0.5 * jnp.log2(d))), 1) 1feabcd
369 # Golub and Van Loan (2013) say this is not a totally general heuristic
371def transpose(x): 1feabcd
372 """ swap the last two axes of array x, corresponds to matrix tranposition
373 with the broadcasting convention of matmul """
374 if x.ndim < 2:
375 return x
376 elif isinstance(x, jnp.ndarray):
377 return jnp.swapaxes(x, -2, -1)
378 else:
379 # need to support numpy because this function is used with gvars
380 return numpy.swapaxes(x, -2, -1)
382class Chol(Decomposition): 1feabcd
383 """Cholesky decomposition. The matrix is regularized adding a small multiple
384 of the identity."""
386 def __init__(self, K, *, epsrel='auto', epsabs=0): 1feabcd
387 # K <- K + Iε
388 # K = LL'
389 self._K = K 1feabcd
390 s = diag_scale_pow2(K) 1feabcd
391 K = K / s / s[:, None] 1feabcd
392 eps = self._parseeps(K, epsrel, epsabs) 1feabcd
393 K = K.at[jnp.diag_indices_from(K)].add(eps) 1feabcd
394 L = jlinalg.cholesky(K, lower=True) 1feabcd
395 with _jaxext.skipifabstract(): 1feabcd
396 if not jnp.all(jnp.isfinite(L)): 1feabcd
397 # TODO check that jax fills with nan after failed row, detect
398 # and report minor index like scipy
399 raise numpy.linalg.LinAlgError('cholesky decomposition not finite, probably matrix not pos def numerically') 1abcd
400 self._L = L * s[:, None] 1feabcd
401 self._eps = eps * jnp.min(s * s) 1feabcd
403 def matrix(self): 1feabcd
404 return self._K 1feabcd
406 def ginv_linear(self, X): 1feabcd
407 # = K⁻¹X
408 # K⁻¹ = L'⁻¹L⁻¹
409 # K⁻¹X = L'⁻¹(L⁻¹X)
410 invLX = jlinalg.solve_triangular(self._L, X, lower=True) 1abcd
411 return jlinalg.solve_triangular(self._L.T, invLX, lower=False) 1abcd
413 def pinv_bilinear(self, A, r): 1feabcd
414 # = A'K⁻¹r = A'L'⁻¹L⁻¹r = (L⁻¹A)'(L⁻¹r)
415 invLr = jlinalg.solve_triangular(self._L, r, lower=True) 1feabcd
416 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd
417 return invLA.T @ invLr 1feabcd
419 def pinv_bilinear_robj(self, A, r): 1feabcd
420 # = A'K⁻¹r
421 invLr = solve_triangular_python(self._L, r, lower=True) 1feabcd
422 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd
423 return numpy.asarray(invLA).T @ invLr 1feabcd
425 def ginv_quad(self, A): 1feabcd
426 # = A'K⁻¹A = A'K⁻¹A = A'L'⁻¹L⁻¹A = (L⁻¹A)'(L⁻¹A)
427 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd
428 return invLA.T @ invLA 1feabcd
430 def ginv_diagquad(self, A): 1feabcd
431 # = diag(A'K⁻¹A)
432 # X = L⁻¹A
433 # diag(A'K⁻¹A)_i = diag(X'X)_i = ∑_j X'_ij X_ji = ∑_j X_ji X_ji
434 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1abcd
435 return jnp.einsum('ji,ji->i', invLA, invLA) 1abcd
437 def correlate(self, x): 1feabcd
438 # = Lx
439 return self._L @ x 1feabcd
441 def back_correlate(self, X): 1feabcd
442 # = L'X
443 return self._L.T @ X 1abcd
445 def pinv_correlate(self, x): 1feabcd
446 # = L⁻¹x
447 return jlinalg.solve_triangular(self._L, x, lower=True) 1feabcd
449 def minus_log_normal_density(self, 1feabcd
450 r, # 1d array, the residuals (data - prior mean)
451 *,
452 dr_vjp=None, # callable, x -> x_i ∂r_i/∂p_j, gradrev and fishvec
453 dK_vjp=None, # callable, x -> x_ij ∂K_ij/∂p_k, gradrev and fishvec
454 dr_jvp_vec=None, # 1d array, ∂r_i/∂r_j v_j, fishvec
455 dK_jvp_vec=None, # 2d array, ∂K_ij/∂p_k v_k, fishvec
456 dr=None, # 2d array, ∂r_i/∂p_j, gradfwd and fisher
457 dK=None, # 3d array, ∂K_ij/∂p_k, gradfwd and fisher
458 value=False,
459 gradrev=False,
460 gradfwd=False,
461 fisher=False,
462 fishvec=False,
463 ):
465 L = self._L 1feabcd
467 out = {} 1feabcd
469 # compute shared factors
470 grad = ( 1feabcd
471 (gradrev and (dK_vjp is not None or dr_vjp is not None))
472 or (gradfwd and (dK is not None or dr is not None))
473 )
474 if value or grad: 1feabcd
475 invLr = jlinalg.solve_triangular(L, r, lower=True) 1feabcd
476 if grad: 1feabcd
477 invKr = jlinalg.solve_triangular(L.T, invLr, lower=False) 1feabcd
478 if (gradrev and dK_vjp is not None) or (gradfwd and dK is not None): 1feabcd
479 invL = jlinalg.solve_triangular(L, jnp.eye(len(L)), lower=True) 1feabcd
480 invK = invL.T @ invL 1feabcd
482 if value: 1feabcd
483 # = 1/2 n log 2π
484 # + 1/2 log det K
485 # + 1/2 r'K⁻¹r
486 # K = LL'
487 # K⁻¹ = L'⁻¹L⁻¹
488 # det K = (det L)² =
489 # = (∏_i L_ii)²
490 # r'K⁻¹r = r'L'⁻¹L⁻¹r =
491 # = (L⁻¹r)'(L⁻¹r)
492 out['value'] = 1/2 * ( 1feabcd
493 len(L) * jnp.log(2 * jnp.pi) +
494 2 * jnp.sum(jnp.log(jnp.diag(L))) +
495 invLr @ invLr
496 )
497 else:
498 out['value'] = None 1abcd
500 if gradrev: 1feabcd
501 # = 1/2 tr(K⁻¹dK)
502 # + r'K⁻¹dr
503 # - 1/2 r'K⁻¹dKK⁻¹r
504 # tr(K⁻¹dK) = K⁻¹_ij dK_ji =
505 # = K⁻¹_ij dK_ij =
506 # = dK_vjp(K⁻¹)
507 # r'K⁻¹dr = r_i K⁻¹_ij dr_j =
508 # = (K⁻¹r)_j dr_j =
509 # = dr_vjp(K⁻¹r)
510 # r'K⁻¹dKK⁻¹r = r_i K⁻¹_ij dK_jl K⁻¹_lm r_m =
511 # = (K⁻¹r)_j dK_jl (K⁻¹r)_l =
512 # = dK_vjp((K⁻¹r) ⊗ (K⁻¹r))
513 out['gradrev'] = 0 1feabcd
514 if dK_vjp is not None: 514 ↛ 518line 514 didn't jump to line 518 because the condition on line 514 was always true1feabcd
515 tr_invK_dK = dK_vjp(invK) 1feabcd
516 r_invK_dK_invK_r = dK_vjp(jnp.outer(invKr, invKr)) 1feabcd
517 out['gradrev'] += 1/2 * (tr_invK_dK - r_invK_dK_invK_r) 1feabcd
518 if dr_vjp is not None: 518 ↛ 524line 518 didn't jump to line 524 because the condition on line 518 was always true1feabcd
519 r_invK_dr = dr_vjp(invKr) 1feabcd
520 out['gradrev'] += r_invK_dr 1feabcd
521 else:
522 out['gradrev'] = None 1eabcd
524 if gradfwd: 1feabcd
525 # = 1/2 tr(K⁻¹dK)
526 # + r'K⁻¹dr
527 # - 1/2 r'K⁻¹dKK⁻¹r
528 # tr(K⁻¹dK)_k = K⁻¹_ij dK_ijk
529 # r'K⁻¹dr = (K⁻¹r)'dr
530 # (r'K⁻¹dKK⁻¹r)_k = r_i K⁻¹_ij dK_jlk K⁻¹_lm r_m =
531 # = (K⁻¹r)_j dK_jlk (K⁻¹r)_l
532 out['gradfwd'] = 0 1eabcd
533 if dK is not None: 533 ↛ 537line 533 didn't jump to line 537 because the condition on line 533 was always true1eabcd
534 tr_invK_dK = jnp.einsum('ij,ijk->k', invK, dK) 1eabcd
535 r_invK_dK_invK_r = jnp.einsum('i,ijk,j->k', invKr, dK, invKr) 1eabcd
536 out['gradfwd'] += 1/2 * (tr_invK_dK - r_invK_dK_invK_r) 1eabcd
537 if dr is not None: 537 ↛ 543line 537 didn't jump to line 543 because the condition on line 537 was always true1eabcd
538 r_invK_dr = invKr @ dr 1eabcd
539 out['gradfwd'] += r_invK_dr 1eabcd
540 else:
541 out['gradfwd'] = None 1feabcd
543 if fisher: 1feabcd
544 # = 1/2 tr(K⁻¹dKK⁻¹d'K)
545 # + dr'K⁻¹d'r
546 # tr(K⁻¹dKK⁻¹d'K)_ij = tr(L'⁻¹L⁻¹dKL'⁻¹L⁻¹d'K)_ij =
547 # = tr(L⁻¹dKL'⁻¹L⁻¹d'KL'⁻¹)_ij =
548 # = (L⁻¹dKL'⁻¹)_kli (L⁻¹dKL'⁻¹)_klj
549 # (L⁻¹dKL'⁻¹)_ijk = L⁻¹_il dK_lmk L'⁻¹_mj =
550 # = L⁻¹_il L⁻¹_jm dK_lmk
551 # (dr'K⁻¹d'r)_kq = dr'_k L'⁻¹L⁻¹dr_q =
552 # = (L⁻¹dr_k)_i (L⁻¹dr_q)_i
553 out['fisher'] = 0 1abcd
554 if dK is not None: 554 ↛ 563line 554 didn't jump to line 563 because the condition on line 554 was always true1abcd
555 invL_dK = solve_triangular_batched(L, 1abcd
556 jnp.moveaxis(dK, 2, 0),
557 lower=True) # kim: L⁻¹_il dK_lmk
558 invL_dK_invL = solve_triangular_batched(L, 1abcd
559 jnp.swapaxes(invL_dK, 1, 2),
560 lower=True) # kji: L⁻¹_jm (L⁻¹_il dK_lmk)
561 tr_invK_dK_invK_dK = jnp.einsum('kij,qij->kq', invL_dK_invL, invL_dK_invL) 1abcd
562 out['fisher'] += 1/2 * tr_invK_dK_invK_dK 1abcd
563 if dr is not None: 563 ↛ 570line 563 didn't jump to line 570 because the condition on line 563 was always true1abcd
564 invLdr = jlinalg.solve_triangular(L, dr, lower=True) 1abcd
565 dr_invK_dr = invLdr.T @ invLdr 1abcd
566 out['fisher'] += dr_invK_dr 1abcd
567 else:
568 out['fisher'] = None 1feabcd
570 if fishvec: 1feabcd
571 # = 1/2 tr(K⁻¹dKK⁻¹d'K) v
572 # + dr'K⁻¹d'r v
573 # tr(K⁻¹dKK⁻¹d'K) v = K_vjp(K⁻¹K_jvp(v)K⁻¹) =
574 # = K_vjp(L'⁻¹L⁻¹ K_jvp(v) L'⁻¹L⁻¹)
575 # dr'K⁻¹d'r v = dr'K⁻¹dr_jvp(v) =
576 # = dr_vjp(K⁻¹dr_jvp(v)) =
577 # = dr_vjp(L'⁻¹L⁻¹ dr_jvp(v))
578 out['fishvec'] = 0 1abcd
579 if not (dK_jvp_vec is None and dK_vjp is None): 579 ↛ 586line 579 didn't jump to line 586 because the condition on line 579 was always true1abcd
580 invL_dKv = jlinalg.solve_triangular(L, dK_jvp_vec, lower=True) 1abcd
581 invK_dKv = jlinalg.solve_triangular(L.T, invL_dKv, lower=False) 1abcd
582 invL_dKv_invK = jlinalg.solve_triangular(L, invK_dKv.T, lower=True) 1abcd
583 invK_dKv_invK = jlinalg.solve_triangular(L.T, invL_dKv_invK, lower=False) 1abcd
584 tr_invK_dK_invK_dK_v = dK_vjp(invK_dKv_invK) 1abcd
585 out['fishvec'] += 1/2 * tr_invK_dK_invK_dK_v 1abcd
586 if not (dr_jvp_vec is None and dr_vjp is None): 586 ↛ 594line 586 didn't jump to line 594 because the condition on line 586 was always true1abcd
587 invL_drv = jlinalg.solve_triangular(L, dr_jvp_vec, lower=True) 1abcd
588 invK_drv = jlinalg.solve_triangular(L.T, invL_drv, lower=False) 1abcd
589 dr_invK_drv_v = dr_vjp(invK_drv) 1abcd
590 out['fishvec'] += dr_invK_drv_v 1abcd
591 else:
592 out['fishvec'] = None 1feabcd
594 return tuple(out.values()) # TODO a namedtuple 1feabcd
596 @classmethod 1feabcd
597 def make_derivs(cls, 1feabcd
598 K_fun, r_fun, primal,
599 *,
600 args=(),
601 kw={},
602 vec=None,
603 value=False,
604 gradrev=False,
605 gradfwd=False,
606 fisher=False,
607 fishvec=False,
608 ):
609 """
610 Prepares arguments for `minus_log_normal_density`.
612 Parameters
613 ----------
614 K_fun, r_fun : callable
615 Functions with signature ``f(primal, *args, **kw)`` that produce the
616 `K` init argument and the `r` `minus_log_normal_density` argument.
617 primal : 1d array
618 The first argument to `K_fun` and `r_fun`.
619 args : tuple
620 Additional positional arguments to `K_fun` and `r_fun`.
621 kw : dict
622 Keyword arguments to `K_fun` and `r_fun`.
623 vec : 1d array
624 A tangent vector to compute the jacobian-vector products.
625 value, gradrev, gradfwd, fisher, fishvec : bool
626 Arguments to `minus_log_normal_density`, used to determine which
627 derivatives are needed.
629 Returns
630 -------
631 K : 2d array
632 Output of `K_fun`.
633 r : 1d array
634 Output of `r_fun`.
635 out : dict
636 Dictionary with derivative arguments to `minus_log_normal_density`.
637 """
639 partial = lambda f: lambda x: f(x, *args, **kw) 1abcd
640 K_fun = partial(K_fun) 1abcd
641 r_fun = partial(r_fun) 1abcd
643 out = {} 1abcd
645 if gradrev or fishvec: 1abcd
646 K, dK_vjp = jax.vjp(K_fun, primal) 1abcd
647 r, dr_vjp = jax.vjp(r_fun, primal) 1abcd
648 out['dK_vjp'] = lambda x: dK_vjp(x)[0] 1abcd
649 out['dr_vjp'] = lambda x: dr_vjp(x)[0] 1abcd
650 else:
651 K = K_fun(primal) 1abcd
652 r = r_fun(primal) 1abcd
653 if fishvec: 1abcd
654 _, out['dK_jvp_vec'] = jax.jvp(K_fun, (primal,), (vec,)) 1abcd
655 _, out['dr_jvp_vec'] = jax.jvp(r_fun, (primal,), (vec,)) 1abcd
656 if gradfwd or fisher: 1abcd
657 out['dK'] = jax.jacfwd(K_fun)(primal) 1abcd
658 out['dr'] = jax.jacfwd(r_fun)(primal) 1abcd
660 return K, r, out 1abcd
662 @property 1feabcd
663 def n(self): 1feabcd
664 return len(self._L) 1feabcd
666 m = n 1feabcd