Coverage for src/lsqfitgp/_GP/_processes.py: 100%
162 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_GP/_processes.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import abc 1feabcd
21import numbers 1feabcd
23import numpy 1feabcd
24from jax import numpy as jnp 1feabcd
26from .. import _Kernel 1feabcd
27from .. import _Deriv 1feabcd
29from . import _base 1feabcd
31class GPProcesses(_base.GPBase): 1feabcd
33 def __init__(self, *, covfun): 1feabcd
34 self._procs = {} # proc key -> _Proc 1feabcd
35 self._kernels = {} # (proc key, proc key) -> CrossKernel 1feabcd
36 if covfun is not None: 1feabcd
37 if not isinstance(covfun, _Kernel.Kernel): 1feabcd
38 raise TypeError('covariance function must be of class Kernel') 1abcd
39 self._procs[self.DefaultProcess] = self._ProcKernel(covfun) 1feabcd
41 def _clone(self): 1feabcd
42 newself = super()._clone() 1feabcd
43 newself._procs = self._procs.copy() 1feabcd
44 newself._kernels = self._kernels.copy() 1feabcd
45 return newself 1feabcd
47 class _Proc(abc.ABC): 1feabcd
48 """
49 Abstract base class for an object holding information about a process
50 in a GP object.
51 """
53 @abc.abstractmethod 1feabcd
54 def __init__(self): # pragma: no cover 1feabcd
55 pass
57 class _ProcKernel(_Proc): 1feabcd
58 """An independent process defined with a kernel"""
60 def __init__(self, kernel, deriv=0): 1feabcd
61 assert isinstance(kernel, _Kernel.Kernel) 1feabcd
62 self.kernel = kernel 1feabcd
63 self.deriv = deriv 1feabcd
65 class _ProcTransf(_Proc): 1feabcd
66 """A process defined as a linear transformation of other processes"""
68 def __init__(self, ops, deriv): 1feabcd
69 """ops = dict proc key -> callable"""
70 self.ops = ops 1eabcd
71 self.deriv = deriv 1eabcd
73 class _ProcLinTransf(_Proc): 1feabcd
75 def __init__(self, transf, keys, deriv): 1feabcd
76 self.transf = transf 1feabcd
77 self.keys = keys 1feabcd
78 self.deriv = deriv 1feabcd
80 class _ProcKernelTransf(_Proc): 1feabcd
81 """A process defined by an operation on the kernel of another process"""
83 def __init__(self, proc, transfname, arg): 1feabcd
84 """proc = proc key, transfname = Kernel transfname, arg = argument to transf """
85 self.proc = proc 1eabcd
86 self.transfname = transfname 1eabcd
87 self.arg = arg 1eabcd
89 _zerokernel = _Kernel.Zero() 1feabcd
91 @_base.newself 1feabcd
92 def defproc(self, key, kernel=None, *, deriv=0): 1feabcd
93 """
95 Add an independent process.
97 Parameters
98 ----------
99 key : hashable
100 The name that identifies the process in the GP object.
101 kernel : Kernel
102 A kernel for the process. If None, use the default kernel. The
103 difference between the default process and a process defined with
104 the default kernel is that, although they have the same kernel,
105 they are independent.
106 deriv : Deriv-like
107 Derivatives to take on the process defined by the kernel.
109 """
111 if key in self._procs: 1feabcd
112 raise KeyError(f'process key {key!r} already used in GP') 1abcd
114 if kernel is None: 1feabcd
115 kernel = self._procs[self.DefaultProcess].kernel 1abcd
117 deriv = _Deriv.Deriv(deriv) 1feabcd
119 self._procs[key] = self._ProcKernel(kernel, deriv) 1feabcd
121 @_base.newself 1feabcd
122 def deftransf(self, key, ops, *, deriv=0): 1feabcd
123 """
125 Define a new process as a linear combination of other processes.
127 Let f_i(x), i = 1, 2, ... be already defined processes, and g_i(x) be
128 deterministic functions. The new process is defined as
130 h(x) = g_1(x) f_1(x) + g_2(x) f_2(x) + ...
132 Parameters
133 ----------
134 key : hashable
135 The name that identifies the new process in the GP object.
136 ops : dict
137 A dictionary mapping process keys to scalars or scalar
138 functions. The functions must take an argument of the same kind
139 of the domain of the process.
140 deriv : Deriv-like
141 The linear combination is derived as specified by this
142 parameter.
144 """
146 for k, func in ops.items(): 1eabcd
147 if k not in self._procs: 1eabcd
148 raise KeyError(f'process key {k!r} not in GP object') 1abcd
149 if not _Kernel.is_numerical_scalar(func) and not callable(func): 1eabcd
150 raise TypeError(f'object of type {type(func)!r} for key {k!r} is neither scalar nor callable') 1abcd
152 if key in self._procs: 1eabcd
153 raise KeyError(f'process key {key!r} already used in GP') 1abcd
155 deriv = _Deriv.Deriv(deriv) 1eabcd
157 self._procs[key] = self._ProcTransf(ops, deriv) 1eabcd
159 # we could implement deftransf in terms of deflintransf with
160 # the following code, but deftransf has linear kernel building
161 # cost so I'm leaving it around (probably not significant anyway)
162 # functions = [
163 # op if callable(op)
164 # else (lambda x: lambda _: x)(op)
165 # for op in ops.values()
166 # ]
167 # def equivalent_lintransf(*procs):
168 # def fun(x):
169 # out = None
170 # for fun, proc in zip(functions, procs):
171 # this = fun(x) * proc(x)
172 # out = this if out is None else out + this
173 # return out
174 # return fun
175 # self.deflintransf(key, equivalent_lintransf, list(ops.keys()), deriv=deriv, checklin=False)
177 @_base.newself 1feabcd
178 def deflintransf(self, key, transf, procs, *, deriv=0, checklin=False): 1feabcd
179 """
181 Define a new process as a linear combination of other processes.
183 Let f_i(x), i = 1, 2, ... be already defined processes, and T
184 a linear map from processes to a single process. The new process is
186 h(x) = T(f_1, f_2, ...)(x).
188 Parameters
189 ----------
190 key : hashable
191 The name that identifies the new process in the GP object.
192 transf : callable
193 A function with signature ``transf(callable, callable, ...) -> callable``.
194 procs : sequence
195 The keys of the processes to be passed to the transformation.
196 deriv : Deriv-like
197 The linear combination is derived as specified by this
198 parameter.
199 checklin : bool
200 If True, check if the transformation is linear. Default False.
202 Notes
203 -----
204 The linearity check may fail if the transformation does nontrivial
205 operations with the inner function input.
207 """
209 # TODO support procs being a single key
211 if key in self._procs: 1feabcd
212 raise KeyError(f'process key {key!r} already used in GP') 1abcd
214 for k in procs: 1feabcd
215 if k not in self._procs: 1feabcd
216 raise KeyError(k) 1abcd
218 deriv = _Deriv.Deriv(deriv) 1feabcd
220 if len(procs) == 0: 1feabcd
221 self._procs[key] = self._ProcKernel(self._zerokernel) 1abcd
222 return 1abcd
224 if checklin is None: 1feabcd
225 checklin = self._checklin 1abcd
226 if checklin: 1feabcd
227 mockup_function = lambda a: lambda _: a 1abcd
228 # TODO this array mockup fails with jax functions
229 class Mockup(numpy.ndarray): 1abcd
230 __getitem__ = lambda *_: Mockup((0,)) 1abcd
231 __getattr__ = __getitem__ 1abcd
232 def checktransf(*arrays): 1abcd
233 functions = [mockup_function(a) for a in arrays] 1abcd
234 return transf(*functions)(Mockup((0,))) 1abcd
235 shapes = [(11,)] * len(procs) 1abcd
236 self._checklinear(checktransf, shapes, elementwise=True) 1abcd
238 self._procs[key] = self._ProcLinTransf(transf, procs, deriv) 1feabcd
240 @_base.newself 1feabcd
241 def deflinop(self, key, transfname, arg, proc): 1feabcd
242 """
244 Define a new process as the transformation of an existing one.
246 Parameters
247 ----------
248 key : hashable
249 Key for the new process.
250 transfname : hashable
251 A transformation recognized by the `~CrossKernel.transf` method
252 of the kernel.
253 arg :
254 A valid argument to the transformation.
255 proc : hashable
256 Key of the process to be transformed.
258 """
260 if key in self._procs: 1eabcd
261 raise KeyError(f'process key {key!r} already used in GP') 1abcd
262 if proc not in self._procs: 1eabcd
263 raise KeyError(f'process {proc!r} not found') 1abcd
264 self._procs[key] = self._ProcKernelTransf(proc, transfname, arg) 1eabcd
266 def defderiv(self, key, deriv, proc): 1feabcd
267 """
269 Define a new process as the derivative of an existing one.
271 .. math::
272 g(x) = \\frac{\\partial^n}{\\partial x^n} f(x)
274 Parameters
275 ----------
276 key : hashable
277 The key of the new process.
278 deriv : Deriv-like
279 Derivation order.
280 proc : hashable
281 The key of the process to be derived.
283 Returns
284 -------
285 gp : GP
286 A new GP object with the applied modifications.
288 """
289 deriv = _Deriv.Deriv(deriv) 1eabcd
290 return self.deflinop(key, 'diff', deriv, proc) 1eabcd
292 def defxtransf(self, key, transf, proc): 1feabcd
293 """
295 Define a new process by transforming the inputs of another one.
297 .. math::
298 g(x) = f(T(x))
300 Parameters
301 ----------
302 key : hashable
303 The key of the new process.
304 transf : callable
305 A function mapping the new kind input to the input expected by the
306 transformed process.
307 proc : hashable
308 The key of the process to be transformed.
310 Returns
311 -------
312 gp : GP
313 A new GP object with the applied modifications.
315 """
316 assert callable(transf) 1abcd
317 return self.deflinop(key, 'xtransf', transf, proc) 1abcd
319 def defrescale(self, key, scalefun, proc): 1feabcd
320 """
322 Define a new process as a rescaling of an existing one.
324 .. math::
325 g(x) = s(x)f(x)
327 Parameters
328 ----------
329 key : hashable
330 The key of the new process.
331 scalefun : callable
332 A function from the domain of the process to a scalar.
333 proc : hashable
334 The key of the process to be transformed.
336 Returns
337 -------
338 gp : GP
339 A new GP object with the applied modifications.
341 """
342 assert callable(scalefun) 1eabcd
343 return self.deflinop(key, 'rescale', scalefun, proc) 1eabcd
345 def _crosskernel(self, xpkey, ypkey): 1feabcd
347 # Check if the kernel is in cache.
348 cache = self._kernels.get((xpkey, ypkey)) 1feabcd
349 if cache is not None: 1feabcd
350 return cache 1feabcd
352 # Compute the kernel.
353 xp = self._procs[xpkey] 1feabcd
354 yp = self._procs[ypkey] 1feabcd
356 if isinstance(xp, self._ProcKernel) and isinstance(yp, self._ProcKernel): 1feabcd
357 kernel = self._crosskernel_kernels(xpkey, ypkey) 1feabcd
358 elif isinstance(xp, self._ProcTransf): 1feabcd
359 kernel = self._crosskernel_transf_any(xpkey, ypkey) 1eabcd
360 elif isinstance(yp, self._ProcTransf): 1feabcd
361 kernel = self._crosskernel_transf_any(ypkey, xpkey)._swap() 1eabcd
362 elif isinstance(xp, self._ProcLinTransf): 1feabcd
363 kernel = self._crosskernel_lintransf_any(xpkey, ypkey) 1feabcd
364 elif isinstance(yp, self._ProcLinTransf): 1feabcd
365 kernel = self._crosskernel_lintransf_any(ypkey, xpkey)._swap() 1feabcd
366 elif isinstance(xp, self._ProcKernelTransf): 1eabcd
367 kernel = self._crosskernel_kerneltransf_any(xpkey, ypkey) 1eabcd
368 elif isinstance(yp, self._ProcKernelTransf): 1eabcd
369 kernel = self._crosskernel_kerneltransf_any(ypkey, xpkey)._swap() 1eabcd
370 else: # pragma: no cover
371 raise TypeError(f'unrecognized process types {type(xp)!r} and {type(yp)!r}') 1abcd
373 # Save cache.
374 self._kernels[xpkey, ypkey] = kernel 1feabcd
375 self._kernels[ypkey, xpkey] = kernel._swap() 1feabcd
377 return kernel 1feabcd
379 def _crosskernel_kernels(self, xpkey, ypkey): 1feabcd
380 xp = self._procs[xpkey] 1feabcd
381 yp = self._procs[ypkey] 1feabcd
383 if xp is yp: 1feabcd
384 return xp.kernel.linop('diff', xp.deriv, xp.deriv) 1feabcd
385 else:
386 return self._zerokernel 1feabcd
388 def _crosskernel_transf_any(self, xpkey, ypkey): 1feabcd
389 xp = self._procs[xpkey] 1eabcd
390 yp = self._procs[ypkey] 1eabcd
392 kernelsum = self._zerokernel 1eabcd
394 for pkey, factor in xp.ops.items(): 1eabcd
395 kernel = self._crosskernel(pkey, ypkey) 1eabcd
396 if kernel is self._zerokernel: 1eabcd
397 continue 1eabcd
399 if not callable(factor): 1eabcd
400 factor = (lambda f: lambda _: f)(factor) 1eabcd
401 kernel = kernel.linop('rescale', factor, None) 1eabcd
403 if kernelsum is self._zerokernel: 1eabcd
404 kernelsum = kernel 1eabcd
405 else:
406 kernelsum += kernel 1eabcd
408 return kernelsum.linop('diff', xp.deriv, 0) 1eabcd
410 def _crosskernel_lintransf_any(self, xpkey, ypkey): 1feabcd
411 xp = self._procs[xpkey] 1feabcd
412 yp = self._procs[ypkey] 1feabcd
414 kernels = [self._crosskernel(pk, ypkey) for pk in xp.keys] 1feabcd
415 kernel = _Kernel.CrossKernel._nary(xp.transf, kernels, _Kernel.CrossKernel._side.LEFT) 1feabcd
416 kernel = kernel.linop('diff', xp.deriv, 0) 1feabcd
418 return kernel 1feabcd
420 def _crosskernel_kerneltransf_any(self, xpkey, ypkey): 1feabcd
421 xp = self._procs[xpkey] 1eabcd
422 yp = self._procs[ypkey] 1eabcd
424 if xp is yp: 1eabcd
425 basekernel = self._crosskernel(xp.proc, xp.proc) 1eabcd
426 # I could avoid handling this case separately but it allows to
427 # skip defining two-step transformations A -> CrossAT -> T
428 else:
429 basekernel = self._crosskernel(xp.proc, ypkey) 1eabcd
431 if basekernel is self._zerokernel: 1eabcd
432 return self._zerokernel 1eabcd
433 elif xp is yp: 1eabcd
434 return basekernel.linop(xp.transfname, xp.arg) 1eabcd
435 else:
436 return basekernel.linop(xp.transfname, xp.arg, None) 1eabcd