Coverage for src/lsqfitgp/_Kernel/_crosskernel.py: 99%
401 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_Kernel/_crosskernel.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import enum 1feabcd
21import functools 1feabcd
22import sys 1feabcd
23import collections 1feabcd
24import types 1feabcd
25import abc 1feabcd
26import warnings 1feabcd
28import numpy 1feabcd
29from jax import numpy as jnp 1feabcd
31from .. import _array 1feabcd
32from .. import _jaxext 1feabcd
33from .. import _utils 1feabcd
35from . import _util 1feabcd
37@functools.lru_cache(maxsize=None) 1feabcd
38def least_common_superclass(*classes): 1feabcd
39 """
40 Find a "least" common superclass. The class is searched in all the MROs,
41 but the comparison is done with `issubclass` to support virtual inheritance.
42 """
43 mros = [c.__mro__ for c in classes] 1feabcd
44 indices = [0] * len(mros) 1feabcd
45 for i, mroi in enumerate(mros): 1feabcd
46 for j, mroj in enumerate(mros): 1feabcd
47 if i == j: 1feabcd
48 continue 1feabcd
49 while not issubclass(mroi[0], mroj[indices[j]]): 1feabcd
50 indices[j] += 1 1feabcd
51 idx = numpy.argmin(indices) 1feabcd
52 return mros[idx][indices[idx]] 1feabcd
54class CrossKernel: 1feabcd
55 r"""
57 Base class to represent kernels, i.e., covariance functions.
59 A kernel is a two-argument function that computes the covariance between
60 two functions at some points according to some probability distribution:
62 .. math::
63 \mathrm{kernel}(x, y) = \mathrm{Cov}[f(x), g(y)].
65 `CrossKernel` objects are callable, the signature is ``obj(x, y)``, and
66 they can be summed and multiplied between them and with scalars. They
67 are immutable; all operations return new objects.
69 Parameters
70 ----------
71 core : callable
72 A function with signature ``core(x, y)``, where ``x`` and ``y``
73 are two broadcastable numpy arrays, which computes the value of the
74 kernel.
75 scale, loc, derivable, maxdim, dim :
76 If specified, these arguments are passed as arguments to the
77 correspondingly named operators, in the order listed here. See `linop`.
78 Briefly: the kernel selects only fields `dim` in the input, checks the
79 dimensionality against `maxdim`, checks there are not too many
80 derivatives taken on the arguments, then transforms as ``(x - loc) /
81 scale``. If any argument is callable, it is passed `**kw` and must
82 return the actual argument. If an argument is a tuple, it is interpreted
83 as a pair of arguments.
84 forcekron : bool, default False
85 If True, apply ``.transf('forcekron')`` to the kernel, before the
86 operations above. Available only for `Kernel`.
87 batchbytes : number, optional
88 If specified, apply ``.batch(batchbytes)`` to the kernel.
89 dynkw : dict, optional
90 Additional keyword arguments passed to `core` that can be modified
91 by transformations. Deleted by transformations by default.
92 **initkw :
93 Additional keyword arguments passed to `core` that can be read but not
94 changed by transformations.
96 Attributes
97 ----------
98 initkw : dict
99 The `initkw` argument.
100 dynkw : dict
101 The `dynkw` argument, or a modification of it if the object has been
102 transformed.
103 core : callable
104 The `core` argument partially evaluated on `initkw`, or another
105 function wrapping it if the object has been transformed.
107 Methods
108 -------
109 batch
110 linop
111 algop
112 transf
113 register_transf
114 register_linop
115 register_corelinop
116 register_xtransf
117 register_algop
118 register_ufuncalgop
119 transf_help
120 has_transf
121 list_transf
122 super_transf
123 make_linop_family
125 See also
126 --------
127 Kernel
129 Notes
130 -----
131 The predefined class hierarchy and the class logic of the transformations
132 assume that each kernel class corresponds to a subalgebra, i.e., addition
133 and multiplication preserve the class.
135 """
137 __slots__ = '_initkw', '_dynkw', '_core' 1feabcd
138 # only __new__ and _clone shall access these attributes
140 @property 1feabcd
141 def initkw(self): 1feabcd
142 return types.MappingProxyType(self._initkw) 1feabcd
144 @property 1feabcd
145 def dynkw(self): 1feabcd
146 return types.MappingProxyType(self._dynkw) 1feabcd
148 @property 1feabcd
149 def core(self): 1feabcd
150 return self._core 1feabcd
152 def __new__(cls, core, *, 1feabcd
153 scale=None,
154 loc=None,
155 derivable=None,
156 maxdim=None,
157 dim=None,
158 forcekron=False,
159 batchbytes=None,
160 dynkw={},
161 **initkw,
162 ):
163 self = super().__new__(cls) 1feabcd
165 self._initkw = initkw 1feabcd
166 self._dynkw = dict(dynkw) 1feabcd
167 self._core = lambda x, y, **dynkw: core(x, y, **initkw, **dynkw) 1feabcd
169 if forcekron: 1feabcd
170 self = self.transf('forcekron') 1abcd
172 linop_args = { 1feabcd
173 'scale': scale,
174 'loc': loc,
175 'derivable': derivable,
176 'maxdim': maxdim,
177 'dim': dim,
178 }
179 for transfname, arg in linop_args.items(): 1feabcd
180 if callable(arg): 1feabcd
181 arg = arg(**initkw) 1feabcd
182 if isinstance(arg, tuple): 1feabcd
183 self = self.linop(transfname, *arg) 1eabcd
184 else:
185 self = self.linop(transfname, arg) 1feabcd
187 if batchbytes is not None: 1feabcd
188 self = self.batch(batchbytes) 1eabcd
190 return self 1feabcd
192 def __call__(self, x, y): 1feabcd
193 x = _array.asarray(x) 1feabcd
194 y = _array.asarray(y) 1feabcd
195 shape = _array.broadcast(x, y).shape 1feabcd
196 result = self.core(x, y, **self.dynkw) 1feabcd
197 assert isinstance(result, (numpy.ndarray, jnp.number, jnp.ndarray)) 1feabcd
198 assert jnp.issubdtype(result.dtype, jnp.number), result.dtype 1feabcd
199 assert result.shape == shape, (result.shape, shape) 1feabcd
200 return result 1feabcd
202 def _clone(self, cls=None, *, initkw=None, dynkw=None, core=None): 1feabcd
203 newself = object.__new__(self.__class__ if cls is None else cls) 1feabcd
204 newself._initkw = self._initkw if initkw is None else dict(initkw) 1feabcd
205 newself._dynkw = {} if dynkw is None else dict(dynkw) 1feabcd
206 newself._core = self._core if core is None else core 1feabcd
207 return newself 1feabcd
209 class _side(enum.Enum): 1feabcd
210 LEFT = 0 1feabcd
211 RIGHT = 1 1feabcd
213 @classmethod 1feabcd
214 def _nary(cls, op, kernels, side): 1feabcd
216 if side is cls._side.LEFT: 1feabcd
217 wrapper = lambda c, _, y, **kw: lambda x: c(x, y, **kw) 1feabcd
218 arg = lambda x, _: x 1feabcd
219 elif side is cls._side.RIGHT: 1abcd
220 wrapper = lambda c, x, _, **kw: lambda y: c(x, y, **kw) 1abcd
221 arg = lambda _, y: y 1abcd
222 else: # pragma: no cover
223 raise KeyError(side)
225 cores = [k.core for k in kernels] 1feabcd
226 def core(x, y, **kw): 1feabcd
227 wrapped = [wrapper(c, x, y, **kw) for c in cores] 1feabcd
228 transformed = op(*wrapped) 1feabcd
229 return transformed(arg(x, y)) 1feabcd
231 return __class__(core) 1feabcd
233 # TODO instead of wrapping the cores just by closing over an argument,
234 # define a `PartialKernel` object that has linop defined, with one arg
235 # fixed to None, and __call__ that closes over an argument. This way I
236 # could apply ops directly with deflintransf.
238 # TODO move in the binary op methods the logic that decides wether to
239 # return NotImplemented, while the algops will raise an exception if the
240 # argument is not to their liking
242 def __add__(self, other): 1feabcd
243 return self.algop('add', other) 1feabcd
245 __radd__ = __add__ 1feabcd
247 def __mul__(self, other): 1feabcd
248 return self.algop('mul', other) 1feabcd
250 __rmul__ = __mul__ 1feabcd
252 def __pow__(self, other): 1feabcd
253 return self.algop('pow', exponent=other) 1abcd
255 def __rpow__(self, other): 1feabcd
256 return self.algop('rpow', base=other) 1abcd
258 def _swap(self): 1feabcd
259 """ permute the arguments """
260 core = self.core 1feabcd
261 return self._clone( 1feabcd
262 __class__,
263 core=lambda x, y, **kw: core(y, x, **kw),
264 )
266 # TODO make _swap a transf inherited by CrossIsotropicKernel => messes
267 # up with Kernel subclasses. I want to make Kernel subclasses identity
268 # on swap, so it can't be an inherited transf, because Kernel appears
269 # as the second base.
271 def batch(self, maxnbytes): 1feabcd
272 """
273 Return a batched version of the kernel.
275 The batched kernel processes its inputs in chunks to try to limit memory
276 usage.
278 Parameters
279 ----------
280 maxnbytes : number
281 The maximum number of input bytes per chunk, counted after
282 broadcasting the input shapes. Actual broadcasting may not occur if
283 not induced by the operations in the kernel.
285 Returns
286 -------
287 batched_kernel : CrossKernel
288 The same kernel but with batched computations.
289 """
290 core = _jaxext.batchufunc(self.core, maxnbytes=maxnbytes) 1eabcd
291 return self._clone(core=core) 1eabcd
293 @classmethod 1feabcd
294 def _crossmro(cls): 1feabcd
295 """ MRO iterator excluding subclasses of Kernel """
296 for c in cls.mro(): # pragma: no branch 1feabcd
297 if not issubclass(c, Kernel): 1feabcd
298 yield c 1feabcd
299 if c is __class__: 1feabcd
300 break 1abcd
302 _transf = {} 1feabcd
304 _Transf = collections.namedtuple('_Transf', ['func', 'doc', 'kind']) 1feabcd
306 def __init_subclass__(cls, **kw): 1feabcd
307 super().__init_subclass__(**kw) 1feabcd
308 cls._transf = {} 1feabcd
309 cls.__slots__ = () 1feabcd
311 @classmethod 1feabcd
312 def _transfmro(cls): 1feabcd
313 """ Iterator of superclasses with a _transf attribute """
314 for c in cls.mro(): # pragma: no branch 1feabcd
315 yield c 1feabcd
316 if c is __class__: 1feabcd
317 break 1feabcd
319 @classmethod 1feabcd
320 def _settransf(cls, transfname, transf): 1feabcd
321 if transfname in cls._transf: 1feabcd
322 raise KeyError(f'transformation {transfname!r} already registered ' 1abcd
323 f'for {cls.__name__}')
324 cls._transf[transfname] = cls._Transf(*transf) 1feabcd
326 @classmethod 1feabcd
327 def _alltransf(cls): 1feabcd
328 """ list all accessible transfs as dict name -> (tcls, transf) """
329 transfs = {} 1feabcd
330 for tcls in cls._transfmro(): 1feabcd
331 for name, transf in tcls._transf.items(): 1feabcd
332 transfs.setdefault(name, (tcls, transf)) 1feabcd
333 return transfs 1feabcd
335 @classmethod 1feabcd
336 def _gettransf(cls, transfname, transfmro=None): 1feabcd
337 """
338 Find a transformation.
340 The transformation is searched following the MRO up to `CrossKernel`.
342 Parameters
343 ----------
344 transfname : hashable
345 The transformation name.
347 Returns
348 -------
349 cls : type
350 The class where the transformation was found.
351 transf, doc, kind : tuple
352 The objects set by `register_transf`.
354 Raises
355 ------
356 KeyError :
357 The transformation was not found.
358 """
359 if transfmro is None: 1feabcd
360 transfmro = cls._transfmro() 1feabcd
361 for c in transfmro: 1feabcd
362 try: 1feabcd
363 return c, c._transf[transfname] 1feabcd
364 except KeyError: 1feabcd
365 pass 1feabcd
366 raise KeyError(transfname) 1abcd
368 @classmethod 1feabcd
369 def inherit_transf(cls, transfname, *, intermediates=False): 1feabcd
370 """
372 Inherit a transformation from a superclass.
374 Parameters
375 ----------
376 transfname : hashable
377 The name of the transformation.
378 intermediates : bool, default False
379 If True, make all superclasses up to the one definining the
380 transformation inherit it too.
382 Raises
383 ------
384 KeyError :
385 The transformation was not found in any superclass, or the
386 transformation is already registered on any of the target classes.
388 See also
389 --------
390 transf
392 """
393 tcls, transf = cls._gettransf(transfname) 1feabcd
394 cls._settransf(transfname, transf) 1feabcd
395 if intermediates: 1feabcd
396 for c in cls.mro()[1:]: # pragma: no branch 1feabcd
397 if c is tcls: 1feabcd
398 break 1feabcd
399 c._settransf(transfname, transf) 1feabcd
401 @classmethod 1feabcd
402 def inherit_all_algops(cls, intermediates=False): 1feabcd
403 """
405 Inherit all algebraic operations from superclasses.
407 This makes sense if the class represents a subalgebra, i.e., it
408 should be preserved by addition and multiplication.
410 Parameters
411 ----------
412 intermediates : bool, default False
413 If True, make all superclasses up to the one definining the
414 transformation inherit it too.
416 Raises
417 ------
418 KeyError :
419 An algebraic operation is already registered for one of the target
420 classes.
422 See also
423 --------
424 transf
426 """
427 mro = cls._transfmro() 1feabcd
428 next(mro) 1feabcd
429 for name, (_, transf) in next(mro)._alltransf().items(): 1feabcd
430 if transf.kind is cls._algopmarker: 1feabcd
431 cls.inherit_transf(name, intermediates=intermediates) 1feabcd
433 @classmethod 1feabcd
434 def list_transf(cls, superclasses=True): 1feabcd
435 """
436 List all the available transformations.
438 Parameters
439 ----------
440 superclasses : bool, default True
441 Include transformations defined in superclasses.
443 Returns
444 -------
445 transfs: dict of Transf
446 The dictionary keys are the transformation names, the values are
447 named tuples ``(tcls, kind, impl, doc)`` where ``tcls`` is the class
448 defining the transformation, ``kind`` is the kind of transformation,
449 ``impl`` is the implementation with signature ``impl(tcls, self,
450 *args, **kw)``, ``doc`` is the docstring.
451 """
452 if superclasses: 1abcd
453 source = cls._alltransf().items 1abcd
454 else:
455 def source(): 1abcd
456 for name, transf in cls._transf.items(): 1abcd
457 yield name, (cls, transf) 1abcd
458 return { 1abcd
459 name: cls.Transf(tcls, transf.kind, transf.func, transf.doc)
460 for name, (tcls, transf) in source()
461 }
463 Transf = collections.namedtuple('Transf', ['tcls', 'kind', 'func', 'doc']) 1feabcd
465 @classmethod 1feabcd
466 def has_transf(cls, transfname): 1feabcd
467 """
468 Check if a transformation is registered.
470 Parameters
471 ----------
472 transfname : hashable
473 The transformation name.
475 Returns
476 -------
477 has_transf : bool
478 Whether the transformation is registered.
480 See also
481 --------
482 transf
483 """
484 try: 1abcd
485 cls._gettransf(transfname) 1abcd
486 except KeyError as exc: 1abcd
487 if exc.args == (transfname,): 1abcd
488 return False 1abcd
489 else: # pragma: no cover
490 raise
491 else:
492 return True 1abcd
494 @classmethod 1feabcd
495 def transf_help(cls, transfname): 1feabcd
496 """
498 Return the documentation of a transformation.
500 Parameters
501 ----------
502 transfname : hashable
503 The name of the transformation.
505 Returns
506 -------
507 doc : str
508 The documentation of the transformation.
510 See also
511 --------
512 transf
514 """
515 _, transf = cls._gettransf(transfname) 1abcd
516 return transf.doc 1abcd
518 def transf(self, transfname, *args, **kw): 1feabcd
519 """
521 Return a transformed kernel.
523 Parameters
524 ----------
525 transfname : hashable
526 A name identifying the transformation.
527 *args, **kw :
528 Arguments to the transformation.
530 Returns
531 -------
532 newkernel : object
533 The output of the transformation.
535 Raises
536 ------
537 KeyError
538 The transformation is not defined in this class or any superclass.
540 See also
541 --------
542 linop, algop, transf_help, has_transf, list_transf, super_transf, register_transf, register_linop, register_corelinop, register_xtransf, register_algop, register_ufuncalgop
544 """
545 tcls, transf = self._gettransf(transfname) 1abcd
546 return transf.func(tcls, self, *args, **kw) 1abcd
548 @classmethod 1feabcd
549 def super_transf(cls, transfname, self, *args, **kw): 1feabcd
550 """
552 Transform the kernel using a superclass transformation.
554 This is equivalent to `transf` but is invoked on a class and the
555 object is passed explicitly. The definition of the transformation is
556 searched starting from the next class in the MRO of `self`.
558 Parameters
559 ----------
560 transfname, *args, **kw :
561 See `transf`.
562 self : CrossKernel
563 The object to transform.
565 Returns
566 -------
567 newkernel : object
568 The output of the transformation.
570 """
571 mro = list(self._transfmro()) 1eabcd
572 idx = mro.index(cls) 1eabcd
573 tcls, transf = self._gettransf(transfname, mro[idx + 1:]) 1eabcd
574 return transf.func(tcls, self, *args, **kw) 1eabcd
576 def linop(self, transfname, *args, **kw): 1feabcd
577 r"""
579 Transform kernels to represent the application of a linear operator.
581 .. math::
582 \text{kernel}_1(x, y) &= \mathrm{Cov}[f_1(x), g_1(y)], \\
583 \text{kernel}_2(x, y) &= \mathrm{Cov}[f_2(x), g_2(y)], \\
584 &\ldots \\
585 \text{newkernel}(x, y) &=
586 \mathrm{Cov}[T_f(f_1, f_2, \ldots)(x), T_g(g_1, g_2, \ldots)(y)]
588 Parameters
589 ----------
590 transfname : hashable
591 The name of the transformation.
592 *args :
593 A sequence of `CrossKernel` instances, representing the operands,
594 followed by one or two non-kernel arguments, indicating how to act
595 on each side of the kernels. If both arguments represent the
596 identity, this is a no-op. If there is only one argument, it is
597 intended that the two arguments are equal. `None` always represents
598 the identity.
600 Returns
601 -------
602 newkernel : CrossKernel
603 The transformed kernel.
605 Raises
606 ------
607 ValueError :
608 The transformation exists but was not defined by `register_linop`.
610 See also
611 --------
612 transf
614 Notes
615 -----
616 The linear operator is defined on the function the kernel represents the
617 distribution of, not in full generality on the kernel itself. When
618 multiple kernels are involved, their distributions may be considered
619 independent or not, depending on the specific operation.
621 If the result is a subclass of the class defining the transformation,
622 the result is casted to the latter. Then, if the result and all the
623 operands are instances of `Kernel`, but the two operator arguments
624 differ, the result is casted to its first non-`Kernel` superclass.
626 """
627 tcls, transf = self._gettransf(transfname) 1feabcd
628 if transf.kind is not self._linopmarker: 1feabcd
629 raise ValueError(f'the transformation {transfname!r} was not ' 1abcd
630 f'defined with register_linop and so can not be invoked '
631 f'by linop')
632 return transf.func(tcls, self, *args) 1feabcd
634 def algop(self, transfname, *operands, **kw): 1feabcd
635 r"""
637 Return a nonnegative algebraic transformation of the input kernels.
639 .. math::
640 \mathrm{newkernel}(x, y) &=
641 f(\mathrm{kernel}_1(x, y), \mathrm{kernel}_2(x, y), \ldots), \\
642 f(z_1, z_2, \ldots) &= \sum_{k_1,k_2,\ldots=0}^\infty
643 a_{k_1 k_2 \ldots} z_1^{k_1} z_2^{k_2} \ldots,
644 \quad a_* \ge 0.
646 Parameters
647 ----------
648 transfname : hashable
649 A name identifying the transformation.
650 *operands : CrossKernel, scalar
651 Arguments to the transformation in addition to self.
652 **kw :
653 Additional arguments to the transformation, not considered as
654 operands.
656 Returns
657 -------
658 newkernel : CrossKernel or NotImplemented
659 The transformed kernel, or NotImplemented if the operation is
660 not supported.
662 See also
663 --------
664 transf
666 Notes
667 -----
668 The class of `newkernel` is the least common superclass of: the
669 "natural" output of the operation, the class defining the
670 transformation, and the classes of the operands.
672 For class determination, scalars in the input count as `Constant`
673 if nonnegative or traced by jax, else `CrossConstant`.
675 """
676 tcls, transf = self._gettransf(transfname) 1feabcd
677 if transf.kind is not self._algopmarker: 1feabcd
678 raise ValueError(f'the transformation {transfname!r} was not ' 1abcd
679 f'defined with register_algop and so can not be invoked '
680 f'by algop')
681 return transf.func(tcls, self, *operands, **kw) 1feabcd
683 @classmethod 1feabcd
684 def register_transf(cls, func, transfname=None, doc=None, kind=None): 1feabcd
685 """
687 Register a transformation for use with `transf`.
689 The transformation will be accessible to subclasses.
691 Parameters
692 ----------
693 func : callable
694 A function ``func(tcls, self, *args, **kw) -> object`` implementing
695 the transformation, where ``tcls`` is the class that defines the
696 transformation.
697 transfname : hashable, optional.
698 The `transfname` parameter to `transf` this transformation will
699 be accessible under. If not specified, use the name of `func`.
700 doc : str, optional
701 The documentation of the transformation for `transf_help`. If not
702 specified, use the docstring of `func`.
703 kind : object, optional
704 An arbitrary marker.
706 Returns
707 -------
708 func : callable
709 The `func` argument as is.
711 Raises
712 ------
713 KeyError :
714 The name is already in use for another transformation in the same
715 class.
717 See also
718 --------
719 transf
721 """
722 if transfname is None: 1feabcd
723 transfname = func.__name__ 1feabcd
724 if doc is None: 1feabcd
725 doc = func.__doc__ 1feabcd
726 cls._settransf(transfname, (func, doc, kind)) 1feabcd
727 return func 1feabcd
729 # TODO forbid to override with a different kind?
731 @classmethod 1feabcd
732 def register_linop(cls, op, transfname=None, doc=None, argparser=None): 1feabcd
733 """
735 Register a transformation for use with `linop`.
737 Parameters
738 ----------
739 op : callable
740 A function ``op(tcls, self, arg1, arg2, *operands) -> CrossKernel``
741 that returns the new kernel, where ``arg1`` and ``arg2`` represent
742 the operators acting on each side of the kernels, and ``operands``
743 are the other kernels beyond ``self``.
744 transfname, doc : optional
745 See `register_transf`.
746 argparser : callable, optional
747 A function applied to ``arg1`` and ``arg2``. Not called if the
748 argument is `None`. It should map the identity to `None`.
750 Returns
751 -------
752 op : callable
753 The `op` argument as is.
755 Notes
756 -----
757 The function `op` is called only if ``arg1`` or ``arg2`` is not `None`
758 after potential conversion with `argparser`.
760 See also
761 --------
762 transf
764 """
766 if transfname is None: 1feabcd
767 transfname = op.__name__ # for result type error message 1feabcd
769 @functools.wraps(op) 1feabcd
770 def func(tcls, self, *allargs): 1feabcd
772 # split the arguments in kernels and non-kernels
773 for pos, arg in enumerate(allargs): 1feabcd
774 if not isinstance(arg, __class__): 1feabcd
775 break 1feabcd
776 else:
777 pos = len(allargs) 1abcd
778 operands = allargs[:pos] 1feabcd
779 args = allargs[pos:] 1feabcd
781 # check the arguments from the first non-kernel onwards are 1 or 2
782 if len(args) not in (1, 2): 1feabcd
783 raise ValueError(f'incorrect number of non-kernel tail ' 1abcd
784 f'arguments {len(args)}, expected 1 or 2')
786 # wrap argument parser to enforce preserving None
787 if argparser: 1feabcd
788 conv = lambda x: None if x is None else argparser(x) 1feabcd
789 else:
790 conv = lambda x: x 1eabcd
792 # determine if the two arguments count as "identical" or not
793 if len(args) == 1: 1feabcd
794 arg = conv(*args) 1feabcd
795 different = False 1feabcd
796 arg1 = arg2 = arg 1feabcd
797 else:
798 arg1, arg2 = args 1feabcd
799 different = arg1 is not arg2 1feabcd
800 arg1 = conv(arg1) 1feabcd
801 arg2 = conv(arg2) 1feabcd
802 different &= arg1 is not arg2 1feabcd
803 # they must be not identical both before and after to handle
804 # these cases:
805 # - if the user passes identical arguments, but argparser
806 # makes copies, it must still count as identical
807 # - if the user passes arguments which are not identical,
808 # but argparser sends them to the same object, then they
809 # surely represent the same transf
811 # handle no-op case
812 if arg1 is None and arg2 is None: 1feabcd
813 return self 1feabcd
815 # invoke implementation
816 result = op(tcls, self, arg1, arg2, *operands) 1feabcd
818 # check result is a kernel
819 if not isinstance(result, __class__): 1feabcd
820 raise TypeError(f'linop {transfname!r} returned ' 1abcd
821 f'object of type {result.__class__.__name__}, expected '
822 f'subclass of {__class__.__name__}')
824 # modify class of the result
825 rcls = result.__class__ 1feabcd
826 if issubclass(rcls, tcls): 1feabcd
827 rcls = tcls 1feabcd
828 all_operands_kernel = all(isinstance(o, Kernel) for o in operands) 1feabcd
829 if isinstance(self, Kernel) and all_operands_kernel and different: 1feabcd
830 rcls = next(rcls._crossmro()) 1feabcd
831 if rcls is not result.__class__: 1feabcd
832 result = result._clone(rcls) 1feabcd
834 return result 1feabcd
836 cls.register_transf(func, transfname, doc, cls._linopmarker) 1feabcd
837 return op 1feabcd
839 class _LinOpMarker(str): pass 1feabcd
840 _linopmarker = _LinOpMarker('linop') 1feabcd
842 @classmethod 1feabcd
843 def register_corelinop(cls, corefunc, transfname=None, doc=None, argparser=None): 1feabcd
844 """
846 Register a linear operator with a function that acts only on the core.
848 Parameters
849 ----------
850 corefunc : callable
851 A function ``corefunc(core, arg1, arg2, *cores) -> newcore``, where
852 ``core`` is the function that implements the kernel passed at
853 initialization, and ``cores`` for other operands.
854 transfname, doc, argparser :
855 See `register_linop`.
857 Returns
858 -------
859 corefunc : callable
860 The `corefunc` argument as is.
862 See also
863 --------
864 transf
866 """
867 @functools.wraps(corefunc) 1feabcd
868 def op(_, self, arg1, arg2, *operands): 1feabcd
869 cores = (o.core for o in operands) 1feabcd
870 core = corefunc(self.core, arg1, arg2, *cores) 1feabcd
871 return self._clone(core=core) 1feabcd
872 cls.register_linop(op, transfname, doc, argparser) 1feabcd
873 return corefunc 1feabcd
875 @classmethod 1feabcd
876 def register_xtransf(cls, xfunc, transfname=None, doc=None): 1feabcd
877 """
879 Register a linear operator that acts only on the input.
881 Parameters
882 ----------
883 xfunc : callable
884 A function ``xfunc(arg) -> (lambda x: newx)`` that takes in a
885 `linop` argument and produces a function to transform the input. Not
886 called if ``arg`` is `None`. To indicate the identity, return
887 `None`.
888 transfname, doc :
889 See `register_linop`. `argparser` is not provided because its
890 functionality can be included in `xfunc`.
892 Returns
893 -------
894 xfunc : callable
895 The `xfunc` argument as is.
897 See also
898 --------
899 transf
901 """
903 @functools.wraps(xfunc) 1feabcd
904 def corefunc(core, xfun, yfun): 1feabcd
905 if not xfun: 1feabcd
906 return lambda x, y, **kw: core(x, yfun(y), **kw) 1abcd
907 elif not yfun: 1feabcd
908 return lambda x, y, **kw: core(xfun(x), y, **kw) 1abcd
909 else:
910 return lambda x, y, **kw: core(xfun(x), yfun(y), **kw) 1feabcd
912 cls.register_corelinop(corefunc, transfname, doc, xfunc) 1feabcd
913 return xfunc 1feabcd
915 @classmethod 1feabcd
916 def register_algop(cls, op, transfname=None, doc=None): 1feabcd
917 """
919 Register a transformation for use with `algop`.
921 Parameters
922 ----------
923 op : callable
924 A function ``op(tcls, *kernels, **kw) -> CrossKernel |
925 NotImplemented`` that returns the new kernel. ``kernels`` may be
926 scalars but for the first argument.
927 transfname, doc :
928 See `register_transf`.
930 Returns
931 -------
932 op : callable
933 The `op` argument as is.
935 See also
936 --------
937 transf
939 """
941 if transfname is None: 1feabcd
942 transfname = op.__name__ # for error message 1feabcd
944 @functools.wraps(op) 1feabcd
945 def func(tcls, *operands, **kw): 1feabcd
946 result = op(tcls, *operands, **kw) 1feabcd
948 if result is NotImplemented: 1feabcd
949 return result 1abcd
950 elif not isinstance(result, __class__): 1feabcd
951 raise TypeError(f'algop {transfname!r} returned ' 1abcd
952 f'object of type {result.__class__.__name__}, expected '
953 f'subclass of {__class__.__name__}')
955 def classes(): 1feabcd
956 yield tcls 1feabcd
957 for o in operands: 1feabcd
958 if isinstance(o, __class__): 1feabcd
959 yield o.__class__ 1feabcd
960 elif _util.is_nonnegative_scalar_trueontracer(o): 1feabcd
961 yield Constant 1feabcd
962 elif _util.is_numerical_scalar(o): 1abcd
963 yield CrossConstant 1abcd
964 else:
965 raise TypeError(f'operands to algop {transfname!r} ' 1abcd
966 f'must be CrossKernel or numbers, found {o!r}')
967 # this type check comes after letting the implementation
968 # return NotImplemented, to support overloading
969 yield result.__class__ 1feabcd
971 lcs = least_common_superclass(*classes()) 1feabcd
972 return result._clone(lcs) 1feabcd
974 cls.register_transf(func, transfname, doc, cls._algopmarker) 1feabcd
975 return op 1feabcd
977 # TODO delete initkw (also in linop) if there's more than one kernel
978 # operand or if the class changed?
980 # TODO consider adding an option domains=callable, returns list of
981 # domains from the operands, or list of tuples right away, and the
982 # impl checks at runtime (if not traced) that the output values are in
983 # the domains, with an informative error message
985 # TODO consider escalating to ancestor definitions of the transf
986 # until an impl does not return NotImplemented, and then raise an
987 # an error instead of letting the NotImplemented escape. This would be
988 # useful to partial behavior ovverride in subclasses, e.g, handle
989 # multiplication by scalar to rewire transformations.
991 class _AlgOpMarker(str): pass 1feabcd
992 _algopmarker = _AlgOpMarker('algop') 1feabcd
994 @classmethod 1feabcd
995 def register_ufuncalgop(cls, ufunc, transfname=None, doc=None): 1feabcd
996 """
998 Register an algebraic operation with a function that acts only on the
999 kernel value.
1001 Parameters
1002 ----------
1003 corefunc : callable
1004 A function ``ufunc(*values, **kw) -> value``, where ``values`` are
1005 the values yielded by the operands.
1006 transfname, doc :
1007 See `register_transf`.
1009 Returns
1010 -------
1011 ufunc : callable
1012 The `ufunc` argument as is.
1014 See also
1015 --------
1016 transf
1018 """
1019 @functools.wraps(ufunc) 1feabcd
1020 def op(_, self, *operands, **kw): 1feabcd
1021 cores = tuple( 1fabcd
1022 o.core if isinstance(o, __class__)
1023 else lambda x, y: o
1024 for o in (self, *operands)
1025 )
1026 def core(x, y, **kw): 1fabcd
1027 values = (core(x, y, **kw) for core in cores) 1abcd
1028 return ufunc(*values, **kw) 1abcd
1029 return self._clone(core=core) 1fabcd
1030 cls.register_algop(op, transfname, doc) 1feabcd
1031 return ufunc 1feabcd
1033 @classmethod 1feabcd
1034 def make_linop_family(cls, transfname, bothker, leftker, rightker=None, *, 1feabcd
1035 doc=None, argparser=None, argnames=None, translkw=None):
1036 """
1038 Form a family of kernels classes related by linear operators.
1040 The class this method is called on is the seed class. A new
1041 trasformation is registered to obtain other classes from the seed class
1042 by applying a newly defined linop transformation.
1044 Parameters
1045 ----------
1046 transfname : str
1047 The name of the new transformation.
1048 bothker, leftker, rightker : CrossKernel
1049 The kernel classes to be obtained by applying the operator to a seed
1050 class object respectively on both sides, only left, or only right.
1051 All classes are assumed to require no positional arguments at
1052 construction, and recognize the same set of keyword arguments. If
1053 `rightker` is not specified, it is defined by subclassing `leftker`
1054 and transposing the kernel on object construction.
1055 doc, argparser : callable, optional
1056 See `register_linop`.
1057 argnames : pair of str, optional
1058 If specified, `leftker` is passed an additional keyword argument
1059 with name ``argnames[0]`` that specifies the argument of the
1060 operator, `rightker` is passed ``argnames[1]``, and `bothker`
1061 similarly is passed both.
1063 If `leftker` is used to implement `rightker`, it must accept both
1064 arguments, and may use them to determine if it is being invoked as
1065 `leftker` or `rightker`.
1066 translkw : callable, optional
1067 A function with signature ``translkw(dynkw, <argnames>, **initkw) ->
1068 dict`` that determines the constructor arguments for a new object
1069 starting from the ones of the object to be transformed. By default,
1070 ``initkw`` is passed over, and an error is raised if ``dynkw`` is
1071 not empty.
1073 Examples
1074 --------
1076 >>> @lgp.kernel
1077 ... def A(x, y, *, gatto):
1078 ... ''' The reknown A kernel of order gatto '''
1079 ... return ...
1080 ...
1081 >>> @lgp.kernel
1082 ... def T(n, m, *, gatto, op):
1083 ... ''' The kernel of the Topo series of an A process of order
1084 ... gatto '''
1085 ... op1, op2 = op
1086 ... return ...
1087 ...
1088 >>> @lgp.crosskernel
1089 ... def CrossTA(n, x, *, gatto, op):
1090 ... ''' The cross covariance between the Topo series of an A
1091 ... process of order gatto and the process itself '''
1092 ... return ...
1093 ...
1094 >>> A.make_linop_family('topo', T, CrossTA)
1095 >>> a = A(gatto=7)
1096 >>> ta = A.linop('topo', True, None)
1097 >>> at = A.linop('topo', None, True)
1098 >>> t = A.linop('topo', True)
1100 See also
1101 --------
1102 transf
1104 """
1106 if rightker is None: 1feabcd
1108 # invent a name for rightker
1109 rightname = f'Cross{cls.__name__}{bothker.__name__}' 1feabcd
1111 # define how to set up rightker
1112 def exec_body(ns): 1feabcd
1114 if leftker.__doc__: 1feabcd
1115 header = 'Automatically generated transposed version of:\n\n' 1abcd
1116 ns['__doc__'] = _utils.append_to_docstring(leftker.__doc__, header, front=True) 1abcd
1118 def __new__(cls, *args, **kw): 1feabcd
1119 self = super(rightker, cls).__new__(cls, *args, **kw) 1abcd
1121 if self.__class__ is cls: 1abcd
1122 self = self._swap() 1abcd
1123 if not isinstance(self, leftker): 1123 ↛ 1124line 1123 didn't jump to line 1124 because the condition on line 1123 was never true1abcd
1124 raise TypeError(f'newly created instance of '
1125 f'automatically defined {rightker.__name__} is not an '
1126 f'instance of {leftker.__name__} after '
1127 f'transposition. Either define transposition '
1128 f'for {leftker.__name__}, or define '
1129 f'{rightker.__name__} manually')
1130 return self._clone(cls) 1abcd
1132 else:
1133 return self._swap() 1abcd
1135 ns['__new__'] = __new__ 1feabcd
1137 # create rightker, evil twin of leftker separated at birth
1138 rightker = types.new_class(rightname, (leftker,), exec_body=exec_body) 1feabcd
1140 # check which classes are symmetric
1141 classes = cls, bothker, leftker, rightker 1feabcd
1142 sym = tuple(issubclass(c, Kernel) for c in classes) 1feabcd
1143 exp = True, True, False, False 1feabcd
1144 if sym != exp: 1feabcd
1145 desc = lambda t: 'Kernel' if t else 'non-Kernel' 1abcd
1146 warnings.warn(f'Expected classes pattern {", ".join(map(desc, exp))}, ' 1abcd
1147 f'found {", ".join(map(desc, sym))}')
1149 # set translkw if not specified
1150 if translkw is None: 1feabcd
1151 def translkw(*, dynkw, **initkw): 1abcd
1152 if dynkw: 1152 ↛ 1153line 1152 didn't jump to line 1153 because the condition on line 1152 was never true1abcd
1153 raise ValueError('found non-empty `dynkw`, the default '
1154 'implementation of `translkw` does not support it')
1155 return initkw 1abcd
1157 # function to produce the arguments to the transformed objects
1158 def makekw(self, arg1, arg2): 1feabcd
1159 kw = dict(dynkw=self.dynkw, **self.initkw) 1eabcd
1160 if argnames is not None: 1eabcd
1161 if arg1 is not None: 1abcd
1162 kw = dict(**kw, **{argnames[0]: arg1}) 1abcd
1163 if arg2 is not None: 1abcd
1164 kw = dict(**kw, **{argnames[1]: arg2}) 1abcd
1165 return translkw(**kw) 1eabcd
1167 # register linop mapping cls to either leftker, rightker or bothker
1168 regkw = dict(transfname=transfname, doc=doc, argparser=argparser) 1feabcd
1169 @functools.partial(cls.register_linop, **regkw) 1feabcd
1170 def op_seed_to_siblings(_, self, arg1, arg2): 1feabcd
1171 kw = makekw(self, arg1, arg2) 1eabcd
1172 if arg2 is None: 1eabcd
1173 return leftker(**kw) 1eabcd
1174 elif arg1 is None: 1eabcd
1175 return rightker(**kw) 1abcd
1176 else:
1177 return bothker(**kw) 1eabcd
1179 # register linop mapping leftker to bothker
1180 @functools.partial(leftker.register_linop, **regkw) 1feabcd
1181 def op_left_to_both(_, self, arg1, arg2): 1feabcd
1182 if arg1 is None: 1abcd
1183 return bothker(**makekw(self, arg1, arg2)) 1abcd
1184 else:
1185 raise ValueError(f'cannot further transform ' 1abcd
1186 f'`{leftker.__name__}` on left side with linop '
1187 f'{transfname!r}')
1189 # register linop mapping rightker to bothker
1190 @functools.partial(rightker.register_linop, **regkw) 1feabcd
1191 def op_right_to_both(_, self, arg1, arg2): 1feabcd
1192 if arg2 is None: 1abcd
1193 return bothker(**makekw(self, arg1, arg2)) 1abcd
1194 else:
1195 raise ValueError(f'cannot further transform ' 1abcd
1196 f'`{rightker.__name__}` on right side with linop '
1197 f'{transfname!r}')
1199class AffineSpan(CrossKernel, abc.ABC): 1feabcd
1200 """
1202 Kernel that tracks affine transformations.
1204 An `AffineSpan` instance accumulates the overall affine transformation
1205 applied to its inputs and output.
1207 `AffineSpan` and it subclasses are preserved by the transformations
1208 'scale', 'loc', 'add' (with scalar) and 'mul' (with scalar).
1210 `AffineSpan` can not be instantiated directly or used as standalone
1211 superclass. It must be the first base before concrete superclasses.
1213 """
1215 _affine_dynkw = dict(lloc=0, rloc=0, lscale=1, rscale=1, offset=0, ampl=1) 1feabcd
1217 def __new__(cls, *args, dynkw={}, **kw): 1feabcd
1218 if cls is __class__: 1eabcd
1219 raise TypeError(f'cannot instantiate {__class__.__name__} directly') 1abcd
1220 new_dynkw = dict(cls._affine_dynkw) 1eabcd
1221 new_dynkw.update(dynkw) 1eabcd
1222 return super().__new__(cls, *args, dynkw=new_dynkw, **kw) 1eabcd
1224 def __init_subclass__(cls, **kw): 1feabcd
1225 super().__init_subclass__(**kw) 1feabcd
1226 for name in __class__._transf: 1feabcd
1227 cls.inherit_transf(name) 1feabcd
1229 # TODO I would like to inherit conditional on there not being another
1230 # definition. However, since registrations are done after class
1231 # creation, this method is invoked too early to do that.
1232 #
1233 # Is it possible to reimplement the transformation system to have the
1234 # transformations defined in the class body?
1235 #
1236 # Option 1: use a decorator that marks the transf methods and makes
1237 # them staticmethods. Then an __init_subclass__ above CrossKernel goes
1238 # through the methods and registers the marked ones.
1239 #
1240 # Option 2: try to use descriptors (classes like property). Is it
1241 # possible to reimplement from scratch something like classmethod? If
1242 # so, I could make a transfmethod that bounds two arguments: tcls
1243 # for the class that defines it, and self for the object that invokes
1244 # it. Then inheriting would be just copying the thing over. If I wanted
1245 # to keep the invocation through method interface, I could define the
1246 # methods with underscores and register them somewhere.
1247 #
1248 # However, since I define kernels with decorators, I can't actually
1249 # count on the transformations being defined into them at class
1250 # creation. But that would be quite a corner case.
1251 #
1252 # To make subclassing convenient without decorators, allow to define a
1253 # core(x, y) method, which __init_subclass__ converts to staticmethod,
1254 # and have __new__ use it to set _core (I can't use the same name with
1255 # slots).
1256 #
1257 # to use without post-class registering make_linopfamily, make it a
1258 # decorator with attributed subdecorators for other members:
1259 #
1260 # @linop_family('T')
1261 # @kernel
1262 # def A(...)
1263 # @A.left('arg')
1264 # @crosskernel
1265 # def CrossTA(...)
1266 # @A.right # optional, generated from left if not specified
1267 # ...
1268 # @A.both('arg1', 'arg2')
1269 # @kernel
1270 # def T(...)
1271 #
1272 # Can also be stacked on top of left/right/both to chain families. Works
1273 # both with decorators and explicit class definitions.
1275 def _clone(self, *args, **kw): 1feabcd
1276 newself = super()._clone(*args, **kw) 1eabcd
1277 if isinstance(newself, __class__): 1eabcd
1278 for name in self._affine_dynkw: 1eabcd
1279 newself._dynkw[name] = self._dynkw[name] 1eabcd
1280 return newself 1eabcd
1282 @classmethod 1feabcd
1283 def __subclasshook__(cls, sub): 1feabcd
1284 if cls is __class__: 1eabcd
1285 return NotImplemented 1eabcd
1286 # to avoid algops promoting to unqualified AffineSpan
1287 if issubclass(cls, Kernel): 1eabcd
1288 if issubclass(sub, Constant): 1eabcd
1289 return True 1abcd
1290 else:
1291 return NotImplemented 1eabcd
1292 elif issubclass(sub, CrossConstant): 1abcd
1293 return True 1abcd
1294 else:
1295 return NotImplemented 1abcd
1297 # TODO I could do separately AffineLeft, AffineRight, AffineOut, and make
1298 # this a subclass of those three. AffineOut would also allow keeping the
1299 # class when adding two objects without keyword arguments in initkw and
1300 # dynkw beyond those managed by Affine. => Make only AffineSide and have it
1301 # switch side on _swap.
1303 # TODO when I reimplement transformations as methods, make AffineSpan not
1304 # a subclass of CrossKernel. Right now I have to to avoid routing around
1305 # the assumption that all classes in the MRO implement the transformation
1306 # management logic.
1308class PreservedBySwap(CrossKernel): 1feabcd
1310 def __new__(cls, *args, **kw): 1feabcd
1311 if cls is __class__: 1311 ↛ 1312line 1311 didn't jump to line 1312 because the condition on line 1311 was never true1eabcd
1312 raise TypeError(f'cannot instantiate {__class__.__name__} directly')
1313 return super().__new__(cls, *args, **kw) 1eabcd
1315 def _swap(self): 1feabcd
1316 return super()._swap()._clone(self.__class__) 1eabcd
1318 # TODO when I implement transformations with methods, make this not an
1319 # instance of CrossKernel.