Quickstart

Basics

Import the bartz module and use the BART.gbart class:

import bartz
bart = bartz.BART.gbart(X, y, ...)
y_pred = bart.predict(X_test)

The interface hews to the R package BART, with a few differences explained in the documentation of BART.gbart.

JAX

bartz is implemented using jax, a Google library for machine learning. It allows to run the code on GPU or TPU and do various other things.

For basic usage, JAX is just an alternative implementation of numpy. The arrays returned by gbart are “jax arrays” instead of “numpy arrays”, but there is no perceived difference in their functionality. If you pass numpy arrays to bartz, they will be converted automatically. You don’t have to deal with jax in any way.

For advanced usage, refer to the jax documentation.

Advanced

bartz exposes the various functions that implement the MCMC of BART. You can use those yourself to try to make your own variant of BART. See the rest of the documentation for reference; the main entry points mcmcstep.init and mcmcloop.run_mcmc. At the moment using the internals is the only way to change the device used by each step of the algorithm, which is useful to pre-process data on CPU and move to GPU only the state of the MCMC if the data preprocessing step does not fit in the GPU memory.