Coverage for src/bartz/mcmcstep.py: 91%
489 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
1# bartz/src/bartz/mcmcstep.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""
26Functions that implement the BART posterior MCMC initialization and update step.
28Functions that do MCMC steps operate by taking as input a bart state, and
29outputting a new state. The inputs are not modified.
31The main entry points are:
33 - `State`: The dataclass that represents a BART MCMC state.
34 - `init`: Creates an initial `State` from data and configurations.
35 - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
36"""
38import math 1ab
39from dataclasses import replace 1ab
40from functools import cache, partial 1ab
41from typing import Any 1ab
43import jax 1ab
44from equinox import Module, field 1ab
45from jax import lax, random 1ab
46from jax import numpy as jnp 1ab
47from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt 1ab
49from . import grove 1ab
50from .jaxext import minimal_unsigned_dtype, split, vmap_nodoc 1ab
53class Forest(Module): 1ab
54 """
55 Represents the MCMC state of a sum of trees.
57 Parameters
58 ----------
59 leaf_trees
60 The leaf values.
61 var_trees
62 The decision axes.
63 split_trees
64 The decision boundaries.
65 p_nonterminal
66 The probability of a nonterminal node at each depth, padded with a
67 zero.
68 p_propose_grow
69 The unnormalized probability of picking a leaf for a grow proposal.
70 leaf_indices
71 The index of the leaf each datapoints falls into, for each tree.
72 min_points_per_leaf
73 The minimum number of data points in a leaf node.
74 affluence_trees
75 Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
76 datapoints. If `min_points_per_leaf` is not specified, this is None.
77 resid_batch_size
78 count_batch_size
79 The data batch sizes for computing the sufficient statistics. If `None`,
80 they are computed with no batching.
81 log_trans_prior
82 The log transition and prior Metropolis-Hastings ratio for the
83 proposed move on each tree.
84 log_likelihood
85 The log likelihood ratio.
86 grow_prop_count
87 prune_prop_count
88 The number of grow/prune proposals made during one full MCMC cycle.
89 grow_acc_count
90 prune_acc_count
91 The number of grow/prune moves accepted during one full MCMC cycle.
92 sigma_mu2
93 The prior variance of a leaf, conditional on the tree structure.
94 """
96 leaf_trees: Float32[Array, 'num_trees 2**d'] 1ab
97 var_trees: UInt[Array, 'num_trees 2**(d-1)'] 1ab
98 split_trees: UInt[Array, 'num_trees 2**(d-1)'] 1ab
99 p_nonterminal: Float32[Array, 'd'] 1ab
100 p_propose_grow: Float32[Array, '2**(d-1)'] 1ab
101 leaf_indices: UInt[Array, 'num_trees n'] 1ab
102 min_points_per_leaf: Int32[Array, ''] | None 1ab
103 affluence_trees: Bool[Array, 'num_trees 2**(d-1)'] | None 1ab
104 resid_batch_size: int | None = field(static=True) 1ab
105 count_batch_size: int | None = field(static=True) 1ab
106 log_trans_prior: Float32[Array, 'num_trees'] | None 1ab
107 log_likelihood: Float32[Array, 'num_trees'] | None 1ab
108 grow_prop_count: Int32[Array, ''] 1ab
109 prune_prop_count: Int32[Array, ''] 1ab
110 grow_acc_count: Int32[Array, ''] 1ab
111 prune_acc_count: Int32[Array, ''] 1ab
112 sigma_mu2: Float32[Array, ''] 1ab
115class State(Module): 1ab
116 """
117 Represents the MCMC state of BART.
119 Parameters
120 ----------
121 X
122 The predictors.
123 max_split
124 The maximum split index for each predictor.
125 y
126 The response. If the data type is `bool`, the model is binary regression.
127 resid
128 The residuals (`y` or `z` minus sum of trees).
129 z
130 The latent variable for binary regression. `None` in continuous
131 regression.
132 offset
133 Constant shift added to the sum of trees.
134 sigma2
135 The error variance. `None` in binary regression.
136 prec_scale
137 The scale on the error precision, i.e., ``1 / error_scale ** 2``.
138 `None` in binary regression.
139 sigma2_alpha
140 sigma2_beta
141 The shape and scale parameters of the inverse gamma prior on the noise
142 variance. `None` in binary regression.
143 forest
144 The sum of trees model.
145 """
147 X: UInt[Array, 'p n'] 1ab
148 max_split: UInt[Array, 'p'] 1ab
149 y: Float32[Array, 'n'] | Bool[Array, 'n'] 1ab
150 z: None | Float32[Array, 'n'] 1ab
151 offset: Float32[Array, ''] 1ab
152 resid: Float32[Array, 'n'] 1ab
153 sigma2: Float32[Array, ''] | None 1ab
154 prec_scale: Float32[Array, 'n'] | None 1ab
155 sigma2_alpha: Float32[Array, ''] | None 1ab
156 sigma2_beta: Float32[Array, ''] | None 1ab
157 forest: Forest 1ab
160def init( 1ab
161 *,
162 X: UInt[Any, 'p n'],
163 y: Float32[Any, 'n'] | Bool[Any, 'n'],
164 offset: float | Float32[Any, ''] = 0.0,
165 max_split: UInt[Any, 'p'],
166 num_trees: int,
167 p_nonterminal: Float32[Any, 'd-1'],
168 sigma_mu2: float | Float32[Any, ''],
169 sigma2_alpha: float | Float32[Any, ''] | None = None,
170 sigma2_beta: float | Float32[Any, ''] | None = None,
171 error_scale: Float32[Any, 'n'] | None = None,
172 min_points_per_leaf: int | None = None,
173 resid_batch_size: int | None | str = 'auto',
174 count_batch_size: int | None | str = 'auto',
175 save_ratios: bool = False,
176) -> State:
177 """
178 Make a BART posterior sampling MCMC initial state.
180 Parameters
181 ----------
182 X
183 The predictors. Note this is trasposed compared to the usual convention.
184 y
185 The response. If the data type is `bool`, the regression model is binary
186 regression with probit.
187 offset
188 Constant shift added to the sum of trees. 0 if not specified.
189 max_split
190 The maximum split index for each variable. All split ranges start at 1.
191 num_trees
192 The number of trees in the forest.
193 p_nonterminal
194 The probability of a nonterminal node at each depth. The maximum depth
195 of trees is fixed by the length of this array.
196 sigma_mu2
197 The prior variance of a leaf, conditional on the tree structure. The
198 prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
199 prior mean of leaves is always zero.
200 sigma2_alpha
201 sigma2_beta
202 The shape and scale parameters of the inverse gamma prior on the error
203 variance. Leave unspecified for binary regression.
204 error_scale
205 Each error is scaled by the corresponding factor in `error_scale`, so
206 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
207 Not supported for binary regression. If not specified, defaults to 1 for
208 all points, but potentially skipping calculations.
209 min_points_per_leaf
210 The minimum number of data points in a leaf node. 0 if not specified.
211 resid_batch_size
212 count_batch_size
213 The batch sizes, along datapoints, for summing the residuals and
214 counting the number of datapoints in each leaf. `None` for no batching.
215 If 'auto', pick a value based on the device of `y`, or the default
216 device.
217 save_ratios
218 Whether to save the Metropolis-Hastings ratios.
220 Returns
221 -------
222 An initialized BART MCMC state.
224 Raises
225 ------
226 ValueError
227 If `y` is boolean and arguments unused in binary regression are set.
228 """
229 p_nonterminal = jnp.asarray(p_nonterminal) 1ab
230 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1ab
231 max_depth = p_nonterminal.size 1ab
233 @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1ab
234 def make_forest(max_depth, dtype): 1ab
235 return grove.make_tree(max_depth, dtype) 1ab
237 y = jnp.asarray(y) 1ab
238 offset = jnp.asarray(offset) 1ab
240 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 1ab
241 resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
242 )
244 is_binary = y.dtype == bool 1ab
245 if is_binary: 1ab
246 if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,): 246 ↛ 247line 246 didn't jump to line 247 because the condition on line 246 was never true1ab
247 raise ValueError(
248 'error_scale, sigma2_alpha, and sigma2_beta must be set '
249 ' to `None` for binary regression.'
250 )
251 sigma2 = None 1ab
252 else:
253 sigma2_alpha = jnp.asarray(sigma2_alpha) 1ab
254 sigma2_beta = jnp.asarray(sigma2_beta) 1ab
255 sigma2 = sigma2_beta / sigma2_alpha 1ab
256 # sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1)
257 # TODO: I don't like this isfinite check, these functions should be
258 # low-level and just do the thing. Why was it here?
260 bart = State( 1ab
261 X=jnp.asarray(X),
262 max_split=jnp.asarray(max_split),
263 y=y,
264 z=jnp.full(y.shape, offset) if is_binary else None,
265 offset=offset,
266 resid=jnp.zeros(y.shape) if is_binary else y - offset,
267 sigma2=sigma2,
268 prec_scale=(
269 None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
270 ),
271 sigma2_alpha=sigma2_alpha,
272 sigma2_beta=sigma2_beta,
273 forest=Forest(
274 leaf_trees=make_forest(max_depth, jnp.float32),
275 var_trees=make_forest(
276 max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)
277 ),
278 split_trees=make_forest(max_depth - 1, max_split.dtype),
279 grow_prop_count=jnp.zeros((), int),
280 grow_acc_count=jnp.zeros((), int),
281 prune_prop_count=jnp.zeros((), int),
282 prune_acc_count=jnp.zeros((), int),
283 p_nonterminal=p_nonterminal,
284 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
285 leaf_indices=jnp.ones(
286 (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
287 ),
288 min_points_per_leaf=(
289 None
290 if min_points_per_leaf is None
291 else jnp.asarray(min_points_per_leaf)
292 ),
293 affluence_trees=(
294 None
295 if min_points_per_leaf is None
296 else make_forest(max_depth - 1, bool)
297 .at[:, 1]
298 .set(y.size >= 2 * min_points_per_leaf)
299 ),
300 resid_batch_size=resid_batch_size,
301 count_batch_size=count_batch_size,
302 log_trans_prior=jnp.full(num_trees, jnp.nan) if save_ratios else None,
303 log_likelihood=jnp.full(num_trees, jnp.nan) if save_ratios else None,
304 sigma_mu2=jnp.asarray(sigma_mu2),
305 ),
306 )
308 return bart 1ab
311def _choose_suffstat_batch_size( 1ab
312 resid_batch_size, count_batch_size, y, forest_size
313) -> tuple[int | None, ...]:
314 @cache 1ab
315 def get_platform(): 1ab
316 try:
317 device = y.devices().pop()
318 except jax.errors.ConcretizationTypeError:
319 device = jax.devices()[0]
320 platform = device.platform
321 if platform not in ('cpu', 'gpu'):
322 raise KeyError(f'Unknown platform: {platform}')
323 return platform
325 if resid_batch_size == 'auto': 325 ↛ 326line 325 didn't jump to line 326 because the condition on line 325 was never true1ab
326 platform = get_platform()
327 n = max(1, y.size)
328 if platform == 'cpu':
329 resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6
330 elif platform == 'gpu':
331 resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
332 resid_batch_size = max(1, resid_batch_size)
334 if count_batch_size == 'auto': 334 ↛ 335line 334 didn't jump to line 335 because the condition on line 334 was never true1ab
335 platform = get_platform()
336 if platform == 'cpu':
337 count_batch_size = None
338 elif platform == 'gpu':
339 n = max(1, y.size)
340 count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
341 # /4 is good on V100, /2 on L4/T4, still haven't tried A100
342 max_memory = 2**29
343 itemsize = 4
344 min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
345 count_batch_size = max(count_batch_size, min_batch_size)
346 count_batch_size = max(1, count_batch_size)
348 return resid_batch_size, count_batch_size 1ab
351@jax.jit 1ab
352def step(key: Key[Array, ''], bart: State) -> State: 1ab
353 """
354 Do one MCMC step.
356 Parameters
357 ----------
358 key
359 A jax random key.
360 bart
361 A BART mcmc state, as created by `init`.
363 Returns
364 -------
365 The new BART mcmc state.
366 """
367 keys = split(key) 1ab
369 if bart.y.dtype == bool: # binary regression 1ab
370 bart = replace(bart, sigma2=jnp.float32(1)) 1ab
371 bart = step_trees(keys.pop(), bart) 1ab
372 bart = replace(bart, sigma2=None) 1ab
373 return step_z(keys.pop(), bart) 1ab
375 else: # continuous regression
376 bart = step_trees(keys.pop(), bart) 1ab
377 return step_sigma(keys.pop(), bart) 1ab
380def step_trees(key: Key[Array, ''], bart: State) -> State: 1ab
381 """
382 Forest sampling step of BART MCMC.
384 Parameters
385 ----------
386 key
387 A jax random key.
388 bart
389 A BART mcmc state, as created by `init`.
391 Returns
392 -------
393 The new BART mcmc state.
395 Notes
396 -----
397 This function zeroes the proposal counters.
398 """
399 keys = split(key) 1ab
400 moves = propose_moves(keys.pop(), bart.forest, bart.max_split) 1ab
401 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab
404class Moves(Module): 1ab
405 """
406 Moves proposed to modify each tree.
408 Parameters
409 ----------
410 allowed
411 Whether the move is possible in the first place. There are additional
412 constraints that could forbid it, but they are computed at acceptance
413 time.
414 grow
415 Whether the move is a grow move or a prune move.
416 num_growable
417 The number of growable leaves in the original tree.
418 node
419 The index of the leaf to grow or node to prune.
420 left
421 right
422 The indices of the children of 'node'.
423 partial_ratio
424 A factor of the Metropolis-Hastings ratio of the move. It lacks
425 the likelihood ratio and the probability of proposing the prune
426 move. If the move is PRUNE, the ratio is inverted. `None` once
427 `log_trans_prior_ratio` has been computed.
428 log_trans_prior_ratio
429 The logarithm of the product of the transition and prior terms of the
430 Metropolis-Hastings ratio for the acceptance of the proposed move.
431 `None` if not yet computed.
432 grow_var
433 The decision axes of the new rules.
434 grow_split
435 The decision boundaries of the new rules.
436 var_trees
437 The updated decision axes of the trees, valid whatever move.
438 logu
439 The logarithm of a uniform (0, 1] random variable to be used to
440 accept the move. It's in (-oo, 0].
441 acc
442 Whether the move was accepted. `None` if not yet computed.
443 to_prune
444 Whether the final operation to apply the move is pruning. This indicates
445 an accepted prune move or a rejected grow move. `None` if not yet
446 computed.
447 """
449 allowed: Bool[Array, 'num_trees'] 1ab
450 grow: Bool[Array, 'num_trees'] 1ab
451 num_growable: UInt[Array, 'num_trees'] 1ab
452 node: UInt[Array, 'num_trees'] 1ab
453 left: UInt[Array, 'num_trees'] 1ab
454 right: UInt[Array, 'num_trees'] 1ab
455 partial_ratio: Float32[Array, 'num_trees'] | None 1ab
456 log_trans_prior_ratio: None | Float32[Array, 'num_trees'] 1ab
457 grow_var: UInt[Array, 'num_trees'] 1ab
458 grow_split: UInt[Array, 'num_trees'] 1ab
459 var_trees: UInt[Array, 'num_trees 2**(d-1)'] 1ab
460 logu: Float32[Array, 'num_trees'] 1ab
461 acc: None | Bool[Array, 'num_trees'] 1ab
462 to_prune: None | Bool[Array, 'num_trees'] 1ab
465def propose_moves( 1ab
466 key: Key[Array, ''], forest: Forest, max_split: UInt[Array, 'p']
467) -> Moves:
468 """
469 Propose moves for all the trees.
471 There are two types of moves: GROW (convert a leaf to a decision node and
472 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
473 leaf, deleting its children).
475 Parameters
476 ----------
477 key
478 A jax random key.
479 forest
480 The `forest` field of a BART MCMC state.
481 max_split
482 The maximum split index for each variable, found in `State`.
484 Returns
485 -------
486 The proposed move for each tree.
487 """
488 num_trees, _ = forest.leaf_trees.shape 1ab
489 keys = split(key, 1 + 2 * num_trees) 1ab
491 # compute moves
492 grow_moves = propose_grow_moves( 1ab
493 keys.pop(num_trees),
494 forest.var_trees,
495 forest.split_trees,
496 forest.affluence_trees,
497 max_split,
498 forest.p_nonterminal,
499 forest.p_propose_grow,
500 )
501 prune_moves = propose_prune_moves( 1ab
502 keys.pop(num_trees),
503 forest.split_trees,
504 forest.affluence_trees,
505 forest.p_nonterminal,
506 forest.p_propose_grow,
507 )
509 u, logu = random.uniform(keys.pop(), (2, num_trees), jnp.float32) 1ab
511 # choose between grow or prune
512 grow_allowed = grow_moves.num_growable.astype(bool) 1ab
513 p_grow = jnp.where(grow_allowed & prune_moves.allowed, 0.5, grow_allowed) 1ab
514 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1ab
516 # compute children indices
517 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1ab
518 left = node << 1 1ab
519 right = left + 1 1ab
521 return Moves( 1ab
522 allowed=grow | prune_moves.allowed,
523 grow=grow,
524 num_growable=grow_moves.num_growable,
525 node=node,
526 left=left,
527 right=right,
528 partial_ratio=jnp.where(
529 grow, grow_moves.partial_ratio, prune_moves.partial_ratio
530 ),
531 log_trans_prior_ratio=None, # will be set in complete_ratio
532 grow_var=grow_moves.var,
533 grow_split=grow_moves.split,
534 var_trees=grow_moves.var_tree,
535 logu=jnp.log1p(-logu),
536 acc=None, # will be set in accept_moves_sequential_stage
537 to_prune=None, # will be set in accept_moves_sequential_stage
538 )
541class GrowMoves(Module): 1ab
542 """
543 Represent a proposed grow move for each tree.
545 Parameters
546 ----------
547 num_growable
548 The number of growable leaves.
549 node
550 The index of the leaf to grow. ``2 ** d`` if there are no growable
551 leaves.
552 var
553 split
554 The decision axis and boundary of the new rule.
555 partial_ratio
556 A factor of the Metropolis-Hastings ratio of the move. It lacks
557 the likelihood ratio and the probability of proposing the prune
558 move.
559 var_tree
560 The updated decision axes of the tree.
561 """
563 num_growable: UInt[Array, 'num_trees'] 1ab
564 node: UInt[Array, 'num_trees'] 1ab
565 var: UInt[Array, 'num_trees'] 1ab
566 split: UInt[Array, 'num_trees'] 1ab
567 partial_ratio: Float32[Array, 'num_trees'] 1ab
568 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
571@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None)) 1ab
572def propose_grow_moves( 1ab
573 key: Key[Array, ''],
574 var_tree: UInt[Array, '2**(d-1)'],
575 split_tree: UInt[Array, '2**(d-1)'],
576 affluence_tree: Bool[Array, '2**(d-1)'] | None,
577 max_split: UInt[Array, 'p'],
578 p_nonterminal: Float32[Array, 'd'],
579 p_propose_grow: Float32[Array, '2**(d-1)'],
580) -> GrowMoves:
581 """
582 Propose a GROW move for each tree.
584 A GROW move picks a leaf node and converts it to a non-terminal node with
585 two leaf children.
587 Parameters
588 ----------
589 key
590 A jax random key.
591 var_tree
592 The splitting axes of the tree.
593 split_tree
594 The splitting points of the tree.
595 affluence_tree
596 Whether a leaf has enough points to be grown.
597 max_split
598 The maximum split index for each variable.
599 p_nonterminal
600 The probability of a nonterminal node at each depth.
601 p_propose_grow
602 The unnormalized probability of choosing a leaf to grow.
604 Returns
605 -------
606 An object representing the proposed move.
608 Notes
609 -----
610 The move is not proposed if a leaf is already at maximum depth, or if a leaf
611 has less than twice the requested minimum number of datapoints per leaf.
612 This is marked by returning `num_growable` set to 0.
614 The move is also not be possible if the ancestors of a leaf have
615 exhausted the possible decision rules that lead to a non-empty selection.
616 This is marked by returning `var` set to `p` and `split` set to 0. But this
617 does not block the move from counting as "proposed", even though it is
618 predictably going to be rejected. This simplifies the MCMC and should not
619 reduce efficiency if not in unrealistic corner cases.
620 """
621 keys = split(key, 3) 1ab
623 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1ab
624 keys.pop(), split_tree, affluence_tree, p_propose_grow
625 )
627 var = choose_variable(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow) 1ab
628 var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)) 1ab
630 split_idx = choose_split(keys.pop(), var_tree, split_tree, max_split, leaf_to_grow) 1ab
632 ratio = compute_partial_ratio( 1ab
633 prob_choose, num_prunable, p_nonterminal, leaf_to_grow
634 )
636 return GrowMoves( 1ab
637 num_growable=num_growable,
638 node=leaf_to_grow,
639 var=var,
640 split=split_idx,
641 partial_ratio=ratio,
642 var_tree=var_tree,
643 )
645 # TODO it is not clear to me how var=p and split=0 when the move is not
646 # possible lead to corrent behavior downstream. Like, the move is proposed,
647 # but then it's a noop? And since it's a noop, it makes no difference if
648 # it's "accepted" or "rejected", it's like it's always rejected, so who
649 # cares if the likelihood ratio or a lot of other numbers are wrong? Uhm.
652def choose_leaf( 1ab
653 key: Key[Array, ''],
654 split_tree: UInt[Array, '2**(d-1)'],
655 affluence_tree: Bool[Array, '2**(d-1)'] | None,
656 p_propose_grow: Float32[Array, '2**(d-1)'],
657) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
658 """
659 Choose a leaf node to grow in a tree.
661 Parameters
662 ----------
663 key
664 A jax random key.
665 split_tree
666 The splitting points of the tree.
667 affluence_tree
668 Whether a leaf has enough points that it could be split into two leaves
669 satisfying the `min_points_per_leaf` requirement.
670 p_propose_grow
671 The unnormalized probability of choosing a leaf to grow.
673 Returns
674 -------
675 leaf_to_grow : int
676 The index of the leaf to grow. If ``num_growable == 0``, return
677 ``2 ** d``.
678 num_growable : int
679 The number of leaf nodes that can be grown, i.e., are nonterminal
680 and have at least twice `min_points_per_leaf` if set.
681 prob_choose : float
682 The (normalized) probability that this function had to choose that
683 specific leaf, given the arguments.
684 num_prunable : int
685 The number of leaf parents that could be pruned, after converting the
686 selected leaf to a non-terminal node.
687 """
688 is_growable = growable_leaves(split_tree, affluence_tree) 1ab
689 num_growable = jnp.count_nonzero(is_growable) 1ab
690 distr = jnp.where(is_growable, p_propose_grow, 0) 1ab
691 leaf_to_grow, distr_norm = categorical(key, distr) 1ab
692 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1ab
693 prob_choose = distr[leaf_to_grow] / distr_norm 1ab
694 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1ab
695 num_prunable = jnp.count_nonzero(is_parent) 1ab
696 return leaf_to_grow, num_growable, prob_choose, num_prunable 1ab
699def growable_leaves( 1ab
700 split_tree: UInt[Array, '2**(d-1)'],
701 affluence_tree: Bool[Array, '2**(d-1)'] | None,
702) -> Bool[Array, '2**(d-1)']:
703 """
704 Return a mask indicating the leaf nodes that can be proposed for growth.
706 The condition is that a leaf is not at the bottom level and has at least two
707 times the number of minimum points per leaf.
709 Parameters
710 ----------
711 split_tree
712 The splitting points of the tree.
713 affluence_tree
714 Whether a leaf has enough points to be grown.
716 Returns
717 -------
718 The mask indicating the leaf nodes that can be proposed to grow.
719 """
720 is_growable = grove.is_actual_leaf(split_tree) 1ab
721 if affluence_tree is not None: 1ab
722 is_growable &= affluence_tree 1ab
723 return is_growable 1ab
726def categorical( 1ab
727 key: Key[Array, ''], distr: Float32[Array, 'n']
728) -> tuple[Int32[Array, ''], Float32[Array, '']]:
729 """
730 Return a random integer from an arbitrary distribution.
732 Parameters
733 ----------
734 key
735 A jax random key.
736 distr
737 An unnormalized probability distribution.
739 Returns
740 -------
741 u : Int32[Array, '']
742 A random integer in the range ``[0, n)``. If all probabilities are zero,
743 return ``n``.
744 norm : Float32[Array, '']
745 The sum of `distr`.
746 """
747 ecdf = jnp.cumsum(distr) 1ab
748 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1ab
749 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1ab
752def choose_variable( 1ab
753 key: Key[Array, ''],
754 var_tree: UInt[Array, '2**(d-1)'],
755 split_tree: UInt[Array, '2**(d-1)'],
756 max_split: UInt[Array, 'p'],
757 leaf_index: Int32[Array, ''],
758) -> Int32[Array, '']:
759 """
760 Choose a variable to split on for a new non-terminal node.
762 Parameters
763 ----------
764 key
765 A jax random key.
766 var_tree
767 The variable indices of the tree.
768 split_tree
769 The splitting points of the tree.
770 max_split
771 The maximum split index for each variable.
772 leaf_index
773 The index of the leaf to grow.
775 Returns
776 -------
777 The index of the variable to split on.
779 Notes
780 -----
781 The variable is chosen among the variables that have a non-empty range of
782 allowed splits. If no variable has a non-empty range, return `p`.
783 """
784 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1ab
785 return randint_exclude(key, max_split.size, var_to_ignore) 1ab
788def fully_used_variables( 1ab
789 var_tree: UInt[Array, '2**(d-1)'],
790 split_tree: UInt[Array, '2**(d-1)'],
791 max_split: UInt[Array, 'p'],
792 leaf_index: Int32[Array, ''],
793) -> UInt[Array, 'd-2']:
794 """
795 Return a list of variables that have an empty split range at a given node.
797 Parameters
798 ----------
799 var_tree
800 The variable indices of the tree.
801 split_tree
802 The splitting points of the tree.
803 max_split
804 The maximum split index for each variable.
805 leaf_index
806 The index of the node, assumed to be valid for `var_tree`.
808 Returns
809 -------
810 The indices of the variables that have an empty split range.
812 Notes
813 -----
814 The number of unused variables is not known in advance. Unused values in the
815 array are filled with `p`. The fill values are not guaranteed to be placed
816 in any particular order, and variables may appear more than once.
817 """
818 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1ab
819 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1ab
820 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1ab
821 num_split = r - l 1ab
822 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1ab
825def ancestor_variables( 1ab
826 var_tree: UInt[Array, '2**(d-1)'],
827 max_split: UInt[Array, 'p'],
828 node_index: Int32[Array, ''],
829) -> UInt[Array, 'd-2']:
830 """
831 Return the list of variables in the ancestors of a node.
833 Parameters
834 ----------
835 var_tree : int array (2 ** (d - 1),)
836 The variable indices of the tree.
837 max_split : int array (p,)
838 The maximum split index for each variable. Used only to get `p`.
839 node_index : int
840 The index of the node, assumed to be valid for `var_tree`.
842 Returns
843 -------
844 The variable indices of the ancestors of the node.
846 Notes
847 -----
848 The ancestors are the nodes going from the root to the parent of the node.
849 The number of ancestors is not known at tracing time; unused spots in the
850 output array are filled with `p`.
851 """
852 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab
853 ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size)) 1ab
854 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1ab
856 def loop(carry, _): 1ab
857 i, index, ancestor_vars = carry 1ab
858 index >>= 1 1ab
859 var = var_tree[index] 1ab
860 var = jnp.where(index, var, max_split.size) 1ab
861 ancestor_vars = ancestor_vars.at[i].set(var) 1ab
862 return (i - 1, index, ancestor_vars), None 1ab
864 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1ab
865 return ancestor_vars 1ab
868def split_range( 1ab
869 var_tree: UInt[Array, '2**(d-1)'],
870 split_tree: UInt[Array, '2**(d-1)'],
871 max_split: UInt[Array, 'p'],
872 node_index: Int32[Array, ''],
873 ref_var: Int32[Array, ''],
874) -> tuple[Int32[Array, ''], Int32[Array, '']]:
875 """
876 Return the range of allowed splits for a variable at a given node.
878 Parameters
879 ----------
880 var_tree
881 The variable indices of the tree.
882 split_tree
883 The splitting points of the tree.
884 max_split
885 The maximum split index for each variable.
886 node_index
887 The index of the node, assumed to be valid for `var_tree`.
888 ref_var
889 The variable for which to measure the split range.
891 Returns
892 -------
893 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=0.
894 """
895 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab
896 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1ab
897 jnp.int32
898 )
899 carry = 0, initial_r, node_index 1ab
901 def loop(carry, _): 1ab
902 l, r, index = carry 1ab
903 right_child = (index & 1).astype(bool) 1ab
904 index >>= 1 1ab
905 split = split_tree[index] 1ab
906 cond = (var_tree[index] == ref_var) & index.astype(bool) 1ab
907 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1ab
908 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1ab
909 return (l, r, index), None 1ab
911 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1ab
912 return l + 1, r 1ab
915def randint_exclude( 1ab
916 key: Key[Array, ''], sup: int, exclude: Integer[Array, 'n']
917) -> Int32[Array, '']:
918 """
919 Return a random integer in a range, excluding some values.
921 Parameters
922 ----------
923 key
924 A jax random key.
925 sup
926 The exclusive upper bound of the range.
927 exclude
928 The values to exclude from the range. Values greater than or equal to
929 `sup` are ignored. Values can appear more than once.
931 Returns
932 -------
933 A random integer `u` in the range ``[0, sup)`` such that ``u not in exclude``.
935 Notes
936 -----
937 If all values in the range are excluded, return `sup`.
938 """
939 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1ab
940 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1ab
941 u = random.randint(key, (), 0, num_allowed) 1ab
943 def loop(u, i): 1ab
944 return jnp.where(i <= u, u + 1, u), None 1ab
946 u, _ = lax.scan(loop, u, exclude) 1ab
947 return u 1ab
950def choose_split( 1ab
951 key: Key[Array, ''],
952 var_tree: UInt[Array, '2**(d-1)'],
953 split_tree: UInt[Array, '2**(d-1)'],
954 max_split: UInt[Array, 'p'],
955 leaf_index: Int32[Array, ''],
956) -> Int32[Array, '']:
957 """
958 Choose a split point for a new non-terminal node.
960 Parameters
961 ----------
962 key
963 A jax random key.
964 var_tree
965 The splitting axes of the tree.
966 split_tree
967 The splitting points of the tree.
968 max_split
969 The maximum split index for each variable.
970 leaf_index
971 The index of the leaf to grow. It is assumed that `var_tree` already
972 contains the target variable at this index.
974 Returns
975 -------
976 The cutpoint. If ``var_tree[leaf_index]`` is out of bounds, return 0.
977 """
978 var = var_tree[leaf_index] 1ab
979 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1ab
980 return random.randint(key, (), l, r) 1ab
982 # TODO what happens if leaf_index is out of bounds? And is the value used
983 # in that case?
986def compute_partial_ratio( 1ab
987 prob_choose: Float32[Array, ''],
988 num_prunable: Int32[Array, ''],
989 p_nonterminal: Float32[Array, 'd'],
990 leaf_to_grow: Int32[Array, ''],
991) -> Float32[Array, '']:
992 """
993 Compute the product of the transition and prior ratios of a grow move.
995 Parameters
996 ----------
997 prob_choose
998 The probability that the leaf had to be chosen amongst the growable
999 leaves.
1000 num_prunable
1001 The number of leaf parents that could be pruned, after converting the
1002 leaf to be grown to a non-terminal node.
1003 p_nonterminal
1004 The probability of a nonterminal node at each depth.
1005 leaf_to_grow
1006 The index of the leaf to grow.
1008 Returns
1009 -------
1010 The partial transition ratio times the prior ratio.
1012 Notes
1013 -----
1014 The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1015 The "partial" transition ratio returned is missing the factor P(propose
1016 prune) in the numerator. The prior ratio is P(new tree) / P(old tree).
1017 """
1018 # the two ratios also contain factors num_available_split *
1019 # num_available_var, but they cancel out
1021 # p_prune can't be computed here because it needs the count trees, which are
1022 # computed in the acceptance phase
1024 prune_allowed = leaf_to_grow != 1 1ab
1025 # prune allowed <---> the initial tree is not a root
1026 # leaf to grow is root --> the tree can only be a root
1027 # tree is a root --> the only leaf I can grow is root
1029 p_grow = jnp.where(prune_allowed, 0.5, 1) 1ab
1031 inv_trans_ratio = p_grow * prob_choose * num_prunable 1ab
1033 depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow] 1ab
1034 p_parent = p_nonterminal[depth] 1ab
1035 cp_children = 1 - p_nonterminal[depth + 1] 1ab
1036 tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent) 1ab
1038 return tree_ratio / inv_trans_ratio 1ab
1041class PruneMoves(Module): 1ab
1042 """
1043 Represent a proposed prune move for each tree.
1045 Parameters
1046 ----------
1047 allowed
1048 Whether the move is possible.
1049 node
1050 The index of the node to prune. ``2 ** d`` if no node can be pruned.
1051 partial_ratio
1052 A factor of the Metropolis-Hastings ratio of the move. It lacks
1053 the likelihood ratio and the probability of proposing the prune
1054 move. This ratio is inverted, and is meant to be inverted back in
1055 `accept_move_and_sample_leaves`.
1056 """
1058 allowed: Bool[Array, 'num_trees'] 1ab
1059 node: UInt[Array, 'num_trees'] 1ab
1060 partial_ratio: Float32[Array, 'num_trees'] 1ab
1063@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 1ab
1064def propose_prune_moves( 1ab
1065 key: Key[Array, ''],
1066 split_tree: UInt[Array, '2**(d-1)'],
1067 affluence_tree: Bool[Array, '2**(d-1)'] | None,
1068 p_nonterminal: Float32[Array, 'd'],
1069 p_propose_grow: Float32[Array, '2**(d-1)'],
1070) -> PruneMoves:
1071 """
1072 Tree structure prune move proposal of BART MCMC.
1074 Parameters
1075 ----------
1076 key
1077 A jax random key.
1078 split_tree
1079 The splitting points of the tree.
1080 affluence_tree
1081 Whether a leaf has enough points to be grown.
1082 p_nonterminal
1083 The probability of a nonterminal node at each depth.
1084 p_propose_grow
1085 The unnormalized probability of choosing a leaf to grow.
1087 Returns
1088 -------
1089 An object representing the proposed moves.
1090 """
1091 node_to_prune, num_prunable, prob_choose = choose_leaf_parent( 1ab
1092 key, split_tree, affluence_tree, p_propose_grow
1093 )
1094 allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root 1ab
1096 ratio = compute_partial_ratio( 1ab
1097 prob_choose, num_prunable, p_nonterminal, node_to_prune
1098 )
1100 return PruneMoves( 1ab
1101 allowed=allowed,
1102 node=node_to_prune,
1103 partial_ratio=ratio,
1104 )
1107def choose_leaf_parent( 1ab
1108 key: Key[Array, ''],
1109 split_tree: UInt[Array, '2**(d-1)'],
1110 affluence_tree: Bool[Array, '2**(d-1)'] | None,
1111 p_propose_grow: Float32[Array, '2**(d-1)'],
1112) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, '']]:
1113 """
1114 Pick a non-terminal node with leaf children to prune in a tree.
1116 Parameters
1117 ----------
1118 key
1119 A jax random key.
1120 split_tree
1121 The splitting points of the tree.
1122 affluence_tree
1123 Whether a leaf has enough points to be grown.
1124 p_propose_grow
1125 The unnormalized probability of choosing a leaf to grow.
1127 Returns
1128 -------
1129 node_to_prune : Int32[Array, '']
1130 The index of the node to prune. If ``num_prunable == 0``, return
1131 ``2 ** d``.
1132 num_prunable : Int32[Array, '']
1133 The number of leaf parents that could be pruned.
1134 prob_choose : Float32[Array, '']
1135 The (normalized) probability that `choose_leaf` would chose
1136 `node_to_prune` as leaf to grow, if passed the tree where
1137 `node_to_prune` had been pruned.
1138 """
1139 is_prunable = grove.is_leaves_parent(split_tree) 1ab
1140 num_prunable = jnp.count_nonzero(is_prunable) 1ab
1141 node_to_prune = randint_masked(key, is_prunable) 1ab
1142 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1ab
1144 split_tree = split_tree.at[node_to_prune].set(0) 1ab
1145 if affluence_tree is not None: 1ab
1146 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1ab
1147 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1ab
1148 prob_choose = p_propose_grow[node_to_prune] 1ab
1149 prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf) 1ab
1151 return node_to_prune, num_prunable, prob_choose 1ab
1154def randint_masked(key: Key[Array, ''], mask: Bool[Array, 'n']) -> Int32[Array, '']: 1ab
1155 """
1156 Return a random integer in a range, including only some values.
1158 Parameters
1159 ----------
1160 key
1161 A jax random key.
1162 mask
1163 The mask indicating the allowed values.
1165 Returns
1166 -------
1167 A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
1169 Notes
1170 -----
1171 If all values in the mask are `False`, return `n`.
1172 """
1173 ecdf = jnp.cumsum(mask) 1ab
1174 u = random.randint(key, (), 0, ecdf[-1]) 1ab
1175 return jnp.searchsorted(ecdf, u, 'right') 1ab
1178def accept_moves_and_sample_leaves( 1ab
1179 key: Key[Array, ''], bart: State, moves: Moves
1180) -> State:
1181 """
1182 Accept or reject the proposed moves and sample the new leaf values.
1184 Parameters
1185 ----------
1186 key
1187 A jax random key.
1188 bart
1189 A valid BART mcmc state.
1190 moves
1191 The proposed moves, see `propose_moves`.
1193 Returns
1194 -------
1195 A new (valid) BART mcmc state.
1196 """
1197 pso = accept_moves_parallel_stage(key, bart, moves) 1ab
1198 bart, moves = accept_moves_sequential_stage(pso) 1ab
1199 return accept_moves_final_stage(bart, moves) 1ab
1202class Counts(Module): 1ab
1203 """
1204 Number of datapoints in the nodes involved in proposed moves for each tree.
1206 Parameters
1207 ----------
1208 left
1209 Number of datapoints in the left child.
1210 right
1211 Number of datapoints in the right child.
1212 total
1213 Number of datapoints in the parent (``= left + right``).
1214 """
1216 left: UInt[Array, 'num_trees'] 1ab
1217 right: UInt[Array, 'num_trees'] 1ab
1218 total: UInt[Array, 'num_trees'] 1ab
1221class Precs(Module): 1ab
1222 """
1223 Likelihood precision scale in the nodes involved in proposed moves for each tree.
1225 The "likelihood precision scale" of a tree node is the sum of the inverse
1226 squared error scales of the datapoints selected by the node.
1228 Parameters
1229 ----------
1230 left
1231 Likelihood precision scale in the left child.
1232 right
1233 Likelihood precision scale in the right child.
1234 total
1235 Likelihood precision scale in the parent (``= left + right``).
1236 """
1238 left: Float32[Array, 'num_trees'] 1ab
1239 right: Float32[Array, 'num_trees'] 1ab
1240 total: Float32[Array, 'num_trees'] 1ab
1243class PreLkV(Module): 1ab
1244 """
1245 Non-sequential terms of the likelihood ratio for each tree.
1247 These terms can be computed in parallel across trees.
1249 Parameters
1250 ----------
1251 sigma2_left
1252 The noise variance in the left child of the leaves grown or pruned by
1253 the moves.
1254 sigma2_right
1255 The noise variance in the right child of the leaves grown or pruned by
1256 the moves.
1257 sigma2_total
1258 The noise variance in the total of the leaves grown or pruned by the
1259 moves.
1260 sqrt_term
1261 The **logarithm** of the square root term of the likelihood ratio.
1262 """
1264 sigma2_left: Float32[Array, 'num_trees'] 1ab
1265 sigma2_right: Float32[Array, 'num_trees'] 1ab
1266 sigma2_total: Float32[Array, 'num_trees'] 1ab
1267 sqrt_term: Float32[Array, 'num_trees'] 1ab
1270class PreLk(Module): 1ab
1271 """
1272 Non-sequential terms of the likelihood ratio shared by all trees.
1274 Parameters
1275 ----------
1276 exp_factor
1277 The factor to multiply the likelihood ratio by, shared by all trees.
1278 """
1280 exp_factor: Float32[Array, ''] 1ab
1283class PreLf(Module): 1ab
1284 """
1285 Pre-computed terms used to sample leaves from their posterior.
1287 These terms can be computed in parallel across trees.
1289 Parameters
1290 ----------
1291 mean_factor
1292 The factor to be multiplied by the sum of the scaled residuals to
1293 obtain the posterior mean.
1294 centered_leaves
1295 The mean-zero normal values to be added to the posterior mean to
1296 obtain the posterior leaf samples.
1297 """
1299 mean_factor: Float32[Array, 'num_trees 2**d'] 1ab
1300 centered_leaves: Float32[Array, 'num_trees 2**d'] 1ab
1303class ParallelStageOut(Module): 1ab
1304 """
1305 The output of `accept_moves_parallel_stage`.
1307 Parameters
1308 ----------
1309 bart
1310 A partially updated BART mcmc state.
1311 moves
1312 The proposed moves, with `partial_ratio` set to `None` and
1313 `log_trans_prior_ratio` set to its final value.
1314 prec_trees
1315 The likelihood precision scale in each potential or actual leaf node. If
1316 there is no precision scale, this is the number of points in each leaf.
1317 move_counts
1318 The counts of the number of points in the the nodes modified by the
1319 moves. If `bart.min_points_per_leaf` is not set and
1320 `bart.prec_scale` is set, they are not computed.
1321 move_precs
1322 The likelihood precision scale in each node modified by the moves. If
1323 `bart.prec_scale` is not set, this is set to `move_counts`.
1324 prelkv
1325 prelk
1326 prelf
1327 Objects with pre-computed terms of the likelihood ratios and leaf
1328 samples.
1329 """
1331 bart: State 1ab
1332 moves: Moves 1ab
1333 prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d'] 1ab
1334 move_counts: Counts | None 1ab
1335 move_precs: Precs | Counts 1ab
1336 prelkv: PreLkV 1ab
1337 prelk: PreLk 1ab
1338 prelf: PreLf 1ab
1341def accept_moves_parallel_stage( 1ab
1342 key: Key[Array, ''], bart: State, moves: Moves
1343) -> ParallelStageOut:
1344 """
1345 Pre-computes quantities used to accept moves, in parallel across trees.
1347 Parameters
1348 ----------
1349 key : jax.dtypes.prng_key array
1350 A jax random key.
1351 bart : dict
1352 A BART mcmc state.
1353 moves : dict
1354 The proposed moves, see `propose_moves`.
1356 Returns
1357 -------
1358 An object with all that could be done in parallel.
1359 """
1360 # where the move is grow, modify the state like the move was accepted
1361 bart = replace( 1ab
1362 bart,
1363 forest=replace(
1364 bart.forest,
1365 var_trees=moves.var_trees,
1366 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1367 leaf_trees=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_trees, moves),
1368 ),
1369 )
1371 # count number of datapoints per leaf
1372 if bart.forest.min_points_per_leaf is not None or bart.prec_scale is None: 1ab
1373 count_trees, move_counts = compute_count_trees( 1ab
1374 bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1375 )
1376 else:
1377 # move_counts is passed later to a function, but then is unused under
1378 # this condition
1379 move_counts = None 1ab
1381 # Check if some nodes can't surely be grown because they don't have enough
1382 # datapoints. This check is not actually used now, it will be used at the
1383 # beginning of the next step to propose moves.
1384 if bart.forest.min_points_per_leaf is not None: 1ab
1385 count_half_trees = count_trees[:, : bart.forest.var_trees.shape[1]] 1ab
1386 bart = replace( 1ab
1387 bart,
1388 forest=replace(
1389 bart.forest,
1390 affluence_trees=count_half_trees >= 2 * bart.forest.min_points_per_leaf,
1391 ),
1392 )
1394 # count number of datapoints per leaf, weighted by error precision scale
1395 if bart.prec_scale is None: 1ab
1396 prec_trees = count_trees 1ab
1397 move_precs = move_counts 1ab
1398 else:
1399 prec_trees, move_precs = compute_prec_trees( 1ab
1400 bart.prec_scale,
1401 bart.forest.leaf_indices,
1402 moves,
1403 bart.forest.count_batch_size,
1404 )
1406 # compute some missing information about moves
1407 moves = complete_ratio(moves, move_counts, bart.forest.min_points_per_leaf) 1ab
1408 bart = replace( 1ab
1409 bart,
1410 forest=replace(
1411 bart.forest,
1412 grow_prop_count=jnp.sum(moves.grow),
1413 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1414 ),
1415 )
1417 prelkv, prelk = precompute_likelihood_terms( 1ab
1418 bart.sigma2, bart.forest.sigma_mu2, move_precs
1419 )
1420 prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2) 1ab
1422 return ParallelStageOut( 1ab
1423 bart=bart,
1424 moves=moves,
1425 prec_trees=prec_trees,
1426 move_counts=move_counts,
1427 move_precs=move_precs,
1428 prelkv=prelkv,
1429 prelk=prelk,
1430 prelf=prelf,
1431 )
1434@partial(vmap_nodoc, in_axes=(0, 0, None)) 1ab
1435def apply_grow_to_indices( 1ab
1436 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
1437) -> UInt[Array, 'num_trees n']:
1438 """
1439 Update the leaf indices to apply a grow move.
1441 Parameters
1442 ----------
1443 moves
1444 The proposed moves, see `propose_moves`.
1445 leaf_indices
1446 The index of the leaf each datapoint falls into.
1447 X
1448 The predictors matrix.
1450 Returns
1451 -------
1452 The updated leaf indices.
1453 """
1454 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab
1455 go_right = X[moves.grow_var, :] >= moves.grow_split 1ab
1456 tree_size = jnp.array(2 * moves.var_trees.size) 1ab
1457 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab
1458 return jnp.where( 1ab
1459 leaf_indices == node_to_update,
1460 left_child + go_right,
1461 leaf_indices,
1462 )
1465def compute_count_trees( 1ab
1466 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
1467) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
1468 """
1469 Count the number of datapoints in each leaf.
1471 Parameters
1472 ----------
1473 leaf_indices
1474 The index of the leaf each datapoint falls into, with the deeper version
1475 of the tree (post-GROW, pre-PRUNE).
1476 moves
1477 The proposed moves, see `propose_moves`.
1478 batch_size
1479 The data batch size to use for the summation.
1481 Returns
1482 -------
1483 count_trees : Int32[Array, 'num_trees 2**d']
1484 The number of points in each potential or actual leaf node.
1485 counts : Counts
1486 The counts of the number of points in the leaves grown or pruned by the
1487 moves.
1488 """
1489 num_trees, tree_size = moves.var_trees.shape 1ab
1490 tree_size *= 2 1ab
1491 tree_indices = jnp.arange(num_trees) 1ab
1493 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1ab
1495 # count datapoints in nodes modified by move
1496 left = count_trees[tree_indices, moves.left] 1ab
1497 right = count_trees[tree_indices, moves.right] 1ab
1498 counts = Counts(left=left, right=right, total=left + right) 1ab
1500 # write count into non-leaf node
1501 count_trees = count_trees.at[tree_indices, moves.node].set(counts.total) 1ab
1503 return count_trees, counts 1ab
1506def count_datapoints_per_leaf( 1ab
1507 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
1508) -> Int32[Array, 'num_trees 2**(d-1)']:
1509 """
1510 Count the number of datapoints in each leaf.
1512 Parameters
1513 ----------
1514 leaf_indices
1515 The index of the leaf each datapoint falls into.
1516 tree_size
1517 The size of the leaf tree array (2 ** d).
1518 batch_size
1519 The data batch size to use for the summation.
1521 Returns
1522 -------
1523 The number of points in each leaf node.
1524 """
1525 if batch_size is None: 1ab
1526 return _count_scan(leaf_indices, tree_size) 1ab
1527 else:
1528 return _count_vec(leaf_indices, tree_size, batch_size) 1ab
1531def _count_scan( 1ab
1532 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
1533) -> Int32[Array, 'num_trees {tree_size}']:
1534 def loop(_, leaf_indices): 1ab
1535 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1ab
1537 _, count_trees = lax.scan(loop, None, leaf_indices) 1ab
1538 return count_trees 1ab
1541def _aggregate_scatter( 1ab
1542 values: Shaped[Array, '*'],
1543 indices: Integer[Array, '*'],
1544 size: int,
1545 dtype: jnp.dtype,
1546) -> Shaped[Array, '{size}']:
1547 return jnp.zeros(size, dtype).at[indices].add(values) 1ab
1550def _count_vec( 1ab
1551 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
1552) -> Int32[Array, 'num_trees 2**(d-1)']:
1553 return _aggregate_batched_alltrees( 1ab
1554 1, leaf_indices, tree_size, jnp.uint32, batch_size
1555 )
1556 # uint16 is super-slow on gpu, don't use it even if n < 2^16
1559def _aggregate_batched_alltrees( 1ab
1560 values: Shaped[Array, '*'],
1561 indices: UInt[Array, 'num_trees n'],
1562 size: int,
1563 dtype: jnp.dtype,
1564 batch_size: int,
1565) -> Shaped[Array, 'num_trees {size}']:
1566 num_trees, n = indices.shape 1ab
1567 tree_indices = jnp.arange(num_trees) 1ab
1568 nbatches = n // batch_size + bool(n % batch_size) 1ab
1569 batch_indices = jnp.arange(n) % nbatches 1ab
1570 return ( 1ab
1571 jnp.zeros((num_trees, size, nbatches), dtype)
1572 .at[tree_indices[:, None], indices, batch_indices]
1573 .add(values)
1574 .sum(axis=2)
1575 )
1578def compute_prec_trees( 1ab
1579 prec_scale: Float32[Array, 'n'],
1580 leaf_indices: UInt[Array, 'num_trees n'],
1581 moves: Moves,
1582 batch_size: int | None,
1583) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
1584 """
1585 Compute the likelihood precision scale in each leaf.
1587 Parameters
1588 ----------
1589 prec_scale
1590 The scale of the precision of the error on each datapoint.
1591 leaf_indices
1592 The index of the leaf each datapoint falls into, with the deeper version
1593 of the tree (post-GROW, pre-PRUNE).
1594 moves
1595 The proposed moves, see `propose_moves`.
1596 batch_size
1597 The data batch size to use for the summation.
1599 Returns
1600 -------
1601 prec_trees : Float32[Array, 'num_trees 2**d']
1602 The likelihood precision scale in each potential or actual leaf node.
1603 precs : Precs
1604 The likelihood precision scale in the nodes involved in the moves.
1605 """
1606 num_trees, tree_size = moves.var_trees.shape 1ab
1607 tree_size *= 2 1ab
1608 tree_indices = jnp.arange(num_trees) 1ab
1610 prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size) 1ab
1612 # prec datapoints in nodes modified by move
1613 left = prec_trees[tree_indices, moves.left] 1ab
1614 right = prec_trees[tree_indices, moves.right] 1ab
1615 precs = Precs(left=left, right=right, total=left + right) 1ab
1617 # write prec into non-leaf node
1618 prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total) 1ab
1620 return prec_trees, precs 1ab
1623def prec_per_leaf( 1ab
1624 prec_scale: Float32[Array, 'n'],
1625 leaf_indices: UInt[Array, 'num_trees n'],
1626 tree_size: int,
1627 batch_size: int | None,
1628) -> Float32[Array, 'num_trees {tree_size}']:
1629 """
1630 Compute the likelihood precision scale in each leaf.
1632 Parameters
1633 ----------
1634 prec_scale
1635 The scale of the precision of the error on each datapoint.
1636 leaf_indices
1637 The index of the leaf each datapoint falls into.
1638 tree_size
1639 The size of the leaf tree array (2 ** d).
1640 batch_size
1641 The data batch size to use for the summation.
1643 Returns
1644 -------
1645 The likelihood precision scale in each leaf node.
1646 """
1647 if batch_size is None: 1647 ↛ 1648line 1647 didn't jump to line 1648 because the condition on line 1647 was never true1ab
1648 return _prec_scan(prec_scale, leaf_indices, tree_size)
1649 else:
1650 return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size) 1ab
1653def _prec_scan( 1ab
1654 prec_scale: Float32[Array, 'n'],
1655 leaf_indices: UInt[Array, 'num_trees n'],
1656 tree_size: int,
1657) -> Float32[Array, 'num_trees {tree_size}']:
1658 def loop(_, leaf_indices):
1659 return None, _aggregate_scatter(
1660 prec_scale, leaf_indices, tree_size, jnp.float32
1661 )
1663 _, prec_trees = lax.scan(loop, None, leaf_indices)
1664 return prec_trees
1667def _prec_vec( 1ab
1668 prec_scale: Float32[Array, 'n'],
1669 leaf_indices: UInt[Array, 'num_trees n'],
1670 tree_size: int,
1671 batch_size: int,
1672) -> Float32[Array, 'num_trees {tree_size}']:
1673 return _aggregate_batched_alltrees( 1ab
1674 prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1675 )
1678def complete_ratio( 1ab
1679 moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
1680) -> Moves:
1681 """
1682 Complete non-likelihood MH ratio calculation.
1684 This function adds the probability of choosing the prune move.
1686 Parameters
1687 ----------
1688 moves
1689 The proposed moves, see `propose_moves`.
1690 move_counts
1691 The counts of the number of points in the the nodes modified by the
1692 moves.
1693 min_points_per_leaf
1694 The minimum number of data points in a leaf node.
1696 Returns
1697 -------
1698 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1699 """
1700 p_prune = compute_p_prune(moves, move_counts, min_points_per_leaf) 1ab
1701 return replace( 1ab
1702 moves,
1703 log_trans_prior_ratio=jnp.log(moves.partial_ratio * p_prune),
1704 partial_ratio=None,
1705 )
1708def compute_p_prune( 1ab
1709 moves: Moves, move_counts: Counts | None, min_points_per_leaf: int | None
1710) -> Float32[Array, 'num_trees']:
1711 """
1712 Compute the probability of proposing a prune move for each tree.
1714 Parameters
1715 ----------
1716 moves
1717 The proposed moves, see `propose_moves`.
1718 move_counts
1719 The number of datapoints in the proposed children of the leaf to grow.
1720 Not used if `min_points_per_leaf` is `None`.
1721 min_points_per_leaf
1722 The minimum number of data points in a leaf node.
1724 Returns
1725 -------
1726 The probability of proposing a prune move.
1728 Notes
1729 -----
1730 This probability is computed for going from the state with the deeper tree
1731 to the one with the shallower one. This means, if grow: after accepting the
1732 grow move, if prune: right away.
1733 """
1734 # calculation in case the move is grow
1735 other_growable_leaves = moves.num_growable >= 2 1ab
1736 new_leaves_growable = moves.node < moves.var_trees.shape[1] // 2 1ab
1737 if min_points_per_leaf is not None: 1ab
1738 assert move_counts is not None 1ab
1739 any_above_threshold = move_counts.left >= 2 * min_points_per_leaf 1ab
1740 any_above_threshold |= move_counts.right >= 2 * min_points_per_leaf 1ab
1741 new_leaves_growable &= any_above_threshold 1ab
1742 grow_again_allowed = other_growable_leaves | new_leaves_growable 1ab
1743 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1ab
1745 # calculation in case the move is prune
1746 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab
1748 return jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab
1751@vmap_nodoc 1ab
1752def adapt_leaf_trees_to_grow_indices( 1ab
1753 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
1754) -> Float32[Array, 'num_trees 2**d']:
1755 """
1756 Modify leaves such that post-grow indices work on the original tree.
1758 The value of the leaf to grow is copied to what would be its children if the
1759 grow move was accepted.
1761 Parameters
1762 ----------
1763 leaf_trees
1764 The leaf values.
1765 moves
1766 The proposed moves, see `propose_moves`.
1768 Returns
1769 -------
1770 The modified leaf values.
1771 """
1772 values_at_node = leaf_trees[moves.node] 1ab
1773 return ( 1ab
1774 leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
1775 .set(values_at_node)
1776 .at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
1777 .set(values_at_node)
1778 )
1781def precompute_likelihood_terms( 1ab
1782 sigma2: Float32[Array, ''],
1783 sigma_mu2: Float32[Array, ''],
1784 move_precs: Precs | Counts,
1785) -> tuple[PreLkV, PreLk]:
1786 """
1787 Pre-compute terms used in the likelihood ratio of the acceptance step.
1789 Parameters
1790 ----------
1791 sigma2
1792 The error variance, or the global error variance factor is `prec_scale`
1793 is set.
1794 sigma_mu2
1795 The prior variance of each leaf.
1796 move_precs
1797 The likelihood precision scale in the leaves grown or pruned by the
1798 moves, under keys 'left', 'right', and 'total' (left + right).
1800 Returns
1801 -------
1802 prelkv : PreLkV
1803 Dictionary with pre-computed terms of the likelihood ratio, one per
1804 tree.
1805 prelk : PreLk
1806 Dictionary with pre-computed terms of the likelihood ratio, shared by
1807 all trees.
1808 """
1809 sigma2_left = sigma2 + move_precs.left * sigma_mu2 1ab
1810 sigma2_right = sigma2 + move_precs.right * sigma_mu2 1ab
1811 sigma2_total = sigma2 + move_precs.total * sigma_mu2 1ab
1812 prelkv = PreLkV( 1ab
1813 sigma2_left=sigma2_left,
1814 sigma2_right=sigma2_right,
1815 sigma2_total=sigma2_total,
1816 sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
1817 )
1818 return prelkv, PreLk( 1ab
1819 exp_factor=sigma_mu2 / (2 * sigma2),
1820 )
1823def precompute_leaf_terms( 1ab
1824 key: Key[Array, ''],
1825 prec_trees: Float32[Array, 'num_trees 2**d'],
1826 sigma2: Float32[Array, ''],
1827 sigma_mu2: Float32[Array, ''],
1828) -> PreLf:
1829 """
1830 Pre-compute terms used to sample leaves from their posterior.
1832 Parameters
1833 ----------
1834 key
1835 A jax random key.
1836 prec_trees
1837 The likelihood precision scale in each potential or actual leaf node.
1838 sigma2
1839 The error variance, or the global error variance factor if `prec_scale`
1840 is set.
1841 sigma_mu2
1842 The prior variance of each leaf.
1844 Returns
1845 -------
1846 Pre-computed terms for leaf sampling.
1847 """
1848 prec_lk = prec_trees / sigma2 1ab
1849 prec_prior = lax.reciprocal(sigma_mu2) 1ab
1850 var_post = lax.reciprocal(prec_lk + prec_prior) 1ab
1851 z = random.normal(key, prec_trees.shape, sigma2.dtype) 1ab
1852 return PreLf( 1ab
1853 mean_factor=var_post / sigma2,
1854 # mean = mean_lk * prec_lk * var_post
1855 # resid_tree = mean_lk * prec_tree -->
1856 # --> mean_lk = resid_tree / prec_tree (kind of)
1857 # mean_factor =
1858 # = mean / resid_tree =
1859 # = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
1860 # = 1 / prec_tree * prec_tree / sigma2 * var_post =
1861 # = var_post / sigma2
1862 centered_leaves=z * jnp.sqrt(var_post),
1863 )
1866def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 1ab
1867 """
1868 Accept/reject the moves one tree at a time.
1870 This is the most performance-sensitive function because it contains all and
1871 only the parts of the algorithm that can not be parallelized across trees.
1873 Parameters
1874 ----------
1875 pso
1876 The output of `accept_moves_parallel_stage`.
1878 Returns
1879 -------
1880 bart : State
1881 A partially updated BART mcmc state.
1882 moves : Moves
1883 The accepted/rejected moves, with `acc` and `to_prune` set.
1884 """
1886 def loop(resid, pt): 1ab
1887 resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves( 1ab
1888 resid,
1889 SeqStageInAllTrees(
1890 pso.bart.X,
1891 pso.bart.forest.resid_batch_size,
1892 pso.bart.prec_scale,
1893 pso.bart.forest.min_points_per_leaf,
1894 pso.bart.forest.log_likelihood is not None,
1895 pso.prelk,
1896 ),
1897 pt,
1898 )
1899 return resid, (leaf_tree, acc, to_prune, ratios) 1ab
1901 pts = SeqStageInPerTree( 1ab
1902 pso.bart.forest.leaf_trees,
1903 pso.prec_trees,
1904 pso.moves,
1905 pso.move_counts,
1906 pso.move_precs,
1907 pso.bart.forest.leaf_indices,
1908 pso.prelkv,
1909 pso.prelf,
1910 )
1911 resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, pso.bart.resid, pts) 1ab
1913 save_ratios = pso.bart.forest.log_likelihood is not None 1ab
1914 bart = replace( 1ab
1915 pso.bart,
1916 resid=resid,
1917 forest=replace(
1918 pso.bart.forest,
1919 leaf_trees=leaf_trees,
1920 log_likelihood=ratios['log_likelihood'] if save_ratios else None,
1921 log_trans_prior=ratios['log_trans_prior'] if save_ratios else None,
1922 ),
1923 )
1924 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab
1926 return bart, moves 1ab
1929class SeqStageInAllTrees(Module): 1ab
1930 """
1931 The inputs to `accept_move_and_sample_leaves` that are the same for all trees.
1933 Parameters
1934 ----------
1935 X
1936 The predictors.
1937 resid_batch_size
1938 The batch size for computing the sum of residuals in each leaf.
1939 prec_scale
1940 The scale of the precision of the error on each datapoint. If None, it
1941 is assumed to be 1.
1942 min_points_per_leaf
1943 The minimum number of data points in a leaf node.
1944 save_ratios
1945 Whether to save the acceptance ratios.
1946 prelk
1947 The pre-computed terms of the likelihood ratio which are shared across
1948 trees.
1949 """
1951 X: UInt[Array, 'p n'] 1ab
1952 resid_batch_size: int | None 1ab
1953 prec_scale: Float32[Array, 'n'] | None 1ab
1954 min_points_per_leaf: Int32[Array, ''] | None 1ab
1955 save_ratios: bool 1ab
1956 prelk: PreLk 1ab
1959class SeqStageInPerTree(Module): 1ab
1960 """
1961 The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
1963 Parameters
1964 ----------
1965 leaf_tree
1966 The leaf values of the tree.
1967 prec_tree
1968 The likelihood precision scale in each potential or actual leaf node.
1969 move
1970 The proposed move, see `propose_moves`.
1971 move_counts
1972 The counts of the number of points in the the nodes modified by the
1973 moves.
1974 move_precs
1975 The likelihood precision scale in each node modified by the moves.
1976 leaf_indices
1977 The leaf indices for the largest version of the tree compatible with
1978 the move.
1979 prelkv
1980 prelf
1981 The pre-computed terms of the likelihood ratio and leaf sampling which
1982 are specific to the tree.
1983 """
1985 leaf_tree: Float32[Array, '2**d'] 1ab
1986 prec_tree: Float32[Array, '2**d'] 1ab
1987 move: Moves 1ab
1988 move_counts: Counts | None 1ab
1989 move_precs: Precs | Counts 1ab
1990 leaf_indices: UInt[Array, 'n'] 1ab
1991 prelkv: PreLkV 1ab
1992 prelf: PreLf 1ab
1995def accept_move_and_sample_leaves( 1ab
1996 resid: Float32[Array, 'n'],
1997 at: SeqStageInAllTrees,
1998 pt: SeqStageInPerTree,
1999) -> tuple[
2000 Float32[Array, 'n'],
2001 Float32[Array, '2**d'],
2002 Bool[Array, ''],
2003 Bool[Array, ''],
2004 dict[str, Float32[Array, '']],
2005]:
2006 """
2007 Accept or reject a proposed move and sample the new leaf values.
2009 Parameters
2010 ----------
2011 resid
2012 The residuals (data minus forest value).
2013 at
2014 The inputs that are the same for all trees.
2015 pt
2016 The inputs that are separate for each tree.
2018 Returns
2019 -------
2020 resid : Float32[Array, 'n']
2021 The updated residuals (data minus forest value).
2022 leaf_tree : Float32[Array, '2**d']
2023 The new leaf values of the tree.
2024 acc : Bool[Array, '']
2025 Whether the move was accepted.
2026 to_prune : Bool[Array, '']
2027 Whether, to reflect the acceptance status of the move, the state should
2028 be updated by pruning the leaves involved in the move.
2029 ratios : dict[str, Float32[Array, '']]
2030 The acceptance ratios for the moves. Empty if not to be saved.
2031 """
2032 # sum residuals in each leaf, in tree proposed by grow move
2033 if at.prec_scale is None: 1ab
2034 scaled_resid = resid 1ab
2035 else:
2036 scaled_resid = resid * at.prec_scale 1ab
2037 resid_tree = sum_resid( 1ab
2038 scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
2039 )
2041 # subtract starting tree from function
2042 resid_tree += pt.prec_tree * pt.leaf_tree 1ab
2044 # get indices of move
2045 node = pt.move.node 1ab
2046 assert node.dtype == jnp.int32 1ab
2047 left = pt.move.left 1ab
2048 right = pt.move.right 1ab
2050 # sum residuals in parent node modified by move
2051 resid_left = resid_tree[left] 1ab
2052 resid_right = resid_tree[right] 1ab
2053 resid_total = resid_left + resid_right 1ab
2054 resid_tree = resid_tree.at[node].set(resid_total) 1ab
2056 # compute acceptance ratio
2057 log_lk_ratio = compute_likelihood_ratio( 1ab
2058 resid_total, resid_left, resid_right, pt.prelkv, at.prelk
2059 )
2060 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab
2061 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab
2062 ratios = {} 1ab
2063 if at.save_ratios: 1ab
2064 ratios.update( 1ab
2065 log_trans_prior=pt.move.log_trans_prior_ratio,
2066 # TODO save log_trans_prior_ratio as a vector outside of this loop,
2067 # then change the option everywhere to `save_likelihood_ratio`.
2068 log_likelihood=log_lk_ratio,
2069 )
2071 # determine whether to accept the move
2072 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1ab
2073 if at.min_points_per_leaf is not None: 1ab
2074 assert pt.move_counts is not None 1ab
2075 acc &= pt.move_counts.left >= at.min_points_per_leaf 1ab
2076 acc &= pt.move_counts.right >= at.min_points_per_leaf 1ab
2078 # compute leaves posterior and sample leaves
2079 initial_leaf_tree = pt.leaf_tree 1ab
2080 mean_post = resid_tree * pt.prelf.mean_factor 1ab
2081 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab
2083 # copy leaves around such that the leaf indices point to the correct leaf
2084 to_prune = acc ^ pt.move.grow 1ab
2085 leaf_tree = ( 1ab
2086 leaf_tree.at[jnp.where(to_prune, left, leaf_tree.size)]
2087 .set(leaf_tree[node])
2088 .at[jnp.where(to_prune, right, leaf_tree.size)]
2089 .set(leaf_tree[node])
2090 )
2092 # replace old tree with new tree in function values
2093 resid += (initial_leaf_tree - leaf_tree)[pt.leaf_indices] 1ab
2095 return resid, leaf_tree, acc, to_prune, ratios 1ab
2098def sum_resid( 1ab
2099 scaled_resid: Float32[Array, 'n'],
2100 leaf_indices: UInt[Array, 'n'],
2101 tree_size: int,
2102 batch_size: int | None,
2103) -> Float32[Array, '{tree_size}']:
2104 """
2105 Sum the residuals in each leaf.
2107 Parameters
2108 ----------
2109 scaled_resid
2110 The residuals (data minus forest value) multiplied by the error
2111 precision scale.
2112 leaf_indices
2113 The leaf indices of the tree (in which leaf each data point falls into).
2114 tree_size
2115 The size of the tree array (2 ** d).
2116 batch_size
2117 The data batch size for the aggregation. Batching increases numerical
2118 accuracy and parallelism.
2120 Returns
2121 -------
2122 The sum of the residuals at data points in each leaf.
2123 """
2124 if batch_size is None: 1ab
2125 aggr_func = _aggregate_scatter 1ab
2126 else:
2127 aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size) 1ab
2128 return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32) 1ab
2131def _aggregate_batched_onetree( 1ab
2132 values: Shaped[Array, '*'],
2133 indices: Integer[Array, '*'],
2134 size: int,
2135 dtype: jnp.dtype,
2136 batch_size: int,
2137) -> Float32[Array, '{size}']:
2138 (n,) = indices.shape 1ab
2139 nbatches = n // batch_size + bool(n % batch_size) 1ab
2140 batch_indices = jnp.arange(n) % nbatches 1ab
2141 return ( 1ab
2142 jnp.zeros((size, nbatches), dtype)
2143 .at[indices, batch_indices]
2144 .add(values)
2145 .sum(axis=1)
2146 )
2149def compute_likelihood_ratio( 1ab
2150 total_resid: Float32[Array, ''],
2151 left_resid: Float32[Array, ''],
2152 right_resid: Float32[Array, ''],
2153 prelkv: PreLkV,
2154 prelk: PreLk,
2155) -> Float32[Array, '']:
2156 """
2157 Compute the likelihood ratio of a grow move.
2159 Parameters
2160 ----------
2161 total_resid
2162 left_resid
2163 right_resid
2164 The sum of the residuals (scaled by error precision scale) of the
2165 datapoints falling in the nodes involved in the moves.
2166 prelkv
2167 prelk
2168 The pre-computed terms of the likelihood ratio, see
2169 `precompute_likelihood_terms`.
2171 Returns
2172 -------
2173 The likelihood ratio P(data | new tree) / P(data | old tree).
2174 """
2175 exp_term = prelk.exp_factor * ( 1ab
2176 left_resid * left_resid / prelkv.sigma2_left
2177 + right_resid * right_resid / prelkv.sigma2_right
2178 - total_resid * total_resid / prelkv.sigma2_total
2179 )
2180 return prelkv.sqrt_term + exp_term 1ab
2183def accept_moves_final_stage(bart: State, moves: Moves) -> State: 1ab
2184 """
2185 Post-process the mcmc state after accepting/rejecting the moves.
2187 This function is separate from `accept_moves_sequential_stage` to signal it
2188 can work in parallel across trees.
2190 Parameters
2191 ----------
2192 bart
2193 A partially updated BART mcmc state.
2194 moves
2195 The proposed moves (see `propose_moves`) as updated by
2196 `accept_moves_sequential_stage`.
2198 Returns
2199 -------
2200 The fully updated BART mcmc state.
2201 """
2202 return replace( 1ab
2203 bart,
2204 forest=replace(
2205 bart.forest,
2206 grow_acc_count=jnp.sum(moves.acc & moves.grow),
2207 prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2208 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2209 split_trees=apply_moves_to_split_trees(bart.forest.split_trees, moves),
2210 ),
2211 )
2214@vmap_nodoc 1ab
2215def apply_moves_to_leaf_indices( 1ab
2216 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
2217) -> UInt[Array, 'num_trees n']:
2218 """
2219 Update the leaf indices to match the accepted move.
2221 Parameters
2222 ----------
2223 leaf_indices
2224 The index of the leaf each datapoint falls into, if the grow move was
2225 accepted.
2226 moves
2227 The proposed moves (see `propose_moves`), as updated by
2228 `accept_moves_sequential_stage`.
2230 Returns
2231 -------
2232 The updated leaf indices.
2233 """
2234 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab
2235 is_child = (leaf_indices & mask) == moves.left 1ab
2236 return jnp.where( 1ab
2237 is_child & moves.to_prune,
2238 moves.node.astype(leaf_indices.dtype),
2239 leaf_indices,
2240 )
2243@vmap_nodoc 1ab
2244def apply_moves_to_split_trees( 1ab
2245 split_trees: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2246) -> UInt[Array, 'num_trees 2**(d-1)']:
2247 """
2248 Update the split trees to match the accepted move.
2250 Parameters
2251 ----------
2252 split_trees
2253 The cutpoints of the decision nodes in the initial trees.
2254 moves
2255 The proposed moves (see `propose_moves`), as updated by
2256 `accept_moves_sequential_stage`.
2258 Returns
2259 -------
2260 The updated split trees.
2261 """
2262 assert moves.to_prune is not None 1ab
2263 return ( 1ab
2264 split_trees.at[
2265 jnp.where(
2266 moves.grow,
2267 moves.node,
2268 split_trees.size,
2269 )
2270 ]
2271 .set(moves.grow_split.astype(split_trees.dtype))
2272 .at[
2273 jnp.where(
2274 moves.to_prune,
2275 moves.node,
2276 split_trees.size,
2277 )
2278 ]
2279 .set(0)
2280 )
2283def step_sigma(key: Key[Array, ''], bart: State) -> State: 1ab
2284 """
2285 MCMC-update the error variance (factor).
2287 Parameters
2288 ----------
2289 key
2290 A jax random key.
2291 bart
2292 A BART mcmc state.
2294 Returns
2295 -------
2296 The new BART mcmc state, with an updated `sigma2`.
2297 """
2298 resid = bart.resid 1ab
2299 alpha = bart.sigma2_alpha + resid.size / 2 1ab
2300 if bart.prec_scale is None: 1ab
2301 scaled_resid = resid 1ab
2302 else:
2303 scaled_resid = resid * bart.prec_scale 1ab
2304 norm2 = resid @ scaled_resid 1ab
2305 beta = bart.sigma2_beta + norm2 / 2 1ab
2307 sample = random.gamma(key, alpha) 1ab
2308 return replace(bart, sigma2=beta / sample) 1ab
2311def step_z(key: Key[Array, ''], bart: State) -> State: 1ab
2312 """
2313 MCMC-update the latent variable for binary regression.
2315 Parameters
2316 ----------
2317 key
2318 A jax random key.
2319 bart
2320 A BART MCMC state.
2322 Returns
2323 -------
2324 The updated BART MCMC state.
2325 """
2326 trees_plus_offset = bart.z - bart.resid 1ab
2327 lower = jnp.where(bart.y, -trees_plus_offset, -jnp.inf) 1ab
2328 upper = jnp.where(bart.y, jnp.inf, -trees_plus_offset) 1ab
2329 resid = random.truncated_normal(key, lower, upper) 1ab
2330 # TODO jax's implementation of truncated_normal is not good, it just does
2331 # cdf inversion with erf and erf_inv. I can do better, at least avoiding to
2332 # compute one of the boundaries, and maybe also flipping and using ndtr
2333 # instead of erf for numerical stability (open an issue in jax?)
2334 z = trees_plus_offset + resid 1ab
2335 return replace(bart, z=z, resid=resid) 1ab