Coverage for src/lsqfitgp/_Kernel/_alg.py: 100%
71 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_Kernel/_alg.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" register algops on CrossKernel and AffineSpan """
22import functools 1efabcd
24from jax import numpy as jnp 1efabcd
25from jax.scipy import special as jspecial 1efabcd
27from .. import _special 1efabcd
29from . import _util 1efabcd
30from ._crosskernel import CrossKernel, AffineSpan 1efabcd
32@CrossKernel.register_algop 1efabcd
33def add(tcls, self, other): 1efabcd
34 r"""
36 Sum of kernels.
38 .. math::
39 \mathrm{newkernel}(x, y) &= \mathrm{kernel}(x, y) + \mathrm{other}(x, y), \\
40 \mathrm{newkernel}(x, y) &= \mathrm{kernel}(x, y) + \mathrm{other}.
42 Parameters
43 ----------
44 other : CrossKernel or scalar
45 The other kernel.
47 """
48 core = self.core 1efabcd
49 if _util.is_numerical_scalar(other): 1efabcd
50 newcore = lambda x, y, **kw: core(x, y, **kw) + other 1abcd
51 elif isinstance(other, CrossKernel): 1efabcd
52 other = other.core 1efabcd
53 newcore = lambda x, y, **kw: core(x, y, **kw) + other(x, y, **kw) 1efabcd
54 else:
55 return NotImplemented 1abcd
56 return self._clone(core=newcore) 1efabcd
58@CrossKernel.register_algop 1efabcd
59def mul(tcls, self, other): 1efabcd
60 r"""
62 Product of kernels.
64 .. math::
65 \mathrm{newkernel}(x, y) &= \mathrm{kernel}(x, y) \cdot \mathrm{other}(x, y), \\
66 \mathrm{newkernel}(x, y) &= \mathrm{kernel}(x, y) \cdot \mathrm{other}.
68 Parameters
69 ----------
70 other : CrossKernel or scalar
71 The other kernel.
73 """
74 core = self.core 1efabcd
75 if _util.is_numerical_scalar(other): 1efabcd
76 newcore = lambda x, y, **kw: core(x, y, **kw) * other 1efabcd
77 elif isinstance(other, CrossKernel): 1efabcd
78 other = other.core 1efabcd
79 newcore = lambda x, y, **kw: core(x, y, **kw) * other(x, y, **kw) 1efabcd
80 else:
81 return NotImplemented 1abcd
82 return self._clone(core=newcore) 1efabcd
84@CrossKernel.register_algop 1efabcd
85def pow(tcls, self, *, exponent): 1efabcd
86 r"""
88 Power of the kernel.
90 .. math::
91 \mathrm{newkernel}(x, y) = \mathrm{kernel}(x, y)^{\mathrm{exponent}}
93 Parameters
94 ----------
95 exponent : nonnegative integer
96 The exponent. If traced by jax, it must have unsigned integer type.
98 """
99 if _util.is_nonnegative_integer_scalar(exponent): 1abcd
100 core = self.core 1abcd
101 newcore = lambda x, y, **kw: core(x, y, **kw) ** exponent 1abcd
102 return self._clone(core=newcore) 1abcd
103 else:
104 return NotImplemented 1abcd
106 # TODO this will raise TypeError on negative integers. It should stop
107 # method search and raise ValueError. Same for rpow. Check if it is a
108 # scalar, then check if it satisfies the bound.
110@CrossKernel.register_algop 1efabcd
111def rpow(tcls, self, *, base): 1efabcd
112 r"""
114 Exponentiation of the kernel.
116 .. math::
117 \text{newkernel}(x, y) = \text{base}^{\text{kernel}(x, y)}
119 Parameters
120 ----------
121 base : scalar
122 A number >= 1. If traced by jax, the value is not checked.
124 """
125 if _util.is_scalar_cond_trueontracer(base, lambda x: x >= 1): 1abcd
126 core = self.core 1abcd
127 newcore = lambda x, y, **kw: base ** core(x, y, **kw) 1abcd
128 return self._clone(core=newcore) 1abcd
129 else:
130 return NotImplemented 1abcd
132CrossKernel.register_ufuncalgop(jnp.tan) 1efabcd
133# CrossKernel.register_ufuncalgop(lambda x: 1 / jnp.sinc(x), '1/sinc')
134CrossKernel.register_ufuncalgop(lambda x: 1 / jnp.cos(x), '1/cos') 1efabcd
135CrossKernel.register_ufuncalgop(jnp.arcsin) 1efabcd
136CrossKernel.register_ufuncalgop(lambda x: 1 / jnp.arccos(x), '1/arccos') 1efabcd
137CrossKernel.register_ufuncalgop(lambda x: 1 / (1 - x), '1/(1-x)') 1efabcd
138CrossKernel.register_ufuncalgop(jnp.exp) 1efabcd
139CrossKernel.register_ufuncalgop(lambda x: -jnp.log1p(-x), '-log1p(-x)') 1efabcd
140CrossKernel.register_ufuncalgop(jnp.expm1) 1efabcd
141CrossKernel.register_ufuncalgop(_special.expm1x) 1efabcd
142CrossKernel.register_ufuncalgop(jnp.sinh) 1efabcd
143CrossKernel.register_ufuncalgop(jnp.cosh) 1efabcd
144CrossKernel.register_ufuncalgop(jnp.arctanh) 1efabcd
145CrossKernel.register_ufuncalgop(jspecial.i0) 1efabcd
146CrossKernel.register_ufuncalgop(jspecial.i1) 1efabcd
147# @CrossKernel.register_ufuncalgop
148# def iv(x, *, order):
149# assert _util.is_nonnegative_scalar_trueontracer(order)
150# return _special.iv(order, x)
152# TODO other unary algop:
153# - hypergeom (wrap the scipy impl in _special)
155@functools.partial(AffineSpan.register_algop, transfname='add') 1efabcd
156def affine_add(tcls, self, other): 1efabcd
157 newself = AffineSpan.super_transf('add', self, other) 1abcd
158 if _util.is_numerical_scalar(other): 1abcd
159 dynkw = dict(self.dynkw) 1abcd
160 dynkw['offset'] = dynkw['offset'] + other 1abcd
161 return newself._clone(self.__class__, dynkw=dynkw) 1abcd
162 else:
163 return newself 1abcd
165@functools.partial(AffineSpan.register_algop, transfname='mul') 1efabcd
166def affine_mul(tcls, self, other): 1efabcd
167 newself = AffineSpan.super_transf('mul', self, other) 1abcd
168 if _util.is_numerical_scalar(other): 1abcd
169 dynkw = dict(self.dynkw) 1abcd
170 dynkw['offset'] = other * dynkw['offset'] 1abcd
171 dynkw['ampl'] = other * dynkw['ampl'] 1abcd
172 return newself._clone(self.__class__, dynkw=dynkw) 1abcd
173 else:
174 return newself 1abcd