Coverage for src/lsqfitgp/_linalg/_seqalg.py: 98%
111 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_linalg/_seqalg.py
2#
3# Copyright (c) 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import abc 1efabcd
22from jax import numpy as jnp 1efabcd
23from jax import lax 1efabcd
24from jax import tree_util 1efabcd
26from . import _pytree 1efabcd
28class SequentialOperation(_pytree.AutoPyTree, metaclass=abc.ABCMeta): 1efabcd
29 """see jax.lax.fori_loop for semantics"""
31 @abc.abstractmethod 1efabcd
32 def __init__(self, *args): # pragma: no cover 1efabcd
33 pass
35 @property 1efabcd
36 @abc.abstractmethod 1efabcd
37 def inputs(self): # pragma: no cover 1efabcd
38 """tuple of indices of other ops to be used as input"""
39 pass
41 @abc.abstractmethod 1efabcd
42 def init(self, n, *inputs): # pragma: no cover 1efabcd
43 """called before the cycle starts with the requested inputs"""
44 pass
46 @abc.abstractmethod 1efabcd
47 def iter_out(self, i): # pragma: no cover 1efabcd
48 """output passed to other ops who request it through `inputs`,
49 guaranteed to be called after `init`"""
50 pass
52 @abc.abstractmethod 1efabcd
53 def iter(self, i, *inputs): # pragma: no cover 1efabcd
54 """update for iteration"""
55 pass
57 @abc.abstractmethod 1efabcd
58 def finalize(self): # pragma: no cover 1efabcd
59 """return final product"""
60 pass
62def sequential_algorithm(n, ops): 1efabcd
63 """
64 Define and execute a sequential algorithm on matrices.
66 Parameters
67 ----------
68 n : int
69 Number of steps of the algorithm.
70 ops : list of SequentialOperation
71 Instantiated `SequentialOperation`s that represent the algorithm.
73 Return
74 ------
75 results : tuple
76 The sequence of final outputs of each operation in `ops`.
77 """
78 for i, op in enumerate(ops): 1abcd
79 inputs = op.inputs 1abcd
80 if any(j >= i for j in inputs): 1abcd
81 raise ValueError(f'{i}-th operation {op.__class__.__name__} requested inputs {inputs!r} with forward references') 1abcd
82 args = (ops[j].iter_out(0) for j in inputs) 1abcd
83 op.init(n, *args) 1abcd
84 def body_fun(i, ops): 1abcd
85 for op in ops: 1abcd
86 args = (ops[j].iter_out(i) for j in op.inputs) 1abcd
87 op.iter(i, *args) 1abcd
88 return ops 1abcd
89 ops = lax.fori_loop(1, n, body_fun, ops) # TODO convert to lax.scan and unroll? 1abcd
90 return tuple(op.finalize() for op in ops) 1abcd
92class Producer(SequentialOperation): 1efabcd
93 """produces something at each iteration but no final output"""
95 def finalize(self): 1efabcd
96 pass 1abcd
98class Consumer(SequentialOperation): 1efabcd
99 """produces a final output but no iteration output"""
101 iter_out = NotImplemented 1efabcd
103class SingleInput(SequentialOperation): 1efabcd
105 def __init__(self, input): 1efabcd
106 self.inputs = (input,) 1abcd
108 inputs = NotImplemented 1efabcd
110class Stack(Consumer, SingleInput): 1efabcd
111 """input = an operation producing arrays"""
113 def init(self, n, a0): 1efabcd
114 out = jnp.zeros((n,) + a0.shape, a0.dtype) 1abcd
115 self.out = out.at[0, ...].set(a0) 1abcd
117 def iter(self, i, ai): 1efabcd
118 self.out = self.out.at[i, ...].set(ai) 1abcd
120 def finalize(self): 1efabcd
121 """the stacked arrays"""
122 return self.out 1abcd
124class MatMulIterByFull(Consumer, SingleInput): 1efabcd
126 def __init__(self, input, b): 1efabcd
127 """input = an operation producing pieces of left operand (a)
128 b = right operand"""
129 self.inputs = (input,) 1abcd
130 b = jnp.asarray(b) 1abcd
131 assert b.ndim in (1, 2) 1abcd
132 vec = b.ndim < 2 1abcd
133 if vec: 1abcd
134 b = b[:, None] 1abcd
135 self.vec = vec 1abcd
136 self.b = b 1abcd
138 @abc.abstractmethod 1efabcd
139 def init(self, n, a0): # pragma: no cover 1efabcd
140 self.ab = ...
142 @abc.abstractmethod 1efabcd
143 def iter(self, i, ai): # pragma: no cover 1efabcd
144 self.ab = ...
146 def finalize(self): 1efabcd
147 ab = self.ab 1abcd
148 if self.vec: 148 ↛ 149line 148 didn't jump to line 149 because the condition on line 148 was never true1abcd
149 ab = jnp.squeeze(ab, -1)
150 return ab 1abcd
152class MatMulRowByFull(Producer, MatMulIterByFull): 1efabcd
154 def init(self, n, a0): 1efabcd
155 assert a0.ndim == 1 1abcd
156 assert self.b.shape[0] == len(a0) 1abcd
157 self.abi = a0 @ self.b 1abcd
159 def iter_out(self, i): 1efabcd
160 abi = self.abi 1abcd
161 if self.vec: 1abcd
162 abi = jnp.squeeze(abi, -1) 1abcd
163 return abi 1abcd
165 def iter(self, i, ai): 1efabcd
166 """ i-th row of input @ b """
167 self.abi = ai @ self.b 1abcd
169class SolveTriLowerColByFull(MatMulIterByFull): 1efabcd
170 # x[0] /= a[0, 0]
171 # for i in range(1, len(x)):
172 # x[i:] -= x[i - 1] * a[i:, i - 1]
173 # x[i] /= a[i, i]
175 def init(self, n, a0): 1efabcd
176 b = self.b 1abcd
177 del self.b 1abcd
178 assert a0.shape == (n,) 1abcd
179 assert b.shape[0] == n 1abcd
180 self.prevai = a0.at[0].set(0) 1abcd
181 self.ab = b.at[0, :].divide(a0[0]) 1abcd
183 def iter(self, i, ai): 1efabcd
184 ab = self.ab 1abcd
185 ab = ab - ab[i - 1, :] * self.prevai[:, None] 1abcd
186 self.ab = ab.at[i, :].divide(ai[i]) 1abcd
187 self.prevai = ai.at[i].set(0) 1abcd
189class Rows(Producer): 1efabcd
191 def __init__(self, x): 1efabcd
192 self.x = x 1abcd
194 inputs = () 1efabcd
196 def init(self, n): 1efabcd
197 pass 1abcd
199 def iter_out(self, i): 1efabcd
200 return self.x[i, ...] 1abcd
202 def iter(self, i): 1efabcd
203 pass 1abcd
205class MatMulColByRow(Consumer): 1efabcd
207 def __init__(self, inputa, inputb): 1efabcd
208 self.inputs = (inputa, inputb) 1abcd
210 inputs = None 1efabcd
212 def init(self, n, a0, b0): 1efabcd
213 assert a0.ndim == 1 and b0.ndim <= 1 1abcd
214 self.vec = b0.ndim > 0 1abcd
215 if self.vec: 1abcd
216 self.ab = a0[:, None] * b0[None, :] 1abcd
217 else:
218 self.ab = a0 * b0 1abcd
220 def iter(self, i, ai, bi): 1efabcd
221 if self.vec: 1abcd
222 self.ab = self.ab + ai[:, None] * bi[None, :] 1abcd
223 else:
224 self.ab = self.ab + ai * bi 1abcd
226 def finalize(self): 1efabcd
227 return self.ab 1abcd
229class SumLogDiag(Consumer, SingleInput): 1efabcd
230 """input = operation producing the rows/columns of a square matrix"""
232 def init(self, n, m0): 1efabcd
233 assert m0.shape == (n,) 1abcd
234 self.sld = jnp.log(m0[0]) 1abcd
236 def iter(self, i, mi): 1efabcd
237 self.sld = self.sld + jnp.log(mi[i]) 1abcd
239 def finalize(self): 1efabcd
240 """sum(log(diag(m)))"""
241 return self.sld 1abcd