Coverage for src/lsqfitgp/_patch_jax.py: 88%

31 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_patch_jax.py 

2# 

3# Copyright (c) 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" modifications to the global state of jax """ 

21 

22from jax import config 1fabcde

23from jax import tree_util 1fabcde

24import gvar 1fabcde

25import numpy 1fabcde

26 

27config.update("jax_enable_x64", True) 1fabcde

28 

29class BufferDictPyTreeDef: 1fabcde

30 

31 @staticmethod 1fabcde

32 def _skeleton(bd): 1fabcde

33 """ Return a memoryless BufferDict with the same layout as `bd` """ 

34 return gvar.BufferDict(bd, buf=numpy.empty(bd.buf.shape, [])) 1abcde

35 # BufferDict mirrors the data type of the _buf attribute, so we do not 

36 # need to preserve it to maintain consistency. buf is not copied. 

37 

38 def __init__(self, bd): 1fabcde

39 self.skeleton = self._skeleton(bd) 1abcde

40 self.layout = {k: tuple(bd.slice_shape(k)) for k in bd.keys()} 1abcde

41 # it is not necessary to save the data type because that's in buf 

42 

43 def __eq__(self, other): 1fabcde

44 if not isinstance(other, __class__): 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true1abcde

45 return NotImplemented 

46 return self.layout == other.layout 1abcde

47 

48 def __hash__(self): 1fabcde

49 return hash(self.layout) 

50 

51 def __repr__(self): 1fabcde

52 return repr(self.layout) 

53 

54 @classmethod 1fabcde

55 def flatten(cls, bd): 1fabcde

56 return (bd.buf,), cls(bd) 1abcde

57 

58 @classmethod 1fabcde

59 def unflatten(cls, self, children): 1fabcde

60 buf, = children 1abcde

61 new = cls._skeleton(self.skeleton) 1abcde

62 # copy the skeleton to permit multiple unflattening 

63 new._extension = {} 1abcde

64 new._buf = buf 1abcde

65 return new 1abcde

66 

67# register BufferDict as a pytree 

68tree_util.register_pytree_node(gvar.BufferDict, BufferDictPyTreeDef.flatten, BufferDictPyTreeDef.unflatten) 1fabcde

69 

70# TODO the current implementation of BufferDict as pytree is not really 

71# consistent with how JAX handles trees, because JAX expects to be allowed to 

72# put arbitrary objects in the leaves; in particular, internally it sometimes 

73# creates dummy trees filled with None. Maybe the current impl is fine with 

74# this; _buf gets set to None, and assuming the BufferDict is never really 

75# used in that crooked state, everything goes fine. The thing that this breaks 

76# is a subsequent flattening of the dummy, I think JAX never does this. (The 

77# reason for switching to buf-as-leaf in place of dict-values-as-leaves is that 

78# the latter breaks tracing.) 

79# 

80# Maybe since BufferDict is simple and stable, I could read its code, bypass its 

81# initialization altogether and set all the internal attributes to make it a 

82# proper pytree but also compatible with tracing. 

83 

84# TODO try to drop BufferDict altogether. Currently I use it only in bcf and 

85# bart to pass stuff to a precompiled function. In empbayes_fit it is rebuilt 

86# by custom code.