Coverage for src/bartz/mcmcstep.py: 95%
491 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
1# bartz/src/bartz/mcmcstep.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""
26Functions that implement the BART posterior MCMC initialization and update step.
28Functions that do MCMC steps operate by taking as input a bart state, and
29outputting a new state. The inputs are not modified.
31The main entry points are:
33 - `State`: The dataclass that represents a BART MCMC state.
34 - `init`: Creates an initial `State` from data and configurations.
35 - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
36"""
38import math 1ab
39from dataclasses import replace 1ab
40from functools import cache, partial 1ab
41from typing import Any, Literal 1ab
43import jax 1ab
44from equinox import Module, field, tree_at 1ab
45from jax import lax, random 1ab
46from jax import numpy as jnp 1ab
47from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt 1ab
49from bartz import grove 1ab
50from bartz.jaxext import ( 1ab
51 minimal_unsigned_dtype,
52 split,
53 truncated_normal_onesided,
54 vmap_nodoc,
55)
58class Forest(Module): 1ab
59 """
60 Represents the MCMC state of a sum of trees.
62 Parameters
63 ----------
64 leaf_tree
65 The leaf values.
66 var_tree
67 The decision axes.
68 split_tree
69 The decision boundaries.
70 affluence_tree
71 Marks leaves that can be grown.
72 p_nonterminal
73 The prior probability of each node being nonterminal, conditional on
74 its ancestors. Includes the nodes at maximum depth which should be set
75 to 0.
76 p_propose_grow
77 The unnormalized probability of picking a leaf for a grow proposal.
78 leaf_indices
79 The index of the leaf each datapoints falls into, for each tree.
80 min_points_per_decision_node
81 The minimum number of data points in a decision node.
82 min_points_per_leaf
83 The minimum number of data points in a leaf node.
84 resid_batch_size
85 count_batch_size
86 The data batch sizes for computing the sufficient statistics. If `None`,
87 they are computed with no batching.
88 log_trans_prior
89 The log transition and prior Metropolis-Hastings ratio for the
90 proposed move on each tree.
91 log_likelihood
92 The log likelihood ratio.
93 grow_prop_count
94 prune_prop_count
95 The number of grow/prune proposals made during one full MCMC cycle.
96 grow_acc_count
97 prune_acc_count
98 The number of grow/prune moves accepted during one full MCMC cycle.
99 sigma_mu2
100 The prior variance of a leaf, conditional on the tree structure.
101 """
103 leaf_tree: Float32[Array, 'num_trees 2**d'] 1ab
104 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
105 split_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
106 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
107 max_split: UInt[Array, ' p'] 1ab
108 blocked_vars: UInt[Array, ' k'] | None 1ab
109 p_nonterminal: Float32[Array, ' 2**d'] 1ab
110 p_propose_grow: Float32[Array, ' 2**(d-1)'] 1ab
111 leaf_indices: UInt[Array, 'num_trees n'] 1ab
112 min_points_per_decision_node: Int32[Array, ''] | None 1ab
113 min_points_per_leaf: Int32[Array, ''] | None 1ab
114 resid_batch_size: int | None = field(static=True) 1ab
115 count_batch_size: int | None = field(static=True) 1ab
116 log_trans_prior: Float32[Array, ' num_trees'] | None 1ab
117 log_likelihood: Float32[Array, ' num_trees'] | None 1ab
118 grow_prop_count: Int32[Array, ''] 1ab
119 prune_prop_count: Int32[Array, ''] 1ab
120 grow_acc_count: Int32[Array, ''] 1ab
121 prune_acc_count: Int32[Array, ''] 1ab
122 sigma_mu2: Float32[Array, ''] 1ab
125class State(Module): 1ab
126 """
127 Represents the MCMC state of BART.
129 Parameters
130 ----------
131 X
132 The predictors.
133 max_split
134 The maximum split index for each predictor.
135 y
136 The response. If the data type is `bool`, the model is binary regression.
137 resid
138 The residuals (`y` or `z` minus sum of trees).
139 z
140 The latent variable for binary regression. `None` in continuous
141 regression.
142 offset
143 Constant shift added to the sum of trees.
144 sigma2
145 The error variance. `None` in binary regression.
146 prec_scale
147 The scale on the error precision, i.e., ``1 / error_scale ** 2``.
148 `None` in binary regression.
149 sigma2_alpha
150 sigma2_beta
151 The shape and scale parameters of the inverse gamma prior on the noise
152 variance. `None` in binary regression.
153 forest
154 The sum of trees model.
155 """
157 X: UInt[Array, 'p n'] 1ab
158 y: Float32[Array, ' n'] | Bool[Array, ' n'] 1ab
159 z: None | Float32[Array, ' n'] 1ab
160 offset: Float32[Array, ''] 1ab
161 resid: Float32[Array, ' n'] 1ab
162 sigma2: Float32[Array, ''] | None 1ab
163 prec_scale: Float32[Array, ' n'] | None 1ab
164 sigma2_alpha: Float32[Array, ''] | None 1ab
165 sigma2_beta: Float32[Array, ''] | None 1ab
166 forest: Forest 1ab
169def init( 1ab
170 *,
171 X: UInt[Any, 'p n'],
172 y: Float32[Any, ' n'] | Bool[Any, ' n'],
173 offset: float | Float32[Any, ''] = 0.0,
174 max_split: UInt[Any, ' p'],
175 num_trees: int,
176 p_nonterminal: Float32[Any, ' d-1'],
177 sigma_mu2: float | Float32[Any, ''],
178 sigma2_alpha: float | Float32[Any, ''] | None = None,
179 sigma2_beta: float | Float32[Any, ''] | None = None,
180 error_scale: Float32[Any, ' n'] | None = None,
181 min_points_per_decision_node: int | Integer[Any, ''] | None = None,
182 resid_batch_size: int | None | Literal['auto'] = 'auto',
183 count_batch_size: int | None | Literal['auto'] = 'auto',
184 save_ratios: bool = False,
185 filter_splitless_vars: bool = True,
186 min_points_per_leaf: int | Integer[Any, ''] | None = None,
187) -> State:
188 """
189 Make a BART posterior sampling MCMC initial state.
191 Parameters
192 ----------
193 X
194 The predictors. Note this is trasposed compared to the usual convention.
195 y
196 The response. If the data type is `bool`, the regression model is binary
197 regression with probit.
198 offset
199 Constant shift added to the sum of trees. 0 if not specified.
200 max_split
201 The maximum split index for each variable. All split ranges start at 1.
202 num_trees
203 The number of trees in the forest.
204 p_nonterminal
205 The probability of a nonterminal node at each depth. The maximum depth
206 of trees is fixed by the length of this array.
207 sigma_mu2
208 The prior variance of a leaf, conditional on the tree structure. The
209 prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
210 prior mean of leaves is always zero.
211 sigma2_alpha
212 sigma2_beta
213 The shape and scale parameters of the inverse gamma prior on the error
214 variance. Leave unspecified for binary regression.
215 error_scale
216 Each error is scaled by the corresponding factor in `error_scale`, so
217 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
218 Not supported for binary regression. If not specified, defaults to 1 for
219 all points, but potentially skipping calculations.
220 min_points_per_decision_node
221 The minimum number of data points in a decision node. 0 if not
222 specified.
223 resid_batch_size
224 count_batch_size
225 The batch sizes, along datapoints, for summing the residuals and
226 counting the number of datapoints in each leaf. `None` for no batching.
227 If 'auto', pick a value based on the device of `y`, or the default
228 device.
229 save_ratios
230 Whether to save the Metropolis-Hastings ratios.
231 filter_splitless_vars
232 Whether to check `max_split` for variables without available cutpoints.
233 If any are found, they are put into a list of variables to exclude from
234 the MCMC. If `False`, no check is performed, but the results may be
235 wrong if any variable is blocked. The function is jax-traceable only
236 if this is set to `False`.
237 min_points_per_leaf
238 The minimum number of datapoints in a leaf node. 0 if not specified.
239 Unlike `min_points_per_decision_node`, this constraint is not taken into
240 account in the Metropolis-Hastings ratio because it would be expensive
241 to compute. Grow moves that would violate this constraint are vetoed.
242 This parameter is independent of `min_points_per_decision_node` and
243 there is no check that they are coherent. It makes sense to set
244 ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
246 Returns
247 -------
248 An initialized BART MCMC state.
250 Raises
251 ------
252 ValueError
253 If `y` is boolean and arguments unused in binary regression are set.
255 Notes
256 -----
257 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
258 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
259 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
260 integers in the range ``[0, 1, ..., max_split[i]]``.
261 """
262 p_nonterminal = jnp.asarray(p_nonterminal) 1ab
263 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1ab
264 max_depth = p_nonterminal.size 1ab
266 @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1ab
267 def make_forest(max_depth, dtype): 1ab
268 return grove.make_tree(max_depth, dtype) 1ab
270 y = jnp.asarray(y) 1ab
271 offset = jnp.asarray(offset) 1ab
273 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 1ab
274 resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
275 )
277 is_binary = y.dtype == bool 1ab
278 if is_binary: 1ab
279 if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,): 279 ↛ 280line 279 didn't jump to line 280 because the condition on line 279 was never true1ab
280 msg = (
281 'error_scale, sigma2_alpha, and sigma2_beta must be set '
282 ' to `None` for binary regression.'
283 )
284 raise ValueError(msg)
285 sigma2 = None 1ab
286 else:
287 sigma2_alpha = jnp.asarray(sigma2_alpha) 1ab
288 sigma2_beta = jnp.asarray(sigma2_beta) 1ab
289 sigma2 = sigma2_beta / sigma2_alpha 1ab
291 max_split = jnp.asarray(max_split) 1ab
293 if filter_splitless_vars: 1ab
294 (blocked_vars,) = jnp.nonzero(max_split == 0) 1ab
295 blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size)) 1ab
296 # see `fully_used_variables` for the type cast
297 else:
298 blocked_vars = None 1ab
300 return State( 1ab
301 X=jnp.asarray(X),
302 y=y,
303 z=jnp.full(y.shape, offset) if is_binary else None,
304 offset=offset,
305 resid=jnp.zeros(y.shape) if is_binary else y - offset,
306 sigma2=sigma2,
307 prec_scale=(
308 None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
309 ),
310 sigma2_alpha=sigma2_alpha,
311 sigma2_beta=sigma2_beta,
312 forest=Forest(
313 leaf_tree=make_forest(max_depth, jnp.float32),
314 var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)),
315 split_tree=make_forest(max_depth - 1, max_split.dtype),
316 affluence_tree=(
317 make_forest(max_depth - 1, bool)
318 .at[:, 1]
319 .set(
320 True
321 if min_points_per_decision_node is None
322 else y.size >= min_points_per_decision_node
323 )
324 ),
325 blocked_vars=blocked_vars,
326 max_split=max_split,
327 grow_prop_count=jnp.zeros((), int),
328 grow_acc_count=jnp.zeros((), int),
329 prune_prop_count=jnp.zeros((), int),
330 prune_acc_count=jnp.zeros((), int),
331 p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)],
332 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
333 leaf_indices=jnp.ones(
334 (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
335 ),
336 min_points_per_decision_node=(
337 None
338 if min_points_per_decision_node is None
339 else jnp.asarray(min_points_per_decision_node)
340 ),
341 min_points_per_leaf=(
342 None
343 if min_points_per_leaf is None
344 else jnp.asarray(min_points_per_leaf)
345 ),
346 resid_batch_size=resid_batch_size,
347 count_batch_size=count_batch_size,
348 log_trans_prior=jnp.zeros(num_trees) if save_ratios else None,
349 log_likelihood=jnp.zeros(num_trees) if save_ratios else None,
350 sigma_mu2=jnp.asarray(sigma_mu2),
351 ),
352 )
355def _choose_suffstat_batch_size( 1ab
356 resid_batch_size, count_batch_size, y, forest_size
357) -> tuple[int | None, ...]:
358 @cache 1ab
359 def get_platform(): 1ab
360 try: 1ab
361 device = y.devices().pop() 1ab
362 except jax.errors.ConcretizationTypeError: 1ab
363 device = jax.devices()[0] 1ab
364 platform = device.platform 1ab
365 if platform not in ('cpu', 'gpu'): 365 ↛ 366line 365 didn't jump to line 366 because the condition on line 365 was never true1ab
366 msg = f'Unknown platform: {platform}'
367 raise KeyError(msg)
368 return platform 1ab
370 if resid_batch_size == 'auto': 1ab
371 platform = get_platform() 1ab
372 n = max(1, y.size) 1ab
373 if platform == 'cpu': 373 ↛ 375line 373 didn't jump to line 375 because the condition on line 373 was always true1ab
374 resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6 1ab
375 elif platform == 'gpu':
376 resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3
377 resid_batch_size = max(1, resid_batch_size) 1ab
379 if count_batch_size == 'auto': 1ab
380 platform = get_platform() 1ab
381 if platform == 'cpu': 381 ↛ 383line 381 didn't jump to line 383 because the condition on line 381 was always true1ab
382 count_batch_size = None 1ab
383 elif platform == 'gpu':
384 n = max(1, y.size)
385 count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2
386 # /4 is good on V100, /2 on L4/T4, still haven't tried A100
387 max_memory = 2**29
388 itemsize = 4
389 min_batch_size = math.ceil(forest_size * itemsize * n / max_memory)
390 count_batch_size = max(count_batch_size, min_batch_size)
391 count_batch_size = max(1, count_batch_size)
393 return resid_batch_size, count_batch_size 1ab
396@jax.jit 1ab
397def step(key: Key[Array, ''], bart: State) -> State: 1ab
398 """
399 Do one MCMC step.
401 Parameters
402 ----------
403 key
404 A jax random key.
405 bart
406 A BART mcmc state, as created by `init`.
408 Returns
409 -------
410 The new BART mcmc state.
411 """
412 keys = split(key) 1ab
414 if bart.y.dtype == bool: # binary regression 1ab
415 bart = replace(bart, sigma2=jnp.float32(1)) 1ab
416 bart = step_trees(keys.pop(), bart) 1ab
417 bart = replace(bart, sigma2=None) 1ab
418 return step_z(keys.pop(), bart) 1ab
420 else: # continuous regression
421 bart = step_trees(keys.pop(), bart) 1ab
422 return step_sigma(keys.pop(), bart) 1ab
425def step_trees(key: Key[Array, ''], bart: State) -> State: 1ab
426 """
427 Forest sampling step of BART MCMC.
429 Parameters
430 ----------
431 key
432 A jax random key.
433 bart
434 A BART mcmc state, as created by `init`.
436 Returns
437 -------
438 The new BART mcmc state.
440 Notes
441 -----
442 This function zeroes the proposal counters.
443 """
444 keys = split(key) 1ab
445 moves = propose_moves(keys.pop(), bart.forest) 1ab
446 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab
449class Moves(Module): 1ab
450 """
451 Moves proposed to modify each tree.
453 Parameters
454 ----------
455 allowed
456 Whether there is a possible move. If `False`, the other values may not
457 make sense. The only case in which a move is marked as allowed but is
458 then vetoed is if it does not satisfy `min_points_per_leaf`, which for
459 efficiency is implemented post-hoc without changing the rest of the
460 MCMC logic.
461 grow
462 Whether the move is a grow move or a prune move.
463 num_growable
464 The number of growable leaves in the original tree.
465 node
466 The index of the leaf to grow or node to prune.
467 left
468 right
469 The indices of the children of 'node'.
470 partial_ratio
471 A factor of the Metropolis-Hastings ratio of the move. It lacks the
472 likelihood ratio, the probability of proposing the prune move, and the
473 probability that the children of the modified node are terminal. If the
474 move is PRUNE, the ratio is inverted. `None` once
475 `log_trans_prior_ratio` has been computed.
476 log_trans_prior_ratio
477 The logarithm of the product of the transition and prior terms of the
478 Metropolis-Hastings ratio for the acceptance of the proposed move.
479 `None` if not yet computed. If PRUNE, the log-ratio is negated.
480 grow_var
481 The decision axes of the new rules.
482 grow_split
483 The decision boundaries of the new rules.
484 var_tree
485 The updated decision axes of the trees, valid whatever move.
486 affluence_tree
487 A partially updated `affluence_tree`, marking non-leaf nodes that would
488 become leaves if the move was accepted. This mark initially (out of
489 `propose_moves`) takes into account if there would be available decision
490 rules to grow the leaf, and whether there are enough datapoints in the
491 node is marked in `accept_moves_parallel_stage`.
492 logu
493 The logarithm of a uniform (0, 1] random variable to be used to
494 accept the move. It's in (-oo, 0].
495 acc
496 Whether the move was accepted. `None` if not yet computed.
497 to_prune
498 Whether the final operation to apply the move is pruning. This indicates
499 an accepted prune move or a rejected grow move. `None` if not yet
500 computed.
501 """
503 allowed: Bool[Array, ' num_trees'] 1ab
504 grow: Bool[Array, ' num_trees'] 1ab
505 num_growable: UInt[Array, ' num_trees'] 1ab
506 node: UInt[Array, ' num_trees'] 1ab
507 left: UInt[Array, ' num_trees'] 1ab
508 right: UInt[Array, ' num_trees'] 1ab
509 partial_ratio: Float32[Array, ' num_trees'] | None 1ab
510 log_trans_prior_ratio: None | Float32[Array, ' num_trees'] 1ab
511 grow_var: UInt[Array, ' num_trees'] 1ab
512 grow_split: UInt[Array, ' num_trees'] 1ab
513 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
514 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
515 logu: Float32[Array, ' num_trees'] 1ab
516 acc: None | Bool[Array, ' num_trees'] 1ab
517 to_prune: None | Bool[Array, ' num_trees'] 1ab
520def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: 1ab
521 """
522 Propose moves for all the trees.
524 There are two types of moves: GROW (convert a leaf to a decision node and
525 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
526 leaf, deleting its children).
528 Parameters
529 ----------
530 key
531 A jax random key.
532 forest
533 The `forest` field of a BART MCMC state.
535 Returns
536 -------
537 The proposed move for each tree.
538 """
539 num_trees, _ = forest.leaf_tree.shape 1ab
540 keys = split(key, 1 + 2 * num_trees) 1ab
542 # compute moves
543 grow_moves = propose_grow_moves( 1ab
544 keys.pop(num_trees),
545 forest.var_tree,
546 forest.split_tree,
547 forest.affluence_tree,
548 forest.max_split,
549 forest.blocked_vars,
550 forest.p_nonterminal,
551 forest.p_propose_grow,
552 )
553 prune_moves = propose_prune_moves( 1ab
554 keys.pop(num_trees),
555 forest.split_tree,
556 grow_moves.affluence_tree,
557 forest.p_nonterminal,
558 forest.p_propose_grow,
559 )
561 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1ab
563 # choose between grow or prune
564 p_grow = jnp.where( 1ab
565 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
566 )
567 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1ab
569 # compute children indices
570 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1ab
571 left = node << 1 1ab
572 right = left + 1 1ab
574 return Moves( 1ab
575 allowed=grow_moves.allowed | prune_moves.allowed,
576 grow=grow,
577 num_growable=grow_moves.num_growable,
578 node=node,
579 left=left,
580 right=right,
581 partial_ratio=jnp.where(
582 grow, grow_moves.partial_ratio, prune_moves.partial_ratio
583 ),
584 log_trans_prior_ratio=None, # will be set in complete_ratio
585 grow_var=grow_moves.var,
586 grow_split=grow_moves.split,
587 # var_tree does not need to be updated if prune
588 var_tree=grow_moves.var_tree,
589 # affluence_tree is updated for both moves unconditionally, prune last
590 affluence_tree=prune_moves.affluence_tree,
591 logu=jnp.log1p(-exp1mlogu),
592 acc=None, # will be set in accept_moves_sequential_stage
593 to_prune=None, # will be set in accept_moves_sequential_stage
594 )
597class GrowMoves(Module): 1ab
598 """
599 Represent a proposed grow move for each tree.
601 Parameters
602 ----------
603 allowed
604 Whether the move is allowed for proposal.
605 num_growable
606 The number of leaves that can be proposed for grow.
607 node
608 The index of the leaf to grow. ``2 ** d`` if there are no growable
609 leaves.
610 var
611 split
612 The decision axis and boundary of the new rule.
613 partial_ratio
614 A factor of the Metropolis-Hastings ratio of the move. It lacks
615 the likelihood ratio and the probability of proposing the prune
616 move.
617 var_tree
618 The updated decision axes of the tree.
619 affluence_tree
620 A partially updated `affluence_tree` that marks each new leaf that
621 would be produced as `True` if it would have available decision rules.
622 """
624 allowed: Bool[Array, ' num_trees'] 1ab
625 num_growable: UInt[Array, ' num_trees'] 1ab
626 node: UInt[Array, ' num_trees'] 1ab
627 var: UInt[Array, ' num_trees'] 1ab
628 split: UInt[Array, ' num_trees'] 1ab
629 partial_ratio: Float32[Array, ' num_trees'] 1ab
630 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
631 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
634@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None)) 1ab
635def propose_grow_moves( 1ab
636 key: Key[Array, ' num_trees'],
637 var_tree: UInt[Array, 'num_trees 2**(d-1)'],
638 split_tree: UInt[Array, 'num_trees 2**(d-1)'],
639 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
640 max_split: UInt[Array, ' p'],
641 blocked_vars: Int32[Array, ' k'] | None,
642 p_nonterminal: Float32[Array, ' 2**d'],
643 p_propose_grow: Float32[Array, ' 2**(d-1)'],
644) -> GrowMoves:
645 """
646 Propose a GROW move for each tree.
648 A GROW move picks a leaf node and converts it to a non-terminal node with
649 two leaf children.
651 Parameters
652 ----------
653 key
654 A jax random key.
655 var_tree
656 The splitting axes of the tree.
657 split_tree
658 The splitting points of the tree.
659 affluence_tree
660 Whether each leaf has enough points to be grown.
661 max_split
662 The maximum split index for each variable.
663 blocked_vars
664 The indices of the variables that have no available cutpoints.
665 p_nonterminal
666 The a priori probability of a node to be nonterminal conditional on the
667 ancestors, including at the maximum depth where it should be zero.
668 p_propose_grow
669 The unnormalized probability of choosing a leaf to grow.
671 Returns
672 -------
673 An object representing the proposed move.
675 Notes
676 -----
677 The move is not proposed if each leaf is already at maximum depth, or has
678 less datapoints than the requested threshold `min_points_per_decision_node`,
679 or it does not have any available decision rules given its ancestors. This
680 is marked by setting `allowed` to `False` and `num_growable` to 0.
681 """
682 keys = split(key, 3) 1ab
684 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1ab
685 keys.pop(), split_tree, affluence_tree, p_propose_grow
686 )
688 # sample a decision rule
689 var, num_available_var = choose_variable( 1ab
690 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars
691 )
692 split_idx, l, r = choose_split( 1ab
693 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
694 )
696 # determine if the new leaves would have available decision rules; if the
697 # move is blocked, these values may not make sense
698 left_growable = right_growable = num_available_var > 1 1ab
699 left_growable |= l < split_idx 1ab
700 right_growable |= split_idx + 1 < r 1ab
701 left = leaf_to_grow << 1 1ab
702 right = left + 1 1ab
703 affluence_tree = affluence_tree.at[left].set(left_growable) 1ab
704 affluence_tree = affluence_tree.at[right].set(right_growable) 1ab
706 ratio = compute_partial_ratio( 1ab
707 prob_choose, num_prunable, p_nonterminal, leaf_to_grow
708 )
710 return GrowMoves( 1ab
711 allowed=num_growable > 0,
712 num_growable=num_growable,
713 node=leaf_to_grow,
714 var=var,
715 split=split_idx,
716 partial_ratio=ratio,
717 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
718 affluence_tree=affluence_tree,
719 )
722def choose_leaf( 1ab
723 key: Key[Array, ''],
724 split_tree: UInt[Array, ' 2**(d-1)'],
725 affluence_tree: Bool[Array, ' 2**(d-1)'],
726 p_propose_grow: Float32[Array, ' 2**(d-1)'],
727) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
728 """
729 Choose a leaf node to grow in a tree.
731 Parameters
732 ----------
733 key
734 A jax random key.
735 split_tree
736 The splitting points of the tree.
737 affluence_tree
738 Whether a leaf has enough points that it could be split into two leaves
739 satisfying the `min_points_per_leaf` requirement.
740 p_propose_grow
741 The unnormalized probability of choosing a leaf to grow.
743 Returns
744 -------
745 leaf_to_grow : Int32[Array, '']
746 The index of the leaf to grow. If ``num_growable == 0``, return
747 ``2 ** d``.
748 num_growable : Int32[Array, '']
749 The number of leaf nodes that can be grown, i.e., are nonterminal
750 and have at least twice `min_points_per_leaf`.
751 prob_choose : Float32[Array, '']
752 The (normalized) probability that this function had to choose that
753 specific leaf, given the arguments.
754 num_prunable : Int32[Array, '']
755 The number of leaf parents that could be pruned, after converting the
756 selected leaf to a non-terminal node.
757 """
758 is_growable = growable_leaves(split_tree, affluence_tree) 1ab
759 num_growable = jnp.count_nonzero(is_growable) 1ab
760 distr = jnp.where(is_growable, p_propose_grow, 0) 1ab
761 leaf_to_grow, distr_norm = categorical(key, distr) 1ab
762 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1ab
763 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1ab
764 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1ab
765 num_prunable = jnp.count_nonzero(is_parent) 1ab
766 return leaf_to_grow, num_growable, prob_choose, num_prunable 1ab
769def growable_leaves( 1ab
770 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
771) -> Bool[Array, ' 2**(d-1)']:
772 """
773 Return a mask indicating the leaf nodes that can be proposed for growth.
775 The condition is that a leaf is not at the bottom level, has available
776 decision rules given its ancestors, and has at least
777 `min_points_per_decision_node` points.
779 Parameters
780 ----------
781 split_tree
782 The splitting points of the tree.
783 affluence_tree
784 Marks leaves that can be grown.
786 Returns
787 -------
788 The mask indicating the leaf nodes that can be proposed to grow.
790 Notes
791 -----
792 This function needs `split_tree` and not just `affluence_tree` because
793 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
794 """
795 return grove.is_actual_leaf(split_tree) & affluence_tree 1ab
798def categorical( 1ab
799 key: Key[Array, ''], distr: Float32[Array, ' n']
800) -> tuple[Int32[Array, ''], Float32[Array, '']]:
801 """
802 Return a random integer from an arbitrary distribution.
804 Parameters
805 ----------
806 key
807 A jax random key.
808 distr
809 An unnormalized probability distribution.
811 Returns
812 -------
813 u : Int32[Array, '']
814 A random integer in the range ``[0, n)``. If all probabilities are zero,
815 return ``n``.
816 norm : Float32[Array, '']
817 The sum of `distr`.
818 """
819 ecdf = jnp.cumsum(distr) 1ab
820 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1ab
821 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1ab
824def choose_variable( 1ab
825 key: Key[Array, ''],
826 var_tree: UInt[Array, ' 2**(d-1)'],
827 split_tree: UInt[Array, ' 2**(d-1)'],
828 max_split: UInt[Array, ' p'],
829 leaf_index: Int32[Array, ''],
830 blocked_vars: Int32[Array, ' k'] | None,
831) -> tuple[Int32[Array, ''], Int32[Array, '']]:
832 """
833 Choose a variable to split on for a new non-terminal node.
835 Parameters
836 ----------
837 key
838 A jax random key.
839 var_tree
840 The variable indices of the tree.
841 split_tree
842 The splitting points of the tree.
843 max_split
844 The maximum split index for each variable.
845 leaf_index
846 The index of the leaf to grow.
847 blocked_vars
848 The indices of the variables that have no available cutpoints. If
849 `None`, all variables are assumed unblocked.
851 Returns
852 -------
853 var : Int32[Array, '']
854 The index of the variable to split on.
855 num_available_var : Int32[Array, '']
856 The number of variables with available decision rules `var` was chosen
857 from.
858 """
859 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1ab
860 if blocked_vars is not None: 1ab
861 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1ab
862 return randint_exclude(key, max_split.size, var_to_ignore) 1ab
865def fully_used_variables( 1ab
866 var_tree: UInt[Array, ' 2**(d-1)'],
867 split_tree: UInt[Array, ' 2**(d-1)'],
868 max_split: UInt[Array, ' p'],
869 leaf_index: Int32[Array, ''],
870) -> UInt[Array, ' d-2']:
871 """
872 Find variables in the ancestors of a node that have an empty split range.
874 Parameters
875 ----------
876 var_tree
877 The variable indices of the tree.
878 split_tree
879 The splitting points of the tree.
880 max_split
881 The maximum split index for each variable.
882 leaf_index
883 The index of the node, assumed to be valid for `var_tree`.
885 Returns
886 -------
887 The indices of the variables that have an empty split range.
889 Notes
890 -----
891 The number of unused variables is not known in advance. Unused values in the
892 array are filled with `p`. The fill values are not guaranteed to be placed
893 in any particular order, and variables may appear more than once.
894 """
895 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1ab
896 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1ab
897 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1ab
898 num_split = r - l 1ab
899 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1ab
900 # the type of var_to_ignore is already sufficient to hold max_split.size,
901 # see ancestor_variables()
904def ancestor_variables( 1ab
905 var_tree: UInt[Array, ' 2**(d-1)'],
906 max_split: UInt[Array, ' p'],
907 node_index: Int32[Array, ''],
908) -> UInt[Array, ' d-2']:
909 """
910 Return the list of variables in the ancestors of a node.
912 Parameters
913 ----------
914 var_tree
915 The variable indices of the tree.
916 max_split
917 The maximum split index for each variable. Used only to get `p`.
918 node_index
919 The index of the node, assumed to be valid for `var_tree`.
921 Returns
922 -------
923 The variable indices of the ancestors of the node.
925 Notes
926 -----
927 The ancestors are the nodes going from the root to the parent of the node.
928 The number of ancestors is not known at tracing time; unused spots in the
929 output array are filled with `p`.
930 """
931 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab
932 ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size)) 1ab
933 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1ab
935 def loop(carry, _): 1ab
936 i, index, ancestor_vars = carry 1ab
937 index >>= 1 1ab
938 var = var_tree[index] 1ab
939 var = jnp.where(index, var, max_split.size) 1ab
940 ancestor_vars = ancestor_vars.at[i].set(var) 1ab
941 return (i - 1, index, ancestor_vars), None 1ab
943 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1ab
944 return ancestor_vars 1ab
947def split_range( 1ab
948 var_tree: UInt[Array, ' 2**(d-1)'],
949 split_tree: UInt[Array, ' 2**(d-1)'],
950 max_split: UInt[Array, ' p'],
951 node_index: Int32[Array, ''],
952 ref_var: Int32[Array, ''],
953) -> tuple[Int32[Array, ''], Int32[Array, '']]:
954 """
955 Return the range of allowed splits for a variable at a given node.
957 Parameters
958 ----------
959 var_tree
960 The variable indices of the tree.
961 split_tree
962 The splitting points of the tree.
963 max_split
964 The maximum split index for each variable.
965 node_index
966 The index of the node, assumed to be valid for `var_tree`.
967 ref_var
968 The variable for which to measure the split range.
970 Returns
971 -------
972 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
973 """
974 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab
975 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1ab
976 jnp.int32
977 )
978 carry = jnp.int32(0), initial_r, node_index 1ab
980 def loop(carry, _): 1ab
981 l, r, index = carry 1ab
982 right_child = (index & 1).astype(bool) 1ab
983 index >>= 1 1ab
984 split = split_tree[index] 1ab
985 cond = (var_tree[index] == ref_var) & index.astype(bool) 1ab
986 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1ab
987 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1ab
988 return (l, r, index), None 1ab
990 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1ab
991 return l + 1, r 1ab
994def randint_exclude( 1ab
995 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
996) -> tuple[Int32[Array, ''], Int32[Array, '']]:
997 """
998 Return a random integer in a range, excluding some values.
1000 Parameters
1001 ----------
1002 key
1003 A jax random key.
1004 sup
1005 The exclusive upper bound of the range.
1006 exclude
1007 The values to exclude from the range. Values greater than or equal to
1008 `sup` are ignored. Values can appear more than once.
1010 Returns
1011 -------
1012 u : Int32[Array, '']
1013 A random integer `u` in the range ``[0, sup)`` such that ``u not in
1014 exclude``.
1015 num_allowed : Int32[Array, '']
1016 The number of integers in the range that were not excluded.
1018 Notes
1019 -----
1020 If all values in the range are excluded, return `sup`.
1021 """
1022 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1ab
1023 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1ab
1024 u = random.randint(key, (), 0, num_allowed) 1ab
1026 def loop(u, i_excluded): 1ab
1027 return jnp.where(i_excluded <= u, u + 1, u), None 1ab
1029 u, _ = lax.scan(loop, u, exclude) 1ab
1030 return u, num_allowed 1ab
1033def choose_split( 1ab
1034 key: Key[Array, ''],
1035 var: Int32[Array, ''],
1036 var_tree: UInt[Array, ' 2**(d-1)'],
1037 split_tree: UInt[Array, ' 2**(d-1)'],
1038 max_split: UInt[Array, ' p'],
1039 leaf_index: Int32[Array, ''],
1040) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
1041 """
1042 Choose a split point for a new non-terminal node.
1044 Parameters
1045 ----------
1046 key
1047 A jax random key.
1048 var
1049 The variable to split on.
1050 var_tree
1051 The splitting axes of the tree. Does not need to already contain `var`
1052 at `leaf_index`.
1053 split_tree
1054 The splitting points of the tree.
1055 max_split
1056 The maximum split index for each variable.
1057 leaf_index
1058 The index of the leaf to grow.
1060 Returns
1061 -------
1062 split : Int32[Array, '']
1063 The cutpoint.
1064 l : Int32[Array, '']
1065 r : Int32[Array, '']
1066 The integer range `split` was drawn from is [l, r).
1068 Notes
1069 -----
1070 If `var` is out of bounds, or if the available split range on that variable
1071 is empty, return 0.
1072 """
1073 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1ab
1074 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1ab
1077def compute_partial_ratio( 1ab
1078 prob_choose: Float32[Array, ''],
1079 num_prunable: Int32[Array, ''],
1080 p_nonterminal: Float32[Array, ' 2**d'],
1081 leaf_to_grow: Int32[Array, ''],
1082) -> Float32[Array, '']:
1083 """
1084 Compute the product of the transition and prior ratios of a grow move.
1086 Parameters
1087 ----------
1088 prob_choose
1089 The probability that the leaf had to be chosen amongst the growable
1090 leaves.
1091 num_prunable
1092 The number of leaf parents that could be pruned, after converting the
1093 leaf to be grown to a non-terminal node.
1094 p_nonterminal
1095 The a priori probability of each node being nonterminal conditional on
1096 its ancestors.
1097 leaf_to_grow
1098 The index of the leaf to grow.
1100 Returns
1101 -------
1102 The partial transition ratio times the prior ratio.
1104 Notes
1105 -----
1106 The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1107 The "partial" transition ratio returned is missing the factor P(propose
1108 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
1109 "partial" prior ratio is missing the factor P(children are leaves).
1110 """
1111 # the two ratios also contain factors num_available_split *
1112 # num_available_var, but they cancel out
1114 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
1115 # computed here because they need the count trees, which are computed in the
1116 # acceptance phase
1118 prune_allowed = leaf_to_grow != 1 1ab
1119 # prune allowed <---> the initial tree is not a root
1120 # leaf to grow is root --> the tree can only be a root
1121 # tree is a root --> the only leaf I can grow is root
1122 p_grow = jnp.where(prune_allowed, 0.5, 1) 1ab
1123 inv_trans_ratio = p_grow * prob_choose * num_prunable 1ab
1125 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this
1126 # would produce a 0 and then an inf when `complete_ratio` takes the log
1127 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1ab
1128 tree_ratio = pnt / (1 - pnt) 1ab
1130 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1ab
1133class PruneMoves(Module): 1ab
1134 """
1135 Represent a proposed prune move for each tree.
1137 Parameters
1138 ----------
1139 allowed
1140 Whether the move is possible.
1141 node
1142 The index of the node to prune. ``2 ** d`` if no node can be pruned.
1143 partial_ratio
1144 A factor of the Metropolis-Hastings ratio of the move. It lacks the
1145 likelihood ratio, the probability of proposing the prune move, and the
1146 prior probability that the children of the node to prune are leaves.
1147 This ratio is inverted, and is meant to be inverted back in
1148 `accept_move_and_sample_leaves`.
1149 """
1151 allowed: Bool[Array, ' num_trees'] 1ab
1152 node: UInt[Array, ' num_trees'] 1ab
1153 partial_ratio: Float32[Array, ' num_trees'] 1ab
1154 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
1157@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 1ab
1158def propose_prune_moves( 1ab
1159 key: Key[Array, ''],
1160 split_tree: UInt[Array, ' 2**(d-1)'],
1161 affluence_tree: Bool[Array, ' 2**(d-1)'],
1162 p_nonterminal: Float32[Array, ' 2**d'],
1163 p_propose_grow: Float32[Array, ' 2**(d-1)'],
1164) -> PruneMoves:
1165 """
1166 Tree structure prune move proposal of BART MCMC.
1168 Parameters
1169 ----------
1170 key
1171 A jax random key.
1172 split_tree
1173 The splitting points of the tree.
1174 affluence_tree
1175 Whether each leaf can be grown.
1176 p_nonterminal
1177 The a priori probability of a node to be nonterminal conditional on
1178 the ancestors, including at the maximum depth where it should be zero.
1179 p_propose_grow
1180 The unnormalized probability of choosing a leaf to grow.
1182 Returns
1183 -------
1184 An object representing the proposed moves.
1185 """
1186 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1ab
1187 key, split_tree, affluence_tree, p_propose_grow
1188 )
1190 ratio = compute_partial_ratio( 1ab
1191 prob_choose, num_prunable, p_nonterminal, node_to_prune
1192 )
1194 return PruneMoves( 1ab
1195 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
1196 node=node_to_prune,
1197 partial_ratio=ratio,
1198 affluence_tree=affluence_tree,
1199 )
1202def choose_leaf_parent( 1ab
1203 key: Key[Array, ''],
1204 split_tree: UInt[Array, ' 2**(d-1)'],
1205 affluence_tree: Bool[Array, ' 2**(d-1)'],
1206 p_propose_grow: Float32[Array, ' 2**(d-1)'],
1207) -> tuple[
1208 Int32[Array, ''],
1209 Int32[Array, ''],
1210 Float32[Array, ''],
1211 Bool[Array, 'num_trees 2**(d-1)'],
1212]:
1213 """
1214 Pick a non-terminal node with leaf children to prune in a tree.
1216 Parameters
1217 ----------
1218 key
1219 A jax random key.
1220 split_tree
1221 The splitting points of the tree.
1222 affluence_tree
1223 Whether a leaf has enough points to be grown.
1224 p_propose_grow
1225 The unnormalized probability of choosing a leaf to grow.
1227 Returns
1228 -------
1229 node_to_prune : Int32[Array, '']
1230 The index of the node to prune. If ``num_prunable == 0``, return
1231 ``2 ** d``.
1232 num_prunable : Int32[Array, '']
1233 The number of leaf parents that could be pruned.
1234 prob_choose : Float32[Array, '']
1235 The (normalized) probability that `choose_leaf` would chose
1236 `node_to_prune` as leaf to grow, if passed the tree where
1237 `node_to_prune` had been pruned.
1238 affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
1239 A partially updated `affluence_tree`, marking the node to prune as
1240 growable.
1241 """
1242 # sample a node to prune
1243 is_prunable = grove.is_leaves_parent(split_tree) 1ab
1244 num_prunable = jnp.count_nonzero(is_prunable) 1ab
1245 node_to_prune = randint_masked(key, is_prunable) 1ab
1246 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1ab
1248 # compute stuff for reverse move
1249 split_tree = split_tree.at[node_to_prune].set(0) 1ab
1250 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1ab
1251 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1ab
1252 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1ab
1253 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1ab
1254 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1ab
1256 return node_to_prune, num_prunable, prob_choose, affluence_tree 1ab
1259def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']: 1ab
1260 """
1261 Return a random integer in a range, including only some values.
1263 Parameters
1264 ----------
1265 key
1266 A jax random key.
1267 mask
1268 The mask indicating the allowed values.
1270 Returns
1271 -------
1272 A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
1274 Notes
1275 -----
1276 If all values in the mask are `False`, return `n`.
1277 """
1278 ecdf = jnp.cumsum(mask) 1ab
1279 u = random.randint(key, (), 0, ecdf[-1]) 1ab
1280 return jnp.searchsorted(ecdf, u, 'right') 1ab
1283def accept_moves_and_sample_leaves( 1ab
1284 key: Key[Array, ''], bart: State, moves: Moves
1285) -> State:
1286 """
1287 Accept or reject the proposed moves and sample the new leaf values.
1289 Parameters
1290 ----------
1291 key
1292 A jax random key.
1293 bart
1294 A valid BART mcmc state.
1295 moves
1296 The proposed moves, see `propose_moves`.
1298 Returns
1299 -------
1300 A new (valid) BART mcmc state.
1301 """
1302 pso = accept_moves_parallel_stage(key, bart, moves) 1ab
1303 bart, moves = accept_moves_sequential_stage(pso) 1ab
1304 return accept_moves_final_stage(bart, moves) 1ab
1307class Counts(Module): 1ab
1308 """
1309 Number of datapoints in the nodes involved in proposed moves for each tree.
1311 Parameters
1312 ----------
1313 left
1314 Number of datapoints in the left child.
1315 right
1316 Number of datapoints in the right child.
1317 total
1318 Number of datapoints in the parent (``= left + right``).
1319 """
1321 left: UInt[Array, ' num_trees'] 1ab
1322 right: UInt[Array, ' num_trees'] 1ab
1323 total: UInt[Array, ' num_trees'] 1ab
1326class Precs(Module): 1ab
1327 """
1328 Likelihood precision scale in the nodes involved in proposed moves for each tree.
1330 The "likelihood precision scale" of a tree node is the sum of the inverse
1331 squared error scales of the datapoints selected by the node.
1333 Parameters
1334 ----------
1335 left
1336 Likelihood precision scale in the left child.
1337 right
1338 Likelihood precision scale in the right child.
1339 total
1340 Likelihood precision scale in the parent (``= left + right``).
1341 """
1343 left: Float32[Array, ' num_trees'] 1ab
1344 right: Float32[Array, ' num_trees'] 1ab
1345 total: Float32[Array, ' num_trees'] 1ab
1348class PreLkV(Module): 1ab
1349 """
1350 Non-sequential terms of the likelihood ratio for each tree.
1352 These terms can be computed in parallel across trees.
1354 Parameters
1355 ----------
1356 sigma2_left
1357 The noise variance in the left child of the leaves grown or pruned by
1358 the moves.
1359 sigma2_right
1360 The noise variance in the right child of the leaves grown or pruned by
1361 the moves.
1362 sigma2_total
1363 The noise variance in the total of the leaves grown or pruned by the
1364 moves.
1365 sqrt_term
1366 The **logarithm** of the square root term of the likelihood ratio.
1367 """
1369 sigma2_left: Float32[Array, ' num_trees'] 1ab
1370 sigma2_right: Float32[Array, ' num_trees'] 1ab
1371 sigma2_total: Float32[Array, ' num_trees'] 1ab
1372 sqrt_term: Float32[Array, ' num_trees'] 1ab
1375class PreLk(Module): 1ab
1376 """
1377 Non-sequential terms of the likelihood ratio shared by all trees.
1379 Parameters
1380 ----------
1381 exp_factor
1382 The factor to multiply the likelihood ratio by, shared by all trees.
1383 """
1385 exp_factor: Float32[Array, ''] 1ab
1388class PreLf(Module): 1ab
1389 """
1390 Pre-computed terms used to sample leaves from their posterior.
1392 These terms can be computed in parallel across trees.
1394 Parameters
1395 ----------
1396 mean_factor
1397 The factor to be multiplied by the sum of the scaled residuals to
1398 obtain the posterior mean.
1399 centered_leaves
1400 The mean-zero normal values to be added to the posterior mean to
1401 obtain the posterior leaf samples.
1402 """
1404 mean_factor: Float32[Array, 'num_trees 2**d'] 1ab
1405 centered_leaves: Float32[Array, 'num_trees 2**d'] 1ab
1408class ParallelStageOut(Module): 1ab
1409 """
1410 The output of `accept_moves_parallel_stage`.
1412 Parameters
1413 ----------
1414 bart
1415 A partially updated BART mcmc state.
1416 moves
1417 The proposed moves, with `partial_ratio` set to `None` and
1418 `log_trans_prior_ratio` set to its final value.
1419 prec_trees
1420 The likelihood precision scale in each potential or actual leaf node. If
1421 there is no precision scale, this is the number of points in each leaf.
1422 move_counts
1423 The counts of the number of points in the the nodes modified by the
1424 moves. If `bart.min_points_per_leaf` is not set and
1425 `bart.prec_scale` is set, they are not computed.
1426 move_precs
1427 The likelihood precision scale in each node modified by the moves. If
1428 `bart.prec_scale` is not set, this is set to `move_counts`.
1429 prelkv
1430 prelk
1431 prelf
1432 Objects with pre-computed terms of the likelihood ratios and leaf
1433 samples.
1434 """
1436 bart: State 1ab
1437 moves: Moves 1ab
1438 prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d'] 1ab
1439 move_precs: Precs | Counts 1ab
1440 prelkv: PreLkV 1ab
1441 prelk: PreLk 1ab
1442 prelf: PreLf 1ab
1445def accept_moves_parallel_stage( 1ab
1446 key: Key[Array, ''], bart: State, moves: Moves
1447) -> ParallelStageOut:
1448 """
1449 Pre-compute quantities used to accept moves, in parallel across trees.
1451 Parameters
1452 ----------
1453 key : jax.dtypes.prng_key array
1454 A jax random key.
1455 bart : dict
1456 A BART mcmc state.
1457 moves : dict
1458 The proposed moves, see `propose_moves`.
1460 Returns
1461 -------
1462 An object with all that could be done in parallel.
1463 """
1464 # where the move is grow, modify the state like the move was accepted
1465 bart = replace( 1ab
1466 bart,
1467 forest=replace(
1468 bart.forest,
1469 var_tree=moves.var_tree,
1470 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1471 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
1472 ),
1473 )
1475 # count number of datapoints per leaf
1476 if (
1477 bart.forest.min_points_per_decision_node is not None
1478 or bart.forest.min_points_per_leaf is not None
1479 or bart.prec_scale is None
1480 ):
1481 count_trees, move_counts = compute_count_trees( 1ab
1482 bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1483 )
1485 # mark which leaves & potential leaves have enough points to be grown
1486 if bart.forest.min_points_per_decision_node is not None: 1ab
1487 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1ab
1488 moves = replace( 1ab
1489 moves,
1490 affluence_tree=moves.affluence_tree
1491 & (count_half_trees >= bart.forest.min_points_per_decision_node),
1492 )
1494 # copy updated affluence_tree to state
1495 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1ab
1497 # veto grove move if new leaves don't have enough datapoints
1498 if bart.forest.min_points_per_leaf is not None: 1ab
1499 moves = replace( 1ab
1500 moves,
1501 allowed=moves.allowed
1502 & (move_counts.left >= bart.forest.min_points_per_leaf)
1503 & (move_counts.right >= bart.forest.min_points_per_leaf),
1504 )
1506 # count number of datapoints per leaf, weighted by error precision scale
1507 if bart.prec_scale is None: 1ab
1508 prec_trees = count_trees 1ab
1509 move_precs = move_counts 1ab
1510 else:
1511 prec_trees, move_precs = compute_prec_trees( 1ab
1512 bart.prec_scale,
1513 bart.forest.leaf_indices,
1514 moves,
1515 bart.forest.count_batch_size,
1516 )
1517 assert move_precs is not None 1ab
1519 # compute some missing information about moves
1520 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1ab
1521 save_ratios = bart.forest.log_likelihood is not None 1ab
1522 bart = replace( 1ab
1523 bart,
1524 forest=replace(
1525 bart.forest,
1526 grow_prop_count=jnp.sum(moves.grow),
1527 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1528 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
1529 ),
1530 )
1532 # pre-compute some likelihood ratio & posterior terms
1533 assert bart.sigma2 is not None # `step` shall temporarily set it to 1 1ab
1534 prelkv, prelk = precompute_likelihood_terms( 1ab
1535 bart.sigma2, bart.forest.sigma_mu2, move_precs
1536 )
1537 prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2) 1ab
1539 return ParallelStageOut( 1ab
1540 bart=bart,
1541 moves=moves,
1542 prec_trees=prec_trees,
1543 move_precs=move_precs,
1544 prelkv=prelkv,
1545 prelk=prelk,
1546 prelf=prelf,
1547 )
1550@partial(vmap_nodoc, in_axes=(0, 0, None)) 1ab
1551def apply_grow_to_indices( 1ab
1552 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
1553) -> UInt[Array, 'num_trees n']:
1554 """
1555 Update the leaf indices to apply a grow move.
1557 Parameters
1558 ----------
1559 moves
1560 The proposed moves, see `propose_moves`.
1561 leaf_indices
1562 The index of the leaf each datapoint falls into.
1563 X
1564 The predictors matrix.
1566 Returns
1567 -------
1568 The updated leaf indices.
1569 """
1570 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab
1571 go_right = X[moves.grow_var, :] >= moves.grow_split 1ab
1572 tree_size = jnp.array(2 * moves.var_tree.size) 1ab
1573 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab
1574 return jnp.where( 1ab
1575 leaf_indices == node_to_update, left_child + go_right, leaf_indices
1576 )
1579def compute_count_trees( 1ab
1580 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
1581) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
1582 """
1583 Count the number of datapoints in each leaf.
1585 Parameters
1586 ----------
1587 leaf_indices
1588 The index of the leaf each datapoint falls into, with the deeper version
1589 of the tree (post-GROW, pre-PRUNE).
1590 moves
1591 The proposed moves, see `propose_moves`.
1592 batch_size
1593 The data batch size to use for the summation.
1595 Returns
1596 -------
1597 count_trees : Int32[Array, 'num_trees 2**d']
1598 The number of points in each potential or actual leaf node.
1599 counts : Counts
1600 The counts of the number of points in the leaves grown or pruned by the
1601 moves.
1602 """
1603 num_trees, tree_size = moves.var_tree.shape 1ab
1604 tree_size *= 2 1ab
1605 tree_indices = jnp.arange(num_trees) 1ab
1607 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1ab
1609 # count datapoints in nodes modified by move
1610 left = count_trees[tree_indices, moves.left] 1ab
1611 right = count_trees[tree_indices, moves.right] 1ab
1612 counts = Counts(left=left, right=right, total=left + right) 1ab
1614 # write count into non-leaf node
1615 count_trees = count_trees.at[tree_indices, moves.node].set(counts.total) 1ab
1617 return count_trees, counts 1ab
1620def count_datapoints_per_leaf( 1ab
1621 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
1622) -> Int32[Array, 'num_trees 2**(d-1)']:
1623 """
1624 Count the number of datapoints in each leaf.
1626 Parameters
1627 ----------
1628 leaf_indices
1629 The index of the leaf each datapoint falls into.
1630 tree_size
1631 The size of the leaf tree array (2 ** d).
1632 batch_size
1633 The data batch size to use for the summation.
1635 Returns
1636 -------
1637 The number of points in each leaf node.
1638 """
1639 if batch_size is None: 1ab
1640 return _count_scan(leaf_indices, tree_size) 1ab
1641 else:
1642 return _count_vec(leaf_indices, tree_size, batch_size) 1ab
1645def _count_scan( 1ab
1646 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
1647) -> Int32[Array, 'num_trees {tree_size}']:
1648 def loop(_, leaf_indices): 1ab
1649 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1ab
1651 _, count_trees = lax.scan(loop, None, leaf_indices) 1ab
1652 return count_trees 1ab
1655def _aggregate_scatter( 1ab
1656 values: Shaped[Array, '*'],
1657 indices: Integer[Array, '*'],
1658 size: int,
1659 dtype: jnp.dtype,
1660) -> Shaped[Array, ' {size}']:
1661 return jnp.zeros(size, dtype).at[indices].add(values) 1ab
1664def _count_vec( 1ab
1665 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
1666) -> Int32[Array, 'num_trees 2**(d-1)']:
1667 return _aggregate_batched_alltrees( 1ab
1668 1, leaf_indices, tree_size, jnp.uint32, batch_size
1669 )
1670 # uint16 is super-slow on gpu, don't use it even if n < 2^16
1673def _aggregate_batched_alltrees( 1ab
1674 values: Shaped[Array, '*'],
1675 indices: UInt[Array, 'num_trees n'],
1676 size: int,
1677 dtype: jnp.dtype,
1678 batch_size: int,
1679) -> Shaped[Array, 'num_trees {size}']:
1680 num_trees, n = indices.shape 1ab
1681 tree_indices = jnp.arange(num_trees) 1ab
1682 nbatches = n // batch_size + bool(n % batch_size) 1ab
1683 batch_indices = jnp.arange(n) % nbatches 1ab
1684 return ( 1ab
1685 jnp.zeros((num_trees, size, nbatches), dtype)
1686 .at[tree_indices[:, None], indices, batch_indices]
1687 .add(values)
1688 .sum(axis=2)
1689 )
1692def compute_prec_trees( 1ab
1693 prec_scale: Float32[Array, ' n'],
1694 leaf_indices: UInt[Array, 'num_trees n'],
1695 moves: Moves,
1696 batch_size: int | None,
1697) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
1698 """
1699 Compute the likelihood precision scale in each leaf.
1701 Parameters
1702 ----------
1703 prec_scale
1704 The scale of the precision of the error on each datapoint.
1705 leaf_indices
1706 The index of the leaf each datapoint falls into, with the deeper version
1707 of the tree (post-GROW, pre-PRUNE).
1708 moves
1709 The proposed moves, see `propose_moves`.
1710 batch_size
1711 The data batch size to use for the summation.
1713 Returns
1714 -------
1715 prec_trees : Float32[Array, 'num_trees 2**d']
1716 The likelihood precision scale in each potential or actual leaf node.
1717 precs : Precs
1718 The likelihood precision scale in the nodes involved in the moves.
1719 """
1720 num_trees, tree_size = moves.var_tree.shape 1ab
1721 tree_size *= 2 1ab
1722 tree_indices = jnp.arange(num_trees) 1ab
1724 prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size) 1ab
1726 # prec datapoints in nodes modified by move
1727 left = prec_trees[tree_indices, moves.left] 1ab
1728 right = prec_trees[tree_indices, moves.right] 1ab
1729 precs = Precs(left=left, right=right, total=left + right) 1ab
1731 # write prec into non-leaf node
1732 prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total) 1ab
1734 return prec_trees, precs 1ab
1737def prec_per_leaf( 1ab
1738 prec_scale: Float32[Array, ' n'],
1739 leaf_indices: UInt[Array, 'num_trees n'],
1740 tree_size: int,
1741 batch_size: int | None,
1742) -> Float32[Array, 'num_trees {tree_size}']:
1743 """
1744 Compute the likelihood precision scale in each leaf.
1746 Parameters
1747 ----------
1748 prec_scale
1749 The scale of the precision of the error on each datapoint.
1750 leaf_indices
1751 The index of the leaf each datapoint falls into.
1752 tree_size
1753 The size of the leaf tree array (2 ** d).
1754 batch_size
1755 The data batch size to use for the summation.
1757 Returns
1758 -------
1759 The likelihood precision scale in each leaf node.
1760 """
1761 if batch_size is None: 1761 ↛ 1762line 1761 didn't jump to line 1762 because the condition on line 1761 was never true1ab
1762 return _prec_scan(prec_scale, leaf_indices, tree_size)
1763 else:
1764 return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size) 1ab
1767def _prec_scan( 1ab
1768 prec_scale: Float32[Array, ' n'],
1769 leaf_indices: UInt[Array, 'num_trees n'],
1770 tree_size: int,
1771) -> Float32[Array, 'num_trees {tree_size}']:
1772 def loop(_, leaf_indices):
1773 return None, _aggregate_scatter(
1774 prec_scale, leaf_indices, tree_size, jnp.float32
1775 )
1777 _, prec_trees = lax.scan(loop, None, leaf_indices)
1778 return prec_trees
1781def _prec_vec( 1ab
1782 prec_scale: Float32[Array, ' n'],
1783 leaf_indices: UInt[Array, 'num_trees n'],
1784 tree_size: int,
1785 batch_size: int,
1786) -> Float32[Array, 'num_trees {tree_size}']:
1787 return _aggregate_batched_alltrees( 1ab
1788 prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1789 )
1792def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 1ab
1793 """
1794 Complete non-likelihood MH ratio calculation.
1796 This function adds the probability of choosing a prune move over the grow
1797 move in the inverse transition, and the a priori probability that the
1798 children nodes are leaves.
1800 Parameters
1801 ----------
1802 moves
1803 The proposed moves. Must have already been updated to keep into account
1804 the thresholds on the number of datapoints per node, this happens in
1805 `accept_moves_parallel_stage`.
1806 p_nonterminal
1807 The a priori probability of each node being nonterminal conditional on
1808 its ancestors, including at the maximum depth where it should be zero.
1810 Returns
1811 -------
1812 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1813 """
1814 # can the leaves can be grown?
1815 num_trees, _ = moves.affluence_tree.shape 1ab
1816 tree_indices = jnp.arange(num_trees) 1ab
1817 left_growable = moves.affluence_tree.at[tree_indices, moves.left].get( 1ab
1818 mode='fill', fill_value=False
1819 )
1820 right_growable = moves.affluence_tree.at[tree_indices, moves.right].get( 1ab
1821 mode='fill', fill_value=False
1822 )
1824 # p_prune if grow
1825 other_growable_leaves = moves.num_growable >= 2 1ab
1826 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1ab
1827 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1ab
1829 # p_prune if prune
1830 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab
1832 # select p_prune
1833 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab
1835 # prior probability of both children being terminal
1836 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1ab
1837 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1ab
1838 pt_children = pt_left * pt_right 1ab
1840 return replace( 1ab
1841 moves,
1842 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
1843 partial_ratio=None,
1844 )
1847@vmap_nodoc 1ab
1848def adapt_leaf_trees_to_grow_indices( 1ab
1849 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
1850) -> Float32[Array, 'num_trees 2**d']:
1851 """
1852 Modify leaves such that post-grow indices work on the original tree.
1854 The value of the leaf to grow is copied to what would be its children if the
1855 grow move was accepted.
1857 Parameters
1858 ----------
1859 leaf_trees
1860 The leaf values.
1861 moves
1862 The proposed moves, see `propose_moves`.
1864 Returns
1865 -------
1866 The modified leaf values.
1867 """
1868 values_at_node = leaf_trees[moves.node] 1ab
1869 return ( 1ab
1870 leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
1871 .set(values_at_node)
1872 .at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
1873 .set(values_at_node)
1874 )
1877def precompute_likelihood_terms( 1ab
1878 sigma2: Float32[Array, ''],
1879 sigma_mu2: Float32[Array, ''],
1880 move_precs: Precs | Counts,
1881) -> tuple[PreLkV, PreLk]:
1882 """
1883 Pre-compute terms used in the likelihood ratio of the acceptance step.
1885 Parameters
1886 ----------
1887 sigma2
1888 The error variance, or the global error variance factor is `prec_scale`
1889 is set.
1890 sigma_mu2
1891 The prior variance of each leaf.
1892 move_precs
1893 The likelihood precision scale in the leaves grown or pruned by the
1894 moves, under keys 'left', 'right', and 'total' (left + right).
1896 Returns
1897 -------
1898 prelkv : PreLkV
1899 Dictionary with pre-computed terms of the likelihood ratio, one per
1900 tree.
1901 prelk : PreLk
1902 Dictionary with pre-computed terms of the likelihood ratio, shared by
1903 all trees.
1904 """
1905 sigma2_left = sigma2 + move_precs.left * sigma_mu2 1ab
1906 sigma2_right = sigma2 + move_precs.right * sigma_mu2 1ab
1907 sigma2_total = sigma2 + move_precs.total * sigma_mu2 1ab
1908 prelkv = PreLkV( 1ab
1909 sigma2_left=sigma2_left,
1910 sigma2_right=sigma2_right,
1911 sigma2_total=sigma2_total,
1912 sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
1913 )
1914 return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2)) 1ab
1917def precompute_leaf_terms( 1ab
1918 key: Key[Array, ''],
1919 prec_trees: Float32[Array, 'num_trees 2**d'],
1920 sigma2: Float32[Array, ''],
1921 sigma_mu2: Float32[Array, ''],
1922) -> PreLf:
1923 """
1924 Pre-compute terms used to sample leaves from their posterior.
1926 Parameters
1927 ----------
1928 key
1929 A jax random key.
1930 prec_trees
1931 The likelihood precision scale in each potential or actual leaf node.
1932 sigma2
1933 The error variance, or the global error variance factor if `prec_scale`
1934 is set.
1935 sigma_mu2
1936 The prior variance of each leaf.
1938 Returns
1939 -------
1940 Pre-computed terms for leaf sampling.
1941 """
1942 prec_lk = prec_trees / sigma2 1ab
1943 prec_prior = lax.reciprocal(sigma_mu2) 1ab
1944 var_post = lax.reciprocal(prec_lk + prec_prior) 1ab
1945 z = random.normal(key, prec_trees.shape, sigma2.dtype) 1ab
1946 return PreLf( 1ab
1947 mean_factor=var_post / sigma2,
1948 # | mean = mean_lk * prec_lk * var_post
1949 # | resid_tree = mean_lk * prec_tree -->
1950 # | --> mean_lk = resid_tree / prec_tree (kind of)
1951 # | mean_factor =
1952 # | = mean / resid_tree =
1953 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
1954 # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
1955 # | = var_post / sigma2
1956 centered_leaves=z * jnp.sqrt(var_post),
1957 )
1960def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 1ab
1961 """
1962 Accept/reject the moves one tree at a time.
1964 This is the most performance-sensitive function because it contains all and
1965 only the parts of the algorithm that can not be parallelized across trees.
1967 Parameters
1968 ----------
1969 pso
1970 The output of `accept_moves_parallel_stage`.
1972 Returns
1973 -------
1974 bart : State
1975 A partially updated BART mcmc state.
1976 moves : Moves
1977 The accepted/rejected moves, with `acc` and `to_prune` set.
1978 """
1980 def loop(resid, pt): 1ab
1981 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1ab
1982 resid,
1983 SeqStageInAllTrees(
1984 pso.bart.X,
1985 pso.bart.forest.resid_batch_size,
1986 pso.bart.prec_scale,
1987 pso.bart.forest.log_likelihood is not None,
1988 pso.prelk,
1989 ),
1990 pt,
1991 )
1992 return resid, (leaf_tree, acc, to_prune, lkratio) 1ab
1994 pts = SeqStageInPerTree( 1ab
1995 pso.bart.forest.leaf_tree,
1996 pso.prec_trees,
1997 pso.moves,
1998 pso.move_precs,
1999 pso.bart.forest.leaf_indices,
2000 pso.prelkv,
2001 pso.prelf,
2002 )
2003 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1ab
2005 bart = replace( 1ab
2006 pso.bart,
2007 resid=resid,
2008 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
2009 )
2010 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab
2012 return bart, moves 1ab
2015class SeqStageInAllTrees(Module): 1ab
2016 """
2017 The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
2019 Parameters
2020 ----------
2021 X
2022 The predictors.
2023 resid_batch_size
2024 The batch size for computing the sum of residuals in each leaf.
2025 prec_scale
2026 The scale of the precision of the error on each datapoint. If None, it
2027 is assumed to be 1.
2028 save_ratios
2029 Whether to save the acceptance ratios.
2030 prelk
2031 The pre-computed terms of the likelihood ratio which are shared across
2032 trees.
2033 """
2035 X: UInt[Array, 'p n'] 1ab
2036 resid_batch_size: int | None = field(static=True) 1ab
2037 prec_scale: Float32[Array, ' n'] | None 1ab
2038 save_ratios: bool = field(static=True) 1ab
2039 prelk: PreLk 1ab
2042class SeqStageInPerTree(Module): 1ab
2043 """
2044 The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
2046 Parameters
2047 ----------
2048 leaf_tree
2049 The leaf values of the tree.
2050 prec_tree
2051 The likelihood precision scale in each potential or actual leaf node.
2052 move
2053 The proposed move, see `propose_moves`.
2054 move_precs
2055 The likelihood precision scale in each node modified by the moves.
2056 leaf_indices
2057 The leaf indices for the largest version of the tree compatible with
2058 the move.
2059 prelkv
2060 prelf
2061 The pre-computed terms of the likelihood ratio and leaf sampling which
2062 are specific to the tree.
2063 """
2065 leaf_tree: Float32[Array, ' 2**d'] 1ab
2066 prec_tree: Float32[Array, ' 2**d'] 1ab
2067 move: Moves 1ab
2068 move_precs: Precs | Counts 1ab
2069 leaf_indices: UInt[Array, ' n'] 1ab
2070 prelkv: PreLkV 1ab
2071 prelf: PreLf 1ab
2074def accept_move_and_sample_leaves( 1ab
2075 resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree
2076) -> tuple[
2077 Float32[Array, ' n'],
2078 Float32[Array, ' 2**d'],
2079 Bool[Array, ''],
2080 Bool[Array, ''],
2081 Float32[Array, ''] | None,
2082]:
2083 """
2084 Accept or reject a proposed move and sample the new leaf values.
2086 Parameters
2087 ----------
2088 resid
2089 The residuals (data minus forest value).
2090 at
2091 The inputs that are the same for all trees.
2092 pt
2093 The inputs that are separate for each tree.
2095 Returns
2096 -------
2097 resid : Float32[Array, 'n']
2098 The updated residuals (data minus forest value).
2099 leaf_tree : Float32[Array, '2**d']
2100 The new leaf values of the tree.
2101 acc : Bool[Array, '']
2102 Whether the move was accepted.
2103 to_prune : Bool[Array, '']
2104 Whether, to reflect the acceptance status of the move, the state should
2105 be updated by pruning the leaves involved in the move.
2106 log_lk_ratio : Float32[Array, ''] | None
2107 The logarithm of the likelihood ratio for the move. `None` if not to be
2108 saved.
2109 """
2110 # sum residuals in each leaf, in tree proposed by grow move
2111 if at.prec_scale is None: 1ab
2112 scaled_resid = resid 1ab
2113 else:
2114 scaled_resid = resid * at.prec_scale 1ab
2115 resid_tree = sum_resid( 1ab
2116 scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
2117 )
2119 # subtract starting tree from function
2120 resid_tree += pt.prec_tree * pt.leaf_tree 1ab
2122 # sum residuals in parent node modified by move
2123 resid_left = resid_tree[pt.move.left] 1ab
2124 resid_right = resid_tree[pt.move.right] 1ab
2125 resid_total = resid_left + resid_right 1ab
2126 assert pt.move.node.dtype == jnp.int32 1ab
2127 resid_tree = resid_tree.at[pt.move.node].set(resid_total) 1ab
2129 # compute acceptance ratio
2130 log_lk_ratio = compute_likelihood_ratio( 1ab
2131 resid_total, resid_left, resid_right, pt.prelkv, at.prelk
2132 )
2133 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab
2134 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab
2135 if not at.save_ratios: 1ab
2136 log_lk_ratio = None 1ab
2138 # determine whether to accept the move
2139 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1ab
2141 # compute leaves posterior and sample leaves
2142 mean_post = resid_tree * pt.prelf.mean_factor 1ab
2143 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab
2145 # copy leaves around such that the leaf indices point to the correct leaf
2146 to_prune = acc ^ pt.move.grow 1ab
2147 leaf_tree = ( 1ab
2148 leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)]
2149 .set(leaf_tree[pt.move.node])
2150 .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)]
2151 .set(leaf_tree[pt.move.node])
2152 )
2154 # replace old tree with new tree in function values
2155 resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices] 1ab
2157 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1ab
2160def sum_resid( 1ab
2161 scaled_resid: Float32[Array, ' n'],
2162 leaf_indices: UInt[Array, ' n'],
2163 tree_size: int,
2164 batch_size: int | None,
2165) -> Float32[Array, ' {tree_size}']:
2166 """
2167 Sum the residuals in each leaf.
2169 Parameters
2170 ----------
2171 scaled_resid
2172 The residuals (data minus forest value) multiplied by the error
2173 precision scale.
2174 leaf_indices
2175 The leaf indices of the tree (in which leaf each data point falls into).
2176 tree_size
2177 The size of the tree array (2 ** d).
2178 batch_size
2179 The data batch size for the aggregation. Batching increases numerical
2180 accuracy and parallelism.
2182 Returns
2183 -------
2184 The sum of the residuals at data points in each leaf.
2185 """
2186 if batch_size is None: 1ab
2187 aggr_func = _aggregate_scatter 1ab
2188 else:
2189 aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size) 1ab
2190 return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32) 1ab
2193def _aggregate_batched_onetree( 1ab
2194 values: Shaped[Array, '*'],
2195 indices: Integer[Array, '*'],
2196 size: int,
2197 dtype: jnp.dtype,
2198 batch_size: int,
2199) -> Float32[Array, ' {size}']:
2200 (n,) = indices.shape 1ab
2201 nbatches = n // batch_size + bool(n % batch_size) 1ab
2202 batch_indices = jnp.arange(n) % nbatches 1ab
2203 return ( 1ab
2204 jnp.zeros((size, nbatches), dtype)
2205 .at[indices, batch_indices]
2206 .add(values)
2207 .sum(axis=1)
2208 )
2211def compute_likelihood_ratio( 1ab
2212 total_resid: Float32[Array, ''],
2213 left_resid: Float32[Array, ''],
2214 right_resid: Float32[Array, ''],
2215 prelkv: PreLkV,
2216 prelk: PreLk,
2217) -> Float32[Array, '']:
2218 """
2219 Compute the likelihood ratio of a grow move.
2221 Parameters
2222 ----------
2223 total_resid
2224 left_resid
2225 right_resid
2226 The sum of the residuals (scaled by error precision scale) of the
2227 datapoints falling in the nodes involved in the moves.
2228 prelkv
2229 prelk
2230 The pre-computed terms of the likelihood ratio, see
2231 `precompute_likelihood_terms`.
2233 Returns
2234 -------
2235 The likelihood ratio P(data | new tree) / P(data | old tree).
2236 """
2237 exp_term = prelk.exp_factor * ( 1ab
2238 left_resid * left_resid / prelkv.sigma2_left
2239 + right_resid * right_resid / prelkv.sigma2_right
2240 - total_resid * total_resid / prelkv.sigma2_total
2241 )
2242 return prelkv.sqrt_term + exp_term 1ab
2245def accept_moves_final_stage(bart: State, moves: Moves) -> State: 1ab
2246 """
2247 Post-process the mcmc state after accepting/rejecting the moves.
2249 This function is separate from `accept_moves_sequential_stage` to signal it
2250 can work in parallel across trees.
2252 Parameters
2253 ----------
2254 bart
2255 A partially updated BART mcmc state.
2256 moves
2257 The proposed moves (see `propose_moves`) as updated by
2258 `accept_moves_sequential_stage`.
2260 Returns
2261 -------
2262 The fully updated BART mcmc state.
2263 """
2264 return replace( 1ab
2265 bart,
2266 forest=replace(
2267 bart.forest,
2268 grow_acc_count=jnp.sum(moves.acc & moves.grow),
2269 prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2270 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2271 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
2272 ),
2273 )
2276@vmap_nodoc 1ab
2277def apply_moves_to_leaf_indices( 1ab
2278 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
2279) -> UInt[Array, 'num_trees n']:
2280 """
2281 Update the leaf indices to match the accepted move.
2283 Parameters
2284 ----------
2285 leaf_indices
2286 The index of the leaf each datapoint falls into, if the grow move was
2287 accepted.
2288 moves
2289 The proposed moves (see `propose_moves`), as updated by
2290 `accept_moves_sequential_stage`.
2292 Returns
2293 -------
2294 The updated leaf indices.
2295 """
2296 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab
2297 is_child = (leaf_indices & mask) == moves.left 1ab
2298 return jnp.where( 1ab
2299 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
2300 )
2303@vmap_nodoc 1ab
2304def apply_moves_to_split_trees( 1ab
2305 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2306) -> UInt[Array, 'num_trees 2**(d-1)']:
2307 """
2308 Update the split trees to match the accepted move.
2310 Parameters
2311 ----------
2312 split_tree
2313 The cutpoints of the decision nodes in the initial trees.
2314 moves
2315 The proposed moves (see `propose_moves`), as updated by
2316 `accept_moves_sequential_stage`.
2318 Returns
2319 -------
2320 The updated split trees.
2321 """
2322 assert moves.to_prune is not None 1ab
2323 return ( 1ab
2324 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
2325 .set(moves.grow_split.astype(split_tree.dtype))
2326 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
2327 .set(0)
2328 )
2331def step_sigma(key: Key[Array, ''], bart: State) -> State: 1ab
2332 """
2333 MCMC-update the error variance (factor).
2335 Parameters
2336 ----------
2337 key
2338 A jax random key.
2339 bart
2340 A BART mcmc state.
2342 Returns
2343 -------
2344 The new BART mcmc state, with an updated `sigma2`.
2345 """
2346 resid = bart.resid 1ab
2347 alpha = bart.sigma2_alpha + resid.size / 2 1ab
2348 if bart.prec_scale is None: 1ab
2349 scaled_resid = resid 1ab
2350 else:
2351 scaled_resid = resid * bart.prec_scale 1ab
2352 norm2 = resid @ scaled_resid 1ab
2353 beta = bart.sigma2_beta + norm2 / 2 1ab
2355 sample = random.gamma(key, alpha) 1ab
2356 # random.gamma seems to be slow at compiling, maybe cdf inversion would
2357 # be better, but it's not implemented in jax
2358 return replace(bart, sigma2=beta / sample) 1ab
2361def step_z(key: Key[Array, ''], bart: State) -> State: 1ab
2362 """
2363 MCMC-update the latent variable for binary regression.
2365 Parameters
2366 ----------
2367 key
2368 A jax random key.
2369 bart
2370 A BART MCMC state.
2372 Returns
2373 -------
2374 The updated BART MCMC state.
2375 """
2376 trees_plus_offset = bart.z - bart.resid 1ab
2377 assert bart.y.dtype == bool 1ab
2378 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1ab
2379 z = trees_plus_offset + resid 1ab
2380 return replace(bart, z=z, resid=resid) 1ab