Coverage for src/bartz/mcmcloop.py: 97%
69 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
1# bartz/src/bartz/mcmcloop.py
2#
3# Copyright (c) 2024, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""
26Functions that implement the full BART posterior MCMC loop.
27"""
29import functools 1a
31import jax 1a
32from jax import random 1a
33from jax import debug 1a
34from jax import numpy as jnp 1a
35from jax import lax 1a
37from . import jaxext 1a
38from . import grove 1a
39from . import mcmcstep 1a
41@functools.partial(jax.jit, static_argnums=(1, 2, 3, 4)) 1a
42def run_mcmc(bart, n_burn, n_save, n_skip, callback, key): 1a
43 """
44 Run the MCMC for the BART posterior.
46 Parameters
47 ----------
48 bart : dict
49 The initial MCMC state, as created and updated by the functions in
50 `bartz.mcmcstep`.
51 n_burn : int
52 The number of initial iterations which are not saved.
53 n_save : int
54 The number of iterations to save.
55 n_skip : int
56 The number of iterations to skip between each saved iteration, plus 1.
57 callback : callable
58 An arbitrary function run at each iteration, called with the following
59 arguments, passed by keyword:
61 bart : dict
62 The current MCMC state.
63 burnin : bool
64 Whether the last iteration was in the burn-in phase.
65 i_total : int
66 The index of the last iteration (0-based).
67 i_skip : int
68 The index of the last iteration, starting from the last saved
69 iteration.
70 n_burn, n_save, n_skip : int
71 The corresponding arguments as-is.
73 Since this function is called under the jax jit, the values are not
74 available at the time the Python code is executed. Use the utilities in
75 `jax.debug` to access the values at actual runtime.
76 key : jax.dtypes.prng_key array
77 The key for random number generation.
79 Returns
80 -------
81 bart : dict
82 The final MCMC state.
83 burnin_trace : dict
84 The trace of the burn-in phase, containing the following subset of
85 fields from the `bart` dictionary, with an additional head index that
86 runs over MCMC iterations: 'sigma2', 'grow_prop_count',
87 'grow_acc_count', 'prune_prop_count', 'prune_acc_count'.
88 main_trace : dict
89 The trace of the main phase, containing the following subset of fields
90 from the `bart` dictionary, with an additional head index that runs
91 over MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees', plus
92 the fields in `burnin_trace`.
93 """
95 tracelist_burnin = 'sigma2', 'grow_prop_count', 'grow_acc_count', 'prune_prop_count', 'prune_acc_count', 'ratios' 1a
97 tracelist_main = tracelist_burnin + ('leaf_trees', 'var_trees', 'split_trees') 1a
99 callback_kw = dict(n_burn=n_burn, n_save=n_save, n_skip=n_skip) 1a
101 def inner_loop(carry, _, tracelist, burnin): 1a
102 bart, i_total, i_skip, key = carry 1a
103 key, subkey = random.split(key) 1a
104 bart = mcmcstep.step(bart, subkey) 1a
105 callback(bart=bart, burnin=burnin, i_total=i_total, i_skip=i_skip, **callback_kw) 1a
106 output = {key: bart[key] for key in tracelist if key in bart} 1a
107 return (bart, i_total + 1, i_skip + 1, key), output 1a
109 def empty_trace(bart, tracelist): 1a
110 return jax.vmap(lambda x: x, in_axes=None, out_axes=0, axis_size=0)(bart) 1a
112 if n_burn > 0: 1a
113 carry = bart, 0, 0, key 1a
114 burnin_loop = functools.partial(inner_loop, tracelist=tracelist_burnin, burnin=True) 1a
115 (bart, i_total, _, key), burnin_trace = lax.scan(burnin_loop, carry, None, n_burn) 1a
116 else:
117 i_total = 0 1a
118 burnin_trace = empty_trace(bart, tracelist_burnin) 1a
120 def outer_loop(carry, _): 1a
121 bart, i_total, key = carry 1a
122 main_loop = functools.partial(inner_loop, tracelist=[], burnin=False) 1a
123 inner_carry = bart, i_total, 0, key 1a
124 (bart, i_total, _, key), _ = lax.scan(main_loop, inner_carry, None, n_skip) 1a
125 output = {key: bart[key] for key in tracelist_main if key in bart} 1a
126 return (bart, i_total, key), output 1a
128 if n_save > 0: 128 ↛ 132line 128 didn't jump to line 132 because the condition on line 128 was always true1a
129 carry = bart, i_total, key 1a
130 (bart, _, _), main_trace = lax.scan(outer_loop, carry, None, n_save) 1a
131 else:
132 main_trace = empty_trace(bart, tracelist_main)
134 return bart, burnin_trace, main_trace 1a
136@functools.lru_cache 1a
137 # cache to make the callback function object unique, such that the jit
138 # of run_mcmc recognizes it
139def make_simple_print_callback(printevery): 1a
140 """
141 Create a logging callback function for MCMC iterations.
143 Parameters
144 ----------
145 printevery : int
146 The number of iterations between each log.
148 Returns
149 -------
150 callback : callable
151 A function in the format required by `run_mcmc`.
152 """
153 def callback(*, bart, burnin, i_total, i_skip, n_burn, n_save, n_skip): 1a
154 prop_total = len(bart['leaf_trees']) 1a
155 grow_prop = bart['grow_prop_count'] / prop_total 1a
156 prune_prop = bart['prune_prop_count'] / prop_total 1a
157 grow_acc = bart['grow_acc_count'] / bart['grow_prop_count'] 1a
158 prune_acc = bart['prune_acc_count'] / bart['prune_prop_count'] 1a
159 n_total = n_burn + n_save * n_skip 1a
160 printcond = (i_total + 1) % printevery == 0 1a
161 debug.callback(_simple_print_callback, burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond) 1a
162 return callback 1a
164def _simple_print_callback(burnin, i_total, n_total, grow_prop, grow_acc, prune_prop, prune_acc, printcond): 1a
165 if printcond: 1a
166 burnin_flag = ' (burnin)' if burnin else '' 1a
167 total_str = str(n_total) 1a
168 ndigits = len(total_str) 1a
169 i_str = str(i_total + 1).rjust(ndigits) 1a
170 print(f'Iteration {i_str}/{total_str} ' 1a
171 f'P_grow={grow_prop:.2f} P_prune={prune_prop:.2f} '
172 f'A_grow={grow_acc:.2f} A_prune={prune_acc:.2f}{burnin_flag}')
174@jax.jit 1a
175def evaluate_trace(trace, X): 1a
176 """
177 Compute predictions for all iterations of the BART MCMC.
179 Parameters
180 ----------
181 trace : dict
182 A trace of the BART MCMC, as returned by `run_mcmc`.
183 X : array (p, n)
184 The predictors matrix, with `p` predictors and `n` observations.
186 Returns
187 -------
188 y : array (n_trace, n)
189 The predictions for each iteration of the MCMC.
190 """
191 evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False) 1a
192 evaluate_trees = jaxext.autobatch(evaluate_trees, 2 ** 29, (None, 0, 0, 0)) 1a
193 def loop(_, state): 1a
194 values = evaluate_trees(X, state['leaf_trees'], state['var_trees'], state['split_trees']) 1a
195 return None, jnp.sum(values, axis=0, dtype=jnp.float32) 1a
196 _, y = lax.scan(loop, None, trace) 1a
197 return y 1a