Coverage for src/lsqfitgp/copula/_distr.py: 95%

213 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/copula/_distr.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" define Distr and distribution """ 

21 

22import abc 1feabcd

23import functools 1feabcd

24import collections 1feabcd

25import numbers 1feabcd

26import inspect 1feabcd

27import types 1feabcd

28import math 1feabcd

29 

30import gvar 1feabcd

31import numpy 1feabcd

32import jax 1feabcd

33from jax import numpy as jnp 1feabcd

34 

35from .. import _gvarext 1feabcd

36from .. import _array 1feabcd

37from .. import _signature 1feabcd

38from . import _base 1feabcd

39 

40######### The following 5 functions are adapted from numpy.lib.mixins ######### 

41 

42def _disables_array_ufunc(obj): 1feabcd

43 """True when __array_ufunc__ is set to None.""" 

44 return getattr(obj, '__array_ufunc__', NotImplemented) is None 1eabcd

45 

46def _binary_method(ufunc, name): 1feabcd

47 """Implement a forward binary method with a ufunc, e.g., __add__.""" 

48 def func(self, other): 1feabcd

49 if _disables_array_ufunc(other): 49 ↛ 50line 49 didn't jump to line 50 because the condition on line 49 was never true1abcd

50 return NotImplemented 

51 return ufunc(self, other) 1abcd

52 func.__name__ = '__{}__'.format(name) 1feabcd

53 return func 1feabcd

54 

55def _reflected_binary_method(ufunc, name): 1feabcd

56 """Implement a reflected binary method with a ufunc, e.g., __radd__.""" 

57 def func(self, other): 1feabcd

58 if _disables_array_ufunc(other): 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true1eabcd

59 return NotImplemented 

60 return ufunc(other, self) 1eabcd

61 func.__name__ = '__r{}__'.format(name) 1feabcd

62 return func 1feabcd

63 

64def _numeric_methods(ufunc, name): 1feabcd

65 """Implement forward and reflected binary methods with a ufunc.""" 

66 return (_binary_method(ufunc, name), 1feabcd

67 _reflected_binary_method(ufunc, name)) 

68 

69def _unary_method(ufunc, name): 1feabcd

70 """Implement a unary special method with a ufunc.""" 

71 def func(self): 1feabcd

72 return ufunc(self) 1abcd

73 func.__name__ = '__{}__'.format(name) 1feabcd

74 return func 1feabcd

75 

76############################################################################### 

77 

78class Distr(_base.DistrBase): 1feabcd

79 r""" 

80 

81 Abstract base class to represent probability distributions. 

82 

83 A `Distr` object represents a probability distribution of a variable in 

84 :math:`\mathbb R^n`, and provides a transformation function from a 

85 (multivariate) Normal variable to the target random variable. 

86 

87 The main functionality is defined in `DistrBase`. The additional attributes 

88 and methods `params`, `signature`, and `invfcn` are not intended for common 

89 usage. 

90 

91 Parameters 

92 ---------- 

93 *params : tuple of scalar, array or Distr 

94 The parameters of the distribution. If the parameters have leading axes 

95 other than those required, the distribution is repeated i.i.d. 

96 over those axes. If a parameter is an instance of `Distr` itself, it 

97 is a random parameter and its distribution is accounted for. 

98 shape : int or tuple of int 

99 The shape of the array of i.i.d. variables to be represented, scalar by 

100 default. If the variable is multivariate, this shape adds as leading 

101 axes in the array. This shape broadcasts with the non-core shapes of the 

102 parameters. 

103 name : str, optional 

104 If specified, the distribution is defined for usage with 

105 `gvar.BufferDict` using `gvar.BufferDict.add_distribution`, and for 

106 convenience the constructor returns an array of gvars with the 

107 appropriate shape instead of the `Distr` object. See `add_distribution`. 

108 

109 Returns 

110 ------- 

111 If `name` is None (default): 

112 

113 distr : Distr 

114 An object representing the distribution. 

115 

116 Else: 

117 

118 gvars : array of gvars 

119 An array of primary gvars that can be set as value in a 

120 `gvar.BufferDict` under a key that uses the just defined name. 

121 

122 Attributes 

123 ---------- 

124 params : tuple 

125 The parameters as passed to the constructor. 

126 signature : Signature 

127 An object representing the signature of `invfcn`. This is a class 

128 attribute. 

129 

130 Methods 

131 ------- 

132 invfcn : classmethod 

133 Transformation function from a (multivariate) Normal variable to the 

134 target random variable. 

135 

136 Examples 

137 -------- 

138 

139 Use directly with `gvar.BufferDict` by setting `name`: 

140 

141 >>> copula = gvar.BufferDict({ 

142 ... 'A(x)': lgp.copula.beta(1, 1, name='A'), 

143 ... 'B(y)': lgp.copula.beta(3, 5, name='B'), 

144 ... }) 

145 >>> copula['x'] 

146 0.50(40) 

147 >>> copula['y'] 

148 0.36(18) 

149 

150 Corresponding "unrolled" usage: 

151 

152 >>> A = lgp.copula.beta(1, 1) 

153 >>> B = lgp.copula.beta(3, 5) 

154 >>> A.add_distribution('A') 

155 >>> B.add_distribution('B') 

156 >>> copula = gvar.BufferDict({ 

157 ... 'A(x)': A.gvars(), 

158 ... 'B(y)': B.gvars(), 

159 ... }) 

160 

161 Notice that, although the name used for `add_distribution` must be globally 

162 unique, for convenience it is permitted to redefine the same distribution 

163 family with the same parameters, even from another `Distr` instance. 

164 

165 To generate automatically sensible names and avoid repeating them twice, use 

166 `makedict`: 

167 

168 >>> lgp.copula.makedict({ 

169 ... 'x': lgp.copula.beta(1, 1), 

170 ... 'y': lgp.copula.beta(3, 5), 

171 ... }) 

172 BufferDict({'__copula_beta{1, 1}(x)': 0.0(1.0), '__copula_beta{3, 5}(y)': 0.0(1.0)}) 

173 

174 Define a distribution with a random parameter: 

175 

176 >>> X = lgp.copula.halfnorm(np.sqrt(lgp.copula.invgamma(1, 1))) 

177 >>> X 

178 halfnorm(sqrt(invgamma(1, 1))) 

179 

180 Now `X` represents the model 

181 

182 .. math:: 

183 \sigma^2 &\sim \mathrm{InvGamma}(1, 1), \\ 

184 X \mid \sigma &\sim \mathrm{HalfNorm}(\sigma). 

185 

186 In general it is possible to transform a `Distr` with `numpy` ufuncs and 

187 continuous arithmetic operations. 

188 

189 Repeated usage of `Distr` instances for random parameters will share 

190 those parameters in the distributions. The following code: 

191 

192 >>> sigma2 = lgp.copula.invgamma(1, 1) 

193 >>> X = lgp.copula.halfnorm(np.sqrt(sigma2)) 

194 >>> Y = lgp.copula.halfcauchy(np.sqrt(sigma2)) 

195 

196 Corresponds to the model 

197 

198 .. math:: 

199 \sigma^2 &\sim \mathrm{InvGamma}(1, 1), \\ 

200 X \mid \sigma &\sim \mathrm{HalfNorm}(\sigma), \\ 

201 Y \mid \sigma &\sim \mathrm{HalfCauchy}(\sigma), 

202 

203 with the same parameter :math:`\sigma^2` shared between the two 

204 distributions. However, if the distributions are now put into a 

205 `gvar.BufferDict`, with 

206  

207 >>> sigma2.add_distribution('distr_sigma2') 

208 >>> X.add_distribution('distr_X') 

209 >>> Y.add_distribution('distr_Y') 

210 >>> bd = gvar.BufferDict({ 

211 ... 'distr_sigma2(sigma2)': sigma2.gvars(), 

212 ... 'distr_X(X)': X.gvars(), 

213 ... 'distr_Y(Y)': Y.gvars(), 

214 ... }) 

215 

216 then this relationship breaks down; the model represented by the dictionary 

217 `bd` is 

218 

219 .. math:: 

220 \sigma^2 &\sim \mathrm{InvGamma}(1, 1), \\ 

221 X \mid \sigma_X &\sim \mathrm{HalfNorm}(\sigma_X), \quad 

222 & \sigma_X^2 &\sim \mathrm{InvGamma}(1, 1), \\ 

223 Y \mid \sigma_Y &\sim \mathrm{HalfCauchy}(\sigma_Y), \quad 

224 & \sigma_Y^2 &\sim \mathrm{InvGamma}(1, 1), 

225 

226 with separate, independent parameters :math:`\sigma,\sigma_X,\sigma_Y`, 

227 because each dictionary entry is evaluated separately. Indeed, trying to do 

228 this with `makedict` will raise an error: 

229 

230 >>> bd = lgp.copula.makedict({'sigma2': sigma2, 'X': X, 'Y': Y}) 

231 ValueError: cross-key occurrences of object(s): 

232 invgamma with id 6201535248: <sigma2>, <X.0.0>, <Y.0.0> 

233 

234 To use all the distributions at once while preserving the relationships, 

235 put them into a container of choice and wrap it as a `Copula` object: 

236 

237 >>> sigmaXY = lgp.copula.Copula({'sigma2': sigma2, 'X': X, 'Y': Y}) 

238 

239 The `Copula` provides a `partial_invfcn` function to map Normal variables 

240 to a structure, with the same layout as the input one, of desired variates. 

241 The whole `Copula` can be used in `gvar.BufferDict`: 

242 

243 >>> bd = lgp.copula.makedict({'sigmaXY': sigmaXY}) 

244 >>> bd 

245 BufferDict({"__copula_{'sigma2': invgamma{1, 1}, 'X': halfnorm{sqrt{_Path{path=[{DictKey{key='sigma2'},}]}}}, 'Y': halfcauchy{sqrt{_Path{path=[{DictKey{key='sigma2'},}]}}}}(sigmaXY)": array([0.0(1.0), 0.0(1.0), 0.0(1.0)], dtype=object)}) 

246 >>> bd['sigmaXY'] 

247 {'sigma2': 1.4(1.7), 'X': 0.81(89), 'Y': 1.2(1.7)} 

248 >>> gvar.corr(bd['sigmaXY']['X'], bd['sigmaXY']['Y']) 

249 0.21950577757757836 

250 

251 Although the actual dictionary value is a flat array, getting the unwrapped 

252 key reproduces the original structure. 

253  

254 To apply arbitrary transformations, use manually `invfcn`: 

255 

256 >>> @functools.partial(lgp.gvar_gufunc, signature='(n)->(n)') 

257 >>> @functools.partial(jnp.vectorize, signature='(n)->(n)') 

258 >>> def model_invfcn(normal_params): 

259 ... sigma2 = lgp.copula.invgamma.invfcn(normal_params[0], 1, 1) 

260 ... sigma = jnp.sqrt(sigma2) 

261 ... X = lgp.copula.halfnorm.invfcn(normal_params[1], sigma) 

262 ... Y = lgp.copula.halfcauchy.invfcn(normal_params[2], sigma) 

263 ... return jnp.stack([sigma, X, Y]) 

264  

265 The `jax.numpy.vectorize` decorator makes `model_invfcn` support 

266 broadcasting on additional input axes, while `gvar_gufunc` makes it accept 

267 gvars as input. 

268 

269 See also 

270 -------- 

271 DistrBase, Copula, gvar.BufferDict.uniform 

272 

273 Notes 

274 ----- 

275 Concrete subclasses must define `invfcn`, and define the class attribute 

276 `signature` to the numpy signature string of `invfcn`, unless `invfcn` is an 

277 ufunc and its number of parameters can be inferred. `invfcn` must be 

278 vectorized. 

279 

280 """ 

281 

282 @classmethod 1feabcd

283 @abc.abstractmethod 1feabcd

284 def invfcn(cls, x, *params): 1feabcd

285 r""" 

286 

287 Normal to desired distribution transformation. 

288 

289 Maps a (multivariate) Normal variable to a variable with the desired 

290 marginal distribution. In symbols: :math:`y = F^{-1}(\Phi(x))`. This 

291 function is a generalized ufunc, jax traceable, vmappable one time, and 

292 differentiable one time. The signature is accessible through the 

293 class attribute `signature`. 

294 

295 Parameters 

296 ---------- 

297 x : array_like 

298 The input Normal variable. 

299 *params : array_like 

300 The parameters of the distribution. 

301 

302 Returns 

303 ------- 

304 y : array_like 

305 The output variable with the desired marginal distribution. 

306 

307 """ 

308 pass 

309 

310 def _get_x_core_shape(self, *preprocessed_params): 1feabcd

311 sig = self.signature.eval(None, *preprocessed_params) 1feabcd

312 return sig.core_in_shapes[0] 1feabcd

313 

314 def _eval_shapes(self, shape): 1feabcd

315 

316 # check number of parameters 

317 if self.signature.nin != 1 + len(self.params): 1feabcd

318 raise TypeError(f'{self.__class__.__name__} distribution has ' 1abcd

319 f'{self.signature.nin - 1} parameters, but {len(self.params)} ' 

320 'parameters were passed to the constructor') 

321 

322 # convert shape to tuple 

323 if isinstance(shape, numbers.Integral): 1feabcd

324 shape = (shape,) 1abcd

325 else: 

326 shape = tuple(shape) 1feabcd

327 

328 # make sure parameters have a shape 

329 array_params = [ 1feabcd

330 p if hasattr(p, 'shape') else jnp.asarray(p) 

331 for p in self.params 

332 ] 

333 

334 # parse signature of cls.invfcn 

335 x_core_shape = self._get_x_core_shape(*array_params) 1feabcd

336 x = jax.ShapeDtypeStruct(shape + x_core_shape, 'd') 1feabcd

337 sig = self.signature.eval(x, *array_params) 1feabcd

338 self._in_shape_1 = sig.in_shapes[0] 1feabcd

339 self.distrshape, = sig.core_out_shapes 1feabcd

340 self.shape, = sig.out_shapes 1feabcd

341 

342 self._compute_in_shape() 1feabcd

343 

344 def _compute_in_shape(self): 1feabcd

345 in_size = math.prod(self._in_shape_1) 1feabcd

346 cache = set() 1feabcd

347 for p in self.params: 1feabcd

348 if isinstance(p, __class__): 1feabcd

349 in_size += p._compute_in_size(cache) 1eabcd

350 if in_size == 1: 1feabcd

351 self.in_shape = () 1feabcd

352 else: 

353 self.in_shape = in_size, 1abcd

354 self._ancestor_count = len(cache) 1feabcd

355 

356 def _compute_in_size(self, cache): 1feabcd

357 if (out := super()._compute_in_size(cache)) is not None: 1eabcd

358 return out 1abcd

359 in_size = math.prod(self._in_shape_1) 1eabcd

360 for p in self.params: 1eabcd

361 if isinstance(p, __class__): 1eabcd

362 in_size += p._compute_in_size(cache) 1abcd

363 return in_size 1eabcd

364 

365 def _partial_invfcn_internal(self, x, i, cache): 1feabcd

366 if (out := super()._partial_invfcn_internal(x, i, cache)) is not None: 1feabcd

367 return out 1abcd

368 

369 concrete_params = [] 1feabcd

370 for p in self.params: 1feabcd

371 

372 if isinstance(p, __class__): 1feabcd

373 p, i = p._partial_invfcn_internal(x, i, cache) 1eabcd

374 else: 

375 p = jnp.asarray(p) 1feabcd

376 

377 concrete_params.append(p) 1feabcd

378 

379 in_size = math.prod(self._in_shape_1) 1feabcd

380 assert i + in_size <= x.size 1feabcd

381 last = x[i:i + in_size].reshape(self._in_shape_1) 1feabcd

382 

383 y = self.invfcn(last, *concrete_params) 1feabcd

384 if y.shape != self.shape or y.dtype != self.dtype: 384 ↛ 385line 384 didn't jump to line 385 because the condition on line 384 was never true1feabcd

385 raise ValueError(f'{self.__class__.__name__}.invfcn returned ' 

386 f'array with shape {y.shape} and dtype {y.dtype}, while ' 

387 f'{self.shape} and {self.dtype} were expected') 

388 

389 cache[self] = y 1feabcd

390 return y, i + in_size 1feabcd

391 

392 @functools.cached_property 1feabcd

393 def _partial_invfcn(self): 1feabcd

394 

395 # determine signature 

396 shapestr = lambda shape: ','.join(map(str, shape)) 1feabcd

397 signature = f'({shapestr(self.in_shape)})->({shapestr(self.shape)})' 1feabcd

398 

399 # wrap to support gvars 

400 @functools.partial(_gvarext.gvar_gufunc, signature=signature) 1feabcd

401 # @jax.jit 

402 @functools.partial(jnp.vectorize, signature=signature) 1feabcd

403 def _partial_invfcn(x): 1feabcd

404 assert x.shape == self.in_shape 1feabcd

405 if not self.in_shape: 1feabcd

406 x = x[None] 1feabcd

407 cache = {} 1feabcd

408 y, i = self._partial_invfcn_internal(x, 0, cache) 1feabcd

409 assert i == x.size 1feabcd

410 assert len(cache) == 1 + self._ancestor_count 1feabcd

411 return y 1feabcd

412 

413 return _partial_invfcn 1feabcd

414 

415 def __init_subclass__(cls, **kw): 1feabcd

416 super().__init_subclass__(**kw) 1feabcd

417 

418 # check and/or set signature attribute (the gufunc signature of invfcn) 

419 if not hasattr(cls, 'signature'): 1feabcd

420 sig = inspect.signature(cls.invfcn) 1feabcd

421 if not all( 421 ↛ 425line 421 didn't jump to line 425 because the condition on line 421 was never true1feabcd

422 p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) 

423 for p in sig.parameters.values() 

424 ): 

425 raise ValueError('can not automatically infer signature of ' 

426 f'{cls.__qualname__}.invfcn') 

427 cls.signature = ','.join(['()'] * len(sig.parameters)) + '->()' 1feabcd

428 if not isinstance(cls.signature, _signature.Signature): 428 ↛ 430line 428 didn't jump to line 430 because the condition on line 428 was always true1feabcd

429 cls.signature = _signature.Signature(cls.signature) 1feabcd

430 cls.signature.check_nargs(cls.invfcn) 1feabcd

431 

432 # set dtype to float if not specified 

433 if getattr(cls, 'dtype', NotImplemented) is NotImplemented: 1feabcd

434 cls.dtype = jax.dtypes.canonicalize_dtype(jnp.float64) 1feabcd

435 

436 # set __signature__ to take positional parameters from invfcn 

437 sig = inspect.signature(cls.invfcn) 1feabcd

438 pos_params = list(sig.parameters.values())[1:] 1feabcd

439 sig = inspect.signature(cls.__new__) 1feabcd

440 key_params = [ 1feabcd

441 p for i, p in enumerate(sig.parameters.values()) 

442 if p.kind in (inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) 

443 and i > 0 

444 ] 

445 cls.__signature__ = inspect.Signature(pos_params + key_params) 1feabcd

446 

447 def __new__(cls, *params, shape=(), name=None): 1feabcd

448 

449 self = super().__new__(cls) 1feabcd

450 self.params = params 1feabcd

451 self._eval_shapes(shape) 1feabcd

452 

453 if name is None: 1feabcd

454 return self 1feabcd

455 else: 

456 self.add_distribution(name) 1abcd

457 return self.gvars() 1abcd

458 

459 class _Descr(collections.namedtuple('Distr', 'family shape params')): 1feabcd

460 """ static representation of a Distr object """ 

461 

462 def __repr__(self): 1feabcd

463 args = list(map(repr, self.params)) 1feabcd

464 if len(self.shape) == 1: 1feabcd

465 args += [f'shape={self.shape[0]}'] 1abcd

466 elif self.shape: 1feabcd

467 args += [f'shape={self.shape}'] 1abcd

468 arglist = ', '.join(args) 1feabcd

469 return f'{self.family.__name__}({arglist})' 1feabcd

470 

471 def _compute_staticdescr(self, path, cache): 1feabcd

472 if (obj := super()._compute_staticdescr(path, cache)) is not None: 1feabcd

473 return obj 1abcd

474 

475 params = [] 1feabcd

476 for i, p in enumerate(self.params): 1feabcd

477 if isinstance(p, __class__): 1feabcd

478 p = p._compute_staticdescr(path + [i], cache) 1eabcd

479 else: 

480 p = numpy.asarray(p).tolist() 1feabcd

481 params.append(p) 1feabcd

482 

483 return self._Descr(self.__class__, self.shape, tuple(params)) 1feabcd

484 

485 def _shapestr(self, shape): 1feabcd

486 if shape: 1eabcd

487 return (str(shape) 1abcd

488 .replace(',)', ')') 

489 .replace('(' , '[') 

490 .replace(')' , ']') 

491 .replace(' ', '') 

492 ) 

493 else: 

494 return '' 1e

495 

496 def __repr__(self, path='', cache=None): 1feabcd

497 

498 if isinstance(cache := super().__repr__(path, cache), str): 1feabcd

499 return cache 1abcd

500 

501 args = [] 1feabcd

502 for i, p in enumerate(self.params): 1feabcd

503 

504 if isinstance(p, __class__): 1feabcd

505 p = p.__repr__('.'.join((path, str(i))).lstrip('.'), cache) 1eabcd

506 elif hasattr(p, 'shape'): 1feabcd

507 p = f'Array{self._shapestr(p.shape)}' 1eabcd

508 else: 

509 p = repr(p) 1feabcd

510 args.append(p) 1feabcd

511 

512 if len(self.shape) == 1: 1feabcd

513 args += [f'shape={self.shape[0]}'] 1abcd

514 elif self.shape: 1feabcd

515 args += [f'shape={self.shape}'] 1abcd

516 

517 return f'{self.__class__.__name__}({", ".join(args)})' 1feabcd

518 

519 def __array_ufunc__(self, ufunc, method, *inputs, **kw): 1feabcd

520 if method != '__call__' or kw or ufunc.signature: 520 ↛ 522line 520 didn't jump to line 522 because the condition on line 520 was never true1eabcd

521 # TODO jax 0.4.15 should introduce ufunc methods 

522 return NotImplemented 

523 ufunc_class = UFunc.make_subclass(ufunc) 1eabcd

524 return ufunc_class(*inputs) 1eabcd

525 

526 # TODO make this work with gufuncs. See comment in _signature.py. 

527 # matmul in particular. 

528 

529 # continuous binary operations 

530 __add__, __radd__ = _numeric_methods(numpy.add, 'add') 1feabcd

531 __sub__, __rsub__ = _numeric_methods(numpy.subtract, 'sub') 1feabcd

532 __mul__, __rmul__ = _numeric_methods(numpy.multiply, 'mul') 1feabcd

533 # __matmul__, __rmatmul__ = _numeric_methods(numpy.matmul, 'matmul') 

534 __truediv__, __rtruediv__ = _numeric_methods(numpy.divide, 'truediv') 1feabcd

535 __mod__, __rmod__ = _numeric_methods(numpy.remainder, 'mod') 1feabcd

536 __divmod__, __rdivmod__ = _numeric_methods(numpy.divmod, 'divmod') 1feabcd

537 __pow__, __rpow__ = _numeric_methods(numpy.power, 'pow') 1feabcd

538 

539 # continuous unary operations 

540 __neg__ = _unary_method(numpy.negative, 'neg') 1feabcd

541 __pos__ = _unary_method(numpy.positive, 'pos') 1feabcd

542 __abs__ = _unary_method(numpy.absolute, 'abs') 1feabcd

543 

544 # TODO add __getitem__ and __array_function__ 

545 

546class UFunc: 1feabcd

547 """ base class of objects representing ufuncs applied to Distr instances """ 

548 

549 def __new__(cls, *args): 1feabcd

550 return super().__new__(cls, *args) 1eabcd

551 # this __new__ serves to forbid keyword arguments 

552 

553 @classmethod 1feabcd

554 def invfcn(cls, x, *args): 1feabcd

555 return cls._ufunc(*args) 1eabcd

556 

557 def _get_x_core_shape(self, *_): 1feabcd

558 return (0,) 1eabcd

559 

560 @classmethod 1feabcd

561 @functools.lru_cache(maxsize=None) # functools.cache not available in 3.8 1feabcd

562 def make_subclass(cls, ufunc): 1feabcd

563 def exec_body(ns): 1eabcd

564 ns['_ufunc'] = getattr(jnp, ufunc.__name__) 1eabcd

565 ns['signature'] = ','.join(['(0)'] + ufunc.nin * ['()']) + '->()' 1eabcd

566 return types.new_class(ufunc.__name__, (__class__, Distr), exec_body=exec_body) 1eabcd

567 

568def distribution(invfcn, signature=None, dtype=None): 1feabcd

569 r""" 

570 

571 Decorator to define a distribution from a transformation function. 

572 

573 Parameters 

574 ---------- 

575 invfcn : function 

576 The transformation function from a (multivariate) standard Normal 

577 variable to the target random variable. The signature must be 

578 ``invfcn(x, *params)``. It must be jax-traceable. It does not need to 

579 be vectorized. 

580 signature : str, optional 

581 The signature of `invfcn`, as a numpy signature string. If not 

582 specified, `invfcn` is assumed to take and output scalars. 

583 dtype : dtype, optional 

584 The dtype of the output of `invfcn`. If not specified, it is assumed to 

585 be floating point. 

586 

587 Returns 

588 ------- 

589 cls : Distr 

590 The new distribution class. 

591 

592 Examples 

593 -------- 

594 

595 >>> @lgp.copula.distribution 

596 ... def uniform(x, a, b): 

597 ... return a + (b - a) * jax.scipy.stats.norm.cdf(x) 

598 

599 >>> @functools.partial(lgp.copula.distribution, signature='(n,m)->(n)') 

600 ... def wishart(x): 

601 ... " this parametrization is terrible, do not use " 

602 ... return x @ x.T 

603 

604 """ 

605 

606 def exec_body(ns): 1abcd

607 if signature is not None: 607 ↛ 609line 607 didn't jump to line 609 because the condition on line 607 was always true1abcd

608 ns['signature'] = signature 1abcd

609 if dtype is not None: 609 ↛ 611line 609 didn't jump to line 611 because the condition on line 609 was always true1abcd

610 ns['dtype'] = dtype 1abcd

611 ns['invfcn'] = staticmethod(jnp.vectorize(invfcn, signature=signature)) 1abcd

612 

613 return types.new_class(invfcn.__name__, (Distr,), exec_body=exec_body) 1abcd