Coverage for src/lsqfitgp/_linalg/_pytree.py: 100%
29 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_linalg/_pytree.py
2#
3# Copyright (c) 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import functools 1efabcd
22import numpy 1efabcd
23from jax import numpy as jnp 1efabcd
24from jax import tree_util 1efabcd
26class AutoPyTree: 1efabcd
27 """
28 Class adding automatic recursive support for jax pytree flattening
29 """
31 def _jax_vars(self): 1efabcd
32 """ Returns list of object attribute names which are to be considered
33 children of the PyTree node """
34 return [ 1efabcd
35 n for n, v in vars(self).items()
36 if isinstance(v, (jnp.ndarray, numpy.ndarray, AutoPyTree))
37 ]
39 def __init_subclass__(cls, **kw): 1efabcd
40 super().__init_subclass__(**kw) 1efabcd
41 tree_util.register_pytree_node_class(cls) 1efabcd
43 # Since I decide dinamically which members are children based on their type,
44 # I have to cache the jax pytree structure aux_data such that the
45 # structure is preserved when constructing an object with tree_unflatten
46 # with dummies as children. This happens in jax.jacfwd for some reason.
47 @functools.cached_property 1efabcd
48 def _aux_data(self): 1efabcd
49 jax_vars = self._jax_vars() 1efabcd
50 other_vars = [ 1efabcd
51 (n, v) for n, v in vars(self).items()
52 if n not in jax_vars
53 ]
54 # assert jax_vars
55 return jax_vars, other_vars 1efabcd
57 def tree_flatten(self): 1efabcd
58 """JAX PyTree encoder. See `jax.tree_util.tree_flatten`."""
59 jax_vars, _ = self._aux_data 1efabcd
60 # print(f'unpacking {jax_vars} from {self.__class__.__name__}')
61 children = tuple(getattr(self, n) for n in jax_vars) 1efabcd
62 return children, self._aux_data 1efabcd
64 @classmethod 1efabcd
65 def tree_unflatten(cls, aux_data, children): 1efabcd
66 """JAX PyTree decoder. See `jax.tree_util.tree_unflatten`."""
67 self = cls.__new__(cls) 1efabcd
68 self._aux_data = aux_data 1efabcd
69 jax_vars, other_vars = aux_data 1efabcd
70 for n, v in zip(jax_vars, children): 1efabcd
71 setattr(self, n, v) 1efabcd
72 for n, v in other_vars: 1efabcd
73 setattr(self, n, v) 1abcd
74 return self 1efabcd