Coverage for src/lsqfitgp/_linalg/_seqalg.py: 98%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_linalg/_seqalg.py 

2# 

3# Copyright (c) 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import abc 1efabcd

21 

22from jax import numpy as jnp 1efabcd

23from jax import lax 1efabcd

24from jax import tree_util 1efabcd

25 

26from . import _pytree 1efabcd

27 

28class SequentialOperation(_pytree.AutoPyTree, metaclass=abc.ABCMeta): 1efabcd

29 """see jax.lax.fori_loop for semantics""" 

30 

31 @abc.abstractmethod 1efabcd

32 def __init__(self, *args): # pragma: no cover 1efabcd

33 pass 

34 

35 @property 1efabcd

36 @abc.abstractmethod 1efabcd

37 def inputs(self): # pragma: no cover 1efabcd

38 """tuple of indices of other ops to be used as input""" 

39 pass 

40 

41 @abc.abstractmethod 1efabcd

42 def init(self, n, *inputs): # pragma: no cover 1efabcd

43 """called before the cycle starts with the requested inputs""" 

44 pass 

45 

46 @abc.abstractmethod 1efabcd

47 def iter_out(self, i): # pragma: no cover 1efabcd

48 """output passed to other ops who request it through `inputs`, 

49 guaranteed to be called after `init`""" 

50 pass 

51 

52 @abc.abstractmethod 1efabcd

53 def iter(self, i, *inputs): # pragma: no cover 1efabcd

54 """update for iteration""" 

55 pass 

56 

57 @abc.abstractmethod 1efabcd

58 def finalize(self): # pragma: no cover 1efabcd

59 """return final product""" 

60 pass 

61 

62def sequential_algorithm(n, ops): 1efabcd

63 """ 

64 Define and execute a sequential algorithm on matrices. 

65 

66 Parameters 

67 ---------- 

68 n : int 

69 Number of steps of the algorithm. 

70 ops : list of SequentialOperation 

71 Instantiated `SequentialOperation`s that represent the algorithm. 

72 

73 Return 

74 ------ 

75 results : tuple 

76 The sequence of final outputs of each operation in `ops`. 

77 """ 

78 for i, op in enumerate(ops): 1abcd

79 inputs = op.inputs 1abcd

80 if any(j >= i for j in inputs): 1abcd

81 raise ValueError(f'{i}-th operation {op.__class__.__name__} requested inputs {inputs!r} with forward references') 1abcd

82 args = (ops[j].iter_out(0) for j in inputs) 1abcd

83 op.init(n, *args) 1abcd

84 def body_fun(i, ops): 1abcd

85 for op in ops: 1abcd

86 args = (ops[j].iter_out(i) for j in op.inputs) 1abcd

87 op.iter(i, *args) 1abcd

88 return ops 1abcd

89 ops = lax.fori_loop(1, n, body_fun, ops) # TODO convert to lax.scan and unroll? 1abcd

90 return tuple(op.finalize() for op in ops) 1abcd

91 

92class Producer(SequentialOperation): 1efabcd

93 """produces something at each iteration but no final output""" 

94 

95 def finalize(self): 1efabcd

96 pass 1abcd

97 

98class Consumer(SequentialOperation): 1efabcd

99 """produces a final output but no iteration output""" 

100 

101 iter_out = NotImplemented 1efabcd

102 

103class SingleInput(SequentialOperation): 1efabcd

104 

105 def __init__(self, input): 1efabcd

106 self.inputs = (input,) 1abcd

107 

108 inputs = NotImplemented 1efabcd

109 

110class Stack(Consumer, SingleInput): 1efabcd

111 """input = an operation producing arrays""" 

112 

113 def init(self, n, a0): 1efabcd

114 out = jnp.zeros((n,) + a0.shape, a0.dtype) 1abcd

115 self.out = out.at[0, ...].set(a0) 1abcd

116 

117 def iter(self, i, ai): 1efabcd

118 self.out = self.out.at[i, ...].set(ai) 1abcd

119 

120 def finalize(self): 1efabcd

121 """the stacked arrays""" 

122 return self.out 1abcd

123 

124class MatMulIterByFull(Consumer, SingleInput): 1efabcd

125 

126 def __init__(self, input, b): 1efabcd

127 """input = an operation producing pieces of left operand (a) 

128 b = right operand""" 

129 self.inputs = (input,) 1abcd

130 b = jnp.asarray(b) 1abcd

131 assert b.ndim in (1, 2) 1abcd

132 vec = b.ndim < 2 1abcd

133 if vec: 1abcd

134 b = b[:, None] 1abcd

135 self.vec = vec 1abcd

136 self.b = b 1abcd

137 

138 @abc.abstractmethod 1efabcd

139 def init(self, n, a0): # pragma: no cover 1efabcd

140 self.ab = ... 

141 

142 @abc.abstractmethod 1efabcd

143 def iter(self, i, ai): # pragma: no cover 1efabcd

144 self.ab = ... 

145 

146 def finalize(self): 1efabcd

147 ab = self.ab 1abcd

148 if self.vec: 148 ↛ 149line 148 didn't jump to line 149 because the condition on line 148 was never true1abcd

149 ab = jnp.squeeze(ab, -1) 

150 return ab 1abcd

151 

152class MatMulRowByFull(Producer, MatMulIterByFull): 1efabcd

153 

154 def init(self, n, a0): 1efabcd

155 assert a0.ndim == 1 1abcd

156 assert self.b.shape[0] == len(a0) 1abcd

157 self.abi = a0 @ self.b 1abcd

158 

159 def iter_out(self, i): 1efabcd

160 abi = self.abi 1abcd

161 if self.vec: 1abcd

162 abi = jnp.squeeze(abi, -1) 1abcd

163 return abi 1abcd

164 

165 def iter(self, i, ai): 1efabcd

166 """ i-th row of input @ b """ 

167 self.abi = ai @ self.b 1abcd

168 

169class SolveTriLowerColByFull(MatMulIterByFull): 1efabcd

170 # x[0] /= a[0, 0] 

171 # for i in range(1, len(x)): 

172 # x[i:] -= x[i - 1] * a[i:, i - 1] 

173 # x[i] /= a[i, i] 

174 

175 def init(self, n, a0): 1efabcd

176 b = self.b 1abcd

177 del self.b 1abcd

178 assert a0.shape == (n,) 1abcd

179 assert b.shape[0] == n 1abcd

180 self.prevai = a0.at[0].set(0) 1abcd

181 self.ab = b.at[0, :].divide(a0[0]) 1abcd

182 

183 def iter(self, i, ai): 1efabcd

184 ab = self.ab 1abcd

185 ab = ab - ab[i - 1, :] * self.prevai[:, None] 1abcd

186 self.ab = ab.at[i, :].divide(ai[i]) 1abcd

187 self.prevai = ai.at[i].set(0) 1abcd

188 

189class Rows(Producer): 1efabcd

190 

191 def __init__(self, x): 1efabcd

192 self.x = x 1abcd

193 

194 inputs = () 1efabcd

195 

196 def init(self, n): 1efabcd

197 pass 1abcd

198 

199 def iter_out(self, i): 1efabcd

200 return self.x[i, ...] 1abcd

201 

202 def iter(self, i): 1efabcd

203 pass 1abcd

204 

205class MatMulColByRow(Consumer): 1efabcd

206 

207 def __init__(self, inputa, inputb): 1efabcd

208 self.inputs = (inputa, inputb) 1abcd

209 

210 inputs = None 1efabcd

211 

212 def init(self, n, a0, b0): 1efabcd

213 assert a0.ndim == 1 and b0.ndim <= 1 1abcd

214 self.vec = b0.ndim > 0 1abcd

215 if self.vec: 1abcd

216 self.ab = a0[:, None] * b0[None, :] 1abcd

217 else: 

218 self.ab = a0 * b0 1abcd

219 

220 def iter(self, i, ai, bi): 1efabcd

221 if self.vec: 1abcd

222 self.ab = self.ab + ai[:, None] * bi[None, :] 1abcd

223 else: 

224 self.ab = self.ab + ai * bi 1abcd

225 

226 def finalize(self): 1efabcd

227 return self.ab 1abcd

228 

229class SumLogDiag(Consumer, SingleInput): 1efabcd

230 """input = operation producing the rows/columns of a square matrix""" 

231 

232 def init(self, n, m0): 1efabcd

233 assert m0.shape == (n,) 1abcd

234 self.sld = jnp.log(m0[0]) 1abcd

235 

236 def iter(self, i, mi): 1efabcd

237 self.sld = self.sld + jnp.log(mi[i]) 1abcd

238 

239 def finalize(self): 1efabcd

240 """sum(log(diag(m)))""" 

241 return self.sld 1abcd