Coverage for src/lsqfitgp/_linalg/_decomp.py: 85%
248 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_linalg/_decomp.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20"""
22Copy-pasted from the notes:
242023-02-14
25==========
27My current decomposition system is a mess. I can't take reverse gradients.
28I can't straightforwardly implement optimized algorithms that compute together
29likelihood, gradient, and fisher. Jax patterns break down unpredictably. I
30have to redesign it from scratch.
32Guidelines and requirements:
34 - Sufficient modularity to implement composite decompositions (Woodbury,
35 Block)
37 - Does not mess up jax in any way
39 - Caches decompositions
41 - Favors optimizing together the likelihood and its derivatives
43Operations (in the following I indicate with lowercase inputs which are
44typically vectors or tall matrices, and uppercase inputs which are typically
45large matrices, since optimization requires taking it into account):
47 pinv_bilinear(A, r) => A'K⁺r (for the posterior mean)
48 pinv_bilinear_robj(A, r) same but r can be gvars
49 ginv_quad(A) => A'K⁻A (for the posterior covariance)
50 I want the pseudoinverse for the mean because the data may not be
51 in the span and I want to project it orthogonally, while for the
52 covariance I expect A and K to come from a pos def matrix so they are
53 coherent
54 ginv_diagquad(A) => diag(A'K⁻A) (for the posterior variance)
55 minus_log_normal_density(
56 r: 1d array, # the residuals (data - prior mean)
57 dr_vjp: callable, # x -> x_i ∂r_i/∂p_j, gradrev and fishvec
58 dK_vjp: callable, # x -> x_ij ∂K_ij/∂p_k, gradrev and fishvec
59 dr_jvp: callable, # x -> ∂r_i/∂r_j x_j, fishvec
60 dK_jvp: callable, # x -> ∂K_ij/∂p_k x_k, fishvec
61 dr: 2d array, # ∂r_i/∂p_j, gradfwd and fisher
62 dK: 3d array, # ∂K_ij/∂p_k, gradfwd and fisher
63 vec: 1d array, # input vector of fishvec, same size as params
64 value: bool,
65 gradrev: bool,
66 gradfwd: bool,
67 fisher: bool,
68 fishvec: bool,
69 )
70 This computes on request
71 value: 1/2 tr(KK⁺) log 2π
72 + 1/2 tr(I-KK⁺) log 2π
73 + 1/2 log pdet K
74 + 1/2 tr(I-KK⁺) log ε
75 + 1/2 r'(K⁺+(I-KK⁺)/ε)r
76 gradrev,
77 gradfwd: 1/2 tr(K⁺dK)
78 + r'(K⁺+(I-KK⁺)/ε) dr
79 - 1/2 r'(K⁺+2(I-KK⁺)/ε)dKK⁺r
80 fisher: 1/2 tr(K⁺dK(K⁺+2(I-KK⁺)/ε)d'K)
81 - 2 tr(K⁺dK(I-KK⁺)d'KK⁺)
82 + dr'(K⁺+(I-KK⁺)/ε)d'r
83 fishvec: fisher matrix times vec
84 There should be options for omitting the pieces with ε. I also need a
85 way to make densities with different values of ε comparable with each
86 other (may not be possible, if it is, it probably requires a history of
87 ranks and ε). gradfwd/rev form K⁺ explicitly to compute tr(K⁺dK) for
88 efficiency.
89 correlate(x)
90 Zx where K = ZZ'.
91 back_correlate(X):
92 Z'X, this is used by Sandwich and Woodbury.
94Since I also want to compute the Student density, I could split
95minus_log_normal_density's return value into logdet and quad. The gradient
96splits nicely between the two terms, but I have to redo the calculation of the
97Fisher matrix for the Student distribution. Alternatively, I could use the the
98Normal Fisher. => See Lange et al. (1989, app. B). => I think I can split the
99gradient and Fisher matrix too.
1012023-03-07
102==========
104To compute a Fisher-vector product when there are many parameters, do
106 tr(K+ dK K+ dK) v =
107 = K_vjp(K+ K_jvp(v) K+)
109"""
111# TODO to automatize this further, I could take in a function that generates K
112# (or its pieces) and the arguments to the function. But how would this play
113# together with passing decomposition objects as pieces?
115# TODO split this file by class
117# TODO Consider using lineax for implementing non-materialized decomps
119import abc 1feabcd
120import functools 1feabcd
122import numpy 1feabcd
123import jax 1feabcd
124from jax import numpy as jnp 1feabcd
125from jax.scipy import linalg as jlinalg 1feabcd
126from jax import lax 1feabcd
128from .. import _jaxext 1feabcd
129from . import _pytree 1feabcd
131class Decomposition(_pytree.AutoPyTree, abc.ABC): 1feabcd
132 """
133 Abstract base class for decompositions of positive semidefinite matrices.
134 """
136 @abc.abstractmethod 1feabcd
137 def __init__(self, *args, **kw): 1feabcd
138 """ Decompose the input matrix """
139 pass
141 @abc.abstractmethod 1feabcd
142 def matrix(self): 1feabcd
143 """ The input matrix """
144 pass
146 @abc.abstractmethod 1feabcd
147 def ginv_linear(self, X): 1feabcd
148 """ Compute K⁻X """
149 pass
151 @abc.abstractmethod 1feabcd
152 def pinv_bilinear(self, A, r): 1feabcd
153 """Compute A'K⁺r."""
154 pass
156 @abc.abstractmethod 1feabcd
157 def pinv_bilinear_robj(self, A, r): 1feabcd
158 """Compute A'K⁺r, where r can be an array of objects."""
159 pass
161 @abc.abstractmethod 1feabcd
162 def ginv_quad(self, A): 1feabcd
163 """Compute A'K⁻A."""
164 pass
166 @abc.abstractmethod 1feabcd
167 def ginv_diagquad(self, A): 1feabcd
168 """Compute diag(A'K⁻A)."""
169 pass
171 @abc.abstractmethod 1feabcd
172 def correlate(self, x): 1feabcd
173 """ Compute Zx where K = ZZ' """
174 pass
176 @abc.abstractmethod 1feabcd
177 def back_correlate(self, X): 1feabcd
178 """ Compute Z'X """
179 pass
181 @abc.abstractmethod 1feabcd
182 def pinv_correlate(self, x): 1feabcd
183 """ Compute Z⁺x """
184 pass
186 @abc.abstractmethod 1feabcd
187 def minus_log_normal_density(self, 1feabcd
188 r,
189 *,
190 dr_vjp=None,
191 dK_vjp=None,
192 dr_jvp=None,
193 dK_jvp=None,
194 dr=None,
195 dK=None,
196 value=False,
197 gradrev=False,
198 gradfwd=False,
199 fisher=False,
200 fishvec=False,
201 ):
202 """
203 Compute minus log a Normal density and its derivatives, with covariance
204 matrix K.
206 If an input derivative is not specified, it is assumed to be zero.
208 Parameters
209 ----------
210 r: 1d array
211 The residuals (value - mean)
212 dr_vjp: callable
213 x -> x_i ∂r_i/∂p_j, for gradrev and fishvec
214 dK_vjp: callable
215 x -> x_ij ∂K_ij/∂p_k, for gradrev and fishvec
216 dr_jvp_vec: 1d array
217 ∂r_i/∂r_j vec_j, for fishvec
218 dK_jvp_vec: 2d array
219 ∂K_ij/∂p_k vec_k, for fishvec
220 dr: 2d array
221 ∂r_i/∂p_j for gradfwd and fisher
222 dK: 3d array
223 ∂K_ij/∂p_k, for gradfwd and fisher
224 value: bool
225 gradrev: bool
226 gradfwd: bool
227 fisher: bool
228 fishvec: bool
229 These parameters indicate which of the return values to compute.
230 Default all False.
232 Returns
233 -------
234 value: 1/2 tr(KK⁺) log 2π
235 + 1/2 tr(I-KK⁺) log 2π
236 + 1/2 log pdet K
237 + 1/2 tr(I-KK⁺) log ε
238 + 1/2 r'(K⁺+(I-KK⁺)/ε)r
239 gradrev,
240 gradfwd: 1/2 tr(K⁺dK)
241 + r'(K⁺+(I-KK⁺)/ε) dr
242 - 1/2 r'(K⁺+2(I-KK⁺)/ε)dKK⁺r
243 fisher: 1/2 tr(K⁺dK(K⁺+2(I-KK⁺)/ε)d'K)
244 - 2 tr(K⁺dK(I-KK⁺)d'KK⁺)
245 + dr'(K⁺+(I-KK⁺)/ε)d'r
246 fishvec: fisher matrix @ vec
247 """
248 pass
250 def _parseeps(self, K, epsrel, epsabs, maxeigv=None): 1feabcd
251 """ Determine eps from input arguments """
252 machine_eps = jnp.finfo(_jaxext.float_type(K)).eps 1feabcd
253 if epsrel == 'auto': 1feabcd
254 epsrel = len(K) * machine_eps 1feabcd
255 if epsabs == 'auto': 255 ↛ 256line 255 didn't jump to line 256 because the condition on line 255 was never true1feabcd
256 epsabs = machine_eps
257 if maxeigv is None: 257 ↛ 259line 257 didn't jump to line 259 because the condition on line 257 was always true1feabcd
258 maxeigv = eigval_bound(K) 1feabcd
259 self._eps = epsrel * maxeigv + epsabs 1feabcd
260 return self._eps 1feabcd
262 @property 1feabcd
263 def eps(self): 1feabcd
264 """
265 The threshold below which eigenvalues are too small to be determined.
266 """
267 return self._eps 1abcd
269 @property 1feabcd
270 @abc.abstractmethod 1feabcd
271 def n(self): 1feabcd
272 """ Number of rows/columns of the matrix """
273 pass
275 @property 1feabcd
276 @abc.abstractmethod 1feabcd
277 def m(self): 1feabcd
278 """ Number of columns of Z """
279 pass
281 def ginv(self): 1feabcd
282 """ Compute K⁻ """
283 return self.ginv_quad(jnp.eye(self.n)) 1abcd
285def solve_triangular_python(a, b, *, lower=False): 1feabcd
286 """
287 Pure python implementation of scipy.linalg.solve_triangular for when
288 a or b are object arrays.
289 """
290 # TODO maybe commit this to gvar.linalg
291 a = numpy.asarray(a) 1feabcd
292 x = numpy.copy(b) 1feabcd
294 vec = x.ndim < 2 1feabcd
295 if vec: 295 ↛ 298line 295 didn't jump to line 298 because the condition on line 295 was always true1feabcd
296 x = x[:, None] 1feabcd
298 n = a.shape[-1] 1feabcd
299 assert x.shape[-2] == n 1feabcd
301 if not lower: 301 ↛ 302line 301 didn't jump to line 302 because the condition on line 301 was never true1feabcd
302 a = a[..., ::-1, ::-1]
303 x = x[..., ::-1, :]
305 x[..., 0, :] /= a[..., 0, 0, None] 1feabcd
306 for i in range(1, n): 1feabcd
307 x[..., i:, :] -= x[..., None, i - 1, :] * a[..., i:, i - 1, None] 1feabcd
308 x[..., i, :] /= a[..., i, i, None] 1feabcd
310 if not lower: 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true1feabcd
311 x = x[..., ::-1, :]
313 if vec: 313 ↛ 315line 313 didn't jump to line 315 because the condition on line 313 was always true1feabcd
314 x = numpy.squeeze(x, -1) 1feabcd
315 return x 1feabcd
317def solve_triangular_batched(a, b, *, lower=False): 1feabcd
318 """ Version of jax.scipy.linalg.solve_triangular that batches matmul-like """
319 a = jnp.asarray(a) 1abcd
320 b = jnp.asarray(b) 1abcd
321 vec = b.ndim < 2 1abcd
322 if vec: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true1abcd
323 b = b[:, None]
325 batch_shape = jnp.broadcast_shapes(a.shape[:-2], b.shape[:-2]) 1abcd
326 a_shape = batch_shape + a.shape[-2:] 1abcd
327 b_shape = batch_shape + b.shape[-2:] 1abcd
328 result = lax.linalg.triangular_solve( 1abcd
329 jnp.broadcast_to(a, a_shape), jnp.broadcast_to(b, b_shape),
330 left_side=True, lower=lower,
331 )
332 assert result.shape == b_shape 1abcd
334 if vec: 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true1abcd
335 result = result.squeeze(-1)
336 return result 1abcd
338def solve_batched(a, b, **kw): 1feabcd
339 """ Version of jax.scipy.linalg.solve that batches matmul-like """
340 a = jnp.asarray(a) 1abcd
341 b = jnp.asarray(b) 1abcd
342 vec = b.ndim < 2 1abcd
343 if vec: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true1abcd
344 b = b[:, None]
346 @functools.partial(jnp.vectorize, signature='(i,j),(j,k)->(i,k)') 1abcd
347 def solve_batched(a, b): 1abcd
348 return jlinalg.solve(a, b, **kw) 1abcd
349 result = solve_batched(a, b) 1abcd
351 if vec: 351 ↛ 352line 351 didn't jump to line 352 because the condition on line 351 was never true1abcd
352 result = result.squeeze(-1)
353 return result 1abcd
355def eigval_bound(K): 1feabcd
356 """
357 Upper bound on the largest magnitude eigenvalue of the matrix, from
358 Gershgorin's theorem.
359 """
360 return jnp.max(jnp.sum(jnp.abs(K), axis=1)) 1feabcd
362def diag_scale_pow2(K): 1feabcd
363 """
364 Compute a vector s of powers of 2 such that diag(K / outer(s, s)) ~ 1.
365 """
366 d = jnp.diag(K) 1feabcd
367 return jnp.where(d, jnp.exp2(jnp.rint(0.5 * jnp.log2(d))), 1) 1feabcd
369def transpose(x): 1feabcd
370 """ swap the last two axes of array x, corresponds to matrix tranposition
371 with the broadcasting convention of matmul """
372 if x.ndim < 2:
373 return x
374 elif isinstance(x, jnp.ndarray):
375 return jnp.swapaxes(x, -2, -1)
376 else:
377 # need to support numpy because this function is used with gvars
378 return numpy.swapaxes(x, -2, -1)
380class Chol(Decomposition): 1feabcd
381 """Cholesky decomposition. The matrix is regularized adding a small multiple
382 of the identity."""
384 def __init__(self, K, *, epsrel='auto', epsabs=0): 1feabcd
385 # K <- K + Iε
386 # K = LL'
387 self._K = K 1feabcd
388 s = diag_scale_pow2(K) 1feabcd
389 K = K / s / s[:, None] 1feabcd
390 eps = self._parseeps(K, epsrel, epsabs) 1feabcd
391 K = K.at[jnp.diag_indices_from(K)].add(eps) 1feabcd
392 L = jlinalg.cholesky(K, lower=True) 1feabcd
393 with _jaxext.skipifabstract(): 1feabcd
394 if not jnp.all(jnp.isfinite(L)): 1feabcd
395 # TODO check that jax fills with nan after failed row, detect
396 # and report minor index like scipy
397 raise numpy.linalg.LinAlgError('cholesky decomposition not finite, probably matrix not pos def numerically') 1abcd
398 self._L = L * s[:, None] 1feabcd
399 self._eps = eps * jnp.min(s * s) 1feabcd
401 def matrix(self): 1feabcd
402 return self._K 1feabcd
404 def ginv_linear(self, X): 1feabcd
405 # = K⁻¹X
406 # K⁻¹ = L'⁻¹L⁻¹
407 # K⁻¹X = L'⁻¹(L⁻¹X)
408 invLX = jlinalg.solve_triangular(self._L, X, lower=True) 1abcd
409 return jlinalg.solve_triangular(self._L.T, invLX, lower=False) 1abcd
411 def pinv_bilinear(self, A, r): 1feabcd
412 # = A'K⁻¹r = A'L'⁻¹L⁻¹r = (L⁻¹A)'(L⁻¹r)
413 invLr = jlinalg.solve_triangular(self._L, r, lower=True) 1feabcd
414 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd
415 return invLA.T @ invLr 1feabcd
417 def pinv_bilinear_robj(self, A, r): 1feabcd
418 # = A'K⁻¹r
419 invLr = solve_triangular_python(self._L, r, lower=True) 1feabcd
420 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd
421 return numpy.asarray(invLA).T @ invLr 1feabcd
423 def ginv_quad(self, A): 1feabcd
424 # = A'K⁻¹A = A'K⁻¹A = A'L'⁻¹L⁻¹A = (L⁻¹A)'(L⁻¹A)
425 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1feabcd
426 return invLA.T @ invLA 1feabcd
428 def ginv_diagquad(self, A): 1feabcd
429 # = diag(A'K⁻¹A)
430 # X = L⁻¹A
431 # diag(A'K⁻¹A)_i = diag(X'X)_i = ∑_j X'_ij X_ji = ∑_j X_ji X_ji
432 invLA = jlinalg.solve_triangular(self._L, A, lower=True) 1abcd
433 return jnp.einsum('ji,ji->i', invLA, invLA) 1abcd
435 def correlate(self, x): 1feabcd
436 # = Lx
437 return self._L @ x 1feabcd
439 def back_correlate(self, X): 1feabcd
440 # = L'X
441 return self._L.T @ X 1abcd
443 def pinv_correlate(self, x): 1feabcd
444 # = L⁻¹x
445 return jlinalg.solve_triangular(self._L, x, lower=True) 1feabcd
447 def minus_log_normal_density(self, 1feabcd
448 r, # 1d array, the residuals (data - prior mean)
449 *,
450 dr_vjp=None, # callable, x -> x_i ∂r_i/∂p_j, gradrev and fishvec
451 dK_vjp=None, # callable, x -> x_ij ∂K_ij/∂p_k, gradrev and fishvec
452 dr_jvp_vec=None, # 1d array, ∂r_i/∂r_j v_j, fishvec
453 dK_jvp_vec=None, # 2d array, ∂K_ij/∂p_k v_k, fishvec
454 dr=None, # 2d array, ∂r_i/∂p_j, gradfwd and fisher
455 dK=None, # 3d array, ∂K_ij/∂p_k, gradfwd and fisher
456 value=False,
457 gradrev=False,
458 gradfwd=False,
459 fisher=False,
460 fishvec=False,
461 ):
463 L = self._L 1feabcd
465 out = {} 1feabcd
467 # compute shared factors
468 grad = ( 1feabcd
469 (gradrev and (dK_vjp is not None or dr_vjp is not None))
470 or (gradfwd and (dK is not None or dr is not None))
471 )
472 if value or grad: 1feabcd
473 invLr = jlinalg.solve_triangular(L, r, lower=True) 1feabcd
474 if grad: 1feabcd
475 invKr = jlinalg.solve_triangular(L.T, invLr, lower=False) 1feabcd
476 if (gradrev and dK_vjp is not None) or (gradfwd and dK is not None): 1feabcd
477 invL = jlinalg.solve_triangular(L, jnp.eye(len(L)), lower=True) 1feabcd
478 invK = invL.T @ invL 1feabcd
480 if value: 1feabcd
481 # = 1/2 n log 2π
482 # + 1/2 log det K
483 # + 1/2 r'K⁻¹r
484 # K = LL'
485 # K⁻¹ = L'⁻¹L⁻¹
486 # det K = (det L)² =
487 # = (∏_i L_ii)²
488 # r'K⁻¹r = r'L'⁻¹L⁻¹r =
489 # = (L⁻¹r)'(L⁻¹r)
490 out['value'] = 1/2 * ( 1feabcd
491 len(L) * jnp.log(2 * jnp.pi) +
492 2 * jnp.sum(jnp.log(jnp.diag(L))) +
493 invLr @ invLr
494 )
495 else:
496 out['value'] = None 1abcd
498 if gradrev: 1feabcd
499 # = 1/2 tr(K⁻¹dK)
500 # + r'K⁻¹dr
501 # - 1/2 r'K⁻¹dKK⁻¹r
502 # tr(K⁻¹dK) = K⁻¹_ij dK_ji =
503 # = K⁻¹_ij dK_ij =
504 # = dK_vjp(K⁻¹)
505 # r'K⁻¹dr = r_i K⁻¹_ij dr_j =
506 # = (K⁻¹r)_j dr_j =
507 # = dr_vjp(K⁻¹r)
508 # r'K⁻¹dKK⁻¹r = r_i K⁻¹_ij dK_jl K⁻¹_lm r_m =
509 # = (K⁻¹r)_j dK_jl (K⁻¹r)_l =
510 # = dK_vjp((K⁻¹r) ⊗ (K⁻¹r))
511 out['gradrev'] = 0 1feabcd
512 if dK_vjp is not None: 512 ↛ 516line 512 didn't jump to line 516 because the condition on line 512 was always true1feabcd
513 tr_invK_dK = dK_vjp(invK) 1feabcd
514 r_invK_dK_invK_r = dK_vjp(jnp.outer(invKr, invKr)) 1feabcd
515 out['gradrev'] += 1/2 * (tr_invK_dK - r_invK_dK_invK_r) 1feabcd
516 if dr_vjp is not None: 516 ↛ 522line 516 didn't jump to line 522 because the condition on line 516 was always true1feabcd
517 r_invK_dr = dr_vjp(invKr) 1feabcd
518 out['gradrev'] += r_invK_dr 1feabcd
519 else:
520 out['gradrev'] = None 1eabcd
522 if gradfwd: 1feabcd
523 # = 1/2 tr(K⁻¹dK)
524 # + r'K⁻¹dr
525 # - 1/2 r'K⁻¹dKK⁻¹r
526 # tr(K⁻¹dK)_k = K⁻¹_ij dK_ijk
527 # r'K⁻¹dr = (K⁻¹r)'dr
528 # (r'K⁻¹dKK⁻¹r)_k = r_i K⁻¹_ij dK_jlk K⁻¹_lm r_m =
529 # = (K⁻¹r)_j dK_jlk (K⁻¹r)_l
530 out['gradfwd'] = 0 1eabcd
531 if dK is not None: 531 ↛ 535line 531 didn't jump to line 535 because the condition on line 531 was always true1eabcd
532 tr_invK_dK = jnp.einsum('ij,ijk->k', invK, dK) 1eabcd
533 r_invK_dK_invK_r = jnp.einsum('i,ijk,j->k', invKr, dK, invKr) 1eabcd
534 out['gradfwd'] += 1/2 * (tr_invK_dK - r_invK_dK_invK_r) 1eabcd
535 if dr is not None: 535 ↛ 541line 535 didn't jump to line 541 because the condition on line 535 was always true1eabcd
536 r_invK_dr = invKr @ dr 1eabcd
537 out['gradfwd'] += r_invK_dr 1eabcd
538 else:
539 out['gradfwd'] = None 1feabcd
541 if fisher: 1feabcd
542 # = 1/2 tr(K⁻¹dKK⁻¹d'K)
543 # + dr'K⁻¹d'r
544 # tr(K⁻¹dKK⁻¹d'K)_ij = tr(L'⁻¹L⁻¹dKL'⁻¹L⁻¹d'K)_ij =
545 # = tr(L⁻¹dKL'⁻¹L⁻¹d'KL'⁻¹)_ij =
546 # = (L⁻¹dKL'⁻¹)_kli (L⁻¹dKL'⁻¹)_klj
547 # (L⁻¹dKL'⁻¹)_ijk = L⁻¹_il dK_lmk L'⁻¹_mj =
548 # = L⁻¹_il L⁻¹_jm dK_lmk
549 # (dr'K⁻¹d'r)_kq = dr'_k L'⁻¹L⁻¹dr_q =
550 # = (L⁻¹dr_k)_i (L⁻¹dr_q)_i
551 out['fisher'] = 0 1abcd
552 if dK is not None: 552 ↛ 561line 552 didn't jump to line 561 because the condition on line 552 was always true1abcd
553 invL_dK = solve_triangular_batched(L, 1abcd
554 jnp.moveaxis(dK, 2, 0),
555 lower=True) # kim: L⁻¹_il dK_lmk
556 invL_dK_invL = solve_triangular_batched(L, 1abcd
557 jnp.swapaxes(invL_dK, 1, 2),
558 lower=True) # kji: L⁻¹_jm (L⁻¹_il dK_lmk)
559 tr_invK_dK_invK_dK = jnp.einsum('kij,qij->kq', invL_dK_invL, invL_dK_invL) 1abcd
560 out['fisher'] += 1/2 * tr_invK_dK_invK_dK 1abcd
561 if dr is not None: 561 ↛ 568line 561 didn't jump to line 568 because the condition on line 561 was always true1abcd
562 invLdr = jlinalg.solve_triangular(L, dr, lower=True) 1abcd
563 dr_invK_dr = invLdr.T @ invLdr 1abcd
564 out['fisher'] += dr_invK_dr 1abcd
565 else:
566 out['fisher'] = None 1feabcd
568 if fishvec: 1feabcd
569 # = 1/2 tr(K⁻¹dKK⁻¹d'K) v
570 # + dr'K⁻¹d'r v
571 # tr(K⁻¹dKK⁻¹d'K) v = K_vjp(K⁻¹K_jvp(v)K⁻¹) =
572 # = K_vjp(L'⁻¹L⁻¹ K_jvp(v) L'⁻¹L⁻¹)
573 # dr'K⁻¹d'r v = dr'K⁻¹dr_jvp(v) =
574 # = dr_vjp(K⁻¹dr_jvp(v)) =
575 # = dr_vjp(L'⁻¹L⁻¹ dr_jvp(v))
576 out['fishvec'] = 0 1abcd
577 if not (dK_jvp_vec is None and dK_vjp is None): 577 ↛ 584line 577 didn't jump to line 584 because the condition on line 577 was always true1abcd
578 invL_dKv = jlinalg.solve_triangular(L, dK_jvp_vec, lower=True) 1abcd
579 invK_dKv = jlinalg.solve_triangular(L.T, invL_dKv, lower=False) 1abcd
580 invL_dKv_invK = jlinalg.solve_triangular(L, invK_dKv.T, lower=True) 1abcd
581 invK_dKv_invK = jlinalg.solve_triangular(L.T, invL_dKv_invK, lower=False) 1abcd
582 tr_invK_dK_invK_dK_v = dK_vjp(invK_dKv_invK) 1abcd
583 out['fishvec'] += 1/2 * tr_invK_dK_invK_dK_v 1abcd
584 if not (dr_jvp_vec is None and dr_vjp is None): 584 ↛ 592line 584 didn't jump to line 592 because the condition on line 584 was always true1abcd
585 invL_drv = jlinalg.solve_triangular(L, dr_jvp_vec, lower=True) 1abcd
586 invK_drv = jlinalg.solve_triangular(L.T, invL_drv, lower=False) 1abcd
587 dr_invK_drv_v = dr_vjp(invK_drv) 1abcd
588 out['fishvec'] += dr_invK_drv_v 1abcd
589 else:
590 out['fishvec'] = None 1feabcd
592 return tuple(out.values()) 1feabcd
594 @classmethod 1feabcd
595 def make_derivs(cls, 1feabcd
596 K_fun, r_fun, primal,
597 *,
598 args=(),
599 kw={},
600 vec=None,
601 value=False,
602 gradrev=False,
603 gradfwd=False,
604 fisher=False,
605 fishvec=False,
606 ):
607 """
608 Prepares arguments for `minus_log_normal_density`.
610 Parameters
611 ----------
612 K_fun, r_fun : callable
613 Functions with signature ``f(primal, *args, **kw)`` that produce the
614 `K` init argument and the `r` `minus_log_normal_density` argument.
615 primal : 1d array
616 The first argument to `K_fun` and `r_fun`.
617 args : tuple
618 Additional positional arguments to `K_fun` and `r_fun`.
619 kw : dict
620 Keyword arguments to `K_fun` and `r_fun`.
621 vec : 1d array
622 A tangent vector to compute the jacobian-vector products.
623 value, gradrev, gradfwd, fisher, fishvec : bool
624 Arguments to `minus_log_normal_density`, used to determine which
625 derivatives are needed.
627 Returns
628 -------
629 K : 2d array
630 Output of `K_fun`.
631 r : 1d array
632 Output of `r_fun`.
633 out : dict
634 Dictionary with derivative arguments to `minus_log_normal_density`.
635 """
637 partial = lambda f: lambda x: f(x, *args, **kw) 1abcd
638 K_fun = partial(K_fun) 1abcd
639 r_fun = partial(r_fun) 1abcd
641 out = {} 1abcd
643 if gradrev or fishvec: 1abcd
644 K, dK_vjp = jax.vjp(K_fun, primal) 1abcd
645 r, dr_vjp = jax.vjp(r_fun, primal) 1abcd
646 out['dK_vjp'] = lambda x: dK_vjp(x)[0] 1abcd
647 out['dr_vjp'] = lambda x: dr_vjp(x)[0] 1abcd
648 else:
649 K = K_fun(primal) 1abcd
650 r = r_fun(primal) 1abcd
651 if fishvec: 1abcd
652 _, out['dK_jvp_vec'] = jax.jvp(K_fun, (primal,), (vec,)) 1abcd
653 _, out['dr_jvp_vec'] = jax.jvp(r_fun, (primal,), (vec,)) 1abcd
654 if gradfwd or fisher: 1abcd
655 out['dK'] = jax.jacfwd(K_fun)(primal) 1abcd
656 out['dr'] = jax.jacfwd(r_fun)(primal) 1abcd
658 return K, r, out 1abcd
660 @property 1feabcd
661 def n(self): 1feabcd
662 return len(self._L) 1feabcd
664 m = n 1feabcd