Quickstart¶
Basics¶
Import the bartz
module and use the gbart
class:
import bartz
bart = bartz.BART.gbart(X, y, ...)
y_pred = bart.predict(X_test)
The interface hews to the R package BART, with a few differences.
JAX¶
bartz
is implemented using jax
, a Google library for machine learning. It allows to run the code on GPU or TPU and do various other things.
For basic usage, JAX is just an alternative implementation of numpy
. The arrays returned by gbart
are “jax arrays” instead of “numpy arrays”, but there is no perceived difference in their functionality. If you pass numpy arrays to bartz
, they will be converted automatically. You don’t have to deal with jax
in any way.
For advanced usage, refer to the jax documentation.
Advanced¶
bartz
exposes the various functions that implement the MCMC of BART. You can use those yourself to try to make your own variant of BART. See the rest of the documentation for reference.