Coverage for src/lsqfitgp/_Kernel/_util.py: 100%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_Kernel/_util.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import numbers 1efabcd

21import operator 1efabcd

22 

23import numpy 1efabcd

24import jax 1efabcd

25from jax import numpy as jnp 1efabcd

26from jax import tree_util 1efabcd

27 

28from .. import _array 1efabcd

29 

30def is_numerical_scalar(x): 1efabcd

31 return ( 1efabcd

32 isinstance(x, numbers.Number) or 

33 (isinstance(x, (numpy.ndarray, jnp.ndarray)) and x.ndim == 0) 

34 ) 

35 # do not use jnp.isscalar because it returns False for strongly 

36 # typed 0-dim arrays; do not use jnp.ndim(•) == 0 because it accepts 

37 # non-numerical types 

38 

39def is_nonnegative_integer_scalar(x): 1efabcd

40 if isinstance(x, numbers.Integral) and x >= 0: 1abcd

41 # python scalars and numpy scalars 

42 return True 1abcd

43 if isinstance(x, numpy.ndarray) and x.ndim == 0 and numpy.issubdtype(x.dtype, numpy.integer) and x.item() >= 0: 1abcd

44 # 0-dim numpy arrays 

45 return True 1abcd

46 if isinstance(x, jnp.ndarray) and x.ndim == 0 and jnp.issubdtype(x.dtype, jnp.integer): 1abcd

47 try: 1abcd

48 # concrete jax arrays 

49 return x.item() >= 0 1abcd

50 except jax.errors.ConcretizationTypeError: 1abcd

51 # jax tracers 

52 return jnp.issubdtype(x.dtype, jnp.unsignedinteger) 1abcd

53 return False 1abcd

54 

55def is_scalar_cond_trueontracer(x, cond): 1efabcd

56 if isinstance(x, numbers.Number) and cond(x): 1efabcd

57 # python scalars and numpy scalars 

58 return True 1efabcd

59 if isinstance(x, numpy.ndarray) and x.ndim == 0 and numpy.issubdtype(x.dtype, numpy.number) and cond(x.item()): 1efabcd

60 # 0-dim numpy arrays 

61 return True 1abcd

62 if isinstance(x, jnp.ndarray) and x.ndim == 0 and jnp.issubdtype(x.dtype, jnp.number): 1efabcd

63 try: 1efabcd

64 # concrete jax arrays 

65 return cond(x.item()) 1efabcd

66 except jax.errors.ConcretizationTypeError: 1efabcd

67 # jax tracers 

68 return True 1efabcd

69 return False 1abcd

70 

71def is_nonnegative_scalar_trueontracer(x): 1efabcd

72 return is_scalar_cond_trueontracer(x, lambda x: x >= 0) 1efabcd

73 

74# TODO reimplement with tree_reduce, closuring ndim to recognize shaped fields 

75def _reduce_recurse_dtype(fun, args, reductor, npreductor, jnpreductor, **kw): 1efabcd

76 x = args[0] 1efabcd

77 if x.dtype.names is None: 1efabcd

78 return fun(*args, **kw) 1efabcd

79 else: 

80 acc = None 1efabcd

81 for name in x.dtype.names: 1efabcd

82 recargs = tuple(arg[name] for arg in args) 1efabcd

83 result = _reduce_recurse_dtype(fun, recargs, reductor, npreductor, jnpreductor, **kw) 1efabcd

84 

85 dtype = x.dtype[name] 1efabcd

86 if dtype.ndim: 1efabcd

87 axis = tuple(range(-dtype.ndim, 0)) 1efabcd

88 red = jnpreductor if isinstance(result, jnp.ndarray) else npreductor 1efabcd

89 result = red(result, axis=axis) 1efabcd

90 

91 if acc is None: 1efabcd

92 acc = result 1efabcd

93 else: 

94 acc = reductor(acc, result) 1eabcd

95 

96 assert acc.shape == _array.broadcast(*args).shape 1efabcd

97 return acc 1efabcd

98 

99def sum_recurse_dtype(fun, *args, **kw): 1efabcd

100 return _reduce_recurse_dtype(fun, args, operator.add, numpy.sum, jnp.sum, **kw) 1efabcd

101 

102def prod_recurse_dtype(fun, *args, **kw): 1efabcd

103 return _reduce_recurse_dtype(fun, args, operator.mul, numpy.prod, jnp.prod, **kw) 1efabcd

104 

105def ufunc_recurse_dtype(ufunc, x, *args): 1efabcd

106 """ apply an ufunc to all the leaf fields """ 

107 

108 allargs = (x, *args) 1efabcd

109 expected_shape = jnp.broadcast_shapes(*(x.shape for x in allargs)) 1efabcd

110 

111 if x.dtype.names is None: 1efabcd

112 out = ufunc(*allargs) 1efabcd

113 else: 

114 args = map(_array.StructuredArray, allargs) 1efabcd

115 out = tree_util.tree_map(ufunc, *args) 1efabcd

116 

117 assert out.shape == expected_shape 1efabcd

118 return out 1efabcd