Coverage for src/bartz/mcmcstep.py: 95%
551 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-07 22:47 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-07 22:47 +0000
1# bartz/src/bartz/mcmcstep.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""
26Functions that implement the BART posterior MCMC initialization and update step.
28Functions that do MCMC steps operate by taking as input a bart state, and
29outputting a new state. The inputs are not modified.
31The entry points are:
33 - `State`: The dataclass that represents a BART MCMC state.
34 - `init`: Creates an initial `State` from data and configurations.
35 - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
36 - `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`.
37"""
39import math 1ab
40from dataclasses import replace 1ab
41from functools import cache, partial 1ab
42from typing import Any, Literal 1ab
44import jax 1ab
45from equinox import Module, field, tree_at 1ab
46from jax import lax, random 1ab
47from jax import numpy as jnp 1ab
48from jax.scipy.special import gammaln, logsumexp 1ab
49from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt 1ab
51from bartz import grove 1ab
52from bartz.jaxext import ( 1ab
53 minimal_unsigned_dtype,
54 split,
55 truncated_normal_onesided,
56 vmap_nodoc,
57)
60class Forest(Module): 1ab
61 """
62 Represents the MCMC state of a sum of trees.
64 Parameters
65 ----------
66 leaf_tree
67 The leaf values.
68 var_tree
69 The decision axes.
70 split_tree
71 The decision boundaries.
72 affluence_tree
73 Marks leaves that can be grown.
74 max_split
75 The maximum split index for each predictor.
76 blocked_vars
77 Indices of variables that are not used. This shall include at least
78 the `i` such that ``max_split[i] == 0``, otherwise behavior is
79 undefined.
80 p_nonterminal
81 The prior probability of each node being nonterminal, conditional on
82 its ancestors. Includes the nodes at maximum depth which should be set
83 to 0.
84 p_propose_grow
85 The unnormalized probability of picking a leaf for a grow proposal.
86 leaf_indices
87 The index of the leaf each datapoints falls into, for each tree.
88 min_points_per_decision_node
89 The minimum number of data points in a decision node.
90 min_points_per_leaf
91 The minimum number of data points in a leaf node.
92 resid_batch_size
93 count_batch_size
94 The data batch sizes for computing the sufficient statistics. If `None`,
95 they are computed with no batching.
96 log_trans_prior
97 The log transition and prior Metropolis-Hastings ratio for the
98 proposed move on each tree.
99 log_likelihood
100 The log likelihood ratio.
101 grow_prop_count
102 prune_prop_count
103 The number of grow/prune proposals made during one full MCMC cycle.
104 grow_acc_count
105 prune_acc_count
106 The number of grow/prune moves accepted during one full MCMC cycle.
107 sigma_mu2
108 The prior variance of a leaf, conditional on the tree structure.
109 log_s
110 The logarithm of the prior probability for choosing a variable to split
111 along in a decision rule, conditional on the ancestors. Not normalized.
112 If `None`, use a uniform distribution.
113 theta
114 The concentration parameter for the Dirichlet prior on the variable
115 distribution `s`. Required only to update `s`.
116 a
117 b
118 rho
119 Parameters of the prior on `theta`. Required only to sample `theta`.
120 See `step_theta`.
121 """
123 leaf_tree: Float32[Array, 'num_trees 2**d'] 1ab
124 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
125 split_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
126 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
127 max_split: UInt[Array, ' p'] 1ab
128 blocked_vars: UInt[Array, ' k'] | None 1ab
129 p_nonterminal: Float32[Array, ' 2**d'] 1ab
130 p_propose_grow: Float32[Array, ' 2**(d-1)'] 1ab
131 leaf_indices: UInt[Array, 'num_trees n'] 1ab
132 min_points_per_decision_node: Int32[Array, ''] | None 1ab
133 min_points_per_leaf: Int32[Array, ''] | None 1ab
134 resid_batch_size: int | None = field(static=True) 1ab
135 count_batch_size: int | None = field(static=True) 1ab
136 log_trans_prior: Float32[Array, ' num_trees'] | None 1ab
137 log_likelihood: Float32[Array, ' num_trees'] | None 1ab
138 grow_prop_count: Int32[Array, ''] 1ab
139 prune_prop_count: Int32[Array, ''] 1ab
140 grow_acc_count: Int32[Array, ''] 1ab
141 prune_acc_count: Int32[Array, ''] 1ab
142 sigma_mu2: Float32[Array, ''] 1ab
143 log_s: Float32[Array, ' p'] | None 1ab
144 theta: Float32[Array, ''] | None 1ab
145 a: Float32[Array, ''] | None 1ab
146 b: Float32[Array, ''] | None 1ab
147 rho: Float32[Array, ''] | None 1ab
150class State(Module): 1ab
151 """
152 Represents the MCMC state of BART.
154 Parameters
155 ----------
156 X
157 The predictors.
158 y
159 The response. If the data type is `bool`, the model is binary regression.
160 resid
161 The residuals (`y` or `z` minus sum of trees).
162 z
163 The latent variable for binary regression. `None` in continuous
164 regression.
165 offset
166 Constant shift added to the sum of trees.
167 sigma2
168 The error variance. `None` in binary regression.
169 prec_scale
170 The scale on the error precision, i.e., ``1 / error_scale ** 2``.
171 `None` in binary regression.
172 sigma2_alpha
173 sigma2_beta
174 The shape and scale parameters of the inverse gamma prior on the noise
175 variance. `None` in binary regression.
176 forest
177 The sum of trees model.
178 """
180 X: UInt[Array, 'p n'] 1ab
181 y: Float32[Array, ' n'] | Bool[Array, ' n'] 1ab
182 z: None | Float32[Array, ' n'] 1ab
183 offset: Float32[Array, ''] 1ab
184 resid: Float32[Array, ' n'] 1ab
185 sigma2: Float32[Array, ''] | None 1ab
186 prec_scale: Float32[Array, ' n'] | None 1ab
187 sigma2_alpha: Float32[Array, ''] | None 1ab
188 sigma2_beta: Float32[Array, ''] | None 1ab
189 forest: Forest 1ab
192def init( 1ab
193 *,
194 X: UInt[Any, 'p n'],
195 y: Float32[Any, ' n'] | Bool[Any, ' n'],
196 offset: float | Float32[Any, ''] = 0.0,
197 max_split: UInt[Any, ' p'],
198 num_trees: int,
199 p_nonterminal: Float32[Any, ' d-1'],
200 sigma_mu2: float | Float32[Any, ''],
201 sigma2_alpha: float | Float32[Any, ''] | None = None,
202 sigma2_beta: float | Float32[Any, ''] | None = None,
203 error_scale: Float32[Any, ' n'] | None = None,
204 min_points_per_decision_node: int | Integer[Any, ''] | None = None,
205 resid_batch_size: int | None | Literal['auto'] = 'auto',
206 count_batch_size: int | None | Literal['auto'] = 'auto',
207 save_ratios: bool = False,
208 filter_splitless_vars: bool = True,
209 min_points_per_leaf: int | Integer[Any, ''] | None = None,
210 log_s: Float32[Any, ' p'] | None = None,
211 theta: float | Float32[Any, ''] | None = None,
212 a: float | Float32[Any, ''] | None = None,
213 b: float | Float32[Any, ''] | None = None,
214 rho: float | Float32[Any, ''] | None = None,
215) -> State:
216 """
217 Make a BART posterior sampling MCMC initial state.
219 Parameters
220 ----------
221 X
222 The predictors. Note this is trasposed compared to the usual convention.
223 y
224 The response. If the data type is `bool`, the regression model is binary
225 regression with probit.
226 offset
227 Constant shift added to the sum of trees. 0 if not specified.
228 max_split
229 The maximum split index for each variable. All split ranges start at 1.
230 num_trees
231 The number of trees in the forest.
232 p_nonterminal
233 The probability of a nonterminal node at each depth. The maximum depth
234 of trees is fixed by the length of this array.
235 sigma_mu2
236 The prior variance of a leaf, conditional on the tree structure. The
237 prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
238 prior mean of leaves is always zero.
239 sigma2_alpha
240 sigma2_beta
241 The shape and scale parameters of the inverse gamma prior on the error
242 variance. Leave unspecified for binary regression.
243 error_scale
244 Each error is scaled by the corresponding factor in `error_scale`, so
245 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
246 Not supported for binary regression. If not specified, defaults to 1 for
247 all points, but potentially skipping calculations.
248 min_points_per_decision_node
249 The minimum number of data points in a decision node. 0 if not
250 specified.
251 resid_batch_size
252 count_batch_size
253 The batch sizes, along datapoints, for summing the residuals and
254 counting the number of datapoints in each leaf. `None` for no batching.
255 If 'auto', pick a value based on the device of `y`, or the default
256 device.
257 save_ratios
258 Whether to save the Metropolis-Hastings ratios.
259 filter_splitless_vars
260 Whether to check `max_split` for variables without available cutpoints.
261 If any are found, they are put into a list of variables to exclude from
262 the MCMC. If `False`, no check is performed, but the results may be
263 wrong if any variable is blocked. The function is jax-traceable only
264 if this is set to `False`.
265 min_points_per_leaf
266 The minimum number of datapoints in a leaf node. 0 if not specified.
267 Unlike `min_points_per_decision_node`, this constraint is not taken into
268 account in the Metropolis-Hastings ratio because it would be expensive
269 to compute. Grow moves that would violate this constraint are vetoed.
270 This parameter is independent of `min_points_per_decision_node` and
271 there is no check that they are coherent. It makes sense to set
272 ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
273 log_s
274 The logarithm of the prior probability for choosing a variable to split
275 along in a decision rule, conditional on the ancestors. Not normalized.
276 If not specified, use a uniform distribution. If not specified and
277 `theta` or `rho`, `a`, `b` are, it's initialized automatically.
278 theta
279 The concentration parameter for the Dirichlet prior on `s`. Required
280 only to update `log_s`. If not specified, and `rho`, `a`, `b` are
281 specified, it's initialized automatically.
282 a
283 b
284 rho
285 Parameters of the prior on `theta`. Required only to sample `theta`.
287 Returns
288 -------
289 An initialized BART MCMC state.
291 Raises
292 ------
293 ValueError
294 If `y` is boolean and arguments unused in binary regression are set.
296 Notes
297 -----
298 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
299 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
300 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
301 integers in the range ``[0, 1, ..., max_split[i]]``.
302 """
303 p_nonterminal = jnp.asarray(p_nonterminal) 1ab
304 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1ab
305 max_depth = p_nonterminal.size 1ab
307 @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1ab
308 def make_forest(max_depth, dtype): 1ab
309 return grove.make_tree(max_depth, dtype) 1ab
311 y = jnp.asarray(y) 1ab
312 offset = jnp.asarray(offset) 1ab
314 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 1ab
315 resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
316 )
318 is_binary = y.dtype == bool 1ab
319 if is_binary: 1ab
320 if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,): 320 ↛ 321line 320 didn't jump to line 321 because the condition on line 320 was never true1ab
321 msg = (
322 'error_scale, sigma2_alpha, and sigma2_beta must be set '
323 ' to `None` for binary regression.'
324 )
325 raise ValueError(msg)
326 sigma2 = None 1ab
327 else:
328 sigma2_alpha = jnp.asarray(sigma2_alpha) 1ab
329 sigma2_beta = jnp.asarray(sigma2_beta) 1ab
330 sigma2 = sigma2_beta / sigma2_alpha 1ab
332 max_split = jnp.asarray(max_split) 1ab
334 if filter_splitless_vars: 1ab
335 (blocked_vars,) = jnp.nonzero(max_split == 0) 1ab
336 blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size)) 1ab
337 # see `fully_used_variables` for the type cast
338 else:
339 blocked_vars = None 1ab
341 # check and initialize sparsity parameters
342 if not _all_none_or_not_none(rho, a, b): 342 ↛ 343line 342 didn't jump to line 343 because the condition on line 342 was never true1ab
343 msg = 'rho, a, b are not either all `None` or all set'
344 raise ValueError(msg)
345 if theta is None and rho is not None: 1ab
346 theta = rho 1ab
347 if log_s is None and theta is not None: 1ab
348 log_s = jnp.zeros(max_split.size) 1ab
350 return State( 1ab
351 X=jnp.asarray(X),
352 y=y,
353 z=jnp.full(y.shape, offset) if is_binary else None,
354 offset=offset,
355 resid=jnp.zeros(y.shape) if is_binary else y - offset,
356 sigma2=sigma2,
357 prec_scale=(
358 None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
359 ),
360 sigma2_alpha=sigma2_alpha,
361 sigma2_beta=sigma2_beta,
362 forest=Forest(
363 leaf_tree=make_forest(max_depth, jnp.float32),
364 var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)),
365 split_tree=make_forest(max_depth - 1, max_split.dtype),
366 affluence_tree=(
367 make_forest(max_depth - 1, bool)
368 .at[:, 1]
369 .set(
370 True
371 if min_points_per_decision_node is None
372 else y.size >= min_points_per_decision_node
373 )
374 ),
375 blocked_vars=blocked_vars,
376 max_split=max_split,
377 grow_prop_count=jnp.zeros((), int),
378 grow_acc_count=jnp.zeros((), int),
379 prune_prop_count=jnp.zeros((), int),
380 prune_acc_count=jnp.zeros((), int),
381 p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)],
382 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
383 leaf_indices=jnp.ones(
384 (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
385 ),
386 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
387 min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
388 resid_batch_size=resid_batch_size,
389 count_batch_size=count_batch_size,
390 log_trans_prior=jnp.zeros(num_trees) if save_ratios else None,
391 log_likelihood=jnp.zeros(num_trees) if save_ratios else None,
392 sigma_mu2=jnp.asarray(sigma_mu2),
393 log_s=_asarray_or_none(log_s),
394 theta=_asarray_or_none(theta),
395 rho=_asarray_or_none(rho),
396 a=_asarray_or_none(a),
397 b=_asarray_or_none(b),
398 ),
399 )
402def _all_none_or_not_none(*args): 1ab
403 is_none = [x is None for x in args] 1ab
404 return all(is_none) or not any(is_none) 1ab
407def _asarray_or_none(x): 1ab
408 if x is None: 1ab
409 return None 1ab
410 return jnp.asarray(x) 1ab
413def _choose_suffstat_batch_size( 1ab
414 resid_batch_size, count_batch_size, y, forest_size
415) -> tuple[int | None, ...]:
416 @cache 1ab
417 def get_platform(): 1ab
418 try: 1ab
419 device = y.devices().pop() 1ab
420 except jax.errors.ConcretizationTypeError: 1ab
421 device = jax.devices()[0] 1ab
422 platform = device.platform 1ab
423 if platform not in ('cpu', 'gpu'): 423 ↛ 424line 423 didn't jump to line 424 because the condition on line 423 was never true1ab
424 msg = f'Unknown platform: {platform}'
425 raise KeyError(msg)
426 return platform 1ab
428 if resid_batch_size == 'auto': 1ab
429 platform = get_platform() 1ab
430 n = max(1, y.size) 1ab
431 if platform == 'cpu': 431 ↛ 433line 431 didn't jump to line 433 because the condition on line 431 was always true1ab
432 resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6 1ab
433 elif platform == 'gpu':
434 resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3
435 resid_batch_size = max(1, resid_batch_size) 1ab
437 if count_batch_size == 'auto': 1ab
438 platform = get_platform() 1ab
439 if platform == 'cpu': 439 ↛ 441line 439 didn't jump to line 441 because the condition on line 439 was always true1ab
440 count_batch_size = None 1ab
441 elif platform == 'gpu':
442 n = max(1, y.size)
443 count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2
444 # /4 is good on V100, /2 on L4/T4, still haven't tried A100
445 max_memory = 2**29
446 itemsize = 4
447 min_batch_size = math.ceil(forest_size * itemsize * n / max_memory)
448 count_batch_size = max(count_batch_size, min_batch_size)
449 count_batch_size = max(1, count_batch_size)
451 return resid_batch_size, count_batch_size 1ab
454@jax.jit 1ab
455def step(key: Key[Array, ''], bart: State) -> State: 1ab
456 """
457 Do one MCMC step.
459 Parameters
460 ----------
461 key
462 A jax random key.
463 bart
464 A BART mcmc state, as created by `init`.
466 Returns
467 -------
468 The new BART mcmc state.
469 """
470 keys = split(key) 1ab
472 if bart.y.dtype == bool: # binary regression 1ab
473 bart = replace(bart, sigma2=jnp.float32(1)) 1ab
474 bart = step_trees(keys.pop(), bart) 1ab
475 bart = replace(bart, sigma2=None) 1ab
476 return step_z(keys.pop(), bart) 1ab
478 else: # continuous regression
479 bart = step_trees(keys.pop(), bart) 1ab
480 return step_sigma(keys.pop(), bart) 1ab
483def step_trees(key: Key[Array, ''], bart: State) -> State: 1ab
484 """
485 Forest sampling step of BART MCMC.
487 Parameters
488 ----------
489 key
490 A jax random key.
491 bart
492 A BART mcmc state, as created by `init`.
494 Returns
495 -------
496 The new BART mcmc state.
498 Notes
499 -----
500 This function zeroes the proposal counters.
501 """
502 keys = split(key) 1ab
503 moves = propose_moves(keys.pop(), bart.forest) 1ab
504 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab
507class Moves(Module): 1ab
508 """
509 Moves proposed to modify each tree.
511 Parameters
512 ----------
513 allowed
514 Whether there is a possible move. If `False`, the other values may not
515 make sense. The only case in which a move is marked as allowed but is
516 then vetoed is if it does not satisfy `min_points_per_leaf`, which for
517 efficiency is implemented post-hoc without changing the rest of the
518 MCMC logic.
519 grow
520 Whether the move is a grow move or a prune move.
521 num_growable
522 The number of growable leaves in the original tree.
523 node
524 The index of the leaf to grow or node to prune.
525 left
526 right
527 The indices of the children of 'node'.
528 partial_ratio
529 A factor of the Metropolis-Hastings ratio of the move. It lacks the
530 likelihood ratio, the probability of proposing the prune move, and the
531 probability that the children of the modified node are terminal. If the
532 move is PRUNE, the ratio is inverted. `None` once
533 `log_trans_prior_ratio` has been computed.
534 log_trans_prior_ratio
535 The logarithm of the product of the transition and prior terms of the
536 Metropolis-Hastings ratio for the acceptance of the proposed move.
537 `None` if not yet computed. If PRUNE, the log-ratio is negated.
538 grow_var
539 The decision axes of the new rules.
540 grow_split
541 The decision boundaries of the new rules.
542 var_tree
543 The updated decision axes of the trees, valid whatever move.
544 affluence_tree
545 A partially updated `affluence_tree`, marking non-leaf nodes that would
546 become leaves if the move was accepted. This mark initially (out of
547 `propose_moves`) takes into account if there would be available decision
548 rules to grow the leaf, and whether there are enough datapoints in the
549 node is marked in `accept_moves_parallel_stage`.
550 logu
551 The logarithm of a uniform (0, 1] random variable to be used to
552 accept the move. It's in (-oo, 0].
553 acc
554 Whether the move was accepted. `None` if not yet computed.
555 to_prune
556 Whether the final operation to apply the move is pruning. This indicates
557 an accepted prune move or a rejected grow move. `None` if not yet
558 computed.
559 """
561 allowed: Bool[Array, ' num_trees'] 1ab
562 grow: Bool[Array, ' num_trees'] 1ab
563 num_growable: UInt[Array, ' num_trees'] 1ab
564 node: UInt[Array, ' num_trees'] 1ab
565 left: UInt[Array, ' num_trees'] 1ab
566 right: UInt[Array, ' num_trees'] 1ab
567 partial_ratio: Float32[Array, ' num_trees'] | None 1ab
568 log_trans_prior_ratio: None | Float32[Array, ' num_trees'] 1ab
569 grow_var: UInt[Array, ' num_trees'] 1ab
570 grow_split: UInt[Array, ' num_trees'] 1ab
571 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
572 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
573 logu: Float32[Array, ' num_trees'] 1ab
574 acc: None | Bool[Array, ' num_trees'] 1ab
575 to_prune: None | Bool[Array, ' num_trees'] 1ab
578def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: 1ab
579 """
580 Propose moves for all the trees.
582 There are two types of moves: GROW (convert a leaf to a decision node and
583 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
584 leaf, deleting its children).
586 Parameters
587 ----------
588 key
589 A jax random key.
590 forest
591 The `forest` field of a BART MCMC state.
593 Returns
594 -------
595 The proposed move for each tree.
596 """
597 num_trees, _ = forest.leaf_tree.shape 1ab
598 keys = split(key, 1 + 2 * num_trees) 1ab
600 # compute moves
601 grow_moves = propose_grow_moves( 1ab
602 keys.pop(num_trees),
603 forest.var_tree,
604 forest.split_tree,
605 forest.affluence_tree,
606 forest.max_split,
607 forest.blocked_vars,
608 forest.p_nonterminal,
609 forest.p_propose_grow,
610 forest.log_s,
611 )
612 prune_moves = propose_prune_moves( 1ab
613 keys.pop(num_trees),
614 forest.split_tree,
615 grow_moves.affluence_tree,
616 forest.p_nonterminal,
617 forest.p_propose_grow,
618 )
620 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1ab
622 # choose between grow or prune
623 p_grow = jnp.where( 1ab
624 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
625 )
626 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1ab
628 # compute children indices
629 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1ab
630 left = node << 1 1ab
631 right = left + 1 1ab
633 return Moves( 1ab
634 allowed=grow_moves.allowed | prune_moves.allowed,
635 grow=grow,
636 num_growable=grow_moves.num_growable,
637 node=node,
638 left=left,
639 right=right,
640 partial_ratio=jnp.where(
641 grow, grow_moves.partial_ratio, prune_moves.partial_ratio
642 ),
643 log_trans_prior_ratio=None, # will be set in complete_ratio
644 grow_var=grow_moves.var,
645 grow_split=grow_moves.split,
646 # var_tree does not need to be updated if prune
647 var_tree=grow_moves.var_tree,
648 # affluence_tree is updated for both moves unconditionally, prune last
649 affluence_tree=prune_moves.affluence_tree,
650 logu=jnp.log1p(-exp1mlogu),
651 acc=None, # will be set in accept_moves_sequential_stage
652 to_prune=None, # will be set in accept_moves_sequential_stage
653 )
656class GrowMoves(Module): 1ab
657 """
658 Represent a proposed grow move for each tree.
660 Parameters
661 ----------
662 allowed
663 Whether the move is allowed for proposal.
664 num_growable
665 The number of leaves that can be proposed for grow.
666 node
667 The index of the leaf to grow. ``2 ** d`` if there are no growable
668 leaves.
669 var
670 split
671 The decision axis and boundary of the new rule.
672 partial_ratio
673 A factor of the Metropolis-Hastings ratio of the move. It lacks
674 the likelihood ratio and the probability of proposing the prune
675 move.
676 var_tree
677 The updated decision axes of the tree.
678 affluence_tree
679 A partially updated `affluence_tree` that marks each new leaf that
680 would be produced as `True` if it would have available decision rules.
681 """
683 allowed: Bool[Array, ' num_trees'] 1ab
684 num_growable: UInt[Array, ' num_trees'] 1ab
685 node: UInt[Array, ' num_trees'] 1ab
686 var: UInt[Array, ' num_trees'] 1ab
687 split: UInt[Array, ' num_trees'] 1ab
688 partial_ratio: Float32[Array, ' num_trees'] 1ab
689 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
690 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
693@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None)) 1ab
694def propose_grow_moves( 1ab
695 key: Key[Array, ' num_trees'],
696 var_tree: UInt[Array, 'num_trees 2**(d-1)'],
697 split_tree: UInt[Array, 'num_trees 2**(d-1)'],
698 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
699 max_split: UInt[Array, ' p'],
700 blocked_vars: Int32[Array, ' k'] | None,
701 p_nonterminal: Float32[Array, ' 2**d'],
702 p_propose_grow: Float32[Array, ' 2**(d-1)'],
703 log_s: Float32[Array, ' p'] | None,
704) -> GrowMoves:
705 """
706 Propose a GROW move for each tree.
708 A GROW move picks a leaf node and converts it to a non-terminal node with
709 two leaf children.
711 Parameters
712 ----------
713 key
714 A jax random key.
715 var_tree
716 The splitting axes of the tree.
717 split_tree
718 The splitting points of the tree.
719 affluence_tree
720 Whether each leaf has enough points to be grown.
721 max_split
722 The maximum split index for each variable.
723 blocked_vars
724 The indices of the variables that have no available cutpoints.
725 p_nonterminal
726 The a priori probability of a node to be nonterminal conditional on the
727 ancestors, including at the maximum depth where it should be zero.
728 p_propose_grow
729 The unnormalized probability of choosing a leaf to grow.
730 log_s
731 Unnormalized log-probability used to choose a variable to split on
732 amongst the available ones.
734 Returns
735 -------
736 An object representing the proposed move.
738 Notes
739 -----
740 The move is not proposed if each leaf is already at maximum depth, or has
741 less datapoints than the requested threshold `min_points_per_decision_node`,
742 or it does not have any available decision rules given its ancestors. This
743 is marked by setting `allowed` to `False` and `num_growable` to 0.
744 """
745 keys = split(key, 3) 1ab
747 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1ab
748 keys.pop(), split_tree, affluence_tree, p_propose_grow
749 )
751 # sample a decision rule
752 var, num_available_var = choose_variable( 1ab
753 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
754 )
755 split_idx, l, r = choose_split( 1ab
756 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
757 )
759 # determine if the new leaves would have available decision rules; if the
760 # move is blocked, these values may not make sense
761 left_growable = right_growable = num_available_var > 1 1ab
762 left_growable |= l < split_idx 1ab
763 right_growable |= split_idx + 1 < r 1ab
764 left = leaf_to_grow << 1 1ab
765 right = left + 1 1ab
766 affluence_tree = affluence_tree.at[left].set(left_growable) 1ab
767 affluence_tree = affluence_tree.at[right].set(right_growable) 1ab
769 ratio = compute_partial_ratio( 1ab
770 prob_choose, num_prunable, p_nonterminal, leaf_to_grow
771 )
773 return GrowMoves( 1ab
774 allowed=num_growable > 0,
775 num_growable=num_growable,
776 node=leaf_to_grow,
777 var=var,
778 split=split_idx,
779 partial_ratio=ratio,
780 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
781 affluence_tree=affluence_tree,
782 )
785def choose_leaf( 1ab
786 key: Key[Array, ''],
787 split_tree: UInt[Array, ' 2**(d-1)'],
788 affluence_tree: Bool[Array, ' 2**(d-1)'],
789 p_propose_grow: Float32[Array, ' 2**(d-1)'],
790) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
791 """
792 Choose a leaf node to grow in a tree.
794 Parameters
795 ----------
796 key
797 A jax random key.
798 split_tree
799 The splitting points of the tree.
800 affluence_tree
801 Whether a leaf has enough points that it could be split into two leaves
802 satisfying the `min_points_per_leaf` requirement.
803 p_propose_grow
804 The unnormalized probability of choosing a leaf to grow.
806 Returns
807 -------
808 leaf_to_grow : Int32[Array, '']
809 The index of the leaf to grow. If ``num_growable == 0``, return
810 ``2 ** d``.
811 num_growable : Int32[Array, '']
812 The number of leaf nodes that can be grown, i.e., are nonterminal
813 and have at least twice `min_points_per_leaf`.
814 prob_choose : Float32[Array, '']
815 The (normalized) probability that this function had to choose that
816 specific leaf, given the arguments.
817 num_prunable : Int32[Array, '']
818 The number of leaf parents that could be pruned, after converting the
819 selected leaf to a non-terminal node.
820 """
821 is_growable = growable_leaves(split_tree, affluence_tree) 1ab
822 num_growable = jnp.count_nonzero(is_growable) 1ab
823 distr = jnp.where(is_growable, p_propose_grow, 0) 1ab
824 leaf_to_grow, distr_norm = categorical(key, distr) 1ab
825 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1ab
826 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1ab
827 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1ab
828 num_prunable = jnp.count_nonzero(is_parent) 1ab
829 return leaf_to_grow, num_growable, prob_choose, num_prunable 1ab
832def growable_leaves( 1ab
833 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
834) -> Bool[Array, ' 2**(d-1)']:
835 """
836 Return a mask indicating the leaf nodes that can be proposed for growth.
838 The condition is that a leaf is not at the bottom level, has available
839 decision rules given its ancestors, and has at least
840 `min_points_per_decision_node` points.
842 Parameters
843 ----------
844 split_tree
845 The splitting points of the tree.
846 affluence_tree
847 Marks leaves that can be grown.
849 Returns
850 -------
851 The mask indicating the leaf nodes that can be proposed to grow.
853 Notes
854 -----
855 This function needs `split_tree` and not just `affluence_tree` because
856 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
857 """
858 return grove.is_actual_leaf(split_tree) & affluence_tree 1ab
861def categorical( 1ab
862 key: Key[Array, ''], distr: Float32[Array, ' n']
863) -> tuple[Int32[Array, ''], Float32[Array, '']]:
864 """
865 Return a random integer from an arbitrary distribution.
867 Parameters
868 ----------
869 key
870 A jax random key.
871 distr
872 An unnormalized probability distribution.
874 Returns
875 -------
876 u : Int32[Array, '']
877 A random integer in the range ``[0, n)``. If all probabilities are zero,
878 return ``n``.
879 norm : Float32[Array, '']
880 The sum of `distr`.
882 Notes
883 -----
884 This function uses a cumsum instead of the Gumbel trick, so it's ok only
885 for small ranges with probabilities well greater than 0.
886 """
887 ecdf = jnp.cumsum(distr) 1ab
888 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1ab
889 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1ab
892def choose_variable( 1ab
893 key: Key[Array, ''],
894 var_tree: UInt[Array, ' 2**(d-1)'],
895 split_tree: UInt[Array, ' 2**(d-1)'],
896 max_split: UInt[Array, ' p'],
897 leaf_index: Int32[Array, ''],
898 blocked_vars: Int32[Array, ' k'] | None,
899 log_s: Float32[Array, ' p'] | None,
900) -> tuple[Int32[Array, ''], Int32[Array, '']]:
901 """
902 Choose a variable to split on for a new non-terminal node.
904 Parameters
905 ----------
906 key
907 A jax random key.
908 var_tree
909 The variable indices of the tree.
910 split_tree
911 The splitting points of the tree.
912 max_split
913 The maximum split index for each variable.
914 leaf_index
915 The index of the leaf to grow.
916 blocked_vars
917 The indices of the variables that have no available cutpoints. If
918 `None`, all variables are assumed unblocked.
919 log_s
920 The logarithm of the prior probability for choosing a variable. If
921 `None`, use a uniform distribution.
923 Returns
924 -------
925 var : Int32[Array, '']
926 The index of the variable to split on.
927 num_available_var : Int32[Array, '']
928 The number of variables with available decision rules `var` was chosen
929 from.
930 """
931 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1ab
932 if blocked_vars is not None: 1ab
933 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1ab
935 if log_s is None: 1ab
936 return randint_exclude(key, max_split.size, var_to_ignore) 1ab
937 else:
938 return categorical_exclude(key, log_s, var_to_ignore) 1ab
941def fully_used_variables( 1ab
942 var_tree: UInt[Array, ' 2**(d-1)'],
943 split_tree: UInt[Array, ' 2**(d-1)'],
944 max_split: UInt[Array, ' p'],
945 leaf_index: Int32[Array, ''],
946) -> UInt[Array, ' d-2']:
947 """
948 Find variables in the ancestors of a node that have an empty split range.
950 Parameters
951 ----------
952 var_tree
953 The variable indices of the tree.
954 split_tree
955 The splitting points of the tree.
956 max_split
957 The maximum split index for each variable.
958 leaf_index
959 The index of the node, assumed to be valid for `var_tree`.
961 Returns
962 -------
963 The indices of the variables that have an empty split range.
965 Notes
966 -----
967 The number of unused variables is not known in advance. Unused values in the
968 array are filled with `p`. The fill values are not guaranteed to be placed
969 in any particular order, and variables may appear more than once.
970 """
971 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1ab
972 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1ab
973 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1ab
974 num_split = r - l 1ab
975 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1ab
976 # the type of var_to_ignore is already sufficient to hold max_split.size,
977 # see ancestor_variables()
980def ancestor_variables( 1ab
981 var_tree: UInt[Array, ' 2**(d-1)'],
982 max_split: UInt[Array, ' p'],
983 node_index: Int32[Array, ''],
984) -> UInt[Array, ' d-2']:
985 """
986 Return the list of variables in the ancestors of a node.
988 Parameters
989 ----------
990 var_tree
991 The variable indices of the tree.
992 max_split
993 The maximum split index for each variable. Used only to get `p`.
994 node_index
995 The index of the node, assumed to be valid for `var_tree`.
997 Returns
998 -------
999 The variable indices of the ancestors of the node.
1001 Notes
1002 -----
1003 The ancestors are the nodes going from the root to the parent of the node.
1004 The number of ancestors is not known at tracing time; unused spots in the
1005 output array are filled with `p`.
1006 """
1007 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab
1008 ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size)) 1ab
1009 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1ab
1011 def loop(carry, _): 1ab
1012 i, index, ancestor_vars = carry 1ab
1013 index >>= 1 1ab
1014 var = var_tree[index] 1ab
1015 var = jnp.where(index, var, max_split.size) 1ab
1016 ancestor_vars = ancestor_vars.at[i].set(var) 1ab
1017 return (i - 1, index, ancestor_vars), None 1ab
1019 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1ab
1020 return ancestor_vars 1ab
1023def split_range( 1ab
1024 var_tree: UInt[Array, ' 2**(d-1)'],
1025 split_tree: UInt[Array, ' 2**(d-1)'],
1026 max_split: UInt[Array, ' p'],
1027 node_index: Int32[Array, ''],
1028 ref_var: Int32[Array, ''],
1029) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1030 """
1031 Return the range of allowed splits for a variable at a given node.
1033 Parameters
1034 ----------
1035 var_tree
1036 The variable indices of the tree.
1037 split_tree
1038 The splitting points of the tree.
1039 max_split
1040 The maximum split index for each variable.
1041 node_index
1042 The index of the node, assumed to be valid for `var_tree`.
1043 ref_var
1044 The variable for which to measure the split range.
1046 Returns
1047 -------
1048 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
1049 """
1050 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab
1051 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1ab
1052 jnp.int32
1053 )
1054 carry = jnp.int32(0), initial_r, node_index 1ab
1056 def loop(carry, _): 1ab
1057 l, r, index = carry 1ab
1058 right_child = (index & 1).astype(bool) 1ab
1059 index >>= 1 1ab
1060 split = split_tree[index] 1ab
1061 cond = (var_tree[index] == ref_var) & index.astype(bool) 1ab
1062 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1ab
1063 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1ab
1064 return (l, r, index), None 1ab
1066 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1ab
1067 return l + 1, r 1ab
1070def randint_exclude( 1ab
1071 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
1072) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1073 """
1074 Return a random integer in a range, excluding some values.
1076 Parameters
1077 ----------
1078 key
1079 A jax random key.
1080 sup
1081 The exclusive upper bound of the range.
1082 exclude
1083 The values to exclude from the range. Values greater than or equal to
1084 `sup` are ignored. Values can appear more than once.
1086 Returns
1087 -------
1088 u : Int32[Array, '']
1089 A random integer `u` in the range ``[0, sup)`` such that ``u not in
1090 exclude``.
1091 num_allowed : Int32[Array, '']
1092 The number of integers in the range that were not excluded.
1094 Notes
1095 -----
1096 If all values in the range are excluded, return `sup`.
1097 """
1098 exclude, num_allowed = _process_exclude(sup, exclude) 1ab
1099 u = random.randint(key, (), 0, num_allowed) 1ab
1101 def loop(u, i_excluded): 1ab
1102 return jnp.where(i_excluded <= u, u + 1, u), None 1ab
1104 u, _ = lax.scan(loop, u, exclude) 1ab
1105 return u, num_allowed 1ab
1108def _process_exclude(sup, exclude): 1ab
1109 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1ab
1110 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1ab
1111 return exclude, num_allowed 1ab
1114def categorical_exclude( 1ab
1115 key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
1116) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1117 """
1118 Draw from a categorical distribution, excluding a set of values.
1120 Parameters
1121 ----------
1122 key
1123 A jax random key.
1124 logits
1125 The unnormalized log-probabilities of each category.
1126 exclude
1127 The values to exclude from the range [0, k). Values greater than or
1128 equal to `logits.size` are ignored. Values can appear more than once.
1130 Returns
1131 -------
1132 u : Int32[Array, '']
1133 A random integer in the range ``[0, k)`` such that ``u not in exclude``.
1134 num_allowed : Int32[Array, '']
1135 The number of integers in the range that were not excluded.
1137 Notes
1138 -----
1139 If all values in the range are excluded, the result is unspecified.
1140 """
1141 exclude, num_allowed = _process_exclude(logits.size, exclude) 1ab
1142 kinda_neg_inf = jnp.finfo(logits.dtype).min 1ab
1143 logits = logits.at[exclude].set(kinda_neg_inf) 1ab
1144 u = random.categorical(key, logits) 1ab
1145 return u, num_allowed 1ab
1148def choose_split( 1ab
1149 key: Key[Array, ''],
1150 var: Int32[Array, ''],
1151 var_tree: UInt[Array, ' 2**(d-1)'],
1152 split_tree: UInt[Array, ' 2**(d-1)'],
1153 max_split: UInt[Array, ' p'],
1154 leaf_index: Int32[Array, ''],
1155) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
1156 """
1157 Choose a split point for a new non-terminal node.
1159 Parameters
1160 ----------
1161 key
1162 A jax random key.
1163 var
1164 The variable to split on.
1165 var_tree
1166 The splitting axes of the tree. Does not need to already contain `var`
1167 at `leaf_index`.
1168 split_tree
1169 The splitting points of the tree.
1170 max_split
1171 The maximum split index for each variable.
1172 leaf_index
1173 The index of the leaf to grow.
1175 Returns
1176 -------
1177 split : Int32[Array, '']
1178 The cutpoint.
1179 l : Int32[Array, '']
1180 r : Int32[Array, '']
1181 The integer range `split` was drawn from is [l, r).
1183 Notes
1184 -----
1185 If `var` is out of bounds, or if the available split range on that variable
1186 is empty, return 0.
1187 """
1188 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1ab
1189 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1ab
1192def compute_partial_ratio( 1ab
1193 prob_choose: Float32[Array, ''],
1194 num_prunable: Int32[Array, ''],
1195 p_nonterminal: Float32[Array, ' 2**d'],
1196 leaf_to_grow: Int32[Array, ''],
1197) -> Float32[Array, '']:
1198 """
1199 Compute the product of the transition and prior ratios of a grow move.
1201 Parameters
1202 ----------
1203 prob_choose
1204 The probability that the leaf had to be chosen amongst the growable
1205 leaves.
1206 num_prunable
1207 The number of leaf parents that could be pruned, after converting the
1208 leaf to be grown to a non-terminal node.
1209 p_nonterminal
1210 The a priori probability of each node being nonterminal conditional on
1211 its ancestors.
1212 leaf_to_grow
1213 The index of the leaf to grow.
1215 Returns
1216 -------
1217 The partial transition ratio times the prior ratio.
1219 Notes
1220 -----
1221 The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1222 The "partial" transition ratio returned is missing the factor P(propose
1223 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
1224 "partial" prior ratio is missing the factor P(children are leaves).
1225 """
1226 # the two ratios also contain factors num_available_split *
1227 # num_available_var * s[var], but they cancel out
1229 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
1230 # computed here because they need the count trees, which are computed in the
1231 # acceptance phase
1233 prune_allowed = leaf_to_grow != 1 1ab
1234 # prune allowed <---> the initial tree is not a root
1235 # leaf to grow is root --> the tree can only be a root
1236 # tree is a root --> the only leaf I can grow is root
1237 p_grow = jnp.where(prune_allowed, 0.5, 1) 1ab
1238 inv_trans_ratio = p_grow * prob_choose * num_prunable 1ab
1240 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this
1241 # would produce a 0 and then an inf when `complete_ratio` takes the log
1242 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1ab
1243 tree_ratio = pnt / (1 - pnt) 1ab
1245 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1ab
1248class PruneMoves(Module): 1ab
1249 """
1250 Represent a proposed prune move for each tree.
1252 Parameters
1253 ----------
1254 allowed
1255 Whether the move is possible.
1256 node
1257 The index of the node to prune. ``2 ** d`` if no node can be pruned.
1258 partial_ratio
1259 A factor of the Metropolis-Hastings ratio of the move. It lacks the
1260 likelihood ratio, the probability of proposing the prune move, and the
1261 prior probability that the children of the node to prune are leaves.
1262 This ratio is inverted, and is meant to be inverted back in
1263 `accept_move_and_sample_leaves`.
1264 """
1266 allowed: Bool[Array, ' num_trees'] 1ab
1267 node: UInt[Array, ' num_trees'] 1ab
1268 partial_ratio: Float32[Array, ' num_trees'] 1ab
1269 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
1272@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 1ab
1273def propose_prune_moves( 1ab
1274 key: Key[Array, ''],
1275 split_tree: UInt[Array, ' 2**(d-1)'],
1276 affluence_tree: Bool[Array, ' 2**(d-1)'],
1277 p_nonterminal: Float32[Array, ' 2**d'],
1278 p_propose_grow: Float32[Array, ' 2**(d-1)'],
1279) -> PruneMoves:
1280 """
1281 Tree structure prune move proposal of BART MCMC.
1283 Parameters
1284 ----------
1285 key
1286 A jax random key.
1287 split_tree
1288 The splitting points of the tree.
1289 affluence_tree
1290 Whether each leaf can be grown.
1291 p_nonterminal
1292 The a priori probability of a node to be nonterminal conditional on
1293 the ancestors, including at the maximum depth where it should be zero.
1294 p_propose_grow
1295 The unnormalized probability of choosing a leaf to grow.
1297 Returns
1298 -------
1299 An object representing the proposed moves.
1300 """
1301 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1ab
1302 key, split_tree, affluence_tree, p_propose_grow
1303 )
1305 ratio = compute_partial_ratio( 1ab
1306 prob_choose, num_prunable, p_nonterminal, node_to_prune
1307 )
1309 return PruneMoves( 1ab
1310 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
1311 node=node_to_prune,
1312 partial_ratio=ratio,
1313 affluence_tree=affluence_tree,
1314 )
1317def choose_leaf_parent( 1ab
1318 key: Key[Array, ''],
1319 split_tree: UInt[Array, ' 2**(d-1)'],
1320 affluence_tree: Bool[Array, ' 2**(d-1)'],
1321 p_propose_grow: Float32[Array, ' 2**(d-1)'],
1322) -> tuple[
1323 Int32[Array, ''],
1324 Int32[Array, ''],
1325 Float32[Array, ''],
1326 Bool[Array, 'num_trees 2**(d-1)'],
1327]:
1328 """
1329 Pick a non-terminal node with leaf children to prune in a tree.
1331 Parameters
1332 ----------
1333 key
1334 A jax random key.
1335 split_tree
1336 The splitting points of the tree.
1337 affluence_tree
1338 Whether a leaf has enough points to be grown.
1339 p_propose_grow
1340 The unnormalized probability of choosing a leaf to grow.
1342 Returns
1343 -------
1344 node_to_prune : Int32[Array, '']
1345 The index of the node to prune. If ``num_prunable == 0``, return
1346 ``2 ** d``.
1347 num_prunable : Int32[Array, '']
1348 The number of leaf parents that could be pruned.
1349 prob_choose : Float32[Array, '']
1350 The (normalized) probability that `choose_leaf` would chose
1351 `node_to_prune` as leaf to grow, if passed the tree where
1352 `node_to_prune` had been pruned.
1353 affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
1354 A partially updated `affluence_tree`, marking the node to prune as
1355 growable.
1356 """
1357 # sample a node to prune
1358 is_prunable = grove.is_leaves_parent(split_tree) 1ab
1359 num_prunable = jnp.count_nonzero(is_prunable) 1ab
1360 node_to_prune = randint_masked(key, is_prunable) 1ab
1361 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1ab
1363 # compute stuff for reverse move
1364 split_tree = split_tree.at[node_to_prune].set(0) 1ab
1365 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1ab
1366 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1ab
1367 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1ab
1368 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1ab
1369 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1ab
1371 return node_to_prune, num_prunable, prob_choose, affluence_tree 1ab
1374def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']: 1ab
1375 """
1376 Return a random integer in a range, including only some values.
1378 Parameters
1379 ----------
1380 key
1381 A jax random key.
1382 mask
1383 The mask indicating the allowed values.
1385 Returns
1386 -------
1387 A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
1389 Notes
1390 -----
1391 If all values in the mask are `False`, return `n`.
1392 """
1393 ecdf = jnp.cumsum(mask) 1ab
1394 u = random.randint(key, (), 0, ecdf[-1]) 1ab
1395 return jnp.searchsorted(ecdf, u, 'right') 1ab
1398def accept_moves_and_sample_leaves( 1ab
1399 key: Key[Array, ''], bart: State, moves: Moves
1400) -> State:
1401 """
1402 Accept or reject the proposed moves and sample the new leaf values.
1404 Parameters
1405 ----------
1406 key
1407 A jax random key.
1408 bart
1409 A valid BART mcmc state.
1410 moves
1411 The proposed moves, see `propose_moves`.
1413 Returns
1414 -------
1415 A new (valid) BART mcmc state.
1416 """
1417 pso = accept_moves_parallel_stage(key, bart, moves) 1ab
1418 bart, moves = accept_moves_sequential_stage(pso) 1ab
1419 return accept_moves_final_stage(bart, moves) 1ab
1422class Counts(Module): 1ab
1423 """
1424 Number of datapoints in the nodes involved in proposed moves for each tree.
1426 Parameters
1427 ----------
1428 left
1429 Number of datapoints in the left child.
1430 right
1431 Number of datapoints in the right child.
1432 total
1433 Number of datapoints in the parent (``= left + right``).
1434 """
1436 left: UInt[Array, ' num_trees'] 1ab
1437 right: UInt[Array, ' num_trees'] 1ab
1438 total: UInt[Array, ' num_trees'] 1ab
1441class Precs(Module): 1ab
1442 """
1443 Likelihood precision scale in the nodes involved in proposed moves for each tree.
1445 The "likelihood precision scale" of a tree node is the sum of the inverse
1446 squared error scales of the datapoints selected by the node.
1448 Parameters
1449 ----------
1450 left
1451 Likelihood precision scale in the left child.
1452 right
1453 Likelihood precision scale in the right child.
1454 total
1455 Likelihood precision scale in the parent (``= left + right``).
1456 """
1458 left: Float32[Array, ' num_trees'] 1ab
1459 right: Float32[Array, ' num_trees'] 1ab
1460 total: Float32[Array, ' num_trees'] 1ab
1463class PreLkV(Module): 1ab
1464 """
1465 Non-sequential terms of the likelihood ratio for each tree.
1467 These terms can be computed in parallel across trees.
1469 Parameters
1470 ----------
1471 sigma2_left
1472 The noise variance in the left child of the leaves grown or pruned by
1473 the moves.
1474 sigma2_right
1475 The noise variance in the right child of the leaves grown or pruned by
1476 the moves.
1477 sigma2_total
1478 The noise variance in the total of the leaves grown or pruned by the
1479 moves.
1480 sqrt_term
1481 The **logarithm** of the square root term of the likelihood ratio.
1482 """
1484 sigma2_left: Float32[Array, ' num_trees'] 1ab
1485 sigma2_right: Float32[Array, ' num_trees'] 1ab
1486 sigma2_total: Float32[Array, ' num_trees'] 1ab
1487 sqrt_term: Float32[Array, ' num_trees'] 1ab
1490class PreLk(Module): 1ab
1491 """
1492 Non-sequential terms of the likelihood ratio shared by all trees.
1494 Parameters
1495 ----------
1496 exp_factor
1497 The factor to multiply the likelihood ratio by, shared by all trees.
1498 """
1500 exp_factor: Float32[Array, ''] 1ab
1503class PreLf(Module): 1ab
1504 """
1505 Pre-computed terms used to sample leaves from their posterior.
1507 These terms can be computed in parallel across trees.
1509 Parameters
1510 ----------
1511 mean_factor
1512 The factor to be multiplied by the sum of the scaled residuals to
1513 obtain the posterior mean.
1514 centered_leaves
1515 The mean-zero normal values to be added to the posterior mean to
1516 obtain the posterior leaf samples.
1517 """
1519 mean_factor: Float32[Array, 'num_trees 2**d'] 1ab
1520 centered_leaves: Float32[Array, 'num_trees 2**d'] 1ab
1523class ParallelStageOut(Module): 1ab
1524 """
1525 The output of `accept_moves_parallel_stage`.
1527 Parameters
1528 ----------
1529 bart
1530 A partially updated BART mcmc state.
1531 moves
1532 The proposed moves, with `partial_ratio` set to `None` and
1533 `log_trans_prior_ratio` set to its final value.
1534 prec_trees
1535 The likelihood precision scale in each potential or actual leaf node. If
1536 there is no precision scale, this is the number of points in each leaf.
1537 move_counts
1538 The counts of the number of points in the the nodes modified by the
1539 moves. If `bart.min_points_per_leaf` is not set and
1540 `bart.prec_scale` is set, they are not computed.
1541 move_precs
1542 The likelihood precision scale in each node modified by the moves. If
1543 `bart.prec_scale` is not set, this is set to `move_counts`.
1544 prelkv
1545 prelk
1546 prelf
1547 Objects with pre-computed terms of the likelihood ratios and leaf
1548 samples.
1549 """
1551 bart: State 1ab
1552 moves: Moves 1ab
1553 prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d'] 1ab
1554 move_precs: Precs | Counts 1ab
1555 prelkv: PreLkV 1ab
1556 prelk: PreLk 1ab
1557 prelf: PreLf 1ab
1560def accept_moves_parallel_stage( 1ab
1561 key: Key[Array, ''], bart: State, moves: Moves
1562) -> ParallelStageOut:
1563 """
1564 Pre-compute quantities used to accept moves, in parallel across trees.
1566 Parameters
1567 ----------
1568 key : jax.dtypes.prng_key array
1569 A jax random key.
1570 bart : dict
1571 A BART mcmc state.
1572 moves : dict
1573 The proposed moves, see `propose_moves`.
1575 Returns
1576 -------
1577 An object with all that could be done in parallel.
1578 """
1579 # where the move is grow, modify the state like the move was accepted
1580 bart = replace( 1ab
1581 bart,
1582 forest=replace(
1583 bart.forest,
1584 var_tree=moves.var_tree,
1585 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1586 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
1587 ),
1588 )
1590 # count number of datapoints per leaf
1591 if ( 1591 ↛ 1601line 1591 didn't jump to line 1601 because the condition on line 1591 was always true
1592 bart.forest.min_points_per_decision_node is not None
1593 or bart.forest.min_points_per_leaf is not None
1594 or bart.prec_scale is None
1595 ):
1596 count_trees, move_counts = compute_count_trees( 1ab
1597 bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1598 )
1600 # mark which leaves & potential leaves have enough points to be grown
1601 if bart.forest.min_points_per_decision_node is not None: 1ab
1602 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1ab
1603 moves = replace( 1ab
1604 moves,
1605 affluence_tree=moves.affluence_tree
1606 & (count_half_trees >= bart.forest.min_points_per_decision_node),
1607 )
1609 # copy updated affluence_tree to state
1610 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1ab
1612 # veto grove move if new leaves don't have enough datapoints
1613 if bart.forest.min_points_per_leaf is not None: 1ab
1614 moves = replace( 1ab
1615 moves,
1616 allowed=moves.allowed
1617 & (move_counts.left >= bart.forest.min_points_per_leaf)
1618 & (move_counts.right >= bart.forest.min_points_per_leaf),
1619 )
1621 # count number of datapoints per leaf, weighted by error precision scale
1622 if bart.prec_scale is None: 1ab
1623 prec_trees = count_trees 1ab
1624 move_precs = move_counts 1ab
1625 else:
1626 prec_trees, move_precs = compute_prec_trees( 1ab
1627 bart.prec_scale,
1628 bart.forest.leaf_indices,
1629 moves,
1630 bart.forest.count_batch_size,
1631 )
1632 assert move_precs is not None 1ab
1634 # compute some missing information about moves
1635 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1ab
1636 save_ratios = bart.forest.log_likelihood is not None 1ab
1637 bart = replace( 1ab
1638 bart,
1639 forest=replace(
1640 bart.forest,
1641 grow_prop_count=jnp.sum(moves.grow),
1642 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1643 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
1644 ),
1645 )
1647 # pre-compute some likelihood ratio & posterior terms
1648 assert bart.sigma2 is not None # `step` shall temporarily set it to 1 1ab
1649 prelkv, prelk = precompute_likelihood_terms( 1ab
1650 bart.sigma2, bart.forest.sigma_mu2, move_precs
1651 )
1652 prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2) 1ab
1654 return ParallelStageOut( 1ab
1655 bart=bart,
1656 moves=moves,
1657 prec_trees=prec_trees,
1658 move_precs=move_precs,
1659 prelkv=prelkv,
1660 prelk=prelk,
1661 prelf=prelf,
1662 )
1665@partial(vmap_nodoc, in_axes=(0, 0, None)) 1ab
1666def apply_grow_to_indices( 1ab
1667 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
1668) -> UInt[Array, 'num_trees n']:
1669 """
1670 Update the leaf indices to apply a grow move.
1672 Parameters
1673 ----------
1674 moves
1675 The proposed moves, see `propose_moves`.
1676 leaf_indices
1677 The index of the leaf each datapoint falls into.
1678 X
1679 The predictors matrix.
1681 Returns
1682 -------
1683 The updated leaf indices.
1684 """
1685 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab
1686 go_right = X[moves.grow_var, :] >= moves.grow_split 1ab
1687 tree_size = jnp.array(2 * moves.var_tree.size) 1ab
1688 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab
1689 return jnp.where( 1ab
1690 leaf_indices == node_to_update, left_child + go_right, leaf_indices
1691 )
1694def compute_count_trees( 1ab
1695 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
1696) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
1697 """
1698 Count the number of datapoints in each leaf.
1700 Parameters
1701 ----------
1702 leaf_indices
1703 The index of the leaf each datapoint falls into, with the deeper version
1704 of the tree (post-GROW, pre-PRUNE).
1705 moves
1706 The proposed moves, see `propose_moves`.
1707 batch_size
1708 The data batch size to use for the summation.
1710 Returns
1711 -------
1712 count_trees : Int32[Array, 'num_trees 2**d']
1713 The number of points in each potential or actual leaf node.
1714 counts : Counts
1715 The counts of the number of points in the leaves grown or pruned by the
1716 moves.
1717 """
1718 num_trees, tree_size = moves.var_tree.shape 1ab
1719 tree_size *= 2 1ab
1720 tree_indices = jnp.arange(num_trees) 1ab
1722 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1ab
1724 # count datapoints in nodes modified by move
1725 left = count_trees[tree_indices, moves.left] 1ab
1726 right = count_trees[tree_indices, moves.right] 1ab
1727 counts = Counts(left=left, right=right, total=left + right) 1ab
1729 # write count into non-leaf node
1730 count_trees = count_trees.at[tree_indices, moves.node].set(counts.total) 1ab
1732 return count_trees, counts 1ab
1735def count_datapoints_per_leaf( 1ab
1736 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
1737) -> Int32[Array, 'num_trees 2**(d-1)']:
1738 """
1739 Count the number of datapoints in each leaf.
1741 Parameters
1742 ----------
1743 leaf_indices
1744 The index of the leaf each datapoint falls into.
1745 tree_size
1746 The size of the leaf tree array (2 ** d).
1747 batch_size
1748 The data batch size to use for the summation.
1750 Returns
1751 -------
1752 The number of points in each leaf node.
1753 """
1754 if batch_size is None: 1ab
1755 return _count_scan(leaf_indices, tree_size) 1ab
1756 else:
1757 return _count_vec(leaf_indices, tree_size, batch_size) 1ab
1760def _count_scan( 1ab
1761 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
1762) -> Int32[Array, 'num_trees {tree_size}']:
1763 def loop(_, leaf_indices): 1ab
1764 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1ab
1766 _, count_trees = lax.scan(loop, None, leaf_indices) 1ab
1767 return count_trees 1ab
1770def _aggregate_scatter( 1ab
1771 values: Shaped[Array, '*'],
1772 indices: Integer[Array, '*'],
1773 size: int,
1774 dtype: jnp.dtype,
1775) -> Shaped[Array, ' {size}']:
1776 return jnp.zeros(size, dtype).at[indices].add(values) 1ab
1779def _count_vec( 1ab
1780 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
1781) -> Int32[Array, 'num_trees 2**(d-1)']:
1782 return _aggregate_batched_alltrees( 1ab
1783 1, leaf_indices, tree_size, jnp.uint32, batch_size
1784 )
1785 # uint16 is super-slow on gpu, don't use it even if n < 2^16
1788def _aggregate_batched_alltrees( 1ab
1789 values: Shaped[Array, '*'],
1790 indices: UInt[Array, 'num_trees n'],
1791 size: int,
1792 dtype: jnp.dtype,
1793 batch_size: int,
1794) -> Shaped[Array, 'num_trees {size}']:
1795 num_trees, n = indices.shape 1ab
1796 tree_indices = jnp.arange(num_trees) 1ab
1797 nbatches = n // batch_size + bool(n % batch_size) 1ab
1798 batch_indices = jnp.arange(n) % nbatches 1ab
1799 return ( 1ab
1800 jnp.zeros((num_trees, size, nbatches), dtype)
1801 .at[tree_indices[:, None], indices, batch_indices]
1802 .add(values)
1803 .sum(axis=2)
1804 )
1807def compute_prec_trees( 1ab
1808 prec_scale: Float32[Array, ' n'],
1809 leaf_indices: UInt[Array, 'num_trees n'],
1810 moves: Moves,
1811 batch_size: int | None,
1812) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
1813 """
1814 Compute the likelihood precision scale in each leaf.
1816 Parameters
1817 ----------
1818 prec_scale
1819 The scale of the precision of the error on each datapoint.
1820 leaf_indices
1821 The index of the leaf each datapoint falls into, with the deeper version
1822 of the tree (post-GROW, pre-PRUNE).
1823 moves
1824 The proposed moves, see `propose_moves`.
1825 batch_size
1826 The data batch size to use for the summation.
1828 Returns
1829 -------
1830 prec_trees : Float32[Array, 'num_trees 2**d']
1831 The likelihood precision scale in each potential or actual leaf node.
1832 precs : Precs
1833 The likelihood precision scale in the nodes involved in the moves.
1834 """
1835 num_trees, tree_size = moves.var_tree.shape 1ab
1836 tree_size *= 2 1ab
1837 tree_indices = jnp.arange(num_trees) 1ab
1839 prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size) 1ab
1841 # prec datapoints in nodes modified by move
1842 left = prec_trees[tree_indices, moves.left] 1ab
1843 right = prec_trees[tree_indices, moves.right] 1ab
1844 precs = Precs(left=left, right=right, total=left + right) 1ab
1846 # write prec into non-leaf node
1847 prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total) 1ab
1849 return prec_trees, precs 1ab
1852def prec_per_leaf( 1ab
1853 prec_scale: Float32[Array, ' n'],
1854 leaf_indices: UInt[Array, 'num_trees n'],
1855 tree_size: int,
1856 batch_size: int | None,
1857) -> Float32[Array, 'num_trees {tree_size}']:
1858 """
1859 Compute the likelihood precision scale in each leaf.
1861 Parameters
1862 ----------
1863 prec_scale
1864 The scale of the precision of the error on each datapoint.
1865 leaf_indices
1866 The index of the leaf each datapoint falls into.
1867 tree_size
1868 The size of the leaf tree array (2 ** d).
1869 batch_size
1870 The data batch size to use for the summation.
1872 Returns
1873 -------
1874 The likelihood precision scale in each leaf node.
1875 """
1876 if batch_size is None: 1876 ↛ 1879line 1876 didn't jump to line 1879 because the condition on line 1876 was always true1ab
1877 return _prec_scan(prec_scale, leaf_indices, tree_size) 1ab
1878 else:
1879 return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
1882def _prec_scan( 1ab
1883 prec_scale: Float32[Array, ' n'],
1884 leaf_indices: UInt[Array, 'num_trees n'],
1885 tree_size: int,
1886) -> Float32[Array, 'num_trees {tree_size}']:
1887 def loop(_, leaf_indices): 1ab
1888 return None, _aggregate_scatter( 1ab
1889 prec_scale, leaf_indices, tree_size, jnp.float32
1890 )
1892 _, prec_trees = lax.scan(loop, None, leaf_indices) 1ab
1893 return prec_trees 1ab
1896def _prec_vec( 1ab
1897 prec_scale: Float32[Array, ' n'],
1898 leaf_indices: UInt[Array, 'num_trees n'],
1899 tree_size: int,
1900 batch_size: int,
1901) -> Float32[Array, 'num_trees {tree_size}']:
1902 return _aggregate_batched_alltrees(
1903 prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1904 )
1907def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 1ab
1908 """
1909 Complete non-likelihood MH ratio calculation.
1911 This function adds the probability of choosing a prune move over the grow
1912 move in the inverse transition, and the a priori probability that the
1913 children nodes are leaves.
1915 Parameters
1916 ----------
1917 moves
1918 The proposed moves. Must have already been updated to keep into account
1919 the thresholds on the number of datapoints per node, this happens in
1920 `accept_moves_parallel_stage`.
1921 p_nonterminal
1922 The a priori probability of each node being nonterminal conditional on
1923 its ancestors, including at the maximum depth where it should be zero.
1925 Returns
1926 -------
1927 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1928 """
1929 # can the leaves can be grown?
1930 num_trees, _ = moves.affluence_tree.shape 1ab
1931 tree_indices = jnp.arange(num_trees) 1ab
1932 left_growable = moves.affluence_tree.at[tree_indices, moves.left].get( 1ab
1933 mode='fill', fill_value=False
1934 )
1935 right_growable = moves.affluence_tree.at[tree_indices, moves.right].get( 1ab
1936 mode='fill', fill_value=False
1937 )
1939 # p_prune if grow
1940 other_growable_leaves = moves.num_growable >= 2 1ab
1941 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1ab
1942 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1ab
1944 # p_prune if prune
1945 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab
1947 # select p_prune
1948 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab
1950 # prior probability of both children being terminal
1951 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1ab
1952 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1ab
1953 pt_children = pt_left * pt_right 1ab
1955 return replace( 1ab
1956 moves,
1957 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
1958 partial_ratio=None,
1959 )
1962@vmap_nodoc 1ab
1963def adapt_leaf_trees_to_grow_indices( 1ab
1964 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
1965) -> Float32[Array, 'num_trees 2**d']:
1966 """
1967 Modify leaves such that post-grow indices work on the original tree.
1969 The value of the leaf to grow is copied to what would be its children if the
1970 grow move was accepted.
1972 Parameters
1973 ----------
1974 leaf_trees
1975 The leaf values.
1976 moves
1977 The proposed moves, see `propose_moves`.
1979 Returns
1980 -------
1981 The modified leaf values.
1982 """
1983 values_at_node = leaf_trees[moves.node] 1ab
1984 return ( 1ab
1985 leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
1986 .set(values_at_node)
1987 .at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
1988 .set(values_at_node)
1989 )
1992def precompute_likelihood_terms( 1ab
1993 sigma2: Float32[Array, ''],
1994 sigma_mu2: Float32[Array, ''],
1995 move_precs: Precs | Counts,
1996) -> tuple[PreLkV, PreLk]:
1997 """
1998 Pre-compute terms used in the likelihood ratio of the acceptance step.
2000 Parameters
2001 ----------
2002 sigma2
2003 The error variance, or the global error variance factor is `prec_scale`
2004 is set.
2005 sigma_mu2
2006 The prior variance of each leaf.
2007 move_precs
2008 The likelihood precision scale in the leaves grown or pruned by the
2009 moves, under keys 'left', 'right', and 'total' (left + right).
2011 Returns
2012 -------
2013 prelkv : PreLkV
2014 Dictionary with pre-computed terms of the likelihood ratio, one per
2015 tree.
2016 prelk : PreLk
2017 Dictionary with pre-computed terms of the likelihood ratio, shared by
2018 all trees.
2019 """
2020 sigma2_left = sigma2 + move_precs.left * sigma_mu2 1ab
2021 sigma2_right = sigma2 + move_precs.right * sigma_mu2 1ab
2022 sigma2_total = sigma2 + move_precs.total * sigma_mu2 1ab
2023 prelkv = PreLkV( 1ab
2024 sigma2_left=sigma2_left,
2025 sigma2_right=sigma2_right,
2026 sigma2_total=sigma2_total,
2027 sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
2028 )
2029 return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2)) 1ab
2032def precompute_leaf_terms( 1ab
2033 key: Key[Array, ''],
2034 prec_trees: Float32[Array, 'num_trees 2**d'],
2035 sigma2: Float32[Array, ''],
2036 sigma_mu2: Float32[Array, ''],
2037) -> PreLf:
2038 """
2039 Pre-compute terms used to sample leaves from their posterior.
2041 Parameters
2042 ----------
2043 key
2044 A jax random key.
2045 prec_trees
2046 The likelihood precision scale in each potential or actual leaf node.
2047 sigma2
2048 The error variance, or the global error variance factor if `prec_scale`
2049 is set.
2050 sigma_mu2
2051 The prior variance of each leaf.
2053 Returns
2054 -------
2055 Pre-computed terms for leaf sampling.
2056 """
2057 prec_lk = prec_trees / sigma2 1ab
2058 prec_prior = lax.reciprocal(sigma_mu2) 1ab
2059 var_post = lax.reciprocal(prec_lk + prec_prior) 1ab
2060 z = random.normal(key, prec_trees.shape, sigma2.dtype) 1ab
2061 return PreLf( 1ab
2062 mean_factor=var_post / sigma2,
2063 # | mean = mean_lk * prec_lk * var_post
2064 # | resid_tree = mean_lk * prec_tree -->
2065 # | --> mean_lk = resid_tree / prec_tree (kind of)
2066 # | mean_factor =
2067 # | = mean / resid_tree =
2068 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
2069 # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
2070 # | = var_post / sigma2
2071 centered_leaves=z * jnp.sqrt(var_post),
2072 )
2075def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 1ab
2076 """
2077 Accept/reject the moves one tree at a time.
2079 This is the most performance-sensitive function because it contains all and
2080 only the parts of the algorithm that can not be parallelized across trees.
2082 Parameters
2083 ----------
2084 pso
2085 The output of `accept_moves_parallel_stage`.
2087 Returns
2088 -------
2089 bart : State
2090 A partially updated BART mcmc state.
2091 moves : Moves
2092 The accepted/rejected moves, with `acc` and `to_prune` set.
2093 """
2095 def loop(resid, pt): 1ab
2096 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1ab
2097 resid,
2098 SeqStageInAllTrees(
2099 pso.bart.X,
2100 pso.bart.forest.resid_batch_size,
2101 pso.bart.prec_scale,
2102 pso.bart.forest.log_likelihood is not None,
2103 pso.prelk,
2104 ),
2105 pt,
2106 )
2107 return resid, (leaf_tree, acc, to_prune, lkratio) 1ab
2109 pts = SeqStageInPerTree( 1ab
2110 pso.bart.forest.leaf_tree,
2111 pso.prec_trees,
2112 pso.moves,
2113 pso.move_precs,
2114 pso.bart.forest.leaf_indices,
2115 pso.prelkv,
2116 pso.prelf,
2117 )
2118 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1ab
2120 bart = replace( 1ab
2121 pso.bart,
2122 resid=resid,
2123 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
2124 )
2125 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab
2127 return bart, moves 1ab
2130class SeqStageInAllTrees(Module): 1ab
2131 """
2132 The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
2134 Parameters
2135 ----------
2136 X
2137 The predictors.
2138 resid_batch_size
2139 The batch size for computing the sum of residuals in each leaf.
2140 prec_scale
2141 The scale of the precision of the error on each datapoint. If None, it
2142 is assumed to be 1.
2143 save_ratios
2144 Whether to save the acceptance ratios.
2145 prelk
2146 The pre-computed terms of the likelihood ratio which are shared across
2147 trees.
2148 """
2150 X: UInt[Array, 'p n'] 1ab
2151 resid_batch_size: int | None = field(static=True) 1ab
2152 prec_scale: Float32[Array, ' n'] | None 1ab
2153 save_ratios: bool = field(static=True) 1ab
2154 prelk: PreLk 1ab
2157class SeqStageInPerTree(Module): 1ab
2158 """
2159 The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
2161 Parameters
2162 ----------
2163 leaf_tree
2164 The leaf values of the tree.
2165 prec_tree
2166 The likelihood precision scale in each potential or actual leaf node.
2167 move
2168 The proposed move, see `propose_moves`.
2169 move_precs
2170 The likelihood precision scale in each node modified by the moves.
2171 leaf_indices
2172 The leaf indices for the largest version of the tree compatible with
2173 the move.
2174 prelkv
2175 prelf
2176 The pre-computed terms of the likelihood ratio and leaf sampling which
2177 are specific to the tree.
2178 """
2180 leaf_tree: Float32[Array, ' 2**d'] 1ab
2181 prec_tree: Float32[Array, ' 2**d'] 1ab
2182 move: Moves 1ab
2183 move_precs: Precs | Counts 1ab
2184 leaf_indices: UInt[Array, ' n'] 1ab
2185 prelkv: PreLkV 1ab
2186 prelf: PreLf 1ab
2189def accept_move_and_sample_leaves( 1ab
2190 resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree
2191) -> tuple[
2192 Float32[Array, ' n'],
2193 Float32[Array, ' 2**d'],
2194 Bool[Array, ''],
2195 Bool[Array, ''],
2196 Float32[Array, ''] | None,
2197]:
2198 """
2199 Accept or reject a proposed move and sample the new leaf values.
2201 Parameters
2202 ----------
2203 resid
2204 The residuals (data minus forest value).
2205 at
2206 The inputs that are the same for all trees.
2207 pt
2208 The inputs that are separate for each tree.
2210 Returns
2211 -------
2212 resid : Float32[Array, 'n']
2213 The updated residuals (data minus forest value).
2214 leaf_tree : Float32[Array, '2**d']
2215 The new leaf values of the tree.
2216 acc : Bool[Array, '']
2217 Whether the move was accepted.
2218 to_prune : Bool[Array, '']
2219 Whether, to reflect the acceptance status of the move, the state should
2220 be updated by pruning the leaves involved in the move.
2221 log_lk_ratio : Float32[Array, ''] | None
2222 The logarithm of the likelihood ratio for the move. `None` if not to be
2223 saved.
2224 """
2225 # sum residuals in each leaf, in tree proposed by grow move
2226 if at.prec_scale is None: 1ab
2227 scaled_resid = resid 1ab
2228 else:
2229 scaled_resid = resid * at.prec_scale 1ab
2230 resid_tree = sum_resid( 1ab
2231 scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
2232 )
2234 # subtract starting tree from function
2235 resid_tree += pt.prec_tree * pt.leaf_tree 1ab
2237 # sum residuals in parent node modified by move
2238 resid_left = resid_tree[pt.move.left] 1ab
2239 resid_right = resid_tree[pt.move.right] 1ab
2240 resid_total = resid_left + resid_right 1ab
2241 assert pt.move.node.dtype == jnp.int32 1ab
2242 resid_tree = resid_tree.at[pt.move.node].set(resid_total) 1ab
2244 # compute acceptance ratio
2245 log_lk_ratio = compute_likelihood_ratio( 1ab
2246 resid_total, resid_left, resid_right, pt.prelkv, at.prelk
2247 )
2248 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab
2249 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab
2250 if not at.save_ratios: 1ab
2251 log_lk_ratio = None 1ab
2253 # determine whether to accept the move
2254 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1ab
2256 # compute leaves posterior and sample leaves
2257 mean_post = resid_tree * pt.prelf.mean_factor 1ab
2258 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab
2260 # copy leaves around such that the leaf indices point to the correct leaf
2261 to_prune = acc ^ pt.move.grow 1ab
2262 leaf_tree = ( 1ab
2263 leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)]
2264 .set(leaf_tree[pt.move.node])
2265 .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)]
2266 .set(leaf_tree[pt.move.node])
2267 )
2269 # replace old tree with new tree in function values
2270 resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices] 1ab
2272 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1ab
2275def sum_resid( 1ab
2276 scaled_resid: Float32[Array, ' n'],
2277 leaf_indices: UInt[Array, ' n'],
2278 tree_size: int,
2279 batch_size: int | None,
2280) -> Float32[Array, ' {tree_size}']:
2281 """
2282 Sum the residuals in each leaf.
2284 Parameters
2285 ----------
2286 scaled_resid
2287 The residuals (data minus forest value) multiplied by the error
2288 precision scale.
2289 leaf_indices
2290 The leaf indices of the tree (in which leaf each data point falls into).
2291 tree_size
2292 The size of the tree array (2 ** d).
2293 batch_size
2294 The data batch size for the aggregation. Batching increases numerical
2295 accuracy and parallelism.
2297 Returns
2298 -------
2299 The sum of the residuals at data points in each leaf.
2300 """
2301 if batch_size is None: 1ab
2302 aggr_func = _aggregate_scatter 1ab
2303 else:
2304 aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size) 1ab
2305 return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32) 1ab
2308def _aggregate_batched_onetree( 1ab
2309 values: Shaped[Array, '*'],
2310 indices: Integer[Array, '*'],
2311 size: int,
2312 dtype: jnp.dtype,
2313 batch_size: int,
2314) -> Float32[Array, ' {size}']:
2315 (n,) = indices.shape 1ab
2316 nbatches = n // batch_size + bool(n % batch_size) 1ab
2317 batch_indices = jnp.arange(n) % nbatches 1ab
2318 return ( 1ab
2319 jnp.zeros((size, nbatches), dtype)
2320 .at[indices, batch_indices]
2321 .add(values)
2322 .sum(axis=1)
2323 )
2326def compute_likelihood_ratio( 1ab
2327 total_resid: Float32[Array, ''],
2328 left_resid: Float32[Array, ''],
2329 right_resid: Float32[Array, ''],
2330 prelkv: PreLkV,
2331 prelk: PreLk,
2332) -> Float32[Array, '']:
2333 """
2334 Compute the likelihood ratio of a grow move.
2336 Parameters
2337 ----------
2338 total_resid
2339 left_resid
2340 right_resid
2341 The sum of the residuals (scaled by error precision scale) of the
2342 datapoints falling in the nodes involved in the moves.
2343 prelkv
2344 prelk
2345 The pre-computed terms of the likelihood ratio, see
2346 `precompute_likelihood_terms`.
2348 Returns
2349 -------
2350 The likelihood ratio P(data | new tree) / P(data | old tree).
2351 """
2352 exp_term = prelk.exp_factor * ( 1ab
2353 left_resid * left_resid / prelkv.sigma2_left
2354 + right_resid * right_resid / prelkv.sigma2_right
2355 - total_resid * total_resid / prelkv.sigma2_total
2356 )
2357 return prelkv.sqrt_term + exp_term 1ab
2360def accept_moves_final_stage(bart: State, moves: Moves) -> State: 1ab
2361 """
2362 Post-process the mcmc state after accepting/rejecting the moves.
2364 This function is separate from `accept_moves_sequential_stage` to signal it
2365 can work in parallel across trees.
2367 Parameters
2368 ----------
2369 bart
2370 A partially updated BART mcmc state.
2371 moves
2372 The proposed moves (see `propose_moves`) as updated by
2373 `accept_moves_sequential_stage`.
2375 Returns
2376 -------
2377 The fully updated BART mcmc state.
2378 """
2379 return replace( 1ab
2380 bart,
2381 forest=replace(
2382 bart.forest,
2383 grow_acc_count=jnp.sum(moves.acc & moves.grow),
2384 prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2385 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2386 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
2387 ),
2388 )
2391@vmap_nodoc 1ab
2392def apply_moves_to_leaf_indices( 1ab
2393 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
2394) -> UInt[Array, 'num_trees n']:
2395 """
2396 Update the leaf indices to match the accepted move.
2398 Parameters
2399 ----------
2400 leaf_indices
2401 The index of the leaf each datapoint falls into, if the grow move was
2402 accepted.
2403 moves
2404 The proposed moves (see `propose_moves`), as updated by
2405 `accept_moves_sequential_stage`.
2407 Returns
2408 -------
2409 The updated leaf indices.
2410 """
2411 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab
2412 is_child = (leaf_indices & mask) == moves.left 1ab
2413 return jnp.where( 1ab
2414 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
2415 )
2418@vmap_nodoc 1ab
2419def apply_moves_to_split_trees( 1ab
2420 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2421) -> UInt[Array, 'num_trees 2**(d-1)']:
2422 """
2423 Update the split trees to match the accepted move.
2425 Parameters
2426 ----------
2427 split_tree
2428 The cutpoints of the decision nodes in the initial trees.
2429 moves
2430 The proposed moves (see `propose_moves`), as updated by
2431 `accept_moves_sequential_stage`.
2433 Returns
2434 -------
2435 The updated split trees.
2436 """
2437 assert moves.to_prune is not None 1ab
2438 return ( 1ab
2439 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
2440 .set(moves.grow_split.astype(split_tree.dtype))
2441 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
2442 .set(0)
2443 )
2446def step_sigma(key: Key[Array, ''], bart: State) -> State: 1ab
2447 """
2448 MCMC-update the error variance (factor).
2450 Parameters
2451 ----------
2452 key
2453 A jax random key.
2454 bart
2455 A BART mcmc state.
2457 Returns
2458 -------
2459 The new BART mcmc state, with an updated `sigma2`.
2460 """
2461 resid = bart.resid 1ab
2462 alpha = bart.sigma2_alpha + resid.size / 2 1ab
2463 if bart.prec_scale is None: 1ab
2464 scaled_resid = resid 1ab
2465 else:
2466 scaled_resid = resid * bart.prec_scale 1ab
2467 norm2 = resid @ scaled_resid 1ab
2468 beta = bart.sigma2_beta + norm2 / 2 1ab
2470 sample = random.gamma(key, alpha) 1ab
2471 # random.gamma seems to be slow at compiling, maybe cdf inversion would
2472 # be better, but it's not implemented in jax
2473 return replace(bart, sigma2=beta / sample) 1ab
2476def step_z(key: Key[Array, ''], bart: State) -> State: 1ab
2477 """
2478 MCMC-update the latent variable for binary regression.
2480 Parameters
2481 ----------
2482 key
2483 A jax random key.
2484 bart
2485 A BART MCMC state.
2487 Returns
2488 -------
2489 The updated BART MCMC state.
2490 """
2491 trees_plus_offset = bart.z - bart.resid 1ab
2492 assert bart.y.dtype == bool 1ab
2493 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1ab
2494 z = trees_plus_offset + resid 1ab
2495 return replace(bart, z=z, resid=resid) 1ab
2498def step_s(key: Key[Array, ''], bart: State) -> State: 1ab
2499 """
2500 Update `log_s` using Dirichlet sampling.
2502 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
2503 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
2504 varcount is the count of how many times each variable is used in the
2505 current forest.
2507 Parameters
2508 ----------
2509 key
2510 Random key for sampling.
2511 bart
2512 The current BART state.
2514 Returns
2515 -------
2516 Updated BART state with re-sampled `log_s`.
2518 """
2519 assert bart.forest.theta is not None 1ab
2521 # histogram current variable usage
2522 p = bart.forest.max_split.size 1ab
2523 varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree) 1ab
2525 # sample from Dirichlet posterior
2526 alpha = bart.forest.theta / p + varcount 1ab
2527 log_s = random.loggamma(key, alpha) 1ab
2529 # update forest with new s
2530 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1ab
2533def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State: 1ab
2534 """
2535 Update `theta`.
2537 The prior is theta / (theta + rho) ~ Beta(a, b).
2539 Parameters
2540 ----------
2541 key
2542 Random key for sampling.
2543 bart
2544 The current BART state.
2545 num_grid
2546 The number of points in the evenly-spaced grid used to sample
2547 theta / (theta + rho).
2549 Returns
2550 -------
2551 Updated BART state with re-sampled `theta`.
2552 """
2553 assert bart.forest.log_s is not None 1ab
2554 assert bart.forest.rho is not None 1ab
2555 assert bart.forest.a is not None 1ab
2556 assert bart.forest.b is not None 1ab
2558 # the grid points are the midpoints of num_grid bins in (0, 1)
2559 padding = 1 / (2 * num_grid) 1ab
2560 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1ab
2562 # normalize s
2563 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1ab
2565 # sample lambda
2566 logp, theta_grid = _log_p_lamda( 1ab
2567 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
2568 )
2569 i = random.categorical(key, logp) 1ab
2570 theta = theta_grid[i] 1ab
2572 return replace(bart, forest=replace(bart.forest, theta=theta)) 1ab
2575def _log_p_lamda( 1ab
2576 lamda: Float32[Array, ' num_grid'],
2577 log_s: Float32[Array, ' p'],
2578 rho: Float32[Array, ''],
2579 a: Float32[Array, ''],
2580 b: Float32[Array, ''],
2581) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
2582 # in the following I use lamda[::-1] == 1 - lamda
2583 theta = rho * lamda / lamda[::-1] 1ab
2584 p = log_s.size 1ab
2585 return ( 1ab
2586 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
2587 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
2588 + gammaln(theta)
2589 - p * gammaln(theta / p)
2590 + theta / p * jnp.sum(log_s)
2591 ), theta
2594def step_sparse(key: Key[Array, ''], bart: State) -> State: 1ab
2595 """
2596 Update the sparsity parameters.
2598 This invokes `step_s`, and then `step_theta` only if the parameters of
2599 the theta prior are defined.
2601 Parameters
2602 ----------
2603 key
2604 Random key for sampling.
2605 bart
2606 The current BART state.
2608 Returns
2609 -------
2610 Updated BART state with re-sampled `log_s` and `theta`.
2611 """
2612 keys = split(key) 1ab
2613 bart = step_s(keys.pop(), bart) 1ab
2614 if bart.forest.rho is not None: 1ab
2615 bart = step_theta(keys.pop(), bart) 1ab
2616 return bart 1ab