Coverage for src/lsqfitgp/_GP/_processes.py: 100%

162 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_GP/_processes.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import abc 1feabcd

21import numbers 1feabcd

22 

23import numpy 1feabcd

24from jax import numpy as jnp 1feabcd

25 

26from .. import _Kernel 1feabcd

27from .. import _Deriv 1feabcd

28 

29from . import _base 1feabcd

30 

31class GPProcesses(_base.GPBase): 1feabcd

32 

33 def __init__(self, *, covfun): 1feabcd

34 self._procs = {} # proc key -> _Proc 1feabcd

35 self._kernels = {} # (proc key, proc key) -> CrossKernel 1feabcd

36 if covfun is not None: 1feabcd

37 if not isinstance(covfun, _Kernel.Kernel): 1feabcd

38 raise TypeError('covariance function must be of class Kernel') 1abcd

39 self._procs[self.DefaultProcess] = self._ProcKernel(covfun) 1feabcd

40 

41 def _clone(self): 1feabcd

42 newself = super()._clone() 1feabcd

43 newself._procs = self._procs.copy() 1feabcd

44 newself._kernels = self._kernels.copy() 1feabcd

45 return newself 1feabcd

46 

47 class _Proc(abc.ABC): 1feabcd

48 """ 

49 Abstract base class for an object holding information about a process 

50 in a GP object. 

51 """ 

52 

53 @abc.abstractmethod 1feabcd

54 def __init__(self): # pragma: no cover 1feabcd

55 pass 

56 

57 class _ProcKernel(_Proc): 1feabcd

58 """An independent process defined with a kernel""" 

59 

60 def __init__(self, kernel, deriv=0): 1feabcd

61 assert isinstance(kernel, _Kernel.Kernel) 1feabcd

62 self.kernel = kernel 1feabcd

63 self.deriv = deriv 1feabcd

64 

65 class _ProcTransf(_Proc): 1feabcd

66 """A process defined as a linear transformation of other processes""" 

67 

68 def __init__(self, ops, deriv): 1feabcd

69 """ops = dict proc key -> callable""" 

70 self.ops = ops 1eabcd

71 self.deriv = deriv 1eabcd

72 

73 class _ProcLinTransf(_Proc): 1feabcd

74 

75 def __init__(self, transf, keys, deriv): 1feabcd

76 self.transf = transf 1feabcd

77 self.keys = keys 1feabcd

78 self.deriv = deriv 1feabcd

79 

80 class _ProcKernelTransf(_Proc): 1feabcd

81 """A process defined by an operation on the kernel of another process""" 

82 

83 def __init__(self, proc, transfname, arg): 1feabcd

84 """proc = proc key, transfname = Kernel transfname, arg = argument to transf """ 

85 self.proc = proc 1eabcd

86 self.transfname = transfname 1eabcd

87 self.arg = arg 1eabcd

88 

89 _zerokernel = _Kernel.Zero() 1feabcd

90 

91 @_base.newself 1feabcd

92 def defproc(self, key, kernel=None, *, deriv=0): 1feabcd

93 """ 

94  

95 Add an independent process. 

96  

97 Parameters 

98 ---------- 

99 key : hashable 

100 The name that identifies the process in the GP object. 

101 kernel : Kernel 

102 A kernel for the process. If None, use the default kernel. The 

103 difference between the default process and a process defined with 

104 the default kernel is that, although they have the same kernel, 

105 they are independent. 

106 deriv : Deriv-like 

107 Derivatives to take on the process defined by the kernel. 

108  

109 """ 

110 

111 if key in self._procs: 1feabcd

112 raise KeyError(f'process key {key!r} already used in GP') 1abcd

113 

114 if kernel is None: 1feabcd

115 kernel = self._procs[self.DefaultProcess].kernel 1abcd

116 

117 deriv = _Deriv.Deriv(deriv) 1feabcd

118 

119 self._procs[key] = self._ProcKernel(kernel, deriv) 1feabcd

120 

121 @_base.newself 1feabcd

122 def deftransf(self, key, ops, *, deriv=0): 1feabcd

123 """ 

124  

125 Define a new process as a linear combination of other processes. 

126  

127 Let f_i(x), i = 1, 2, ... be already defined processes, and g_i(x) be 

128 deterministic functions. The new process is defined as 

129  

130 h(x) = g_1(x) f_1(x) + g_2(x) f_2(x) + ... 

131  

132 Parameters 

133 ---------- 

134 key : hashable 

135 The name that identifies the new process in the GP object. 

136 ops : dict 

137 A dictionary mapping process keys to scalars or scalar 

138 functions. The functions must take an argument of the same kind 

139 of the domain of the process. 

140 deriv : Deriv-like 

141 The linear combination is derived as specified by this 

142 parameter. 

143  

144 """ 

145 

146 for k, func in ops.items(): 1eabcd

147 if k not in self._procs: 1eabcd

148 raise KeyError(f'process key {k!r} not in GP object') 1abcd

149 if not _Kernel.is_numerical_scalar(func) and not callable(func): 1eabcd

150 raise TypeError(f'object of type {type(func)!r} for key {k!r} is neither scalar nor callable') 1abcd

151 

152 if key in self._procs: 1eabcd

153 raise KeyError(f'process key {key!r} already used in GP') 1abcd

154 

155 deriv = _Deriv.Deriv(deriv) 1eabcd

156 

157 self._procs[key] = self._ProcTransf(ops, deriv) 1eabcd

158 

159 # we could implement deftransf in terms of deflintransf with 

160 # the following code, but deftransf has linear kernel building 

161 # cost so I'm leaving it around (probably not significant anyway) 

162 # functions = [ 

163 # op if callable(op) 

164 # else (lambda x: lambda _: x)(op) 

165 # for op in ops.values() 

166 # ] 

167 # def equivalent_lintransf(*procs): 

168 # def fun(x): 

169 # out = None 

170 # for fun, proc in zip(functions, procs): 

171 # this = fun(x) * proc(x) 

172 # out = this if out is None else out + this 

173 # return out 

174 # return fun 

175 # self.deflintransf(key, equivalent_lintransf, list(ops.keys()), deriv=deriv, checklin=False) 

176 

177 @_base.newself 1feabcd

178 def deflintransf(self, key, transf, procs, *, deriv=0, checklin=False): 1feabcd

179 """ 

180  

181 Define a new process as a linear combination of other processes. 

182  

183 Let f_i(x), i = 1, 2, ... be already defined processes, and T 

184 a linear map from processes to a single process. The new process is 

185  

186 h(x) = T(f_1, f_2, ...)(x). 

187  

188 Parameters 

189 ---------- 

190 key : hashable 

191 The name that identifies the new process in the GP object. 

192 transf : callable 

193 A function with signature ``transf(callable, callable, ...) -> callable``. 

194 procs : sequence 

195 The keys of the processes to be passed to the transformation. 

196 deriv : Deriv-like 

197 The linear combination is derived as specified by this 

198 parameter. 

199 checklin : bool 

200 If True, check if the transformation is linear. Default False. 

201  

202 Notes 

203 ----- 

204 The linearity check may fail if the transformation does nontrivial 

205 operations with the inner function input. 

206  

207 """ 

208 

209 # TODO support procs being a single key 

210 

211 if key in self._procs: 1feabcd

212 raise KeyError(f'process key {key!r} already used in GP') 1abcd

213 

214 for k in procs: 1feabcd

215 if k not in self._procs: 1feabcd

216 raise KeyError(k) 1abcd

217 

218 deriv = _Deriv.Deriv(deriv) 1feabcd

219 

220 if len(procs) == 0: 1feabcd

221 self._procs[key] = self._ProcKernel(self._zerokernel) 1abcd

222 return 1abcd

223 

224 if checklin is None: 1feabcd

225 checklin = self._checklin 1abcd

226 if checklin: 1feabcd

227 mockup_function = lambda a: lambda _: a 1abcd

228 # TODO this array mockup fails with jax functions 

229 class Mockup(numpy.ndarray): 1abcd

230 __getitem__ = lambda *_: Mockup((0,)) 1abcd

231 __getattr__ = __getitem__ 1abcd

232 def checktransf(*arrays): 1abcd

233 functions = [mockup_function(a) for a in arrays] 1abcd

234 return transf(*functions)(Mockup((0,))) 1abcd

235 shapes = [(11,)] * len(procs) 1abcd

236 self._checklinear(checktransf, shapes, elementwise=True) 1abcd

237 

238 self._procs[key] = self._ProcLinTransf(transf, procs, deriv) 1feabcd

239 

240 @_base.newself 1feabcd

241 def deflinop(self, key, transfname, arg, proc): 1feabcd

242 """ 

243  

244 Define a new process as the transformation of an existing one. 

245  

246 Parameters 

247 ---------- 

248 key : hashable 

249 Key for the new process. 

250 transfname : hashable 

251 A transformation recognized by the `~CrossKernel.transf` method 

252 of the kernel. 

253 arg : 

254 A valid argument to the transformation. 

255 proc : hashable 

256 Key of the process to be transformed. 

257  

258 """ 

259 

260 if key in self._procs: 1eabcd

261 raise KeyError(f'process key {key!r} already used in GP') 1abcd

262 if proc not in self._procs: 1eabcd

263 raise KeyError(f'process {proc!r} not found') 1abcd

264 self._procs[key] = self._ProcKernelTransf(proc, transfname, arg) 1eabcd

265 

266 def defderiv(self, key, deriv, proc): 1feabcd

267 """ 

268  

269 Define a new process as the derivative of an existing one. 

270  

271 .. math:: 

272 g(x) = \\frac{\\partial^n}{\\partial x^n} f(x) 

273  

274 Parameters 

275 ---------- 

276 key : hashable 

277 The key of the new process. 

278 deriv : Deriv-like 

279 Derivation order. 

280 proc : hashable 

281 The key of the process to be derived. 

282  

283 Returns 

284 ------- 

285 gp : GP 

286 A new GP object with the applied modifications. 

287 

288 """ 

289 deriv = _Deriv.Deriv(deriv) 1eabcd

290 return self.deflinop(key, 'diff', deriv, proc) 1eabcd

291 

292 def defxtransf(self, key, transf, proc): 1feabcd

293 """ 

294  

295 Define a new process by transforming the inputs of another one. 

296  

297 .. math:: 

298 g(x) = f(T(x)) 

299  

300 Parameters 

301 ---------- 

302 key : hashable 

303 The key of the new process. 

304 transf : callable 

305 A function mapping the new kind input to the input expected by the 

306 transformed process. 

307 proc : hashable 

308 The key of the process to be transformed. 

309  

310 Returns 

311 ------- 

312 gp : GP 

313 A new GP object with the applied modifications. 

314 

315 """ 

316 assert callable(transf) 1abcd

317 return self.deflinop(key, 'xtransf', transf, proc) 1abcd

318 

319 def defrescale(self, key, scalefun, proc): 1feabcd

320 """ 

321  

322 Define a new process as a rescaling of an existing one. 

323  

324 .. math:: 

325 g(x) = s(x)f(x) 

326  

327 Parameters 

328 ---------- 

329 key : hashable 

330 The key of the new process. 

331 scalefun : callable 

332 A function from the domain of the process to a scalar. 

333 proc : hashable 

334 The key of the process to be transformed. 

335  

336 Returns 

337 ------- 

338 gp : GP 

339 A new GP object with the applied modifications. 

340 

341 """ 

342 assert callable(scalefun) 1eabcd

343 return self.deflinop(key, 'rescale', scalefun, proc) 1eabcd

344 

345 def _crosskernel(self, xpkey, ypkey): 1feabcd

346 

347 # Check if the kernel is in cache. 

348 cache = self._kernels.get((xpkey, ypkey)) 1feabcd

349 if cache is not None: 1feabcd

350 return cache 1feabcd

351 

352 # Compute the kernel. 

353 xp = self._procs[xpkey] 1feabcd

354 yp = self._procs[ypkey] 1feabcd

355 

356 if isinstance(xp, self._ProcKernel) and isinstance(yp, self._ProcKernel): 1feabcd

357 kernel = self._crosskernel_kernels(xpkey, ypkey) 1feabcd

358 elif isinstance(xp, self._ProcTransf): 1feabcd

359 kernel = self._crosskernel_transf_any(xpkey, ypkey) 1eabcd

360 elif isinstance(yp, self._ProcTransf): 1feabcd

361 kernel = self._crosskernel_transf_any(ypkey, xpkey)._swap() 1eabcd

362 elif isinstance(xp, self._ProcLinTransf): 1feabcd

363 kernel = self._crosskernel_lintransf_any(xpkey, ypkey) 1feabcd

364 elif isinstance(yp, self._ProcLinTransf): 1feabcd

365 kernel = self._crosskernel_lintransf_any(ypkey, xpkey)._swap() 1feabcd

366 elif isinstance(xp, self._ProcKernelTransf): 1eabcd

367 kernel = self._crosskernel_kerneltransf_any(xpkey, ypkey) 1eabcd

368 elif isinstance(yp, self._ProcKernelTransf): 1eabcd

369 kernel = self._crosskernel_kerneltransf_any(ypkey, xpkey)._swap() 1eabcd

370 else: # pragma: no cover 

371 raise TypeError(f'unrecognized process types {type(xp)!r} and {type(yp)!r}') 1abcd

372 

373 # Save cache. 

374 self._kernels[xpkey, ypkey] = kernel 1feabcd

375 self._kernels[ypkey, xpkey] = kernel._swap() 1feabcd

376 

377 return kernel 1feabcd

378 

379 def _crosskernel_kernels(self, xpkey, ypkey): 1feabcd

380 xp = self._procs[xpkey] 1feabcd

381 yp = self._procs[ypkey] 1feabcd

382 

383 if xp is yp: 1feabcd

384 return xp.kernel.linop('diff', xp.deriv, xp.deriv) 1feabcd

385 else: 

386 return self._zerokernel 1feabcd

387 

388 def _crosskernel_transf_any(self, xpkey, ypkey): 1feabcd

389 xp = self._procs[xpkey] 1eabcd

390 yp = self._procs[ypkey] 1eabcd

391 

392 kernelsum = self._zerokernel 1eabcd

393 

394 for pkey, factor in xp.ops.items(): 1eabcd

395 kernel = self._crosskernel(pkey, ypkey) 1eabcd

396 if kernel is self._zerokernel: 1eabcd

397 continue 1eabcd

398 

399 if not callable(factor): 1eabcd

400 factor = (lambda f: lambda _: f)(factor) 1eabcd

401 kernel = kernel.linop('rescale', factor, None) 1eabcd

402 

403 if kernelsum is self._zerokernel: 1eabcd

404 kernelsum = kernel 1eabcd

405 else: 

406 kernelsum += kernel 1eabcd

407 

408 return kernelsum.linop('diff', xp.deriv, 0) 1eabcd

409 

410 def _crosskernel_lintransf_any(self, xpkey, ypkey): 1feabcd

411 xp = self._procs[xpkey] 1feabcd

412 yp = self._procs[ypkey] 1feabcd

413 

414 kernels = [self._crosskernel(pk, ypkey) for pk in xp.keys] 1feabcd

415 kernel = _Kernel.CrossKernel._nary(xp.transf, kernels, _Kernel.CrossKernel._side.LEFT) 1feabcd

416 kernel = kernel.linop('diff', xp.deriv, 0) 1feabcd

417 

418 return kernel 1feabcd

419 

420 def _crosskernel_kerneltransf_any(self, xpkey, ypkey): 1feabcd

421 xp = self._procs[xpkey] 1eabcd

422 yp = self._procs[ypkey] 1eabcd

423 

424 if xp is yp: 1eabcd

425 basekernel = self._crosskernel(xp.proc, xp.proc) 1eabcd

426 # I could avoid handling this case separately but it allows to 

427 # skip defining two-step transformations A -> CrossAT -> T 

428 else: 

429 basekernel = self._crosskernel(xp.proc, ypkey) 1eabcd

430 

431 if basekernel is self._zerokernel: 1eabcd

432 return self._zerokernel 1eabcd

433 elif xp is yp: 1eabcd

434 return basekernel.linop(xp.transfname, xp.arg) 1eabcd

435 else: 

436 return basekernel.linop(xp.transfname, xp.arg, None) 1eabcd