Coverage for src/bartz/prepcovars.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 18:54 +0000

1# bartz/src/bartz/prepcovars.py 

2# 

3# Copyright (c) 2024, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25import functools 1a

26 

27import jax 1a

28from jax import numpy as jnp 1a

29 

30from . import jaxext 1a

31from . import grove 1a

32 

33@functools.partial(jax.jit, static_argnums=(1,)) 1a

34def quantilized_splits_from_matrix(X, max_bins): 1a

35 """ 

36 Determine bins that make the distribution of each predictor uniform. 

37 

38 Parameters 

39 ---------- 

40 X : array (p, n) 

41 A matrix with `p` predictors and `n` observations. 

42 max_bins : int 

43 The maximum number of bins to produce. 

44 

45 Returns 

46 ------- 

47 splits : array (p, m) 

48 A matrix containing, for each predictor, the boundaries between bins. 

49 `m` is ``min(max_bins, n) - 1``, which is an upper bound on the number 

50 of splits. Each predictor may have a different number of splits; unused 

51 values at the end of each row are filled with the maximum value 

52 representable in the type of `X`. 

53 max_split : array (p,) 

54 The number of actually used values in each row of `splits`. 

55 """ 

56 out_length = min(max_bins, X.shape[1]) - 1 1a

57 # return _quantilized_splits_from_matrix(X, out_length) 

58 @functools.partial(jaxext.autobatch, max_io_nbytes=2 ** 29) 1a

59 def quantilize(X): 1a

60 return _quantilized_splits_from_matrix(X, out_length) 1a

61 return quantilize(X) 1a

62 

63@functools.partial(jax.vmap, in_axes=(0, None)) 1a

64def _quantilized_splits_from_matrix(x, out_length): 1a

65 huge = jaxext.huge_value(x) 1a

66 u, actual_length = jaxext.unique(x, size=x.size, fill_value=huge) 1a

67 actual_length -= 1 1a

68 if jnp.issubdtype(x.dtype, jnp.integer): 1a

69 midpoints = u[:-1] + jaxext.ensure_unsigned(u[1:] - u[:-1]) // 2 1a

70 indices = jnp.arange(midpoints.size, dtype=jaxext.minimal_unsigned_dtype(midpoints.size - 1)) 1a

71 midpoints = jnp.where(indices < actual_length, midpoints, huge) 1a

72 else: 

73 midpoints = (u[1:] + u[:-1]) / 2 1a

74 indices = jnp.linspace(-1, actual_length, out_length + 2)[1:-1] 1a

75 indices = jnp.around(indices).astype(jaxext.minimal_unsigned_dtype(midpoints.size - 1)) 1a

76 # indices calculation with float rather than int to avoid potential 

77 # overflow with int32, and to round to nearest instead of rounding down 

78 decimated_midpoints = midpoints[indices] 1a

79 truncated_midpoints = midpoints[:out_length] 1a

80 splits = jnp.where(actual_length > out_length, decimated_midpoints, truncated_midpoints) 1a

81 max_split = jnp.minimum(actual_length, out_length) 1a

82 max_split = max_split.astype(jaxext.minimal_unsigned_dtype(out_length)) 1a

83 return splits, max_split 1a

84 

85@functools.partial(jax.jit, static_argnums=(1,)) 1a

86def uniform_splits_from_matrix(X, num_bins): 1a

87 """ 

88 Make an evenly spaced binning grid. 

89 

90 Parameters 

91 ---------- 

92 X : array (p, n) 

93 A matrix with `p` predictors and `n` observations. 

94 num_bins : int 

95 The number of bins to produce. 

96 

97 Returns 

98 ------- 

99 splits : array (p, num_bins - 1) 

100 A matrix containing, for each predictor, the boundaries between bins. 

101 The excluded endpoints are the minimum and maximum value in each row of 

102 `X`. 

103 max_split : array (p,) 

104 The number of cutpoints in each row of `splits`, i.e., ``num_bins - 1``. 

105 """ 

106 low = jnp.min(X, axis=1) 1a

107 high = jnp.max(X, axis=1) 1a

108 splits = jnp.linspace(low, high, num_bins + 1, axis=1)[:, 1:-1] 1a

109 assert splits.shape == (X.shape[0], num_bins - 1) 1a

110 max_split = jnp.full(*splits.shape, jaxext.minimal_unsigned_dtype(num_bins - 1)) 1a

111 return splits, max_split 1a

112 

113@functools.partial(jax.jit, static_argnames=('method',)) 1a

114def bin_predictors(X, splits, **kw): 1a

115 """ 

116 Bin the predictors according to the given splits. 

117 

118 A value ``x`` is mapped to bin ``i`` iff ``splits[i - 1] < x <= splits[i]``. 

119 

120 Parameters 

121 ---------- 

122 X : array (p, n) 

123 A matrix with `p` predictors and `n` observations. 

124 splits : array (p, m) 

125 A matrix containing, for each predictor, the boundaries between bins. 

126 `m` is the maximum number of splits; each row may have shorter 

127 actual length, marked by padding unused locations at the end of the 

128 row with the maximum value allowed by the type. 

129 **kw : dict 

130 Additional arguments are passed to `jax.numpy.searchsorted`. 

131 

132 Returns 

133 ------- 

134 X_binned : int array (p, n) 

135 A matrix with `p` predictors and `n` observations, where each predictor 

136 has been replaced by the index of the bin it falls into. 

137 """ 

138 @functools.partial(jaxext.autobatch, max_io_nbytes=2 ** 29) 1a

139 @jax.vmap 1a

140 def bin_predictors(x, splits): 1a

141 dtype = jaxext.minimal_unsigned_dtype(splits.size) 1a

142 return jnp.searchsorted(splits, x, **kw).astype(dtype) 1a

143 return bin_predictors(X, splits) 1a