Coverage for src/bartz/mcmcstep.py: 94%
360 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
1# bartz/src/bartz/mcmcstep.py
2#
3# Copyright (c) 2024, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""
26Functions that implement the BART posterior MCMC initialization and update step.
28Functions that do MCMC steps operate by taking as input a bart state, and
29outputting a new dictionary with the new state. The input dict/arrays are not
30modified.
32In general, integer types are chosen to be the minimal types that contain the
33range of possible values.
34"""
36import functools 1a
37import math 1a
39import jax 1a
40from jax import random 1a
41from jax import numpy as jnp 1a
42from jax import lax 1a
44from . import jaxext 1a
45from . import grove 1a
47def init(*, 1a
48 X,
49 y,
50 max_split,
51 num_trees,
52 p_nonterminal,
53 sigma2_alpha,
54 sigma2_beta,
55 small_float=jnp.float32,
56 large_float=jnp.float32,
57 min_points_per_leaf=None,
58 resid_batch_size='auto',
59 count_batch_size='auto',
60 save_ratios=False,
61 ):
62 """
63 Make a BART posterior sampling MCMC initial state.
65 Parameters
66 ----------
67 X : int array (p, n)
68 The predictors. Note this is trasposed compared to the usual convention.
69 y : float array (n,)
70 The response.
71 max_split : int array (p,)
72 The maximum split index for each variable. All split ranges start at 1.
73 num_trees : int
74 The number of trees in the forest.
75 p_nonterminal : float array (d - 1,)
76 The probability of a nonterminal node at each depth. The maximum depth
77 of trees is fixed by the length of this array.
78 sigma2_alpha : float
79 The shape parameter of the inverse gamma prior on the noise variance.
80 sigma2_beta : float
81 The scale parameter of the inverse gamma prior on the noise variance.
82 small_float : dtype, default float32
83 The dtype for large arrays used in the algorithm.
84 large_float : dtype, default float32
85 The dtype for scalars, small arrays, and arrays which require accuracy.
86 min_points_per_leaf : int, optional
87 The minimum number of data points in a leaf node. 0 if not specified.
88 resid_batch_size, count_batch_sizes : int, None, str, default 'auto'
89 The batch sizes, along datapoints, for summing the residuals and
90 counting the number of datapoints in each leaf. `None` for no batching.
91 If 'auto', pick a value based on the device of `y`, or the default
92 device.
93 save_ratios : bool, default False
94 Whether to save the Metropolis-Hastings ratios.
96 Returns
97 -------
98 bart : dict
99 A dictionary with array values, representing a BART mcmc state. The
100 keys are:
102 'leaf_trees' : small_float array (num_trees, 2 ** d)
103 The leaf values.
104 'var_trees' : int array (num_trees, 2 ** (d - 1))
105 The decision axes.
106 'split_trees' : int array (num_trees, 2 ** (d - 1))
107 The decision boundaries.
108 'resid' : large_float array (n,)
109 The residuals (data minus forest value). Large float to avoid
110 roundoff.
111 'sigma2' : large_float
112 The noise variance.
113 'grow_prop_count', 'prune_prop_count' : int
114 The number of grow/prune proposals made during one full MCMC cycle.
115 'grow_acc_count', 'prune_acc_count' : int
116 The number of grow/prune moves accepted during one full MCMC cycle.
117 'p_nonterminal' : large_float array (d,)
118 The probability of a nonterminal node at each depth, padded with a
119 zero.
120 'p_propose_grow' : large_float array (2 ** (d - 1),)
121 The unnormalized probability of picking a leaf for a grow proposal.
122 'sigma2_alpha' : large_float
123 The shape parameter of the inverse gamma prior on the noise variance.
124 'sigma2_beta' : large_float
125 The scale parameter of the inverse gamma prior on the noise variance.
126 'max_split' : int array (p,)
127 The maximum split index for each variable.
128 'y' : small_float array (n,)
129 The response.
130 'X' : int array (p, n)
131 The predictors.
132 'leaf_indices' : int array (num_trees, n)
133 The index of the leaf each datapoints falls into, for each tree.
134 'min_points_per_leaf' : int or None
135 The minimum number of data points in a leaf node.
136 'affluence_trees' : bool array (num_trees, 2 ** (d - 1)) or None
137 Whether a non-bottom leaf nodes contains twice `min_points_per_leaf`
138 datapoints. If `min_points_per_leaf` is not specified, this is None.
139 'opt' : LeafDict
140 A dictionary with config values:
142 'small_float' : dtype
143 The dtype for large arrays used in the algorithm.
144 'large_float' : dtype
145 The dtype for scalars, small arrays, and arrays which require
146 accuracy.
147 'require_min_points' : bool
148 Whether the `min_points_per_leaf` parameter is specified.
149 'resid_batch_size', 'count_batch_size' : int or None
150 The data batch sizes for computing the sufficient statistics.
151 'ratios' : dict, optional
152 If `save_ratios` is True, this field is present. It has the fields:
154 'log_trans_prior' : large_float array (num_trees,)
155 The log transition and prior Metropolis-Hastings ratio for the
156 proposed move on each tree.
157 'log_likelihood' : large_float array (num_trees,)
158 The log likelihood ratio.
159 """
161 p_nonterminal = jnp.asarray(p_nonterminal, large_float) 1a
162 p_nonterminal = jnp.pad(p_nonterminal, (0, 1)) 1a
163 max_depth = p_nonterminal.size 1a
165 @functools.partial(jax.vmap, in_axes=None, out_axes=0, axis_size=num_trees) 1a
166 def make_forest(max_depth, dtype): 1a
167 return grove.make_tree(max_depth, dtype) 1a
169 small_float = jnp.dtype(small_float) 1a
170 large_float = jnp.dtype(large_float) 1a
171 y = jnp.asarray(y, small_float) 1a
172 resid_batch_size, count_batch_size = _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, 2 ** max_depth * num_trees) 1a
173 sigma2 = jnp.array(sigma2_beta / sigma2_alpha, large_float) 1a
174 sigma2 = jnp.where(jnp.isfinite(sigma2) & (sigma2 > 0), sigma2, 1) 1a
176 bart = dict( 1a
177 leaf_trees=make_forest(max_depth, small_float),
178 var_trees=make_forest(max_depth - 1, jaxext.minimal_unsigned_dtype(X.shape[0] - 1)),
179 split_trees=make_forest(max_depth - 1, max_split.dtype),
180 resid=jnp.asarray(y, large_float),
181 sigma2=sigma2,
182 grow_prop_count=jnp.zeros((), int),
183 grow_acc_count=jnp.zeros((), int),
184 prune_prop_count=jnp.zeros((), int),
185 prune_acc_count=jnp.zeros((), int),
186 p_nonterminal=p_nonterminal,
187 p_propose_grow=p_nonterminal[grove.tree_depths(2 ** (max_depth - 1))],
188 sigma2_alpha=jnp.asarray(sigma2_alpha, large_float),
189 sigma2_beta=jnp.asarray(sigma2_beta, large_float),
190 max_split=jnp.asarray(max_split),
191 y=y,
192 X=jnp.asarray(X),
193 leaf_indices=jnp.ones((num_trees, y.size), jaxext.minimal_unsigned_dtype(2 ** max_depth - 1)),
194 min_points_per_leaf=(
195 None if min_points_per_leaf is None else
196 jnp.asarray(min_points_per_leaf)
197 ),
198 affluence_trees=(
199 None if min_points_per_leaf is None else
200 make_forest(max_depth - 1, bool).at[:, 1].set(y.size >= 2 * min_points_per_leaf)
201 ),
202 opt=jaxext.LeafDict(
203 small_float=small_float,
204 large_float=large_float,
205 require_min_points=min_points_per_leaf is not None,
206 resid_batch_size=resid_batch_size,
207 count_batch_size=count_batch_size,
208 ),
209 )
211 if save_ratios: 1a
212 bart['ratios'] = dict( 1a
213 log_trans_prior=jnp.full(num_trees, jnp.nan),
214 log_likelihood=jnp.full(num_trees, jnp.nan),
215 )
217 return bart 1a
219def _choose_suffstat_batch_size(resid_batch_size, count_batch_size, y, forest_size): 1a
221 @functools.cache 1a
222 def get_platform(): 1a
223 try: 1a
224 device = y.devices().pop() 1a
225 except jax.errors.ConcretizationTypeError: 1a
226 device = jax.devices()[0] 1a
227 platform = device.platform 1a
228 if platform not in ('cpu', 'gpu'): 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true1a
229 raise KeyError(f'Unknown platform: {platform}')
230 return platform 1a
232 if resid_batch_size == 'auto': 1a
233 platform = get_platform() 1a
234 n = max(1, y.size) 1a
235 if platform == 'cpu': 235 ↛ 237line 235 didn't jump to line 237 because the condition on line 235 was always true1a
236 resid_batch_size = 2 ** int(round(math.log2(n / 6))) # n/6 1a
237 elif platform == 'gpu':
238 resid_batch_size = 2 ** int(round((1 + math.log2(n)) / 3)) # n^1/3
239 resid_batch_size = max(1, resid_batch_size) 1a
241 if count_batch_size == 'auto': 1a
242 platform = get_platform() 1a
243 if platform == 'cpu': 243 ↛ 245line 243 didn't jump to line 245 because the condition on line 243 was always true1a
244 count_batch_size = None 1a
245 elif platform == 'gpu':
246 n = max(1, y.size)
247 count_batch_size = 2 ** int(round(math.log2(n) / 2 - 2)) # n^1/2
248 # /4 is good on V100, /2 on L4/T4, still haven't tried A100
249 max_memory = 2 ** 29
250 itemsize = 4
251 min_batch_size = int(math.ceil(forest_size * itemsize * n / max_memory))
252 count_batch_size = max(count_batch_size, min_batch_size)
253 count_batch_size = max(1, count_batch_size)
255 return resid_batch_size, count_batch_size 1a
257def step(bart, key): 1a
258 """
259 Perform one full MCMC step on a BART state.
261 Parameters
262 ----------
263 bart : dict
264 A BART mcmc state, as created by `init`.
265 key : jax.dtypes.prng_key array
266 A jax random key.
268 Returns
269 -------
270 bart : dict
271 The new BART mcmc state.
272 """
273 key, subkey = random.split(key) 1a
274 bart = sample_trees(bart, subkey) 1a
275 return sample_sigma(bart, key) 1a
277def sample_trees(bart, key): 1a
278 """
279 Forest sampling step of BART MCMC.
281 Parameters
282 ----------
283 bart : dict
284 A BART mcmc state, as created by `init`.
285 key : jax.dtypes.prng_key array
286 A jax random key.
288 Returns
289 -------
290 bart : dict
291 The new BART mcmc state.
293 Notes
294 -----
295 This function zeroes the proposal counters.
296 """
297 key, subkey = random.split(key) 1a
298 moves = sample_moves(bart, subkey) 1a
299 return accept_moves_and_sample_leaves(bart, moves, key) 1a
301def sample_moves(bart, key): 1a
302 """
303 Propose moves for all the trees.
305 Parameters
306 ----------
307 bart : dict
308 BART mcmc state.
309 key : jax.dtypes.prng_key array
310 A jax random key.
312 Returns
313 -------
314 moves : dict
315 A dictionary with fields:
317 'allowed' : bool array (num_trees,)
318 Whether the move is possible.
319 'grow' : bool array (num_trees,)
320 Whether the move is a grow move or a prune move.
321 'num_growable' : int array (num_trees,)
322 The number of growable leaves in the original tree.
323 'node' : int array (num_trees,)
324 The index of the leaf to grow or node to prune.
325 'left', 'right' : int array (num_trees,)
326 The indices of the children of 'node'.
327 'partial_ratio' : float array (num_trees,)
328 A factor of the Metropolis-Hastings ratio of the move. It lacks
329 the likelihood ratio and the probability of proposing the prune
330 move. If the move is Prune, the ratio is inverted.
331 'grow_var' : int array (num_trees,)
332 The decision axes of the new rules.
333 'grow_split' : int array (num_trees,)
334 The decision boundaries of the new rules.
335 'var_trees' : int array (num_trees, 2 ** (d - 1))
336 The updated decision axes of the trees, valid whatever move.
337 'logu' : float array (num_trees,)
338 The logarithm of a uniform (0, 1] random variable to be used to
339 accept the move. It's in (-oo, 0].
340 """
341 ntree = bart['leaf_trees'].shape[0] 1a
342 key = random.split(key, 1 + ntree) 1a
343 key, subkey = key[0], key[1:] 1a
345 # compute moves
346 grow_moves, prune_moves = _sample_moves_vmap_trees(bart['var_trees'], bart['split_trees'], bart['affluence_trees'], bart['max_split'], bart['p_nonterminal'], bart['p_propose_grow'], subkey) 1a
348 u, logu = random.uniform(key, (2, ntree), bart['opt']['large_float']) 1a
350 # choose between grow or prune
351 grow_allowed = grow_moves['num_growable'].astype(bool) 1a
352 p_grow = jnp.where(grow_allowed & prune_moves['allowed'], 0.5, grow_allowed) 1a
353 grow = u < p_grow # use < instead of <= because u is in [0, 1) 1a
355 # compute children indices
356 node = jnp.where(grow, grow_moves['node'], prune_moves['node']) 1a
357 left = node << 1 1a
358 right = left + 1 1a
360 return dict( 1a
361 allowed=grow | prune_moves['allowed'],
362 grow=grow,
363 num_growable=grow_moves['num_growable'],
364 node=node,
365 left=left,
366 right=right,
367 partial_ratio=jnp.where(grow, grow_moves['partial_ratio'], prune_moves['partial_ratio']),
368 grow_var=grow_moves['var'],
369 grow_split=grow_moves['split'],
370 var_trees=grow_moves['var_tree'],
371 logu=jnp.log1p(-logu),
372 )
374@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, 0, None, None, None, 0)) 1a
375def _sample_moves_vmap_trees(*args): 1a
376 args, key = args[:-1], args[-1] 1a
377 key, key1 = random.split(key) 1a
378 grow = grow_move(*args, key) 1a
379 prune = prune_move(*args, key1) 1a
380 return grow, prune 1a
382def grow_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key): 1a
383 """
384 Tree structure grow move proposal of BART MCMC.
386 This moves picks a leaf node and converts it to a non-terminal node with
387 two leaf children. The move is not possible if all the leaves are already at
388 maximum depth.
390 Parameters
391 ----------
392 var_tree : array (2 ** (d - 1),)
393 The variable indices of the tree.
394 split_tree : array (2 ** (d - 1),)
395 The splitting points of the tree.
396 affluence_tree : bool array (2 ** (d - 1),) or None
397 Whether a leaf has enough points to be grown.
398 max_split : array (p,)
399 The maximum split index for each variable.
400 p_nonterminal : array (d,)
401 The probability of a nonterminal node at each depth.
402 p_propose_grow : array (2 ** (d - 1),)
403 The unnormalized probability of choosing a leaf to grow.
404 key : jax.dtypes.prng_key array
405 A jax random key.
407 Returns
408 -------
409 grow_move : dict
410 A dictionary with fields:
412 'num_growable' : int
413 The number of growable leaves.
414 'node' : int
415 The index of the leaf to grow. ``2 ** d`` if there are no growable
416 leaves.
417 'var', 'split' : int
418 The decision axis and boundary of the new rule.
419 'partial_ratio' : float
420 A factor of the Metropolis-Hastings ratio of the move. It lacks
421 the likelihood ratio and the probability of proposing the prune
422 move.
423 'var_tree' : array (2 ** (d - 1),)
424 The updated decision axes of the tree.
425 """
427 key, key1, key2 = random.split(key, 3) 1a
429 leaf_to_grow, num_growable, prob_choose, num_prunable = choose_leaf(split_tree, affluence_tree, p_propose_grow, key) 1a
431 var = choose_variable(var_tree, split_tree, max_split, leaf_to_grow, key1) 1a
432 var_tree = var_tree.at[leaf_to_grow].set(var.astype(var_tree.dtype)) 1a
434 split = choose_split(var_tree, split_tree, max_split, leaf_to_grow, key2) 1a
436 ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow) 1a
438 return dict( 1a
439 num_growable=num_growable,
440 node=leaf_to_grow,
441 var=var,
442 split=split,
443 partial_ratio=ratio,
444 var_tree=var_tree,
445 )
447def choose_leaf(split_tree, affluence_tree, p_propose_grow, key): 1a
448 """
449 Choose a leaf node to grow in a tree.
451 Parameters
452 ----------
453 split_tree : array (2 ** (d - 1),)
454 The splitting points of the tree.
455 affluence_tree : bool array (2 ** (d - 1),) or None
456 Whether a leaf has enough points to be grown.
457 p_propose_grow : array (2 ** (d - 1),)
458 The unnormalized probability of choosing a leaf to grow.
459 key : jax.dtypes.prng_key array
460 A jax random key.
462 Returns
463 -------
464 leaf_to_grow : int
465 The index of the leaf to grow. If ``num_growable == 0``, return
466 ``2 ** d``.
467 num_growable : int
468 The number of leaf nodes that can be grown.
469 prob_choose : float
470 The normalized probability of choosing the selected leaf.
471 num_prunable : int
472 The number of leaf parents that could be pruned, after converting the
473 selected leaf to a non-terminal node.
474 """
475 is_growable = growable_leaves(split_tree, affluence_tree) 1a
476 num_growable = jnp.count_nonzero(is_growable) 1a
477 distr = jnp.where(is_growable, p_propose_grow, 0) 1a
478 leaf_to_grow, distr_norm = categorical(key, distr) 1a
479 leaf_to_grow = jnp.where(num_growable, leaf_to_grow, 2 * split_tree.size) 1a
480 prob_choose = distr[leaf_to_grow] / distr_norm 1a
481 is_parent = grove.is_leaves_parent(split_tree.at[leaf_to_grow].set(1)) 1a
482 num_prunable = jnp.count_nonzero(is_parent) 1a
483 return leaf_to_grow, num_growable, prob_choose, num_prunable 1a
485def growable_leaves(split_tree, affluence_tree): 1a
486 """
487 Return a mask indicating the leaf nodes that can be proposed for growth.
489 Parameters
490 ----------
491 split_tree : array (2 ** (d - 1),)
492 The splitting points of the tree.
493 affluence_tree : bool array (2 ** (d - 1),) or None
494 Whether a leaf has enough points to be grown.
496 Returns
497 -------
498 is_growable : bool array (2 ** (d - 1),)
499 The mask indicating the leaf nodes that can be proposed to grow, i.e.,
500 that are not at the bottom level and have at least two times the number
501 of minimum points per leaf.
502 """
503 is_growable = grove.is_actual_leaf(split_tree) 1a
504 if affluence_tree is not None: 504 ↛ 506line 504 didn't jump to line 506 because the condition on line 504 was always true1a
505 is_growable &= affluence_tree 1a
506 return is_growable 1a
508def categorical(key, distr): 1a
509 """
510 Return a random integer from an arbitrary distribution.
512 Parameters
513 ----------
514 key : jax.dtypes.prng_key array
515 A jax random key.
516 distr : float array (n,)
517 An unnormalized probability distribution.
519 Returns
520 -------
521 u : int
522 A random integer in the range ``[0, n)``. If all probabilities are zero,
523 return ``n``.
524 """
525 ecdf = jnp.cumsum(distr) 1a
526 u = random.uniform(key, (), ecdf.dtype, 0, ecdf[-1]) 1a
527 return jnp.searchsorted(ecdf, u, 'right'), ecdf[-1] 1a
529def choose_variable(var_tree, split_tree, max_split, leaf_index, key): 1a
530 """
531 Choose a variable to split on for a new non-terminal node.
533 Parameters
534 ----------
535 var_tree : int array (2 ** (d - 1),)
536 The variable indices of the tree.
537 split_tree : int array (2 ** (d - 1),)
538 The splitting points of the tree.
539 max_split : int array (p,)
540 The maximum split index for each variable.
541 leaf_index : int
542 The index of the leaf to grow.
543 key : jax.dtypes.prng_key array
544 A jax random key.
546 Returns
547 -------
548 var : int
549 The index of the variable to split on.
551 Notes
552 -----
553 The variable is chosen among the variables that have a non-empty range of
554 allowed splits. If no variable has a non-empty range, return `p`.
555 """
556 var_to_ignore = fully_used_variables(var_tree, split_tree, max_split, leaf_index) 1a
557 return randint_exclude(key, max_split.size, var_to_ignore) 1a
559def fully_used_variables(var_tree, split_tree, max_split, leaf_index): 1a
560 """
561 Return a list of variables that have an empty split range at a given node.
563 Parameters
564 ----------
565 var_tree : int array (2 ** (d - 1),)
566 The variable indices of the tree.
567 split_tree : int array (2 ** (d - 1),)
568 The splitting points of the tree.
569 max_split : int array (p,)
570 The maximum split index for each variable.
571 leaf_index : int
572 The index of the node, assumed to be valid for `var_tree`.
574 Returns
575 -------
576 var_to_ignore : int array (d - 2,)
577 The indices of the variables that have an empty split range. Since the
578 number of such variables is not fixed, unused values in the array are
579 filled with `p`. The fill values are not guaranteed to be placed in any
580 particular order. Variables may appear more than once.
581 """
583 var_to_ignore = ancestor_variables(var_tree, max_split, leaf_index) 1a
584 split_range_vec = jax.vmap(split_range, in_axes=(None, None, None, None, 0)) 1a
585 l, r = split_range_vec(var_tree, split_tree, max_split, leaf_index, var_to_ignore) 1a
586 num_split = r - l 1a
587 return jnp.where(num_split == 0, var_to_ignore, max_split.size) 1a
589def ancestor_variables(var_tree, max_split, node_index): 1a
590 """
591 Return the list of variables in the ancestors of a node.
593 Parameters
594 ----------
595 var_tree : int array (2 ** (d - 1),)
596 The variable indices of the tree.
597 max_split : int array (p,)
598 The maximum split index for each variable. Used only to get `p`.
599 node_index : int
600 The index of the node, assumed to be valid for `var_tree`.
602 Returns
603 -------
604 ancestor_vars : int array (d - 2,)
605 The variable indices of the ancestors of the node, from the root to
606 the parent. Unused spots are filled with `p`.
607 """
608 max_num_ancestors = grove.tree_depth(var_tree) - 1 1a
609 ancestor_vars = jnp.zeros(max_num_ancestors, jaxext.minimal_unsigned_dtype(max_split.size)) 1a
610 carry = ancestor_vars.size - 1, node_index, ancestor_vars 1a
611 def loop(carry, _): 1a
612 i, index, ancestor_vars = carry 1a
613 index >>= 1 1a
614 var = var_tree[index] 1a
615 var = jnp.where(index, var, max_split.size) 1a
616 ancestor_vars = ancestor_vars.at[i].set(var) 1a
617 return (i - 1, index, ancestor_vars), None 1a
618 (_, _, ancestor_vars), _ = lax.scan(loop, carry, None, ancestor_vars.size) 1a
619 return ancestor_vars 1a
621def split_range(var_tree, split_tree, max_split, node_index, ref_var): 1a
622 """
623 Return the range of allowed splits for a variable at a given node.
625 Parameters
626 ----------
627 var_tree : int array (2 ** (d - 1),)
628 The variable indices of the tree.
629 split_tree : int array (2 ** (d - 1),)
630 The splitting points of the tree.
631 max_split : int array (p,)
632 The maximum split index for each variable.
633 node_index : int
634 The index of the node, assumed to be valid for `var_tree`.
635 ref_var : int
636 The variable for which to measure the split range.
638 Returns
639 -------
640 l, r : int
641 The range of allowed splits is [l, r).
642 """
643 max_num_ancestors = grove.tree_depth(var_tree) - 1 1a
644 initial_r = 1 + max_split.at[ref_var].get(mode='fill', fill_value=0).astype(jnp.int32) 1a
645 carry = 0, initial_r, node_index 1a
646 def loop(carry, _): 1a
647 l, r, index = carry 1a
648 right_child = (index & 1).astype(bool) 1a
649 index >>= 1 1a
650 split = split_tree[index] 1a
651 cond = (var_tree[index] == ref_var) & index.astype(bool) 1a
652 l = jnp.where(cond & right_child, jnp.maximum(l, split), l) 1a
653 r = jnp.where(cond & ~right_child, jnp.minimum(r, split), r) 1a
654 return (l, r, index), None 1a
655 (l, r, _), _ = lax.scan(loop, carry, None, max_num_ancestors) 1a
656 return l + 1, r 1a
658def randint_exclude(key, sup, exclude): 1a
659 """
660 Return a random integer in a range, excluding some values.
662 Parameters
663 ----------
664 key : jax.dtypes.prng_key array
665 A jax random key.
666 sup : int
667 The exclusive upper bound of the range.
668 exclude : int array (n,)
669 The values to exclude from the range. Values greater than or equal to
670 `sup` are ignored. Values can appear more than once.
672 Returns
673 -------
674 u : int
675 A random integer in the range ``[0, sup)``, and which satisfies
676 ``u not in exclude``. If all values in the range are excluded, return
677 `sup`.
678 """
679 exclude = jnp.unique(exclude, size=exclude.size, fill_value=sup) 1a
680 num_allowed = sup - jnp.count_nonzero(exclude < sup) 1a
681 u = random.randint(key, (), 0, num_allowed) 1a
682 def loop(u, i): 1a
683 return jnp.where(i <= u, u + 1, u), None 1a
684 u, _ = lax.scan(loop, u, exclude) 1a
685 return u 1a
687def choose_split(var_tree, split_tree, max_split, leaf_index, key): 1a
688 """
689 Choose a split point for a new non-terminal node.
691 Parameters
692 ----------
693 var_tree : int array (2 ** (d - 1),)
694 The variable indices of the tree.
695 split_tree : int array (2 ** (d - 1),)
696 The splitting points of the tree.
697 max_split : int array (p,)
698 The maximum split index for each variable.
699 leaf_index : int
700 The index of the leaf to grow. It is assumed that `var_tree` already
701 contains the target variable at this index.
702 key : jax.dtypes.prng_key array
703 A jax random key.
705 Returns
706 -------
707 split : int
708 The split point.
709 """
710 var = var_tree[leaf_index] 1a
711 l, r = split_range(var_tree, split_tree, max_split, leaf_index, var) 1a
712 return random.randint(key, (), l, r) 1a
714def compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, leaf_to_grow): 1a
715 """
716 Compute the product of the transition and prior ratios of a grow move.
718 Parameters
719 ----------
720 num_growable : int
721 The number of leaf nodes that can be grown.
722 num_prunable : int
723 The number of leaf parents that could be pruned, after converting the
724 leaf to be grown to a non-terminal node.
725 p_nonterminal : array (d,)
726 The probability of a nonterminal node at each depth.
727 leaf_to_grow : int
728 The index of the leaf to grow.
730 Returns
731 -------
732 ratio : float
733 The transition ratio P(new tree -> old tree) / P(old tree -> new tree)
734 times the prior ratio P(new tree) / P(old tree), but the transition
735 ratio is missing the factor P(propose prune) in the numerator.
736 """
738 # the two ratios also contain factors num_available_split *
739 # num_available_var, but they cancel out
741 # p_prune can't be computed here because it needs the count trees, which are
742 # computed in the acceptance phase
744 prune_allowed = leaf_to_grow != 1 1a
745 # prune allowed <---> the initial tree is not a root
746 # leaf to grow is root --> the tree can only be a root
747 # tree is a root --> the only leaf I can grow is root
749 p_grow = jnp.where(prune_allowed, 0.5, 1) 1a
751 inv_trans_ratio = p_grow * prob_choose * num_prunable 1a
753 depth = grove.tree_depths(2 ** (p_nonterminal.size - 1))[leaf_to_grow] 1a
754 p_parent = p_nonterminal[depth] 1a
755 cp_children = 1 - p_nonterminal[depth + 1] 1a
756 tree_ratio = cp_children * cp_children * p_parent / (1 - p_parent) 1a
758 return tree_ratio / inv_trans_ratio 1a
760def prune_move(var_tree, split_tree, affluence_tree, max_split, p_nonterminal, p_propose_grow, key): 1a
761 """
762 Tree structure prune move proposal of BART MCMC.
764 Parameters
765 ----------
766 var_tree : int array (2 ** (d - 1),)
767 The variable indices of the tree.
768 split_tree : int array (2 ** (d - 1),)
769 The splitting points of the tree.
770 affluence_tree : bool array (2 ** (d - 1),) or None
771 Whether a leaf has enough points to be grown.
772 max_split : int array (p,)
773 The maximum split index for each variable.
774 p_nonterminal : float array (d,)
775 The probability of a nonterminal node at each depth.
776 p_propose_grow : float array (2 ** (d - 1),)
777 The unnormalized probability of choosing a leaf to grow.
778 key : jax.dtypes.prng_key array
779 A jax random key.
781 Returns
782 -------
783 prune_move : dict
784 A dictionary with fields:
786 'allowed' : bool
787 Whether the move is possible.
788 'node' : int
789 The index of the node to prune. ``2 ** d`` if no node can be pruned.
790 'partial_ratio' : float
791 A factor of the Metropolis-Hastings ratio of the move. It lacks
792 the likelihood ratio and the probability of proposing the prune
793 move. This ratio is inverted.
794 """
795 node_to_prune, num_prunable, prob_choose = choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key) 1a
796 allowed = split_tree[1].astype(bool) # allowed iff the tree is not a root 1a
798 ratio = compute_partial_ratio(prob_choose, num_prunable, p_nonterminal, node_to_prune) 1a
800 return dict( 1a
801 allowed=allowed,
802 node=node_to_prune,
803 partial_ratio=ratio, # it is inverted in accept_move_and_sample_leaves
804 )
806def choose_leaf_parent(split_tree, affluence_tree, p_propose_grow, key): 1a
807 """
808 Pick a non-terminal node with leaf children to prune in a tree.
810 Parameters
811 ----------
812 split_tree : array (2 ** (d - 1),)
813 The splitting points of the tree.
814 affluence_tree : bool array (2 ** (d - 1),) or None
815 Whether a leaf has enough points to be grown.
816 p_propose_grow : array (2 ** (d - 1),)
817 The unnormalized probability of choosing a leaf to grow.
818 key : jax.dtypes.prng_key array
819 A jax random key.
821 Returns
822 -------
823 node_to_prune : int
824 The index of the node to prune. If ``num_prunable == 0``, return
825 ``2 ** d``.
826 num_prunable : int
827 The number of leaf parents that could be pruned.
828 prob_choose : float
829 The normalized probability of choosing the node to prune for growth.
830 """
831 is_prunable = grove.is_leaves_parent(split_tree) 1a
832 num_prunable = jnp.count_nonzero(is_prunable) 1a
833 node_to_prune = randint_masked(key, is_prunable) 1a
834 node_to_prune = jnp.where(num_prunable, node_to_prune, 2 * split_tree.size) 1a
836 split_tree = split_tree.at[node_to_prune].set(0) 1a
837 affluence_tree = ( 1a
838 None if affluence_tree is None else
839 affluence_tree.at[node_to_prune].set(True)
840 )
841 is_growable_leaf = growable_leaves(split_tree, affluence_tree) 1a
842 prob_choose = p_propose_grow[node_to_prune] 1a
843 prob_choose /= jnp.sum(p_propose_grow, where=is_growable_leaf) 1a
845 return node_to_prune, num_prunable, prob_choose 1a
847def randint_masked(key, mask): 1a
848 """
849 Return a random integer in a range, including only some values.
851 Parameters
852 ----------
853 key : jax.dtypes.prng_key array
854 A jax random key.
855 mask : bool array (n,)
856 The mask indicating the allowed values.
858 Returns
859 -------
860 u : int
861 A random integer in the range ``[0, n)``, and which satisfies
862 ``mask[u] == True``. If all values in the mask are `False`, return `n`.
863 """
864 ecdf = jnp.cumsum(mask) 1a
865 u = random.randint(key, (), 0, ecdf[-1]) 1a
866 return jnp.searchsorted(ecdf, u, 'right') 1a
868def accept_moves_and_sample_leaves(bart, moves, key): 1a
869 """
870 Accept or reject the proposed moves and sample the new leaf values.
872 Parameters
873 ----------
874 bart : dict
875 A BART mcmc state.
876 moves : dict
877 The proposed moves, see `sample_moves`.
878 key : jax.dtypes.prng_key array
879 A jax random key.
881 Returns
882 -------
883 bart : dict
884 The new BART mcmc state.
885 """
886 bart, moves, count_trees, move_counts, prelkv, prelk, prelf = accept_moves_parallel_stage(bart, moves, key) 1a
887 bart, moves = accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf) 1a
888 return accept_moves_final_stage(bart, moves) 1a
890def accept_moves_parallel_stage(bart, moves, key): 1a
891 """
892 Pre-computes quantities used to accept moves, in parallel across trees.
894 Parameters
895 ----------
896 bart : dict
897 A BART mcmc state.
898 moves : dict
899 The proposed moves, see `sample_moves`.
900 key : jax.dtypes.prng_key array
901 A jax random key.
903 Returns
904 -------
905 bart : dict
906 A partially updated BART mcmc state.
907 moves : dict
908 The proposed moves, with the field 'partial_ratio' replaced
909 by 'log_trans_prior_ratio'.
910 count_trees : array (num_trees, 2 ** d)
911 The number of points in each potential or actual leaf node.
912 move_counts : dict
913 The counts of the number of points in the the nodes modified by the
914 moves.
915 prelkv, prelk, prelf : dict
916 Dictionary with pre-computed terms of the likelihood ratios and leaf
917 samples.
918 """
919 bart = bart.copy() 1a
921 # where the move is grow, modify the state like the move was accepted
922 bart['var_trees'] = moves['var_trees'] 1a
923 bart['leaf_indices'] = apply_grow_to_indices(moves, bart['leaf_indices'], bart['X']) 1a
924 bart['leaf_trees'] = adapt_leaf_trees_to_grow_indices(bart['leaf_trees'], moves) 1a
926 # count number of datapoints per leaf
927 count_trees, move_counts = compute_count_trees(bart['leaf_indices'], moves, bart['opt']['count_batch_size']) 1a
928 if bart['opt']['require_min_points']: 928 ↛ 933line 928 didn't jump to line 933 because the condition on line 928 was always true1a
929 count_half_trees = count_trees[:, :bart['var_trees'].shape[1]] 1a
930 bart['affluence_trees'] = count_half_trees >= 2 * bart['min_points_per_leaf'] 1a
932 # compute some missing information about moves
933 moves = complete_ratio(moves, move_counts, bart['min_points_per_leaf']) 1a
934 bart['grow_prop_count'] = jnp.sum(moves['grow']) 1a
935 bart['prune_prop_count'] = jnp.sum(moves['allowed'] & ~moves['grow']) 1a
937 prelkv, prelk = precompute_likelihood_terms(count_trees, bart['sigma2'], move_counts) 1a
938 prelf = precompute_leaf_terms(count_trees, bart['sigma2'], key) 1a
940 return bart, moves, count_trees, move_counts, prelkv, prelk, prelf 1a
942@functools.partial(jaxext.vmap_nodoc, in_axes=(0, 0, None)) 1a
943def apply_grow_to_indices(moves, leaf_indices, X): 1a
944 """
945 Update the leaf indices to apply a grow move.
947 Parameters
948 ----------
949 moves : dict
950 The proposed moves, see `sample_moves`.
951 leaf_indices : array (num_trees, n)
952 The index of the leaf each datapoint falls into.
953 X : array (p, n)
954 The predictors matrix.
956 Returns
957 -------
958 grow_leaf_indices : array (num_trees, n)
959 The updated leaf indices.
960 """
961 left_child = moves['node'].astype(leaf_indices.dtype) << 1 1a
962 go_right = X[moves['grow_var'], :] >= moves['grow_split'] 1a
963 tree_size = jnp.array(2 * moves['var_trees'].size) 1a
964 node_to_update = jnp.where(moves['grow'], moves['node'], tree_size) 1a
965 return jnp.where( 1a
966 leaf_indices == node_to_update,
967 left_child + go_right,
968 leaf_indices,
969 )
971def compute_count_trees(leaf_indices, moves, batch_size): 1a
972 """
973 Count the number of datapoints in each leaf.
975 Parameters
976 ----------
977 grow_leaf_indices : int array (num_trees, n)
978 The index of the leaf each datapoint falls into, if the grow move is
979 accepted.
980 moves : dict
981 The proposed moves, see `sample_moves`.
982 batch_size : int or None
983 The data batch size to use for the summation.
985 Returns
986 -------
987 count_trees : int array (num_trees, 2 ** (d - 1))
988 The number of points in each potential or actual leaf node.
989 counts : dict
990 The counts of the number of points in the the nodes modified by the
991 moves, organized as two dictionaries 'grow' and 'prune', with subfields
992 'left', 'right', and 'total'.
993 """
995 ntree, tree_size = moves['var_trees'].shape 1a
996 tree_size *= 2 1a
997 tree_indices = jnp.arange(ntree) 1a
999 count_trees = count_datapoints_per_leaf(leaf_indices, tree_size, batch_size) 1a
1001 # count datapoints in nodes modified by move
1002 counts = dict() 1a
1003 counts['left'] = count_trees[tree_indices, moves['left']] 1a
1004 counts['right'] = count_trees[tree_indices, moves['right']] 1a
1005 counts['total'] = counts['left'] + counts['right'] 1a
1007 # write count into non-leaf node
1008 count_trees = count_trees.at[tree_indices, moves['node']].set(counts['total']) 1a
1010 return count_trees, counts 1a
1012def count_datapoints_per_leaf(leaf_indices, tree_size, batch_size): 1a
1013 """
1014 Count the number of datapoints in each leaf.
1016 Parameters
1017 ----------
1018 leaf_indices : int array (num_trees, n)
1019 The index of the leaf each datapoint falls into.
1020 tree_size : int
1021 The size of the leaf tree array (2 ** d).
1022 batch_size : int or None
1023 The data batch size to use for the summation.
1025 Returns
1026 -------
1027 count_trees : int array (num_trees, 2 ** (d - 1))
1028 The number of points in each leaf node.
1029 """
1030 if batch_size is None: 1a
1031 return _count_scan(leaf_indices, tree_size) 1a
1032 else:
1033 return _count_vec(leaf_indices, tree_size, batch_size) 1a
1035def _count_scan(leaf_indices, tree_size): 1a
1036 def loop(_, leaf_indices): 1a
1037 return None, _aggregate_scatter(1, leaf_indices, tree_size, jnp.uint32) 1a
1038 _, count_trees = lax.scan(loop, None, leaf_indices) 1a
1039 return count_trees 1a
1041def _aggregate_scatter(values, indices, size, dtype): 1a
1042 return (jnp 1a
1043 .zeros(size, dtype)
1044 .at[indices]
1045 .add(values)
1046 )
1048def _count_vec(leaf_indices, tree_size, batch_size): 1a
1049 return _aggregate_batched_alltrees(1, leaf_indices, tree_size, jnp.uint32, batch_size) 1a
1050 # uint16 is super-slow on gpu, don't use it even if n < 2^16
1052def _aggregate_batched_alltrees(values, indices, size, dtype, batch_size): 1a
1053 ntree, n = indices.shape 1a
1054 tree_indices = jnp.arange(ntree) 1a
1055 nbatches = n // batch_size + bool(n % batch_size) 1a
1056 batch_indices = jnp.arange(n) % nbatches 1a
1057 return (jnp 1a
1058 .zeros((ntree, size, nbatches), dtype)
1059 .at[tree_indices[:, None], indices, batch_indices]
1060 .add(values)
1061 .sum(axis=2)
1062 )
1064def complete_ratio(moves, move_counts, min_points_per_leaf): 1a
1065 """
1066 Complete non-likelihood MH ratio calculation.
1068 This functions adds the probability of choosing the prune move.
1070 Parameters
1071 ----------
1072 moves : dict
1073 The proposed moves, see `sample_moves`.
1074 move_counts : dict
1075 The counts of the number of points in the the nodes modified by the
1076 moves.
1077 min_points_per_leaf : int or None
1078 The minimum number of data points in a leaf node.
1080 Returns
1081 -------
1082 moves : dict
1083 The updated moves, with the field 'partial_ratio' replaced by
1084 'log_trans_prior_ratio'.
1085 """
1086 moves = moves.copy() 1a
1087 p_prune = compute_p_prune(moves, move_counts['left'], move_counts['right'], min_points_per_leaf) 1a
1088 moves['log_trans_prior_ratio'] = jnp.log(moves.pop('partial_ratio') * p_prune) 1a
1089 return moves 1a
1091def compute_p_prune(moves, left_count, right_count, min_points_per_leaf): 1a
1092 """
1093 Compute the probability of proposing a prune move.
1095 Parameters
1096 ----------
1097 moves : dict
1098 The proposed moves, see `sample_moves`.
1099 left_count, right_count : int
1100 The number of datapoints in the proposed children of the leaf to grow.
1101 min_points_per_leaf : int or None
1102 The minimum number of data points in a leaf node.
1104 Returns
1105 -------
1106 p_prune : float
1107 The probability of proposing a prune move. If grow: after accepting the
1108 grow move, if prune: right away.
1109 """
1111 # calculation in case the move is grow
1112 other_growable_leaves = moves['num_growable'] >= 2 1a
1113 new_leaves_growable = moves['node'] < moves['var_trees'].shape[1] // 2 1a
1114 if min_points_per_leaf is not None: 1114 ↛ 1118line 1114 didn't jump to line 1118 because the condition on line 1114 was always true1a
1115 any_above_threshold = left_count >= 2 * min_points_per_leaf 1a
1116 any_above_threshold |= right_count >= 2 * min_points_per_leaf 1a
1117 new_leaves_growable &= any_above_threshold 1a
1118 grow_again_allowed = other_growable_leaves | new_leaves_growable 1a
1119 grow_p_prune = jnp.where(grow_again_allowed, 0.5, 1) 1a
1121 # calculation in case the move is prune
1122 prune_p_prune = jnp.where(moves['num_growable'], 0.5, 1) 1a
1124 return jnp.where(moves['grow'], grow_p_prune, prune_p_prune) 1a
1126@jaxext.vmap_nodoc 1a
1127def adapt_leaf_trees_to_grow_indices(leaf_trees, moves): 1a
1128 """
1129 Modify leaf values such that the indices of the grow moves work on the
1130 original tree.
1132 Parameters
1133 ----------
1134 leaf_trees : float array (num_trees, 2 ** d)
1135 The leaf values.
1136 moves : dict
1137 The proposed moves, see `sample_moves`.
1139 Returns
1140 -------
1141 leaf_trees : float array (num_trees, 2 ** d)
1142 The modified leaf values. The value of the leaf to grow is copied to
1143 what would be its children if the grow move was accepted.
1144 """
1145 values_at_node = leaf_trees[moves['node']] 1a
1146 return (leaf_trees 1a
1147 .at[jnp.where(moves['grow'], moves['left'], leaf_trees.size)]
1148 .set(values_at_node)
1149 .at[jnp.where(moves['grow'], moves['right'], leaf_trees.size)]
1150 .set(values_at_node)
1151 )
1153def precompute_likelihood_terms(count_trees, sigma2, move_counts): 1a
1154 """
1155 Pre-compute terms used in the likelihood ratio of the acceptance step.
1157 Parameters
1158 ----------
1159 count_trees : array (num_trees, 2 ** d)
1160 The number of points in each potential or actual leaf node.
1161 sigma2 : float
1162 The noise variance.
1163 move_counts : dict
1164 The counts of the number of points in the the nodes modified by the
1165 moves.
1167 Returns
1168 -------
1169 prelkv : dict
1170 Dictionary with pre-computed terms of the likelihood ratio, one per
1171 tree.
1172 prelk : dict
1173 Dictionary with pre-computed terms of the likelihood ratio, shared by
1174 all trees.
1175 """
1176 ntree = len(count_trees) 1a
1177 sigma_mu2 = 1 / ntree 1a
1178 prelkv = dict() 1a
1179 prelkv['sigma2_left'] = sigma2 + move_counts['left'] * sigma_mu2 1a
1180 prelkv['sigma2_right'] = sigma2 + move_counts['right'] * sigma_mu2 1a
1181 prelkv['sigma2_total'] = sigma2 + move_counts['total'] * sigma_mu2 1a
1182 prelkv['sqrt_term'] = jnp.log( 1a
1183 sigma2 * prelkv['sigma2_total'] /
1184 (prelkv['sigma2_left'] * prelkv['sigma2_right'])
1185 ) / 2
1186 return prelkv, dict( 1a
1187 exp_factor=sigma_mu2 / (2 * sigma2),
1188 )
1190def precompute_leaf_terms(count_trees, sigma2, key): 1a
1191 """
1192 Pre-compute terms used to sample leaves from their posterior.
1194 Parameters
1195 ----------
1196 count_trees : array (num_trees, 2 ** d)
1197 The number of points in each potential or actual leaf node.
1198 sigma2 : float
1199 The noise variance.
1200 key : jax.dtypes.prng_key array
1201 A jax random key.
1203 Returns
1204 -------
1205 prelf : dict
1206 Dictionary with pre-computed terms of the leaf sampling, with fields:
1208 'mean_factor' : float array (num_trees, 2 ** d)
1209 The factor to be multiplied by the sum of residuals to obtain the
1210 posterior mean.
1211 'centered_leaves' : float array (num_trees, 2 ** d)
1212 The mean-zero normal values to be added to the posterior mean to
1213 obtain the posterior leaf samples.
1214 """
1215 ntree = len(count_trees) 1a
1216 prec_lk = count_trees / sigma2 1a
1217 var_post = lax.reciprocal(prec_lk + ntree) # = 1 / (prec_lk + prec_prior) 1a
1218 z = random.normal(key, count_trees.shape, sigma2.dtype) 1a
1219 return dict( 1a
1220 mean_factor=var_post / sigma2, # = mean_lk * prec_lk * var_post / resid_tree
1221 centered_leaves=z * jnp.sqrt(var_post),
1222 )
1224def accept_moves_sequential_stage(bart, count_trees, moves, move_counts, prelkv, prelk, prelf): 1a
1225 """
1226 The part of accepting the moves that has to be done one tree at a time.
1228 Parameters
1229 ----------
1230 bart : dict
1231 A partially updated BART mcmc state.
1232 count_trees : array (num_trees, 2 ** d)
1233 The number of points in each potential or actual leaf node.
1234 moves : dict
1235 The proposed moves, see `sample_moves`.
1236 move_counts : dict
1237 The counts of the number of points in the the nodes modified by the
1238 moves.
1239 prelkv, prelk, prelf : dict
1240 Dictionaries with pre-computed terms of the likelihood ratios and leaf
1241 samples.
1243 Returns
1244 -------
1245 bart : dict
1246 A partially updated BART mcmc state.
1247 moves : dict
1248 The proposed moves, with these additional fields:
1250 'acc' : bool array (num_trees,)
1251 Whether the move was accepted.
1252 'to_prune' : bool array (num_trees,)
1253 Whether, to reflect the acceptance status of the move, the state
1254 should be updated by pruning the leaves involved in the move.
1255 """
1256 bart = bart.copy() 1a
1257 moves = moves.copy() 1a
1259 def loop(resid, item): 1a
1260 resid, leaf_tree, acc, to_prune, ratios = accept_move_and_sample_leaves( 1a
1261 bart['X'],
1262 len(bart['leaf_trees']),
1263 bart['opt']['resid_batch_size'],
1264 resid,
1265 bart['min_points_per_leaf'],
1266 'ratios' in bart,
1267 prelk,
1268 *item,
1269 )
1270 return resid, (leaf_tree, acc, to_prune, ratios) 1a
1272 items = ( 1a
1273 bart['leaf_trees'], count_trees,
1274 moves, move_counts,
1275 bart['leaf_indices'],
1276 prelkv, prelf,
1277 )
1278 resid, (leaf_trees, acc, to_prune, ratios) = lax.scan(loop, bart['resid'], items) 1a
1280 bart['resid'] = resid 1a
1281 bart['leaf_trees'] = leaf_trees 1a
1282 bart.get('ratios', {}).update(ratios) 1a
1283 moves['acc'] = acc 1a
1284 moves['to_prune'] = to_prune 1a
1286 return bart, moves 1a
1288def accept_move_and_sample_leaves(X, ntree, resid_batch_size, resid, min_points_per_leaf, save_ratios, prelk, leaf_tree, count_tree, move, move_counts, leaf_indices, prelkv, prelf): 1a
1289 """
1290 Accept or reject a proposed move and sample the new leaf values.
1292 Parameters
1293 ----------
1294 X : int array (p, n)
1295 The predictors.
1296 ntree : int
1297 The number of trees in the forest.
1298 resid_batch_size : int, None
1299 The batch size for computing the sum of residuals in each leaf.
1300 resid : float array (n,)
1301 The residuals (data minus forest value).
1302 min_points_per_leaf : int or None
1303 The minimum number of data points in a leaf node.
1304 save_ratios : bool
1305 Whether to save the acceptance ratios.
1306 prelk : dict
1307 The pre-computed terms of the likelihood ratio which are shared across
1308 trees.
1309 leaf_tree : float array (2 ** d,)
1310 The leaf values of the tree.
1311 count_tree : int array (2 ** d,)
1312 The number of datapoints in each leaf.
1313 move : dict
1314 The proposed move, see `sample_moves`.
1315 leaf_indices : int array (n,)
1316 The leaf indices for the largest version of the tree compatible with
1317 the move.
1318 prelkv, prelf : dict
1319 The pre-computed terms of the likelihood ratio and leaf sampling which
1320 are specific to the tree.
1322 Returns
1323 -------
1324 resid : float array (n,)
1325 The updated residuals (data minus forest value).
1326 leaf_tree : float array (2 ** d,)
1327 The new leaf values of the tree.
1328 acc : bool
1329 Whether the move was accepted.
1330 to_prune : bool
1331 Whether, to reflect the acceptance status of the move, the state should
1332 be updated by pruning the leaves involved in the move.
1333 ratios : dict
1334 The acceptance ratios for the moves. Empty if not to be saved.
1335 """
1337 # sum residuals and count units per leaf, in tree proposed by grow move
1338 resid_tree = sum_resid(resid, leaf_indices, leaf_tree.size, resid_batch_size) 1a
1340 # subtract starting tree from function
1341 resid_tree += count_tree * leaf_tree 1a
1343 # get indices of move
1344 node = move['node'] 1a
1345 assert node.dtype == jnp.int32 1a
1346 left = move['left'] 1a
1347 right = move['right'] 1a
1349 # sum residuals in parent node modified by move
1350 resid_left = resid_tree[left] 1a
1351 resid_right = resid_tree[right] 1a
1352 resid_total = resid_left + resid_right 1a
1353 resid_tree = resid_tree.at[node].set(resid_total) 1a
1355 # compute acceptance ratio
1356 log_lk_ratio = compute_likelihood_ratio(resid_total, resid_left, resid_right, prelkv, prelk) 1a
1357 log_ratio = move['log_trans_prior_ratio'] + log_lk_ratio 1a
1358 log_ratio = jnp.where(move['grow'], log_ratio, -log_ratio) 1a
1359 ratios = {} 1a
1360 if save_ratios: 1a
1361 ratios.update( 1a
1362 log_trans_prior=move['log_trans_prior_ratio'],
1363 log_likelihood=log_lk_ratio,
1364 )
1366 # determine whether to accept the move
1367 acc = move['allowed'] & (move['logu'] <= log_ratio) 1a
1368 if min_points_per_leaf is not None: 1368 ↛ 1373line 1368 didn't jump to line 1373 because the condition on line 1368 was always true1a
1369 acc &= move_counts['left'] >= min_points_per_leaf 1a
1370 acc &= move_counts['right'] >= min_points_per_leaf 1a
1372 # compute leaves posterior and sample leaves
1373 initial_leaf_tree = leaf_tree 1a
1374 mean_post = resid_tree * prelf['mean_factor'] 1a
1375 leaf_tree = mean_post + prelf['centered_leaves'] 1a
1377 # copy leaves around such that the leaf indices select the right leaf
1378 to_prune = acc ^ move['grow'] 1a
1379 leaf_tree = (leaf_tree 1a
1380 .at[jnp.where(to_prune, left, leaf_tree.size)]
1381 .set(leaf_tree[node])
1382 .at[jnp.where(to_prune, right, leaf_tree.size)]
1383 .set(leaf_tree[node])
1384 )
1386 # replace old tree with new tree in function values
1387 resid += (initial_leaf_tree - leaf_tree)[leaf_indices] 1a
1389 return resid, leaf_tree, acc, to_prune, ratios 1a
1391def sum_resid(resid, leaf_indices, tree_size, batch_size): 1a
1392 """
1393 Sum the residuals in each leaf.
1395 Parameters
1396 ----------
1397 resid : float array (n,)
1398 The residuals (data minus forest value).
1399 leaf_indices : int array (n,)
1400 The leaf indices of the tree (in which leaf each data point falls into).
1401 tree_size : int
1402 The size of the tree array (2 ** d).
1403 batch_size : int, None
1404 The data batch size for the aggregation. Batching increases numerical
1405 accuracy and parallelism.
1407 Returns
1408 -------
1409 resid_tree : float array (2 ** d,)
1410 The sum of the residuals at data points in each leaf.
1411 """
1412 if batch_size is None: 1a
1413 aggr_func = _aggregate_scatter 1a
1414 else:
1415 aggr_func = functools.partial(_aggregate_batched_onetree, batch_size=batch_size) 1a
1416 return aggr_func(resid, leaf_indices, tree_size, jnp.float32) 1a
1418def _aggregate_batched_onetree(values, indices, size, dtype, batch_size): 1a
1419 n, = indices.shape 1a
1420 nbatches = n // batch_size + bool(n % batch_size) 1a
1421 batch_indices = jnp.arange(n) % nbatches 1a
1422 return (jnp 1a
1423 .zeros((size, nbatches), dtype)
1424 .at[indices, batch_indices]
1425 .add(values)
1426 .sum(axis=1)
1427 )
1429def compute_likelihood_ratio(total_resid, left_resid, right_resid, prelkv, prelk): 1a
1430 """
1431 Compute the likelihood ratio of a grow move.
1433 Parameters
1434 ----------
1435 total_resid : float
1436 The sum of the residuals in the leaf to grow.
1437 left_resid, right_resid : float
1438 The sum of the residuals in the left/right child of the leaf to grow.
1439 prelkv, prelk : dict
1440 The pre-computed terms of the likelihood ratio, see
1441 `precompute_likelihood_terms`.
1443 Returns
1444 -------
1445 ratio : float
1446 The likelihood ratio P(data | new tree) / P(data | old tree).
1447 """
1448 exp_term = prelk['exp_factor'] * ( 1a
1449 left_resid * left_resid / prelkv['sigma2_left'] +
1450 right_resid * right_resid / prelkv['sigma2_right'] -
1451 total_resid * total_resid / prelkv['sigma2_total']
1452 )
1453 return prelkv['sqrt_term'] + exp_term 1a
1455def accept_moves_final_stage(bart, moves): 1a
1456 """
1457 The final part of accepting the moves, in parallel across trees.
1459 Parameters
1460 ----------
1461 bart : dict
1462 A partially updated BART mcmc state.
1463 counts : dict
1464 The indicators of proposals and acceptances for grow and prune moves.
1465 moves : dict
1466 The proposed moves (see `sample_moves`) as updated by
1467 `accept_moves_sequential_stage`.
1469 Returns
1470 -------
1471 bart : dict
1472 The fully updated BART mcmc state.
1473 """
1474 bart = bart.copy() 1a
1475 bart['grow_acc_count'] = jnp.sum(moves['acc'] & moves['grow']) 1a
1476 bart['prune_acc_count'] = jnp.sum(moves['acc'] & ~moves['grow']) 1a
1477 bart['leaf_indices'] = apply_moves_to_leaf_indices(bart['leaf_indices'], moves) 1a
1478 bart['split_trees'] = apply_moves_to_split_trees(bart['split_trees'], moves) 1a
1479 return bart 1a
1481@jax.vmap 1a
1482def apply_moves_to_leaf_indices(leaf_indices, moves): 1a
1483 """
1484 Update the leaf indices to match the accepted move.
1486 Parameters
1487 ----------
1488 leaf_indices : int array (num_trees, n)
1489 The index of the leaf each datapoint falls into, if the grow move was
1490 accepted.
1491 moves : dict
1492 The proposed moves (see `sample_moves`), as updated by
1493 `accept_moves_sequential_stage`.
1495 Returns
1496 -------
1497 leaf_indices : int array (num_trees, n)
1498 The updated leaf indices.
1499 """
1500 mask = ~jnp.array(1, leaf_indices.dtype) # ...1111111110 1a
1501 is_child = (leaf_indices & mask) == moves['left'] 1a
1502 return jnp.where( 1a
1503 is_child & moves['to_prune'],
1504 moves['node'].astype(leaf_indices.dtype),
1505 leaf_indices,
1506 )
1508@jax.vmap 1a
1509def apply_moves_to_split_trees(split_trees, moves): 1a
1510 """
1511 Update the split trees to match the accepted move.
1513 Parameters
1514 ----------
1515 split_trees : int array (num_trees, 2 ** (d - 1))
1516 The cutpoints of the decision nodes in the initial trees.
1517 moves : dict
1518 The proposed moves (see `sample_moves`), as updated by
1519 `accept_moves_sequential_stage`.
1521 Returns
1522 -------
1523 split_trees : int array (num_trees, 2 ** (d - 1))
1524 The updated split trees.
1525 """
1526 return (split_trees 1a
1527 .at[jnp.where(
1528 moves['grow'],
1529 moves['node'],
1530 split_trees.size,
1531 )]
1532 .set(moves['grow_split'].astype(split_trees.dtype))
1533 .at[jnp.where(
1534 moves['to_prune'],
1535 moves['node'],
1536 split_trees.size,
1537 )]
1538 .set(0)
1539 )
1541def sample_sigma(bart, key): 1a
1542 """
1543 Noise variance sampling step of BART MCMC.
1545 Parameters
1546 ----------
1547 bart : dict
1548 A BART mcmc state, as created by `init`.
1549 key : jax.dtypes.prng_key array
1550 A jax random key.
1552 Returns
1553 -------
1554 bart : dict
1555 The new BART mcmc state.
1556 """
1557 bart = bart.copy() 1a
1559 resid = bart['resid'] 1a
1560 alpha = bart['sigma2_alpha'] + resid.size / 2 1a
1561 norm2 = jnp.dot(resid, resid, preferred_element_type=bart['opt']['large_float']) 1a
1562 beta = bart['sigma2_beta'] + norm2 / 2 1a
1564 sample = random.gamma(key, alpha) 1a
1565 bart['sigma2'] = beta / sample 1a
1567 return bart 1a