Coverage for src/lsqfitgp/_patch_gvar.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_patch_gvar.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" modifications to the global state of gvar """ 

21 

22import functools 1abcdef

23 

24import gvar 1abcdef

25from jax import numpy as jnp 1abcdef

26from jax.scipy import special as jspecial 1abcdef

27 

28gvar_ufuncs = [ 1abcdef

29 'sin', 

30 'cos', 

31 'tan', 

32 'exp', 

33 'log', 

34 'sqrt', 

35 'fabs', 

36 'sinh', 

37 'cosh', 

38 'tanh', 

39 'arcsin', 

40 'arccos', 

41 'arctan', 

42 'arctan2', 

43 'arcsinh', 

44 'arccosh', 

45 'arctanh', 

46 'square', 

47 'erf', 

48] 

49 

50for fname in gvar_ufuncs: 1abcdef

51 fgvar = getattr(gvar, fname) 1abcdef

52 fjax = getattr(jnp, fname, getattr(jspecial, fname, NotImplemented)) 1abcdef

53 fboth = functools.singledispatch(fgvar) 1abcdef

54 fboth.register(jnp.ndarray, fjax) 1abcdef

55 setattr(gvar, fname, fboth) 1abcdef

56 

57# reset transformations to support jax arrays 

58gvar.BufferDict.del_distribution('log') 1abcdef

59gvar.BufferDict.del_distribution('sqrt') 1abcdef

60gvar.BufferDict.del_distribution('erfinv') 1abcdef

61gvar.BufferDict.add_distribution('log', gvar.exp) 1abcdef

62gvar.BufferDict.add_distribution('sqrt', gvar.square) 1abcdef

63gvar.BufferDict.add_distribution('erfinv', gvar.erf) 1abcdef