Coverage for src/bartz/prepcovars.py: 85%
78 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
1# bartz/src/bartz/prepcovars.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions to preprocess data."""
27from functools import partial 1ab
29from jax import jit, vmap 1ab
30from jax import numpy as jnp 1ab
31from jaxtyping import Array, Float, Integer, Real, UInt 1ab
33from bartz.jaxext import autobatch, minimal_unsigned_dtype, unique 1ab
36def parse_xinfo( 1ab
37 xinfo: Float[Array, 'p m'],
38) -> tuple[Float[Array, 'p m'], UInt[Array, ' p']]:
39 """Parse pre-defined splits in the format of the R package BART.
41 Parameters
42 ----------
43 xinfo
44 A matrix with the cutpoins to use to bin each predictor. Each row shall
45 contain a sorted list of cutpoints for a predictor. If there are less
46 cutpoints than the number of columns in the matrix, fill the remaining
47 cells with NaN.
49 `xinfo` shall be a matrix even if `x_train` is a dataframe.
51 Returns
52 -------
53 splits : Float[Array, 'p m']
54 `xinfo` modified by replacing nan with a large value.
55 max_split : UInt[Array, 'p']
56 The number of non-nan elements in each row of `xinfo`.
57 """
58 is_not_nan = ~jnp.isnan(xinfo) 1ab
59 max_split = jnp.sum(is_not_nan, axis=1) 1ab
60 max_split = max_split.astype(minimal_unsigned_dtype(xinfo.shape[1])) 1ab
61 huge = _huge_value(xinfo) 1ab
62 splits = jnp.where(is_not_nan, xinfo, huge) 1ab
63 return splits, max_split 1ab
66@partial(jit, static_argnums=(1,)) 1ab
67def quantilized_splits_from_matrix( 1ab
68 X: Real[Array, 'p n'], max_bins: int
69) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
70 """
71 Determine bins that make the distribution of each predictor uniform.
73 Parameters
74 ----------
75 X
76 A matrix with `p` predictors and `n` observations.
77 max_bins
78 The maximum number of bins to produce.
80 Returns
81 -------
82 splits : Real[Array, 'p m']
83 A matrix containing, for each predictor, the boundaries between bins.
84 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number
85 of splits. Each predictor may have a different number of splits; unused
86 values at the end of each row are filled with the maximum value
87 representable in the type of `X`.
88 max_split : UInt[Array, ' p']
89 The number of actually used values in each row of `splits`.
91 Raises
92 ------
93 ValueError
94 If `X` has no columns or if `max_bins` is less than 1.
95 """
96 out_length = min(max_bins, X.shape[1]) - 1 1ab
98 if out_length < 0: 1ab
99 msg = f'{X.shape[1]=} and {max_bins=}, they should be both at least 1.' 1ab
100 raise ValueError(msg) 1ab
102 @partial(autobatch, max_io_nbytes=2**29) 1ab
103 def quantilize(X): 1ab
104 # wrap this function because autobatch needs traceable args
105 return _quantilized_splits_from_matrix(X, out_length) 1ab
107 return quantilize(X) 1ab
110@partial(vmap, in_axes=(0, None)) 1ab
111def _quantilized_splits_from_matrix( 1ab
112 x: Real[Array, 'p n'], out_length: int
113) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
114 # find the sorted unique values in x
115 huge = _huge_value(x) 1ab
116 u, actual_length = unique(x, size=x.size, fill_value=huge) 1ab
118 # compute the midpoints between each unique value
119 if jnp.issubdtype(x.dtype, jnp.integer): 1ab
120 midpoints = u[:-1] + _ensure_unsigned(u[1:] - u[:-1]) // 2 1ab
121 else:
122 midpoints = u[:-1] + (u[1:] - u[:-1]) / 2 1ab
123 # using x_i + (x_i+1 - x_i) / 2 instead of (x_i + x_i+1) / 2 is to
124 # avoid overflow
125 actual_length -= 1 1ab
126 if midpoints.size: 1ab
127 midpoints = midpoints.at[actual_length].set(huge) 1ab
129 # take a subset of the midpoints if there are more than the requested maximum
130 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1] 1ab
131 indices = jnp.around(indices).astype(minimal_unsigned_dtype(midpoints.size - 1)) 1ab
132 # indices calculation with float rather than int to avoid potential
133 # overflow with int32, and to round to nearest instead of rounding down
134 decimated_midpoints = midpoints[indices] 1ab
135 truncated_midpoints = midpoints[:out_length] 1ab
136 splits = jnp.where( 1ab
137 actual_length > out_length, decimated_midpoints, truncated_midpoints
138 )
139 max_split = jnp.minimum(actual_length, out_length) 1ab
140 max_split = max_split.astype(minimal_unsigned_dtype(out_length)) 1ab
141 return splits, max_split 1ab
144def _huge_value(x: Array) -> int | float: 1ab
145 """
146 Return the maximum value that can be stored in `x`.
148 Parameters
149 ----------
150 x
151 A numerical numpy or jax array.
153 Returns
154 -------
155 The maximum value allowed by `x`'s type (finite for floats).
156 """
157 if jnp.issubdtype(x.dtype, jnp.integer): 1ab
158 return jnp.iinfo(x.dtype).max 1ab
159 else:
160 return float(jnp.finfo(x.dtype).max) 1ab
163def _ensure_unsigned(x: Integer[Array, '*shape']) -> UInt[Array, '*shape']: 1ab
164 """If x has signed integer type, cast it to the unsigned dtype of the same size."""
165 return x.astype(_signed_to_unsigned(x.dtype)) 1ab
168def _signed_to_unsigned(int_dtype: jnp.dtype) -> jnp.dtype: 1ab
169 """
170 Map a signed integer type to its unsigned counterpart.
172 Unsigned types are passed through.
173 """
174 assert jnp.issubdtype(int_dtype, jnp.integer) 1ab
175 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true1ab
176 return int_dtype
177 match int_dtype: 1ab
178 case jnp.int8: 178 ↛ 179line 178 didn't jump to line 179 because the pattern on line 178 never matched1ab
179 return jnp.uint8
180 case jnp.int16: 180 ↛ 181line 180 didn't jump to line 181 because the pattern on line 180 never matched1ab
181 return jnp.uint16
182 case jnp.int32: 182 ↛ 184line 182 didn't jump to line 184 because the pattern on line 182 always matched1ab
183 return jnp.uint32 1ab
184 case jnp.int64:
185 return jnp.uint64
186 case _:
187 msg = f'unexpected integer type {int_dtype}'
188 raise TypeError(msg)
191@partial(jit, static_argnums=(1,)) 1ab
192def uniform_splits_from_matrix( 1ab
193 X: Real[Array, 'p n'], num_bins: int
194) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]:
195 """
196 Make an evenly spaced binning grid.
198 Parameters
199 ----------
200 X
201 A matrix with `p` predictors and `n` observations.
202 num_bins
203 The number of bins to produce.
205 Returns
206 -------
207 splits : Real[Array, 'p m']
208 A matrix containing, for each predictor, the boundaries between bins.
209 The excluded endpoints are the minimum and maximum value in each row of
210 `X`.
211 max_split : UInt[Array, ' p']
212 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``.
213 """
214 low = jnp.min(X, axis=1) 1ab
215 high = jnp.max(X, axis=1) 1ab
216 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1] 1ab
217 assert splits.shape == (X.shape[0], num_bins - 1) 1ab
218 max_split = jnp.full(*splits.shape, minimal_unsigned_dtype(num_bins - 1)) 1ab
219 return splits, max_split 1ab
222@partial(jit, static_argnames=('method',)) 1ab
223def bin_predictors( 1ab
224 X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw
225) -> UInt[Array, 'p n']:
226 """
227 Bin the predictors according to the given splits.
229 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
231 Parameters
232 ----------
233 X
234 A matrix with `p` predictors and `n` observations.
235 splits
236 A matrix containing, for each predictor, the boundaries between bins.
237 `m` is the maximum number of splits; each row may have shorter
238 actual length, marked by padding unused locations at the end of the
239 row with the maximum value allowed by the type.
240 **kw
241 Additional arguments are passed to `jax.numpy.searchsorted`.
243 Returns
244 -------
245 `X` but with each value replaced by the index of the bin it falls into.
246 """
248 @partial(autobatch, max_io_nbytes=2**29) 1ab
249 @vmap 1ab
250 def bin_predictors(x, splits): 1ab
251 dtype = minimal_unsigned_dtype(splits.size) 1ab
252 return jnp.searchsorted(splits, x, **kw).astype(dtype) 1ab
254 return bin_predictors(X, splits) 1ab