Coverage for src/lsqfitgp/_gvarext/_ufunc.py: 94%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_gvarext/_ufunc.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1efabcd

21import string 1efabcd

22 

23import jax 1efabcd

24import gvar 1efabcd

25from jax import numpy as jnp 1efabcd

26 

27from .. import _signature 1efabcd

28 

29from ._jacobian import jacobian, from_jacobian 1efabcd

30from ._tabulate import tabulate_together 1efabcd

31 

32def gvar_gufunc(func, *, signature=None): 1efabcd

33 """ 

34 

35 Wraps a jax-traceable generalized ufunc with one argument to support gvars. 

36 

37 Parameters 

38 ---------- 

39 func : callable 

40 A function from one array to one array. It must be a generalized ufunc, 

41 and differentiable one time with `jax`. 

42 signature : str, optional 

43 The signature of the generalized ufunc. If not specified, it is assumed 

44 to be scalar to scalar (normal ufunc). 

45 

46 Returns 

47 ------- 

48 decorated_func : callable 

49 A function that, in addition to numerical arrays, accepts gvars and 

50 returns gvars. 

51 

52 See also 

53 -------- 

54 numpy.vectorize 

55 

56 """ 

57 

58 # parse signature 

59 if signature is None: 59 ↛ 60line 59 didn't jump to line 60 because the condition on line 59 was never true1efabcd

60 signature = '()->()' 

61 sig = _signature.Signature(signature) 1efabcd

62 inp, = sig.incores 1efabcd

63 out, = sig.outcores 1efabcd

64 jac_sig = _signature.Signature.from_tuples([inp], [out + inp]) 1efabcd

65 

66 # make jacobian function 

67 deriv = jnp.vectorize(jax.jacfwd(func), signature=jac_sig.signature) 1efabcd

68 

69 # get indices for summation 

70 ninp = len(inp) 1efabcd

71 nout = len(out) 1efabcd

72 head_indices = '...' 1efabcd

73 out_indices = string.ascii_letters[:nout] 1efabcd

74 in_indices = string.ascii_letters[nout:nout + ninp] 1efabcd

75 gvar_indices = string.ascii_letters[nout + ninp] 1efabcd

76 

77 # make summation formula 

78 jac_indices = head_indices + out_indices + in_indices 1efabcd

79 in_jac_indices = head_indices + in_indices + gvar_indices 1efabcd

80 out_indices = head_indices + out_indices + gvar_indices 1efabcd

81 formula = f'{jac_indices},{in_jac_indices}->{out_indices}' 1efabcd

82 

83 def gvar_function(x): 1efabcd

84 

85 # unpack the gvars 

86 in_mean = gvar.mean(x) 1efabcd

87 in_jac, indices = jacobian(x) 1efabcd

88 

89 # apply function 

90 out_mean = func(in_mean) 1efabcd

91 jac = deriv(in_mean) 1efabcd

92 

93 # check shapes match 

94 head_ndim = jac.ndim - nout - ninp 1efabcd

95 assert jac.shape[:head_ndim] == in_jac.shape[:in_jac.ndim - 1 - ninp] 1efabcd

96 

97 # contract 

98 out_jac = jnp.einsum(formula, jac, in_jac) 1efabcd

99 

100 # pack output 

101 return from_jacobian(out_mean, out_jac, indices) 1efabcd

102 

103 @functools.wraps(func) 1efabcd

104 def decorated_func(x): 1efabcd

105 if isinstance(x, gvar.GVar): 1efabcd

106 out = gvar_function(x) 1efabcd

107 if not out.ndim: 107 ↛ 109line 107 didn't jump to line 109 because the condition on line 107 was always true1efabcd

108 out = out.item() 1efabcd

109 return out 1efabcd

110 elif getattr(x, 'dtype', None) == object: 1efabcd

111 return gvar_function(x) 1abcd

112 else: 

113 return func(x) 1efabcd

114 

115 return decorated_func 1efabcd

116 

117 # TODO add more than one argument or output. Possibly without taking 

118 # derivatives when it's not a gvar, i.e., merge the wrappers and cycle over 

119 # args. Also implement excluded => note that jnp.vectorize only supports 

120 # positional arguments, excluded takes in only indices, not names