Coverage for src/lsqfitgp/_linalg/_toeplitz.py: 96%

136 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_linalg/_toeplitz.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import jax 1efabcd

21from jax import numpy as jnp 1efabcd

22import numpy 1efabcd

23 

24from . import _seqalg 1efabcd

25 

26class SymSchur(_seqalg.Producer): 1efabcd

27 """ 

28 Cholesky decomposition of a symmetric Toeplitz matrix 

29  

30 Adapted from TOEPLITZ_CHOLESKY by John Burkardt (LGPL license) 

31 http://people.sc.fsu.edu/~jburkardt/py_src/toeplitz_cholesky/toeplitz_cholesky.html 

32 """ 

33 

34 def __init__(self, t): 1efabcd

35 """t = first row of the matrix""" 

36 t = jnp.asarray(t) 1abcd

37 assert t.ndim == 1 1abcd

38 # assert t[0] > 0, '1-th leading minor is not positive definite' 

39 self.t = t 1abcd

40 

41 inputs = () 1efabcd

42 

43 def init(self, n): 1efabcd

44 t = self.t 1abcd

45 del self.t 1abcd

46 assert len(t) == n 1abcd

47 norm = t[0] 1abcd

48 t = t / norm 1abcd

49 self.g = jnp.stack([t, t]) 1abcd

50 self.snorm = jnp.sqrt(norm) 1abcd

51 

52 def iter_out(self, i): 1efabcd

53 """i-th column of Cholesky factor L""" 

54 return self.g[0, :] * self.snorm 1abcd

55 

56 def iter(self, i): 1efabcd

57 g = self.g 1abcd

58 g = g.at[0, :].set(jnp.roll(g[0, :], 1)) 1abcd

59 g = g.at[:, 0].set(0).at[:, i - 1].set(0) 1abcd

60 # assert g[0, i] > 0, 'what??' 

61 rho = -g[1, i] / g[0, i] 1abcd

62 # assert abs(rho) < 1, f'{i+1}-th leading minor is not positive definite' 

63 gamma = jnp.sqrt((1 - rho) * (1 + rho)) 1abcd

64 self.g = (g + g[::-1] * rho) / gamma 1abcd

65 

66 # TODO Schur, Levinson, and in general algorithms with triangular matrices 

67 # are not efficient as I implement them because due to the constraint of 

68 # fixed size in the jit I can not take advantage of the triangularity and 

69 # the number of operations doubles. This could be solved with a blocked 

70 # version which decreases the block size a few times, n blocks brings down 

71 # the increase factor to 1 + 1/n. The compilation time and code size would 

72 # be proportional to n. Aligning block size to powers of 2 would help 

73 # caching the compilation, bringing the compilation time to the normal one 

74 # after warmup (does jax compilation reuse functions or inline them?) 

75 # 

76 # Anyway, this 2 factor is not currently relevant. 

77 

78class SymLevinson(_seqalg.Producer): 1efabcd

79 """ 

80 Cholesky decomposition of the *inverse* of a symmetric Toeplitz matrix 

81  

82 Adapted form SuperGauss/DurbinLevinson.h (GPL license) 

83 https://cran.r-project.org/package=SuperGauss 

84  

85 Note: Schur should be more accurate than Levinson 

86 """ 

87 

88 def __init__(self, t): 1efabcd

89 """t = first row of the matrix""" 

90 t = jnp.asarray(t, float) 1abcd

91 assert t.ndim == 1 1abcd

92 # assert t[0] > 0, '1-th leading minor is not positive definite' 

93 self.t = t 1abcd

94 

95 inputs = () 1efabcd

96 

97 def init(self, n): 1efabcd

98 self.phi1 = jnp.zeros(n) 1abcd

99 self.phi2 = jnp.zeros(n) 1abcd

100 self.nu = self.t[0] 1abcd

101 self.tlag = jnp.roll(self.t, -1) 1abcd

102 del self.t 1abcd

103 

104 def iter_out(self, i): 1efabcd

105 """i-th row of L^-1""" 

106 return -self.phi2.at[i].set(-1) / jnp.sqrt(self.nu) 1abcd

107 

108 def iter(self, i): 1efabcd

109 phi1 = self.phi1 1abcd

110 phi2 = self.phi2 1abcd

111 nu = self.nu 1abcd

112 tlag = self.tlag 1abcd

113 

114 pi = i - 1 1abcd

115 rp = phi2 @ tlag 1abcd

116 phi1 = phi1.at[pi].set((tlag[pi] - rp) / nu) 1abcd

117 phi1 = phi1 - phi1[pi] * phi2 1abcd

118 # assert abs(phi1[pi]) < 1, f'{i+1}-th leading minor is not positive definite' 

119 nu = nu * (1 - phi1[pi]) * (1 + phi1[pi]) 1abcd

120 phi2 = jnp.roll(phi1[::-1], i) 1abcd

121 

122 self.phi1 = phi1 1abcd

123 self.phi2 = phi2 1abcd

124 self.nu = nu 1abcd

125 

126@jax.jit 1efabcd

127def chol(t): 1efabcd

128 _, out = _seqalg.sequential_algorithm(len(t), [SymSchur(t), _seqalg.Stack(0)]) 1abcd

129 return out.T 1abcd

130 

131@jax.jit 1efabcd

132def chol_solve(t, *bs): 1efabcd

133 ops = [SymSchur(t)] + [_seqalg.SolveTriLowerColByFull(0, b) for b in bs] 1abcd

134 out = _seqalg.sequential_algorithm(len(t), ops) 1abcd

135 return out[1] if len(bs) == 1 else out[1:] 1abcd

136 

137@jax.jit 1efabcd

138def chol_matmul(t, b): 1efabcd

139 ops = [SymSchur(t), _seqalg.Rows(b), _seqalg.MatMulColByRow(0, 1)] 1abcd

140 _, _, out = _seqalg.sequential_algorithm(len(t), ops) 1abcd

141 return out 1abcd

142 

143@jax.jit 1efabcd

144def chol_transp_matmul(t, b): 1efabcd

145 ops = [SymSchur(t), _seqalg.MatMulRowByFull(0, b), _seqalg.Stack(1)] 

146 _, _, out = _seqalg.sequential_algorithm(len(t), ops) 

147 return out 

148 

149@jax.jit 1efabcd

150def logdet(t): 1efabcd

151 _, out = _seqalg.sequential_algorithm(len(t), [SymSchur(t), _seqalg.SumLogDiag(0)]) 1abcd

152 return 2 * out 1abcd

153 

154@jax.jit 1efabcd

155def solve(t, b): 1efabcd

156 ops = [SymLevinson(t), _seqalg.MatMulRowByFull(0, b), _seqalg.MatMulColByRow(0, 1)] 1abcd

157 _, _, out = _seqalg.sequential_algorithm(len(t), ops) 1abcd

158 return out 1abcd

159 

160@jax.jit 1efabcd

161def chol_transp_solve(t, b): 1efabcd

162 ops = [SymLevinson(t), _seqalg.Rows(b), _seqalg.MatMulColByRow(0, 1)] 

163 _, _, out = _seqalg.sequential_algorithm(len(t), ops) 

164 return out 

165 

166def chol_solve_numpy(t, b, diageps=None): 1efabcd

167 """ 

168  

169 Solve a linear system for the cholesky factor of a symmetric Toeplitz 

170 matrix. The algorithm is: 

171  

172 t[0] += diageps 

173 m = toeplitz(t) 

174 l = chol(m) 

175 return solve(l, b) 

176  

177 Numpy object arrays are supported. Broadcasts like matmul. 

178  

179 Parameters 

180 ---------- 

181 t : (..., n) array 

182 The first row or column of the matrix. 

183 b : (..., n, m) or (n,) array 

184 The right hand side of the linear system. 

185 diageps : scalar, optional 

186 Term added to the diagonal elements of the matrix for regularization. 

187 

188 """ 

189 

190 t = numpy.array(t, subok=True) 1abcd

191 n = t.shape[-1] 1abcd

192 

193 b = numpy.asanyarray(b) 1abcd

194 vec = b.ndim < 2 1abcd

195 if vec: 1abcd

196 b = b[:, None] 1abcd

197 assert b.shape[-2] == n 1abcd

198 

199 t = t.astype(numpy.result_type(t, 0.1), copy=False) 1abcd

200 b = b.astype(numpy.result_type(b, 0.1), copy=False) 1abcd

201 

202 if n == 0: 1abcd

203 shape = numpy.broadcast_shapes(t.shape[:-1], b.shape[:-2]) 1abcd

204 shape += (n,) if vec else b.shape[-2:] 1abcd

205 dtype = numpy.result_type(t.dtype, b.dtype) 1abcd

206 return numpy.empty(shape, dtype) 1abcd

207 

208 if diageps is not None: 1abcd

209 t[..., 0] += diageps 1abcd

210 

211 if numpy.any(t[..., 0] <= 0): 1abcd

212 msg = '1-th leading minor is not positive definite' 1abcd

213 raise numpy.linalg.LinAlgError(msg) 1abcd

214 

215 norm = numpy.copy(t[..., 0, None], subok=True) 1abcd

216 t /= norm 1abcd

217 invLb = numpy.copy(numpy.broadcast_arrays(b, t[..., None])[0], subok=True) 1abcd

218 prevLi = t 1abcd

219 g = numpy.stack([numpy.roll(t, 1, -1), t], -2) 1abcd

220 

221 for i in range(1, n): 1abcd

222 

223 assert numpy.all(g[..., 0, i] > 0) 1abcd

224 rho = -g[..., 1, i, None, None] / g[..., 0, i, None, None] 1abcd

225 

226 if numpy.any(numpy.abs(rho) >= 1): 1abcd

227 msg = '{}-th leading minor is not positive definite'.format(i + 1) 1abcd

228 raise numpy.linalg.LinAlgError(msg) 1abcd

229 

230 gamma = numpy.sqrt((1 - rho) * (1 + rho)) 1abcd

231 g[..., :, i:] += g[..., ::-1, i:] * rho 1abcd

232 g[..., :, i:] /= gamma 1abcd

233 Li = g[..., 0, i:] # i-th column of L from row i 1abcd

234 invLb[..., i:, :] -= invLb[..., i - 1, None, :] * prevLi[..., i:, None] 1abcd

235 invLb[..., i, :] /= Li[..., 0, None] 1abcd

236 prevLi[..., i:] = Li 1abcd

237 g[..., 0, i:] = numpy.roll(g[..., 0, i:], 1, -1) 1abcd

238 

239 invLb /= numpy.sqrt(norm[..., None]) 1abcd

240 if vec: 1abcd

241 invLb = numpy.squeeze(invLb, -1) 1abcd

242 return invLb 1abcd

243 

244def eigv_bound(t): 1efabcd

245 """ 

246  

247 Bound the eigenvalues of a symmetric Toeplitz matrix. 

248  

249 Parameters 

250 ---------- 

251 t : array 

252 The first row of the matrix. 

253  

254 Returns 

255 ------- 

256 m : scalar 

257 Any eigenvalue `v` of the matrix satisfies `|v| <= m`. 

258  

259 """ 

260 s = jnp.abs(t) 1abcd

261 c = jnp.cumsum(s) 1abcd

262 d = c + c[::-1] - s[0] 1abcd

263 return jnp.max(d) 1abcd