Coverage for src/bartz/mcmcloop.py: 97%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 18:54 +0000

1# bartz/src/bartz/mcmcloop.py 

2# 

3# Copyright (c) 2024, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25""" 

26Functions that implement the full BART posterior MCMC loop. 

27""" 

28 

29import functools 1a

30 

31import jax 1a

32from jax import random 1a

33from jax import debug 1a

34from jax import numpy as jnp 1a

35from jax import lax 1a

36 

37from . import jaxext 1a

38from . import grove 1a

39from . import mcmcstep 1a

40 

41@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) 1a

42def run_mcmc(bart, n_burn, n_save, n_skip, callback, key): 1a

43 """ 

44 Run the MCMC for the BART posterior. 

45 

46 Parameters 

47 ---------- 

48 bart : dict 

49 The initial MCMC state, as created and updated by the functions in 

50 `bartz.mcmcstep`. 

51 n_burn : int 

52 The number of initial iterations which are not saved. 

53 n_save : int 

54 The number of iterations to save. 

55 n_skip : int 

56 The number of iterations to skip between each saved iteration, plus 1. 

57 callback : callable 

58 An arbitrary function run at each iteration, called with the following 

59 arguments, passed by keyword: 

60 

61 bart : dict 

62 The current MCMC state. 

63 burnin : bool 

64 Whether the last iteration was in the burn-in phase. 

65 i_total : int 

66 The index of the last iteration (0-based). 

67 i_skip : int 

68 The index of the last iteration, starting from the last saved 

69 iteration. 

70 n_burn, n_save, n_skip : int 

71 The corresponding arguments as-is. 

72 

73 Since this function is called under the jax jit, the values are not 

74 available at the time the Python code is executed. Use the utilities in 

75 `jax.debug` to access the values at actual runtime. 

76 key : jax.dtypes.prng_key array 

77 The key for random number generation. 

78 

79 Returns 

80 ------- 

81 bart : dict 

82 The final MCMC state. 

83 burnin_trace : dict 

84 The trace of the burn-in phase, containing the following subset of 

85 fields from the `bart` dictionary, with an additional head index that 

86 runs over MCMC iterations: 'sigma2', 'grow_prop_count', 

87 'grow_acc_count', 'prune_prop_count', 'prune_acc_count'. 

88 main_trace : dict 

89 The trace of the main phase, containing the following subset of fields 

90 from the `bart` dictionary, with an additional head index that runs 

91 over MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees', plus 

92 the fields in `burnin_trace`. 

93 """ 

94 

95 tracelist_burnin = 'sigma2', 'grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count', 'ratios' 1a

96 

97 tracelist_main = tracelist_burnin + ('leaf_trees', 'var_trees', 'split_trees') 1a

98 

99 callback_kw = dict(n_burn=n_burn, n_save=n_save, n_skip=n_skip) 1a

100 

101 def inner_loop(carry, _, tracelist, burnin): 1a

102 bart, i_total, i_skip, key = carry 1a

103 key, subkey = random.split(key) 1a

104 bart = mcmcstep.step(bart, subkey) 1a

105 callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw) 1a

106 output = {key: bart[key] for key in tracelist if key in bart} 1a

107 return (bart, i_total + 1, i_skip + 1, key), output 1a

108 

109 def empty_trace(bart, tracelist): 1a

110 return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=0)(bart) 1a

111 

112 if n_burn > 0: 1a

113 carry = bart, 0, 0, key 1a

114 burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True) 1a

115 (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn) 1a

116 else: 

117 i_total = 0 1a

118 burnin_trace = empty_trace(bart, tracelist_burnin) 1a

119 

120 def outer_loop(carry, _): 1a

121 bart, i_total, key = carry 1a

122 main_loop = functools.partial(inner_loop, tracelist=[], burnin=False) 1a

123 inner_carry = bart, i_total, 0, key 1a

124 (bart, i_total, _, key), _ = lax.scan(main_loop, inner_carry, None, n_skip) 1a

125 output = {key: bart[key] for key in tracelist_main if key in bart} 1a

126 return (bart, i_total, key), output 1a

127 

128 if n_save > 0: 128 ↛ 132line 128 didn't jump to line 132 because the condition on line 128 was always true1a

129 carry = bart, i_total, key 1a

130 (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save) 1a

131 else: 

132 main_trace = empty_trace(bart, tracelist_main) 

133 

134 return bart, burnin_trace, main_trace 1a

135 

136@functools.lru_cache 1a

137 # cache to make the callback function object unique, such that the jit 

138 # of run_mcmc recognizes it 

139def make_simple_print_callback(printevery): 1a

140 """ 

141 Create a logging callback function for MCMC iterations. 

142 

143 Parameters 

144 ---------- 

145 printevery : int 

146 The number of iterations between each log. 

147 

148 Returns 

149 ------- 

150 callback : callable 

151 A function in the format required by `run_mcmc`. 

152 """ 

153 def callback(*, bart, burnin, i_total, i_skip, n_burn, n_save, n_skip): 1a

154 prop_total = len(bart['leaf_trees']) 1a

155 grow_prop = bart['grow_prop_count'] / prop_total 1a

156 prune_prop = bart['prune_prop_count'] / prop_total 1a

157 grow_acc = bart['grow_acc_count'] / bart['grow_prop_count'] 1a

158 prune_acc = bart['prune_acc_count'] / bart['prune_prop_count'] 1a

159 n_total = n_burn + n_save * n_skip 1a

160 printcond = (i_total + 1) % printevery == 0 1a

161 debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond) 1a

162 return callback 1a

163 

164def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond): 1a

165 if printcond: 1a

166 burnin_flag = ' (burnin)' if burnin else '' 1a

167 total_str = str(n_total) 1a

168 ndigits = len(total_str) 1a

169 i_str = str(i_total + 1).rjust(ndigits) 1a

170 print(f'Iteration {i_str}/{total_str} ' 1a

171 f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} ' 

172 f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}') 

173 

174@jax.jit 1a

175def evaluate_trace(trace, X): 1a

176 """ 

177 Compute predictions for all iterations of the BART MCMC. 

178 

179 Parameters 

180 ---------- 

181 trace : dict 

182 A trace of the BART MCMC, as returned by `run_mcmc`. 

183 X : array (p, n) 

184 The predictors matrix, with `p` predictors and `n` observations. 

185 

186 Returns 

187 ------- 

188 y : array (n_trace, n) 

189 The predictions for each iteration of the MCMC. 

190 """ 

191 evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False) 1a

192 evaluate_trees = jaxext.autobatch(evaluate_trees, 2 ** 29, (None, 0, 0, 0)) 1a

193 def loop(_, state): 1a

194 values = evaluate_trees(X, state['leaf_trees'], state['var_trees'], state['split_trees']) 1a

195 return None, jnp.sum(values, axis=0, dtype=jnp.float32) 1a

196 _, y = lax.scan(loop, None, trace) 1a

197 return y 1a