Coverage for src/bartz/mcmcstep.py: 95%
567 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-09-06 16:14 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-09-06 16:14 +0000
1# bartz/src/bartz/mcmcstep.py
2#
3# Copyright (c) 2024-2025, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""
26Functions that implement the BART posterior MCMC initialization and update step.
28Functions that do MCMC steps operate by taking as input a bart state, and
29outputting a new state. The inputs are not modified.
31The entry points are:
33 - `State`: The dataclass that represents a BART MCMC state.
34 - `init`: Creates an initial `State` from data and configurations.
35 - `step`: Performs one full MCMC step on a `State`, returning a new `State`.
36 - `step_sparse`: Performs the MCMC update for variable selection, which is skipped in `step`.
37"""
39import math 1ab
40from dataclasses import replace 1ab
41from functools import cache, partial 1ab
42from typing import Any, Literal 1ab
44import jax 1ab
45from equinox import Module, field, tree_at 1ab
46from jax import lax, random 1ab
47from jax import numpy as jnp 1ab
48from jax.scipy.linalg import solve_triangular 1ab
49from jax.scipy.special import gammaln, logsumexp 1ab
50from jaxtyping import Array, Bool, Float32, Int32, Integer, Key, Shaped, UInt 1ab
52from bartz import grove 1ab
53from bartz.jaxext import ( 1ab
54 minimal_unsigned_dtype,
55 split,
56 truncated_normal_onesided,
57 vmap_nodoc,
58)
61class Forest(Module): 1ab
62 """
63 Represents the MCMC state of a sum of trees.
65 Parameters
66 ----------
67 leaf_tree
68 The leaf values.
69 var_tree
70 The decision axes.
71 split_tree
72 The decision boundaries.
73 affluence_tree
74 Marks leaves that can be grown.
75 max_split
76 The maximum split index for each predictor.
77 blocked_vars
78 Indices of variables that are not used. This shall include at least
79 the `i` such that ``max_split[i] == 0``, otherwise behavior is
80 undefined.
81 p_nonterminal
82 The prior probability of each node being nonterminal, conditional on
83 its ancestors. Includes the nodes at maximum depth which should be set
84 to 0.
85 p_propose_grow
86 The unnormalized probability of picking a leaf for a grow proposal.
87 leaf_indices
88 The index of the leaf each datapoints falls into, for each tree.
89 min_points_per_decision_node
90 The minimum number of data points in a decision node.
91 min_points_per_leaf
92 The minimum number of data points in a leaf node.
93 resid_batch_size
94 count_batch_size
95 The data batch sizes for computing the sufficient statistics. If `None`,
96 they are computed with no batching.
97 log_trans_prior
98 The log transition and prior Metropolis-Hastings ratio for the
99 proposed move on each tree.
100 log_likelihood
101 The log likelihood ratio.
102 grow_prop_count
103 prune_prop_count
104 The number of grow/prune proposals made during one full MCMC cycle.
105 grow_acc_count
106 prune_acc_count
107 The number of grow/prune moves accepted during one full MCMC cycle.
108 sigma_mu2
109 The prior variance of a leaf, conditional on the tree structure.
110 log_s
111 The logarithm of the prior probability for choosing a variable to split
112 along in a decision rule, conditional on the ancestors. Not normalized.
113 If `None`, use a uniform distribution.
114 theta
115 The concentration parameter for the Dirichlet prior on the variable
116 distribution `s`. Required only to update `s`.
117 a
118 b
119 rho
120 Parameters of the prior on `theta`. Required only to sample `theta`.
121 See `step_theta`.
122 """
124 leaf_tree: Float32[Array, 'num_trees 2**d'] 1ab
125 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
126 split_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
127 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
128 max_split: UInt[Array, ' p'] 1ab
129 blocked_vars: UInt[Array, ' k'] | None 1ab
130 p_nonterminal: Float32[Array, ' 2**d'] 1ab
131 p_propose_grow: Float32[Array, ' 2**(d-1)'] 1ab
132 leaf_indices: UInt[Array, 'num_trees n'] 1ab
133 min_points_per_decision_node: Int32[Array, ''] | None 1ab
134 min_points_per_leaf: Int32[Array, ''] | None 1ab
135 resid_batch_size: int | None = field(static=True) 1ab
136 count_batch_size: int | None = field(static=True) 1ab
137 log_trans_prior: Float32[Array, ' num_trees'] | None 1ab
138 log_likelihood: Float32[Array, ' num_trees'] | None 1ab
139 grow_prop_count: Int32[Array, ''] 1ab
140 prune_prop_count: Int32[Array, ''] 1ab
141 grow_acc_count: Int32[Array, ''] 1ab
142 prune_acc_count: Int32[Array, ''] 1ab
143 sigma_mu2: Float32[Array, ''] 1ab
144 log_s: Float32[Array, ' p'] | None 1ab
145 theta: Float32[Array, ''] | None 1ab
146 a: Float32[Array, ''] | None 1ab
147 b: Float32[Array, ''] | None 1ab
148 rho: Float32[Array, ''] | None 1ab
151class State(Module): 1ab
152 """
153 Represents the MCMC state of BART.
155 Parameters
156 ----------
157 X
158 The predictors.
159 y
160 The response. If the data type is `bool`, the model is binary regression.
161 resid
162 The residuals (`y` or `z` minus sum of trees).
163 z
164 The latent variable for binary regression. `None` in continuous
165 regression.
166 offset
167 Constant shift added to the sum of trees.
168 sigma2
169 The error variance. `None` in binary regression.
170 prec_scale
171 The scale on the error precision, i.e., ``1 / error_scale ** 2``.
172 `None` in binary regression.
173 sigma2_alpha
174 sigma2_beta
175 The shape and scale parameters of the inverse gamma prior on the noise
176 variance. `None` in binary regression.
177 forest
178 The sum of trees model.
179 """
181 X: UInt[Array, 'p n'] 1ab
182 y: Float32[Array, ' n'] | Bool[Array, ' n'] 1ab
183 z: None | Float32[Array, ' n'] 1ab
184 offset: Float32[Array, ''] 1ab
185 resid: Float32[Array, ' n'] 1ab
186 sigma2: Float32[Array, ''] | None 1ab
187 prec_scale: Float32[Array, ' n'] | None 1ab
188 sigma2_alpha: Float32[Array, ''] | None 1ab
189 sigma2_beta: Float32[Array, ''] | None 1ab
190 forest: Forest 1ab
193def init( 1ab
194 *,
195 X: UInt[Any, 'p n'],
196 y: Float32[Any, ' n'] | Bool[Any, ' n'],
197 offset: float | Float32[Any, ''] = 0.0,
198 max_split: UInt[Any, ' p'],
199 num_trees: int,
200 p_nonterminal: Float32[Any, ' d-1'],
201 sigma_mu2: float | Float32[Any, ''],
202 sigma2_alpha: float | Float32[Any, ''] | None = None,
203 sigma2_beta: float | Float32[Any, ''] | None = None,
204 error_scale: Float32[Any, ' n'] | None = None,
205 min_points_per_decision_node: int | Integer[Any, ''] | None = None,
206 resid_batch_size: int | None | Literal['auto'] = 'auto',
207 count_batch_size: int | None | Literal['auto'] = 'auto',
208 save_ratios: bool = False,
209 filter_splitless_vars: bool = True,
210 min_points_per_leaf: int | Integer[Any, ''] | None = None,
211 log_s: Float32[Any, ' p'] | None = None,
212 theta: float | Float32[Any, ''] | None = None,
213 a: float | Float32[Any, ''] | None = None,
214 b: float | Float32[Any, ''] | None = None,
215 rho: float | Float32[Any, ''] | None = None,
216) -> State:
217 """
218 Make a BART posterior sampling MCMC initial state.
220 Parameters
221 ----------
222 X
223 The predictors. Note this is trasposed compared to the usual convention.
224 y
225 The response. If the data type is `bool`, the regression model is binary
226 regression with probit.
227 offset
228 Constant shift added to the sum of trees. 0 if not specified.
229 max_split
230 The maximum split index for each variable. All split ranges start at 1.
231 num_trees
232 The number of trees in the forest.
233 p_nonterminal
234 The probability of a nonterminal node at each depth. The maximum depth
235 of trees is fixed by the length of this array.
236 sigma_mu2
237 The prior variance of a leaf, conditional on the tree structure. The
238 prior variance of the sum of trees is ``num_trees * sigma_mu2``. The
239 prior mean of leaves is always zero.
240 sigma2_alpha
241 sigma2_beta
242 The shape and scale parameters of the inverse gamma prior on the error
243 variance. Leave unspecified for binary regression.
244 error_scale
245 Each error is scaled by the corresponding factor in `error_scale`, so
246 the error variance for ``y[i]`` is ``sigma2 * error_scale[i] ** 2``.
247 Not supported for binary regression. If not specified, defaults to 1 for
248 all points, but potentially skipping calculations.
249 min_points_per_decision_node
250 The minimum number of data points in a decision node. 0 if not
251 specified.
252 resid_batch_size
253 count_batch_size
254 The batch sizes, along datapoints, for summing the residuals and
255 counting the number of datapoints in each leaf. `None` for no batching.
256 If 'auto', pick a value based on the device of `y`, or the default
257 device.
258 save_ratios
259 Whether to save the Metropolis-Hastings ratios.
260 filter_splitless_vars
261 Whether to check `max_split` for variables without available cutpoints.
262 If any are found, they are put into a list of variables to exclude from
263 the MCMC. If `False`, no check is performed, but the results may be
264 wrong if any variable is blocked. The function is jax-traceable only
265 if this is set to `False`.
266 min_points_per_leaf
267 The minimum number of datapoints in a leaf node. 0 if not specified.
268 Unlike `min_points_per_decision_node`, this constraint is not taken into
269 account in the Metropolis-Hastings ratio because it would be expensive
270 to compute. Grow moves that would violate this constraint are vetoed.
271 This parameter is independent of `min_points_per_decision_node` and
272 there is no check that they are coherent. It makes sense to set
273 ``min_points_per_decision_node >= 2 * min_points_per_leaf``.
274 log_s
275 The logarithm of the prior probability for choosing a variable to split
276 along in a decision rule, conditional on the ancestors. Not normalized.
277 If not specified, use a uniform distribution. If not specified and
278 `theta` or `rho`, `a`, `b` are, it's initialized automatically.
279 theta
280 The concentration parameter for the Dirichlet prior on `s`. Required
281 only to update `log_s`. If not specified, and `rho`, `a`, `b` are
282 specified, it's initialized automatically.
283 a
284 b
285 rho
286 Parameters of the prior on `theta`. Required only to sample `theta`.
288 Returns
289 -------
290 An initialized BART MCMC state.
292 Raises
293 ------
294 ValueError
295 If `y` is boolean and arguments unused in binary regression are set.
297 Notes
298 -----
299 In decision nodes, the values in ``X[i, :]`` are compared to a cutpoint out
300 of the range ``[1, 2, ..., max_split[i]]``. A point belongs to the left
301 child iff ``X[i, j] < cutpoint``. Thus it makes sense for ``X[i, :]`` to be
302 integers in the range ``[0, 1, ..., max_split[i]]``.
303 """
304 p_nonterminal = jnp.asarray(p_nonterminal) 1ab
305 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1ab
306 max_depth = p_nonterminal.size 1ab
308 @partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1ab
309 def make_forest(max_depth, dtype): 1ab
310 return grove.make_tree(max_depth, dtype) 1ab
312 y = jnp.asarray(y) 1ab
313 offset = jnp.asarray(offset) 1ab
315 resid_batch_size, count_batch_size = _choose_suffstat_batch_size( 1ab
316 resid_batch_size, count_batch_size, y, 2**max_depth * num_trees
317 )
319 is_binary = y.dtype == bool 1ab
320 if is_binary: 1ab
321 if (error_scale, sigma2_alpha, sigma2_beta) != 3 * (None,): 321 ↛ 322line 321 didn't jump to line 322 because the condition on line 321 was never true1ab
322 msg = (
323 'error_scale, sigma2_alpha, and sigma2_beta must be set '
324 ' to `None` for binary regression.'
325 )
326 raise ValueError(msg)
327 sigma2 = None 1ab
328 else:
329 sigma2_alpha = jnp.asarray(sigma2_alpha) 1ab
330 sigma2_beta = jnp.asarray(sigma2_beta) 1ab
331 sigma2 = sigma2_beta / sigma2_alpha 1ab
333 max_split = jnp.asarray(max_split) 1ab
335 if filter_splitless_vars: 1ab
336 (blocked_vars,) = jnp.nonzero(max_split == 0) 1ab
337 blocked_vars = blocked_vars.astype(minimal_unsigned_dtype(max_split.size)) 1ab
338 # see `fully_used_variables` for the type cast
339 else:
340 blocked_vars = None 1ab
342 # check and initialize sparsity parameters
343 if not _all_none_or_not_none(rho, a, b): 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true1ab
344 msg = 'rho, a, b are not either all `None` or all set'
345 raise ValueError(msg)
346 if theta is None and rho is not None: 1ab
347 theta = rho 1ab
348 if log_s is None and theta is not None: 1ab
349 log_s = jnp.zeros(max_split.size) 1ab
351 return State( 1ab
352 X=jnp.asarray(X),
353 y=y,
354 z=jnp.full(y.shape, offset) if is_binary else None,
355 offset=offset,
356 resid=jnp.zeros(y.shape) if is_binary else y - offset,
357 sigma2=sigma2,
358 prec_scale=(
359 None if error_scale is None else lax.reciprocal(jnp.square(error_scale))
360 ),
361 sigma2_alpha=sigma2_alpha,
362 sigma2_beta=sigma2_beta,
363 forest=Forest(
364 leaf_tree=make_forest(max_depth, jnp.float32),
365 var_tree=make_forest(max_depth - 1, minimal_unsigned_dtype(X.shape[0] - 1)),
366 split_tree=make_forest(max_depth - 1, max_split.dtype),
367 affluence_tree=(
368 make_forest(max_depth - 1, bool)
369 .at[:, 1]
370 .set(
371 True
372 if min_points_per_decision_node is None
373 else y.size >= min_points_per_decision_node
374 )
375 ),
376 blocked_vars=blocked_vars,
377 max_split=max_split,
378 grow_prop_count=jnp.zeros((), int),
379 grow_acc_count=jnp.zeros((), int),
380 prune_prop_count=jnp.zeros((), int),
381 prune_acc_count=jnp.zeros((), int),
382 p_nonterminal=p_nonterminal[grove.tree_depths(2**max_depth)],
383 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
384 leaf_indices=jnp.ones(
385 (num_trees, y.size), minimal_unsigned_dtype(2**max_depth - 1)
386 ),
387 min_points_per_decision_node=_asarray_or_none(min_points_per_decision_node),
388 min_points_per_leaf=_asarray_or_none(min_points_per_leaf),
389 resid_batch_size=resid_batch_size,
390 count_batch_size=count_batch_size,
391 log_trans_prior=jnp.zeros(num_trees) if save_ratios else None,
392 log_likelihood=jnp.zeros(num_trees) if save_ratios else None,
393 sigma_mu2=jnp.asarray(sigma_mu2),
394 log_s=_asarray_or_none(log_s),
395 theta=_asarray_or_none(theta),
396 rho=_asarray_or_none(rho),
397 a=_asarray_or_none(a),
398 b=_asarray_or_none(b),
399 ),
400 )
403def _all_none_or_not_none(*args): 1ab
404 is_none = [x is None for x in args] 1ab
405 return all(is_none) or not any(is_none) 1ab
408def _asarray_or_none(x): 1ab
409 if x is None: 1ab
410 return None 1ab
411 return jnp.asarray(x) 1ab
414def _choose_suffstat_batch_size( 1ab
415 resid_batch_size, count_batch_size, y, forest_size
416) -> tuple[int | None, ...]:
417 @cache 1ab
418 def get_platform(): 1ab
419 try: 1ab
420 device = y.devices().pop() 1ab
421 except jax.errors.ConcretizationTypeError: 1ab
422 device = jax.devices()[0] 1ab
423 platform = device.platform 1ab
424 if platform not in ('cpu', 'gpu'): 424 ↛ 425line 424 didn't jump to line 425 because the condition on line 424 was never true1ab
425 msg = f'Unknown platform: {platform}'
426 raise KeyError(msg)
427 return platform 1ab
429 if resid_batch_size == 'auto': 1ab
430 platform = get_platform() 1ab
431 n = max(1, y.size) 1ab
432 if platform == 'cpu': 432 ↛ 434line 432 didn't jump to line 434 because the condition on line 432 was always true1ab
433 resid_batch_size = 2 ** round(math.log2(n / 6)) # n/6 1ab
434 elif platform == 'gpu':
435 resid_batch_size = 2 ** round((1 + math.log2(n)) / 3) # n^1/3
436 resid_batch_size = max(1, resid_batch_size) 1ab
438 if count_batch_size == 'auto': 1ab
439 platform = get_platform() 1ab
440 if platform == 'cpu': 440 ↛ 442line 440 didn't jump to line 442 because the condition on line 440 was always true1ab
441 count_batch_size = None 1ab
442 elif platform == 'gpu':
443 n = max(1, y.size)
444 count_batch_size = 2 ** round(math.log2(n) / 2 - 2) # n^1/2
445 # /4 is good on V100, /2 on L4/T4, still haven't tried A100
446 max_memory = 2**29
447 itemsize = 4
448 min_batch_size = math.ceil(forest_size * itemsize * n / max_memory)
449 count_batch_size = max(count_batch_size, min_batch_size)
450 count_batch_size = max(1, count_batch_size)
452 return resid_batch_size, count_batch_size 1ab
455@jax.jit 1ab
456def step(key: Key[Array, ''], bart: State) -> State: 1ab
457 """
458 Do one MCMC step.
460 Parameters
461 ----------
462 key
463 A jax random key.
464 bart
465 A BART mcmc state, as created by `init`.
467 Returns
468 -------
469 The new BART mcmc state.
470 """
471 keys = split(key) 1ab
473 if bart.y.dtype == bool: # binary regression 1ab
474 bart = replace(bart, sigma2=jnp.float32(1)) 1ab
475 bart = step_trees(keys.pop(), bart) 1ab
476 bart = replace(bart, sigma2=None) 1ab
477 return step_z(keys.pop(), bart) 1ab
479 else: # continuous regression
480 bart = step_trees(keys.pop(), bart) 1ab
481 return step_sigma(keys.pop(), bart) 1ab
484def step_trees(key: Key[Array, ''], bart: State) -> State: 1ab
485 """
486 Forest sampling step of BART MCMC.
488 Parameters
489 ----------
490 key
491 A jax random key.
492 bart
493 A BART mcmc state, as created by `init`.
495 Returns
496 -------
497 The new BART mcmc state.
499 Notes
500 -----
501 This function zeroes the proposal counters.
502 """
503 keys = split(key) 1ab
504 moves = propose_moves(keys.pop(), bart.forest) 1ab
505 return accept_moves_and_sample_leaves(keys.pop(), bart, moves) 1ab
508class Moves(Module): 1ab
509 """
510 Moves proposed to modify each tree.
512 Parameters
513 ----------
514 allowed
515 Whether there is a possible move. If `False`, the other values may not
516 make sense. The only case in which a move is marked as allowed but is
517 then vetoed is if it does not satisfy `min_points_per_leaf`, which for
518 efficiency is implemented post-hoc without changing the rest of the
519 MCMC logic.
520 grow
521 Whether the move is a grow move or a prune move.
522 num_growable
523 The number of growable leaves in the original tree.
524 node
525 The index of the leaf to grow or node to prune.
526 left
527 right
528 The indices of the children of 'node'.
529 partial_ratio
530 A factor of the Metropolis-Hastings ratio of the move. It lacks the
531 likelihood ratio, the probability of proposing the prune move, and the
532 probability that the children of the modified node are terminal. If the
533 move is PRUNE, the ratio is inverted. `None` once
534 `log_trans_prior_ratio` has been computed.
535 log_trans_prior_ratio
536 The logarithm of the product of the transition and prior terms of the
537 Metropolis-Hastings ratio for the acceptance of the proposed move.
538 `None` if not yet computed. If PRUNE, the log-ratio is negated.
539 grow_var
540 The decision axes of the new rules.
541 grow_split
542 The decision boundaries of the new rules.
543 var_tree
544 The updated decision axes of the trees, valid whatever move.
545 affluence_tree
546 A partially updated `affluence_tree`, marking non-leaf nodes that would
547 become leaves if the move was accepted. This mark initially (out of
548 `propose_moves`) takes into account if there would be available decision
549 rules to grow the leaf, and whether there are enough datapoints in the
550 node is marked in `accept_moves_parallel_stage`.
551 logu
552 The logarithm of a uniform (0, 1] random variable to be used to
553 accept the move. It's in (-oo, 0].
554 acc
555 Whether the move was accepted. `None` if not yet computed.
556 to_prune
557 Whether the final operation to apply the move is pruning. This indicates
558 an accepted prune move or a rejected grow move. `None` if not yet
559 computed.
560 """
562 allowed: Bool[Array, ' num_trees'] 1ab
563 grow: Bool[Array, ' num_trees'] 1ab
564 num_growable: UInt[Array, ' num_trees'] 1ab
565 node: UInt[Array, ' num_trees'] 1ab
566 left: UInt[Array, ' num_trees'] 1ab
567 right: UInt[Array, ' num_trees'] 1ab
568 partial_ratio: Float32[Array, ' num_trees'] | None 1ab
569 log_trans_prior_ratio: None | Float32[Array, ' num_trees'] 1ab
570 grow_var: UInt[Array, ' num_trees'] 1ab
571 grow_split: UInt[Array, ' num_trees'] 1ab
572 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
573 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
574 logu: Float32[Array, ' num_trees'] 1ab
575 acc: None | Bool[Array, ' num_trees'] 1ab
576 to_prune: None | Bool[Array, ' num_trees'] 1ab
579def propose_moves(key: Key[Array, ''], forest: Forest) -> Moves: 1ab
580 """
581 Propose moves for all the trees.
583 There are two types of moves: GROW (convert a leaf to a decision node and
584 add two leaves beneath it) and PRUNE (convert the parent of two leaves to a
585 leaf, deleting its children).
587 Parameters
588 ----------
589 key
590 A jax random key.
591 forest
592 The `forest` field of a BART MCMC state.
594 Returns
595 -------
596 The proposed move for each tree.
597 """
598 num_trees, _ = forest.leaf_tree.shape 1ab
599 keys = split(key, 1 + 2 * num_trees) 1ab
601 # compute moves
602 grow_moves = propose_grow_moves( 1ab
603 keys.pop(num_trees),
604 forest.var_tree,
605 forest.split_tree,
606 forest.affluence_tree,
607 forest.max_split,
608 forest.blocked_vars,
609 forest.p_nonterminal,
610 forest.p_propose_grow,
611 forest.log_s,
612 )
613 prune_moves = propose_prune_moves( 1ab
614 keys.pop(num_trees),
615 forest.split_tree,
616 grow_moves.affluence_tree,
617 forest.p_nonterminal,
618 forest.p_propose_grow,
619 )
621 u, exp1mlogu = random.uniform(keys.pop(), (2, num_trees)) 1ab
623 # choose between grow or prune
624 p_grow = jnp.where( 1ab
625 grow_moves.allowed & prune_moves.allowed, 0.5, grow_moves.allowed
626 )
627 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1ab
629 # compute children indices
630 node = jnp.where(grow, grow_moves.node, prune_moves.node) 1ab
631 left = node << 1 1ab
632 right = left + 1 1ab
634 return Moves( 1ab
635 allowed=grow_moves.allowed | prune_moves.allowed,
636 grow=grow,
637 num_growable=grow_moves.num_growable,
638 node=node,
639 left=left,
640 right=right,
641 partial_ratio=jnp.where(
642 grow, grow_moves.partial_ratio, prune_moves.partial_ratio
643 ),
644 log_trans_prior_ratio=None, # will be set in complete_ratio
645 grow_var=grow_moves.var,
646 grow_split=grow_moves.split,
647 # var_tree does not need to be updated if prune
648 var_tree=grow_moves.var_tree,
649 # affluence_tree is updated for both moves unconditionally, prune last
650 affluence_tree=prune_moves.affluence_tree,
651 logu=jnp.log1p(-exp1mlogu),
652 acc=None, # will be set in accept_moves_sequential_stage
653 to_prune=None, # will be set in accept_moves_sequential_stage
654 )
657class GrowMoves(Module): 1ab
658 """
659 Represent a proposed grow move for each tree.
661 Parameters
662 ----------
663 allowed
664 Whether the move is allowed for proposal.
665 num_growable
666 The number of leaves that can be proposed for grow.
667 node
668 The index of the leaf to grow. ``2 ** d`` if there are no growable
669 leaves.
670 var
671 split
672 The decision axis and boundary of the new rule.
673 partial_ratio
674 A factor of the Metropolis-Hastings ratio of the move. It lacks
675 the likelihood ratio and the probability of proposing the prune
676 move.
677 var_tree
678 The updated decision axes of the tree.
679 affluence_tree
680 A partially updated `affluence_tree` that marks each new leaf that
681 would be produced as `True` if it would have available decision rules.
682 """
684 allowed: Bool[Array, ' num_trees'] 1ab
685 num_growable: UInt[Array, ' num_trees'] 1ab
686 node: UInt[Array, ' num_trees'] 1ab
687 var: UInt[Array, ' num_trees'] 1ab
688 split: UInt[Array, ' num_trees'] 1ab
689 partial_ratio: Float32[Array, ' num_trees'] 1ab
690 var_tree: UInt[Array, 'num_trees 2**(d-1)'] 1ab
691 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
694@partial(vmap_nodoc, in_axes=(0, 0, 0, 0, None, None, None, None, None)) 1ab
695def propose_grow_moves( 1ab
696 key: Key[Array, ' num_trees'],
697 var_tree: UInt[Array, 'num_trees 2**(d-1)'],
698 split_tree: UInt[Array, 'num_trees 2**(d-1)'],
699 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'],
700 max_split: UInt[Array, ' p'],
701 blocked_vars: Int32[Array, ' k'] | None,
702 p_nonterminal: Float32[Array, ' 2**d'],
703 p_propose_grow: Float32[Array, ' 2**(d-1)'],
704 log_s: Float32[Array, ' p'] | None,
705) -> GrowMoves:
706 """
707 Propose a GROW move for each tree.
709 A GROW move picks a leaf node and converts it to a non-terminal node with
710 two leaf children.
712 Parameters
713 ----------
714 key
715 A jax random key.
716 var_tree
717 The splitting axes of the tree.
718 split_tree
719 The splitting points of the tree.
720 affluence_tree
721 Whether each leaf has enough points to be grown.
722 max_split
723 The maximum split index for each variable.
724 blocked_vars
725 The indices of the variables that have no available cutpoints.
726 p_nonterminal
727 The a priori probability of a node to be nonterminal conditional on the
728 ancestors, including at the maximum depth where it should be zero.
729 p_propose_grow
730 The unnormalized probability of choosing a leaf to grow.
731 log_s
732 Unnormalized log-probability used to choose a variable to split on
733 amongst the available ones.
735 Returns
736 -------
737 An object representing the proposed move.
739 Notes
740 -----
741 The move is not proposed if each leaf is already at maximum depth, or has
742 less datapoints than the requested threshold `min_points_per_decision_node`,
743 or it does not have any available decision rules given its ancestors. This
744 is marked by setting `allowed` to `False` and `num_growable` to 0.
745 """
746 keys = split(key, 3) 1ab
748 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf( 1ab
749 keys.pop(), split_tree, affluence_tree, p_propose_grow
750 )
752 # sample a decision rule
753 var, num_available_var = choose_variable( 1ab
754 keys.pop(), var_tree, split_tree, max_split, leaf_to_grow, blocked_vars, log_s
755 )
756 split_idx, l, r = choose_split( 1ab
757 keys.pop(), var, var_tree, split_tree, max_split, leaf_to_grow
758 )
760 # determine if the new leaves would have available decision rules; if the
761 # move is blocked, these values may not make sense
762 left_growable = right_growable = num_available_var > 1 1ab
763 left_growable |= l < split_idx 1ab
764 right_growable |= split_idx + 1 < r 1ab
765 left = leaf_to_grow << 1 1ab
766 right = left + 1 1ab
767 affluence_tree = affluence_tree.at[left].set(left_growable) 1ab
768 affluence_tree = affluence_tree.at[right].set(right_growable) 1ab
770 ratio = compute_partial_ratio( 1ab
771 prob_choose, num_prunable, p_nonterminal, leaf_to_grow
772 )
774 return GrowMoves( 1ab
775 allowed=num_growable > 0,
776 num_growable=num_growable,
777 node=leaf_to_grow,
778 var=var,
779 split=split_idx,
780 partial_ratio=ratio,
781 var_tree=var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)),
782 affluence_tree=affluence_tree,
783 )
786def choose_leaf( 1ab
787 key: Key[Array, ''],
788 split_tree: UInt[Array, ' 2**(d-1)'],
789 affluence_tree: Bool[Array, ' 2**(d-1)'],
790 p_propose_grow: Float32[Array, ' 2**(d-1)'],
791) -> tuple[Int32[Array, ''], Int32[Array, ''], Float32[Array, ''], Int32[Array, '']]:
792 """
793 Choose a leaf node to grow in a tree.
795 Parameters
796 ----------
797 key
798 A jax random key.
799 split_tree
800 The splitting points of the tree.
801 affluence_tree
802 Whether a leaf has enough points that it could be split into two leaves
803 satisfying the `min_points_per_leaf` requirement.
804 p_propose_grow
805 The unnormalized probability of choosing a leaf to grow.
807 Returns
808 -------
809 leaf_to_grow : Int32[Array, '']
810 The index of the leaf to grow. If ``num_growable == 0``, return
811 ``2 ** d``.
812 num_growable : Int32[Array, '']
813 The number of leaf nodes that can be grown, i.e., are nonterminal
814 and have at least twice `min_points_per_leaf`.
815 prob_choose : Float32[Array, '']
816 The (normalized) probability that this function had to choose that
817 specific leaf, given the arguments.
818 num_prunable : Int32[Array, '']
819 The number of leaf parents that could be pruned, after converting the
820 selected leaf to a non-terminal node.
821 """
822 is_growable = growable_leaves(split_tree, affluence_tree) 1ab
823 num_growable = jnp.count_nonzero(is_growable) 1ab
824 distr = jnp.where(is_growable, p_propose_grow, 0) 1ab
825 leaf_to_grow, distr_norm = categorical(key, distr) 1ab
826 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1ab
827 prob_choose = distr[leaf_to_grow] / jnp.where(distr_norm, distr_norm, 1) 1ab
828 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1ab
829 num_prunable = jnp.count_nonzero(is_parent) 1ab
830 return leaf_to_grow, num_growable, prob_choose, num_prunable 1ab
833def growable_leaves( 1ab
834 split_tree: UInt[Array, ' 2**(d-1)'], affluence_tree: Bool[Array, ' 2**(d-1)']
835) -> Bool[Array, ' 2**(d-1)']:
836 """
837 Return a mask indicating the leaf nodes that can be proposed for growth.
839 The condition is that a leaf is not at the bottom level, has available
840 decision rules given its ancestors, and has at least
841 `min_points_per_decision_node` points.
843 Parameters
844 ----------
845 split_tree
846 The splitting points of the tree.
847 affluence_tree
848 Marks leaves that can be grown.
850 Returns
851 -------
852 The mask indicating the leaf nodes that can be proposed to grow.
854 Notes
855 -----
856 This function needs `split_tree` and not just `affluence_tree` because
857 `affluence_tree` can be "dirty", i.e., mark unused nodes as `True`.
858 """
859 return grove.is_actual_leaf(split_tree) & affluence_tree 1ab
862def categorical( 1ab
863 key: Key[Array, ''], distr: Float32[Array, ' n']
864) -> tuple[Int32[Array, ''], Float32[Array, '']]:
865 """
866 Return a random integer from an arbitrary distribution.
868 Parameters
869 ----------
870 key
871 A jax random key.
872 distr
873 An unnormalized probability distribution.
875 Returns
876 -------
877 u : Int32[Array, '']
878 A random integer in the range ``[0, n)``. If all probabilities are zero,
879 return ``n``.
880 norm : Float32[Array, '']
881 The sum of `distr`.
883 Notes
884 -----
885 This function uses a cumsum instead of the Gumbel trick, so it's ok only
886 for small ranges with probabilities well greater than 0.
887 """
888 ecdf = jnp.cumsum(distr) 1ab
889 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1ab
890 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1ab
893def choose_variable( 1ab
894 key: Key[Array, ''],
895 var_tree: UInt[Array, ' 2**(d-1)'],
896 split_tree: UInt[Array, ' 2**(d-1)'],
897 max_split: UInt[Array, ' p'],
898 leaf_index: Int32[Array, ''],
899 blocked_vars: Int32[Array, ' k'] | None,
900 log_s: Float32[Array, ' p'] | None,
901) -> tuple[Int32[Array, ''], Int32[Array, '']]:
902 """
903 Choose a variable to split on for a new non-terminal node.
905 Parameters
906 ----------
907 key
908 A jax random key.
909 var_tree
910 The variable indices of the tree.
911 split_tree
912 The splitting points of the tree.
913 max_split
914 The maximum split index for each variable.
915 leaf_index
916 The index of the leaf to grow.
917 blocked_vars
918 The indices of the variables that have no available cutpoints. If
919 `None`, all variables are assumed unblocked.
920 log_s
921 The logarithm of the prior probability for choosing a variable. If
922 `None`, use a uniform distribution.
924 Returns
925 -------
926 var : Int32[Array, '']
927 The index of the variable to split on.
928 num_available_var : Int32[Array, '']
929 The number of variables with available decision rules `var` was chosen
930 from.
931 """
932 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1ab
933 if blocked_vars is not None: 1ab
934 var_to_ignore = jnp.concatenate([var_to_ignore, blocked_vars]) 1ab
936 if log_s is None: 1ab
937 return randint_exclude(key, max_split.size, var_to_ignore) 1ab
938 else:
939 return categorical_exclude(key, log_s, var_to_ignore) 1ab
942def fully_used_variables( 1ab
943 var_tree: UInt[Array, ' 2**(d-1)'],
944 split_tree: UInt[Array, ' 2**(d-1)'],
945 max_split: UInt[Array, ' p'],
946 leaf_index: Int32[Array, ''],
947) -> UInt[Array, ' d-2']:
948 """
949 Find variables in the ancestors of a node that have an empty split range.
951 Parameters
952 ----------
953 var_tree
954 The variable indices of the tree.
955 split_tree
956 The splitting points of the tree.
957 max_split
958 The maximum split index for each variable.
959 leaf_index
960 The index of the node, assumed to be valid for `var_tree`.
962 Returns
963 -------
964 The indices of the variables that have an empty split range.
966 Notes
967 -----
968 The number of unused variables is not known in advance. Unused values in the
969 array are filled with `p`. The fill values are not guaranteed to be placed
970 in any particular order, and variables may appear more than once.
971 """
972 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1ab
973 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1ab
974 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1ab
975 num_split = r - l 1ab
976 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1ab
977 # the type of var_to_ignore is already sufficient to hold max_split.size,
978 # see ancestor_variables()
981def ancestor_variables( 1ab
982 var_tree: UInt[Array, ' 2**(d-1)'],
983 max_split: UInt[Array, ' p'],
984 node_index: Int32[Array, ''],
985) -> UInt[Array, ' d-2']:
986 """
987 Return the list of variables in the ancestors of a node.
989 Parameters
990 ----------
991 var_tree
992 The variable indices of the tree.
993 max_split
994 The maximum split index for each variable. Used only to get `p`.
995 node_index
996 The index of the node, assumed to be valid for `var_tree`.
998 Returns
999 -------
1000 The variable indices of the ancestors of the node.
1002 Notes
1003 -----
1004 The ancestors are the nodes going from the root to the parent of the node.
1005 The number of ancestors is not known at tracing time; unused spots in the
1006 output array are filled with `p`.
1007 """
1008 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab
1009 ancestor_vars = jnp.zeros(max_num_ancestors, minimal_unsigned_dtype(max_split.size)) 1ab
1010 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1ab
1012 def loop(carry, _): 1ab
1013 i, index, ancestor_vars = carry 1ab
1014 index >>= 1 1ab
1015 var = var_tree[index] 1ab
1016 var = jnp.where(index, var, max_split.size) 1ab
1017 ancestor_vars = ancestor_vars.at[i].set(var) 1ab
1018 return (i - 1, index, ancestor_vars), None 1ab
1020 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1ab
1021 return ancestor_vars 1ab
1024def split_range( 1ab
1025 var_tree: UInt[Array, ' 2**(d-1)'],
1026 split_tree: UInt[Array, ' 2**(d-1)'],
1027 max_split: UInt[Array, ' p'],
1028 node_index: Int32[Array, ''],
1029 ref_var: Int32[Array, ''],
1030) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1031 """
1032 Return the range of allowed splits for a variable at a given node.
1034 Parameters
1035 ----------
1036 var_tree
1037 The variable indices of the tree.
1038 split_tree
1039 The splitting points of the tree.
1040 max_split
1041 The maximum split index for each variable.
1042 node_index
1043 The index of the node, assumed to be valid for `var_tree`.
1044 ref_var
1045 The variable for which to measure the split range.
1047 Returns
1048 -------
1049 The range of allowed splits as [l, r). If `ref_var` is out of bounds, l=r=1.
1050 """
1051 max_num_ancestors = grove.tree_depth(var_tree) - 1 1ab
1052 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype( 1ab
1053 jnp.int32
1054 )
1055 carry = jnp.int32(0), initial_r, node_index 1ab
1057 def loop(carry, _): 1ab
1058 l, r, index = carry 1ab
1059 right_child = (index & 1).astype(bool) 1ab
1060 index >>= 1 1ab
1061 split = split_tree[index] 1ab
1062 cond = (var_tree[index] == ref_var) & index.astype(bool) 1ab
1063 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1ab
1064 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1ab
1065 return (l, r, index), None 1ab
1067 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1ab
1068 return l + 1, r 1ab
1071def randint_exclude( 1ab
1072 key: Key[Array, ''], sup: int | Integer[Array, ''], exclude: Integer[Array, ' n']
1073) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1074 """
1075 Return a random integer in a range, excluding some values.
1077 Parameters
1078 ----------
1079 key
1080 A jax random key.
1081 sup
1082 The exclusive upper bound of the range.
1083 exclude
1084 The values to exclude from the range. Values greater than or equal to
1085 `sup` are ignored. Values can appear more than once.
1087 Returns
1088 -------
1089 u : Int32[Array, '']
1090 A random integer `u` in the range ``[0, sup)`` such that ``u not in
1091 exclude``.
1092 num_allowed : Int32[Array, '']
1093 The number of integers in the range that were not excluded.
1095 Notes
1096 -----
1097 If all values in the range are excluded, return `sup`.
1098 """
1099 exclude, num_allowed = _process_exclude(sup, exclude) 1ab
1100 u = random.randint(key, (), 0, num_allowed) 1ab
1102 def loop(u, i_excluded): 1ab
1103 return jnp.where(i_excluded <= u, u + 1, u), None 1ab
1105 u, _ = lax.scan(loop, u, exclude) 1ab
1106 return u, num_allowed 1ab
1109def _process_exclude(sup, exclude): 1ab
1110 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1ab
1111 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1ab
1112 return exclude, num_allowed 1ab
1115def categorical_exclude( 1ab
1116 key: Key[Array, ''], logits: Float32[Array, ' k'], exclude: Integer[Array, ' n']
1117) -> tuple[Int32[Array, ''], Int32[Array, '']]:
1118 """
1119 Draw from a categorical distribution, excluding a set of values.
1121 Parameters
1122 ----------
1123 key
1124 A jax random key.
1125 logits
1126 The unnormalized log-probabilities of each category.
1127 exclude
1128 The values to exclude from the range [0, k). Values greater than or
1129 equal to `logits.size` are ignored. Values can appear more than once.
1131 Returns
1132 -------
1133 u : Int32[Array, '']
1134 A random integer in the range ``[0, k)`` such that ``u not in exclude``.
1135 num_allowed : Int32[Array, '']
1136 The number of integers in the range that were not excluded.
1138 Notes
1139 -----
1140 If all values in the range are excluded, the result is unspecified.
1141 """
1142 exclude, num_allowed = _process_exclude(logits.size, exclude) 1ab
1143 kinda_neg_inf = jnp.finfo(logits.dtype).min 1ab
1144 logits = logits.at[exclude].set(kinda_neg_inf) 1ab
1145 u = random.categorical(key, logits) 1ab
1146 return u, num_allowed 1ab
1149def choose_split( 1ab
1150 key: Key[Array, ''],
1151 var: Int32[Array, ''],
1152 var_tree: UInt[Array, ' 2**(d-1)'],
1153 split_tree: UInt[Array, ' 2**(d-1)'],
1154 max_split: UInt[Array, ' p'],
1155 leaf_index: Int32[Array, ''],
1156) -> tuple[Int32[Array, ''], Int32[Array, ''], Int32[Array, '']]:
1157 """
1158 Choose a split point for a new non-terminal node.
1160 Parameters
1161 ----------
1162 key
1163 A jax random key.
1164 var
1165 The variable to split on.
1166 var_tree
1167 The splitting axes of the tree. Does not need to already contain `var`
1168 at `leaf_index`.
1169 split_tree
1170 The splitting points of the tree.
1171 max_split
1172 The maximum split index for each variable.
1173 leaf_index
1174 The index of the leaf to grow.
1176 Returns
1177 -------
1178 split : Int32[Array, '']
1179 The cutpoint.
1180 l : Int32[Array, '']
1181 r : Int32[Array, '']
1182 The integer range `split` was drawn from is [l, r).
1184 Notes
1185 -----
1186 If `var` is out of bounds, or if the available split range on that variable
1187 is empty, return 0.
1188 """
1189 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1ab
1190 return jnp.where(l < r, random.randint(key, (), l, r), 0), l, r 1ab
1193def compute_partial_ratio( 1ab
1194 prob_choose: Float32[Array, ''],
1195 num_prunable: Int32[Array, ''],
1196 p_nonterminal: Float32[Array, ' 2**d'],
1197 leaf_to_grow: Int32[Array, ''],
1198) -> Float32[Array, '']:
1199 """
1200 Compute the product of the transition and prior ratios of a grow move.
1202 Parameters
1203 ----------
1204 prob_choose
1205 The probability that the leaf had to be chosen amongst the growable
1206 leaves.
1207 num_prunable
1208 The number of leaf parents that could be pruned, after converting the
1209 leaf to be grown to a non-terminal node.
1210 p_nonterminal
1211 The a priori probability of each node being nonterminal conditional on
1212 its ancestors.
1213 leaf_to_grow
1214 The index of the leaf to grow.
1216 Returns
1217 -------
1218 The partial transition ratio times the prior ratio.
1220 Notes
1221 -----
1222 The transition ratio is P(new tree => old tree) / P(old tree => new tree).
1223 The "partial" transition ratio returned is missing the factor P(propose
1224 prune) in the numerator. The prior ratio is P(new tree) / P(old tree). The
1225 "partial" prior ratio is missing the factor P(children are leaves).
1226 """
1227 # the two ratios also contain factors num_available_split *
1228 # num_available_var * s[var], but they cancel out
1230 # p_prune and 1 - p_nonterminal[child] * I(is the child growable) can't be
1231 # computed here because they need the count trees, which are computed in the
1232 # acceptance phase
1234 prune_allowed = leaf_to_grow != 1 1ab
1235 # prune allowed <---> the initial tree is not a root
1236 # leaf to grow is root --> the tree can only be a root
1237 # tree is a root --> the only leaf I can grow is root
1238 p_grow = jnp.where(prune_allowed, 0.5, 1) 1ab
1239 inv_trans_ratio = p_grow * prob_choose * num_prunable 1ab
1241 # .at.get because if leaf_to_grow is out of bounds (move not allowed), this
1242 # would produce a 0 and then an inf when `complete_ratio` takes the log
1243 pnt = p_nonterminal.at[leaf_to_grow].get(mode='fill', fill_value=0.5) 1ab
1244 tree_ratio = pnt / (1 - pnt) 1ab
1246 return tree_ratio / jnp.where(inv_trans_ratio, inv_trans_ratio, 1) 1ab
1249class PruneMoves(Module): 1ab
1250 """
1251 Represent a proposed prune move for each tree.
1253 Parameters
1254 ----------
1255 allowed
1256 Whether the move is possible.
1257 node
1258 The index of the node to prune. ``2 ** d`` if no node can be pruned.
1259 partial_ratio
1260 A factor of the Metropolis-Hastings ratio of the move. It lacks the
1261 likelihood ratio, the probability of proposing the prune move, and the
1262 prior probability that the children of the node to prune are leaves.
1263 This ratio is inverted, and is meant to be inverted back in
1264 `accept_move_and_sample_leaves`.
1265 """
1267 allowed: Bool[Array, ' num_trees'] 1ab
1268 node: UInt[Array, ' num_trees'] 1ab
1269 partial_ratio: Float32[Array, ' num_trees'] 1ab
1270 affluence_tree: Bool[Array, 'num_trees 2**(d-1)'] 1ab
1273@partial(vmap_nodoc, in_axes=(0, 0, 0, None, None)) 1ab
1274def propose_prune_moves( 1ab
1275 key: Key[Array, ''],
1276 split_tree: UInt[Array, ' 2**(d-1)'],
1277 affluence_tree: Bool[Array, ' 2**(d-1)'],
1278 p_nonterminal: Float32[Array, ' 2**d'],
1279 p_propose_grow: Float32[Array, ' 2**(d-1)'],
1280) -> PruneMoves:
1281 """
1282 Tree structure prune move proposal of BART MCMC.
1284 Parameters
1285 ----------
1286 key
1287 A jax random key.
1288 split_tree
1289 The splitting points of the tree.
1290 affluence_tree
1291 Whether each leaf can be grown.
1292 p_nonterminal
1293 The a priori probability of a node to be nonterminal conditional on
1294 the ancestors, including at the maximum depth where it should be zero.
1295 p_propose_grow
1296 The unnormalized probability of choosing a leaf to grow.
1298 Returns
1299 -------
1300 An object representing the proposed moves.
1301 """
1302 node_to_prune, num_prunable, prob_choose, affluence_tree = choose_leaf_parent( 1ab
1303 key, split_tree, affluence_tree, p_propose_grow
1304 )
1306 ratio = compute_partial_ratio( 1ab
1307 prob_choose, num_prunable, p_nonterminal, node_to_prune
1308 )
1310 return PruneMoves( 1ab
1311 allowed=split_tree[1].astype(bool), # allowed iff the tree is not a root
1312 node=node_to_prune,
1313 partial_ratio=ratio,
1314 affluence_tree=affluence_tree,
1315 )
1318def choose_leaf_parent( 1ab
1319 key: Key[Array, ''],
1320 split_tree: UInt[Array, ' 2**(d-1)'],
1321 affluence_tree: Bool[Array, ' 2**(d-1)'],
1322 p_propose_grow: Float32[Array, ' 2**(d-1)'],
1323) -> tuple[
1324 Int32[Array, ''],
1325 Int32[Array, ''],
1326 Float32[Array, ''],
1327 Bool[Array, 'num_trees 2**(d-1)'],
1328]:
1329 """
1330 Pick a non-terminal node with leaf children to prune in a tree.
1332 Parameters
1333 ----------
1334 key
1335 A jax random key.
1336 split_tree
1337 The splitting points of the tree.
1338 affluence_tree
1339 Whether a leaf has enough points to be grown.
1340 p_propose_grow
1341 The unnormalized probability of choosing a leaf to grow.
1343 Returns
1344 -------
1345 node_to_prune : Int32[Array, '']
1346 The index of the node to prune. If ``num_prunable == 0``, return
1347 ``2 ** d``.
1348 num_prunable : Int32[Array, '']
1349 The number of leaf parents that could be pruned.
1350 prob_choose : Float32[Array, '']
1351 The (normalized) probability that `choose_leaf` would chose
1352 `node_to_prune` as leaf to grow, if passed the tree where
1353 `node_to_prune` had been pruned.
1354 affluence_tree : Bool[Array, 'num_trees 2**(d-1)']
1355 A partially updated `affluence_tree`, marking the node to prune as
1356 growable.
1357 """
1358 # sample a node to prune
1359 is_prunable = grove.is_leaves_parent(split_tree) 1ab
1360 num_prunable = jnp.count_nonzero(is_prunable) 1ab
1361 node_to_prune = randint_masked(key, is_prunable) 1ab
1362 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1ab
1364 # compute stuff for reverse move
1365 split_tree = split_tree.at[node_to_prune].set(0) 1ab
1366 affluence_tree = affluence_tree.at[node_to_prune].set(True) 1ab
1367 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1ab
1368 distr_norm = jnp.sum(p_propose_grow, where=is_growable_leaf) 1ab
1369 prob_choose = p_propose_grow.at[node_to_prune].get(mode='fill', fill_value=0) 1ab
1370 prob_choose = prob_choose / jnp.where(distr_norm, distr_norm, 1) 1ab
1372 return node_to_prune, num_prunable, prob_choose, affluence_tree 1ab
1375def randint_masked(key: Key[Array, ''], mask: Bool[Array, ' n']) -> Int32[Array, '']: 1ab
1376 """
1377 Return a random integer in a range, including only some values.
1379 Parameters
1380 ----------
1381 key
1382 A jax random key.
1383 mask
1384 The mask indicating the allowed values.
1386 Returns
1387 -------
1388 A random integer in the range ``[0, n)`` such that ``mask[u] == True``.
1390 Notes
1391 -----
1392 If all values in the mask are `False`, return `n`.
1393 """
1394 ecdf = jnp.cumsum(mask) 1ab
1395 u = random.randint(key, (), 0, ecdf[-1]) 1ab
1396 return jnp.searchsorted(ecdf, u, 'right') 1ab
1399def accept_moves_and_sample_leaves( 1ab
1400 key: Key[Array, ''], bart: State, moves: Moves
1401) -> State:
1402 """
1403 Accept or reject the proposed moves and sample the new leaf values.
1405 Parameters
1406 ----------
1407 key
1408 A jax random key.
1409 bart
1410 A valid BART mcmc state.
1411 moves
1412 The proposed moves, see `propose_moves`.
1414 Returns
1415 -------
1416 A new (valid) BART mcmc state.
1417 """
1418 pso = accept_moves_parallel_stage(key, bart, moves) 1ab
1419 bart, moves = accept_moves_sequential_stage(pso) 1ab
1420 return accept_moves_final_stage(bart, moves) 1ab
1423class Counts(Module): 1ab
1424 """
1425 Number of datapoints in the nodes involved in proposed moves for each tree.
1427 Parameters
1428 ----------
1429 left
1430 Number of datapoints in the left child.
1431 right
1432 Number of datapoints in the right child.
1433 total
1434 Number of datapoints in the parent (``= left + right``).
1435 """
1437 left: UInt[Array, ' num_trees'] 1ab
1438 right: UInt[Array, ' num_trees'] 1ab
1439 total: UInt[Array, ' num_trees'] 1ab
1442class Precs(Module): 1ab
1443 """
1444 Likelihood precision scale in the nodes involved in proposed moves for each tree.
1446 The "likelihood precision scale" of a tree node is the sum of the inverse
1447 squared error scales of the datapoints selected by the node.
1449 Parameters
1450 ----------
1451 left
1452 Likelihood precision scale in the left child.
1453 right
1454 Likelihood precision scale in the right child.
1455 total
1456 Likelihood precision scale in the parent (``= left + right``).
1457 """
1459 left: Float32[Array, ' num_trees'] 1ab
1460 right: Float32[Array, ' num_trees'] 1ab
1461 total: Float32[Array, ' num_trees'] 1ab
1464class PreLkV(Module): 1ab
1465 """
1466 Non-sequential terms of the likelihood ratio for each tree.
1468 These terms can be computed in parallel across trees.
1470 Parameters
1471 ----------
1472 sigma2_left
1473 The noise variance in the left child of the leaves grown or pruned by
1474 the moves.
1475 sigma2_right
1476 The noise variance in the right child of the leaves grown or pruned by
1477 the moves.
1478 sigma2_total
1479 The noise variance in the total of the leaves grown or pruned by the
1480 moves.
1481 sqrt_term
1482 The **logarithm** of the square root term of the likelihood ratio.
1483 """
1485 sigma2_left: Float32[Array, ' num_trees'] 1ab
1486 sigma2_right: Float32[Array, ' num_trees'] 1ab
1487 sigma2_total: Float32[Array, ' num_trees'] 1ab
1488 sqrt_term: Float32[Array, ' num_trees'] 1ab
1491class PreLk(Module): 1ab
1492 """
1493 Non-sequential terms of the likelihood ratio shared by all trees.
1495 Parameters
1496 ----------
1497 exp_factor
1498 The factor to multiply the likelihood ratio by, shared by all trees.
1499 """
1501 exp_factor: Float32[Array, ''] 1ab
1504class PreLf(Module): 1ab
1505 """
1506 Pre-computed terms used to sample leaves from their posterior.
1508 These terms can be computed in parallel across trees.
1510 Parameters
1511 ----------
1512 mean_factor
1513 The factor to be multiplied by the sum of the scaled residuals to
1514 obtain the posterior mean.
1515 centered_leaves
1516 The mean-zero normal values to be added to the posterior mean to
1517 obtain the posterior leaf samples.
1518 """
1520 mean_factor: Float32[Array, 'num_trees 2**d'] 1ab
1521 centered_leaves: Float32[Array, 'num_trees 2**d'] 1ab
1524class ParallelStageOut(Module): 1ab
1525 """
1526 The output of `accept_moves_parallel_stage`.
1528 Parameters
1529 ----------
1530 bart
1531 A partially updated BART mcmc state.
1532 moves
1533 The proposed moves, with `partial_ratio` set to `None` and
1534 `log_trans_prior_ratio` set to its final value.
1535 prec_trees
1536 The likelihood precision scale in each potential or actual leaf node. If
1537 there is no precision scale, this is the number of points in each leaf.
1538 move_counts
1539 The counts of the number of points in the the nodes modified by the
1540 moves. If `bart.min_points_per_leaf` is not set and
1541 `bart.prec_scale` is set, they are not computed.
1542 move_precs
1543 The likelihood precision scale in each node modified by the moves. If
1544 `bart.prec_scale` is not set, this is set to `move_counts`.
1545 prelkv
1546 prelk
1547 prelf
1548 Objects with pre-computed terms of the likelihood ratios and leaf
1549 samples.
1550 """
1552 bart: State 1ab
1553 moves: Moves 1ab
1554 prec_trees: Float32[Array, 'num_trees 2**d'] | Int32[Array, 'num_trees 2**d'] 1ab
1555 move_precs: Precs | Counts 1ab
1556 prelkv: PreLkV 1ab
1557 prelk: PreLk 1ab
1558 prelf: PreLf 1ab
1561def accept_moves_parallel_stage( 1ab
1562 key: Key[Array, ''], bart: State, moves: Moves
1563) -> ParallelStageOut:
1564 """
1565 Pre-compute quantities used to accept moves, in parallel across trees.
1567 Parameters
1568 ----------
1569 key : jax.dtypes.prng_key array
1570 A jax random key.
1571 bart : dict
1572 A BART mcmc state.
1573 moves : dict
1574 The proposed moves, see `propose_moves`.
1576 Returns
1577 -------
1578 An object with all that could be done in parallel.
1579 """
1580 # where the move is grow, modify the state like the move was accepted
1581 bart = replace( 1ab
1582 bart,
1583 forest=replace(
1584 bart.forest,
1585 var_tree=moves.var_tree,
1586 leaf_indices=apply_grow_to_indices(moves, bart.forest.leaf_indices, bart.X),
1587 leaf_tree=adapt_leaf_trees_to_grow_indices(bart.forest.leaf_tree, moves),
1588 ),
1589 )
1591 # count number of datapoints per leaf
1592 if ( 1592 ↛ 1602line 1592 didn't jump to line 1602 because the condition on line 1592 was always true
1593 bart.forest.min_points_per_decision_node is not None
1594 or bart.forest.min_points_per_leaf is not None
1595 or bart.prec_scale is None
1596 ):
1597 count_trees, move_counts = compute_count_trees( 1ab
1598 bart.forest.leaf_indices, moves, bart.forest.count_batch_size
1599 )
1601 # mark which leaves & potential leaves have enough points to be grown
1602 if bart.forest.min_points_per_decision_node is not None: 1ab
1603 count_half_trees = count_trees[:, : bart.forest.var_tree.shape[1]] 1ab
1604 moves = replace( 1ab
1605 moves,
1606 affluence_tree=moves.affluence_tree
1607 & (count_half_trees >= bart.forest.min_points_per_decision_node),
1608 )
1610 # copy updated affluence_tree to state
1611 bart = tree_at(lambda bart: bart.forest.affluence_tree, bart, moves.affluence_tree) 1ab
1613 # veto grove move if new leaves don't have enough datapoints
1614 if bart.forest.min_points_per_leaf is not None: 1ab
1615 moves = replace( 1ab
1616 moves,
1617 allowed=moves.allowed
1618 & (move_counts.left >= bart.forest.min_points_per_leaf)
1619 & (move_counts.right >= bart.forest.min_points_per_leaf),
1620 )
1622 # count number of datapoints per leaf, weighted by error precision scale
1623 if bart.prec_scale is None: 1ab
1624 prec_trees = count_trees 1ab
1625 move_precs = move_counts 1ab
1626 else:
1627 prec_trees, move_precs = compute_prec_trees( 1ab
1628 bart.prec_scale,
1629 bart.forest.leaf_indices,
1630 moves,
1631 bart.forest.count_batch_size,
1632 )
1633 assert move_precs is not None 1ab
1635 # compute some missing information about moves
1636 moves = complete_ratio(moves, bart.forest.p_nonterminal) 1ab
1637 save_ratios = bart.forest.log_likelihood is not None 1ab
1638 bart = replace( 1ab
1639 bart,
1640 forest=replace(
1641 bart.forest,
1642 grow_prop_count=jnp.sum(moves.grow),
1643 prune_prop_count=jnp.sum(moves.allowed & ~moves.grow),
1644 log_trans_prior=moves.log_trans_prior_ratio if save_ratios else None,
1645 ),
1646 )
1648 # pre-compute some likelihood ratio & posterior terms
1649 assert bart.sigma2 is not None # `step` shall temporarily set it to 1 1ab
1650 prelkv, prelk = precompute_likelihood_terms( 1ab
1651 bart.sigma2, bart.forest.sigma_mu2, move_precs
1652 )
1653 prelf = precompute_leaf_terms(key, prec_trees, bart.sigma2, bart.forest.sigma_mu2) 1ab
1655 return ParallelStageOut( 1ab
1656 bart=bart,
1657 moves=moves,
1658 prec_trees=prec_trees,
1659 move_precs=move_precs,
1660 prelkv=prelkv,
1661 prelk=prelk,
1662 prelf=prelf,
1663 )
1666@partial(vmap_nodoc, in_axes=(0, 0, None)) 1ab
1667def apply_grow_to_indices( 1ab
1668 moves: Moves, leaf_indices: UInt[Array, 'num_trees n'], X: UInt[Array, 'p n']
1669) -> UInt[Array, 'num_trees n']:
1670 """
1671 Update the leaf indices to apply a grow move.
1673 Parameters
1674 ----------
1675 moves
1676 The proposed moves, see `propose_moves`.
1677 leaf_indices
1678 The index of the leaf each datapoint falls into.
1679 X
1680 The predictors matrix.
1682 Returns
1683 -------
1684 The updated leaf indices.
1685 """
1686 left_child = moves.node.astype(leaf_indices.dtype) << 1 1ab
1687 go_right = X[moves.grow_var, :] >= moves.grow_split 1ab
1688 tree_size = jnp.array(2 * moves.var_tree.size) 1ab
1689 node_to_update = jnp.where(moves.grow, moves.node, tree_size) 1ab
1690 return jnp.where( 1ab
1691 leaf_indices == node_to_update, left_child + go_right, leaf_indices
1692 )
1695def compute_count_trees( 1ab
1696 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves, batch_size: int | None
1697) -> tuple[Int32[Array, 'num_trees 2**d'], Counts]:
1698 """
1699 Count the number of datapoints in each leaf.
1701 Parameters
1702 ----------
1703 leaf_indices
1704 The index of the leaf each datapoint falls into, with the deeper version
1705 of the tree (post-GROW, pre-PRUNE).
1706 moves
1707 The proposed moves, see `propose_moves`.
1708 batch_size
1709 The data batch size to use for the summation.
1711 Returns
1712 -------
1713 count_trees : Int32[Array, 'num_trees 2**d']
1714 The number of points in each potential or actual leaf node.
1715 counts : Counts
1716 The counts of the number of points in the leaves grown or pruned by the
1717 moves.
1718 """
1719 num_trees, tree_size = moves.var_tree.shape 1ab
1720 tree_size *= 2 1ab
1721 tree_indices = jnp.arange(num_trees) 1ab
1723 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1ab
1725 # count datapoints in nodes modified by move
1726 left = count_trees[tree_indices, moves.left] 1ab
1727 right = count_trees[tree_indices, moves.right] 1ab
1728 counts = Counts(left=left, right=right, total=left + right) 1ab
1730 # write count into non-leaf node
1731 count_trees = count_trees.at[tree_indices, moves.node].set(counts.total) 1ab
1733 return count_trees, counts 1ab
1736def count_datapoints_per_leaf( 1ab
1737 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int | None
1738) -> Int32[Array, 'num_trees 2**(d-1)']:
1739 """
1740 Count the number of datapoints in each leaf.
1742 Parameters
1743 ----------
1744 leaf_indices
1745 The index of the leaf each datapoint falls into.
1746 tree_size
1747 The size of the leaf tree array (2 ** d).
1748 batch_size
1749 The data batch size to use for the summation.
1751 Returns
1752 -------
1753 The number of points in each leaf node.
1754 """
1755 if batch_size is None: 1ab
1756 return _count_scan(leaf_indices, tree_size) 1ab
1757 else:
1758 return _count_vec(leaf_indices, tree_size, batch_size) 1ab
1761def _count_scan( 1ab
1762 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int
1763) -> Int32[Array, 'num_trees {tree_size}']:
1764 def loop(_, leaf_indices): 1ab
1765 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1ab
1767 _, count_trees = lax.scan(loop, None, leaf_indices) 1ab
1768 return count_trees 1ab
1771def _aggregate_scatter( 1ab
1772 values: Shaped[Array, '*'],
1773 indices: Integer[Array, '*'],
1774 size: int,
1775 dtype: jnp.dtype,
1776) -> Shaped[Array, ' {size}']:
1777 return jnp.zeros(size, dtype).at[indices].add(values) 1ab
1780def _count_vec( 1ab
1781 leaf_indices: UInt[Array, 'num_trees n'], tree_size: int, batch_size: int
1782) -> Int32[Array, 'num_trees 2**(d-1)']:
1783 return _aggregate_batched_alltrees( 1ab
1784 1, leaf_indices, tree_size, jnp.uint32, batch_size
1785 )
1786 # uint16 is super-slow on gpu, don't use it even if n < 2^16
1789def _aggregate_batched_alltrees( 1ab
1790 values: Shaped[Array, '*'],
1791 indices: UInt[Array, 'num_trees n'],
1792 size: int,
1793 dtype: jnp.dtype,
1794 batch_size: int,
1795) -> Shaped[Array, 'num_trees {size}']:
1796 num_trees, n = indices.shape 1ab
1797 tree_indices = jnp.arange(num_trees) 1ab
1798 nbatches = n // batch_size + bool(n % batch_size) 1ab
1799 batch_indices = jnp.arange(n) % nbatches 1ab
1800 return ( 1ab
1801 jnp.zeros((num_trees, size, nbatches), dtype)
1802 .at[tree_indices[:, None], indices, batch_indices]
1803 .add(values)
1804 .sum(axis=2)
1805 )
1808def compute_prec_trees( 1ab
1809 prec_scale: Float32[Array, ' n'],
1810 leaf_indices: UInt[Array, 'num_trees n'],
1811 moves: Moves,
1812 batch_size: int | None,
1813) -> tuple[Float32[Array, 'num_trees 2**d'], Precs]:
1814 """
1815 Compute the likelihood precision scale in each leaf.
1817 Parameters
1818 ----------
1819 prec_scale
1820 The scale of the precision of the error on each datapoint.
1821 leaf_indices
1822 The index of the leaf each datapoint falls into, with the deeper version
1823 of the tree (post-GROW, pre-PRUNE).
1824 moves
1825 The proposed moves, see `propose_moves`.
1826 batch_size
1827 The data batch size to use for the summation.
1829 Returns
1830 -------
1831 prec_trees : Float32[Array, 'num_trees 2**d']
1832 The likelihood precision scale in each potential or actual leaf node.
1833 precs : Precs
1834 The likelihood precision scale in the nodes involved in the moves.
1835 """
1836 num_trees, tree_size = moves.var_tree.shape 1ab
1837 tree_size *= 2 1ab
1838 tree_indices = jnp.arange(num_trees) 1ab
1840 prec_trees = prec_per_leaf(prec_scale, leaf_indices, tree_size, batch_size) 1ab
1842 # prec datapoints in nodes modified by move
1843 left = prec_trees[tree_indices, moves.left] 1ab
1844 right = prec_trees[tree_indices, moves.right] 1ab
1845 precs = Precs(left=left, right=right, total=left + right) 1ab
1847 # write prec into non-leaf node
1848 prec_trees = prec_trees.at[tree_indices, moves.node].set(precs.total) 1ab
1850 return prec_trees, precs 1ab
1853def prec_per_leaf( 1ab
1854 prec_scale: Float32[Array, ' n'],
1855 leaf_indices: UInt[Array, 'num_trees n'],
1856 tree_size: int,
1857 batch_size: int | None,
1858) -> Float32[Array, 'num_trees {tree_size}']:
1859 """
1860 Compute the likelihood precision scale in each leaf.
1862 Parameters
1863 ----------
1864 prec_scale
1865 The scale of the precision of the error on each datapoint.
1866 leaf_indices
1867 The index of the leaf each datapoint falls into.
1868 tree_size
1869 The size of the leaf tree array (2 ** d).
1870 batch_size
1871 The data batch size to use for the summation.
1873 Returns
1874 -------
1875 The likelihood precision scale in each leaf node.
1876 """
1877 if batch_size is None: 1877 ↛ 1880line 1877 didn't jump to line 1880 because the condition on line 1877 was always true1ab
1878 return _prec_scan(prec_scale, leaf_indices, tree_size) 1ab
1879 else:
1880 return _prec_vec(prec_scale, leaf_indices, tree_size, batch_size)
1883def _prec_scan( 1ab
1884 prec_scale: Float32[Array, ' n'],
1885 leaf_indices: UInt[Array, 'num_trees n'],
1886 tree_size: int,
1887) -> Float32[Array, 'num_trees {tree_size}']:
1888 def loop(_, leaf_indices): 1ab
1889 return None, _aggregate_scatter( 1ab
1890 prec_scale, leaf_indices, tree_size, jnp.float32
1891 )
1893 _, prec_trees = lax.scan(loop, None, leaf_indices) 1ab
1894 return prec_trees 1ab
1897def _prec_vec( 1ab
1898 prec_scale: Float32[Array, ' n'],
1899 leaf_indices: UInt[Array, 'num_trees n'],
1900 tree_size: int,
1901 batch_size: int,
1902) -> Float32[Array, 'num_trees {tree_size}']:
1903 return _aggregate_batched_alltrees(
1904 prec_scale, leaf_indices, tree_size, jnp.float32, batch_size
1905 )
1908def complete_ratio(moves: Moves, p_nonterminal: Float32[Array, ' 2**d']) -> Moves: 1ab
1909 """
1910 Complete non-likelihood MH ratio calculation.
1912 This function adds the probability of choosing a prune move over the grow
1913 move in the inverse transition, and the a priori probability that the
1914 children nodes are leaves.
1916 Parameters
1917 ----------
1918 moves
1919 The proposed moves. Must have already been updated to keep into account
1920 the thresholds on the number of datapoints per node, this happens in
1921 `accept_moves_parallel_stage`.
1922 p_nonterminal
1923 The a priori probability of each node being nonterminal conditional on
1924 its ancestors, including at the maximum depth where it should be zero.
1926 Returns
1927 -------
1928 The updated moves, with `partial_ratio=None` and `log_trans_prior_ratio` set.
1929 """
1930 # can the leaves can be grown?
1931 num_trees, _ = moves.affluence_tree.shape 1ab
1932 tree_indices = jnp.arange(num_trees) 1ab
1933 left_growable = moves.affluence_tree.at[tree_indices, moves.left].get( 1ab
1934 mode='fill', fill_value=False
1935 )
1936 right_growable = moves.affluence_tree.at[tree_indices, moves.right].get( 1ab
1937 mode='fill', fill_value=False
1938 )
1940 # p_prune if grow
1941 other_growable_leaves = moves.num_growable >= 2 1ab
1942 grow_again_allowed = other_growable_leaves | left_growable | right_growable 1ab
1943 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1ab
1945 # p_prune if prune
1946 prune_p_prune = jnp.where(moves.num_growable, 0.5, 1) 1ab
1948 # select p_prune
1949 p_prune = jnp.where(moves.grow, grow_p_prune, prune_p_prune) 1ab
1951 # prior probability of both children being terminal
1952 pt_left = 1 - p_nonterminal[moves.left] * left_growable 1ab
1953 pt_right = 1 - p_nonterminal[moves.right] * right_growable 1ab
1954 pt_children = pt_left * pt_right 1ab
1956 return replace( 1ab
1957 moves,
1958 log_trans_prior_ratio=jnp.log(moves.partial_ratio * pt_children * p_prune),
1959 partial_ratio=None,
1960 )
1963@vmap_nodoc 1ab
1964def adapt_leaf_trees_to_grow_indices( 1ab
1965 leaf_trees: Float32[Array, 'num_trees 2**d'], moves: Moves
1966) -> Float32[Array, 'num_trees 2**d']:
1967 """
1968 Modify leaves such that post-grow indices work on the original tree.
1970 The value of the leaf to grow is copied to what would be its children if the
1971 grow move was accepted.
1973 Parameters
1974 ----------
1975 leaf_trees
1976 The leaf values.
1977 moves
1978 The proposed moves, see `propose_moves`.
1980 Returns
1981 -------
1982 The modified leaf values.
1983 """
1984 values_at_node = leaf_trees[moves.node] 1ab
1985 return ( 1ab
1986 leaf_trees.at[jnp.where(moves.grow, moves.left, leaf_trees.size)]
1987 .set(values_at_node)
1988 .at[jnp.where(moves.grow, moves.right, leaf_trees.size)]
1989 .set(values_at_node)
1990 )
1993def precompute_likelihood_terms( 1ab
1994 sigma2: Float32[Array, ''],
1995 sigma_mu2: Float32[Array, ''],
1996 move_precs: Precs | Counts,
1997) -> tuple[PreLkV, PreLk]:
1998 """
1999 Pre-compute terms used in the likelihood ratio of the acceptance step.
2001 Parameters
2002 ----------
2003 sigma2
2004 The error variance, or the global error variance factor is `prec_scale`
2005 is set.
2006 sigma_mu2
2007 The prior variance of each leaf.
2008 move_precs
2009 The likelihood precision scale in the leaves grown or pruned by the
2010 moves, under keys 'left', 'right', and 'total' (left + right).
2012 Returns
2013 -------
2014 prelkv : PreLkV
2015 Dictionary with pre-computed terms of the likelihood ratio, one per
2016 tree.
2017 prelk : PreLk
2018 Dictionary with pre-computed terms of the likelihood ratio, shared by
2019 all trees.
2020 """
2021 sigma2_left = sigma2 + move_precs.left * sigma_mu2 1ab
2022 sigma2_right = sigma2 + move_precs.right * sigma_mu2 1ab
2023 sigma2_total = sigma2 + move_precs.total * sigma_mu2 1ab
2024 prelkv = PreLkV( 1ab
2025 sigma2_left=sigma2_left,
2026 sigma2_right=sigma2_right,
2027 sigma2_total=sigma2_total,
2028 sqrt_term=jnp.log(sigma2 * sigma2_total / (sigma2_left * sigma2_right)) / 2,
2029 )
2030 return prelkv, PreLk(exp_factor=sigma_mu2 / (2 * sigma2)) 1ab
2033def precompute_leaf_terms( 1ab
2034 key: Key[Array, ''],
2035 prec_trees: Float32[Array, 'num_trees 2**d'],
2036 sigma2: Float32[Array, ''],
2037 sigma_mu2: Float32[Array, ''],
2038) -> PreLf:
2039 """
2040 Pre-compute terms used to sample leaves from their posterior.
2042 Parameters
2043 ----------
2044 key
2045 A jax random key.
2046 prec_trees
2047 The likelihood precision scale in each potential or actual leaf node.
2048 sigma2
2049 The error variance, or the global error variance factor if `prec_scale`
2050 is set.
2051 sigma_mu2
2052 The prior variance of each leaf.
2054 Returns
2055 -------
2056 Pre-computed terms for leaf sampling.
2057 """
2058 prec_lk = prec_trees / sigma2 1ab
2059 prec_prior = lax.reciprocal(sigma_mu2) 1ab
2060 var_post = lax.reciprocal(prec_lk + prec_prior) 1ab
2061 z = random.normal(key, prec_trees.shape, sigma2.dtype) 1ab
2062 return PreLf( 1ab
2063 mean_factor=var_post / sigma2,
2064 # | mean = mean_lk * prec_lk * var_post
2065 # | resid_tree = mean_lk * prec_tree -->
2066 # | --> mean_lk = resid_tree / prec_tree (kind of)
2067 # | mean_factor =
2068 # | = mean / resid_tree =
2069 # | = resid_tree / prec_tree * prec_lk * var_post / resid_tree =
2070 # | = 1 / prec_tree * prec_tree / sigma2 * var_post =
2071 # | = var_post / sigma2
2072 centered_leaves=z * jnp.sqrt(var_post),
2073 )
2076def accept_moves_sequential_stage(pso: ParallelStageOut) -> tuple[State, Moves]: 1ab
2077 """
2078 Accept/reject the moves one tree at a time.
2080 This is the most performance-sensitive function because it contains all and
2081 only the parts of the algorithm that can not be parallelized across trees.
2083 Parameters
2084 ----------
2085 pso
2086 The output of `accept_moves_parallel_stage`.
2088 Returns
2089 -------
2090 bart : State
2091 A partially updated BART mcmc state.
2092 moves : Moves
2093 The accepted/rejected moves, with `acc` and `to_prune` set.
2094 """
2096 def loop(resid, pt): 1ab
2097 resid, leaf_tree, acc, to_prune, lkratio = accept_move_and_sample_leaves( 1ab
2098 resid,
2099 SeqStageInAllTrees(
2100 pso.bart.X,
2101 pso.bart.forest.resid_batch_size,
2102 pso.bart.prec_scale,
2103 pso.bart.forest.log_likelihood is not None,
2104 pso.prelk,
2105 ),
2106 pt,
2107 )
2108 return resid, (leaf_tree, acc, to_prune, lkratio) 1ab
2110 pts = SeqStageInPerTree( 1ab
2111 pso.bart.forest.leaf_tree,
2112 pso.prec_trees,
2113 pso.moves,
2114 pso.move_precs,
2115 pso.bart.forest.leaf_indices,
2116 pso.prelkv,
2117 pso.prelf,
2118 )
2119 resid, (leaf_trees, acc, to_prune, lkratio) = lax.scan(loop, pso.bart.resid, pts) 1ab
2121 bart = replace( 1ab
2122 pso.bart,
2123 resid=resid,
2124 forest=replace(pso.bart.forest, leaf_tree=leaf_trees, log_likelihood=lkratio),
2125 )
2126 moves = replace(pso.moves, acc=acc, to_prune=to_prune) 1ab
2128 return bart, moves 1ab
2131class SeqStageInAllTrees(Module): 1ab
2132 """
2133 The inputs to `accept_move_and_sample_leaves` that are shared by all trees.
2135 Parameters
2136 ----------
2137 X
2138 The predictors.
2139 resid_batch_size
2140 The batch size for computing the sum of residuals in each leaf.
2141 prec_scale
2142 The scale of the precision of the error on each datapoint. If None, it
2143 is assumed to be 1.
2144 save_ratios
2145 Whether to save the acceptance ratios.
2146 prelk
2147 The pre-computed terms of the likelihood ratio which are shared across
2148 trees.
2149 """
2151 X: UInt[Array, 'p n'] 1ab
2152 resid_batch_size: int | None = field(static=True) 1ab
2153 prec_scale: Float32[Array, ' n'] | None 1ab
2154 save_ratios: bool = field(static=True) 1ab
2155 prelk: PreLk 1ab
2158class SeqStageInPerTree(Module): 1ab
2159 """
2160 The inputs to `accept_move_and_sample_leaves` that are separate for each tree.
2162 Parameters
2163 ----------
2164 leaf_tree
2165 The leaf values of the tree.
2166 prec_tree
2167 The likelihood precision scale in each potential or actual leaf node.
2168 move
2169 The proposed move, see `propose_moves`.
2170 move_precs
2171 The likelihood precision scale in each node modified by the moves.
2172 leaf_indices
2173 The leaf indices for the largest version of the tree compatible with
2174 the move.
2175 prelkv
2176 prelf
2177 The pre-computed terms of the likelihood ratio and leaf sampling which
2178 are specific to the tree.
2179 """
2181 leaf_tree: Float32[Array, ' 2**d'] 1ab
2182 prec_tree: Float32[Array, ' 2**d'] 1ab
2183 move: Moves 1ab
2184 move_precs: Precs | Counts 1ab
2185 leaf_indices: UInt[Array, ' n'] 1ab
2186 prelkv: PreLkV 1ab
2187 prelf: PreLf 1ab
2190def accept_move_and_sample_leaves( 1ab
2191 resid: Float32[Array, ' n'], at: SeqStageInAllTrees, pt: SeqStageInPerTree
2192) -> tuple[
2193 Float32[Array, ' n'],
2194 Float32[Array, ' 2**d'],
2195 Bool[Array, ''],
2196 Bool[Array, ''],
2197 Float32[Array, ''] | None,
2198]:
2199 """
2200 Accept or reject a proposed move and sample the new leaf values.
2202 Parameters
2203 ----------
2204 resid
2205 The residuals (data minus forest value).
2206 at
2207 The inputs that are the same for all trees.
2208 pt
2209 The inputs that are separate for each tree.
2211 Returns
2212 -------
2213 resid : Float32[Array, 'n']
2214 The updated residuals (data minus forest value).
2215 leaf_tree : Float32[Array, '2**d']
2216 The new leaf values of the tree.
2217 acc : Bool[Array, '']
2218 Whether the move was accepted.
2219 to_prune : Bool[Array, '']
2220 Whether, to reflect the acceptance status of the move, the state should
2221 be updated by pruning the leaves involved in the move.
2222 log_lk_ratio : Float32[Array, ''] | None
2223 The logarithm of the likelihood ratio for the move. `None` if not to be
2224 saved.
2225 """
2226 # sum residuals in each leaf, in tree proposed by grow move
2227 if at.prec_scale is None: 1ab
2228 scaled_resid = resid 1ab
2229 else:
2230 scaled_resid = resid * at.prec_scale 1ab
2231 resid_tree = sum_resid( 1ab
2232 scaled_resid, pt.leaf_indices, pt.leaf_tree.size, at.resid_batch_size
2233 )
2235 # subtract starting tree from function
2236 resid_tree += pt.prec_tree * pt.leaf_tree 1ab
2238 # sum residuals in parent node modified by move
2239 resid_left = resid_tree[pt.move.left] 1ab
2240 resid_right = resid_tree[pt.move.right] 1ab
2241 resid_total = resid_left + resid_right 1ab
2242 assert pt.move.node.dtype == jnp.int32 1ab
2243 resid_tree = resid_tree.at[pt.move.node].set(resid_total) 1ab
2245 # compute acceptance ratio
2246 log_lk_ratio = compute_likelihood_ratio( 1ab
2247 resid_total, resid_left, resid_right, pt.prelkv, at.prelk
2248 )
2249 log_ratio = pt.move.log_trans_prior_ratio + log_lk_ratio 1ab
2250 log_ratio = jnp.where(pt.move.grow, log_ratio, -log_ratio) 1ab
2251 if not at.save_ratios: 1ab
2252 log_lk_ratio = None 1ab
2254 # determine whether to accept the move
2255 acc = pt.move.allowed & (pt.move.logu <= log_ratio) 1ab
2257 # compute leaves posterior and sample leaves
2258 mean_post = resid_tree * pt.prelf.mean_factor 1ab
2259 leaf_tree = mean_post + pt.prelf.centered_leaves 1ab
2261 # copy leaves around such that the leaf indices point to the correct leaf
2262 to_prune = acc ^ pt.move.grow 1ab
2263 leaf_tree = ( 1ab
2264 leaf_tree.at[jnp.where(to_prune, pt.move.left, leaf_tree.size)]
2265 .set(leaf_tree[pt.move.node])
2266 .at[jnp.where(to_prune, pt.move.right, leaf_tree.size)]
2267 .set(leaf_tree[pt.move.node])
2268 )
2270 # replace old tree with new tree in function values
2271 resid += (pt.leaf_tree - leaf_tree)[pt.leaf_indices] 1ab
2273 return resid, leaf_tree, acc, to_prune, log_lk_ratio 1ab
2276def sum_resid( 1ab
2277 scaled_resid: Float32[Array, ' n'],
2278 leaf_indices: UInt[Array, ' n'],
2279 tree_size: int,
2280 batch_size: int | None,
2281) -> Float32[Array, ' {tree_size}']:
2282 """
2283 Sum the residuals in each leaf.
2285 Parameters
2286 ----------
2287 scaled_resid
2288 The residuals (data minus forest value) multiplied by the error
2289 precision scale.
2290 leaf_indices
2291 The leaf indices of the tree (in which leaf each data point falls into).
2292 tree_size
2293 The size of the tree array (2 ** d).
2294 batch_size
2295 The data batch size for the aggregation. Batching increases numerical
2296 accuracy and parallelism.
2298 Returns
2299 -------
2300 The sum of the residuals at data points in each leaf.
2301 """
2302 if batch_size is None: 1ab
2303 aggr_func = _aggregate_scatter 1ab
2304 else:
2305 aggr_func = partial(_aggregate_batched_onetree, batch_size=batch_size) 1ab
2306 return aggr_func(scaled_resid, leaf_indices, tree_size, jnp.float32) 1ab
2309def _aggregate_batched_onetree( 1ab
2310 values: Shaped[Array, '*'],
2311 indices: Integer[Array, '*'],
2312 size: int,
2313 dtype: jnp.dtype,
2314 batch_size: int,
2315) -> Float32[Array, ' {size}']:
2316 (n,) = indices.shape 1ab
2317 nbatches = n // batch_size + bool(n % batch_size) 1ab
2318 batch_indices = jnp.arange(n) % nbatches 1ab
2319 return ( 1ab
2320 jnp.zeros((size, nbatches), dtype)
2321 .at[indices, batch_indices]
2322 .add(values)
2323 .sum(axis=1)
2324 )
2327def compute_likelihood_ratio( 1ab
2328 total_resid: Float32[Array, ''],
2329 left_resid: Float32[Array, ''],
2330 right_resid: Float32[Array, ''],
2331 prelkv: PreLkV,
2332 prelk: PreLk,
2333) -> Float32[Array, '']:
2334 """
2335 Compute the likelihood ratio of a grow move.
2337 Parameters
2338 ----------
2339 total_resid
2340 left_resid
2341 right_resid
2342 The sum of the residuals (scaled by error precision scale) of the
2343 datapoints falling in the nodes involved in the moves.
2344 prelkv
2345 prelk
2346 The pre-computed terms of the likelihood ratio, see
2347 `precompute_likelihood_terms`.
2349 Returns
2350 -------
2351 The likelihood ratio P(data | new tree) / P(data | old tree).
2352 """
2353 exp_term = prelk.exp_factor * ( 1ab
2354 left_resid * left_resid / prelkv.sigma2_left
2355 + right_resid * right_resid / prelkv.sigma2_right
2356 - total_resid * total_resid / prelkv.sigma2_total
2357 )
2358 return prelkv.sqrt_term + exp_term 1ab
2361def accept_moves_final_stage(bart: State, moves: Moves) -> State: 1ab
2362 """
2363 Post-process the mcmc state after accepting/rejecting the moves.
2365 This function is separate from `accept_moves_sequential_stage` to signal it
2366 can work in parallel across trees.
2368 Parameters
2369 ----------
2370 bart
2371 A partially updated BART mcmc state.
2372 moves
2373 The proposed moves (see `propose_moves`) as updated by
2374 `accept_moves_sequential_stage`.
2376 Returns
2377 -------
2378 The fully updated BART mcmc state.
2379 """
2380 return replace( 1ab
2381 bart,
2382 forest=replace(
2383 bart.forest,
2384 grow_acc_count=jnp.sum(moves.acc & moves.grow),
2385 prune_acc_count=jnp.sum(moves.acc & ~moves.grow),
2386 leaf_indices=apply_moves_to_leaf_indices(bart.forest.leaf_indices, moves),
2387 split_tree=apply_moves_to_split_trees(bart.forest.split_tree, moves),
2388 ),
2389 )
2392@vmap_nodoc 1ab
2393def apply_moves_to_leaf_indices( 1ab
2394 leaf_indices: UInt[Array, 'num_trees n'], moves: Moves
2395) -> UInt[Array, 'num_trees n']:
2396 """
2397 Update the leaf indices to match the accepted move.
2399 Parameters
2400 ----------
2401 leaf_indices
2402 The index of the leaf each datapoint falls into, if the grow move was
2403 accepted.
2404 moves
2405 The proposed moves (see `propose_moves`), as updated by
2406 `accept_moves_sequential_stage`.
2408 Returns
2409 -------
2410 The updated leaf indices.
2411 """
2412 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1ab
2413 is_child = (leaf_indices & mask) == moves.left 1ab
2414 return jnp.where( 1ab
2415 is_child & moves.to_prune, moves.node.astype(leaf_indices.dtype), leaf_indices
2416 )
2419@vmap_nodoc 1ab
2420def apply_moves_to_split_trees( 1ab
2421 split_tree: UInt[Array, 'num_trees 2**(d-1)'], moves: Moves
2422) -> UInt[Array, 'num_trees 2**(d-1)']:
2423 """
2424 Update the split trees to match the accepted move.
2426 Parameters
2427 ----------
2428 split_tree
2429 The cutpoints of the decision nodes in the initial trees.
2430 moves
2431 The proposed moves (see `propose_moves`), as updated by
2432 `accept_moves_sequential_stage`.
2434 Returns
2435 -------
2436 The updated split trees.
2437 """
2438 assert moves.to_prune is not None 1ab
2439 return ( 1ab
2440 split_tree.at[jnp.where(moves.grow, moves.node, split_tree.size)]
2441 .set(moves.grow_split.astype(split_tree.dtype))
2442 .at[jnp.where(moves.to_prune, moves.node, split_tree.size)]
2443 .set(0)
2444 )
2447def step_sigma(key: Key[Array, ''], bart: State) -> State: 1ab
2448 """
2449 MCMC-update the error variance (factor).
2451 Parameters
2452 ----------
2453 key
2454 A jax random key.
2455 bart
2456 A BART mcmc state.
2458 Returns
2459 -------
2460 The new BART mcmc state, with an updated `sigma2`.
2461 """
2462 resid = bart.resid 1ab
2463 alpha = bart.sigma2_alpha + resid.size / 2 1ab
2464 if bart.prec_scale is None: 1ab
2465 scaled_resid = resid 1ab
2466 else:
2467 scaled_resid = resid * bart.prec_scale 1ab
2468 norm2 = resid @ scaled_resid 1ab
2469 beta = bart.sigma2_beta + norm2 / 2 1ab
2471 sample = random.gamma(key, alpha) 1ab
2472 # random.gamma seems to be slow at compiling, maybe cdf inversion would
2473 # be better, but it's not implemented in jax
2474 return replace(bart, sigma2=beta / sample) 1ab
2477@jax.jit 1ab
2478def _sample_wishart_bartlett( 1ab
2479 key: Key[Array, ''], df: Integer[Array, ''], scale_inv: Float32[Array, 'k k']
2480) -> Float32[Array, 'k k']:
2481 """
2482 Sample a precision matrix W ~ Wishart(df, scale_inv^-1) using Bartlett decomposition.
2484 Parameters
2485 ----------
2486 key
2487 A JAX random key
2488 df
2489 Degrees of freedom
2490 scale_inv
2491 Scale matrix of the corresponding Inverse Wishart distribution
2493 Returns
2494 -------
2495 A sample from Wishart(df, scale)
2496 """
2497 keys = split(key) 1ab
2499 k = scale_inv.shape[0] 1ab
2501 # Gershgorin estimate for max eigenvalue
2502 rho = jnp.max(jnp.sum(jnp.abs(scale_inv), axis=1)) 1ab
2503 u = k * rho * jnp.finfo(scale_inv.dtype).eps + jnp.finfo(scale_inv.dtype).eps 1ab
2505 # Stabilize the matrix
2506 scale_inv = scale_inv.at[jnp.diag_indices(k)].add(u) 1ab
2507 L = jnp.linalg.cholesky(scale_inv) 1ab
2509 # Diagonal elements: A_ii ~ sqrt(chi^2(df - i))
2510 # chi^2(k) = Gamma(k/2, scale=2)
2511 df_vector = df - jnp.arange(k) 1ab
2512 chi2_samples = random.gamma(keys.pop(), df_vector / 2.0) * 2.0 1ab
2513 diag_A = jnp.sqrt(chi2_samples) 1ab
2515 off_diag_A = random.normal(keys.pop(), (k, k)) 1ab
2516 A = jnp.tril(off_diag_A, -1) + jnp.diag(diag_A) 1ab
2517 T = solve_triangular(L, A, lower=True, trans='T') 1ab
2519 return T @ T.T 1ab
2522def step_z(key: Key[Array, ''], bart: State) -> State: 1ab
2523 """
2524 MCMC-update the latent variable for binary regression.
2526 Parameters
2527 ----------
2528 key
2529 A jax random key.
2530 bart
2531 A BART MCMC state.
2533 Returns
2534 -------
2535 The updated BART MCMC state.
2536 """
2537 trees_plus_offset = bart.z - bart.resid 1ab
2538 assert bart.y.dtype == bool 1ab
2539 resid = truncated_normal_onesided(key, (), ~bart.y, -trees_plus_offset) 1ab
2540 z = trees_plus_offset + resid 1ab
2541 return replace(bart, z=z, resid=resid) 1ab
2544def step_s(key: Key[Array, ''], bart: State) -> State: 1ab
2545 """
2546 Update `log_s` using Dirichlet sampling.
2548 The prior is s ~ Dirichlet(theta/p, ..., theta/p), and the posterior
2549 is s ~ Dirichlet(theta/p + varcount, ..., theta/p + varcount), where
2550 varcount is the count of how many times each variable is used in the
2551 current forest.
2553 Parameters
2554 ----------
2555 key
2556 Random key for sampling.
2557 bart
2558 The current BART state.
2560 Returns
2561 -------
2562 Updated BART state with re-sampled `log_s`.
2564 Notes
2565 -----
2566 This full conditional is approximated, because it does not take into account
2567 that there are forbidden decision rules.
2568 """
2569 assert bart.forest.theta is not None 1ab
2571 # histogram current variable usage
2572 p = bart.forest.max_split.size 1ab
2573 varcount = grove.var_histogram(p, bart.forest.var_tree, bart.forest.split_tree) 1ab
2575 # sample from Dirichlet posterior
2576 alpha = bart.forest.theta / p + varcount 1ab
2577 log_s = random.loggamma(key, alpha) 1ab
2579 # update forest with new s
2580 return replace(bart, forest=replace(bart.forest, log_s=log_s)) 1ab
2583def step_theta(key: Key[Array, ''], bart: State, *, num_grid: int = 1000) -> State: 1ab
2584 """
2585 Update `theta`.
2587 The prior is theta / (theta + rho) ~ Beta(a, b).
2589 Parameters
2590 ----------
2591 key
2592 Random key for sampling.
2593 bart
2594 The current BART state.
2595 num_grid
2596 The number of points in the evenly-spaced grid used to sample
2597 theta / (theta + rho).
2599 Returns
2600 -------
2601 Updated BART state with re-sampled `theta`.
2602 """
2603 assert bart.forest.log_s is not None 1ab
2604 assert bart.forest.rho is not None 1ab
2605 assert bart.forest.a is not None 1ab
2606 assert bart.forest.b is not None 1ab
2608 # the grid points are the midpoints of num_grid bins in (0, 1)
2609 padding = 1 / (2 * num_grid) 1ab
2610 lamda_grid = jnp.linspace(padding, 1 - padding, num_grid) 1ab
2612 # normalize s
2613 log_s = bart.forest.log_s - logsumexp(bart.forest.log_s) 1ab
2615 # sample lambda
2616 logp, theta_grid = _log_p_lamda( 1ab
2617 lamda_grid, log_s, bart.forest.rho, bart.forest.a, bart.forest.b
2618 )
2619 i = random.categorical(key, logp) 1ab
2620 theta = theta_grid[i] 1ab
2622 return replace(bart, forest=replace(bart.forest, theta=theta)) 1ab
2625def _log_p_lamda( 1ab
2626 lamda: Float32[Array, ' num_grid'],
2627 log_s: Float32[Array, ' p'],
2628 rho: Float32[Array, ''],
2629 a: Float32[Array, ''],
2630 b: Float32[Array, ''],
2631) -> tuple[Float32[Array, ' num_grid'], Float32[Array, ' num_grid']]:
2632 # in the following I use lamda[::-1] == 1 - lamda
2633 theta = rho * lamda / lamda[::-1] 1ab
2634 p = log_s.size 1ab
2635 return ( 1ab
2636 (a - 1) * jnp.log1p(-lamda[::-1]) # log(lambda)
2637 + (b - 1) * jnp.log1p(-lamda) # log(1 - lambda)
2638 + gammaln(theta)
2639 - p * gammaln(theta / p)
2640 + theta / p * jnp.sum(log_s)
2641 ), theta
2644def step_sparse(key: Key[Array, ''], bart: State) -> State: 1ab
2645 """
2646 Update the sparsity parameters.
2648 This invokes `step_s`, and then `step_theta` only if the parameters of
2649 the theta prior are defined.
2651 Parameters
2652 ----------
2653 key
2654 Random key for sampling.
2655 bart
2656 The current BART state.
2658 Returns
2659 -------
2660 Updated BART state with re-sampled `log_s` and `theta`.
2661 """
2662 keys = split(key) 1ab
2663 bart = step_s(keys.pop(), bart) 1ab
2664 if bart.forest.rho is not None: 1ab
2665 bart = step_theta(keys.pop(), bart) 1ab
2666 return bart 1ab