Coverage for src/lsqfitgp/_kernels/_arma.py: 97%

228 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_kernels/_arma.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import jax 1efabcd

21from jax import numpy as jnp 1efabcd

22from jax import lax 1efabcd

23import numpy 1efabcd

24 

25from .. import _linalg 1efabcd

26from .._linalg import _toeplitz 1efabcd

27from .. import _jaxext 1efabcd

28from .._Kernel import stationarykernel 1efabcd

29 

30# use positive delta because negative indices wrap around 

31@stationarykernel(derivable=False, maxdim=1, input='abs') 1efabcd

32def MA(delta, w=None, norm=False): 1efabcd

33 """ 

34 Discrete moving average kernel. 

35  

36 .. math:: 

37 k(\\Delta) = \\sum_{k=|\\Delta|}^{n-1} w_k w_{k-|\\Delta|}, 

38 \\quad \\mathbf w = (w_0, \\ldots, w_{n-1}). 

39  

40 The inputs must be integers. It is the autocovariance function of a moving 

41 average with weights :math:`\\mathbf w` applied to white noise: 

42  

43 .. math:: 

44 k(i, j) &= \\operatorname{Cov}[y_i, y_j], \\\\ 

45 y_i &= \\sum_{k=0}^{n-1} w_k \\epsilon_{i-k}, \\\\ 

46 \\operatorname{Cov}[\\epsilon_i,\\epsilon_j] &= \\delta_{ij}. 

47  

48 If ``norm=True``, the variance is normalized to 1, which amounts to 

49 normalizing :math:`\\mathbf w` to unit length. 

50  

51 """ 

52 

53 # TODO reference? must find some standard book with a treatment which is 

54 # not too formal yet writes clearly about the covariance function 

55 

56 # TODO nd version with w.ndim == n, it's a nd convolution. use 

57 # jax.scipy.signal.correlate. 

58 

59 w = jnp.asarray(w) 1abcd

60 assert w.ndim == 1 1abcd

61 if len(w): 1abcd

62 cov = jnp.convolve(w, w[::-1]) 1abcd

63 if norm: 1abcd

64 cov /= cov[len(w) - 1] 1abcd

65 return cov.at[delta + len(w) - 1].get(mode='fill', fill_value=0) 1abcd

66 else: 

67 return jnp.zeros(delta.shape) 1abcd

68 

69@stationarykernel(derivable=False, maxdim=1, input='abs') 1efabcd

70def _ARBase(delta, phi=None, gamma=None, maxlag=None, slnr=None, lnc=None, norm=False): 1efabcd

71 """ 

72 Discrete autoregressive kernel. 

73  

74 You have to specify one and only one of the sets of parameters 

75 ``phi+maxlag``, ``gamma+maxlag``, ``slnr+lnc``. 

76 

77 Parameters 

78 ---------- 

79 phi : (p,) real 

80 The autoregressive coefficients at lag 1...p. 

81 gamma : (p + 1,) real 

82 The autocovariance at lag 0...p. 

83 maxlag : int 

84 The maximum lag that the kernel will be evaluated on. If the actual 

85 inputs produce higher lags, the missing values are filled with ``nan``. 

86 slnr : (nr,) real 

87 The real roots of the characteristic polynomial, expressed in the 

88 following way: ``sign(slnr)`` is the sign of the root, and 

89 ``abs(slnr)`` is the natural logarithm of the absolute value. 

90 lnc : (nc,) complex 

91 The natural logarithm of the complex roots of the characteristic 

92 polynomial (:math:`\\log z = \\log|z| + i\\arg z`), where each root 

93 also stands for its paired conjugate. 

94  

95 In `slnr` and `lnc`, the multiplicity of a root is expressed by 

96 repeating the root in the array (not necessarily next to each other). 

97 Only exact repetition counts; very close yet distinct roots are treated 

98 as separate and lead to numerical instability, in particular complex 

99 roots very close to the real line. An exactly real complex root behaves 

100 like a pair of identical real roots. Two complex roots also count as 

101 equal if conjugate, and the argument is standardized to :math:`[0, 

102 2\\pi)`. 

103 norm : bool, default False 

104 If True, normalize the autocovariance to be 1 at lag 0. If False, 

105 normalize such that the variance of the generating noise is 1, or use 

106 the user-provided normalization if `gamma` is specified. 

107  

108 Notes 

109 ----- 

110 This is the covariance function of a stationary autoregressive process, 

111 which is defined recursively as 

112  

113 .. math:: 

114 y_i = \\sum_{k=1}^p \\phi_k y_{i-k} + \\epsilon_i, 

115  

116 where :math:`\\epsilon_i` is white noise, i.e., 

117 :math:`\\operatorname{Cov}[\\epsilon_i, \\epsilon_j] = \\delta_{ij}`. The 

118 length :math:`p` of the vector of coefficients :math:`\\boldsymbol\\phi` 

119 is the "order" of the process. 

120  

121 The covariance function can be expressed in two ways. First as the same 

122 recursion defining the process: 

123  

124 .. math:: 

125 \\gamma_m = \\sum_{k=1}^p \\phi_k \\gamma_{m-k} + \\delta_{m0}, 

126  

127 where :math:`\\gamma_m \\equiv \\operatorname{Cov}[y_i, y_{i+m}]`. This is 

128 called "Yule-Walker equation." Second, as a linear combination of mixed 

129 power-exponentials: 

130  

131 .. math:: 

132 \\gamma_m = \\sum_{j=1}^n 

133 \\sum_{l=1}^{\\mu_j} 

134 a_{jl} |m|^{l-1} x_j^{-|m|}, 

135  

136 where :math:`x_j` and :math:`\\mu_j` are the (complex) roots and 

137 corresponding multiplicities of the "characteristic polynomial" 

138  

139 .. math:: 

140 P(x) = 1 - \\sum_{k=1}^p \\phi_k x^k, 

141  

142 and the :math:`a_{jl}` are uniquely determined complex coefficients. The 

143 :math:`\\boldsymbol\\phi` vector is valid iff :math:`|x_j|>1, \\forall j`. 

144  

145 There are three alternative parametrization for this kernel. 

146  

147 If you specify `phi`, the first terms of the covariance are computed 

148 solving the Yule-Walker equation, and then evolved up to `maxlag`. It 

149 is necessary to specify `maxlag` instead of letting the code figure it out 

150 from the actual inputs for technical reasons. 

151  

152 Likewise, if you specify `gamma`, the coefficients are obtained with 

153 Yule-Walker and then used to evolve the covariance. The only difference is 

154 that the normalization can be different: starting from `phi`, the variance 

155 of the generating noise :math:`\\epsilon` is fixed to 1, while giving 

156 `gamma` directly implies an arbitrary value. 

157  

158 Instead, if you specify the roots with `slnr` and `lnc`, the coefficients 

159 are obtained from the polynomial defined in terms of the roots, and then 

160 the amplitudes :math:`a_{jl}` are computed by solving a linear system with 

161 the covariance (from YW) as RHS. Finally, the full covariance function is 

162 evaluated with the analytical expression. 

163  

164 The reasons for using the logarithm are that 1) in practice the roots are 

165 tipically close to 1, so the logarithm is numerically more accurate, and 2) 

166 the logarithm is readily interpretable as the inverse of the correlation 

167 length. 

168  

169 """ 

170 cond = ( 1abcd

171 (phi is not None and maxlag is not None and gamma is None and slnr is None and lnc is None) or 

172 (phi is None and maxlag is not None and gamma is not None and slnr is None and lnc is None) or 

173 (phi is None and maxlag is None and gamma is None and slnr is not None and lnc is not None) 

174 ) 

175 if not cond: 1abcd

176 raise ValueError('invalid set of specified parameters') 1abcd

177 

178 # TODO maybe I could allow redundantly specifying gamma and phi, e.g., for 

179 # numerical accuracy reasons if they are determined from an analytical 

180 # expression. 

181 

182 if phi is None and gamma is None: 1abcd

183 return _ar_with_roots(delta, slnr, lnc, norm) 1abcd

184 else: 

185 return _ar_with_phigamma(delta, phi, gamma, maxlag, norm) 1abcd

186 

187def _ar_with_phigamma(delta, phi, gamma, maxlag, norm): 1efabcd

188 if phi is None: 1abcd

189 phi = AR.phi_from_gamma(gamma) 1abcd

190 if gamma is None: 1abcd

191 gamma = AR.gamma_from_phi(phi) 1abcd

192 if norm: 1abcd

193 gamma = gamma / gamma[0] 1abcd

194 acf = AR.extend_gamma(gamma, phi, maxlag + 1 - len(gamma)) 1abcd

195 return acf.at[delta].get(mode='fill', fill_value=jnp.nan) 1abcd

196 

197def _yule_walker(gamma): 1efabcd

198 """ 

199 gamma = autocovariance at lag 0...p 

200 output: autoregressive coefficients at lag 1...p 

201 """ 

202 gamma = jnp.asarray(gamma) 1abcd

203 assert gamma.ndim == 1 1abcd

204 t = gamma[:-1] 1abcd

205 b = gamma[1:] 1abcd

206 if t.size: 1abcd

207 return _toeplitz.solve(t, b) 1abcd

208 else: 

209 return jnp.empty(0) 1abcd

210 

211def _yule_walker_inv_mat(phi): 1efabcd

212 phi = jnp.asarray(phi) 1abcd

213 assert phi.ndim == 1 1abcd

214 p = len(phi) 1abcd

215 m = jnp.arange(p + 1)[:, None] # rows 1abcd

216 n = m.T # columns 1abcd

217 phi = jnp.pad(phi, (1, 1)) 1abcd

218 kp = jnp.clip(m + n, 0, p + 1) 1abcd

219 km = jnp.clip(m - n, 0, p + 1) 1abcd

220 return jnp.eye(p + 1) - (phi[kp] + phi[km]) / jnp.where(n, 1, 2) 1abcd

221 

222def _yule_walker_inv(phi): 1efabcd

223 """ 

224 phi = autoregressive coefficients at lag 1...p 

225 output: autocovariance at lag 0...p, assuming driving noise has sdev 1 

226 """ 

227 a = _yule_walker_inv_mat(phi) 1abcd

228 b = jnp.zeros(len(a)).at[0].set(1) 1abcd

229 # gamma = _pseudo_solve(a, b) 

230 gamma = jnp.linalg.solve(a, b) 1abcd

231 return gamma 1abcd

232 

233def _ar_evolve(phi, start, noise): 1efabcd

234 """ 

235 phi = autoregressive coefficients at lag 1...p 

236 start = first p values of the process (increasing time) 

237 noise = n noise values added at each step 

238 output: n new process values 

239 """ 

240 phi = jnp.asarray(phi) 1abcd

241 start = jnp.asarray(start) 1abcd

242 noise = jnp.asarray(noise) 1abcd

243 assert phi.ndim == 1 and phi.shape == start.shape and noise.ndim == 1 1abcd

244 return _ar_evolve_jit(phi, start, noise) 1abcd

245 

246@jax.jit 1efabcd

247def _ar_evolve_jit(phi, start, noise): 1efabcd

248 

249 def f(carry, eps): 1abcd

250 vals, cc, roll = carry 1abcd

251 phi = lax.dynamic_slice(cc, [vals.size - roll], [vals.size]) 1abcd

252 nextval = phi @ vals + eps 1abcd

253 if vals.size: 1abcd

254 vals = vals.at[roll].set(nextval) 1abcd

255 # maybe for some weird reason like alignment, actual rolling would 

256 # be faster. whatever 

257 roll = (roll + 1) % vals.size 1abcd

258 return (vals, cc, roll), nextval 1abcd

259 

260 cc = jnp.concatenate([phi, phi])[::-1] 1abcd

261 _, ev = lax.scan(f, (start, cc, 0), noise, unroll=16) 1abcd

262 return ev 1abcd

263 

264def _ar_with_roots(delta, slnr, lnc, norm): 1efabcd

265 phi = AR.phi_from_roots(slnr, lnc) # <---- weak 1abcd

266 gamma = AR.gamma_from_phi(phi) # <---- point 1abcd

267 if norm: 1abcd

268 gamma /= gamma[0] 1abcd

269 ampl = AR.ampl_from_roots(slnr, lnc, gamma) 1abcd

270 acf = AR.cov_from_ampl(slnr, lnc, ampl, delta) 1abcd

271 return acf 1abcd

272 

273 # TODO Currently gamma is not even pos def for high multiplicity/roots close 

274 # to 1. Raw patch: the badness condition is gamma[0] < 0 or any(abs(gamma) > 

275 # gamma[0]) or gamma inf/nan. Take the smallest log|root| and assume it 

276 # alone determines gamma. This is best implemented as an option in 

277 # _gamma_from_ampl_matmul. 

278 

279 # Is numerical integration of the spectrum a feasible way to get the 

280 # covariance? The roots correspond to peaks, and they get very high as the 

281 # roots get close to 1. But I know where the peaks are in advance => nope 

282 # because the e^iwx oscillates arbitrarily fast. However maybe I can compute 

283 # the first p terms, which solves my current problem with gamma. I guess I 

284 # just have to use a multiple of p of quadrature points. The spectrum 

285 # oscillates too but only up to mode p. The total calculation cost is then 

286 # O(p^2), better than current O(p^3). See Hamilton (1994, p. 155). 

287 

288 # Other solution (Hamilton p. 319): the covariance should be equal to the 

289 # impulse response, so I can get gamma from phi by an evolution starting 

290 # from zeros. => Nope, it's equal only for AR(1). 

291 

292 # condition for phi: in the region phi >= 0, it must be sum(phi) <= 1 

293 # (Hamilton p. 659). 

294 

295 # p = phi.size 

296 # yw = _yule_walker_inv_mat(phi) 

297 # b = jnp.zeros(p + 1).at[0].set(1) 

298 # ampl = jnp.linalg.solve(yw @ mat, b) 

299 # lag = delta if delta.ndim else delta[None] 

300 # acf = _gamma_from_ampl_matmul(slnr, lnc, lag, ampl) 

301 # if norm: 

302 # acf0 = _gamma_from_ampl_matmul(slnr, lnc, jnp.array([0]), ampl) 

303 # acf /= acf0 

304 # return acf if delta.ndim else acf.squeeze(0) 

305 

306def _pseudo_solve(a, b): 1efabcd

307 # this is less accurate than jnp.linalg.solve 

308 u, s, vh = jnp.linalg.svd(a) 1abcd

309 eps = jnp.finfo(a.dtype).eps 1abcd

310 s0 = s[0] if s.size else 0 1abcd

311 invs = jnp.where(s < s0 * eps * len(a), 0, 1 / s) 1abcd

312 return jnp.einsum('ij,j,jk,k', vh.conj().T, invs, u.conj().T, b) 1abcd

313 

314@jax.jit 1efabcd

315def _gamma_from_ampl_matmul(slnr, lnc, lag, ampl, lagnorm=None): 1efabcd

316 

317 vec = ampl.ndim == 1 1abcd

318 if vec: 1abcd

319 ampl = ampl[:, None] 1abcd

320 p = slnr.size + 2 * lnc.size 1abcd

321 assert ampl.shape[-2] == p + 1 1abcd

322 if lagnorm is None: 322 ↛ 325line 322 didn't jump to line 325 because the condition on line 322 was always true1abcd

323 lagnorm = p 1abcd

324 

325 def logcol(root, lag, llag, repeat): 1abcd

326 return -root * lag + jnp.where(repeat, repeat * llag, 0) 1abcd

327 

328 def lognorm(root, repeat, lagnorm): 1abcd

329 maxnorm = jnp.where(repeat, repeat * (-1 + jnp.log(repeat / root)), 0) 1abcd

330 defnorm = logcol(root, lagnorm, jnp.log(lagnorm), repeat) 1abcd

331 maxloc = repeat / root 1abcd

332 return jnp.where(maxloc <= lagnorm, maxnorm, defnorm) 1abcd

333 

334 # roots at infinity 

335 # TODO remove this because it's degenerate with large roots, handle the 

336 # p=0 case outside of this function 

337 col = jnp.where(lag, 0, 1) 1abcd

338 out = col[..., :, None] * ampl[..., 0, :] 1abcd

339 

340 # real roots 

341 llag = jnp.log(lag) 1abcd

342 val = (jnp.nan, 0, out, slnr, lag, llag, lagnorm) 1abcd

343 def loop(i, val): 1abcd

344 prevroot, repeat, out, slnr, lag, llag, lagnorm = val 1abcd

345 root = slnr[i] 1abcd

346 repeat = jnp.where(root == prevroot, repeat + 1, 0) 1abcd

347 prevroot = root 1abcd

348 sign = jnp.sign(root) ** lag 1abcd

349 aroot = jnp.abs(root) 1abcd

350 lcol = logcol(aroot, lag, llag, repeat) 1abcd

351 norm = lognorm(aroot, repeat, lagnorm) 1abcd

352 col = sign * jnp.exp(lcol - norm) 1abcd

353 out += col[..., :, None] * ampl[..., 1 + i, :] 1abcd

354 return prevroot, repeat, out, slnr, lag, llag, lagnorm 1abcd

355 if slnr.size: 1abcd

356 _, _, out, _, _, _, _ = lax.fori_loop(0, slnr.size, loop, val) 1abcd

357 

358 # complex roots 

359 val = (jnp.nan, 0, out, lnc, lag, llag, lagnorm) 1abcd

360 def loop(i, val): 1abcd

361 prevroot, repeat, out, lnc, lag, llag, lagnorm = val 1abcd

362 root = lnc[i] 1abcd

363 repeat = jnp.where(root == prevroot, repeat + 1, 0) 1abcd

364 prevroot = root 1abcd

365 lcol = logcol(root, lag, llag, repeat) 1abcd

366 norm = lognorm(root.real, repeat, lagnorm) 1abcd

367 col = jnp.exp(lcol - norm) 1abcd

368 idx = 1 + slnr.size + 2 * i 1abcd

369 out += col.real[..., :, None] * ampl[..., idx, :] 1abcd

370 

371 # real complex root = a pair of identical real roots 

372 repeat = jnp.where(root.imag, repeat, repeat + 1) 1abcd

373 col1 = jnp.where(root.imag, -col.imag, col.real * lag) 1abcd

374 out += col1[..., :, None] * ampl[..., idx + 1, :] 1abcd

375 

376 return prevroot, repeat, out, lnc, lag, llag, lagnorm 1abcd

377 if lnc.size: 1abcd

378 _, _, out, _, _, _, _ = lax.fori_loop(0, lnc.size, loop, val) 1abcd

379 

380 if vec: 1abcd

381 out = out.squeeze(-1) 1abcd

382 

383 return out 1abcd

384 

385class AR(_ARBase): 1efabcd

386 

387 __doc__ = _ARBase.__doc__ 1efabcd

388 

389 @classmethod 1efabcd

390 def phi_from_gamma(cls, gamma): 1efabcd

391 """ 

392 Determine the autoregressive coefficients from the covariance. 

393  

394 Parameters 

395 ---------- 

396 gamma : (p + 1,) array 

397 The autocovariance at lag 0...p. 

398  

399 Returns 

400 ------- 

401 phi : (p,) array 

402 The autoregressive coefficients at lag 1...p. 

403 """ 

404 gamma = cls._process_gamma(gamma) 1abcd

405 return _yule_walker(gamma) 1abcd

406 

407 @classmethod 1efabcd

408 def gamma_from_phi(cls, phi): 1efabcd

409 """ 

410 Determine the covariance from the autoregressive coefficients. 

411 

412 Parameters 

413 ---------- 

414 phi : (p,) array 

415 The autoregressive coefficients at lag 1...p. 

416  

417 Returns 

418 ------- 

419 gamma : (p + 1,) array 

420 The autocovariance at lag 0...p. The normalization is 

421 with noise variance 1. 

422  

423 Notes 

424 ----- 

425 The result is wildly inaccurate for roots with high multiplicity and/or 

426 close to 1. 

427 """ 

428 phi = cls._process_phi(phi) 1abcd

429 return _yule_walker_inv(phi) 1abcd

430 

431 # TODO fails (nan) for very small roots. In that case the answer is that 

432 # gamma is a constant vector. However I can't get the constant out of 

433 # a degenerate phi, I need the roots, and I don't know the formula. 

434 

435 @classmethod 1efabcd

436 def extend_gamma(cls, gamma, phi, n): 1efabcd

437 """ 

438 Extends values of the covariance function to higher lags. 

439  

440 Parameters 

441 ---------- 

442 gamma : (m,) array 

443 The autocovariance at lag q-m+1...q, with q >= 0 and m >= p + 1. 

444 phi : (p,) array 

445 The autoregressive coefficients at lag 1...p. 

446 n : int 

447 The number of new values to generate. 

448  

449 Returns 

450 ------- 

451 ext : (m + n,) array 

452 The autocovariance at lag q-m+1...q+n. 

453 """ 

454 gamma = cls._process_gamma(gamma) 1abcd

455 phi = cls._process_phi(phi) 1abcd

456 assert gamma.size > phi.size 1abcd

457 ext = _ar_evolve(phi, gamma[len(gamma) - len(phi):], jnp.broadcast_to(0., (n,))) 1abcd

458 return jnp.concatenate([gamma, ext]) 1abcd

459 

460 @classmethod 1efabcd

461 def phi_from_roots(cls, slnr, lnc): 1efabcd

462 """ 

463 Determine the autoregressive coefficients from the roots of the 

464 characteristic polynomial. 

465  

466 Parameters 

467 ---------- 

468 slnr : (nr,) real 

469 The real roots of the characteristic polynomial, expressed in the 

470 following way: ``sign(slnr)`` is the sign of the root, and 

471 ``abs(slnr)`` is the natural logarithm of the absolute value. 

472 lnc : (nc,) complex 

473 The natural logarithm of the complex roots of the characteristic 

474 polynomial (:math:`\\log z = \\log|z| + i\\arg z`), where each root 

475 also stands for its paired conjugate. 

476  

477 Returns 

478 ------- 

479 phi : (p,) real 

480 The autoregressive coefficients at lag 1...p, with p = nr + 2 nc. 

481 """ 

482 slnr, lnc = cls._process_roots(slnr, lnc) 1abcd

483 r = jnp.copysign(jnp.exp(-jnp.abs(slnr)), slnr) # works with +/-0 1abcd

484 c = jnp.exp(-lnc) 1abcd

485 

486 # minus sign in the exponentials to do 1/z, the poly output is already 

487 # reversed 

488 

489 roots = jnp.concatenate([r, c, c.conj()]).sort() # <-- polyroots sorts 1abcd

490 coef = jnp.atleast_1d(jnp.poly(roots)) 1abcd

491 

492 # TODO the implementation of jnp.poly (and np.poly) is inferior to the 

493 # one of np.polynomial.polynomial.polyfromroots, which cares about 

494 # numerical accuracy and would reduce compilation time if ported to jax 

495 # (current one is O(p), that would be O(log p)). 

496 

497 if coef.size: 497 ↛ 501line 497 didn't jump to line 501 because the condition on line 497 was always true1abcd

498 with _jaxext.skipifabstract(): 1abcd

499 numpy.testing.assert_equal(coef[0].item(), 1) 1abcd

500 numpy.testing.assert_allclose(jnp.imag(coef), 0, rtol=0, atol=1e-4) 1abcd

501 return -coef.real[1:] 1abcd

502 

503 # TODO possibly not accurate for large p. Do a test with an 

504 # implementation of poly which uses integer roots and non-fft convolve 

505 # (maybe add it as an option to my to-be-written implementation of poly) 

506 

507 @classmethod 1efabcd

508 def ampl_from_roots(cls, slnr, lnc, gamma): 1efabcd

509 # TODO docs 

510 slnr, lnc = cls._process_roots(slnr, lnc) 1abcd

511 gamma = cls._process_gamma(gamma) 1abcd

512 assert gamma.size == 1 + slnr.size + 2 * lnc.size 1abcd

513 lag = jnp.arange(gamma.size) 1abcd

514 mat = _gamma_from_ampl_matmul(slnr, lnc, lag, jnp.eye(gamma.size)) 1abcd

515 # return jnp.linalg.solve(mat, gamma) 

516 return _pseudo_solve(mat, gamma) 1abcd

517 

518 # TODO I'm using pseudo-solve only because of large roots degeneracy 

519 # in _gamma_from_ampl_matmul, remove it after solving that 

520 

521 # TODO maybe I can increase the precision of the solve with some 

522 # ordering of the columns of mat, I guess (reversed) global sort of the 

523 # roots 

524 

525 @classmethod 1efabcd

526 def cov_from_ampl(cls, slnr, lnc, ampl, lag): 1efabcd

527 # TODO docs 

528 slnr, lnc = cls._process_roots(slnr, lnc) 1abcd

529 ampl = cls._process_ampl(ampl) 1abcd

530 assert ampl.size == 1 + slnr.size + 2 * lnc.size 1abcd

531 lag = cls._process_lag(lag) 1abcd

532 scalar = lag.ndim == 0 1abcd

533 if scalar: 533 ↛ 534line 533 didn't jump to line 534 because the condition on line 533 was never true1abcd

534 lag = lag[None] 

535 acf = _gamma_from_ampl_matmul(slnr, lnc, lag, ampl) 1abcd

536 return acf.squeeze(0) if scalar else acf 1abcd

537 

538 @classmethod 1efabcd

539 def inverse_roots_from_phi(cls, phi): 1efabcd

540 phi = cls._process_phi(phi) 

541 poly = jnp.concatenate([jnp.ones(1), -phi]) 

542 return jnp.roots(poly, strip_zeros=False) 

543 

544 # TODO methods: 

545 # - gamma_from_roots which uses quadrature fourier transf of spectrum 

546 

547 @staticmethod 1efabcd

548 def _process_roots(slnr, lnc): 1efabcd

549 slnr = jnp.asarray(slnr, float).sort() 1abcd

550 lnc = jnp.asarray(lnc, complex) 1abcd

551 assert slnr.ndim == lnc.ndim == 1 1abcd

552 imag = jnp.abs(lnc.imag) % (2 * jnp.pi) 1abcd

553 imag = jnp.where(imag > jnp.pi, 2 * jnp.pi - imag, imag) 1abcd

554 lnc = lnc.real + 1j * imag 1abcd

555 lnc = lnc.sort() 1abcd

556 return slnr, lnc 1abcd

557 

558 @staticmethod 1efabcd

559 def _process_gamma(gamma): 1efabcd

560 gamma = jnp.asarray(gamma, float) 1abcd

561 assert gamma.ndim == 1 and gamma.size >= 1 1abcd

562 return gamma 1abcd

563 

564 @staticmethod 1efabcd

565 def _process_phi(phi): 1efabcd

566 phi = jnp.asarray(phi, float) 1abcd

567 assert phi.ndim == 1 1abcd

568 return phi 1abcd

569 

570 @staticmethod 1efabcd

571 def _process_ampl(ampl): 1efabcd

572 ampl = jnp.asarray(ampl, float) 1abcd

573 assert ampl.ndim == 1 and ampl.size >= 1 1abcd

574 return ampl 1abcd

575 

576 @staticmethod 1efabcd

577 def _process_lag(lag): 1efabcd

578 lag = jnp.asarray(lag) 1abcd

579 assert jnp.issubdtype(lag, jnp.integer) 1abcd

580 return lag.astype(int) 1abcd