Coverage for src/bartz/prepcovars.py: 100%
46 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
1# bartz/src/bartz/prepcovars.py
2#
3# Copyright (c) 2024, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25import functools 1a
27import jax 1a
28from jax import numpy as jnp 1a
30from . import jaxext 1a
31from . import grove 1a
33@functools.partial(jax.jit, static_argnums=(1,)) 1a
34def quantilized_splits_from_matrix(X, max_bins): 1a
35 """
36 Determine bins that make the distribution of each predictor uniform.
38 Parameters
39 ----------
40 X : array (p, n)
41 A matrix with `p` predictors and `n` observations.
42 max_bins : int
43 The maximum number of bins to produce.
45 Returns
46 -------
47 splits : array (p, m)
48 A matrix containing, for each predictor, the boundaries between bins.
49 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number
50 of splits. Each predictor may have a different number of splits; unused
51 values at the end of each row are filled with the maximum value
52 representable in the type of `X`.
53 max_split : array (p,)
54 The number of actually used values in each row of `splits`.
55 """
56 out_length = min(max_bins, X.shape[1]) - 1 1a
57 # return _quantilized_splits_from_matrix(X, out_length)
58 @functools.partial(jaxext.autobatch, max_io_nbytes=2 ** 29) 1a
59 def quantilize(X): 1a
60 return _quantilized_splits_from_matrix(X, out_length) 1a
61 return quantilize(X) 1a
63@functools.partial(jax.vmap, in_axes=(0, None)) 1a
64def _quantilized_splits_from_matrix(x, out_length): 1a
65 huge = jaxext.huge_value(x) 1a
66 u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge) 1a
67 actual_length -= 1 1a
68 if jnp.issubdtype(x.dtype, jnp.integer): 1a
69 midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2 1a
70 indices = jnp.arange(midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1)) 1a
71 midpoints = jnp.where(indices < actual_length, midpoints, huge) 1a
72 else:
73 midpoints = (u[1:] + u[:-1]) / 2 1a
74 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1] 1a
75 indices = jnp.around(indices).astype(jaxext.minimal_unsigned_dtype(midpoints.size - 1)) 1a
76 # indices calculation with float rather than int to avoid potential
77 # overflow with int32, and to round to nearest instead of rounding down
78 decimated_midpoints = midpoints[indices] 1a
79 truncated_midpoints = midpoints[:out_length] 1a
80 splits = jnp.where(actual_length > out_length, decimated_midpoints, truncated_midpoints) 1a
81 max_split = jnp.minimum(actual_length, out_length) 1a
82 max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length)) 1a
83 return splits, max_split 1a
85@functools.partial(jax.jit, static_argnums=(1,)) 1a
86def uniform_splits_from_matrix(X, num_bins): 1a
87 """
88 Make an evenly spaced binning grid.
90 Parameters
91 ----------
92 X : array (p, n)
93 A matrix with `p` predictors and `n` observations.
94 num_bins : int
95 The number of bins to produce.
97 Returns
98 -------
99 splits : array (p, num_bins - 1)
100 A matrix containing, for each predictor, the boundaries between bins.
101 The excluded endpoints are the minimum and maximum value in each row of
102 `X`.
103 max_split : array (p,)
104 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``.
105 """
106 low = jnp.min(X, axis=1) 1a
107 high = jnp.max(X, axis=1) 1a
108 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1] 1a
109 assert splits.shape == (X.shape[0], num_bins - 1) 1a
110 max_split = jnp.full(*splits.shape, jaxext.minimal_unsigned_dtype(num_bins - 1)) 1a
111 return splits, max_split 1a
113@functools.partial(jax.jit, static_argnames=('method',)) 1a
114def bin_predictors(X, splits, **kw): 1a
115 """
116 Bin the predictors according to the given splits.
118 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``.
120 Parameters
121 ----------
122 X : array (p, n)
123 A matrix with `p` predictors and `n` observations.
124 splits : array (p, m)
125 A matrix containing, for each predictor, the boundaries between bins.
126 `m` is the maximum number of splits; each row may have shorter
127 actual length, marked by padding unused locations at the end of the
128 row with the maximum value allowed by the type.
129 **kw : dict
130 Additional arguments are passed to `jax.numpy.searchsorted`.
132 Returns
133 -------
134 X_binned : int array (p, n)
135 A matrix with `p` predictors and `n` observations, where each predictor
136 has been replaced by the index of the bin it falls into.
137 """
138 @functools.partial(jaxext.autobatch, max_io_nbytes=2 ** 29) 1a
139 @jax.vmap 1a
140 def bin_predictors(x, splits): 1a
141 dtype = jaxext.minimal_unsigned_dtype(splits.size) 1a
142 return jnp.searchsorted(splits, x, **kw).astype(dtype) 1a
143 return bin_predictors(X, splits) 1a