Coverage for src/lsqfitgp/_kernels/_arma.py: 97%
228 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_kernels/_arma.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import jax 1efabcd
21from jax import numpy as jnp 1efabcd
22from jax import lax 1efabcd
23import numpy 1efabcd
25from .. import _linalg 1efabcd
26from .._linalg import _toeplitz 1efabcd
27from .. import _jaxext 1efabcd
28from .._Kernel import stationarykernel 1efabcd
30# use positive delta because negative indices wrap around
31@stationarykernel(derivable=False, maxdim=1, input='abs') 1efabcd
32def MA(delta, w=None, norm=False): 1efabcd
33 """
34 Discrete moving average kernel.
36 .. math::
37 k(\\Delta) = \\sum_{k=|\\Delta|}^{n-1} w_k w_{k-|\\Delta|},
38 \\quad \\mathbf w = (w_0, \\ldots, w_{n-1}).
40 The inputs must be integers. It is the autocovariance function of a moving
41 average with weights :math:`\\mathbf w` applied to white noise:
43 .. math::
44 k(i, j) &= \\operatorname{Cov}[y_i, y_j], \\\\
45 y_i &= \\sum_{k=0}^{n-1} w_k \\epsilon_{i-k}, \\\\
46 \\operatorname{Cov}[\\epsilon_i,\\epsilon_j] &= \\delta_{ij}.
48 If ``norm=True``, the variance is normalized to 1, which amounts to
49 normalizing :math:`\\mathbf w` to unit length.
51 """
53 # TODO reference? must find some standard book with a treatment which is
54 # not too formal yet writes clearly about the covariance function
56 # TODO nd version with w.ndim == n, it's a nd convolution. use
57 # jax.scipy.signal.correlate.
59 w = jnp.asarray(w) 1abcd
60 assert w.ndim == 1 1abcd
61 if len(w): 1abcd
62 cov = jnp.convolve(w, w[::-1]) 1abcd
63 if norm: 1abcd
64 cov /= cov[len(w) - 1] 1abcd
65 return cov.at[delta + len(w) - 1].get(mode='fill', fill_value=0) 1abcd
66 else:
67 return jnp.zeros(delta.shape) 1abcd
69@stationarykernel(derivable=False, maxdim=1, input='abs') 1efabcd
70def _ARBase(delta, phi=None, gamma=None, maxlag=None, slnr=None, lnc=None, norm=False): 1efabcd
71 """
72 Discrete autoregressive kernel.
74 You have to specify one and only one of the sets of parameters
75 ``phi+maxlag``, ``gamma+maxlag``, ``slnr+lnc``.
77 Parameters
78 ----------
79 phi : (p,) real
80 The autoregressive coefficients at lag 1...p.
81 gamma : (p + 1,) real
82 The autocovariance at lag 0...p.
83 maxlag : int
84 The maximum lag that the kernel will be evaluated on. If the actual
85 inputs produce higher lags, the missing values are filled with ``nan``.
86 slnr : (nr,) real
87 The real roots of the characteristic polynomial, expressed in the
88 following way: ``sign(slnr)`` is the sign of the root, and
89 ``abs(slnr)`` is the natural logarithm of the absolute value.
90 lnc : (nc,) complex
91 The natural logarithm of the complex roots of the characteristic
92 polynomial (:math:`\\log z = \\log|z| + i\\arg z`), where each root
93 also stands for its paired conjugate.
95 In `slnr` and `lnc`, the multiplicity of a root is expressed by
96 repeating the root in the array (not necessarily next to each other).
97 Only exact repetition counts; very close yet distinct roots are treated
98 as separate and lead to numerical instability, in particular complex
99 roots very close to the real line. An exactly real complex root behaves
100 like a pair of identical real roots. Two complex roots also count as
101 equal if conjugate, and the argument is standardized to :math:`[0,
102 2\\pi)`.
103 norm : bool, default False
104 If True, normalize the autocovariance to be 1 at lag 0. If False,
105 normalize such that the variance of the generating noise is 1, or use
106 the user-provided normalization if `gamma` is specified.
108 Notes
109 -----
110 This is the covariance function of a stationary autoregressive process,
111 which is defined recursively as
113 .. math::
114 y_i = \\sum_{k=1}^p \\phi_k y_{i-k} + \\epsilon_i,
116 where :math:`\\epsilon_i` is white noise, i.e.,
117 :math:`\\operatorname{Cov}[\\epsilon_i, \\epsilon_j] = \\delta_{ij}`. The
118 length :math:`p` of the vector of coefficients :math:`\\boldsymbol\\phi`
119 is the "order" of the process.
121 The covariance function can be expressed in two ways. First as the same
122 recursion defining the process:
124 .. math::
125 \\gamma_m = \\sum_{k=1}^p \\phi_k \\gamma_{m-k} + \\delta_{m0},
127 where :math:`\\gamma_m \\equiv \\operatorname{Cov}[y_i, y_{i+m}]`. This is
128 called "Yule-Walker equation." Second, as a linear combination of mixed
129 power-exponentials:
131 .. math::
132 \\gamma_m = \\sum_{j=1}^n
133 \\sum_{l=1}^{\\mu_j}
134 a_{jl} |m|^{l-1} x_j^{-|m|},
136 where :math:`x_j` and :math:`\\mu_j` are the (complex) roots and
137 corresponding multiplicities of the "characteristic polynomial"
139 .. math::
140 P(x) = 1 - \\sum_{k=1}^p \\phi_k x^k,
142 and the :math:`a_{jl}` are uniquely determined complex coefficients. The
143 :math:`\\boldsymbol\\phi` vector is valid iff :math:`|x_j|>1, \\forall j`.
145 There are three alternative parametrization for this kernel.
147 If you specify `phi`, the first terms of the covariance are computed
148 solving the Yule-Walker equation, and then evolved up to `maxlag`. It
149 is necessary to specify `maxlag` instead of letting the code figure it out
150 from the actual inputs for technical reasons.
152 Likewise, if you specify `gamma`, the coefficients are obtained with
153 Yule-Walker and then used to evolve the covariance. The only difference is
154 that the normalization can be different: starting from `phi`, the variance
155 of the generating noise :math:`\\epsilon` is fixed to 1, while giving
156 `gamma` directly implies an arbitrary value.
158 Instead, if you specify the roots with `slnr` and `lnc`, the coefficients
159 are obtained from the polynomial defined in terms of the roots, and then
160 the amplitudes :math:`a_{jl}` are computed by solving a linear system with
161 the covariance (from YW) as RHS. Finally, the full covariance function is
162 evaluated with the analytical expression.
164 The reasons for using the logarithm are that 1) in practice the roots are
165 tipically close to 1, so the logarithm is numerically more accurate, and 2)
166 the logarithm is readily interpretable as the inverse of the correlation
167 length.
169 """
170 cond = ( 1abcd
171 (phi is not None and maxlag is not None and gamma is None and slnr is None and lnc is None) or
172 (phi is None and maxlag is not None and gamma is not None and slnr is None and lnc is None) or
173 (phi is None and maxlag is None and gamma is None and slnr is not None and lnc is not None)
174 )
175 if not cond: 1abcd
176 raise ValueError('invalid set of specified parameters') 1abcd
178 # TODO maybe I could allow redundantly specifying gamma and phi, e.g., for
179 # numerical accuracy reasons if they are determined from an analytical
180 # expression.
182 if phi is None and gamma is None: 1abcd
183 return _ar_with_roots(delta, slnr, lnc, norm) 1abcd
184 else:
185 return _ar_with_phigamma(delta, phi, gamma, maxlag, norm) 1abcd
187def _ar_with_phigamma(delta, phi, gamma, maxlag, norm): 1efabcd
188 if phi is None: 1abcd
189 phi = AR.phi_from_gamma(gamma) 1abcd
190 if gamma is None: 1abcd
191 gamma = AR.gamma_from_phi(phi) 1abcd
192 if norm: 1abcd
193 gamma = gamma / gamma[0] 1abcd
194 acf = AR.extend_gamma(gamma, phi, maxlag + 1 - len(gamma)) 1abcd
195 return acf.at[delta].get(mode='fill', fill_value=jnp.nan) 1abcd
197def _yule_walker(gamma): 1efabcd
198 """
199 gamma = autocovariance at lag 0...p
200 output: autoregressive coefficients at lag 1...p
201 """
202 gamma = jnp.asarray(gamma) 1abcd
203 assert gamma.ndim == 1 1abcd
204 t = gamma[:-1] 1abcd
205 b = gamma[1:] 1abcd
206 if t.size: 1abcd
207 return _toeplitz.solve(t, b) 1abcd
208 else:
209 return jnp.empty(0) 1abcd
211def _yule_walker_inv_mat(phi): 1efabcd
212 phi = jnp.asarray(phi) 1abcd
213 assert phi.ndim == 1 1abcd
214 p = len(phi) 1abcd
215 m = jnp.arange(p + 1)[:, None] # rows 1abcd
216 n = m.T # columns 1abcd
217 phi = jnp.pad(phi, (1, 1)) 1abcd
218 kp = jnp.clip(m + n, 0, p + 1) 1abcd
219 km = jnp.clip(m - n, 0, p + 1) 1abcd
220 return jnp.eye(p + 1) - (phi[kp] + phi[km]) / jnp.where(n, 1, 2) 1abcd
222def _yule_walker_inv(phi): 1efabcd
223 """
224 phi = autoregressive coefficients at lag 1...p
225 output: autocovariance at lag 0...p, assuming driving noise has sdev 1
226 """
227 a = _yule_walker_inv_mat(phi) 1abcd
228 b = jnp.zeros(len(a)).at[0].set(1) 1abcd
229 # gamma = _pseudo_solve(a, b)
230 gamma = jnp.linalg.solve(a, b) 1abcd
231 return gamma 1abcd
233def _ar_evolve(phi, start, noise): 1efabcd
234 """
235 phi = autoregressive coefficients at lag 1...p
236 start = first p values of the process (increasing time)
237 noise = n noise values added at each step
238 output: n new process values
239 """
240 phi = jnp.asarray(phi) 1abcd
241 start = jnp.asarray(start) 1abcd
242 noise = jnp.asarray(noise) 1abcd
243 assert phi.ndim == 1 and phi.shape == start.shape and noise.ndim == 1 1abcd
244 return _ar_evolve_jit(phi, start, noise) 1abcd
246@jax.jit 1efabcd
247def _ar_evolve_jit(phi, start, noise): 1efabcd
249 def f(carry, eps): 1abcd
250 vals, cc, roll = carry 1abcd
251 phi = lax.dynamic_slice(cc, [vals.size - roll], [vals.size]) 1abcd
252 nextval = phi @ vals + eps 1abcd
253 if vals.size: 1abcd
254 vals = vals.at[roll].set(nextval) 1abcd
255 # maybe for some weird reason like alignment, actual rolling would
256 # be faster. whatever
257 roll = (roll + 1) % vals.size 1abcd
258 return (vals, cc, roll), nextval 1abcd
260 cc = jnp.concatenate([phi, phi])[::-1] 1abcd
261 _, ev = lax.scan(f, (start, cc, 0), noise, unroll=16) 1abcd
262 return ev 1abcd
264def _ar_with_roots(delta, slnr, lnc, norm): 1efabcd
265 phi = AR.phi_from_roots(slnr, lnc) # <---- weak 1abcd
266 gamma = AR.gamma_from_phi(phi) # <---- point 1abcd
267 if norm: 1abcd
268 gamma /= gamma[0] 1abcd
269 ampl = AR.ampl_from_roots(slnr, lnc, gamma) 1abcd
270 acf = AR.cov_from_ampl(slnr, lnc, ampl, delta) 1abcd
271 return acf 1abcd
273 # TODO Currently gamma is not even pos def for high multiplicity/roots close
274 # to 1. Raw patch: the badness condition is gamma[0] < 0 or any(abs(gamma) >
275 # gamma[0]) or gamma inf/nan. Take the smallest log|root| and assume it
276 # alone determines gamma. This is best implemented as an option in
277 # _gamma_from_ampl_matmul.
279 # Is numerical integration of the spectrum a feasible way to get the
280 # covariance? The roots correspond to peaks, and they get very high as the
281 # roots get close to 1. But I know where the peaks are in advance => nope
282 # because the e^iwx oscillates arbitrarily fast. However maybe I can compute
283 # the first p terms, which solves my current problem with gamma. I guess I
284 # just have to use a multiple of p of quadrature points. The spectrum
285 # oscillates too but only up to mode p. The total calculation cost is then
286 # O(p^2), better than current O(p^3). See Hamilton (1994, p. 155).
288 # Other solution (Hamilton p. 319): the covariance should be equal to the
289 # impulse response, so I can get gamma from phi by an evolution starting
290 # from zeros. => Nope, it's equal only for AR(1).
292 # condition for phi: in the region phi >= 0, it must be sum(phi) <= 1
293 # (Hamilton p. 659).
295 # p = phi.size
296 # yw = _yule_walker_inv_mat(phi)
297 # b = jnp.zeros(p + 1).at[0].set(1)
298 # ampl = jnp.linalg.solve(yw @ mat, b)
299 # lag = delta if delta.ndim else delta[None]
300 # acf = _gamma_from_ampl_matmul(slnr, lnc, lag, ampl)
301 # if norm:
302 # acf0 = _gamma_from_ampl_matmul(slnr, lnc, jnp.array([0]), ampl)
303 # acf /= acf0
304 # return acf if delta.ndim else acf.squeeze(0)
306def _pseudo_solve(a, b): 1efabcd
307 # this is less accurate than jnp.linalg.solve
308 u, s, vh = jnp.linalg.svd(a) 1abcd
309 eps = jnp.finfo(a.dtype).eps 1abcd
310 s0 = s[0] if s.size else 0 1abcd
311 invs = jnp.where(s < s0 * eps * len(a), 0, 1 / s) 1abcd
312 return jnp.einsum('ij,j,jk,k', vh.conj().T, invs, u.conj().T, b) 1abcd
314@jax.jit 1efabcd
315def _gamma_from_ampl_matmul(slnr, lnc, lag, ampl, lagnorm=None): 1efabcd
317 vec = ampl.ndim == 1 1abcd
318 if vec: 1abcd
319 ampl = ampl[:, None] 1abcd
320 p = slnr.size + 2 * lnc.size 1abcd
321 assert ampl.shape[-2] == p + 1 1abcd
322 if lagnorm is None: 322 ↛ 325line 322 didn't jump to line 325 because the condition on line 322 was always true1abcd
323 lagnorm = p 1abcd
325 def logcol(root, lag, llag, repeat): 1abcd
326 return -root * lag + jnp.where(repeat, repeat * llag, 0) 1abcd
328 def lognorm(root, repeat, lagnorm): 1abcd
329 maxnorm = jnp.where(repeat, repeat * (-1 + jnp.log(repeat / root)), 0) 1abcd
330 defnorm = logcol(root, lagnorm, jnp.log(lagnorm), repeat) 1abcd
331 maxloc = repeat / root 1abcd
332 return jnp.where(maxloc <= lagnorm, maxnorm, defnorm) 1abcd
334 # roots at infinity
335 # TODO remove this because it's degenerate with large roots, handle the
336 # p=0 case outside of this function
337 col = jnp.where(lag, 0, 1) 1abcd
338 out = col[..., :, None] * ampl[..., 0, :] 1abcd
340 # real roots
341 llag = jnp.log(lag) 1abcd
342 val = (jnp.nan, 0, out, slnr, lag, llag, lagnorm) 1abcd
343 def loop(i, val): 1abcd
344 prevroot, repeat, out, slnr, lag, llag, lagnorm = val 1abcd
345 root = slnr[i] 1abcd
346 repeat = jnp.where(root == prevroot, repeat + 1, 0) 1abcd
347 prevroot = root 1abcd
348 sign = jnp.sign(root) ** lag 1abcd
349 aroot = jnp.abs(root) 1abcd
350 lcol = logcol(aroot, lag, llag, repeat) 1abcd
351 norm = lognorm(aroot, repeat, lagnorm) 1abcd
352 col = sign * jnp.exp(lcol - norm) 1abcd
353 out += col[..., :, None] * ampl[..., 1 + i, :] 1abcd
354 return prevroot, repeat, out, slnr, lag, llag, lagnorm 1abcd
355 if slnr.size: 1abcd
356 _, _, out, _, _, _, _ = lax.fori_loop(0, slnr.size, loop, val) 1abcd
358 # complex roots
359 val = (jnp.nan, 0, out, lnc, lag, llag, lagnorm) 1abcd
360 def loop(i, val): 1abcd
361 prevroot, repeat, out, lnc, lag, llag, lagnorm = val 1abcd
362 root = lnc[i] 1abcd
363 repeat = jnp.where(root == prevroot, repeat + 1, 0) 1abcd
364 prevroot = root 1abcd
365 lcol = logcol(root, lag, llag, repeat) 1abcd
366 norm = lognorm(root.real, repeat, lagnorm) 1abcd
367 col = jnp.exp(lcol - norm) 1abcd
368 idx = 1 + slnr.size + 2 * i 1abcd
369 out += col.real[..., :, None] * ampl[..., idx, :] 1abcd
371 # real complex root = a pair of identical real roots
372 repeat = jnp.where(root.imag, repeat, repeat + 1) 1abcd
373 col1 = jnp.where(root.imag, -col.imag, col.real * lag) 1abcd
374 out += col1[..., :, None] * ampl[..., idx + 1, :] 1abcd
376 return prevroot, repeat, out, lnc, lag, llag, lagnorm 1abcd
377 if lnc.size: 1abcd
378 _, _, out, _, _, _, _ = lax.fori_loop(0, lnc.size, loop, val) 1abcd
380 if vec: 1abcd
381 out = out.squeeze(-1) 1abcd
383 return out 1abcd
385class AR(_ARBase): 1efabcd
387 __doc__ = _ARBase.__doc__ 1efabcd
389 @classmethod 1efabcd
390 def phi_from_gamma(cls, gamma): 1efabcd
391 """
392 Determine the autoregressive coefficients from the covariance.
394 Parameters
395 ----------
396 gamma : (p + 1,) array
397 The autocovariance at lag 0...p.
399 Returns
400 -------
401 phi : (p,) array
402 The autoregressive coefficients at lag 1...p.
403 """
404 gamma = cls._process_gamma(gamma) 1abcd
405 return _yule_walker(gamma) 1abcd
407 @classmethod 1efabcd
408 def gamma_from_phi(cls, phi): 1efabcd
409 """
410 Determine the covariance from the autoregressive coefficients.
412 Parameters
413 ----------
414 phi : (p,) array
415 The autoregressive coefficients at lag 1...p.
417 Returns
418 -------
419 gamma : (p + 1,) array
420 The autocovariance at lag 0...p. The normalization is
421 with noise variance 1.
423 Notes
424 -----
425 The result is wildly inaccurate for roots with high multiplicity and/or
426 close to 1.
427 """
428 phi = cls._process_phi(phi) 1abcd
429 return _yule_walker_inv(phi) 1abcd
431 # TODO fails (nan) for very small roots. In that case the answer is that
432 # gamma is a constant vector. However I can't get the constant out of
433 # a degenerate phi, I need the roots, and I don't know the formula.
435 @classmethod 1efabcd
436 def extend_gamma(cls, gamma, phi, n): 1efabcd
437 """
438 Extends values of the covariance function to higher lags.
440 Parameters
441 ----------
442 gamma : (m,) array
443 The autocovariance at lag q-m+1...q, with q >= 0 and m >= p + 1.
444 phi : (p,) array
445 The autoregressive coefficients at lag 1...p.
446 n : int
447 The number of new values to generate.
449 Returns
450 -------
451 ext : (m + n,) array
452 The autocovariance at lag q-m+1...q+n.
453 """
454 gamma = cls._process_gamma(gamma) 1abcd
455 phi = cls._process_phi(phi) 1abcd
456 assert gamma.size > phi.size 1abcd
457 ext = _ar_evolve(phi, gamma[len(gamma) - len(phi):], jnp.broadcast_to(0., (n,))) 1abcd
458 return jnp.concatenate([gamma, ext]) 1abcd
460 @classmethod 1efabcd
461 def phi_from_roots(cls, slnr, lnc): 1efabcd
462 """
463 Determine the autoregressive coefficients from the roots of the
464 characteristic polynomial.
466 Parameters
467 ----------
468 slnr : (nr,) real
469 The real roots of the characteristic polynomial, expressed in the
470 following way: ``sign(slnr)`` is the sign of the root, and
471 ``abs(slnr)`` is the natural logarithm of the absolute value.
472 lnc : (nc,) complex
473 The natural logarithm of the complex roots of the characteristic
474 polynomial (:math:`\\log z = \\log|z| + i\\arg z`), where each root
475 also stands for its paired conjugate.
477 Returns
478 -------
479 phi : (p,) real
480 The autoregressive coefficients at lag 1...p, with p = nr + 2 nc.
481 """
482 slnr, lnc = cls._process_roots(slnr, lnc) 1abcd
483 r = jnp.copysign(jnp.exp(-jnp.abs(slnr)), slnr) # works with +/-0 1abcd
484 c = jnp.exp(-lnc) 1abcd
486 # minus sign in the exponentials to do 1/z, the poly output is already
487 # reversed
489 roots = jnp.concatenate([r, c, c.conj()]).sort() # <-- polyroots sorts 1abcd
490 coef = jnp.atleast_1d(jnp.poly(roots)) 1abcd
492 # TODO the implementation of jnp.poly (and np.poly) is inferior to the
493 # one of np.polynomial.polynomial.polyfromroots, which cares about
494 # numerical accuracy and would reduce compilation time if ported to jax
495 # (current one is O(p), that would be O(log p)).
497 if coef.size: 497 ↛ 501line 497 didn't jump to line 501 because the condition on line 497 was always true1abcd
498 with _jaxext.skipifabstract(): 1abcd
499 numpy.testing.assert_equal(coef[0].item(), 1) 1abcd
500 numpy.testing.assert_allclose(jnp.imag(coef), 0, rtol=0, atol=1e-4) 1abcd
501 return -coef.real[1:] 1abcd
503 # TODO possibly not accurate for large p. Do a test with an
504 # implementation of poly which uses integer roots and non-fft convolve
505 # (maybe add it as an option to my to-be-written implementation of poly)
507 @classmethod 1efabcd
508 def ampl_from_roots(cls, slnr, lnc, gamma): 1efabcd
509 # TODO docs
510 slnr, lnc = cls._process_roots(slnr, lnc) 1abcd
511 gamma = cls._process_gamma(gamma) 1abcd
512 assert gamma.size == 1 + slnr.size + 2 * lnc.size 1abcd
513 lag = jnp.arange(gamma.size) 1abcd
514 mat = _gamma_from_ampl_matmul(slnr, lnc, lag, jnp.eye(gamma.size)) 1abcd
515 # return jnp.linalg.solve(mat, gamma)
516 return _pseudo_solve(mat, gamma) 1abcd
518 # TODO I'm using pseudo-solve only because of large roots degeneracy
519 # in _gamma_from_ampl_matmul, remove it after solving that
521 # TODO maybe I can increase the precision of the solve with some
522 # ordering of the columns of mat, I guess (reversed) global sort of the
523 # roots
525 @classmethod 1efabcd
526 def cov_from_ampl(cls, slnr, lnc, ampl, lag): 1efabcd
527 # TODO docs
528 slnr, lnc = cls._process_roots(slnr, lnc) 1abcd
529 ampl = cls._process_ampl(ampl) 1abcd
530 assert ampl.size == 1 + slnr.size + 2 * lnc.size 1abcd
531 lag = cls._process_lag(lag) 1abcd
532 scalar = lag.ndim == 0 1abcd
533 if scalar: 533 ↛ 534line 533 didn't jump to line 534 because the condition on line 533 was never true1abcd
534 lag = lag[None]
535 acf = _gamma_from_ampl_matmul(slnr, lnc, lag, ampl) 1abcd
536 return acf.squeeze(0) if scalar else acf 1abcd
538 @classmethod 1efabcd
539 def inverse_roots_from_phi(cls, phi): 1efabcd
540 phi = cls._process_phi(phi)
541 poly = jnp.concatenate([jnp.ones(1), -phi])
542 return jnp.roots(poly, strip_zeros=False)
544 # TODO methods:
545 # - gamma_from_roots which uses quadrature fourier transf of spectrum
547 @staticmethod 1efabcd
548 def _process_roots(slnr, lnc): 1efabcd
549 slnr = jnp.asarray(slnr, float).sort() 1abcd
550 lnc = jnp.asarray(lnc, complex) 1abcd
551 assert slnr.ndim == lnc.ndim == 1 1abcd
552 imag = jnp.abs(lnc.imag) % (2 * jnp.pi) 1abcd
553 imag = jnp.where(imag > jnp.pi, 2 * jnp.pi - imag, imag) 1abcd
554 lnc = lnc.real + 1j * imag 1abcd
555 lnc = lnc.sort() 1abcd
556 return slnr, lnc 1abcd
558 @staticmethod 1efabcd
559 def _process_gamma(gamma): 1efabcd
560 gamma = jnp.asarray(gamma, float) 1abcd
561 assert gamma.ndim == 1 and gamma.size >= 1 1abcd
562 return gamma 1abcd
564 @staticmethod 1efabcd
565 def _process_phi(phi): 1efabcd
566 phi = jnp.asarray(phi, float) 1abcd
567 assert phi.ndim == 1 1abcd
568 return phi 1abcd
570 @staticmethod 1efabcd
571 def _process_ampl(ampl): 1efabcd
572 ampl = jnp.asarray(ampl, float) 1abcd
573 assert ampl.ndim == 1 and ampl.size >= 1 1abcd
574 return ampl 1abcd
576 @staticmethod 1efabcd
577 def _process_lag(lag): 1efabcd
578 lag = jnp.asarray(lag) 1abcd
579 assert jnp.issubdtype(lag, jnp.integer) 1abcd
580 return lag.astype(int) 1abcd