Coverage for src / bartz / prepcovars.py: 85%

78 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-18 15:24 +0000

1# bartz/src/bartz/prepcovars.py 

2# 

3# Copyright (c) 2024-2025, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions to preprocess data.""" 

26 

27from functools import partial 

28 

29from jax import jit, vmap 

30from jax import numpy as jnp 

31from jaxtyping import Array, Float, Integer, Real, UInt 

32 

33from bartz.jaxext import autobatch, minimal_unsigned_dtype, unique 

34 

35 

36def parse_xinfo( 

37 xinfo: Float[Array, 'p m'], 

38) -> tuple[Float[Array, 'p m'], UInt[Array, ' p']]: 

39 """Parse pre-defined splits in the format of the R package BART. 

40 

41 Parameters 

42 ---------- 

43 xinfo 

44 A matrix with the cutpoins to use to bin each predictor. Each row shall 

45 contain a sorted list of cutpoints for a predictor. If there are less 

46 cutpoints than the number of columns in the matrix, fill the remaining 

47 cells with NaN. 

48 

49 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

50 

51 Returns 

52 ------- 

53 splits : Float[Array, 'p m'] 

54 `xinfo` modified by replacing nan with a large value. 

55 max_split : UInt[Array, 'p'] 

56 The number of non-nan elements in each row of `xinfo`. 

57 """ 

58 is_not_nan = ~jnp.isnan(xinfo) 1qrstuvwxyzA

59 max_split = jnp.sum(is_not_nan, axis=1) 1qrstuvwxyzA

60 max_split = max_split.astype(minimal_unsigned_dtype(xinfo.shape[1])) 1qrstuvwxyzA

61 huge = _huge_value(xinfo) 1qrstuvwxyzA

62 splits = jnp.where(is_not_nan, xinfo, huge) 1qrstuvwxyzA

63 return splits, max_split 1qrstuvwxyzA

64 

65 

66@partial(jit, static_argnums=(1,)) 

67def quantilized_splits_from_matrix( 

68 X: Real[Array, 'p n'], max_bins: int 

69) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

70 """ 

71 Determine bins that make the distribution of each predictor uniform. 

72 

73 Parameters 

74 ---------- 

75 X 

76 A matrix with `p` predictors and `n` observations. 

77 max_bins 

78 The maximum number of bins to produce. 

79 

80 Returns 

81 ------- 

82 splits : Real[Array, 'p m'] 

83 A matrix containing, for each predictor, the boundaries between bins. 

84 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number 

85 of splits. Each predictor may have a different number of splits; unused 

86 values at the end of each row are filled with the maximum value 

87 representable in the type of `X`. 

88 max_split : UInt[Array, ' p'] 

89 The number of actually used values in each row of `splits`. 

90 

91 Raises 

92 ------ 

93 ValueError 

94 If `X` has no columns or if `max_bins` is less than 1. 

95 """ 

96 out_length = min(max_bins, X.shape[1]) - 1 1fghijklabecmnodJK

97 

98 if out_length < 0: 1fghijklabecmnodJK

99 msg = f'{X.shape[1]=} and {max_bins=}, they should be both at least 1.' 1JK

100 raise ValueError(msg) 1JK

101 

102 @partial(autobatch, max_io_nbytes=2**29) 1fghijklabecmnod

103 def quantilize(X): 1fghijklabecmnod

104 # wrap this function because autobatch needs traceable args 

105 return _quantilized_splits_from_matrix(X, out_length) 1fghijklabecmnod

106 

107 return quantilize(X) 1fghijklabecmnod

108 

109 

110@partial(vmap, in_axes=(0, None)) 

111def _quantilized_splits_from_matrix( 

112 x: Real[Array, 'p n'], out_length: int 

113) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

114 # find the sorted unique values in x 

115 huge = _huge_value(x) 1fghijklabecmnod

116 u, actual_length = unique(x, size=x.size, fill_value=huge) 1fghijklabecmnod

117 

118 # compute the midpoints between each unique value 

119 if jnp.issubdtype(x.dtype, jnp.integer): 1fghijklabecmnod

120 midpoints = u[:-1] + _ensure_unsigned(u[1:] - u[:-1]) // 2 1abecd

121 else: 

122 midpoints = u[:-1] + (u[1:] - u[:-1]) / 2 1fghijklmno

123 # using x_i + (x_i+1 - x_i) / 2 instead of (x_i + x_i+1) / 2 is to 

124 # avoid overflow 

125 actual_length -= 1 1fghijklabecmnod

126 if midpoints.size: 1fghijklabecmnod

127 midpoints = midpoints.at[actual_length].set(huge) 1fghijklabcmnod

128 

129 # take a subset of the midpoints if there are more than the requested maximum 

130 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1] 1fghijklabecmnod

131 indices = jnp.around(indices).astype(minimal_unsigned_dtype(midpoints.size - 1)) 1fghijklabecmnod

132 # indices calculation with float rather than int to avoid potential 

133 # overflow with int32, and to round to nearest instead of rounding down 

134 decimated_midpoints = midpoints[indices] 1fghijklabecmnod

135 truncated_midpoints = midpoints[:out_length] 1fghijklabecmnod

136 splits = jnp.where( 1fghijklabecmnod

137 actual_length > out_length, decimated_midpoints, truncated_midpoints 

138 ) 

139 max_split = jnp.minimum(actual_length, out_length) 1fghijklabecmnod

140 max_split = max_split.astype(minimal_unsigned_dtype(out_length)) 1fghijklabecmnod

141 return splits, max_split 1fghijklabecmnod

142 

143 

144def _huge_value(x: Array) -> int | float: 

145 """ 

146 Return the maximum value that can be stored in `x`. 

147 

148 Parameters 

149 ---------- 

150 x 

151 A numerical numpy or jax array. 

152 

153 Returns 

154 ------- 

155 The maximum value allowed by `x`'s type (finite for floats). 

156 """ 

157 if jnp.issubdtype(x.dtype, jnp.integer): 1fghijqrstuvwxyklzAabecmnod

158 return jnp.iinfo(x.dtype).max 1abecd

159 else: 

160 return float(jnp.finfo(x.dtype).max) 1fghijqrstuvwxyklzAmno

161 

162 

163def _ensure_unsigned(x: Integer[Array, '*shape']) -> UInt[Array, '*shape']: 

164 """If x has signed integer type, cast it to the unsigned dtype of the same size.""" 

165 return x.astype(_signed_to_unsigned(x.dtype)) 1abecd

166 

167 

168def _signed_to_unsigned(int_dtype: jnp.dtype) -> jnp.dtype: 

169 """ 

170 Map a signed integer type to its unsigned counterpart. 

171 

172 Unsigned types are passed through. 

173 """ 

174 assert jnp.issubdtype(int_dtype, jnp.integer) 1abecd

175 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true1abecd

176 return int_dtype 

177 match int_dtype: 1abecd

178 case jnp.int8: 178 ↛ 179line 178 didn't jump to line 179 because the pattern on line 178 never matched1abecd

179 return jnp.uint8 

180 case jnp.int16: 180 ↛ 181line 180 didn't jump to line 181 because the pattern on line 180 never matched1abecd

181 return jnp.uint16 

182 case jnp.int32: 182 ↛ 184line 182 didn't jump to line 184 because the pattern on line 182 always matched1abecd

183 return jnp.uint32 1abecd

184 case jnp.int64: 

185 return jnp.uint64 

186 case _: 

187 msg = f'unexpected integer type {int_dtype}' 

188 raise TypeError(msg) 

189 

190 

191@partial(jit, static_argnums=(1,)) 

192def uniform_splits_from_matrix( 

193 X: Real[Array, 'p n'], num_bins: int 

194) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

195 """ 

196 Make an evenly spaced binning grid. 

197 

198 Parameters 

199 ---------- 

200 X 

201 A matrix with `p` predictors and `n` observations. 

202 num_bins 

203 The number of bins to produce. 

204 

205 Returns 

206 ------- 

207 splits : Real[Array, 'p m'] 

208 A matrix containing, for each predictor, the boundaries between bins. 

209 The excluded endpoints are the minimum and maximum value in each row of 

210 `X`. 

211 max_split : UInt[Array, ' p'] 

212 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``. 

213 """ 

214 low = jnp.min(X, axis=1) 1BCDEF

215 high = jnp.max(X, axis=1) 1BCDEF

216 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1] 1BCDEF

217 assert splits.shape == (X.shape[0], num_bins - 1) 1BCDEF

218 max_split = jnp.full(*splits.shape, minimal_unsigned_dtype(num_bins - 1)) 1BCDEF

219 return splits, max_split 1BCDEF

220 

221 

222@partial(jit, static_argnames=('method',)) 

223def bin_predictors( 

224 X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw 

225) -> UInt[Array, 'p n']: 

226 """ 

227 Bin the predictors according to the given splits. 

228 

229 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``. 

230 

231 Parameters 

232 ---------- 

233 X 

234 A matrix with `p` predictors and `n` observations. 

235 splits 

236 A matrix containing, for each predictor, the boundaries between bins. 

237 `m` is the maximum number of splits; each row may have shorter 

238 actual length, marked by padding unused locations at the end of the 

239 row with the maximum value allowed by the type. 

240 **kw 

241 Additional arguments are passed to `jax.numpy.searchsorted`. 

242 

243 Returns 

244 ------- 

245 `X` but with each value replaced by the index of the bin it falls into. 

246 """ 

247 

248 @partial(autobatch, max_io_nbytes=2**29) 1fBgChijqrsDtuvwxyEklFzAGHI

249 @vmap 1fBgChijqrsDtuvwxyEklFzAGHI

250 def bin_predictors(x, splits): 1fBgChijqrsDtuvwxyEklFzAGHI

251 dtype = minimal_unsigned_dtype(splits.size) 1fBgChijqrsDtuvwxyEklFzAGHI

252 return jnp.searchsorted(splits, x, **kw).astype(dtype) 1fBgChijqrsDtuvwxyEklFzAGHI

253 

254 return bin_predictors(X, splits) 1fBgChijqrsDtuvwxyEklFzAGHI