Coverage for src/lsqfitgp/_Kernel/_alg.py: 100%

71 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_Kernel/_alg.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" register algops on CrossKernel and AffineSpan """ 

21 

22import functools 1efabcd

23 

24from jax import numpy as jnp 1efabcd

25from jax.scipy import special as jspecial 1efabcd

26 

27from .. import _special 1efabcd

28 

29from . import _util 1efabcd

30from ._crosskernel import CrossKernel, AffineSpan 1efabcd

31 

32@CrossKernel.register_algop 1efabcd

33def add(tcls, self, other): 1efabcd

34 r""" 

35  

36 Sum of kernels. 

37  

38 .. math:: 

39 \mathrm{newkernel}(x, y) &= \mathrm{kernel}(x, y) + \mathrm{other}(x, y), \\ 

40 \mathrm{newkernel}(x, y) &= \mathrm{kernel}(x, y) + \mathrm{other}. 

41  

42 Parameters 

43 ---------- 

44 other : CrossKernel or scalar 

45 The other kernel. 

46  

47 """ 

48 core = self.core 1efabcd

49 if _util.is_numerical_scalar(other): 1efabcd

50 newcore = lambda x, y, **kw: core(x, y, **kw) + other 1abcd

51 elif isinstance(other, CrossKernel): 1efabcd

52 other = other.core 1efabcd

53 newcore = lambda x, y, **kw: core(x, y, **kw) + other(x, y, **kw) 1efabcd

54 else: 

55 return NotImplemented 1abcd

56 return self._clone(core=newcore) 1efabcd

57 

58@CrossKernel.register_algop 1efabcd

59def mul(tcls, self, other): 1efabcd

60 r""" 

61  

62 Product of kernels. 

63  

64 .. math:: 

65 \mathrm{newkernel}(x, y) &= \mathrm{kernel}(x, y) \cdot \mathrm{other}(x, y), \\ 

66 \mathrm{newkernel}(x, y) &= \mathrm{kernel}(x, y) \cdot \mathrm{other}. 

67  

68 Parameters 

69 ---------- 

70 other : CrossKernel or scalar 

71 The other kernel. 

72  

73 """ 

74 core = self.core 1efabcd

75 if _util.is_numerical_scalar(other): 1efabcd

76 newcore = lambda x, y, **kw: core(x, y, **kw) * other 1efabcd

77 elif isinstance(other, CrossKernel): 1efabcd

78 other = other.core 1efabcd

79 newcore = lambda x, y, **kw: core(x, y, **kw) * other(x, y, **kw) 1efabcd

80 else: 

81 return NotImplemented 1abcd

82 return self._clone(core=newcore) 1efabcd

83 

84@CrossKernel.register_algop 1efabcd

85def pow(tcls, self, *, exponent): 1efabcd

86 r""" 

87  

88 Power of the kernel. 

89  

90 .. math:: 

91 \mathrm{newkernel}(x, y) = \mathrm{kernel}(x, y)^{\mathrm{exponent}} 

92  

93 Parameters 

94 ---------- 

95 exponent : nonnegative integer 

96 The exponent. If traced by jax, it must have unsigned integer type. 

97  

98 """ 

99 if _util.is_nonnegative_integer_scalar(exponent): 1abcd

100 core = self.core 1abcd

101 newcore = lambda x, y, **kw: core(x, y, **kw) ** exponent 1abcd

102 return self._clone(core=newcore) 1abcd

103 else: 

104 return NotImplemented 1abcd

105 

106 # TODO this will raise TypeError on negative integers. It should stop 

107 # method search and raise ValueError. Same for rpow. Check if it is a 

108 # scalar, then check if it satisfies the bound. 

109 

110@CrossKernel.register_algop 1efabcd

111def rpow(tcls, self, *, base): 1efabcd

112 r""" 

113  

114 Exponentiation of the kernel. 

115  

116 .. math:: 

117 \text{newkernel}(x, y) = \text{base}^{\text{kernel}(x, y)} 

118  

119 Parameters 

120 ---------- 

121 base : scalar 

122 A number >= 1. If traced by jax, the value is not checked. 

123  

124 """ 

125 if _util.is_scalar_cond_trueontracer(base, lambda x: x >= 1): 1abcd

126 core = self.core 1abcd

127 newcore = lambda x, y, **kw: base ** core(x, y, **kw) 1abcd

128 return self._clone(core=newcore) 1abcd

129 else: 

130 return NotImplemented 1abcd

131 

132CrossKernel.register_ufuncalgop(jnp.tan) 1efabcd

133# CrossKernel.register_ufuncalgop(lambda x: 1 / jnp.sinc(x), '1/sinc') 

134CrossKernel.register_ufuncalgop(lambda x: 1 / jnp.cos(x), '1/cos') 1efabcd

135CrossKernel.register_ufuncalgop(jnp.arcsin) 1efabcd

136CrossKernel.register_ufuncalgop(lambda x: 1 / jnp.arccos(x), '1/arccos') 1efabcd

137CrossKernel.register_ufuncalgop(lambda x: 1 / (1 - x), '1/(1-x)') 1efabcd

138CrossKernel.register_ufuncalgop(jnp.exp) 1efabcd

139CrossKernel.register_ufuncalgop(lambda x: -jnp.log1p(-x), '-log1p(-x)') 1efabcd

140CrossKernel.register_ufuncalgop(jnp.expm1) 1efabcd

141CrossKernel.register_ufuncalgop(_special.expm1x) 1efabcd

142CrossKernel.register_ufuncalgop(jnp.sinh) 1efabcd

143CrossKernel.register_ufuncalgop(jnp.cosh) 1efabcd

144CrossKernel.register_ufuncalgop(jnp.arctanh) 1efabcd

145CrossKernel.register_ufuncalgop(jspecial.i0) 1efabcd

146CrossKernel.register_ufuncalgop(jspecial.i1) 1efabcd

147# @CrossKernel.register_ufuncalgop 

148# def iv(x, *, order): 

149# assert _util.is_nonnegative_scalar_trueontracer(order) 

150# return _special.iv(order, x) 

151 

152# TODO other unary algop: 

153# - hypergeom (wrap the scipy impl in _special) 

154 

155@functools.partial(AffineSpan.register_algop, transfname='add') 1efabcd

156def affine_add(tcls, self, other): 1efabcd

157 newself = AffineSpan.super_transf('add', self, other) 1abcd

158 if _util.is_numerical_scalar(other): 1abcd

159 dynkw = dict(self.dynkw) 1abcd

160 dynkw['offset'] = dynkw['offset'] + other 1abcd

161 return newself._clone(self.__class__, dynkw=dynkw) 1abcd

162 else: 

163 return newself 1abcd

164 

165@functools.partial(AffineSpan.register_algop, transfname='mul') 1efabcd

166def affine_mul(tcls, self, other): 1efabcd

167 newself = AffineSpan.super_transf('mul', self, other) 1abcd

168 if _util.is_numerical_scalar(other): 1abcd

169 dynkw = dict(self.dynkw) 1abcd

170 dynkw['offset'] = other * dynkw['offset'] 1abcd

171 dynkw['ampl'] = other * dynkw['ampl'] 1abcd

172 return newself._clone(self.__class__, dynkw=dynkw) 1abcd

173 else: 

174 return newself 1abcd