Coverage for src/bartz/mcmcloop.py: 88%
133 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
1# bartz/src/bartz/mcmcloop.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions that implement the full BART posterior MCMC loop."""
27import functools 1ab
29import jax 1ab
30import numpy 1ab
31from jax import debug, lax, tree 1ab
32from jax import numpy as jnp 1ab
33from jaxtyping import Array, Real 1ab
35from . import grove, jaxext, mcmcstep 1ab
36from .mcmcstep import State 1ab
39def default_onlymain_extractor(state: State) -> dict[str, Real[Array, 'samples *']]: 1ab
40 """Extract variables for the main trace, to be used in `run_mcmc`."""
41 return dict( 1ab
42 leaf_trees=state.forest.leaf_trees,
43 var_trees=state.forest.var_trees,
44 split_trees=state.forest.split_trees,
45 offset=state.offset,
46 )
49def default_both_extractor(state: State) -> dict[str, Real[Array, 'samples *'] | None]: 1ab
50 """Extract variables for main & burn-in traces, to be used in `run_mcmc`."""
51 return dict( 1ab
52 sigma2=state.sigma2,
53 grow_prop_count=state.forest.grow_prop_count,
54 grow_acc_count=state.forest.grow_acc_count,
55 prune_prop_count=state.forest.prune_prop_count,
56 prune_acc_count=state.forest.prune_acc_count,
57 log_likelihood=state.forest.log_likelihood,
58 log_trans_prior=state.forest.log_trans_prior,
59 )
62def run_mcmc( 1ab
63 key,
64 bart,
65 n_save,
66 *,
67 n_burn=0,
68 n_skip=1,
69 inner_loop_length=None,
70 allow_overflow=False,
71 inner_callback=None,
72 outer_callback=None,
73 callback_state=None,
74 onlymain_extractor=default_onlymain_extractor,
75 both_extractor=default_both_extractor,
76):
77 """
78 Run the MCMC for the BART posterior.
80 Parameters
81 ----------
82 key : jax.dtypes.prng_key array
83 A key for random number generation.
84 bart : dict
85 The initial MCMC state, as created and updated by the functions in
86 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies,
87 so this variable is invalidated after running `run_mcmc`. Make a copy
88 beforehand to use it again.
89 n_save : int
90 The number of iterations to save.
91 n_burn : int, default 0
92 The number of initial iterations which are not saved.
93 n_skip : int, default 1
94 The number of iterations to skip between each saved iteration, plus 1.
95 The effective burn-in is ``n_burn + n_skip - 1``.
96 inner_loop_length : int, optional
97 The MCMC loop is split into an outer and an inner loop. The outer loop
98 is in Python, while the inner loop is in JAX. `inner_loop_length` is the
99 number of iterations of the inner loop to run for each iteration of the
100 outer loop. If not specified, the outer loop will iterate just once,
101 with all iterations done in a single inner loop run. The inner stride is
102 unrelated to the stride used for saving the trace.
103 allow_overflow : bool, default False
104 If `False`, `inner_loop_length` must be a divisor of the total number of
105 iterations ``n_burn + n_skip * n_save``. If `True` and
106 `inner_loop_length` is not a divisor, some of the MCMC iterations in the
107 last outer loop iteration will not be saved to the trace.
108 inner_callback : callable, optional
109 outer_callback : callable, optional
110 Arbitrary functions run during the loop after updating the state.
111 `inner_callback` is called after each update, while `outer_callback` is
112 called after completing an inner loop. The callbacks are invoked with
113 the following arguments, passed by keyword:
115 bart : dict
116 The MCMC state just after updating it.
117 burnin : bool
118 Whether the last iteration was in the burn-in phase.
119 overflow : bool
120 Whether the last iteration was in the overflow phase (iterations
121 not saved due to `inner_loop_length` not being a divisor of the
122 total number of iterations).
123 i_total : int
124 The index of the last MCMC iteration (0-based).
125 i_skip : int
126 The number of MCMC updates from the last saved state. The initial
127 state counts as saved, even if it's not copied into the trace.
128 callback_state : jax pytree
129 The callback state, initially set to the argument passed to
130 `run_mcmc`, afterwards to the value returned by the last invocation
131 of `inner_callback` or `outer_callback`.
132 n_burn, n_save, n_skip : int
133 The corresponding arguments as-is.
134 i_outer : int
135 The index of the last outer loop iteration (0-based).
136 inner_loop_length : int
137 The number of MCMC iterations in the inner loop.
139 `inner_callback` is called under the jax jit, so the argument values are
140 not available at the time the Python code is executed. Use the utilities
141 in `jax.debug` to access the values at actual runtime.
143 The callbacks must return two values:
145 bart : dict
146 A possibly modified MCMC state. To avoid modifying the state,
147 return the `bart` argument passed to the callback as-is.
148 callback_state : jax pytree
149 The new state to be passed on the next callback invocation.
151 For convenience, if a callback returns `None`, the states are not
152 updated.
153 callback_state : jax pytree, optional
154 The initial state for the callbacks.
155 onlymain_extractor : callable, optional
156 both_extractor : callable, optional
157 Functions that extract the variables to be saved respectively only in
158 the main trace and in both traces, given the MCMC state as argument.
159 Must return a pytree, and must be vmappable.
161 Returns
162 -------
163 bart : dict
164 The final MCMC state.
165 burnin_trace : dict of (n_burn, ...) arrays
166 The trace of the burn-in phase, containing the following subset of
167 fields from the `bart` dictionary, with an additional head index that
168 runs over MCMC iterations: 'sigma2', 'grow_prop_count',
169 'grow_acc_count', 'prune_prop_count', 'prune_acc_count' (or if specified
170 the fields in `tracevars_both`).
171 main_trace : dict of (n_save, ...) arrays
172 The trace of the main phase, containing the following subset of fields
173 from the `bart` dictionary, with an additional head index that runs over
174 MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees' (or if
175 specified the fields in `tracevars_onlymain`), plus the fields in
176 `burnin_trace`.
178 Raises
179 ------
180 ValueError
181 If `inner_loop_length` is not a divisor of the total number of
182 iterations and `allow_overflow` is `False`.
184 Notes
185 -----
186 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do
187 not include the initial state, and include the final state.
188 """
190 def empty_trace(length, bart, extractor): 1ab
191 return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart) 1ab
193 trace_both = empty_trace(n_burn + n_save, bart, both_extractor) 1ab
194 trace_onlymain = empty_trace(n_save, bart, onlymain_extractor) 1ab
196 # determine number of iterations for inner and outer loops
197 n_iters = n_burn + n_skip * n_save 1ab
198 if inner_loop_length is None: 1ab
199 inner_loop_length = n_iters 1ab
200 n_outer = n_iters // inner_loop_length 1ab
201 if n_iters % inner_loop_length: 1ab
202 if allow_overflow: 202 ↛ 205line 202 didn't jump to line 205 because the condition on line 202 was always true1ab
203 n_outer += 1 1ab
204 else:
205 raise ValueError(f'{n_iters=} is not divisible by {inner_loop_length=}')
207 carry = (bart, 0, key, trace_both, trace_onlymain, callback_state) 1ab
208 for i_outer in range(n_outer): 1ab
209 carry = _run_mcmc_inner_loop( 1ab
210 carry,
211 inner_loop_length,
212 inner_callback,
213 onlymain_extractor,
214 both_extractor,
215 n_burn,
216 n_save,
217 n_skip,
218 i_outer,
219 )
220 if outer_callback is not None: 1ab
221 bart, i_total, key, trace_both, trace_onlymain, callback_state = carry 1ab
222 i_total -= 1 # because i_total is updated at the end of the inner loop 1ab
223 i_skip = _compute_i_skip(i_total, n_burn, n_skip) 1ab
224 rt = outer_callback( 1ab
225 bart=bart,
226 burnin=i_total < n_burn,
227 overflow=i_total >= n_iters,
228 i_total=i_total,
229 i_skip=i_skip,
230 callback_state=callback_state,
231 n_burn=n_burn,
232 n_save=n_save,
233 n_skip=n_skip,
234 i_outer=i_outer,
235 inner_loop_length=inner_loop_length,
236 )
237 if rt is not None: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true1ab
238 bart, callback_state = rt
239 i_total += 1
240 carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state)
242 bart, _, _, trace_both, trace_onlymain, _ = carry 1ab
244 burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_both) 1ab
245 main_trace = tree.map(lambda x: x[n_burn:, ...], trace_both) 1ab
246 main_trace.update(trace_onlymain) 1ab
248 return bart, burnin_trace, main_trace 1ab
251def _compute_i_skip(i_total, n_burn, n_skip): 1ab
252 burnin = i_total < n_burn 1ab
253 return jnp.where( 1ab
254 burnin,
255 i_total + 1,
256 (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0),
257 )
260@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4)) 1ab
261def _run_mcmc_inner_loop( 1ab
262 carry,
263 inner_loop_length,
264 inner_callback,
265 onlymain_extractor,
266 both_extractor,
267 n_burn,
268 n_save,
269 n_skip,
270 i_outer,
271):
272 def loop(carry, _): 1ab
273 bart, i_total, key, trace_both, trace_onlymain, callback_state = carry 1ab
275 keys = jaxext.split(key) 1ab
276 key = keys.pop() 1ab
277 bart = mcmcstep.step(keys.pop(), bart) 1ab
279 burnin = i_total < n_burn 1ab
280 if inner_callback is not None: 1ab
281 i_skip = _compute_i_skip(i_total, n_burn, n_skip) 1ab
282 rt = inner_callback( 1ab
283 bart=bart,
284 burnin=burnin,
285 overflow=i_total >= n_burn + n_save * n_skip,
286 i_total=i_total,
287 i_skip=i_skip,
288 callback_state=callback_state,
289 n_burn=n_burn,
290 n_save=n_save,
291 n_skip=n_skip,
292 i_outer=i_outer,
293 inner_loop_length=inner_loop_length,
294 )
295 if rt is not None: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true1ab
296 bart, callback_state = rt
298 i_onlymain = jnp.where(burnin, 0, (i_total - n_burn) // n_skip) 1ab
299 i_both = jnp.where(burnin, i_total, n_burn + i_onlymain) 1ab
301 def update_trace(index, trace, state): 1ab
302 def assign_at_index(trace_array, state_array): 1ab
303 if trace_array.size: 1ab
304 return trace_array.at[index, ...].set(state_array) 1ab
305 else:
306 # this handles the case where a trace is empty (e.g.,
307 # no burn-in) because jax refuses to index into an array
308 # of length 0
309 return trace_array 1ab
311 return tree.map(assign_at_index, trace, state) 1ab
313 trace_onlymain = update_trace( 1ab
314 i_onlymain, trace_onlymain, onlymain_extractor(bart)
315 )
316 trace_both = update_trace(i_both, trace_both, both_extractor(bart)) 1ab
318 i_total += 1 1ab
319 carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state) 1ab
320 return carry, None 1ab
322 carry, _ = lax.scan(loop, carry, None, inner_loop_length) 1ab
323 return carry 1ab
326def make_print_callbacks(dot_every_inner=1, report_every_outer=1): 1ab
327 """
328 Prepare logging callbacks for `run_mcmc`.
330 Prepare callbacks which print a dot on every iteration, and a longer
331 report outer loop iteration.
333 Parameters
334 ----------
335 dot_every_inner : int, default 1
336 A dot is printed every `dot_every_inner` MCMC iterations.
337 report_every_outer : int, default 1
338 A report is printed every `report_every_outer` outer loop
339 iterations.
341 Returns
342 -------
343 kwargs : dict
344 A dictionary with the arguments to pass to `run_mcmc` as keyword
345 arguments to set up the callbacks.
347 Examples
348 --------
349 >>> run_mcmc(..., **make_print_callbacks())
350 """
351 return dict( 1ab
352 inner_callback=_print_callback_inner,
353 outer_callback=_print_callback_outer,
354 callback_state=dict(
355 dot_every_inner=dot_every_inner, report_every_outer=report_every_outer
356 ),
357 )
360def _print_callback_inner(*, i_total, callback_state, **_): 1ab
361 dot_every_inner = callback_state['dot_every_inner'] 1ab
362 if dot_every_inner is not None: 362 ↛ exitline 362 didn't return from function '_print_callback_inner' because the condition on line 362 was always true1ab
363 cond = (i_total + 1) % dot_every_inner == 0 1ab
364 debug.callback(_print_dot, cond) 1ab
367def _print_dot(cond): 1ab
368 if cond:
369 print('.', end='', flush=True)
372def _print_callback_outer( 1ab
373 *,
374 bart,
375 burnin,
376 overflow,
377 i_total,
378 n_burn,
379 n_save,
380 n_skip,
381 callback_state,
382 i_outer,
383 inner_loop_length,
384 **_,
385):
386 report_every_outer = callback_state['report_every_outer'] 1ab
387 if report_every_outer is not None: 387 ↛ exitline 387 didn't return from function '_print_callback_outer' because the condition on line 387 was always true1ab
388 dot_every_inner = callback_state['dot_every_inner'] 1ab
389 if dot_every_inner is None: 389 ↛ 390line 389 didn't jump to line 390 because the condition on line 389 was never true1ab
390 newline = False
391 else:
392 newline = dot_every_inner < inner_loop_length 1ab
393 debug.callback( 1ab
394 _print_report,
395 cond=(i_outer + 1) % report_every_outer == 0,
396 newline=newline,
397 burnin=burnin,
398 overflow=overflow,
399 i_total=i_total,
400 n_iters=n_burn + n_save * n_skip,
401 grow_prop_count=bart.forest.grow_prop_count,
402 grow_acc_count=bart.forest.grow_acc_count,
403 prune_prop_count=bart.forest.prune_prop_count,
404 prune_acc_count=bart.forest.prune_acc_count,
405 prop_total=len(bart.forest.leaf_trees),
406 fill=grove.forest_fill(bart.forest.split_trees),
407 )
410def _convert_jax_arrays_in_args(func): 1ab
411 """Remove jax arrays from a function arguments.
413 Converts all jax.Array instances in the arguments to either Python scalars
414 or numpy arrays.
415 """
417 def convert_jax_arrays(pytree): 1ab
418 def convert_jax_arrays(val): 1ab
419 if not isinstance(val, jax.Array): 419 ↛ 420line 419 didn't jump to line 420 because the condition on line 419 was never true1ab
420 return val
421 elif val.shape: 421 ↛ 422line 421 didn't jump to line 422 because the condition on line 421 was never true1ab
422 return numpy.array(val)
423 else:
424 return val.item() 1ab
426 return tree.map(convert_jax_arrays, pytree) 1ab
428 @functools.wraps(func) 1ab
429 def new_func(*args, **kw): 1ab
430 args = convert_jax_arrays(args) 1ab
431 kw = convert_jax_arrays(kw) 1ab
432 return func(*args, **kw) 1ab
434 return new_func 1ab
437@_convert_jax_arrays_in_args 1ab
438# convert all jax arrays in arguments because operations on them could lead to
439# deadlock with the main thread
440def _print_report( 1ab
441 *,
442 cond,
443 newline,
444 burnin,
445 overflow,
446 i_total,
447 n_iters,
448 grow_prop_count,
449 grow_acc_count,
450 prune_prop_count,
451 prune_acc_count,
452 prop_total,
453 fill,
454):
455 if cond: 455 ↛ exitline 455 didn't return from function '_print_report' because the condition on line 455 was always true1ab
456 newline = '\n' if newline else '' 1ab
458 def acc_string(acc_count, prop_count): 1ab
459 if prop_count: 1ab
460 return f'{acc_count / prop_count:.0%}' 1ab
461 else:
462 return ' n/d' 1ab
464 grow_prop = grow_prop_count / prop_total 1ab
465 prune_prop = prune_prop_count / prop_total 1ab
466 grow_acc = acc_string(grow_acc_count, grow_prop_count) 1ab
467 prune_acc = acc_string(prune_acc_count, prune_prop_count) 1ab
469 if burnin: 1ab
470 flag = ' (burnin)' 1ab
471 elif overflow: 1ab
472 flag = ' (overflow)' 1ab
473 else:
474 flag = '' 1ab
476 print( 1ab
477 f'{newline}It {i_total + 1}/{n_iters} '
478 f'grow P={grow_prop:.0%} A={grow_acc}, '
479 f'prune P={prune_prop:.0%} A={prune_acc}, '
480 f'fill={fill:.0%}{flag}'
481 )
484@jax.jit 1ab
485def evaluate_trace(trace, X): 1ab
486 """
487 Compute predictions for all iterations of the BART MCMC.
489 Parameters
490 ----------
491 trace : dict
492 A trace of the BART MCMC, as returned by `run_mcmc`.
493 X : array (p, n)
494 The predictors matrix, with `p` predictors and `n` observations.
496 Returns
497 -------
498 y : array (n_trace, n)
499 The predictions for each iteration of the MCMC.
500 """
501 evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False) 1ab
502 evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0)) 1ab
504 def loop(_, row): 1ab
505 values = evaluate_trees( 1ab
506 X, row['leaf_trees'], row['var_trees'], row['split_trees']
507 )
508 return None, row['offset'] + jnp.sum(values, axis=0, dtype=jnp.float32) 1ab
510 _, y = lax.scan(loop, None, trace) 1ab
511 return y 1ab