Coverage for src/lsqfitgp/_GP/_base.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_GP/_base.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1efabcd

21 

22import jax 1efabcd

23from jax import numpy as jnp 1efabcd

24 

25from .. import _jaxext 1efabcd

26from .. import _utils 1efabcd

27 

28class GPBase: 1efabcd

29 

30 def __init__(self, *, checkfinite=True, checklin=True): 1efabcd

31 self._checkfinite = bool(checkfinite) 1efabcd

32 self._checklin = bool(checklin) 1efabcd

33 

34 def _clone(self): 1efabcd

35 newself = object.__new__(self.__class__) 1efabcd

36 newself._checkfinite = self._checkfinite 1efabcd

37 newself._checklin = self._checklin 1efabcd

38 return newself 1efabcd

39 

40 class _SingletonMeta(type): 1efabcd

41 

42 def __repr__(cls): 1efabcd

43 return cls.__name__ 1abcd

44 

45 class _Singleton(metaclass=_SingletonMeta): 1efabcd

46 

47 def __new__(cls): 1efabcd

48 raise NotImplementedError(f"{cls.__name__} can not be instantiated") 1abcd

49 

50 class DefaultProcess(_Singleton): 1efabcd

51 """ Key of the default process. """ 

52 pass 1efabcd

53 

54 def _checklinear(self, func, inshapes, elementwise=False): 1efabcd

55 

56 # Make input arrays. 

57 rkey = jax.random.PRNGKey(202206091600) 1efabcd

58 inp = [] 1efabcd

59 for shape in inshapes: 1efabcd

60 rkey, subkey = jax.random.split(rkey) 1efabcd

61 inp.append(jax.random.normal(subkey, shape)) 1efabcd

62 

63 # Put zeros into the arrays to check they are preserved. 

64 if elementwise: 1efabcd

65 shape = jnp.broadcast_shapes(*inshapes) 1abcd

66 rkey, subkey = jax.random.split(rkey) 1abcd

67 zeros = jax.random.bernoulli(subkey, 0.5, shape) 1abcd

68 for i, a in enumerate(inp): 1abcd

69 inp[i] = a.at[zeros].set(0) 1abcd

70 

71 # Compute JVP and check it is identical to the function itself. 

72 with _jaxext.skipifabstract(): 1efabcd

73 out0, out1 = jax.jvp(func, inp, inp) 1efabcd

74 if out1.dtype == jax.float0: 1efabcd

75 cond = jnp.allclose(out0, 0) 1abcd

76 else: 

77 cond = jnp.allclose(out0, out1) 1efabcd

78 if not cond: 1efabcd

79 raise RuntimeError('the transformation is not linear') 1abcd

80 

81 # Check that the function is elementwise. 

82 if elementwise: 1efabcd

83 if out0.shape != shape or not (jnp.allclose(out0[zeros], 0) and jnp.allclose(out1[zeros], 0)): 1abcd

84 raise RuntimeError('the transformation is not elementwise') 1abcd

85 

86def newself(meth): 1efabcd

87 """ Decorator to create a new GP object and pass it to the method. """ 

88 

89 @functools.wraps(meth) 1efabcd

90 def newmeth(self, *args, **kw): 1efabcd

91 self = self._clone() 1efabcd

92 meth(self, *args, **kw) 1efabcd

93 return self 1efabcd

94 

95 # append return value description to docstring 

96 doctail = """\ 1efabcd

97 Returns 

98 ------- 

99 gp : GP 

100 A new GP object with the applied modifications. 

101 """ 

102 newmeth.__doc__ = _utils.append_to_docstring(meth.__doc__, doctail) 1efabcd

103 

104 return newmeth 1efabcd