Coverage for src/lsqfitgp/_Kernel/_ops.py: 99%
167 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_Kernel/_ops.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" register linops on CrossKernel and AffineSpan """
22import functools 1feabcd
23import numbers 1feabcd
24import sys 1feabcd
26from jax import numpy as jnp 1feabcd
27import numpy 1feabcd
29from .. import _jaxext 1feabcd
30from .. import _Deriv 1feabcd
31from .. import _array 1feabcd
33from . import _util 1feabcd
34from ._crosskernel import CrossKernel, AffineSpan 1feabcd
36def rescale_argparser(fun): 1feabcd
37 if not callable(fun): 1eabcd
38 raise ValueError("argument to 'rescale' must be a function") 1abcd
39 return fun 1eabcd
41@functools.partial(CrossKernel.register_corelinop, argparser=rescale_argparser) 1feabcd
42def rescale(core, xfun, yfun): 1feabcd
43 r"""
45 Rescale the output of the function.
47 .. math::
48 T(f)(x) = \mathrm{fun}(x) f(x)
50 Parameters
51 ----------
52 xfun, yfun : callable or None
53 Functions from the type of the arguments of the kernel to scalar.
55 """
56 if not xfun: 1eabcd
57 return lambda x, y, **kw: core(x, y, **kw) * yfun(y) 1abcd
58 elif not yfun: 1eabcd
59 return lambda x, y, **kw: xfun(x) * core(x, y, **kw) 1eabcd
60 else:
61 return lambda x, y, **kw: xfun(x) * core(x, y, **kw) * yfun(y) 1eabcd
63@CrossKernel.register_xtransf 1feabcd
64def derivable(derivable): 1feabcd
65 """
66 Specify the degree of derivability of the function.
68 Parameters
69 ----------
70 xderivable, yderivable: int or None
71 Degree of derivability of the function. None means unknown.
73 Notes
74 -----
75 The derivability check is hardcoded into the kernel core and it is not
76 possible to remove it afterwards by applying ``'derivable'`` again with a
77 higher limit.
79 """
80 if isinstance(derivable, bool): 1feabcd
81 derivable = sys.maxsize if derivable else 0 1feabcd
82 elif not isinstance(derivable, numbers.Integral) or derivable < 0: 1eabcd
83 raise ValueError(f'derivability degree {derivable!r} not valid') 1abcd
85 def error_func(current, n): 1feabcd
86 raise ValueError(f'Took {current} derivatives > limit {n} on argument ' 1abcd
87 'of a kernel. This error may be spurious if there are '
88 'derivatives on values that define the input to the kernel, for '
89 'example if a hyperparameter enters the calculation of x. To '
90 'suppress the error, initialize the kernel with derivable=True.')
92 def xtransf(x): 1feabcd
93 if hasattr(x, 'dtype'): 93 ↛ 103line 93 didn't jump to line 103 because the condition on line 93 was always true1feabcd
94 # this branch handles limit_derivatives not accepting non-jax types
95 # because of being based on jax.custom_jvp; this restriction on
96 # custom_jvp appeared in jax 0.4.17
97 if x.dtype.names is not None: 1feabcd
98 x = _array.StructuredArray(x) # structured arrays are not 1feabcd
99 # compatible with jax but common in lsqfitgp, so I wrap them
100 elif not _jaxext.is_jax_type(x.dtype): 1feabcd
101 return x # since anyway there would be an error if a derivative 1abcd
102 # tried to pass through a non-jax type
103 return _jaxext.limit_derivatives(x, n=derivable, error_func=error_func) 1feabcd
105 return xtransf 1feabcd
107 # TODO this system does not ignore additional derivatives that are not
108 # taken by .transf('diff'). Plan:
109 # - understand how to associate a jax transformation to a frame
110 # - make a context manager with a global stack
111 # - at initialization it takes the frames of the derivatives made by diff
112 # - the context is around calling core in diff.newcore
113 # - it adds the frames to the stack
114 # - derivable asks the context manager class to find the frames in x
115 # - raises if their number is above the limit
117def _asfloat(x): 1feabcd
118 return x.astype(_jaxext.float_type(x)) 1feabcd
120def diff_argparser(deriv): 1feabcd
121 deriv = _Deriv.Deriv(deriv) 1feabcd
122 return deriv if deriv else None 1feabcd
124@functools.partial(CrossKernel.register_corelinop, argparser=diff_argparser) 1feabcd
125def diff(core, xderiv, yderiv): 1feabcd
126 r"""
128 Derive the function.
130 .. math::
131 T(f)(x) = \frac{\partial^n f}{\partial x^n} (x)
133 Parameters
134 ----------
135 xderiv, yderiv : Deriv_like
136 A `Deriv` or something that can be converted to a `Deriv`.
138 Raises
139 ------
140 RuntimeError
141 The derivative orders are greater than the `derivative` attribute.
143 """
145 # reparse derivatives because they could be None
146 xderiv = _Deriv.Deriv(xderiv) 1feabcd
147 yderiv = _Deriv.Deriv(yderiv) 1feabcd
149 # wrapper of kernel with derivable arguments unpacked
150 def f(x, y, *args, **kw): 1feabcd
151 i = -1 1feabcd
152 if not xderiv.implicit: 1feabcd
153 for i, dim in enumerate(xderiv): 1feabcd
154 x = x.at[dim].set(args[i]) 1feabcd
155 if not yderiv.implicit: 1feabcd
156 for j, dim in enumerate(yderiv): 1feabcd
157 y = y.at[dim].set(args[1 + i + j]) 1feabcd
158 return core(x, y, **kw) 1feabcd
160 # last x index in case iteration on x does not run but does on y
161 i = -1 1feabcd
163 # derive w.r.t. first argument
164 if xderiv.implicit: 1feabcd
165 for _ in range(xderiv.order): 1feabcd
166 f = _jaxext.elementwise_grad(f, 0) 1feabcd
167 else:
168 for i, dim in enumerate(xderiv): 1feabcd
169 for _ in range(xderiv[dim]): 1feabcd
170 f = _jaxext.elementwise_grad(f, 2 + i) 1feabcd
172 # derive w.r.t. second argument
173 if yderiv.implicit: 1feabcd
174 for _ in range(yderiv.order): 1feabcd
175 f = _jaxext.elementwise_grad(f, 1) 1feabcd
176 else:
177 for j, dim in enumerate(yderiv): 1feabcd
178 for _ in range(yderiv[dim]): 1feabcd
179 f = _jaxext.elementwise_grad(f, 2 + 1 + i + j) 1feabcd
181 # check derivatives are ok for actual input arrays, wrap structured arrays
182 def process_arg(x, deriv, pos): 1feabcd
183 if x.dtype.names is not None: 1feabcd
184 for dim in deriv: 1feabcd
185 if dim not in x.dtype.names: 1feabcd
186 raise ValueError(f'derivative along missing field {dim!r} ' 1abcd
187 f'on {pos} argument')
188 if not jnp.issubdtype(x.dtype[dim], jnp.number): 1feabcd
189 raise TypeError(f'derivative along non-numeric field ' 1abcd
190 f'{dim!r} on {pos} argument')
191 return _array.StructuredArray(x) 1feabcd
192 elif not deriv.implicit: 1feabcd
193 raise ValueError('derivative on named fields with non-structured ' 1abcd
194 f'array on {pos} argument')
195 elif not jnp.issubdtype(x.dtype, jnp.number): 1feabcd
196 raise TypeError(f'derivative along non-numeric array on ' 1abcd
197 f'{pos} argument')
198 return x 1feabcd
200 def newcore(x, y, **kw): 1feabcd
201 x = process_arg(x, xderiv, 'left') 1feabcd
202 y = process_arg(y, yderiv, 'right') 1feabcd
204 args = [] 1feabcd
206 if not xderiv.implicit: 1feabcd
207 for dim in xderiv: 1feabcd
208 args.append(_asfloat(x[dim])) 1feabcd
209 elif xderiv: 1feabcd
210 x = _asfloat(x) 1feabcd
212 if not yderiv.implicit: 1feabcd
213 for dim in yderiv: 1feabcd
214 args.append(_asfloat(y[dim])) 1feabcd
215 elif yderiv: 1feabcd
216 y = _asfloat(y) 1feabcd
218 return f(x, y, *args, **kw) 1feabcd
220 return newcore 1feabcd
222@CrossKernel.register_xtransf 1feabcd
223def xtransf(fun): 1feabcd
224 r"""
226 Transform the inputs of the function.
228 .. math::
229 T(f)(x) = f(\mathrm{fun}(x))
231 Parameters
232 ----------
233 xfun, yfun : callable or None
234 Functions mapping a new kind of input to the kind of input accepted by
235 the kernel.
237 """
238 if not callable(fun): 1abcd
239 raise ValueError("argument to 'xtransf' must be a function") 1abcd
240 return fun 1abcd
242@CrossKernel.register_xtransf 1feabcd
243def dim(dim): 1feabcd
244 """
245 Restrict the function to a field of a structured input::
247 T(f)(x) = f(x[dim])
249 If the array is not structured, an exception is raised. If the field for
250 name `dim` has a nontrivial shape, the array passed to the kernel is still
251 structured but has only field `dim`.
253 Parameters
254 ----------
255 xdim, ydim: None, str, list of str
256 Field names or lists of field names.
258 """
259 if not isinstance(dim, (str, list)): 1feabcd
260 raise TypeError(f'dim must be a (list of) string, found {dim!r}') 1abcd
261 def fun(x): 1feabcd
262 if x.dtype.names is None: 1feabcd
263 raise ValueError(f'cannot get dim={dim!r} from non-structured input') 1abcd
264 elif x.dtype[dim].shape: 1feabcd
265 return x[[dim]] 1abcd
266 else:
267 return x[dim] 1feabcd
268 return fun 1feabcd
270@CrossKernel.register_xtransf 1feabcd
271def maxdim(maxdim): 1feabcd
272 """
274 Restrict the process to a maximum input dimensionality.
276 Parameters
277 ----------
278 xmaxdim, ymaxdim: None, int
279 Maximum dimensionality of the input.
281 Notes
282 -----
283 Once applied a restriction, the check is hardcoded into the kernel core and
284 it is not possible to remove it by applying again `maxdim` with a larger
285 limit.
287 """
288 if not isinstance(maxdim, numbers.Integral) or maxdim < 0: 1feabcd
289 raise ValueError(f'maximum dimensionality {maxdim!r} not valid') 1abcd
291 def fun(x): 1feabcd
292 nd = _array._nd(x.dtype) 1feabcd
293 with _jaxext.skipifabstract(): 1feabcd
294 if nd > maxdim: 1feabcd
295 raise ValueError(f'kernel applied to input with {nd} ' 1abcd
296 f'fields > maxdim={maxdim}')
297 return x 1feabcd
299 return fun 1feabcd
301@CrossKernel.register_xtransf 1feabcd
302def loc(loc): 1feabcd
303 r"""
304 Translate the process inputs:
306 .. math::
307 T(f)(x) = f(x - \mathrm{loc})
309 Parameters
310 ----------
311 xloc, yloc: None, number
312 Translations.
314 """
315 with _jaxext.skipifabstract(): 1feabcd
316 assert -jnp.inf < loc < jnp.inf, loc 1feabcd
317 return lambda x: _util.ufunc_recurse_dtype(lambda x: x - loc, x) 1feabcd
319@CrossKernel.register_xtransf 1feabcd
320def scale(scale): 1feabcd
321 r"""
322 Rescale the process inputs:
324 .. math::
325 T(f)(x) = f(x / \mathrm{scale})
327 Parameters
328 ----------
329 xscale, yscale: None, number
330 Rescaling factors.
332 """
333 with _jaxext.skipifabstract(): 1feabcd
334 assert 0 < scale < jnp.inf, scale 1feabcd
335 return lambda x: _util.ufunc_recurse_dtype(lambda x: x / scale, x) 1feabcd
337def normalize_argparser(do): 1feabcd
338 return do if do else None 1fabcd
340@functools.partial(CrossKernel.register_corelinop, argparser=normalize_argparser) 1feabcd
341def normalize(core, dox, doy): 1feabcd
342 r"""
343 Rescale the process to unit variance.
345 .. math::
346 T(f)(x) &= f(x) / \sqrt{\mathrm{Std}[f(x)]} \\
347 &= f(x) / \sqrt{\mathrm{kernel}(x, x)}
349 Parameters
350 ----------
351 dox, doy : bool
352 Whether to rescale.
353 """
354 if dox and doy: 1fabcd
355 return lambda x, y, **kw: core(x, y, **kw) / jnp.sqrt(core(x, x, **kw) * core(y, y, **kw)) 1fabcd
356 elif dox: 1abcd
357 return lambda x, y, **kw: core(x, y, **kw) / jnp.sqrt(core(x, x, **kw)) 1abcd
358 else:
359 return lambda x, y, **kw: core(x, y, **kw) / jnp.sqrt(core(y, y, **kw)) 1abcd
361@CrossKernel.register_corelinop 1feabcd
362def cond(core, cond1, cond2, other): 1feabcd
363 r"""
365 Switch between two independent processes based on a condition.
367 .. math::
368 T(f, g)(x) = \begin{cases}
369 f(x) & \text{if $\mathrm{cond}(x)$,} \\
370 g(x) & \text{otherwise.}
371 \end{cases}
373 Parameters
374 ----------
375 cond1, cond2 : callable
376 Function that is applied on an array of points and must return
377 a boolean array with the same shape.
378 other :
379 Kernel of the process used where the condition is false.
381 """
382 def newcore(x, y, **kw): 1eabcd
383 xcond = cond1(x) 1eabcd
384 ycond = cond2(y) 1eabcd
385 r = jnp.where(xcond & ycond, core(x, y, **kw), other(x, y, **kw)) 1eabcd
386 return jnp.where(xcond ^ ycond, 0, r) 1eabcd
388 return newcore 1eabcd
390 # TODO add a function `choose` to extend `cond`,
391 # kernel0.linop('choose', kernel1, kernel2, ..., lambda x: x['index'])
393AffineSpan.inherit_transf('maxdim') 1feabcd
394AffineSpan.inherit_transf('derivable') 1feabcd
396@functools.partial(AffineSpan.register_linop, transfname='loc') 1feabcd
397def affine_loc(tcls, self, xloc, yloc): 1feabcd
398 dynkw = dict(self.dynkw) 1abcd
399 newself = tcls.super_transf('loc', self, xloc, yloc) 1abcd
400 dynkw['lloc'] = dynkw['lloc'] + xloc * dynkw['lscale'] 1abcd
401 dynkw['rloc'] = dynkw['rloc'] + yloc * dynkw['rscale'] 1abcd
402 return newself._clone(self.__class__, dynkw=dynkw) 1abcd
404@functools.partial(AffineSpan.register_linop, transfname='scale') 1feabcd
405def affine_scale(tcls, self, xscale, yscale): 1feabcd
406 dynkw = dict(self.dynkw) 1eabcd
407 newself = tcls.super_transf('scale', self, xscale, yscale) 1eabcd
408 dynkw['lscale'] = dynkw['lscale'] * xscale 1eabcd
409 dynkw['rscale'] = dynkw['rscale'] * yscale 1eabcd
410 return newself._clone(self.__class__, dynkw=dynkw) 1eabcd