Coverage for src/lsqfitgp/_linalg/_pytree.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_linalg/_pytree.py 

2# 

3# Copyright (c) 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import functools 1efabcd

21 

22import numpy 1efabcd

23from jax import numpy as jnp 1efabcd

24from jax import tree_util 1efabcd

25 

26class AutoPyTree: 1efabcd

27 """ 

28 Class adding automatic recursive support for jax pytree flattening 

29 """ 

30 

31 def _jax_vars(self): 1efabcd

32 """ Returns list of object attribute names which are to be considered 

33 children of the PyTree node """ 

34 return [ 1efabcd

35 n for n, v in vars(self).items() 

36 if isinstance(v, (jnp.ndarray, numpy.ndarray, AutoPyTree)) 

37 ] 

38 

39 def __init_subclass__(cls, **kw): 1efabcd

40 super().__init_subclass__(**kw) 1efabcd

41 tree_util.register_pytree_node_class(cls) 1efabcd

42 

43 # Since I decide dinamically which members are children based on their type, 

44 # I have to cache the jax pytree structure aux_data such that the 

45 # structure is preserved when constructing an object with tree_unflatten 

46 # with dummies as children. This happens in jax.jacfwd for some reason. 

47 @functools.cached_property 1efabcd

48 def _aux_data(self): 1efabcd

49 jax_vars = self._jax_vars() 1efabcd

50 other_vars = [ 1efabcd

51 (n, v) for n, v in vars(self).items() 

52 if n not in jax_vars 

53 ] 

54 # assert jax_vars 

55 return jax_vars, other_vars 1efabcd

56 

57 def tree_flatten(self): 1efabcd

58 """JAX PyTree encoder. See `jax.tree_util.tree_flatten`.""" 

59 jax_vars, _ = self._aux_data 1efabcd

60 # print(f'unpacking {jax_vars} from {self.__class__.__name__}') 

61 children = tuple(getattr(self, n) for n in jax_vars) 1efabcd

62 return children, self._aux_data 1efabcd

63 

64 @classmethod 1efabcd

65 def tree_unflatten(cls, aux_data, children): 1efabcd

66 """JAX PyTree decoder. See `jax.tree_util.tree_unflatten`.""" 

67 self = cls.__new__(cls) 1efabcd

68 self._aux_data = aux_data 1efabcd

69 jax_vars, other_vars = aux_data 1efabcd

70 for n, v in zip(jax_vars, children): 1efabcd

71 setattr(self, n, v) 1efabcd

72 for n, v in other_vars: 1efabcd

73 setattr(self, n, v) 1abcd

74 return self 1efabcd