Coverage for src/lsqfitgp/_signature.py: 90%
65 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_signature.py
2#
3# Copyright (c) 2023, 2024 Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" define Signature """
22import inspect 1eafbcd
24try: 1eafbcd
25 from numpy.lib import function_base # numpy 1 1eafbcd
26except ImportError: 1eabcd
27 from numpy.lib import _function_base_impl as function_base # numpy 2 1eabcd
28import jax 1eafbcd
29from jax import numpy as jnp 1eafbcd
31class Signature: 1eafbcd
32 """ Class to parse a numpy gufunc signature. """
34 def __init__(self, signature): 1eafbcd
35 self.signature = signature 1eafbcd
36 self.incores, self.outcores = function_base._parse_gufunc_signature(signature) 1eafbcd
38 @classmethod 1eafbcd
39 def from_tuples(cls, incores, outcores): 1eafbcd
40 self = cls.__new__(cls) 1eafbcd
41 tuplestr = lambda t: '(' + ','.join(map(str, t)) + ')' 1eafbcd
42 self.signature = ','.join(map(tuplestr, incores)) + '->' + ','.join(map(tuplestr, outcores)) 1eafbcd
43 self.incores = incores 1eafbcd
44 self.outcores = outcores 1eafbcd
45 return self 1eafbcd
47 def __repr__(self): 1eafbcd
48 return self.signature
50 def check_nargs(self, func): 1eafbcd
51 """ Check that the function has the correct number of arguments. """
52 sig = inspect.signature(func) 1eafbcd
53 if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()): 1eafbcd
54 return 1afbcd
55 if len(sig.parameters) != len(self.incores): 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true1eafbcd
56 raise ValueError(f'function {func} has {len(sig.parameters)} '
57 f'arguments, but signature {self.signature} '
58 f'requires {len(self.incores)}')
60 @property 1eafbcd
61 def nin(self): 1eafbcd
62 return len(self.incores) 1eafbcd
64 @property 1eafbcd
65 def nout(self): 1eafbcd
66 return len(self.outcores)
68 def eval(self, *args): 1eafbcd
69 """
71 Evaluate the signature with the given arguments.
73 Parameters
74 ----------
75 args : sequence of numpy.ndarray or None
76 A missing argument can be replaced with None, provided the other
77 arguments are sufficient to infer all dimension sizes.
79 Returns
80 -------
81 sig : EvaluatedSignature
82 An object with attributes `broadcast_shape`, `sizes`,
83 `core_out_shapes`, `out_shapes`, `core_in_shapes`, `in_shapes`.
85 """
86 return self.EvaluatedSignature(self, *args) 1eafbcd
88 class EvaluatedSignature: 1eafbcd
90 def __init__(self, sig, *args): 1eafbcd
92 assert len(args) == len(sig.incores) 1eafbcd
94 known_args = [] 1eafbcd
95 known_cores = [] 1eafbcd
96 missing_cores = [] 1eafbcd
97 for arg, core in zip(args, sig.incores): 1eafbcd
98 if arg is None: 1eafbcd
99 missing_cores.append(core) 1eafbcd
100 else:
101 known_args.append(jax.ShapeDtypeStruct(arg.shape, 'd')) 1eafbcd
102 known_cores.append(core) 1eafbcd
104 self.broadcast_shape, self.sizes = function_base._parse_input_dimensions(known_args, known_cores) 1eafbcd
106 missing_indices = set(sum(missing_cores, ())) 1eafbcd
107 missing_indices.difference_update(self.sizes) 1eafbcd
108 if missing_indices: 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true1eafbcd
109 raise ValueError(f'cannot infer sizes of dimesions {missing_indices} from signature {sig.signature}')
111 self.core_out_shapes, self.out_shapes = self._compute_shapes(sig.outcores) 1eafbcd
112 self.core_in_shapes, self.in_shapes = self._compute_shapes(sig.incores) 1eafbcd
114 def _compute_shapes(self, cores): 1eafbcd
115 coreshapes = [] 1eafbcd
116 shapes = [] 1eafbcd
117 for core in cores: 1eafbcd
118 core = tuple(self.sizes[i] for i in core) 1eafbcd
119 coreshapes.append(core) 1eafbcd
120 shapes.append(self.broadcast_shape + core) 1eafbcd
121 return tuple(coreshapes), tuple(shapes) 1eafbcd
123 def _repr(self, shapes): 1eafbcd
124 return ','.join(map(str, shapes)).replace(' ', '')
126 def __repr__(self): 1eafbcd
127 return self._repr(self.in_shapes) + '->' + self._repr(self.out_shapes)
129# I use numpy's internals to parse the signature, but these do not correspond to
130# the description in
131# https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html. In
132# particular:
133# - integers are parsed like identifiers
134# - ? is not accepted
135# Looking at the issues, it seems a long standing issue amongst many with
136# vectorize and is not going to be solved.
137# See https://github.com/HypothesisWorks/hypothesis/blob/4e675dee1a4cba9d6902290bbc5719fd72072ec7/hypothesis-python/src/hypothesis/extra/_array_helpers.py#L289
138# for a more complete implementation