Coverage for src/bartz/mcmcstep.py: 96%
610 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 08:16 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 08:16 +0000
1# bartz/src/bartz/mcmcstep.py
2#
3# Copyright (c) 2024-2025, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""
26Functions that implement the BART posterior MCMC initialization and update step.
28Functions that do MCMC steps operate by taking as input a bart state, and
29outputting a new state. The inputs are not modified.
31The entry points are:
33 - `State`: The dataclass that represents a BART MCMC state.
34 - `init`: Creates an initial `State` from data and configurations.
35 - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
36 - `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`.
37"""
39import math 1ab
40from dataclasses import replace 1ab
41from functools import cache, partial 1ab
42from typing import Any, Literal 1ab
44import jax 1ab
45from equinox import Module, field, tree_at 1ab
46from jax import lax, random 1ab
47from jax import numpy as jnp 1ab
48from jax.scipy.linalg import solve_triangular 1ab
49from jax.scipy.special import gammaln, logsumexp 1ab
50from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt 1ab
52from bartz import grove 1ab
53from bartz.jaxext import ( 1ab
54 minimal_unsigned_dtype,
55 split,
56 truncated_normal_onesided,
57 vmap_nodoc,
58)
61class Forest(Module): 1ab
62 """
63 Represents the MCMC state of a sum of trees.
65 Parameters
66 ----------
67 leaf_tree
68 The leaf values.
69 var_tree
70 The decision axes.
71 split_tree
72 The decision boundaries.
73 affluence_tree
74 Marks leaves that can be grown.
75 max_split
76 The maximum split index for each predictor.
77 blocked_vars
78 Indices of variables that are not used. This shall include at least
79 the `i` such that ``max_split[i] == 0``, otherwise behavior is
80 undefined.
81 p_nonterminal
82 The prior probability of each node being nonterminal, conditional on
83 its ancestors. Includes the nodes at maximum depth which should be set
84 to 0.
85 p_propose_grow
86 The unnormalized probability of picking a leaf for a grow proposal.
87 leaf_indices
88 The index of the leaf each datapoints falls into, for each tree.
89 min_points_per_decision_node
90 The minimum number of data points in a decision node.
91 min_points_per_leaf
92 The minimum number of data points in a leaf node.
93 resid_batch_size
94 count_batch_size
95 The data batch sizes for computing the sufficient statistics. If `None`,
96 they are computed with no batching.
97 log_trans_prior
98 The log transition and prior Metropolis-Hastings ratio for the
99 proposed move on each tree.
100 log_likelihood
101 The log likelihood ratio.
102 grow_prop_count
103 prune_prop_count
104 The number of grow/prune proposals made during one full MCMC cycle.
105 grow_acc_count
106 prune_acc_count
107 The number of grow/prune moves accepted during one full MCMC cycle.
108 sigma_mu2
109 The prior variance of a leaf, conditional on the tree structure.
110 log_s
111 The logarithm of the prior probability for choosing a variable to split
112 along in a decision rule, conditional on the ancestors. Not normalized.
113 If `None`, use a uniform distribution.
114 theta
115 The concentration parameter for the Dirichlet prior on the variable
116 distribution `s`. Required only to update `s`.
117 a
118 b
119 rho
120 Parameters of the prior on `theta`. Required only to sample `theta`.
121 See `step_theta`.
122 """
124 leaf_tree: Float32[Array, 'num_trees 2**d'] 1ab
125 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
126 split_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
127 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
128 max_split: UInt[Array, ' p'] 1ab
129 blocked_vars: UInt[Array, ' k'] | None 1ab
130 p_nonterminal: Float32[Array, ' 2**d'] 1ab
131 p_propose_grow: Float32[Array, ' 2**(d-1)'] 1ab
132 leaf_indices: UInt[Array, 'num_trees n'] 1ab
133 min_points_per_decision_node: Int32[Array, ''] | None 1ab
134 min_points_per_leaf: Int32[Array, ''] | None 1ab
135 resid_batch_size: int | None = field(static=True) 1ab
136 count_batch_size: int | None = field(static=True) 1ab
137 log_trans_prior: Float32[Array, ' num_trees'] | None 1ab
138 log_likelihood: Float32[Array, ' num_trees'] | None 1ab
139 grow_prop_count: Int32[Array, ''] 1ab
140 prune_prop_count: Int32[Array, ''] 1ab
141 grow_acc_count: Int32[Array, ''] 1ab
142 prune_acc_count: Int32[Array, ''] 1ab
143 sigma_mu2: Float32[Array, ''] 1ab
144 log_s: Float32[Array, ' p'] | None 1ab
145 theta: Float32[Array, ''] | None 1ab
146 a: Float32[Array, ''] | None 1ab
147 b: Float32[Array, ''] | None 1ab
148 rho: Float32[Array, ''] | None 1ab
151class State(Module): 1ab
152 """
153 Represents the MCMC state of BART.
155 Parameters
156 ----------
157 X
158 The predictors.
159 y
160 The response. If the data type is `bool`, the model is binary regression.
161 resid
162 The residuals (`y` or `z` minus sum of trees).
163 z
164 The latent variable for binary regression. `None` in continuous
165 regression.
166 offset
167 Constant shift added to the sum of trees.
168 sigma2
169 The error variance. `None` in binary regression.
170 prec_scale
171 The scale on the error precision, i.e., ``1 / error_scale ** 2``.
172 `None` in binary regression.
173 sigma2_alpha
174 sigma2_beta
175 The shape and scale parameters of the inverse gamma prior on the noise
176 variance. `None` in binary regression.
177 forest
178 The sum of trees model.
179 """
181 X: UInt[Array, 'p n'] 1ab
182 y: Float32[Array, ' n'] | Bool[Array, ' n'] 1ab
183 z: None | Float32[Array, ' n'] 1ab
184 offset: Float32[Array, ''] 1ab
185 resid: Float32[Array, ' n'] 1ab
186 sigma2: Float32[Array, ''] | None 1ab
187 prec_scale: Float32[Array, ' n'] | None 1ab
188 sigma2_alpha: Float32[Array, ''] | None 1ab
189 sigma2_beta: Float32[Array, ''] | None 1ab
190 forest: Forest 1ab
193def init( 1ab
194 *,
195 X: UInt[Any, 'p n'],
196 y: Float32[Any, ' n'] | Bool[Any, ' n'],
197 offset: float | Float32[Any, ''] = 0.0,
198 max_split: UInt[Any, ' p'],
199 num_trees: int,
200 p_nonterminal: Float32[Any, ' d-1'],
201 sigma_mu2: float | Float32[Any, ''],
202 sigma2_alpha: float | Float32[Any, ''] | None = None,
203 sigma2_beta: float | Float32[Any, ''] | None = None,
204 error_scale: Float32[Any, ' n'] | None = None,
205 min_points_per_decision_node: int | Integer[Any, ''] | None = None,
206 resid_batch_size: int | None | Literal['auto'] = 'auto',
207 count_batch_size: int | None | Literal['auto'] = 'auto',
208 save_ratios: bool = False,
209 filter_splitless_vars: bool = True,
210 min_points_per_leaf: int | Integer[Any, ''] | None = None,
211 log_s: Float32[Any, ' p'] | None = None,
212 theta: float | Float32[Any, ''] | None = None,
213 a: float | Float32[Any, ''] | None = None,
214 b: float | Float32[Any, ''] | None = None,
215 rho: float | Float32[Any, ''] | None = None,
216) -> State:
217 """
218 Make a BART posterior sampling MCMC initial state.
220 Parameters
221 ----------
222 X
223 The predictors. Note this is trasposed compared to the usual convention.
224 y
225 The response. If the data type is `bool`, the regression model is binary
226 regression with probit.
227 offset
228 Constant shift added to the sum of trees. 0 if not specified.
229 max_split
230 The maximum split index for each variable. All split ranges start at 1.
231 num_trees
232 The number of trees in the forest.
233 p_nonterminal
234 The probability of a nonterminal node at each depth. The maximum depth
235 of trees is fixed by the length of this array.
236 sigma_mu2
237 The prior variance of a leaf, conditional on the tree structure. The
238 prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
239 prior mean of leaves is always zero.
240 sigma2_alpha
241 sigma2_beta
242 The shape and scale parameters of the inverse gamma prior on the error
243 variance. Leave unspecified for binary regression.
244 error_scale
245 Each error is scaled by the corresponding factor in `error_scale`, so
246 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
247 Not supported for binary regression. If not specified, defaults to 1 for
248 all points, but potentially skipping calculations.
249 min_points_per_decision_node
250 The minimum number of data points in a decision node. 0 if not
251 specified.
252 resid_batch_size
253 count_batch_size
254 The batch sizes, along datapoints, for summing the residuals and
255 counting the number of datapoints in each leaf. `None` for no batching.
256 If 'auto', pick a value based on the device of `y`, or the default
257 device.
258 save_ratios
259 Whether to save the Metropolis-Hastings ratios.
260 filter_splitless_vars
261 Whether to check `max_split` for variables without available cutpoints.
262 If any are found, they are put into a list of variables to exclude from
263 the MCMC. If `False`, no check is performed, but the results may be
264 wrong if any variable is blocked. The function is jax-traceable only
265 if this is set to `False`.
266 min_points_per_leaf
267 The minimum number of datapoints in a leaf node. 0 if not specified.
268 Unlike `min_points_per_decision_node`, this constraint is not taken into
269 account in the Metropolis-Hastings ratio because it would be expensive
270 to compute. Grow moves that would violate this constraint are vetoed.
271 This parameter is independent of `min_points_per_decision_node` and
272 there is no check that they are coherent. It makes sense to set
273 ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
274 log_s
275 The logarithm of the prior probability for choosing a variable to split
276 along in a decision rule, conditional on the ancestors. Not normalized.
277 If not specified, use a uniform distribution. If not specified and
278 `theta` or `rho`, `a`, `b` are, it's initialized automatically.
279 theta
280 The concentration parameter for the Dirichlet prior on `s`. Required
281 only to update `log_s`. If not specified, and `rho`, `a`, `b` are
282 specified, it's initialized automatically.
283 a
284 b
285 rho
286 Parameters of the prior on `theta`. Required only to sample `theta`.
288 Returns
289 -------
290 An initialized BART MCMC state.
292 Raises
293 ------
294 ValueError
295 If `y` is boolean and arguments unused in binary regression are set.
297 Notes
298 -----
299 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
300 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
301 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
302 integers in the range ``[0, 1, ..., max_split[i]]``.
303 """
304 p_nonterminal = jnp.asarray(p_nonterminal) 1ab
305 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1ab
306 max_depth = p_nonterminal.size 1ab
308 @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1ab
309 def make_forest(max_depth, dtype): 1ab
310 return grove.make_tree(max_depth, dtype) 1ab
312 y = jnp.asarray(y) 1ab
313 offset = jnp.asarray(offset) 1ab
315 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 1ab
316 resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
317 )
319 is_binary = y.dtype == bool 1ab
320 if is_binary: 1ab
321 if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,): 321 ↛ 322line 321 didn't jump to line 322 because the condition on line 321 was never true1ab
322 msg = (
323 'error_scale, sigma2_alpha, and sigma2_beta must be set '
324 ' to `None` for binary regression.'
325 )
326 raise ValueError(msg)
327 sigma2 = None 1ab
328 else:
329 sigma2_alpha = jnp.asarray(sigma2_alpha) 1ab
330 sigma2_beta = jnp.asarray(sigma2_beta) 1ab
331 sigma2 = sigma2_beta / sigma2_alpha 1ab
333 max_split = jnp.asarray(max_split) 1ab
335 if filter_splitless_vars: 1ab
336 (blocked_vars,) = jnp.nonzero(max_split == 0) 1ab
337 blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size)) 1ab
338 # see `fully_used_variables` for the type cast
339 else:
340 blocked_vars = None 1ab
342 # check and initialize sparsity parameters
343 if not _all_none_or_not_none(rho, a, b): 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true1ab
344 msg = 'rho, a, b are not either all `None` or all set'
345 raise ValueError(msg)
346 if theta is None and rho is not None: 1ab
347 theta = rho 1ab
348 if log_s is None and theta is not None: 1ab
349 log_s = jnp.zeros(max_split.size) 1ab
351 return State( 1ab
352 X=jnp.asarray(X),
353 y=y,
354 z=jnp.full(y.shape, offset) if is_binary else None,
355 offset=offset,
356 resid=jnp.zeros(y.shape) if is_binary else y - offset,
357 sigma2=sigma2,
358 prec_scale=(
359 None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
360 ),
361 sigma2_alpha=sigma2_alpha,
362 sigma2_beta=sigma2_beta,
363 forest=Forest(
364 leaf_tree=make_forest(max_depth, jnp.float32),
365 var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)),
366 split_tree=make_forest(max_depth - 1, max_split.dtype),
367 affluence_tree=(
368 make_forest(max_depth - 1, bool)
369 .at[:, 1]
370 .set(
371 True
372 if min_points_per_decision_node is None
373 else y.size >= min_points_per_decision_node
374 )
375 ),
376 blocked_vars=blocked_vars,
377 max_split=max_split,
378 grow_prop_count=jnp.zeros((), int),
379 grow_acc_count=jnp.zeros((), int),
380 prune_prop_count=jnp.zeros((), int),
381 prune_acc_count=jnp.zeros((), int),
382 p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)],
383 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
384 leaf_indices=jnp.ones(
385 (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
386 ),
387 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
388 min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
389 resid_batch_size=resid_batch_size,
390 count_batch_size=count_batch_size,
391 log_trans_prior=jnp.zeros(num_trees) if save_ratios else None,
392 log_likelihood=jnp.zeros(num_trees) if save_ratios else None,
393 sigma_mu2=jnp.asarray(sigma_mu2),
394 log_s=_asarray_or_none(log_s),
395 theta=_asarray_or_none(theta),
396 rho=_asarray_or_none(rho),
397 a=_asarray_or_none(a),
398 b=_asarray_or_none(b),
399 ),
400 )
403def _all_none_or_not_none(*args): 1ab
404 is_none = [x is None for x in args] 1ab
405 return all(is_none) or not any(is_none) 1ab
408def _asarray_or_none(x): 1ab
409 if x is None: 1ab
410 return None 1ab
411 return jnp.asarray(x) 1ab
414def _choose_suffstat_batch_size( 1ab
415 resid_batch_size, count_batch_size, y, forest_size
416) -> tuple[int | None, ...]:
417 @cache 1ab
418 def get_platform(): 1ab
419 try: 1ab
420 device = y.devices().pop() 1ab
421 except jax.errors.ConcretizationTypeError: 1ab
422 device = jax.devices()[0] 1ab
423 platform = device.platform 1ab
424 if platform not in ('cpu', 'gpu'): 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true1ab
425 msg = f'Unknown platform: {platform}'
426 raise KeyError(msg)
427 return platform 1ab
429 if resid_batch_size == 'auto': 1ab
430 platform = get_platform() 1ab
431 n = max(1, y.size) 1ab
432 if platform == 'cpu': 432 ↛ 434line 432 didn't jump to line 434 because the condition on line 432 was always true1ab
433 resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6 1ab
434 elif platform == 'gpu':
435 resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3
436 resid_batch_size = max(1, resid_batch_size) 1ab
438 if count_batch_size == 'auto': 1ab
439 platform = get_platform() 1ab
440 if platform == 'cpu': 440 ↛ 442line 440 didn't jump to line 442 because the condition on line 440 was always true1ab
441 count_batch_size = None 1ab
442 elif platform == 'gpu':
443 n = max(1, y.size)
444 count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2
445 # /4 is good on V100, /2 on L4/T4, still haven't tried A100
446 max_memory = 2**29
447 itemsize = 4
448 min_batch_size = math.ceil(forest_size * itemsize * n / max_memory)
449 count_batch_size = max(count_batch_size, min_batch_size)
450 count_batch_size = max(1, count_batch_size)
452 return resid_batch_size, count_batch_size 1ab
455@jax.jit 1ab
456def step(key: Key[Array, ''], bart: State) -> State: 1ab
457 """
458 Do one MCMC step.
460 Parameters
461 ----------
462 key
463 A jax random key.
464 bart
465 A BART mcmc state, as created by `init`.
467 Returns
468 -------
469 The new BART mcmc state.
470 """
471 keys = split(key) 1ab
473 if bart.y.dtype == bool: # binary regression 1ab
474 bart = replace(bart, sigma2=jnp.float32(1)) 1ab
475 bart = step_trees(keys.pop(), bart) 1ab
476 bart = replace(bart, sigma2=None) 1ab
477 return step_z(keys.pop(), bart) 1ab
479 else: # continuous regression
480 bart = step_trees(keys.pop(), bart) 1ab
481 return step_sigma(keys.pop(), bart) 1ab
484def step_trees(key: Key[Array, ''], bart: State) -> State: 1ab
485 """
486 Forest sampling step of BART MCMC.
488 Parameters
489 ----------
490 key
491 A jax random key.
492 bart
493 A BART mcmc state, as created by `init`.
495 Returns
496 -------
497 The new BART mcmc state.
499 Notes
500 -----
501 This function zeroes the proposal counters.
502 """
503 keys = split(key) 1ab
504 moves = propose_moves(keys.pop(), bart.forest) 1ab
505 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab
508class Moves(Module): 1ab
509 """
510 Moves proposed to modify each tree.
512 Parameters
513 ----------
514 allowed
515 Whether there is a possible move. If `False`, the other values may not
516 make sense. The only case in which a move is marked as allowed but is
517 then vetoed is if it does not satisfy `min_points_per_leaf`, which for
518 efficiency is implemented post-hoc without changing the rest of the
519 MCMC logic.
520 grow
521 Whether the move is a grow move or a prune move.
522 num_growable
523 The number of growable leaves in the original tree.
524 node
525 The index of the leaf to grow or node to prune.
526 left
527 right
528 The indices of the children of 'node'.
529 partial_ratio
530 A factor of the Metropolis-Hastings ratio of the move. It lacks the
531 likelihood ratio, the probability of proposing the prune move, and the
532 probability that the children of the modified node are terminal. If the
533 move is PRUNE, the ratio is inverted. `None` once
534 `log_trans_prior_ratio` has been computed.
535 log_trans_prior_ratio
536 The logarithm of the product of the transition and prior terms of the
537 Metropolis-Hastings ratio for the acceptance of the proposed move.
538 `None` if not yet computed. If PRUNE, the log-ratio is negated.
539 grow_var
540 The decision axes of the new rules.
541 grow_split
542 The decision boundaries of the new rules.
543 var_tree
544 The updated decision axes of the trees, valid whatever move.
545 affluence_tree
546 A partially updated `affluence_tree`, marking non-leaf nodes that would
547 become leaves if the move was accepted. This mark initially (out of
548 `propose_moves`) takes into account if there would be available decision
549 rules to grow the leaf, and whether there are enough datapoints in the
550 node is marked in `accept_moves_parallel_stage`.
551 logu
552 The logarithm of a uniform (0, 1] random variable to be used to
553 accept the move. It's in (-oo, 0].
554 acc
555 Whether the move was accepted. `None` if not yet computed.
556 to_prune
557 Whether the final operation to apply the move is pruning. This indicates
558 an accepted prune move or a rejected grow move. `None` if not yet
559 computed.
560 """
562 allowed: Bool[Array, ' num_trees'] 1ab
563 grow: Bool[Array, ' num_trees'] 1ab
564 num_growable: UInt[Array, ' num_trees'] 1ab
565 node: UInt[Array, ' num_trees'] 1ab
566 left: UInt[Array, ' num_trees'] 1ab
567 right: UInt[Array, ' num_trees'] 1ab
568 partial_ratio: Float32[Array, ' num_trees'] | None 1ab
569 log_trans_prior_ratio: None | Float32[Array, ' num_trees'] 1ab
570 grow_var: UInt[Array, ' num_trees'] 1ab
571 grow_split: UInt[Array, ' num_trees'] 1ab
572 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
573 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
574 logu: Float32[Array, ' num_trees'] 1ab
575 acc: None | Bool[Array, ' num_trees'] 1ab
576 to_prune: None | Bool[Array, ' num_trees'] 1ab
579def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: 1ab
580 """
581 Propose moves for all the trees.
583 There are two types of moves: GROW (convert a leaf to a decision node and
584 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
585 leaf, deleting its children).
587 Parameters
588 ----------
589 key
590 A jax random key.
591 forest
592 The `forest` field of a BART MCMC state.
594 Returns
595 -------
596 The proposed move for each tree.
597 """
598 num_trees, _ = forest.leaf_tree.shape 1ab
599 keys = split(key, 1 + 2 * num_trees) 1ab
601 # compute moves
602 grow_moves = propose_grow_moves( 1ab
603 keys.pop(num_trees),
604 forest.var_tree,
605 forest.split_tree,
606 forest.affluence_tree,
607 forest.max_split,
608 forest.blocked_vars,
609 forest.p_nonterminal,
610 forest.p_propose_grow,
611 forest.log_s,
612 )
613 prune_moves = propose_prune_moves( 1ab
614 keys.pop(num_trees),
615 forest.split_tree,
616 grow_moves.affluence_tree,
617 forest.p_nonterminal,
618 forest.p_propose_grow,
619 )
621 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1ab
623 # choose between grow or prune
624 p_grow = jnp.where( 1ab
625 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
626 )
627 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1ab
629 # compute children indices
630 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1ab
631 left = node << 1 1ab
632 right = left + 1 1ab
634 return Moves( 1ab
635 allowed=grow_moves.allowed | prune_moves.allowed,
636 grow=grow,
637 num_growable=grow_moves.num_growable,
638 node=node,
639 left=left,
640 right=right,
641 partial_ratio=jnp.where(
642 grow, grow_moves.partial_ratio, prune_moves.partial_ratio
643 ),
644 log_trans_prior_ratio=None, # will be set in complete_ratio
645 grow_var=grow_moves.var,
646 grow_split=grow_moves.split,
647 # var_tree does not need to be updated if prune
648 var_tree=grow_moves.var_tree,
649 # affluence_tree is updated for both moves unconditionally, prune last
650 affluence_tree=prune_moves.affluence_tree,
651 logu=jnp.log1p(-exp1mlogu),
652 acc=None, # will be set in accept_moves_sequential_stage
653 to_prune=None, # will be set in accept_moves_sequential_stage
654 )
657class GrowMoves(Module): 1ab
658 """
659 Represent a proposed grow move for each tree.
661 Parameters
662 ----------
663 allowed
664 Whether the move is allowed for proposal.
665 num_growable
666 The number of leaves that can be proposed for grow.
667 node
668 The index of the leaf to grow. ``2 ** d`` if there are no growable
669 leaves.
670 var
671 split
672 The decision axis and boundary of the new rule.
673 partial_ratio
674 A factor of the Metropolis-Hastings ratio of the move. It lacks
675 the likelihood ratio and the probability of proposing the prune
676 move.
677 var_tree
678 The updated decision axes of the tree.
679 affluence_tree
680 A partially updated `affluence_tree` that marks each new leaf that
681 would be produced as `True` if it would have available decision rules.
682 """
684 allowed: Bool[Array, ' num_trees'] 1ab
685 num_growable: UInt[Array, ' num_trees'] 1ab
686 node: UInt[Array, ' num_trees'] 1ab
687 var: UInt[Array, ' num_trees'] 1ab
688 split: UInt[Array, ' num_trees'] 1ab
689 partial_ratio: Float32[Array, ' num_trees'] 1ab
690 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
691 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
694@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None)) 1ab
695def propose_grow_moves( 1ab
696 key: Key[Array, ' num_trees'],
697 var_tree: UInt[Array, 'num_trees 2**(d-1)'],
698 split_tree: UInt[Array, 'num_trees 2**(d-1)'],
699 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
700 max_split: UInt[Array, ' p'],
701 blocked_vars: Int32[Array, ' k'] | None,
702 p_nonterminal: Float32[Array, ' 2**d'],
703 p_propose_grow: Float32[Array, ' 2**(d-1)'],
704 log_s: Float32[Array, ' p'] | None,
705) -> GrowMoves:
706 """
707 Propose a GROW move for each tree.
709 A GROW move picks a leaf node and converts it to a non-terminal node with
710 two leaf children.
712 Parameters
713 ----------
714 key
715 A jax random key.
716 var_tree
717 The splitting axes of the tree.
718 split_tree
719 The splitting points of the tree.
720 affluence_tree
721 Whether each leaf has enough points to be grown.
722 max_split
723 The maximum split index for each variable.
724 blocked_vars
725 The indices of the variables that have no available cutpoints.
726 p_nonterminal
727 The a priori probability of a node to be nonterminal conditional on the
728 ancestors, including at the maximum depth where it should be zero.
729 p_propose_grow
730 The unnormalized probability of choosing a leaf to grow.
731 log_s
732 Unnormalized log-probability used to choose a variable to split on
733 amongst the available ones.
735 Returns
736 -------
737 An object representing the proposed move.
739 Notes
740 -----
741 The move is not proposed if each leaf is already at maximum depth, or has
742 less datapoints than the requested threshold `min_points_per_decision_node`,
743 or it does not have any available decision rules given its ancestors. This
744 is marked by setting `allowed` to `False` and `num_growable` to 0.
745 """
746 keys = split(key, 3) 1ab
748 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1ab
749 keys.pop(), split_tree, affluence_tree, p_propose_grow
750 )
752 # sample a decision rule
753 var, num_available_var = choose_variable( 1ab
754 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
755 )
756 split_idx, l, r = choose_split( 1ab
757 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
758 )
760 # determine if the new leaves would have available decision rules; if the
761 # move is blocked, these values may not make sense
762 left_growable = right_growable = num_available_var > 1 1ab
763 left_growable |= l < split_idx 1ab
764 right_growable |= split_idx + 1 < r 1ab
765 left = leaf_to_grow << 1 1ab
766 right = left + 1 1ab
767 affluence_tree = affluence_tree.at[left].set(left_growable) 1ab
768 affluence_tree = affluence_tree.at[right].set(right_growable) 1ab
770 ratio = compute_partial_ratio( 1ab
771 prob_choose, num_prunable, p_nonterminal, leaf_to_grow
772 )
774 return GrowMoves( 1ab
775 allowed=num_growable > 0,
776 num_growable=num_growable,
777 node=leaf_to_grow,
778 var=var,
779 split=split_idx,
780 partial_ratio=ratio,
781 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
782 affluence_tree=affluence_tree,
783 )
786def choose_leaf( 1ab
787 key: Key[Array, ''],
788 split_tree: UInt[Array, ' 2**(d-1)'],
789 affluence_tree: Bool[Array, ' 2**(d-1)'],
790 p_propose_grow: Float32[Array, ' 2**(d-1)'],
791) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
792 """
793 Choose a leaf node to grow in a tree.
795 Parameters
796 ----------
797 key
798 A jax random key.
799 split_tree
800 The splitting points of the tree.
801 affluence_tree
802 Whether a leaf has enough points that it could be split into two leaves
803 satisfying the `min_points_per_leaf` requirement.
804 p_propose_grow
805 The unnormalized probability of choosing a leaf to grow.
807 Returns
808 -------
809 leaf_to_grow : Int32[Array, '']
810 The index of the leaf to grow. If ``num_growable == 0``, return
811 ``2 ** d``.
812 num_growable : Int32[Array, '']
813 The number of leaf nodes that can be grown, i.e., are nonterminal
814 and have at least twice `min_points_per_leaf`.
815 prob_choose : Float32[Array, '']
816 The (normalized) probability that this function had to choose that
817 specific leaf, given the arguments.
818 num_prunable : Int32[Array, '']
819 The number of leaf parents that could be pruned, after converting the
820 selected leaf to a non-terminal node.
821 """
822 is_growable = growable_leaves(split_tree, affluence_tree) 1ab
823 num_growable = jnp.count_nonzero(is_growable) 1ab
824 distr = jnp.where(is_growable, p_propose_grow, 0) 1ab
825 leaf_to_grow, distr_norm = categorical(key, distr) 1ab
826 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1ab
827 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1ab
828 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1ab
829 num_prunable = jnp.count_nonzero(is_parent) 1ab
830 return leaf_to_grow, num_growable, prob_choose, num_prunable 1ab
833def growable_leaves( 1ab
834 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
835) -> Bool[Array, ' 2**(d-1)']:
836 """
837 Return a mask indicating the leaf nodes that can be proposed for growth.
839 The condition is that a leaf is not at the bottom level, has available
840 decision rules given its ancestors, and has at least
841 `min_points_per_decision_node` points.
843 Parameters
844 ----------
845 split_tree
846 The splitting points of the tree.
847 affluence_tree
848 Marks leaves that can be grown.
850 Returns
851 -------
852 The mask indicating the leaf nodes that can be proposed to grow.
854 Notes
855 -----
856 This function needs `split_tree` and not just `affluence_tree` because
857 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
858 """
859 return grove.is_actual_leaf(split_tree) & affluence_tree 1ab
862def categorical( 1ab
863 key: Key[Array, ''], distr: Float32[Array, ' n']
864) -> tuple[Int32[Array, ''], Float32[Array, '']]:
865 """
866 Return a random integer from an arbitrary distribution.
868 Parameters
869 ----------
870 key
871 A jax random key.
872 distr
873 An unnormalized probability distribution.
875 Returns
876 -------
877 u : Int32[Array, '']
878 A random integer in the range ``[0, n)``. If all probabilities are zero,
879 return ``n``.
880 norm : Float32[Array, '']
881 The sum of `distr`.
883 Notes
884 -----
885 This function uses a cumsum instead of the Gumbel trick, so it's ok only
886 for small ranges with probabilities well greater than 0.
887 """
888 ecdf = jnp.cumsum(distr) 1ab
889 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1ab
890 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1ab
893def choose_variable( 1ab
894 key: Key[Array, ''],
895 var_tree: UInt[Array, ' 2**(d-1)'],
896 split_tree: UInt[Array, ' 2**(d-1)'],
897 max_split: UInt[Array, ' p'],
898 leaf_index: Int32[Array, ''],
899 blocked_vars: Int32[Array, ' k'] | None,
900 log_s: Float32[Array, ' p'] | None,
901) -> tuple[Int32[Array, ''], Int32[Array, '']]:
902 """
903 Choose a variable to split on for a new non-terminal node.
905 Parameters
906 ----------
907 key
908 A jax random key.
909 var_tree
910 The variable indices of the tree.
911 split_tree
912 The splitting points of the tree.
913 max_split
914 The maximum split index for each variable.
915 leaf_index
916 The index of the leaf to grow.
917 blocked_vars
918 The indices of the variables that have no available cutpoints. If
919 `None`, all variables are assumed unblocked.
920 log_s
921 The logarithm of the prior probability for choosing a variable. If
922 `None`, use a uniform distribution.
924 Returns
925 -------
926 var : Int32[Array, '']
927 The index of the variable to split on.
928 num_available_var : Int32[Array, '']
929 The number of variables with available decision rules `var` was chosen
930 from.
931 """
932 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1ab
933 if blocked_vars is not None: 1ab
934 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1ab
936 if log_s is None: 1ab
937 return randint_exclude(key, max_split.size, var_to_ignore) 1ab
938 else:
939 return categorical_exclude(key, log_s, var_to_ignore) 1ab
942def fully_used_variables( 1ab
943 var_tree: UInt[Array, ' 2**(d-1)'],
944 split_tree: UInt[Array, ' 2**(d-1)'],
945 max_split: UInt[Array, ' p'],
946 leaf_index: Int32[Array, ''],
947) -> UInt[Array, ' d-2']:
948 """
949 Find variables in the ancestors of a node that have an empty split range.
951 Parameters
952 ----------
953 var_tree
954 The variable indices of the tree.
955 split_tree
956 The splitting points of the tree.
957 max_split
958 The maximum split index for each variable.
959 leaf_index
960 The index of the node, assumed to be valid for `var_tree`.
962 Returns
963 -------
964 The indices of the variables that have an empty split range.
966 Notes
967 -----
968 The number of unused variables is not known in advance. Unused values in the
969 array are filled with `p`. The fill values are not guaranteed to be placed
970 in any particular order, and variables may appear more than once.
971 """
972 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1ab
973 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1ab
974 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1ab
975 num_split = r - l 1ab
976 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1ab
977 # the type of var_to_ignore is already sufficient to hold max_split.size,
978 # see ancestor_variables()
981def ancestor_variables( 1ab
982 var_tree: UInt[Array, ' 2**(d-1)'],
983 max_split: UInt[Array, ' p'],
984 node_index: Int32[Array, ''],
985) -> UInt[Array, ' d-2']:
986 """
987 Return the list of variables in the ancestors of a node.
989 Parameters
990 ----------
991 var_tree
992 The variable indices of the tree.
993 max_split
994 The maximum split index for each variable. Used only to get `p`.
995 node_index
996 The index of the node, assumed to be valid for `var_tree`.
998 Returns
999 -------
1000 The variable indices of the ancestors of the node.
1002 Notes
1003 -----
1004 The ancestors are the nodes going from the root to the parent of the node.
1005 The number of ancestors is not known at tracing time; unused spots in the
1006 output array are filled with `p`.
1007 """
1008 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab
1009 ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size)) 1ab
1010 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1ab
1012 def loop(carry, _): 1ab
1013 i, index, ancestor_vars = carry 1ab
1014 index >>= 1 1ab
1015 var = var_tree[index] 1ab
1016 var = jnp.where(index, var, max_split.size) 1ab
1017 ancestor_vars = ancestor_vars.at[i].set(var) 1ab
1018 return (i - 1, index, ancestor_vars), None 1ab
1020 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1ab
1021 return ancestor_vars 1ab
1024def split_range( 1ab
1025 var_tree: UInt[Array, ' 2**(d-1)'],
1026 split_tree: UInt[Array, ' 2**(d-1)'],
1027 max_split: UInt[Array, ' p'],
1028 node_index: Int32[Array, ''],
1029 ref_var: Int32[Array, ''],
1030) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1031 """
1032 Return the range of allowed splits for a variable at a given node.
1034 Parameters
1035 ----------
1036 var_tree
1037 The variable indices of the tree.
1038 split_tree
1039 The splitting points of the tree.
1040 max_split
1041 The maximum split index for each variable.
1042 node_index
1043 The index of the node, assumed to be valid for `var_tree`.
1044 ref_var
1045 The variable for which to measure the split range.
1047 Returns
1048 -------
1049 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
1050 """
1051 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab
1052 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1ab
1053 jnp.int32
1054 )
1055 carry = jnp.int32(0), initial_r, node_index 1ab
1057 def loop(carry, _): 1ab
1058 l, r, index = carry 1ab
1059 right_child = (index & 1).astype(bool) 1ab
1060 index >>= 1 1ab
1061 split = split_tree[index] 1ab
1062 cond = (var_tree[index] == ref_var) & index.astype(bool) 1ab
1063 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1ab
1064 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1ab
1065 return (l, r, index), None 1ab
1067 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1ab
1068 return l + 1, r 1ab
1071def randint_exclude( 1ab
1072 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
1073) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1074 """
1075 Return a random integer in a range, excluding some values.
1077 Parameters
1078 ----------
1079 key
1080 A jax random key.
1081 sup
1082 The exclusive upper bound of the range.
1083 exclude
1084 The values to exclude from the range. Values greater than or equal to
1085 `sup` are ignored. Values can appear more than once.
1087 Returns
1088 -------
1089 u : Int32[Array, '']
1090 A random integer `u` in the range ``[0, sup)`` such that ``u not in
1091 exclude``.
1092 num_allowed : Int32[Array, '']
1093 The number of integers in the range that were not excluded.
1095 Notes
1096 -----
1097 If all values in the range are excluded, return `sup`.
1098 """
1099 exclude, num_allowed = _process_exclude(sup, exclude) 1ab
1100 u = random.randint(key, (), 0, num_allowed) 1ab
1102 def loop(u, i_excluded): 1ab
1103 return jnp.where(i_excluded <= u, u + 1, u), None 1ab
1105 u, _ = lax.scan(loop, u, exclude) 1ab
1106 return u, num_allowed 1ab
1109def _process_exclude(sup, exclude): 1ab
1110 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1ab
1111 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1ab
1112 return exclude, num_allowed 1ab
1115def categorical_exclude( 1ab
1116 key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
1117) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1118 """
1119 Draw from a categorical distribution, excluding a set of values.
1121 Parameters
1122 ----------
1123 key
1124 A jax random key.
1125 logits
1126 The unnormalized log-probabilities of each category.
1127 exclude
1128 The values to exclude from the range [0, k). Values greater than or
1129 equal to `logits.size` are ignored. Values can appear more than once.
1131 Returns
1132 -------
1133 u : Int32[Array, '']
1134 A random integer in the range ``[0, k)`` such that ``u not in exclude``.
1135 num_allowed : Int32[Array, '']
1136 The number of integers in the range that were not excluded.
1138 Notes
1139 -----
1140 If all values in the range are excluded, the result is unspecified.
1141 """
1142 exclude, num_allowed = _process_exclude(logits.size, exclude) 1ab
1143 kinda_neg_inf = jnp.finfo(logits.dtype).min 1ab
1144 logits = logits.at[exclude].set(kinda_neg_inf) 1ab
1145 u = random.categorical(key, logits) 1ab
1146 return u, num_allowed 1ab
1149def choose_split( 1ab
1150 key: Key[Array, ''],
1151 var: Int32[Array, ''],
1152 var_tree: UInt[Array, ' 2**(d-1)'],
1153 split_tree: UInt[Array, ' 2**(d-1)'],
1154 max_split: UInt[Array, ' p'],
1155 leaf_index: Int32[Array, ''],
1156) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
1157 """
1158 Choose a split point for a new non-terminal node.
1160 Parameters
1161 ----------
1162 key
1163 A jax random key.
1164 var
1165 The variable to split on.
1166 var_tree
1167 The splitting axes of the tree. Does not need to already contain `var`
1168 at `leaf_index`.
1169 split_tree
1170 The splitting points of the tree.
1171 max_split
1172 The maximum split index for each variable.
1173 leaf_index
1174 The index of the leaf to grow.
1176 Returns
1177 -------
1178 split : Int32[Array, '']
1179 The cutpoint.
1180 l : Int32[Array, '']
1181 r : Int32[Array, '']
1182 The integer range `split` was drawn from is [l, r).
1184 Notes
1185 -----
1186 If `var` is out of bounds, or if the available split range on that variable
1187 is empty, return 0.
1188 """
1189 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1ab
1190 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1ab
1193def compute_partial_ratio( 1ab
1194 prob_choose: Float32[Array, ''],
1195 num_prunable: Int32[Array, ''],
1196 p_nonterminal: Float32[Array, ' 2**d'],
1197 leaf_to_grow: Int32[Array, ''],
1198) -> Float32[Array, '']:
1199 """
1200 Compute the product of the transition and prior ratios of a grow move.
1202 Parameters
1203 ----------
1204 prob_choose
1205 The probability that the leaf had to be chosen amongst the growable
1206 leaves.
1207 num_prunable
1208 The number of leaf parents that could be pruned, after converting the
1209 leaf to be grown to a non-terminal node.
1210 p_nonterminal
1211 The a priori probability of each node being nonterminal conditional on
1212 its ancestors.
1213 leaf_to_grow
1214 The index of the leaf to grow.
1216 Returns
1217 -------
1218 The partial transition ratio times the prior ratio.
1220 Notes
1221 -----
1222 The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1223 The "partial" transition ratio returned is missing the factor P(propose
1224 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
1225 "partial" prior ratio is missing the factor P(children are leaves).
1226 """
1227 # the two ratios also contain factors num_available_split *
1228 # num_available_var * s[var], but they cancel out
1230 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
1231 # computed here because they need the count trees, which are computed in the
1232 # acceptance phase
1234 prune_allowed = leaf_to_grow != 1 1ab
1235 # prune allowed <---> the initial tree is not a root
1236 # leaf to grow is root --> the tree can only be a root
1237 # tree is a root --> the only leaf I can grow is root
1238 p_grow = jnp.where(prune_allowed, 0.5, 1) 1ab
1239 inv_trans_ratio = p_grow * prob_choose * num_prunable 1ab
1241 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this
1242 # would produce a 0 and then an inf when `complete_ratio` takes the log
1243 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1ab
1244 tree_ratio = pnt / (1 - pnt) 1ab
1246 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1ab
1249class PruneMoves(Module): 1ab
1250 """
1251 Represent a proposed prune move for each tree.
1253 Parameters
1254 ----------
1255 allowed
1256 Whether the move is possible.
1257 node
1258 The index of the node to prune. ``2 ** d`` if no node can be pruned.
1259 partial_ratio
1260 A factor of the Metropolis-Hastings ratio of the move. It lacks the
1261 likelihood ratio, the probability of proposing the prune move, and the
1262 prior probability that the children of the node to prune are leaves.
1263 This ratio is inverted, and is meant to be inverted back in
1264 `accept_move_and_sample_leaves`.
1265 """
1267 allowed: Bool[Array, ' num_trees'] 1ab
1268 node: UInt[Array, ' num_trees'] 1ab
1269 partial_ratio: Float32[Array, ' num_trees'] 1ab
1270 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
1273@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 1ab
1274def propose_prune_moves( 1ab
1275 key: Key[Array, ''],
1276 split_tree: UInt[Array, ' 2**(d-1)'],
1277 affluence_tree: Bool[Array, ' 2**(d-1)'],
1278 p_nonterminal: Float32[Array, ' 2**d'],
1279 p_propose_grow: Float32[Array, ' 2**(d-1)'],
1280) -> PruneMoves:
1281 """
1282 Tree structure prune move proposal of BART MCMC.
1284 Parameters
1285 ----------
1286 key
1287 A jax random key.
1288 split_tree
1289 The splitting points of the tree.
1290 affluence_tree
1291 Whether each leaf can be grown.
1292 p_nonterminal
1293 The a priori probability of a node to be nonterminal conditional on
1294 the ancestors, including at the maximum depth where it should be zero.
1295 p_propose_grow
1296 The unnormalized probability of choosing a leaf to grow.
1298 Returns
1299 -------
1300 An object representing the proposed moves.
1301 """
1302 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1ab
1303 key, split_tree, affluence_tree, p_propose_grow
1304 )
1306 ratio = compute_partial_ratio( 1ab
1307 prob_choose, num_prunable, p_nonterminal, node_to_prune
1308 )
1310 return PruneMoves( 1ab
1311 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
1312 node=node_to_prune,
1313 partial_ratio=ratio,
1314 affluence_tree=affluence_tree,
1315 )
1318def choose_leaf_parent( 1ab
1319 key: Key[Array, ''],
1320 split_tree: UInt[Array, ' 2**(d-1)'],
1321 affluence_tree: Bool[Array, ' 2**(d-1)'],
1322 p_propose_grow: Float32[Array, ' 2**(d-1)'],
1323) -> tuple[
1324 Int32[Array, ''],
1325 Int32[Array, ''],
1326 Float32[Array, ''],
1327 Bool[Array, 'num_trees 2**(d-1)'],
1328]:
1329 """
1330 Pick a non-terminal node with leaf children to prune in a tree.
1332 Parameters
1333 ----------
1334 key
1335 A jax random key.
1336 split_tree
1337 The splitting points of the tree.
1338 affluence_tree
1339 Whether a leaf has enough points to be grown.
1340 p_propose_grow
1341 The unnormalized probability of choosing a leaf to grow.
1343 Returns
1344 -------
1345 node_to_prune : Int32[Array, '']
1346 The index of the node to prune. If ``num_prunable == 0``, return
1347 ``2 ** d``.
1348 num_prunable : Int32[Array, '']
1349 The number of leaf parents that could be pruned.
1350 prob_choose : Float32[Array, '']
1351 The (normalized) probability that `choose_leaf` would chose
1352 `node_to_prune` as leaf to grow, if passed the tree where
1353 `node_to_prune` had been pruned.
1354 affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
1355 A partially updated `affluence_tree`, marking the node to prune as
1356 growable.
1357 """
1358 # sample a node to prune
1359 is_prunable = grove.is_leaves_parent(split_tree) 1ab
1360 num_prunable = jnp.count_nonzero(is_prunable) 1ab
1361 node_to_prune = randint_masked(key, is_prunable) 1ab
1362 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1ab
1364 # compute stuff for reverse move
1365 split_tree = split_tree.at[node_to_prune].set(0) 1ab
1366 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1ab
1367 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1ab
1368 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1ab
1369 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1ab
1370 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1ab
1372 return node_to_prune, num_prunable, prob_choose, affluence_tree 1ab
1375def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']: 1ab
1376 """
1377 Return a random integer in a range, including only some values.
1379 Parameters
1380 ----------
1381 key
1382 A jax random key.
1383 mask
1384 The mask indicating the allowed values.
1386 Returns
1387 -------
1388 A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
1390 Notes
1391 -----
1392 If all values in the mask are `False`, return `n`.
1393 """
1394 ecdf = jnp.cumsum(mask) 1ab
1395 u = random.randint(key, (), 0, ecdf[-1]) 1ab
1396 return jnp.searchsorted(ecdf, u, 'right') 1ab
1399def accept_moves_and_sample_leaves( 1ab
1400 key: Key[Array, ''], bart: State, moves: Moves
1401) -> State:
1402 """
1403 Accept or reject the proposed moves and sample the new leaf values.
1405 Parameters
1406 ----------
1407 key
1408 A jax random key.
1409 bart
1410 A valid BART mcmc state.
1411 moves
1412 The proposed moves, see `propose_moves`.
1414 Returns
1415 -------
1416 A new (valid) BART mcmc state.
1417 """
1418 pso = accept_moves_parallel_stage(key, bart, moves) 1ab
1419 bart, moves = accept_moves_sequential_stage(pso) 1ab
1420 return accept_moves_final_stage(bart, moves) 1ab
1423class Counts(Module): 1ab
1424 """
1425 Number of datapoints in the nodes involved in proposed moves for each tree.
1427 Parameters
1428 ----------
1429 left
1430 Number of datapoints in the left child.
1431 right
1432 Number of datapoints in the right child.
1433 total
1434 Number of datapoints in the parent (``= left + right``).
1435 """
1437 left: UInt[Array, ' num_trees'] 1ab
1438 right: UInt[Array, ' num_trees'] 1ab
1439 total: UInt[Array, ' num_trees'] 1ab
1442class Precs(Module): 1ab
1443 """
1444 Likelihood precision scale in the nodes involved in proposed moves for each tree.
1446 The "likelihood precision scale" of a tree node is the sum of the inverse
1447 squared error scales of the datapoints selected by the node.
1449 Parameters
1450 ----------
1451 left
1452 Likelihood precision scale in the left child.
1453 right
1454 Likelihood precision scale in the right child.
1455 total
1456 Likelihood precision scale in the parent (``= left + right``).
1457 """
1459 left: Float32[Array, ' num_trees'] 1ab
1460 right: Float32[Array, ' num_trees'] 1ab
1461 total: Float32[Array, ' num_trees'] 1ab
1464class PreLkV(Module): 1ab
1465 """
1466 Non-sequential terms of the likelihood ratio for each tree.
1468 These terms can be computed in parallel across trees.
1470 Supports both scalar and multivariate models. In the scalar case, variance
1471 terms are 1D arrays of shape (num_trees,); In the multivariate case, they are
1472 arrays of covariance matrices with shape (num_trees, k, k).
1474 Parameters
1475 ----------
1476 sigma2_left
1477 In the scalar case, this is the noise variance in the left child of the leaves
1478 grown or pruned by the moves.
1479 In the multivariate case, this is the intermediate matrix in the quadratic form
1480 representing the contribution of the left child to the exponential term.
1481 sigma2_right
1482 In the scalar case, this is the noise variance in the right child of the leaves
1483 grown or pruned by the moves.
1484 In the multivariate case, this is the intermediate matrix in the quadratic form
1485 representing the contribution of the right child to the exponential term.
1486 sigma2_total
1487 In the scalar case, this is the noise variance in the total of the leaves
1488 grown or pruned by the moves.
1489 In the multivariate case, this is the intermediate matrix in the quadratic form
1490 representing the contribution of the parent node to the exponential term.
1491 sqrt_term
1492 The **logarithm** of the square root term of the likelihood ratio.
1493 """
1495 sigma2_left: Float32[Array, ' num_trees'] | Float32[Array, 'num_trees k k'] 1ab
1496 sigma2_right: Float32[Array, ' num_trees'] | Float32[Array, 'num_trees k k'] 1ab
1497 sigma2_total: Float32[Array, ' num_trees'] | Float32[Array, 'num_trees k k'] 1ab
1498 sqrt_term: Float32[Array, ' num_trees'] 1ab
1501class PreLk(Module): 1ab
1502 """
1503 Non-sequential terms of the likelihood ratio shared by all trees.
1505 Parameters
1506 ----------
1507 exp_factor
1508 The factor to multiply the likelihood ratio by, shared by all trees.
1509 """
1511 exp_factor: Float32[Array, ''] 1ab
1514class PreLf(Module): 1ab
1515 """
1516 Pre-computed terms used to sample leaves from their posterior.
1518 These terms can be computed in parallel across trees.
1520 Supports both scalar and multivariate models. In the scalara case, the arrays have
1521 shape (num_trees, 2**d); In the multivariate case, mean_factor has shape (num_trees, 2**d, k, k) and
1522 centered_leaves has shape (num_trees, 2**d, k).
1524 Parameters
1525 ----------
1526 mean_factor
1527 The factor to be multiplied by the sum of the scaled residuals to
1528 obtain the posterior mean.
1529 centered_leaves
1530 The mean-zero normal values to be added to the posterior mean to
1531 obtain the posterior leaf samples.
1532 """
1534 mean_factor: Float32[Array, 'num_trees 2**d'] | Float32[Array, 'num_trees 2**d k k'] 1ab
1535 centered_leaves: ( 1ab
1536 Float32[Array, 'num_trees 2**d'] | Float32[Array, 'num_trees 2**d k']
1537 )
1540class ParallelStageOut(Module): 1ab
1541 """
1542 The output of `accept_moves_parallel_stage`.
1544 Parameters
1545 ----------
1546 bart
1547 A partially updated BART mcmc state.
1548 moves
1549 The proposed moves, with `partial_ratio` set to `None` and
1550 `log_trans_prior_ratio` set to its final value.
1551 prec_trees
1552 The likelihood precision scale in each potential or actual leaf node. If
1553 there is no precision scale, this is the number of points in each leaf.
1554 move_counts
1555 The counts of the number of points in the the nodes modified by the
1556 moves. If `bart.min_points_per_leaf` is not set and
1557 `bart.prec_scale` is set, they are not computed.
1558 move_precs
1559 The likelihood precision scale in each node modified by the moves. If
1560 `bart.prec_scale` is not set, this is set to `move_counts`.
1561 prelkv
1562 prelk
1563 prelf
1564 Objects with pre-computed terms of the likelihood ratios and leaf
1565 samples.
1566 """
1568 bart: State 1ab
1569 moves: Moves 1ab
1570 prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d'] 1ab
1571 move_precs: Precs | Counts 1ab
1572 prelkv: PreLkV 1ab
1573 prelk: PreLk 1ab
1574 prelf: PreLf 1ab
1577def accept_moves_parallel_stage( 1ab
1578 key: Key[Array, ''], bart: State, moves: Moves
1579) -> ParallelStageOut:
1580 """
1581 Pre-compute quantities used to accept moves, in parallel across trees.
1583 Parameters
1584 ----------
1585 key : jax.dtypes.prng_key array
1586 A jax random key.
1587 bart : dict
1588 A BART mcmc state.
1589 moves : dict
1590 The proposed moves, see `propose_moves`.
1592 Returns
1593 -------
1594 An object with all that could be done in parallel.
1595 """
1596 # where the move is grow, modify the state like the move was accepted
1597 bart = replace( 1ab
1598 bart,
1599 forest=replace(
1600 bart.forest,
1601 var_tree=moves.var_tree,
1602 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1603 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
1604 ),
1605 )
1607 # count number of datapoints per leaf
1608 if ( 1608 ↛ 1618line 1608 didn't jump to line 1618 because the condition on line 1608 was always true
1609 bart.forest.min_points_per_decision_node is not None
1610 or bart.forest.min_points_per_leaf is not None
1611 or bart.prec_scale is None
1612 ):
1613 count_trees, move_counts = compute_count_trees( 1ab
1614 bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1615 )
1617 # mark which leaves & potential leaves have enough points to be grown
1618 if bart.forest.min_points_per_decision_node is not None: 1ab
1619 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1ab
1620 moves = replace( 1ab
1621 moves,
1622 affluence_tree=moves.affluence_tree
1623 & (count_half_trees >= bart.forest.min_points_per_decision_node),
1624 )
1626 # copy updated affluence_tree to state
1627 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1ab
1629 # veto grove move if new leaves don't have enough datapoints
1630 if bart.forest.min_points_per_leaf is not None: 1ab
1631 moves = replace( 1ab
1632 moves,
1633 allowed=moves.allowed
1634 & (move_counts.left >= bart.forest.min_points_per_leaf)
1635 & (move_counts.right >= bart.forest.min_points_per_leaf),
1636 )
1638 # count number of datapoints per leaf, weighted by error precision scale
1639 if bart.prec_scale is None: 1ab
1640 prec_trees = count_trees 1ab
1641 move_precs = move_counts 1ab
1642 else:
1643 prec_trees, move_precs = compute_prec_trees( 1ab
1644 bart.prec_scale,
1645 bart.forest.leaf_indices,
1646 moves,
1647 bart.forest.count_batch_size,
1648 )
1649 assert move_precs is not None 1ab
1651 # compute some missing information about moves
1652 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1ab
1653 save_ratios = bart.forest.log_likelihood is not None 1ab
1654 bart = replace( 1ab
1655 bart,
1656 forest=replace(
1657 bart.forest,
1658 grow_prop_count=jnp.sum(moves.grow),
1659 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1660 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
1661 ),
1662 )
1664 # pre-compute some likelihood ratio & posterior terms
1665 assert bart.sigma2 is not None # `step` shall temporarily set it to 1 1ab
1666 prelkv, prelk = precompute_likelihood_terms( 1ab
1667 bart.sigma2, bart.forest.sigma_mu2, move_precs
1668 )
1669 prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2) 1ab
1671 return ParallelStageOut( 1ab
1672 bart=bart,
1673 moves=moves,
1674 prec_trees=prec_trees,
1675 move_precs=move_precs,
1676 prelkv=prelkv,
1677 prelk=prelk,
1678 prelf=prelf,
1679 )
1682@partial(vmap_nodoc, in_axes=(0, 0, None)) 1ab
1683def apply_grow_to_indices( 1ab
1684 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
1685) -> UInt[Array, 'num_trees n']:
1686 """
1687 Update the leaf indices to apply a grow move.
1689 Parameters
1690 ----------
1691 moves
1692 The proposed moves, see `propose_moves`.
1693 leaf_indices
1694 The index of the leaf each datapoint falls into.
1695 X
1696 The predictors matrix.
1698 Returns
1699 -------
1700 The updated leaf indices.
1701 """
1702 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab
1703 go_right = X[moves.grow_var, :] >= moves.grow_split 1ab
1704 tree_size = jnp.array(2 * moves.var_tree.size) 1ab
1705 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab
1706 return jnp.where( 1ab
1707 leaf_indices == node_to_update, left_child + go_right, leaf_indices
1708 )
1711def compute_count_trees( 1ab
1712 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
1713) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
1714 """
1715 Count the number of datapoints in each leaf.
1717 Parameters
1718 ----------
1719 leaf_indices
1720 The index of the leaf each datapoint falls into, with the deeper version
1721 of the tree (post-GROW, pre-PRUNE).
1722 moves
1723 The proposed moves, see `propose_moves`.
1724 batch_size
1725 The data batch size to use for the summation.
1727 Returns
1728 -------
1729 count_trees : Int32[Array, 'num_trees 2**d']
1730 The number of points in each potential or actual leaf node.
1731 counts : Counts
1732 The counts of the number of points in the leaves grown or pruned by the
1733 moves.
1734 """
1735 num_trees, tree_size = moves.var_tree.shape 1ab
1736 tree_size *= 2 1ab
1737 tree_indices = jnp.arange(num_trees) 1ab
1739 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1ab
1741 # count datapoints in nodes modified by move
1742 left = count_trees[tree_indices, moves.left] 1ab
1743 right = count_trees[tree_indices, moves.right] 1ab
1744 counts = Counts(left=left, right=right, total=left + right) 1ab
1746 # write count into non-leaf node
1747 count_trees = count_trees.at[tree_indices, moves.node].set(counts.total) 1ab
1749 return count_trees, counts 1ab
1752def count_datapoints_per_leaf( 1ab
1753 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
1754) -> Int32[Array, 'num_trees 2**(d-1)']:
1755 """
1756 Count the number of datapoints in each leaf.
1758 Parameters
1759 ----------
1760 leaf_indices
1761 The index of the leaf each datapoint falls into.
1762 tree_size
1763 The size of the leaf tree array (2 ** d).
1764 batch_size
1765 The data batch size to use for the summation.
1767 Returns
1768 -------
1769 The number of points in each leaf node.
1770 """
1771 if batch_size is None: 1ab
1772 return _count_scan(leaf_indices, tree_size) 1ab
1773 else:
1774 return _count_vec(leaf_indices, tree_size, batch_size) 1ab
1777def _count_scan( 1ab
1778 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
1779) -> Int32[Array, 'num_trees {tree_size}']:
1780 def loop(_, leaf_indices): 1ab
1781 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1ab
1783 _, count_trees = lax.scan(loop, None, leaf_indices) 1ab
1784 return count_trees 1ab
1787def _aggregate_scatter( 1ab
1788 values: Shaped[Array, '*'],
1789 indices: Integer[Array, '*'],
1790 size: int,
1791 dtype: jnp.dtype,
1792) -> Shaped[Array, ' {size}']:
1793 return jnp.zeros(size, dtype).at[indices].add(values) 1ab
1796def _count_vec( 1ab
1797 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
1798) -> Int32[Array, 'num_trees 2**(d-1)']:
1799 return _aggregate_batched_alltrees( 1ab
1800 1, leaf_indices, tree_size, jnp.uint32, batch_size
1801 )
1802 # uint16 is super-slow on gpu, don't use it even if n < 2^16
1805def _aggregate_batched_alltrees( 1ab
1806 values: Shaped[Array, '*'],
1807 indices: UInt[Array, 'num_trees n'],
1808 size: int,
1809 dtype: jnp.dtype,
1810 batch_size: int,
1811) -> Shaped[Array, 'num_trees {size}']:
1812 num_trees, n = indices.shape 1ab
1813 tree_indices = jnp.arange(num_trees) 1ab
1814 nbatches = n // batch_size + bool(n % batch_size) 1ab
1815 batch_indices = jnp.arange(n) % nbatches 1ab
1816 return ( 1ab
1817 jnp.zeros((num_trees, size, nbatches), dtype)
1818 .at[tree_indices[:, None], indices, batch_indices]
1819 .add(values)
1820 .sum(axis=2)
1821 )
1824def compute_prec_trees( 1ab
1825 prec_scale: Float32[Array, ' n'],
1826 leaf_indices: UInt[Array, 'num_trees n'],
1827 moves: Moves,
1828 batch_size: int | None,
1829) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
1830 """
1831 Compute the likelihood precision scale in each leaf.
1833 Parameters
1834 ----------
1835 prec_scale
1836 The scale of the precision of the error on each datapoint.
1837 leaf_indices
1838 The index of the leaf each datapoint falls into, with the deeper version
1839 of the tree (post-GROW, pre-PRUNE).
1840 moves
1841 The proposed moves, see `propose_moves`.
1842 batch_size
1843 The data batch size to use for the summation.
1845 Returns
1846 -------
1847 prec_trees : Float32[Array, 'num_trees 2**d']
1848 The likelihood precision scale in each potential or actual leaf node.
1849 precs : Precs
1850 The likelihood precision scale in the nodes involved in the moves.
1851 """
1852 num_trees, tree_size = moves.var_tree.shape 1ab
1853 tree_size *= 2 1ab
1854 tree_indices = jnp.arange(num_trees) 1ab
1856 prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size) 1ab
1858 # prec datapoints in nodes modified by move
1859 left = prec_trees[tree_indices, moves.left] 1ab
1860 right = prec_trees[tree_indices, moves.right] 1ab
1861 precs = Precs(left=left, right=right, total=left + right) 1ab
1863 # write prec into non-leaf node
1864 prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total) 1ab
1866 return prec_trees, precs 1ab
1869def prec_per_leaf( 1ab
1870 prec_scale: Float32[Array, ' n'],
1871 leaf_indices: UInt[Array, 'num_trees n'],
1872 tree_size: int,
1873 batch_size: int | None,
1874) -> Float32[Array, 'num_trees {tree_size}']:
1875 """
1876 Compute the likelihood precision scale in each leaf.
1878 Parameters
1879 ----------
1880 prec_scale
1881 The scale of the precision of the error on each datapoint.
1882 leaf_indices
1883 The index of the leaf each datapoint falls into.
1884 tree_size
1885 The size of the leaf tree array (2 ** d).
1886 batch_size
1887 The data batch size to use for the summation.
1889 Returns
1890 -------
1891 The likelihood precision scale in each leaf node.
1892 """
1893 if batch_size is None: 1893 ↛ 1896line 1893 didn't jump to line 1896 because the condition on line 1893 was always true1ab
1894 return _prec_scan(prec_scale, leaf_indices, tree_size) 1ab
1895 else:
1896 return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
1899def _prec_scan( 1ab
1900 prec_scale: Float32[Array, ' n'],
1901 leaf_indices: UInt[Array, 'num_trees n'],
1902 tree_size: int,
1903) -> Float32[Array, 'num_trees {tree_size}']:
1904 def loop(_, leaf_indices): 1ab
1905 return None, _aggregate_scatter( 1ab
1906 prec_scale, leaf_indices, tree_size, jnp.float32
1907 )
1909 _, prec_trees = lax.scan(loop, None, leaf_indices) 1ab
1910 return prec_trees 1ab
1913def _prec_vec( 1ab
1914 prec_scale: Float32[Array, ' n'],
1915 leaf_indices: UInt[Array, 'num_trees n'],
1916 tree_size: int,
1917 batch_size: int,
1918) -> Float32[Array, 'num_trees {tree_size}']:
1919 return _aggregate_batched_alltrees(
1920 prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1921 )
1924def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 1ab
1925 """
1926 Complete non-likelihood MH ratio calculation.
1928 This function adds the probability of choosing a prune move over the grow
1929 move in the inverse transition, and the a priori probability that the
1930 children nodes are leaves.
1932 Parameters
1933 ----------
1934 moves
1935 The proposed moves. Must have already been updated to keep into account
1936 the thresholds on the number of datapoints per node, this happens in
1937 `accept_moves_parallel_stage`.
1938 p_nonterminal
1939 The a priori probability of each node being nonterminal conditional on
1940 its ancestors, including at the maximum depth where it should be zero.
1942 Returns
1943 -------
1944 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1945 """
1946 # can the leaves can be grown?
1947 num_trees, _ = moves.affluence_tree.shape 1ab
1948 tree_indices = jnp.arange(num_trees) 1ab
1949 left_growable = moves.affluence_tree.at[tree_indices, moves.left].get( 1ab
1950 mode='fill', fill_value=False
1951 )
1952 right_growable = moves.affluence_tree.at[tree_indices, moves.right].get( 1ab
1953 mode='fill', fill_value=False
1954 )
1956 # p_prune if grow
1957 other_growable_leaves = moves.num_growable >= 2 1ab
1958 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1ab
1959 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1ab
1961 # p_prune if prune
1962 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab
1964 # select p_prune
1965 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab
1967 # prior probability of both children being terminal
1968 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1ab
1969 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1ab
1970 pt_children = pt_left * pt_right 1ab
1972 return replace( 1ab
1973 moves,
1974 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
1975 partial_ratio=None,
1976 )
1979@vmap_nodoc 1ab
1980def adapt_leaf_trees_to_grow_indices( 1ab
1981 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
1982) -> Float32[Array, 'num_trees 2**d']:
1983 """
1984 Modify leaves such that post-grow indices work on the original tree.
1986 The value of the leaf to grow is copied to what would be its children if the
1987 grow move was accepted.
1989 Parameters
1990 ----------
1991 leaf_trees
1992 The leaf values.
1993 moves
1994 The proposed moves, see `propose_moves`.
1996 Returns
1997 -------
1998 The modified leaf values.
1999 """
2000 values_at_node = leaf_trees[moves.node] 1ab
2001 return ( 1ab
2002 leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
2003 .set(values_at_node)
2004 .at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
2005 .set(values_at_node)
2006 )
2009def precompute_likelihood_terms( 1ab
2010 sigma2: Float32[Array, ''],
2011 sigma_mu2: Float32[Array, ''],
2012 move_precs: Precs | Counts,
2013) -> tuple[PreLkV, PreLk]:
2014 """
2015 Pre-compute terms used in the likelihood ratio of the acceptance step.
2017 Parameters
2018 ----------
2019 sigma2
2020 The error variance, or the global error variance factor is `prec_scale`
2021 is set.
2022 sigma_mu2
2023 The prior variance of each leaf.
2024 move_precs
2025 The likelihood precision scale in the leaves grown or pruned by the
2026 moves, under keys 'left', 'right', and 'total' (left + right).
2028 Returns
2029 -------
2030 prelkv : PreLkV
2031 Dictionary with pre-computed terms of the likelihood ratio, one per
2032 tree.
2033 prelk : PreLk
2034 Dictionary with pre-computed terms of the likelihood ratio, shared by
2035 all trees.
2036 """
2037 sigma2_left = sigma2 + move_precs.left * sigma_mu2 1ab
2038 sigma2_right = sigma2 + move_precs.right * sigma_mu2 1ab
2039 sigma2_total = sigma2 + move_precs.total * sigma_mu2 1ab
2040 prelkv = PreLkV( 1ab
2041 sigma2_left=sigma2_left,
2042 sigma2_right=sigma2_right,
2043 sigma2_total=sigma2_total,
2044 sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
2045 )
2046 return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2)) 1ab
2049@partial(jnp.vectorize, signature='(k,k)->(k,k)') 1ab
2050def _chol_with_gersh(mat: Float32[Array, '... k k']) -> Float32[Array, '... k k']: 1ab
2051 """Cholesky with Gershgorin stabilization, supports batching."""
2052 rho = jnp.max(jnp.sum(jnp.abs(mat), axis=1)) 1ab
2053 u = mat.shape[0] * rho * jnp.finfo(mat.dtype).eps 1ab
2054 mat = mat.at[jnp.diag_indices_from(mat)].add(u) 1ab
2055 return jnp.linalg.cholesky(mat) 1ab
2058def _logdet_from_chol(L): 1ab
2059 """Compute logdet of A = L'L via Cholesky (sum of log of diag^2)."""
2060 return 2.0 * jnp.sum(jnp.log(jnp.diag(L))) 1ab
2063def precompute_likelihood_terms_mv( 1ab
2064 error_cov_inv: Float32[Array, 'k k'],
2065 leaf_prior_cov_inv: Float32[Array, 'k k'],
2066 move_precs: Counts,
2067) -> tuple[PreLkV, PreLk]:
2068 """
2069 Pre-compute terms used in the likelihood ratio of the acceptance step.
2071 This implementation assumes a homoskedastic error model (i.e., the residual
2072 covariance is the same for all observations). Support for heteroskedasticity
2073 is planed for future updates.
2075 Parameters
2076 ----------
2077 error_cov_inv
2078 The inverse of the error covariance matrix.
2079 leaf_prior_cov_inv
2080 The inverse of prior covariance matrix of each leaf.
2081 move_precs
2082 The likelihood precision scale in the leaves grown or pruned by the
2083 moves, under keys 'left', 'right', and 'total' (left + right).
2085 Returns
2086 -------
2087 prelkv : PreLkV
2088 Dictionary with pre-computed terms of the likelihood ratio, one per
2089 tree.
2090 prelk : PreLk
2091 Dictionary with pre-computed terms of the likelihood ratio, shared by
2092 all trees.
2093 """
2094 nL = move_precs.left[..., None, None] 1ab
2095 nR = move_precs.right[..., None, None] 1ab
2096 nT = move_precs.total[..., None, None] 1ab
2098 L_left = _chol_with_gersh(error_cov_inv * nL + leaf_prior_cov_inv) 1ab
2099 L_right = _chol_with_gersh(error_cov_inv * nR + leaf_prior_cov_inv) 1ab
2100 L_total = _chol_with_gersh(error_cov_inv * nT + leaf_prior_cov_inv) 1ab
2102 sqrt_term = 0.5 * ( 1ab
2103 _logdet_from_chol(_chol_with_gersh(leaf_prior_cov_inv))
2104 + _logdet_from_chol(L_total)
2105 - _logdet_from_chol(L_left)
2106 - _logdet_from_chol(L_right)
2107 )
2109 def _covariance_from_chol(L): 1ab
2110 Y = solve_triangular(L, error_cov_inv, lower=True) 1ab
2111 return Y.T @ Y 1ab
2113 prelkv = PreLkV( 1ab
2114 sigma2_left=_covariance_from_chol(L_left),
2115 sigma2_right=_covariance_from_chol(L_right),
2116 sigma2_total=_covariance_from_chol(L_total),
2117 sqrt_term=sqrt_term,
2118 )
2120 return prelkv, PreLk(exp_factor=0.5) 1ab
2123def precompute_leaf_terms( 1ab
2124 key: Key[Array, ''],
2125 prec_trees: Float32[Array, 'num_trees 2**d'],
2126 sigma2: Float32[Array, ''],
2127 sigma_mu2: Float32[Array, ''],
2128 z: Float32[Array, 'num_trees 2**d'] | None = None,
2129) -> PreLf:
2130 """
2131 Pre-compute terms used to sample leaves from their posterior.
2133 Parameters
2134 ----------
2135 key
2136 A jax random key.
2137 prec_trees
2138 The likelihood precision scale in each potential or actual leaf node.
2139 sigma2
2140 The error variance, or the global error variance factor if `prec_scale`
2141 is set.
2142 sigma_mu2
2143 The prior variance of each leaf.
2144 z
2145 Optional standard normal noise to use for sampling the centered leaves.
2146 This is intended for testing purposes only.
2148 Returns
2149 -------
2150 Pre-computed terms for leaf sampling.
2151 """
2152 prec_lk = prec_trees / sigma2 1ab
2153 prec_prior = lax.reciprocal(sigma_mu2) 1ab
2154 var_post = lax.reciprocal(prec_lk + prec_prior) 1ab
2155 if z is None: 1ab
2156 z = random.normal(key, prec_trees.shape, sigma2.dtype) 1ab
2157 return PreLf( 1ab
2158 mean_factor=var_post / sigma2,
2159 # | mean = mean_lk * prec_lk * var_post
2160 # | resid_tree = mean_lk * prec_tree -->
2161 # | --> mean_lk = resid_tree / prec_tree (kind of)
2162 # | mean_factor =
2163 # | = mean / resid_tree =
2164 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
2165 # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
2166 # | = var_post / sigma2
2167 centered_leaves=z * jnp.sqrt(var_post),
2168 )
2171def precompute_leaf_terms_mv( 1ab
2172 key: Key[Array, ''],
2173 prec_trees: Float32[Array, 'num_trees 2**d'],
2174 error_cov_inv: Float32[Array, 'k k'],
2175 leaf_prior_cov_inv: Float32[Array, 'k k'],
2176 z: Float32[Array, 'num_trees 2**d'] | None = None,
2177) -> PreLf:
2178 """
2179 Pre-compute terms used to sample leaves from their posterior.
2181 Parameters
2182 ----------
2183 key
2184 A jax random key.
2185 prec_trees
2186 The likelihood precision scale in each potential or actual leaf node.
2187 error_cov_inv
2188 The inverse of error variance, or the global error variance factor if `prec_scale`
2189 is set.
2190 leaf_prior_cov_inv
2191 The inverse of prior variance of each leaf.
2192 z
2193 Optional standard normal noise to use for sampling the centered leaves.
2194 This is intended for testing purposes only.
2196 Returns
2197 -------
2198 Pre-computed terms for leaf sampling in multivariate case.
2199 """
2200 num_trees, num_leaves = prec_trees.shape 1ab
2201 k = error_cov_inv.shape[0] 1ab
2202 n_k = prec_trees[..., None, None] # Shape: [num_trees, num_leaves, 1, 1] 1ab
2204 # Only broadcast the inverse of error covariance matrix to satisfy JAX's batching rules
2205 # for `lax.linalg.solve_triangular`, which does not support implicit broadcasting.
2206 error_cov_inv_batched = jnp.broadcast_to( 1ab
2207 error_cov_inv, (num_trees, num_leaves, k, k)
2208 )
2210 posterior_precision = leaf_prior_cov_inv + n_k * error_cov_inv_batched 1ab
2212 L_prec = _chol_with_gersh(posterior_precision) 1ab
2213 Y = solve_triangular(L_prec, error_cov_inv_batched, lower=True) 1ab
2214 mean_factor = solve_triangular(L_prec, Y, trans='T', lower=True) 1ab
2216 if z is None: 1ab
2217 z = random.normal(key, (num_trees, num_leaves, k)) 1ab
2218 centered_leaves = solve_triangular(L_prec, z, trans='T') 1ab
2220 return PreLf( 1ab
2221 mean_factor=mean_factor, # Shape: [num_trees, num_leaves, k, k]
2222 centered_leaves=centered_leaves, # Shape: [num_trees, num_leaves, k]
2223 )
2226def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 1ab
2227 """
2228 Accept/reject the moves one tree at a time.
2230 This is the most performance-sensitive function because it contains all and
2231 only the parts of the algorithm that can not be parallelized across trees.
2233 Parameters
2234 ----------
2235 pso
2236 The output of `accept_moves_parallel_stage`.
2238 Returns
2239 -------
2240 bart : State
2241 A partially updated BART mcmc state.
2242 moves : Moves
2243 The accepted/rejected moves, with `acc` and `to_prune` set.
2244 """
2246 def loop(resid, pt): 1ab
2247 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1ab
2248 resid,
2249 SeqStageInAllTrees(
2250 pso.bart.X,
2251 pso.bart.forest.resid_batch_size,
2252 pso.bart.prec_scale,
2253 pso.bart.forest.log_likelihood is not None,
2254 pso.prelk,
2255 ),
2256 pt,
2257 )
2258 return resid, (leaf_tree, acc, to_prune, lkratio) 1ab
2260 pts = SeqStageInPerTree( 1ab
2261 pso.bart.forest.leaf_tree,
2262 pso.prec_trees,
2263 pso.moves,
2264 pso.move_precs,
2265 pso.bart.forest.leaf_indices,
2266 pso.prelkv,
2267 pso.prelf,
2268 )
2269 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1ab
2271 bart = replace( 1ab
2272 pso.bart,
2273 resid=resid,
2274 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
2275 )
2276 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab
2278 return bart, moves 1ab
2281class SeqStageInAllTrees(Module): 1ab
2282 """
2283 The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
2285 Parameters
2286 ----------
2287 X
2288 The predictors.
2289 resid_batch_size
2290 The batch size for computing the sum of residuals in each leaf.
2291 prec_scale
2292 The scale of the precision of the error on each datapoint. If None, it
2293 is assumed to be 1.
2294 save_ratios
2295 Whether to save the acceptance ratios.
2296 prelk
2297 The pre-computed terms of the likelihood ratio which are shared across
2298 trees.
2299 """
2301 X: UInt[Array, 'p n'] 1ab
2302 resid_batch_size: int | None = field(static=True) 1ab
2303 prec_scale: Float32[Array, ' n'] | None 1ab
2304 save_ratios: bool = field(static=True) 1ab
2305 prelk: PreLk 1ab
2308class SeqStageInPerTree(Module): 1ab
2309 """
2310 The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
2312 Parameters
2313 ----------
2314 leaf_tree
2315 The leaf values of the tree.
2316 prec_tree
2317 The likelihood precision scale in each potential or actual leaf node.
2318 move
2319 The proposed move, see `propose_moves`.
2320 move_precs
2321 The likelihood precision scale in each node modified by the moves.
2322 leaf_indices
2323 The leaf indices for the largest version of the tree compatible with
2324 the move.
2325 prelkv
2326 prelf
2327 The pre-computed terms of the likelihood ratio and leaf sampling which
2328 are specific to the tree.
2329 """
2331 leaf_tree: Float32[Array, ' 2**d'] 1ab
2332 prec_tree: Float32[Array, ' 2**d'] 1ab
2333 move: Moves 1ab
2334 move_precs: Precs | Counts 1ab
2335 leaf_indices: UInt[Array, ' n'] 1ab
2336 prelkv: PreLkV 1ab
2337 prelf: PreLf 1ab
2340def accept_move_and_sample_leaves( 1ab
2341 resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree
2342) -> tuple[
2343 Float32[Array, ' n'],
2344 Float32[Array, ' 2**d'],
2345 Bool[Array, ''],
2346 Bool[Array, ''],
2347 Float32[Array, ''] | None,
2348]:
2349 """
2350 Accept or reject a proposed move and sample the new leaf values.
2352 Parameters
2353 ----------
2354 resid
2355 The residuals (data minus forest value).
2356 at
2357 The inputs that are the same for all trees.
2358 pt
2359 The inputs that are separate for each tree.
2361 Returns
2362 -------
2363 resid : Float32[Array, 'n']
2364 The updated residuals (data minus forest value).
2365 leaf_tree : Float32[Array, '2**d']
2366 The new leaf values of the tree.
2367 acc : Bool[Array, '']
2368 Whether the move was accepted.
2369 to_prune : Bool[Array, '']
2370 Whether, to reflect the acceptance status of the move, the state should
2371 be updated by pruning the leaves involved in the move.
2372 log_lk_ratio : Float32[Array, ''] | None
2373 The logarithm of the likelihood ratio for the move. `None` if not to be
2374 saved.
2375 """
2376 # sum residuals in each leaf, in tree proposed by grow move
2377 if at.prec_scale is None: 1ab
2378 scaled_resid = resid 1ab
2379 else:
2380 scaled_resid = resid * at.prec_scale 1ab
2381 resid_tree = sum_resid( 1ab
2382 scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
2383 )
2385 # subtract starting tree from function
2386 resid_tree += pt.prec_tree * pt.leaf_tree 1ab
2388 # sum residuals in parent node modified by move
2389 resid_left = resid_tree[pt.move.left] 1ab
2390 resid_right = resid_tree[pt.move.right] 1ab
2391 resid_total = resid_left + resid_right 1ab
2392 assert pt.move.node.dtype == jnp.int32 1ab
2393 resid_tree = resid_tree.at[pt.move.node].set(resid_total) 1ab
2395 # compute acceptance ratio
2396 log_lk_ratio = compute_likelihood_ratio( 1ab
2397 resid_total, resid_left, resid_right, pt.prelkv, at.prelk
2398 )
2399 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab
2400 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab
2401 if not at.save_ratios: 1ab
2402 log_lk_ratio = None 1ab
2404 # determine whether to accept the move
2405 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1ab
2407 # compute leaves posterior and sample leaves
2408 mean_post = resid_tree * pt.prelf.mean_factor 1ab
2409 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab
2411 # copy leaves around such that the leaf indices point to the correct leaf
2412 to_prune = acc ^ pt.move.grow 1ab
2413 leaf_tree = ( 1ab
2414 leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)]
2415 .set(leaf_tree[pt.move.node])
2416 .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)]
2417 .set(leaf_tree[pt.move.node])
2418 )
2420 # replace old tree with new tree in function values
2421 resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices] 1ab
2423 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1ab
2426def sum_resid( 1ab
2427 scaled_resid: Float32[Array, ' n'],
2428 leaf_indices: UInt[Array, ' n'],
2429 tree_size: int,
2430 batch_size: int | None,
2431) -> Float32[Array, ' {tree_size}']:
2432 """
2433 Sum the residuals in each leaf.
2435 Parameters
2436 ----------
2437 scaled_resid
2438 The residuals (data minus forest value) multiplied by the error
2439 precision scale.
2440 leaf_indices
2441 The leaf indices of the tree (in which leaf each data point falls into).
2442 tree_size
2443 The size of the tree array (2 ** d).
2444 batch_size
2445 The data batch size for the aggregation. Batching increases numerical
2446 accuracy and parallelism.
2448 Returns
2449 -------
2450 The sum of the residuals at data points in each leaf.
2451 """
2452 if batch_size is None: 1ab
2453 aggr_func = _aggregate_scatter 1ab
2454 else:
2455 aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size) 1ab
2456 return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32) 1ab
2459def _aggregate_batched_onetree( 1ab
2460 values: Shaped[Array, '*'],
2461 indices: Integer[Array, '*'],
2462 size: int,
2463 dtype: jnp.dtype,
2464 batch_size: int,
2465) -> Float32[Array, ' {size}']:
2466 (n,) = indices.shape 1ab
2467 nbatches = n // batch_size + bool(n % batch_size) 1ab
2468 batch_indices = jnp.arange(n) % nbatches 1ab
2469 return ( 1ab
2470 jnp.zeros((size, nbatches), dtype)
2471 .at[indices, batch_indices]
2472 .add(values)
2473 .sum(axis=1)
2474 )
2477def compute_likelihood_ratio( 1ab
2478 total_resid: Float32[Array, ''],
2479 left_resid: Float32[Array, ''],
2480 right_resid: Float32[Array, ''],
2481 prelkv: PreLkV,
2482 prelk: PreLk,
2483) -> Float32[Array, '']:
2484 """
2485 Compute the likelihood ratio of a grow move.
2487 Parameters
2488 ----------
2489 total_resid
2490 left_resid
2491 right_resid
2492 The sum of the residuals (scaled by error precision scale) of the
2493 datapoints falling in the nodes involved in the moves.
2494 prelkv
2495 prelk
2496 The pre-computed terms of the likelihood ratio, see
2497 `precompute_likelihood_terms`.
2499 Returns
2500 -------
2501 The log-likelihood ratio log P(data | new tree) - log P(data | old tree).
2502 """
2503 exp_term = prelk.exp_factor * ( 1ab
2504 left_resid * left_resid / prelkv.sigma2_left
2505 + right_resid * right_resid / prelkv.sigma2_right
2506 - total_resid * total_resid / prelkv.sigma2_total
2507 )
2508 return prelkv.sqrt_term + exp_term 1ab
2511def compute_likelihood_ratio_mv( 1ab
2512 total_resid: Float32[Array, ' k'],
2513 left_resid: Float32[Array, ' k'],
2514 right_resid: Float32[Array, ' k'],
2515 prelkv: PreLkV,
2516 prelk: PreLk, # noqa: ARG001
2517) -> Float32[Array, '']:
2518 """
2519 Compute the likelihood ratio of a grow move, for multivariate case.
2521 Parameters
2522 ----------
2523 total_resid
2524 left_resid
2525 right_resid
2526 The sum of the residuals (scaled by error precision scale) of the
2527 datapoints falling in the nodes involved in the moves.
2528 prelkv
2529 prelk
2530 The pre-computed terms of the likelihood ratio, see
2531 `precompute_likelihood_terms_mv`.
2533 Returns
2534 -------
2535 The log-likelihood ratio log P(data | new tree) - log P(data | old tree).
2536 """
2538 def _quadratic_form(r, cov): 1ab
2539 return r @ cov @ r 1ab
2541 qf_left = _quadratic_form(left_resid, prelkv.sigma2_left) 1ab
2542 qf_right = _quadratic_form(right_resid, prelkv.sigma2_right) 1ab
2543 qf_total = _quadratic_form(total_resid, prelkv.sigma2_total) 1ab
2544 exp_term = 0.5 * (qf_left + qf_right - qf_total) 1ab
2545 return prelkv.sqrt_term + exp_term 1ab
2548def accept_moves_final_stage(bart: State, moves: Moves) -> State: 1ab
2549 """
2550 Post-process the mcmc state after accepting/rejecting the moves.
2552 This function is separate from `accept_moves_sequential_stage` to signal it
2553 can work in parallel across trees.
2555 Parameters
2556 ----------
2557 bart
2558 A partially updated BART mcmc state.
2559 moves
2560 The proposed moves (see `propose_moves`) as updated by
2561 `accept_moves_sequential_stage`.
2563 Returns
2564 -------
2565 The fully updated BART mcmc state.
2566 """
2567 return replace( 1ab
2568 bart,
2569 forest=replace(
2570 bart.forest,
2571 grow_acc_count=jnp.sum(moves.acc & moves.grow),
2572 prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2573 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2574 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
2575 ),
2576 )
2579@vmap_nodoc 1ab
2580def apply_moves_to_leaf_indices( 1ab
2581 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
2582) -> UInt[Array, 'num_trees n']:
2583 """
2584 Update the leaf indices to match the accepted move.
2586 Parameters
2587 ----------
2588 leaf_indices
2589 The index of the leaf each datapoint falls into, if the grow move was
2590 accepted.
2591 moves
2592 The proposed moves (see `propose_moves`), as updated by
2593 `accept_moves_sequential_stage`.
2595 Returns
2596 -------
2597 The updated leaf indices.
2598 """
2599 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab
2600 is_child = (leaf_indices & mask) == moves.left 1ab
2601 return jnp.where( 1ab
2602 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
2603 )
2606@vmap_nodoc 1ab
2607def apply_moves_to_split_trees( 1ab
2608 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2609) -> UInt[Array, 'num_trees 2**(d-1)']:
2610 """
2611 Update the split trees to match the accepted move.
2613 Parameters
2614 ----------
2615 split_tree
2616 The cutpoints of the decision nodes in the initial trees.
2617 moves
2618 The proposed moves (see `propose_moves`), as updated by
2619 `accept_moves_sequential_stage`.
2621 Returns
2622 -------
2623 The updated split trees.
2624 """
2625 assert moves.to_prune is not None 1ab
2626 return ( 1ab
2627 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
2628 .set(moves.grow_split.astype(split_tree.dtype))
2629 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
2630 .set(0)
2631 )
2634def step_sigma(key: Key[Array, ''], bart: State) -> State: 1ab
2635 """
2636 MCMC-update the error variance (factor).
2638 Parameters
2639 ----------
2640 key
2641 A jax random key.
2642 bart
2643 A BART mcmc state.
2645 Returns
2646 -------
2647 The new BART mcmc state, with an updated `sigma2`.
2648 """
2649 resid = bart.resid 1ab
2650 alpha = bart.sigma2_alpha + resid.size / 2 1ab
2651 if bart.prec_scale is None: 1ab
2652 scaled_resid = resid 1ab
2653 else:
2654 scaled_resid = resid * bart.prec_scale 1ab
2655 norm2 = resid @ scaled_resid 1ab
2656 beta = bart.sigma2_beta + norm2 / 2 1ab
2658 sample = random.gamma(key, alpha) 1ab
2659 # random.gamma seems to be slow at compiling, maybe cdf inversion would
2660 # be better, but it's not implemented in jax
2661 return replace(bart, sigma2=beta / sample) 1ab
2664@jax.jit 1ab
2665def _sample_wishart_bartlett( 1ab
2666 key: Key[Array, ''], df: Integer[Array, ''], scale_inv: Float32[Array, 'k k']
2667) -> Float32[Array, 'k k']:
2668 """
2669 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition.
2671 Parameters
2672 ----------
2673 key
2674 A JAX random key
2675 df
2676 Degrees of freedom
2677 scale_inv
2678 Scale matrix of the corresponding Inverse Wishart distribution
2680 Returns
2681 -------
2682 A sample from Wishart(df, scale)
2683 """
2684 keys = split(key) 1ab
2686 k = scale_inv.shape[0] 1ab
2688 # Gershgorin estimate for max eigenvalue
2689 rho = jnp.max(jnp.sum(jnp.abs(scale_inv), axis=1)) 1ab
2690 u = k * rho * jnp.finfo(scale_inv.dtype).eps + jnp.finfo(scale_inv.dtype).eps 1ab
2692 # Stabilize the matrix
2693 scale_inv = scale_inv.at[jnp.diag_indices(k)].add(u) 1ab
2694 L = jnp.linalg.cholesky(scale_inv) 1ab
2696 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i))
2697 # chi^2(k) = Gamma(k/2, scale=2)
2698 df_vector = df - jnp.arange(k) 1ab
2699 chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0 1ab
2700 diag_A = jnp.sqrt(chi2_samples) 1ab
2702 off_diag_A = random.normal(keys.pop(), (k, k)) 1ab
2703 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 1ab
2704 T = solve_triangular(L, A, lower=True, trans='T') 1ab
2706 return T @ T.T 1ab
2709def step_z(key: Key[Array, ''], bart: State) -> State: 1ab
2710 """
2711 MCMC-update the latent variable for binary regression.
2713 Parameters
2714 ----------
2715 key
2716 A jax random key.
2717 bart
2718 A BART MCMC state.
2720 Returns
2721 -------
2722 The updated BART MCMC state.
2723 """
2724 trees_plus_offset = bart.z - bart.resid 1ab
2725 assert bart.y.dtype == bool 1ab
2726 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1ab
2727 z = trees_plus_offset + resid 1ab
2728 return replace(bart, z=z, resid=resid) 1ab
2731def step_s(key: Key[Array, ''], bart: State) -> State: 1ab
2732 """
2733 Update `log_s` using Dirichlet sampling.
2735 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
2736 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
2737 varcount is the count of how many times each variable is used in the
2738 current forest.
2740 Parameters
2741 ----------
2742 key
2743 Random key for sampling.
2744 bart
2745 The current BART state.
2747 Returns
2748 -------
2749 Updated BART state with re-sampled `log_s`.
2751 Notes
2752 -----
2753 This full conditional is approximated, because it does not take into account
2754 that there are forbidden decision rules.
2755 """
2756 assert bart.forest.theta is not None 1ab
2758 # histogram current variable usage
2759 p = bart.forest.max_split.size 1ab
2760 varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree) 1ab
2762 # sample from Dirichlet posterior
2763 alpha = bart.forest.theta / p + varcount 1ab
2764 log_s = random.loggamma(key, alpha) 1ab
2766 # update forest with new s
2767 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1ab
2770def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State: 1ab
2771 """
2772 Update `theta`.
2774 The prior is theta / (theta + rho) ~ Beta(a, b).
2776 Parameters
2777 ----------
2778 key
2779 Random key for sampling.
2780 bart
2781 The current BART state.
2782 num_grid
2783 The number of points in the evenly-spaced grid used to sample
2784 theta / (theta + rho).
2786 Returns
2787 -------
2788 Updated BART state with re-sampled `theta`.
2789 """
2790 assert bart.forest.log_s is not None 1ab
2791 assert bart.forest.rho is not None 1ab
2792 assert bart.forest.a is not None 1ab
2793 assert bart.forest.b is not None 1ab
2795 # the grid points are the midpoints of num_grid bins in (0, 1)
2796 padding = 1 / (2 * num_grid) 1ab
2797 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1ab
2799 # normalize s
2800 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1ab
2802 # sample lambda
2803 logp, theta_grid = _log_p_lamda( 1ab
2804 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
2805 )
2806 i = random.categorical(key, logp) 1ab
2807 theta = theta_grid[i] 1ab
2809 return replace(bart, forest=replace(bart.forest, theta=theta)) 1ab
2812def _log_p_lamda( 1ab
2813 lamda: Float32[Array, ' num_grid'],
2814 log_s: Float32[Array, ' p'],
2815 rho: Float32[Array, ''],
2816 a: Float32[Array, ''],
2817 b: Float32[Array, ''],
2818) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
2819 # in the following I use lamda[::-1] == 1 - lamda
2820 theta = rho * lamda / lamda[::-1] 1ab
2821 p = log_s.size 1ab
2822 return ( 1ab
2823 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
2824 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
2825 + gammaln(theta)
2826 - p * gammaln(theta / p)
2827 + theta / p * jnp.sum(log_s)
2828 ), theta
2831def step_sparse(key: Key[Array, ''], bart: State) -> State: 1ab
2832 """
2833 Update the sparsity parameters.
2835 This invokes `step_s`, and then `step_theta` only if the parameters of
2836 the theta prior are defined.
2838 Parameters
2839 ----------
2840 key
2841 Random key for sampling.
2842 bart
2843 The current BART state.
2845 Returns
2846 -------
2847 Updated BART state with re-sampled `log_s` and `theta`.
2848 """
2849 keys = split(key) 1ab
2850 bart = step_s(keys.pop(), bart) 1ab
2851 if bart.forest.rho is not None: 1ab
2852 bart = step_theta(keys.pop(), bart) 1ab
2853 return bart 1ab