Coverage for src/lsqfitgp/_patch_jax.py: 88%
31 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_patch_jax.py
2#
3# Copyright (c) 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20""" modifications to the global state of jax """
22from jax import config 1fabcde
23from jax import tree_util 1fabcde
24import gvar 1fabcde
25import numpy 1fabcde
27config.update("jax_enable_x64", True) 1fabcde
29class BufferDictPyTreeDef: 1fabcde
31 @staticmethod 1fabcde
32 def _skeleton(bd): 1fabcde
33 """ Return a memoryless BufferDict with the same layout as `bd` """
34 return gvar.BufferDict(bd, buf=numpy.empty(bd.buf.shape, [])) 1abcde
35 # BufferDict mirrors the data type of the _buf attribute, so we do not
36 # need to preserve it to maintain consistency. buf is not copied.
38 def __init__(self, bd): 1fabcde
39 self.skeleton = self._skeleton(bd) 1abcde
40 self.layout = {k: tuple(bd.slice_shape(k)) for k in bd.keys()} 1abcde
41 # it is not necessary to save the data type because that's in buf
43 def __eq__(self, other): 1fabcde
44 if not isinstance(other, __class__): 44 ↛ 45line 44 didn't jump to line 45 because the condition on line 44 was never true1abcde
45 return NotImplemented
46 return self.layout == other.layout 1abcde
48 def __hash__(self): 1fabcde
49 return hash(self.layout)
51 def __repr__(self): 1fabcde
52 return repr(self.layout)
54 @classmethod 1fabcde
55 def flatten(cls, bd): 1fabcde
56 return (bd.buf,), cls(bd) 1abcde
58 @classmethod 1fabcde
59 def unflatten(cls, self, children): 1fabcde
60 buf, = children 1abcde
61 new = cls._skeleton(self.skeleton) 1abcde
62 # copy the skeleton to permit multiple unflattening
63 new._extension = {} 1abcde
64 new._buf = buf 1abcde
65 return new 1abcde
67# register BufferDict as a pytree
68tree_util.register_pytree_node(gvar.BufferDict, BufferDictPyTreeDef.flatten, BufferDictPyTreeDef.unflatten) 1fabcde
70# TODO the current implementation of BufferDict as pytree is not really
71# consistent with how JAX handles trees, because JAX expects to be allowed to
72# put arbitrary objects in the leaves; in particular, internally it sometimes
73# creates dummy trees filled with None. Maybe the current impl is fine with
74# this; _buf gets set to None, and assuming the BufferDict is never really
75# used in that crooked state, everything goes fine. The thing that this breaks
76# is a subsequent flattening of the dummy, I think JAX never does this. (The
77# reason for switching to buf-as-leaf in place of dict-values-as-leaves is that
78# the latter breaks tracing.)
79#
80# Maybe since BufferDict is simple and stable, I could read its code, bypass its
81# initialization altogether and set all the internal attributes to make it a
82# proper pytree but also compatible with tracing.
84# TODO try to drop BufferDict altogether. Currently I use it only in bcf and
85# bart to pass stuff to a precompiled function. In empbayes_fit it is rebuilt
86# by custom code.