Coverage for src/lsqfitgp/_Kernel/_decorators.py: 100%
47 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_Kernel/_decorators.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import types 1efabcd
21import warnings 1efabcd
22import inspect 1efabcd
24from . import _crosskernel 1efabcd
25from . import _kernel 1efabcd
26from . import _stationary 1efabcd
27from . import _isotropic 1efabcd
29def makekernelsubclass(core, bases, **prekw): 1efabcd
31 named_object = getattr(core, 'pyfunc', core) # np.vectorize objects 1efabcd
32 name = getattr(named_object, '__name__', 'DecoratedKernel') 1efabcd
34 bases = tuple(bases) 1efabcd
36 def exec_body(ns): 1efabcd
38 def __new__(cls, **kw): 1efabcd
39 kwargs = prekw.copy() 1efabcd
40 kwargs.update(kw) 1efabcd
41 if len(kwargs) < len(prekw) + len(kw): 1efabcd
42 shared_keys = set(prekw).intersection(kw) 1abcd
43 warnings.warn(f'overriding init argument(s) ' 1abcd
44 f'{shared_keys} of kernel {name}')
45 self = super(newclass, cls).__new__(cls, core, **kwargs) 1efabcd
46 if isinstance(self, bases[-1]) and set(kw).issubset(self.initkw): 1efabcd
47 self = self._clone(cls) 1efabcd
48 return self 1efabcd
50 ns['__new__'] = __new__ 1efabcd
51 ns['__wrapped__'] = named_object 1efabcd
52 ns['__doc__'] = named_object.__doc__ 1efabcd
54 newclass = types.new_class(name, bases, exec_body=exec_body) 1efabcd
55 assert issubclass(newclass, _crosskernel.CrossKernel) 1efabcd
56 return newclass 1efabcd
58def crosskernel(*args, bases=None, **kw): 1efabcd
59 """
61 Decorator to convert a function to a subclass of `CrossKernel`.
63 Parameters
64 ----------
65 *args :
66 Either a function to decorate, or no arguments. The function is used
67 as the `core` argument to `CrossKernel`.
68 bases : tuple of types, optional
69 The bases of the new class. If not specified, use `CrossKernel`.
70 **kw :
71 Additional arguments are passed to `CrossKernel`.
73 Returns
74 -------
75 class_or_dec : callable or type
76 If `args` is empty, a decorator ready to be applied, else the kernel
77 class.
79 Examples
80 --------
82 >>> @lgp.crosskernel(derivable=True)
83 ... def MyKernel(x, y, a=0, b=0):
84 ... return (x - a) * (y - b)
86 Notes
87 -----
88 Arguments passed to the class constructor may modify the class. If the
89 object returned by the the constructor is a subclass of the superclass
90 targeted by the decorator, and all the arguments passed at instantiation
91 are passed down to the decorated function, the class of the object is
92 enforced to be the new class.
94 """
95 if bases is None: 1efabcd
96 bases = _crosskernel.CrossKernel, 1abcd
97 functional = lambda core: makekernelsubclass(core, bases, **kw) 1efabcd
98 if len(args) == 0: 1efabcd
99 return functional 1efabcd
100 elif len(args) == 1: 1efabcd
101 return functional(*args) 1efabcd
102 else:
103 raise ValueError(len(args)) 1abcd
105def kernel(*args, **kw): 1efabcd
106 """
108 Like `crosskernel` but makes a subclass of `Kernel`.
110 Examples
111 --------
113 >>> @lgp.kernel(loc=10) # the default loc will be 10
114 ... def MyKernel(x, y, cippa=1, lippa=42):
115 ... return cippa * (x * y) ** lippa
117 """
118 return crosskernel(*args, bases=(_kernel.Kernel,), **kw) 1efabcd
120def crossstationarykernel(*args, **kw): 1efabcd
121 """
123 Like `crosskernel` but makes a subclass of `CrossStationaryKernel`.
125 """
126 return crosskernel(*args, bases=(_stationary.CrossStationaryKernel,), **kw) 1abcd
128def stationarykernel(*args, **kw): 1efabcd
129 """
131 Like `crosskernel` but makes a subclass of `StationaryKernel`.
133 Examples
134 --------
136 >>> @lgp.stationarykernel(input='posabs')
137 ... def MyKernel(absdelta, cippa=1, lippa=42):
138 ... return cippa * sum(
139 ... jnp.exp(-absdelta[name] / lippa)
140 ... for name in absdelta.dtype.names
141 ... )
143 """
144 return crosskernel(*args, bases=(_stationary.StationaryKernel,), **kw) 1efabcd
146def crossisotropickernel(*args, **kw): 1efabcd
147 """
149 Like `crosskernel` but makes a subclass of `CrossIsotropicKernel`.
151 """
152 return crosskernel(*args, bases=(_isotropic.CrossIsotropicKernel,), **kw) 1abcd
154def isotropickernel(*args, **kw): 1efabcd
155 """
157 Like `crosskernel` but makes a subclass of `IsotropicKernel`.
159 Examples
160 --------
162 >>> @lgp.isotropickernel(derivable=True)
163 ... def MyKernel(distsquared, cippa=1, lippa=42):
164 ... return cippa * jnp.exp(-distsquared) + lippa
166 """
167 return crosskernel(*args, bases=(_isotropic.IsotropicKernel,), **kw) 1efabcd