Coverage for src/lsqfitgp/_signature.py: 90%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_signature.py 

2# 

3# Copyright (c) 2023, 2024 Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" define Signature """ 

21 

22import inspect 1eafbcd

23 

24try: 1eafbcd

25 from numpy.lib import function_base # numpy 1 1eafbcd

26except ImportError: 1eabcd

27 from numpy.lib import _function_base_impl as function_base # numpy 2 1eabcd

28import jax 1eafbcd

29from jax import numpy as jnp 1eafbcd

30 

31class Signature: 1eafbcd

32 """ Class to parse a numpy gufunc signature. """ 

33 

34 def __init__(self, signature): 1eafbcd

35 self.signature = signature 1eafbcd

36 self.incores, self.outcores = function_base._parse_gufunc_signature(signature) 1eafbcd

37 

38 @classmethod 1eafbcd

39 def from_tuples(cls, incores, outcores): 1eafbcd

40 self = cls.__new__(cls) 1eafbcd

41 tuplestr = lambda t: '(' + ','.join(map(str, t)) + ')' 1eafbcd

42 self.signature = ','.join(map(tuplestr, incores)) + '->' + ','.join(map(tuplestr, outcores)) 1eafbcd

43 self.incores = incores 1eafbcd

44 self.outcores = outcores 1eafbcd

45 return self 1eafbcd

46 

47 def __repr__(self): 1eafbcd

48 return self.signature 

49 

50 def check_nargs(self, func): 1eafbcd

51 """ Check that the function has the correct number of arguments. """ 

52 sig = inspect.signature(func) 1eafbcd

53 if any(p.kind == inspect.Parameter.VAR_POSITIONAL for p in sig.parameters.values()): 1eafbcd

54 return 1afbcd

55 if len(sig.parameters) != len(self.incores): 55 ↛ 56line 55 didn't jump to line 56 because the condition on line 55 was never true1eafbcd

56 raise ValueError(f'function {func} has {len(sig.parameters)} ' 

57 f'arguments, but signature {self.signature} ' 

58 f'requires {len(self.incores)}') 

59 

60 @property 1eafbcd

61 def nin(self): 1eafbcd

62 return len(self.incores) 1eafbcd

63 

64 @property 1eafbcd

65 def nout(self): 1eafbcd

66 return len(self.outcores) 

67 

68 def eval(self, *args): 1eafbcd

69 """ 

70 

71 Evaluate the signature with the given arguments. 

72 

73 Parameters 

74 ---------- 

75 args : sequence of numpy.ndarray or None 

76 A missing argument can be replaced with None, provided the other 

77 arguments are sufficient to infer all dimension sizes. 

78 

79 Returns 

80 ------- 

81 sig : EvaluatedSignature 

82 An object with attributes `broadcast_shape`, `sizes`, 

83 `core_out_shapes`, `out_shapes`, `core_in_shapes`, `in_shapes`. 

84 

85 """ 

86 return self.EvaluatedSignature(self, *args) 1eafbcd

87 

88 class EvaluatedSignature: 1eafbcd

89 

90 def __init__(self, sig, *args): 1eafbcd

91 

92 assert len(args) == len(sig.incores) 1eafbcd

93 

94 known_args = [] 1eafbcd

95 known_cores = [] 1eafbcd

96 missing_cores = [] 1eafbcd

97 for arg, core in zip(args, sig.incores): 1eafbcd

98 if arg is None: 1eafbcd

99 missing_cores.append(core) 1eafbcd

100 else: 

101 known_args.append(jax.ShapeDtypeStruct(arg.shape, 'd')) 1eafbcd

102 known_cores.append(core) 1eafbcd

103 

104 self.broadcast_shape, self.sizes = function_base._parse_input_dimensions(known_args, known_cores) 1eafbcd

105 

106 missing_indices = set(sum(missing_cores, ())) 1eafbcd

107 missing_indices.difference_update(self.sizes) 1eafbcd

108 if missing_indices: 108 ↛ 109line 108 didn't jump to line 109 because the condition on line 108 was never true1eafbcd

109 raise ValueError(f'cannot infer sizes of dimesions {missing_indices} from signature {sig.signature}') 

110 

111 self.core_out_shapes, self.out_shapes = self._compute_shapes(sig.outcores) 1eafbcd

112 self.core_in_shapes, self.in_shapes = self._compute_shapes(sig.incores) 1eafbcd

113 

114 def _compute_shapes(self, cores): 1eafbcd

115 coreshapes = [] 1eafbcd

116 shapes = [] 1eafbcd

117 for core in cores: 1eafbcd

118 core = tuple(self.sizes[i] for i in core) 1eafbcd

119 coreshapes.append(core) 1eafbcd

120 shapes.append(self.broadcast_shape + core) 1eafbcd

121 return tuple(coreshapes), tuple(shapes) 1eafbcd

122 

123 def _repr(self, shapes): 1eafbcd

124 return ','.join(map(str, shapes)).replace(' ', '') 

125 

126 def __repr__(self): 1eafbcd

127 return self._repr(self.in_shapes) + '->' + self._repr(self.out_shapes) 

128 

129# I use numpy's internals to parse the signature, but these do not correspond to 

130# the description in 

131# https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html. In 

132# particular: 

133# - integers are parsed like identifiers 

134# - ? is not accepted 

135# Looking at the issues, it seems a long standing issue amongst many with 

136# vectorize and is not going to be solved. 

137# See https://github.com/HypothesisWorks/hypothesis/blob/4e675dee1a4cba9d6902290bbc5719fd72072ec7/hypothesis-python/src/hypothesis/extra/_array_helpers.py#L289 

138# for a more complete implementation