Coverage for src/bartz/prepcovars.py: 85%

78 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 14:46 +0000

1# bartz/src/bartz/prepcovars.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions to preprocess data.""" 

26 

27from functools import partial 1ab

28 

29from jax import jit, vmap 1ab

30from jax import numpy as jnp 1ab

31from jaxtyping import Array, Float, Integer, Real, UInt 1ab

32 

33from bartz.jaxext import autobatch, minimal_unsigned_dtype, unique 1ab

34 

35 

36def parse_xinfo( 1ab

37 xinfo: Float[Array, 'p m'], 

38) -> tuple[Float[Array, 'p m'], UInt[Array, ' p']]: 

39 """Parse pre-defined splits in the format of the R package BART. 

40 

41 Parameters 

42 ---------- 

43 xinfo 

44 A matrix with the cutpoins to use to bin each predictor. Each row shall 

45 contain a sorted list of cutpoints for a predictor. If there are less 

46 cutpoints than the number of columns in the matrix, fill the remaining 

47 cells with NaN. 

48 

49 `xinfo` shall be a matrix even if `x_train` is a dataframe. 

50 

51 Returns 

52 ------- 

53 splits : Float[Array, 'p m'] 

54 `xinfo` modified by replacing nan with a large value. 

55 max_split : UInt[Array, 'p'] 

56 The number of non-nan elements in each row of `xinfo`. 

57 """ 

58 is_not_nan = ~jnp.isnan(xinfo) 1ab

59 max_split = jnp.sum(is_not_nan, axis=1) 1ab

60 max_split = max_split.astype(minimal_unsigned_dtype(xinfo.shape[1])) 1ab

61 huge = _huge_value(xinfo) 1ab

62 splits = jnp.where(is_not_nan, xinfo, huge) 1ab

63 return splits, max_split 1ab

64 

65 

66@partial(jit, static_argnums=(1,)) 1ab

67def quantilized_splits_from_matrix( 1ab

68 X: Real[Array, 'p n'], max_bins: int 

69) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

70 """ 

71 Determine bins that make the distribution of each predictor uniform. 

72 

73 Parameters 

74 ---------- 

75 X 

76 A matrix with `p` predictors and `n` observations. 

77 max_bins 

78 The maximum number of bins to produce. 

79 

80 Returns 

81 ------- 

82 splits : Real[Array, 'p m'] 

83 A matrix containing, for each predictor, the boundaries between bins. 

84 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number 

85 of splits. Each predictor may have a different number of splits; unused 

86 values at the end of each row are filled with the maximum value 

87 representable in the type of `X`. 

88 max_split : UInt[Array, ' p'] 

89 The number of actually used values in each row of `splits`. 

90 

91 Raises 

92 ------ 

93 ValueError 

94 If `X` has no columns or if `max_bins` is less than 1. 

95 """ 

96 out_length = min(max_bins, X.shape[1]) - 1 1ab

97 

98 if out_length < 0: 1ab

99 msg = f'{X.shape[1]=} and {max_bins=}, they should be both at least 1.' 1ab

100 raise ValueError(msg) 1ab

101 

102 @partial(autobatch, max_io_nbytes=2**29) 1ab

103 def quantilize(X): 1ab

104 # wrap this function because autobatch needs traceable args 

105 return _quantilized_splits_from_matrix(X, out_length) 1ab

106 

107 return quantilize(X) 1ab

108 

109 

110@partial(vmap, in_axes=(0, None)) 1ab

111def _quantilized_splits_from_matrix( 1ab

112 x: Real[Array, 'p n'], out_length: int 

113) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

114 # find the sorted unique values in x 

115 huge = _huge_value(x) 1ab

116 u, actual_length = unique(x, size=x.size, fill_value=huge) 1ab

117 

118 # compute the midpoints between each unique value 

119 if jnp.issubdtype(x.dtype, jnp.integer): 1ab

120 midpoints = u[:-1] + _ensure_unsigned(u[1:] - u[:-1]) // 2 1ab

121 else: 

122 midpoints = u[:-1] + (u[1:] - u[:-1]) / 2 1ab

123 # using x_i + (x_i+1 - x_i) / 2 instead of (x_i + x_i+1) / 2 is to 

124 # avoid overflow 

125 actual_length -= 1 1ab

126 if midpoints.size: 1ab

127 midpoints = midpoints.at[actual_length].set(huge) 1ab

128 

129 # take a subset of the midpoints if there are more than the requested maximum 

130 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1] 1ab

131 indices = jnp.around(indices).astype(minimal_unsigned_dtype(midpoints.size - 1)) 1ab

132 # indices calculation with float rather than int to avoid potential 

133 # overflow with int32, and to round to nearest instead of rounding down 

134 decimated_midpoints = midpoints[indices] 1ab

135 truncated_midpoints = midpoints[:out_length] 1ab

136 splits = jnp.where( 1ab

137 actual_length > out_length, decimated_midpoints, truncated_midpoints 

138 ) 

139 max_split = jnp.minimum(actual_length, out_length) 1ab

140 max_split = max_split.astype(minimal_unsigned_dtype(out_length)) 1ab

141 return splits, max_split 1ab

142 

143 

144def _huge_value(x: Array) -> int | float: 1ab

145 """ 

146 Return the maximum value that can be stored in `x`. 

147 

148 Parameters 

149 ---------- 

150 x 

151 A numerical numpy or jax array. 

152 

153 Returns 

154 ------- 

155 The maximum value allowed by `x`'s type (finite for floats). 

156 """ 

157 if jnp.issubdtype(x.dtype, jnp.integer): 1ab

158 return jnp.iinfo(x.dtype).max 1ab

159 else: 

160 return float(jnp.finfo(x.dtype).max) 1ab

161 

162 

163def _ensure_unsigned(x: Integer[Array, '*shape']) -> UInt[Array, '*shape']: 1ab

164 """If x has signed integer type, cast it to the unsigned dtype of the same size.""" 

165 return x.astype(_signed_to_unsigned(x.dtype)) 1ab

166 

167 

168def _signed_to_unsigned(int_dtype: jnp.dtype) -> jnp.dtype: 1ab

169 """ 

170 Map a signed integer type to its unsigned counterpart. 

171 

172 Unsigned types are passed through. 

173 """ 

174 assert jnp.issubdtype(int_dtype, jnp.integer) 1ab

175 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 175 ↛ 176line 175 didn't jump to line 176 because the condition on line 175 was never true1ab

176 return int_dtype 

177 match int_dtype: 1ab

178 case jnp.int8: 178 ↛ 179line 178 didn't jump to line 179 because the pattern on line 178 never matched1ab

179 return jnp.uint8 

180 case jnp.int16: 180 ↛ 181line 180 didn't jump to line 181 because the pattern on line 180 never matched1ab

181 return jnp.uint16 

182 case jnp.int32: 182 ↛ 184line 182 didn't jump to line 184 because the pattern on line 182 always matched1ab

183 return jnp.uint32 1ab

184 case jnp.int64: 

185 return jnp.uint64 

186 case _: 

187 msg = f'unexpected integer type {int_dtype}' 

188 raise TypeError(msg) 

189 

190 

191@partial(jit, static_argnums=(1,)) 1ab

192def uniform_splits_from_matrix( 1ab

193 X: Real[Array, 'p n'], num_bins: int 

194) -> tuple[Real[Array, 'p m'], UInt[Array, ' p']]: 

195 """ 

196 Make an evenly spaced binning grid. 

197 

198 Parameters 

199 ---------- 

200 X 

201 A matrix with `p` predictors and `n` observations. 

202 num_bins 

203 The number of bins to produce. 

204 

205 Returns 

206 ------- 

207 splits : Real[Array, 'p m'] 

208 A matrix containing, for each predictor, the boundaries between bins. 

209 The excluded endpoints are the minimum and maximum value in each row of 

210 `X`. 

211 max_split : UInt[Array, ' p'] 

212 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``. 

213 """ 

214 low = jnp.min(X, axis=1) 1ab

215 high = jnp.max(X, axis=1) 1ab

216 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1] 1ab

217 assert splits.shape == (X.shape[0], num_bins - 1) 1ab

218 max_split = jnp.full(*splits.shape, minimal_unsigned_dtype(num_bins - 1)) 1ab

219 return splits, max_split 1ab

220 

221 

222@partial(jit, static_argnames=('method',)) 1ab

223def bin_predictors( 1ab

224 X: Real[Array, 'p n'], splits: Real[Array, 'p m'], **kw 

225) -> UInt[Array, 'p n']: 

226 """ 

227 Bin the predictors according to the given splits. 

228 

229 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``. 

230 

231 Parameters 

232 ---------- 

233 X 

234 A matrix with `p` predictors and `n` observations. 

235 splits 

236 A matrix containing, for each predictor, the boundaries between bins. 

237 `m` is the maximum number of splits; each row may have shorter 

238 actual length, marked by padding unused locations at the end of the 

239 row with the maximum value allowed by the type. 

240 **kw 

241 Additional arguments are passed to `jax.numpy.searchsorted`. 

242 

243 Returns 

244 ------- 

245 `X` but with each value replaced by the index of the bin it falls into. 

246 """ 

247 

248 @partial(autobatch, max_io_nbytes=2**29) 1ab

249 @vmap 1ab

250 def bin_predictors(x, splits): 1ab

251 dtype = minimal_unsigned_dtype(splits.size) 1ab

252 return jnp.searchsorted(splits, x, **kw).astype(dtype) 1ab

253 

254 return bin_predictors(X, splits) 1ab