Coverage for src/lsqfitgp/copula/_base.py: 96%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/copula/_base.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" define DistrBase """ 

21 

22import abc 1feabcd

23import functools 1feabcd

24import collections 1feabcd

25 

26import gvar 1feabcd

27import numpy 1feabcd

28 

29class DistrBase(metaclass=abc.ABCMeta): 1feabcd

30 r""" 

31 

32 Abstract base class to represent (trees of) probability distributions. 

33 

34 Attributes 

35 ---------- 

36 in_shape : tuple of int 

37 The core shape of the input array to `partial_invfcn`. 

38 shape : (tree of) tuple of int 

39 The core shape of the output array of `partial_invfcn`. 

40 dtype : (tree of) dtype 

41 The dtype of the output array of `partial_invfcn`. 

42 distrshape : (tree of) tuple of int 

43 The sub-core shape of the output array of `partial_invfcn` that 

44 represents the atomic shape of the distribution. 

45 

46 Methods 

47 ------- 

48 partial_invfcn : 

49 Transformation function from a (multivariate) Normal variable to the 

50 target random variable. 

51 add_distribution : 

52 Register the distribution for usage with `gvar.BufferDict`. 

53 gvars : 

54 Return an array of gvars with the appropriate shape for usage with 

55 `gvar.BufferDict`. 

56 

57 See also 

58 -------- 

59 Distr, Copula 

60 

61 """ 

62 

63 def __init_subclass__(cls, **kw): 1feabcd

64 super().__init_subclass__(**kw) 1feabcd

65 cls._named = {} 1feabcd

66 

67 in_shape = NotImplemented 1feabcd

68 shape = NotImplemented 1feabcd

69 dtype = NotImplemented 1feabcd

70 distrshape = NotImplemented 1feabcd

71 

72 def partial_invfcn(self, x): 1feabcd

73 """ 

74  

75 Map independent Normal variables to the desired distribution. 

76 

77 This function is a generalized ufunc. It is jax traceable and 

78 differentiable one time. It supports arrays of gvars as input. The 

79 attributes `in_shape` and `shape` give the core shapes. 

80 

81 Parameters 

82 ---------- 

83 x : ``(..., *in_shape)`` array 

84 An array of values representing draws of i.i.d. Normal variates. 

85 

86 Returns 

87 ------- 

88 y : (tree of) ``(..., *shape)`` array 

89 An array of values representing draws of the desired distribution. 

90 

91 """ 

92 return self._partial_invfcn(x) 1feabcd

93 

94 @abc.abstractmethod 1feabcd

95 def _partial_invfcn(self, x): 1feabcd

96 pass 

97 

98 def _is_same_family(self, invfcn): 1feabcd

99 return getattr(invfcn, '__self__', None).__class__ is self.__class__ 1eabcd

100 

101 def add_distribution(self, name): 1feabcd

102 """ 

103 

104 Register the distribution for usage with `gvar.BufferDict`. 

105 

106 Parameters 

107 ---------- 

108 name : str 

109 The name to use for the distribution. It must be globally unique, 

110 and it should not contain parentheses. To redefine a distribution 

111 with the same name, use `gvar.BufferDict.del_distribution` first. 

112 However, it is allowed to reuse the name if the distribution family, 

113 shape and parameters are identical to those used for the existing 

114 definition. 

115 

116 See also 

117 -------- 

118 gvar.BufferDict.add_distribution, gvar.BufferDict.del_distribution 

119 

120 """ 

121 

122 if gvar.BufferDict.has_distribution(name): 1feabcd

123 invfcn = gvar.BufferDict.invfcn[name] 1eabcd

124 if not self._is_same_family(invfcn): 1eabcd

125 raise ValueError(f'distribution {name} already defined') 1abcd

126 existing = self._named[name] 1eabcd

127 if existing != self._staticdescr: 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true1eabcd

128 raise ValueError('Attempt to overwrite existing' 

129 f' {self.__class__.__name__} distribution with name {name}') 

130 # cls._named is not updated by 

131 # gvar.BufferDict.del_distribution, but it is not a problem 

132 

133 else: 

134 gvar.BufferDict.add_distribution(name, self.partial_invfcn) 1feabcd

135 self._named[name] = self._staticdescr 1feabcd

136 

137 def gvars(self): 1feabcd

138 """ 

139 

140 Return an array of gvars intended as value in a `gvar.BufferDict`. 

141 

142 Returns 

143 ------- 

144 gvars : array of gvars 

145 An array of i.i.d. standard Normal primary gvars with shape 

146 `in_shape`. 

147 

148 """ 

149 

150 return gvar.gvar(numpy.zeros(self.in_shape), numpy.ones(self.in_shape)) 1feabcd

151 

152 @abc.abstractmethod 1feabcd

153 def __repr__(self, path='', cache=None): 1feabcd

154 """ produce a representation where no object appears more than once, 

155 later appearances are replaced by a user-friendly identifier """ 

156 if cache is None: 1feabcd

157 cache = {} 1abcd

158 if self in cache: 1feabcd

159 return cache[self] 1abcd

160 cache[self] = f'<{path}>' 1feabcd

161 return cache 1feabcd

162 

163 class _Path(collections.namedtuple('Path', ['path'])): pass 1feabcd

164 

165 @abc.abstractmethod 1feabcd

166 def _compute_staticdescr(self, path, cache): 1feabcd

167 """ compute static description of self, can be compared """ 

168 if self in cache: 1feabcd

169 return cache[self] 1abcd

170 cache[self] = self._Path(path) 1feabcd

171 

172 @functools.cached_property 1feabcd

173 def _staticdescr(self): 1feabcd

174 return self._compute_staticdescr([], {}) 1feabcd

175 

176 @abc.abstractmethod 1feabcd

177 def _compute_in_size(self, cache): 1feabcd

178 """ compute input size to partial_invfcn, without double counting """ 

179 if self in cache: 1eabcd

180 return 0 1abcd

181 cache.add(self) 1eabcd

182 

183 @abc.abstractmethod 1feabcd

184 def _partial_invfcn_internal(self, x, i, cache): 1feabcd

185 assert x.ndim == 1 1feabcd

186 if self in cache: 1feabcd

187 return cache[self], i 1abcd