Coverage for src/bartz/prepcovars.py: 100%
45 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
1# bartz/src/bartz/prepcovars.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Functions to preprocess data."""
27import functools 1ab
29import jax 1ab
30from jax import numpy as jnp 1ab
32from . import jaxext 1ab
35@functools.partial(jax.jit, static_argnums=(1,)) 1ab
36def quantilized_splits_from_matrix(X, max_bins): 1ab
37 """
38 Determine bins that make the distribution of each predictor uniform.
40 Parameters
41 ----------
42 X : array (p, n)
43 A matrix with `p` predictors and `n` observations.
44 max_bins : int
45 The maximum number of bins to produce.
47 Returns
48 -------
49 splits : array (p, m)
50 A matrix containing, for each predictor, the boundaries between bins.
51 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number
52 of splits. Each predictor may have a different number of splits; unused
53 values at the end of each row are filled with the maximum value
54 representable in the type of `X`.
55 max_split : array (p,)
56 The number of actually used values in each row of `splits`.
57 """
58 out_length = min(max_bins, X.shape[1]) - 1 1ab
60 @functools.partial(jaxext.autobatch, max_io_nbytes=2**29) 1ab
61 def quantilize(X): 1ab
62 # wrap this function because autobatch needs traceable args
63 return _quantilized_splits_from_matrix(X, out_length) 1ab
65 return quantilize(X) 1ab
68@functools.partial(jax.vmap, in_axes=(0, None)) 1ab
69def _quantilized_splits_from_matrix(x, out_length): 1ab
70 huge = jaxext.huge_value(x) 1ab
71 u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge) 1ab
72 actual_length -= 1 1ab
73 if jnp.issubdtype(x.dtype, jnp.integer): 1ab
74 midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2 1ab
75 indices = jnp.arange( 1ab
76 midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1)
77 )
78 midpoints = jnp.where(indices < actual_length, midpoints, huge) 1ab
79 else:
80 midpoints = (u[1:] + u[:-1]) / 2 1ab
81 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1] 1ab
82 indices = jnp.around(indices).astype( 1ab
83 jaxext.minimal_unsigned_dtype(midpoints.size - 1)
84 )
85 # indices calculation with float rather than int to avoid potential
86 # overflow with int32, and to round to nearest instead of rounding down
87 decimated_midpoints = midpoints[indices] 1ab
88 truncated_midpoints = midpoints[:out_length] 1ab
89 splits = jnp.where( 1ab
90 actual_length > out_length, decimated_midpoints, truncated_midpoints
91 )
92 max_split = jnp.minimum(actual_length, out_length) 1ab
93 max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length)) 1ab
94 return splits, max_split 1ab
97@functools.partial(jax.jit, static_argnums=(1,)) 1ab
98def uniform_splits_from_matrix(X, num_bins): 1ab
99 """
100 Make an evenly spaced binning grid.
102 Parameters
103 ----------
104 X : array (p, n)
105 A matrix with `p` predictors and `n` observations.
106 num_bins : int
107 The number of bins to produce.
109 Returns
110 -------
111 splits : array (p, num_bins - 1)
112 A matrix containing, for each predictor, the boundaries between bins.
113 The excluded endpoints are the minimum and maximum value in each row of
114 `X`.
115 max_split : array (p,)
116 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``.
117 """
118 low = jnp.min(X, axis=1) 1ab
119 high = jnp.max(X, axis=1) 1ab
120 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1] 1ab
121 assert splits.shape == (X.shape[0], num_bins - 1) 1ab
122 max_split = jnp.full(*splits.shape, jaxext.minimal_unsigned_dtype(num_bins - 1)) 1ab
123 return splits, max_split 1ab
126@functools.partial(jax.jit, static_argnames=('method',)) 1ab
127def bin_predictors(X, splits, **kw): 1ab
128 """
129 Bin the predictors according to the given splits.
131 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
133 Parameters
134 ----------
135 X : array (p, n)
136 A matrix with `p` predictors and `n` observations.
137 splits : array (p, m)
138 A matrix containing, for each predictor, the boundaries between bins.
139 `m` is the maximum number of splits; each row may have shorter
140 actual length, marked by padding unused locations at the end of the
141 row with the maximum value allowed by the type.
142 **kw : dict
143 Additional arguments are passed to `jax.numpy.searchsorted`.
145 Returns
146 -------
147 X_binned : int array (p, n)
148 A matrix with `p` predictors and `n` observations, where each predictor
149 has been replaced by the index of the bin it falls into.
150 """
152 @functools.partial(jaxext.autobatch, max_io_nbytes=2**29) 1ab
153 @jax.vmap 1ab
154 def bin_predictors(x, splits): 1ab
155 dtype = jaxext.minimal_unsigned_dtype(splits.size) 1ab
156 return jnp.searchsorted(splits, x, **kw).astype(dtype) 1ab
158 return bin_predictors(X, splits) 1ab