Coverage for src/bartz/prepcovars.py: 100%

45 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-05-29 23:01 +0000

1# bartz/src/bartz/prepcovars.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Functions to preprocess data.""" 

26 

27import functools 1ab

28 

29import jax 1ab

30from jax import numpy as jnp 1ab

31 

32from . import jaxext 1ab

33 

34 

35@functools.partial(jax.jit, static_argnums=(1,)) 1ab

36def quantilized_splits_from_matrix(X, max_bins): 1ab

37 """ 

38 Determine bins that make the distribution of each predictor uniform. 

39 

40 Parameters 

41 ---------- 

42 X : array (p, n) 

43 A matrix with `p` predictors and `n` observations. 

44 max_bins : int 

45 The maximum number of bins to produce. 

46 

47 Returns 

48 ------- 

49 splits : array (p, m) 

50 A matrix containing, for each predictor, the boundaries between bins. 

51 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number 

52 of splits. Each predictor may have a different number of splits; unused 

53 values at the end of each row are filled with the maximum value 

54 representable in the type of `X`. 

55 max_split : array (p,) 

56 The number of actually used values in each row of `splits`. 

57 """ 

58 out_length = min(max_bins, X.shape[1]) - 1 1ab

59 

60 @functools.partial(jaxext.autobatch, max_io_nbytes=2**29) 1ab

61 def quantilize(X): 1ab

62 # wrap this function because autobatch needs traceable args 

63 return _quantilized_splits_from_matrix(X, out_length) 1ab

64 

65 return quantilize(X) 1ab

66 

67 

68@functools.partial(jax.vmap, in_axes=(0, None)) 1ab

69def _quantilized_splits_from_matrix(x, out_length): 1ab

70 huge = jaxext.huge_value(x) 1ab

71 u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge) 1ab

72 actual_length -= 1 1ab

73 if jnp.issubdtype(x.dtype, jnp.integer): 1ab

74 midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2 1ab

75 indices = jnp.arange( 1ab

76 midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1) 

77 ) 

78 midpoints = jnp.where(indices < actual_length, midpoints, huge) 1ab

79 else: 

80 midpoints = (u[1:] + u[:-1]) / 2 1ab

81 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1] 1ab

82 indices = jnp.around(indices).astype( 1ab

83 jaxext.minimal_unsigned_dtype(midpoints.size - 1) 

84 ) 

85 # indices calculation with float rather than int to avoid potential 

86 # overflow with int32, and to round to nearest instead of rounding down 

87 decimated_midpoints = midpoints[indices] 1ab

88 truncated_midpoints = midpoints[:out_length] 1ab

89 splits = jnp.where( 1ab

90 actual_length > out_length, decimated_midpoints, truncated_midpoints 

91 ) 

92 max_split = jnp.minimum(actual_length, out_length) 1ab

93 max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length)) 1ab

94 return splits, max_split 1ab

95 

96 

97@functools.partial(jax.jit, static_argnums=(1,)) 1ab

98def uniform_splits_from_matrix(X, num_bins): 1ab

99 """ 

100 Make an evenly spaced binning grid. 

101 

102 Parameters 

103 ---------- 

104 X : array (p, n) 

105 A matrix with `p` predictors and `n` observations. 

106 num_bins : int 

107 The number of bins to produce. 

108 

109 Returns 

110 ------- 

111 splits : array (p, num_bins - 1) 

112 A matrix containing, for each predictor, the boundaries between bins. 

113 The excluded endpoints are the minimum and maximum value in each row of 

114 `X`. 

115 max_split : array (p,) 

116 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``. 

117 """ 

118 low = jnp.min(X, axis=1) 1ab

119 high = jnp.max(X, axis=1) 1ab

120 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1] 1ab

121 assert splits.shape == (X.shape[0], num_bins - 1) 1ab

122 max_split = jnp.full(*splits.shape, jaxext.minimal_unsigned_dtype(num_bins - 1)) 1ab

123 return splits, max_split 1ab

124 

125 

126@functools.partial(jax.jit, static_argnames=('method',)) 1ab

127def bin_predictors(X, splits, **kw): 1ab

128 """ 

129 Bin the predictors according to the given splits. 

130 

131 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``. 

132 

133 Parameters 

134 ---------- 

135 X : array (p, n) 

136 A matrix with `p` predictors and `n` observations. 

137 splits : array (p, m) 

138 A matrix containing, for each predictor, the boundaries between bins. 

139 `m` is the maximum number of splits; each row may have shorter 

140 actual length, marked by padding unused locations at the end of the 

141 row with the maximum value allowed by the type. 

142 **kw : dict 

143 Additional arguments are passed to `jax.numpy.searchsorted`. 

144 

145 Returns 

146 ------- 

147 X_binned : int array (p, n) 

148 A matrix with `p` predictors and `n` observations, where each predictor 

149 has been replaced by the index of the bin it falls into. 

150 """ 

151 

152 @functools.partial(jaxext.autobatch, max_io_nbytes=2**29) 1ab

153 @jax.vmap 1ab

154 def bin_predictors(x, splits): 1ab

155 dtype = jaxext.minimal_unsigned_dtype(splits.size) 1ab

156 return jnp.searchsorted(splits, x, **kw).astype(dtype) 1ab

157 

158 return bin_predictors(X, splits) 1ab