Coverage for src/bartz/mcmcloop.py: 88%

133 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-05-29 23:01 +0000

1# bartz/src/bartz/mcmcloop.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions that implement the full BART posterior MCMC loop.""" 

26 

27import functools 1ab

28 

29import jax 1ab

30import numpy 1ab

31from jax import debug, lax, tree 1ab

32from jax import numpy as jnp 1ab

33from jaxtyping import Array, Real 1ab

34 

35from . import grove, jaxext, mcmcstep 1ab

36from .mcmcstep import State 1ab

37 

38 

39def default_onlymain_extractor(state: State) -> dict[str, Real[Array, 'samples *']]: 1ab

40 """Extract variables for the main trace, to be used in `run_mcmc`.""" 

41 return dict( 1ab

42 leaf_trees=state.forest.leaf_trees, 

43 var_trees=state.forest.var_trees, 

44 split_trees=state.forest.split_trees, 

45 offset=state.offset, 

46 ) 

47 

48 

49def default_both_extractor(state: State) -> dict[str, Real[Array, 'samples *'] | None]: 1ab

50 """Extract variables for main & burn-in traces, to be used in `run_mcmc`.""" 

51 return dict( 1ab

52 sigma2=state.sigma2, 

53 grow_prop_count=state.forest.grow_prop_count, 

54 grow_acc_count=state.forest.grow_acc_count, 

55 prune_prop_count=state.forest.prune_prop_count, 

56 prune_acc_count=state.forest.prune_acc_count, 

57 log_likelihood=state.forest.log_likelihood, 

58 log_trans_prior=state.forest.log_trans_prior, 

59 ) 

60 

61 

62def run_mcmc( 1ab

63 key, 

64 bart, 

65 n_save, 

66 *, 

67 n_burn=0, 

68 n_skip=1, 

69 inner_loop_length=None, 

70 allow_overflow=False, 

71 inner_callback=None, 

72 outer_callback=None, 

73 callback_state=None, 

74 onlymain_extractor=default_onlymain_extractor, 

75 both_extractor=default_both_extractor, 

76): 

77 """ 

78 Run the MCMC for the BART posterior. 

79 

80 Parameters 

81 ---------- 

82 key : jax.dtypes.prng_key array 

83 A key for random number generation. 

84 bart : dict 

85 The initial MCMC state, as created and updated by the functions in 

86 `bartz.mcmcstep`. The MCMC loop uses buffer donation to avoid copies, 

87 so this variable is invalidated after running `run_mcmc`. Make a copy 

88 beforehand to use it again. 

89 n_save : int 

90 The number of iterations to save. 

91 n_burn : int, default 0 

92 The number of initial iterations which are not saved. 

93 n_skip : int, default 1 

94 The number of iterations to skip between each saved iteration, plus 1. 

95 The effective burn-in is ``n_burn + n_skip - 1``. 

96 inner_loop_length : int, optional 

97 The MCMC loop is split into an outer and an inner loop. The outer loop 

98 is in Python, while the inner loop is in JAX. `inner_loop_length` is the 

99 number of iterations of the inner loop to run for each iteration of the 

100 outer loop. If not specified, the outer loop will iterate just once, 

101 with all iterations done in a single inner loop run. The inner stride is 

102 unrelated to the stride used for saving the trace. 

103 allow_overflow : bool, default False 

104 If `False`, `inner_loop_length` must be a divisor of the total number of 

105 iterations ``n_burn + n_skip * n_save``. If `True` and 

106 `inner_loop_length` is not a divisor, some of the MCMC iterations in the 

107 last outer loop iteration will not be saved to the trace. 

108 inner_callback : callable, optional 

109 outer_callback : callable, optional 

110 Arbitrary functions run during the loop after updating the state. 

111 `inner_callback` is called after each update, while `outer_callback` is 

112 called after completing an inner loop. The callbacks are invoked with 

113 the following arguments, passed by keyword: 

114 

115 bart : dict 

116 The MCMC state just after updating it. 

117 burnin : bool 

118 Whether the last iteration was in the burn-in phase. 

119 overflow : bool 

120 Whether the last iteration was in the overflow phase (iterations 

121 not saved due to `inner_loop_length` not being a divisor of the 

122 total number of iterations). 

123 i_total : int 

124 The index of the last MCMC iteration (0-based). 

125 i_skip : int 

126 The number of MCMC updates from the last saved state. The initial 

127 state counts as saved, even if it's not copied into the trace. 

128 callback_state : jax pytree 

129 The callback state, initially set to the argument passed to 

130 `run_mcmc`, afterwards to the value returned by the last invocation 

131 of `inner_callback` or `outer_callback`. 

132 n_burn, n_save, n_skip : int 

133 The corresponding arguments as-is. 

134 i_outer : int 

135 The index of the last outer loop iteration (0-based). 

136 inner_loop_length : int 

137 The number of MCMC iterations in the inner loop. 

138 

139 `inner_callback` is called under the jax jit, so the argument values are 

140 not available at the time the Python code is executed. Use the utilities 

141 in `jax.debug` to access the values at actual runtime. 

142 

143 The callbacks must return two values: 

144 

145 bart : dict 

146 A possibly modified MCMC state. To avoid modifying the state, 

147 return the `bart` argument passed to the callback as-is. 

148 callback_state : jax pytree 

149 The new state to be passed on the next callback invocation. 

150 

151 For convenience, if a callback returns `None`, the states are not 

152 updated. 

153 callback_state : jax pytree, optional 

154 The initial state for the callbacks. 

155 onlymain_extractor : callable, optional 

156 both_extractor : callable, optional 

157 Functions that extract the variables to be saved respectively only in 

158 the main trace and in both traces, given the MCMC state as argument. 

159 Must return a pytree, and must be vmappable. 

160 

161 Returns 

162 ------- 

163 bart : dict 

164 The final MCMC state. 

165 burnin_trace : dict of (n_burn, ...) arrays 

166 The trace of the burn-in phase, containing the following subset of 

167 fields from the `bart` dictionary, with an additional head index that 

168 runs over MCMC iterations: 'sigma2', 'grow_prop_count', 

169 'grow_acc_count', 'prune_prop_count', 'prune_acc_count' (or if specified 

170 the fields in `tracevars_both`). 

171 main_trace : dict of (n_save, ...) arrays 

172 The trace of the main phase, containing the following subset of fields 

173 from the `bart` dictionary, with an additional head index that runs over 

174 MCMC iterations: 'leaf_trees', 'var_trees', 'split_trees' (or if 

175 specified the fields in `tracevars_onlymain`), plus the fields in 

176 `burnin_trace`. 

177 

178 Raises 

179 ------ 

180 ValueError 

181 If `inner_loop_length` is not a divisor of the total number of 

182 iterations and `allow_overflow` is `False`. 

183 

184 Notes 

185 ----- 

186 The number of MCMC updates is ``n_burn + n_skip * n_save``. The traces do 

187 not include the initial state, and include the final state. 

188 """ 

189 

190 def empty_trace(length, bart, extractor): 1ab

191 return jax.vmap(extractor, in_axes=None, out_axes=0, axis_size=length)(bart) 1ab

192 

193 trace_both = empty_trace(n_burn + n_save, bart, both_extractor) 1ab

194 trace_onlymain = empty_trace(n_save, bart, onlymain_extractor) 1ab

195 

196 # determine number of iterations for inner and outer loops 

197 n_iters = n_burn + n_skip * n_save 1ab

198 if inner_loop_length is None: 1ab

199 inner_loop_length = n_iters 1ab

200 n_outer = n_iters // inner_loop_length 1ab

201 if n_iters % inner_loop_length: 1ab

202 if allow_overflow: 202 ↛ 205line 202 didn't jump to line 205 because the condition on line 202 was always true1ab

203 n_outer += 1 1ab

204 else: 

205 raise ValueError(f'{n_iters=} is not divisible by {inner_loop_length=}') 

206 

207 carry = (bart, 0, key, trace_both, trace_onlymain, callback_state) 1ab

208 for i_outer in range(n_outer): 1ab

209 carry = _run_mcmc_inner_loop( 1ab

210 carry, 

211 inner_loop_length, 

212 inner_callback, 

213 onlymain_extractor, 

214 both_extractor, 

215 n_burn, 

216 n_save, 

217 n_skip, 

218 i_outer, 

219 ) 

220 if outer_callback is not None: 1ab

221 bart, i_total, key, trace_both, trace_onlymain, callback_state = carry 1ab

222 i_total -= 1 # because i_total is updated at the end of the inner loop 1ab

223 i_skip = _compute_i_skip(i_total, n_burn, n_skip) 1ab

224 rt = outer_callback( 1ab

225 bart=bart, 

226 burnin=i_total < n_burn, 

227 overflow=i_total >= n_iters, 

228 i_total=i_total, 

229 i_skip=i_skip, 

230 callback_state=callback_state, 

231 n_burn=n_burn, 

232 n_save=n_save, 

233 n_skip=n_skip, 

234 i_outer=i_outer, 

235 inner_loop_length=inner_loop_length, 

236 ) 

237 if rt is not None: 237 ↛ 238line 237 didn't jump to line 238 because the condition on line 237 was never true1ab

238 bart, callback_state = rt 

239 i_total += 1 

240 carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state) 

241 

242 bart, _, _, trace_both, trace_onlymain, _ = carry 1ab

243 

244 burnin_trace = tree.map(lambda x: x[:n_burn, ...], trace_both) 1ab

245 main_trace = tree.map(lambda x: x[n_burn:, ...], trace_both) 1ab

246 main_trace.update(trace_onlymain) 1ab

247 

248 return bart, burnin_trace, main_trace 1ab

249 

250 

251def _compute_i_skip(i_total, n_burn, n_skip): 1ab

252 burnin = i_total < n_burn 1ab

253 return jnp.where( 1ab

254 burnin, 

255 i_total + 1, 

256 (i_total + 1) % n_skip + jnp.where(i_total + 1 < n_skip, n_burn, 0), 

257 ) 

258 

259 

260@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1, 2, 3, 4)) 1ab

261def _run_mcmc_inner_loop( 1ab

262 carry, 

263 inner_loop_length, 

264 inner_callback, 

265 onlymain_extractor, 

266 both_extractor, 

267 n_burn, 

268 n_save, 

269 n_skip, 

270 i_outer, 

271): 

272 def loop(carry, _): 1ab

273 bart, i_total, key, trace_both, trace_onlymain, callback_state = carry 1ab

274 

275 keys = jaxext.split(key) 1ab

276 key = keys.pop() 1ab

277 bart = mcmcstep.step(keys.pop(), bart) 1ab

278 

279 burnin = i_total < n_burn 1ab

280 if inner_callback is not None: 1ab

281 i_skip = _compute_i_skip(i_total, n_burn, n_skip) 1ab

282 rt = inner_callback( 1ab

283 bart=bart, 

284 burnin=burnin, 

285 overflow=i_total >= n_burn + n_save * n_skip, 

286 i_total=i_total, 

287 i_skip=i_skip, 

288 callback_state=callback_state, 

289 n_burn=n_burn, 

290 n_save=n_save, 

291 n_skip=n_skip, 

292 i_outer=i_outer, 

293 inner_loop_length=inner_loop_length, 

294 ) 

295 if rt is not None: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true1ab

296 bart, callback_state = rt 

297 

298 i_onlymain = jnp.where(burnin, 0, (i_total - n_burn) // n_skip) 1ab

299 i_both = jnp.where(burnin, i_total, n_burn + i_onlymain) 1ab

300 

301 def update_trace(index, trace, state): 1ab

302 def assign_at_index(trace_array, state_array): 1ab

303 if trace_array.size: 1ab

304 return trace_array.at[index, ...].set(state_array) 1ab

305 else: 

306 # this handles the case where a trace is empty (e.g., 

307 # no burn-in) because jax refuses to index into an array 

308 # of length 0 

309 return trace_array 1ab

310 

311 return tree.map(assign_at_index, trace, state) 1ab

312 

313 trace_onlymain = update_trace( 1ab

314 i_onlymain, trace_onlymain, onlymain_extractor(bart) 

315 ) 

316 trace_both = update_trace(i_both, trace_both, both_extractor(bart)) 1ab

317 

318 i_total += 1 1ab

319 carry = (bart, i_total, key, trace_both, trace_onlymain, callback_state) 1ab

320 return carry, None 1ab

321 

322 carry, _ = lax.scan(loop, carry, None, inner_loop_length) 1ab

323 return carry 1ab

324 

325 

326def make_print_callbacks(dot_every_inner=1, report_every_outer=1): 1ab

327 """ 

328 Prepare logging callbacks for `run_mcmc`. 

329 

330 Prepare callbacks which print a dot on every iteration, and a longer 

331 report outer loop iteration. 

332 

333 Parameters 

334 ---------- 

335 dot_every_inner : int, default 1 

336 A dot is printed every `dot_every_inner` MCMC iterations. 

337 report_every_outer : int, default 1 

338 A report is printed every `report_every_outer` outer loop 

339 iterations. 

340 

341 Returns 

342 ------- 

343 kwargs : dict 

344 A dictionary with the arguments to pass to `run_mcmc` as keyword 

345 arguments to set up the callbacks. 

346 

347 Examples 

348 -------- 

349 >>> run_mcmc(..., **make_print_callbacks()) 

350 """ 

351 return dict( 1ab

352 inner_callback=_print_callback_inner, 

353 outer_callback=_print_callback_outer, 

354 callback_state=dict( 

355 dot_every_inner=dot_every_inner, report_every_outer=report_every_outer 

356 ), 

357 ) 

358 

359 

360def _print_callback_inner(*, i_total, callback_state, **_): 1ab

361 dot_every_inner = callback_state['dot_every_inner'] 1ab

362 if dot_every_inner is not None: 362 ↛ exitline 362 didn't return from function '_print_callback_inner' because the condition on line 362 was always true1ab

363 cond = (i_total + 1) % dot_every_inner == 0 1ab

364 debug.callback(_print_dot, cond) 1ab

365 

366 

367def _print_dot(cond): 1ab

368 if cond: 

369 print('.', end='', flush=True) 

370 

371 

372def _print_callback_outer( 1ab

373 *, 

374 bart, 

375 burnin, 

376 overflow, 

377 i_total, 

378 n_burn, 

379 n_save, 

380 n_skip, 

381 callback_state, 

382 i_outer, 

383 inner_loop_length, 

384 **_, 

385): 

386 report_every_outer = callback_state['report_every_outer'] 1ab

387 if report_every_outer is not None: 387 ↛ exitline 387 didn't return from function '_print_callback_outer' because the condition on line 387 was always true1ab

388 dot_every_inner = callback_state['dot_every_inner'] 1ab

389 if dot_every_inner is None: 389 ↛ 390line 389 didn't jump to line 390 because the condition on line 389 was never true1ab

390 newline = False 

391 else: 

392 newline = dot_every_inner < inner_loop_length 1ab

393 debug.callback( 1ab

394 _print_report, 

395 cond=(i_outer + 1) % report_every_outer == 0, 

396 newline=newline, 

397 burnin=burnin, 

398 overflow=overflow, 

399 i_total=i_total, 

400 n_iters=n_burn + n_save * n_skip, 

401 grow_prop_count=bart.forest.grow_prop_count, 

402 grow_acc_count=bart.forest.grow_acc_count, 

403 prune_prop_count=bart.forest.prune_prop_count, 

404 prune_acc_count=bart.forest.prune_acc_count, 

405 prop_total=len(bart.forest.leaf_trees), 

406 fill=grove.forest_fill(bart.forest.split_trees), 

407 ) 

408 

409 

410def _convert_jax_arrays_in_args(func): 1ab

411 """Remove jax arrays from a function arguments. 

412 

413 Converts all jax.Array instances in the arguments to either Python scalars 

414 or numpy arrays. 

415 """ 

416 

417 def convert_jax_arrays(pytree): 1ab

418 def convert_jax_arrays(val): 1ab

419 if not isinstance(val, jax.Array): 419 ↛ 420line 419 didn't jump to line 420 because the condition on line 419 was never true1ab

420 return val 

421 elif val.shape: 421 ↛ 422line 421 didn't jump to line 422 because the condition on line 421 was never true1ab

422 return numpy.array(val) 

423 else: 

424 return val.item() 1ab

425 

426 return tree.map(convert_jax_arrays, pytree) 1ab

427 

428 @functools.wraps(func) 1ab

429 def new_func(*args, **kw): 1ab

430 args = convert_jax_arrays(args) 1ab

431 kw = convert_jax_arrays(kw) 1ab

432 return func(*args, **kw) 1ab

433 

434 return new_func 1ab

435 

436 

437@_convert_jax_arrays_in_args 1ab

438# convert all jax arrays in arguments because operations on them could lead to 

439# deadlock with the main thread 

440def _print_report( 1ab

441 *, 

442 cond, 

443 newline, 

444 burnin, 

445 overflow, 

446 i_total, 

447 n_iters, 

448 grow_prop_count, 

449 grow_acc_count, 

450 prune_prop_count, 

451 prune_acc_count, 

452 prop_total, 

453 fill, 

454): 

455 if cond: 455 ↛ exitline 455 didn't return from function '_print_report' because the condition on line 455 was always true1ab

456 newline = '\n' if newline else '' 1ab

457 

458 def acc_string(acc_count, prop_count): 1ab

459 if prop_count: 1ab

460 return f'{acc_count / prop_count:.0%}' 1ab

461 else: 

462 return ' n/d' 1ab

463 

464 grow_prop = grow_prop_count / prop_total 1ab

465 prune_prop = prune_prop_count / prop_total 1ab

466 grow_acc = acc_string(grow_acc_count, grow_prop_count) 1ab

467 prune_acc = acc_string(prune_acc_count, prune_prop_count) 1ab

468 

469 if burnin: 1ab

470 flag = ' (burnin)' 1ab

471 elif overflow: 1ab

472 flag = ' (overflow)' 1ab

473 else: 

474 flag = '' 1ab

475 

476 print( 1ab

477 f'{newline}It {i_total + 1}/{n_iters} ' 

478 f'grow P={grow_prop:.0%} A={grow_acc}, ' 

479 f'prune P={prune_prop:.0%} A={prune_acc}, ' 

480 f'fill={fill:.0%}{flag}' 

481 ) 

482 

483 

484@jax.jit 1ab

485def evaluate_trace(trace, X): 1ab

486 """ 

487 Compute predictions for all iterations of the BART MCMC. 

488 

489 Parameters 

490 ---------- 

491 trace : dict 

492 A trace of the BART MCMC, as returned by `run_mcmc`. 

493 X : array (p, n) 

494 The predictors matrix, with `p` predictors and `n` observations. 

495 

496 Returns 

497 ------- 

498 y : array (n_trace, n) 

499 The predictions for each iteration of the MCMC. 

500 """ 

501 evaluate_trees = functools.partial(grove.evaluate_forest, sum_trees=False) 1ab

502 evaluate_trees = jaxext.autobatch(evaluate_trees, 2**29, (None, 0, 0, 0)) 1ab

503 

504 def loop(_, row): 1ab

505 values = evaluate_trees( 1ab

506 X, row['leaf_trees'], row['var_trees'], row['split_trees'] 

507 ) 

508 return None, row['offset'] + jnp.sum(values, axis=0, dtype=jnp.float32) 1ab

509 

510 _, y = lax.scan(loop, None, trace) 1ab

511 return y 1ab