Coverage for src/lsqfitgp/_Kernel/_decorators.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_Kernel/_decorators.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import types 1efabcd

21import warnings 1efabcd

22import inspect 1efabcd

23 

24from . import _crosskernel 1efabcd

25from . import _kernel 1efabcd

26from . import _stationary 1efabcd

27from . import _isotropic 1efabcd

28 

29def makekernelsubclass(core, bases, **prekw): 1efabcd

30 

31 named_object = getattr(core, 'pyfunc', core) # np.vectorize objects 1efabcd

32 name = getattr(named_object, '__name__', 'DecoratedKernel') 1efabcd

33 

34 bases = tuple(bases) 1efabcd

35 

36 def exec_body(ns): 1efabcd

37 

38 def __new__(cls, **kw): 1efabcd

39 kwargs = prekw.copy() 1efabcd

40 kwargs.update(kw) 1efabcd

41 if len(kwargs) < len(prekw) + len(kw): 1efabcd

42 shared_keys = set(prekw).intersection(kw) 1abcd

43 warnings.warn(f'overriding init argument(s) ' 1abcd

44 f'{shared_keys} of kernel {name}') 

45 self = super(newclass, cls).__new__(cls, core, **kwargs) 1efabcd

46 if isinstance(self, bases[-1]) and set(kw).issubset(self.initkw): 1efabcd

47 self = self._clone(cls) 1efabcd

48 return self 1efabcd

49 

50 ns['__new__'] = __new__ 1efabcd

51 ns['__wrapped__'] = named_object 1efabcd

52 ns['__doc__'] = named_object.__doc__ 1efabcd

53 

54 newclass = types.new_class(name, bases, exec_body=exec_body) 1efabcd

55 assert issubclass(newclass, _crosskernel.CrossKernel) 1efabcd

56 return newclass 1efabcd

57 

58def crosskernel(*args, bases=None, **kw): 1efabcd

59 """ 

60  

61 Decorator to convert a function to a subclass of `CrossKernel`. 

62 

63 Parameters 

64 ---------- 

65 *args : 

66 Either a function to decorate, or no arguments. The function is used 

67 as the `core` argument to `CrossKernel`. 

68 bases : tuple of types, optional 

69 The bases of the new class. If not specified, use `CrossKernel`. 

70 **kw : 

71 Additional arguments are passed to `CrossKernel`. 

72 

73 Returns 

74 ------- 

75 class_or_dec : callable or type 

76 If `args` is empty, a decorator ready to be applied, else the kernel 

77 class. 

78 

79 Examples 

80 -------- 

81  

82 >>> @lgp.crosskernel(derivable=True) 

83 ... def MyKernel(x, y, a=0, b=0): 

84 ... return (x - a) * (y - b) 

85 

86 Notes 

87 ----- 

88 Arguments passed to the class constructor may modify the class. If the 

89 object returned by the the constructor is a subclass of the superclass 

90 targeted by the decorator, and all the arguments passed at instantiation 

91 are passed down to the decorated function, the class of the object is 

92 enforced to be the new class. 

93  

94 """ 

95 if bases is None: 1efabcd

96 bases = _crosskernel.CrossKernel, 1abcd

97 functional = lambda core: makekernelsubclass(core, bases, **kw) 1efabcd

98 if len(args) == 0: 1efabcd

99 return functional 1efabcd

100 elif len(args) == 1: 1efabcd

101 return functional(*args) 1efabcd

102 else: 

103 raise ValueError(len(args)) 1abcd

104 

105def kernel(*args, **kw): 1efabcd

106 """ 

107  

108 Like `crosskernel` but makes a subclass of `Kernel`. 

109 

110 Examples 

111 -------- 

112  

113 >>> @lgp.kernel(loc=10) # the default loc will be 10 

114 ... def MyKernel(x, y, cippa=1, lippa=42): 

115 ... return cippa * (x * y) ** lippa 

116  

117 """ 

118 return crosskernel(*args, bases=(_kernel.Kernel,), **kw) 1efabcd

119 

120def crossstationarykernel(*args, **kw): 1efabcd

121 """ 

122  

123 Like `crosskernel` but makes a subclass of `CrossStationaryKernel`. 

124  

125 """ 

126 return crosskernel(*args, bases=(_stationary.CrossStationaryKernel,), **kw) 1abcd

127 

128def stationarykernel(*args, **kw): 1efabcd

129 """ 

130  

131 Like `crosskernel` but makes a subclass of `StationaryKernel`. 

132 

133 Examples 

134 -------- 

135  

136 >>> @lgp.stationarykernel(input='posabs') 

137 ... def MyKernel(absdelta, cippa=1, lippa=42): 

138 ... return cippa * sum( 

139 ... jnp.exp(-absdelta[name] / lippa) 

140 ... for name in absdelta.dtype.names 

141 ... ) 

142  

143 """ 

144 return crosskernel(*args, bases=(_stationary.StationaryKernel,), **kw) 1efabcd

145 

146def crossisotropickernel(*args, **kw): 1efabcd

147 """ 

148  

149 Like `crosskernel` but makes a subclass of `CrossIsotropicKernel`. 

150  

151 """ 

152 return crosskernel(*args, bases=(_isotropic.CrossIsotropicKernel,), **kw) 1abcd

153 

154def isotropickernel(*args, **kw): 1efabcd

155 """ 

156  

157 Like `crosskernel` but makes a subclass of `IsotropicKernel`. 

158 

159 Examples 

160 -------- 

161  

162 >>> @lgp.isotropickernel(derivable=True) 

163 ... def MyKernel(distsquared, cippa=1, lippa=42): 

164 ... return cippa * jnp.exp(-distsquared) + lippa 

165  

166 """ 

167 return crosskernel(*args, bases=(_isotropic.IsotropicKernel,), **kw) 1efabcd