Quickstart

Basics

Import the bartz module and use the gbart class:

import bartz
bart = bartz.BART.gbart(X, y, ...)
y_pred = bart.predict(X_test)

The interface hews to the R package BART, with a few differences.

JAX

bartz is implemented using jax, a Google library for machine learning. It allows to run the code on GPU or TPU and do various other things.

For basic usage, JAX is just an alternative implementation of numpy. The arrays returned by gbart are “jax arrays” instead of “numpy arrays”, but there is no perceived difference in their functionality. If you pass numpy arrays to bartz, they will be converted automatically. You don’t have to deal with jax in any way.

For advanced usage, refer to the jax documentation.

Advanced

bartz exposes the various functions that implement the MCMC of BART. You can use those yourself to try to make your own variant of BART. See the rest of the documentation for reference.