Coverage for src/lsqfitgp/_linalg/_toeplitz.py: 96%
136 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_linalg/_toeplitz.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import jax 1efabcd
21from jax import numpy as jnp 1efabcd
22import numpy 1efabcd
24from . import _seqalg 1efabcd
26class SymSchur(_seqalg.Producer): 1efabcd
27 """
28 Cholesky decomposition of a symmetric Toeplitz matrix
30 Adapted from TOEPLITZ_CHOLESKY by John Burkardt (LGPL license)
31 http://people.sc.fsu.edu/~jburkardt/py_src/toeplitz_cholesky/toeplitz_cholesky.html
32 """
34 def __init__(self, t): 1efabcd
35 """t = first row of the matrix"""
36 t = jnp.asarray(t) 1abcd
37 assert t.ndim == 1 1abcd
38 # assert t[0] > 0, '1-th leading minor is not positive definite'
39 self.t = t 1abcd
41 inputs = () 1efabcd
43 def init(self, n): 1efabcd
44 t = self.t 1abcd
45 del self.t 1abcd
46 assert len(t) == n 1abcd
47 norm = t[0] 1abcd
48 t = t / norm 1abcd
49 self.g = jnp.stack([t, t]) 1abcd
50 self.snorm = jnp.sqrt(norm) 1abcd
52 def iter_out(self, i): 1efabcd
53 """i-th column of Cholesky factor L"""
54 return self.g[0, :] * self.snorm 1abcd
56 def iter(self, i): 1efabcd
57 g = self.g 1abcd
58 g = g.at[0, :].set(jnp.roll(g[0, :], 1)) 1abcd
59 g = g.at[:, 0].set(0).at[:, i - 1].set(0) 1abcd
60 # assert g[0, i] > 0, 'what??'
61 rho = -g[1, i] / g[0, i] 1abcd
62 # assert abs(rho) < 1, f'{i+1}-th leading minor is not positive definite'
63 gamma = jnp.sqrt((1 - rho) * (1 + rho)) 1abcd
64 self.g = (g + g[::-1] * rho) / gamma 1abcd
66 # TODO Schur, Levinson, and in general algorithms with triangular matrices
67 # are not efficient as I implement them because due to the constraint of
68 # fixed size in the jit I can not take advantage of the triangularity and
69 # the number of operations doubles. This could be solved with a blocked
70 # version which decreases the block size a few times, n blocks brings down
71 # the increase factor to 1 + 1/n. The compilation time and code size would
72 # be proportional to n. Aligning block size to powers of 2 would help
73 # caching the compilation, bringing the compilation time to the normal one
74 # after warmup (does jax compilation reuse functions or inline them?)
75 #
76 # Anyway, this 2 factor is not currently relevant.
78class SymLevinson(_seqalg.Producer): 1efabcd
79 """
80 Cholesky decomposition of the *inverse* of a symmetric Toeplitz matrix
82 Adapted form SuperGauss/DurbinLevinson.h (GPL license)
83 https://cran.r-project.org/package=SuperGauss
85 Note: Schur should be more accurate than Levinson
86 """
88 def __init__(self, t): 1efabcd
89 """t = first row of the matrix"""
90 t = jnp.asarray(t, float) 1abcd
91 assert t.ndim == 1 1abcd
92 # assert t[0] > 0, '1-th leading minor is not positive definite'
93 self.t = t 1abcd
95 inputs = () 1efabcd
97 def init(self, n): 1efabcd
98 self.phi1 = jnp.zeros(n) 1abcd
99 self.phi2 = jnp.zeros(n) 1abcd
100 self.nu = self.t[0] 1abcd
101 self.tlag = jnp.roll(self.t, -1) 1abcd
102 del self.t 1abcd
104 def iter_out(self, i): 1efabcd
105 """i-th row of L^-1"""
106 return -self.phi2.at[i].set(-1) / jnp.sqrt(self.nu) 1abcd
108 def iter(self, i): 1efabcd
109 phi1 = self.phi1 1abcd
110 phi2 = self.phi2 1abcd
111 nu = self.nu 1abcd
112 tlag = self.tlag 1abcd
114 pi = i - 1 1abcd
115 rp = phi2 @ tlag 1abcd
116 phi1 = phi1.at[pi].set((tlag[pi] - rp) / nu) 1abcd
117 phi1 = phi1 - phi1[pi] * phi2 1abcd
118 # assert abs(phi1[pi]) < 1, f'{i+1}-th leading minor is not positive definite'
119 nu = nu * (1 - phi1[pi]) * (1 + phi1[pi]) 1abcd
120 phi2 = jnp.roll(phi1[::-1], i) 1abcd
122 self.phi1 = phi1 1abcd
123 self.phi2 = phi2 1abcd
124 self.nu = nu 1abcd
126@jax.jit 1efabcd
127def chol(t): 1efabcd
128 _, out = _seqalg.sequential_algorithm(len(t), [SymSchur(t), _seqalg.Stack(0)]) 1abcd
129 return out.T 1abcd
131@jax.jit 1efabcd
132def chol_solve(t, *bs): 1efabcd
133 ops = [SymSchur(t)] + [_seqalg.SolveTriLowerColByFull(0, b) for b in bs] 1abcd
134 out = _seqalg.sequential_algorithm(len(t), ops) 1abcd
135 return out[1] if len(bs) == 1 else out[1:] 1abcd
137@jax.jit 1efabcd
138def chol_matmul(t, b): 1efabcd
139 ops = [SymSchur(t), _seqalg.Rows(b), _seqalg.MatMulColByRow(0, 1)] 1abcd
140 _, _, out = _seqalg.sequential_algorithm(len(t), ops) 1abcd
141 return out 1abcd
143@jax.jit 1efabcd
144def chol_transp_matmul(t, b): 1efabcd
145 ops = [SymSchur(t), _seqalg.MatMulRowByFull(0, b), _seqalg.Stack(1)]
146 _, _, out = _seqalg.sequential_algorithm(len(t), ops)
147 return out
149@jax.jit 1efabcd
150def logdet(t): 1efabcd
151 _, out = _seqalg.sequential_algorithm(len(t), [SymSchur(t), _seqalg.SumLogDiag(0)]) 1abcd
152 return 2 * out 1abcd
154@jax.jit 1efabcd
155def solve(t, b): 1efabcd
156 ops = [SymLevinson(t), _seqalg.MatMulRowByFull(0, b), _seqalg.MatMulColByRow(0, 1)] 1abcd
157 _, _, out = _seqalg.sequential_algorithm(len(t), ops) 1abcd
158 return out 1abcd
160@jax.jit 1efabcd
161def chol_transp_solve(t, b): 1efabcd
162 ops = [SymLevinson(t), _seqalg.Rows(b), _seqalg.MatMulColByRow(0, 1)]
163 _, _, out = _seqalg.sequential_algorithm(len(t), ops)
164 return out
166def chol_solve_numpy(t, b, diageps=None): 1efabcd
167 """
169 Solve a linear system for the cholesky factor of a symmetric Toeplitz
170 matrix. The algorithm is:
172 t[0] += diageps
173 m = toeplitz(t)
174 l = chol(m)
175 return solve(l, b)
177 Numpy object arrays are supported. Broadcasts like matmul.
179 Parameters
180 ----------
181 t : (..., n) array
182 The first row or column of the matrix.
183 b : (..., n, m) or (n,) array
184 The right hand side of the linear system.
185 diageps : scalar, optional
186 Term added to the diagonal elements of the matrix for regularization.
188 """
190 t = numpy.array(t, subok=True) 1abcd
191 n = t.shape[-1] 1abcd
193 b = numpy.asanyarray(b) 1abcd
194 vec = b.ndim < 2 1abcd
195 if vec: 1abcd
196 b = b[:, None] 1abcd
197 assert b.shape[-2] == n 1abcd
199 t = t.astype(numpy.result_type(t, 0.1), copy=False) 1abcd
200 b = b.astype(numpy.result_type(b, 0.1), copy=False) 1abcd
202 if n == 0: 1abcd
203 shape = numpy.broadcast_shapes(t.shape[:-1], b.shape[:-2]) 1abcd
204 shape += (n,) if vec else b.shape[-2:] 1abcd
205 dtype = numpy.result_type(t.dtype, b.dtype) 1abcd
206 return numpy.empty(shape, dtype) 1abcd
208 if diageps is not None: 1abcd
209 t[..., 0] += diageps 1abcd
211 if numpy.any(t[..., 0] <= 0): 1abcd
212 msg = '1-th leading minor is not positive definite' 1abcd
213 raise numpy.linalg.LinAlgError(msg) 1abcd
215 norm = numpy.copy(t[..., 0, None], subok=True) 1abcd
216 t /= norm 1abcd
217 invLb = numpy.copy(numpy.broadcast_arrays(b, t[..., None])[0], subok=True) 1abcd
218 prevLi = t 1abcd
219 g = numpy.stack([numpy.roll(t, 1, -1), t], -2) 1abcd
221 for i in range(1, n): 1abcd
223 assert numpy.all(g[..., 0, i] > 0) 1abcd
224 rho = -g[..., 1, i, None, None] / g[..., 0, i, None, None] 1abcd
226 if numpy.any(numpy.abs(rho) >= 1): 1abcd
227 msg = '{}-th leading minor is not positive definite'.format(i + 1) 1abcd
228 raise numpy.linalg.LinAlgError(msg) 1abcd
230 gamma = numpy.sqrt((1 - rho) * (1 + rho)) 1abcd
231 g[..., :, i:] += g[..., ::-1, i:] * rho 1abcd
232 g[..., :, i:] /= gamma 1abcd
233 Li = g[..., 0, i:] # i-th column of L from row i 1abcd
234 invLb[..., i:, :] -= invLb[..., i - 1, None, :] * prevLi[..., i:, None] 1abcd
235 invLb[..., i, :] /= Li[..., 0, None] 1abcd
236 prevLi[..., i:] = Li 1abcd
237 g[..., 0, i:] = numpy.roll(g[..., 0, i:], 1, -1) 1abcd
239 invLb /= numpy.sqrt(norm[..., None]) 1abcd
240 if vec: 1abcd
241 invLb = numpy.squeeze(invLb, -1) 1abcd
242 return invLb 1abcd
244def eigv_bound(t): 1efabcd
245 """
247 Bound the eigenvalues of a symmetric Toeplitz matrix.
249 Parameters
250 ----------
251 t : array
252 The first row of the matrix.
254 Returns
255 -------
256 m : scalar
257 Any eigenvalue `v` of the matrix satisfies `|v| <= m`.
259 """
260 s = jnp.abs(t) 1abcd
261 c = jnp.cumsum(s) 1abcd
262 d = c + c[::-1] - s[0] 1abcd
263 return jnp.max(d) 1abcd