Coverage for src/lsqfitgp/_Kernel/_util.py: 100%
64 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_Kernel/_util.py
2#
3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import numbers 1efabcd
21import operator 1efabcd
23import numpy 1efabcd
24import jax 1efabcd
25from jax import numpy as jnp 1efabcd
26from jax import tree_util 1efabcd
28from .. import _array 1efabcd
30def is_numerical_scalar(x): 1efabcd
31 return ( 1efabcd
32 isinstance(x, numbers.Number) or
33 (isinstance(x, (numpy.ndarray, jnp.ndarray)) and x.ndim == 0)
34 )
35 # do not use jnp.isscalar because it returns False for strongly
36 # typed 0-dim arrays; do not use jnp.ndim(•) == 0 because it accepts
37 # non-numerical types
39def is_nonnegative_integer_scalar(x): 1efabcd
40 if isinstance(x, numbers.Integral) and x >= 0: 1abcd
41 # python scalars and numpy scalars
42 return True 1abcd
43 if isinstance(x, numpy.ndarray) and x.ndim == 0 and numpy.issubdtype(x.dtype, numpy.integer) and x.item() >= 0: 1abcd
44 # 0-dim numpy arrays
45 return True 1abcd
46 if isinstance(x, jnp.ndarray) and x.ndim == 0 and jnp.issubdtype(x.dtype, jnp.integer): 1abcd
47 try: 1abcd
48 # concrete jax arrays
49 return x.item() >= 0 1abcd
50 except jax.errors.ConcretizationTypeError: 1abcd
51 # jax tracers
52 return jnp.issubdtype(x.dtype, jnp.unsignedinteger) 1abcd
53 return False 1abcd
55def is_scalar_cond_trueontracer(x, cond): 1efabcd
56 if isinstance(x, numbers.Number) and cond(x): 1efabcd
57 # python scalars and numpy scalars
58 return True 1efabcd
59 if isinstance(x, numpy.ndarray) and x.ndim == 0 and numpy.issubdtype(x.dtype, numpy.number) and cond(x.item()): 1efabcd
60 # 0-dim numpy arrays
61 return True 1abcd
62 if isinstance(x, jnp.ndarray) and x.ndim == 0 and jnp.issubdtype(x.dtype, jnp.number): 1efabcd
63 try: 1efabcd
64 # concrete jax arrays
65 return cond(x.item()) 1efabcd
66 except jax.errors.ConcretizationTypeError: 1efabcd
67 # jax tracers
68 return True 1efabcd
69 return False 1abcd
71def is_nonnegative_scalar_trueontracer(x): 1efabcd
72 return is_scalar_cond_trueontracer(x, lambda x: x >= 0) 1efabcd
74# TODO reimplement with tree_reduce, closuring ndim to recognize shaped fields
75def _reduce_recurse_dtype(fun, args, reductor, npreductor, jnpreductor, **kw): 1efabcd
76 x = args[0] 1efabcd
77 if x.dtype.names is None: 1efabcd
78 return fun(*args, **kw) 1efabcd
79 else:
80 acc = None 1efabcd
81 for name in x.dtype.names: 1efabcd
82 recargs = tuple(arg[name] for arg in args) 1efabcd
83 result = _reduce_recurse_dtype(fun, recargs, reductor, npreductor, jnpreductor, **kw) 1efabcd
85 dtype = x.dtype[name] 1efabcd
86 if dtype.ndim: 1efabcd
87 axis = tuple(range(-dtype.ndim, 0)) 1efabcd
88 red = jnpreductor if isinstance(result, jnp.ndarray) else npreductor 1efabcd
89 result = red(result, axis=axis) 1efabcd
91 if acc is None: 1efabcd
92 acc = result 1efabcd
93 else:
94 acc = reductor(acc, result) 1eabcd
96 assert acc.shape == _array.broadcast(*args).shape 1efabcd
97 return acc 1efabcd
99def sum_recurse_dtype(fun, *args, **kw): 1efabcd
100 return _reduce_recurse_dtype(fun, args, operator.add, numpy.sum, jnp.sum, **kw) 1efabcd
102def prod_recurse_dtype(fun, *args, **kw): 1efabcd
103 return _reduce_recurse_dtype(fun, args, operator.mul, numpy.prod, jnp.prod, **kw) 1efabcd
105def ufunc_recurse_dtype(ufunc, x, *args): 1efabcd
106 """ apply an ufunc to all the leaf fields """
108 allargs = (x, *args) 1efabcd
109 expected_shape = jnp.broadcast_shapes(*(x.shape for x in allargs)) 1efabcd
111 if x.dtype.names is None: 1efabcd
112 out = ufunc(*allargs) 1efabcd
113 else:
114 args = map(_array.StructuredArray, allargs) 1efabcd
115 out = tree_util.tree_map(ufunc, *args) 1efabcd
117 assert out.shape == expected_shape 1efabcd
118 return out 1efabcd