Data processing

Functions to preprocess data.

bartz.prepcovars.parse_xinfo(xinfo)[source]

Parse pre-defined splits in the format of the R package BART.

Parameters:

xinfo (Float[Array, 'p m']) –

A matrix with the cutpoins to use to bin each predictor. Each row shall contain a sorted list of cutpoints for a predictor. If there are less cutpoints than the number of columns in the matrix, fill the remaining cells with NaN.

xinfo shall be a matrix even if x_train is a dataframe.

Returns:

  • splits (Float[Array, ‘p m’]) – xinfo modified by replacing nan with a large value.

  • max_split (UInt[Array, ‘p’]) – The number of non-nan elements in each row of xinfo.

bartz.prepcovars.quantilized_splits_from_matrix(X, max_bins)[source]

Determine bins that make the distribution of each predictor uniform.

Parameters:
  • X (Real[Array, 'p n']) – A matrix with p predictors and n observations.

  • max_bins (int) – The maximum number of bins to produce.

Returns:

  • splits (Real[Array, ‘p m’]) – A matrix containing, for each predictor, the boundaries between bins. m is min(max_bins, n) - 1, which is an upper bound on the number of splits. Each predictor may have a different number of splits; unused values at the end of each row are filled with the maximum value representable in the type of X.

  • max_split (UInt[Array, ‘ p’]) – The number of actually used values in each row of splits.

Raises:

ValueError – If X has no columns or if max_bins is less than 1.

bartz.prepcovars.uniform_splits_from_matrix(X, num_bins)[source]

Make an evenly spaced binning grid.

Parameters:
  • X (Real[Array, 'p n']) – A matrix with p predictors and n observations.

  • num_bins (int) – The number of bins to produce.

Returns:

  • splits (Real[Array, ‘p m’]) – A matrix containing, for each predictor, the boundaries between bins. The excluded endpoints are the minimum and maximum value in each row of X.

  • max_split (UInt[Array, ‘ p’]) – The number of cutpoints in each row of splits, i.e., num_bins - 1.

bartz.prepcovars.bin_predictors(X, splits, **kw)[source]

Bin the predictors according to the given splits.

A value x is mapped to bin i iff splits[i - 1] < x <= splits[i].

Parameters:
  • X (Real[Array, 'p n']) – A matrix with p predictors and n observations.

  • splits (Real[Array, 'p m']) – A matrix containing, for each predictor, the boundaries between bins. m is the maximum number of splits; each row may have shorter actual length, marked by padding unused locations at the end of the row with the maximum value allowed by the type.

  • **kw – Additional arguments are passed to jax.numpy.searchsorted.

Returns:

UInt[Array, 'p n']X but with each value replaced by the index of the bin it falls into.