Coverage for src/bartz/jaxext/__init__.py: 92%

69 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-27 14:46 +0000

1# bartz/src/bartz/jaxext/__init__.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Additions to jax.""" 

26 

27import functools 1ab

28import math 1ab

29from collections.abc import Sequence 1ab

30 

31import jax 1ab

32from jax import numpy as jnp 1ab

33from jax import random 1ab

34from jax.lax import scan 1ab

35from jax.scipy.special import ndtr 1ab

36from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped 1ab

37 

38from bartz.jaxext._autobatch import autobatch # noqa: F401 1ab

39from bartz.jaxext.scipy.special import ndtri 1ab

40 

41 

42def vmap_nodoc(fun, *args, **kw): 1ab

43 """ 

44 Acts like `jax.vmap` but preserves the docstring of the function unchanged. 

45 

46 This is useful if the docstring already takes into account that the 

47 arguments have additional axes due to vmap. 

48 """ 

49 doc = fun.__doc__ 1ab

50 fun = jax.vmap(fun, *args, **kw) 1ab

51 fun.__doc__ = doc 1ab

52 return fun 1ab

53 

54 

55def minimal_unsigned_dtype(value): 1ab

56 """Return the smallest unsigned integer dtype that can represent `value`.""" 

57 if value < 2**8: 1ab

58 return jnp.uint8 1ab

59 if value < 2**16: 59 ↛ 61line 59 didn't jump to line 61 because the condition on line 59 was always true1ab

60 return jnp.uint16 1ab

61 if value < 2**32: 

62 return jnp.uint32 

63 return jnp.uint64 

64 

65 

66@functools.partial(jax.jit, static_argnums=(1,)) 1ab

67def unique( 1ab

68 x: Shaped[Array, ' _'], size: int, fill_value: Scalar 

69) -> tuple[Shaped[Array, ' {size}'], int]: 

70 """ 

71 Restricted version of `jax.numpy.unique` that uses less memory. 

72 

73 Parameters 

74 ---------- 

75 x 

76 The input array. 

77 size 

78 The length of the output. 

79 fill_value 

80 The value to fill the output with if `size` is greater than the number 

81 of unique values in `x`. 

82 

83 Returns 

84 ------- 

85 out : Shaped[Array, '{size}'] 

86 The unique values in `x`, sorted, and right-padded with `fill_value`. 

87 actual_length : int 

88 The number of used values in `out`. 

89 """ 

90 if x.size == 0: 1ab

91 return jnp.full(size, fill_value, x.dtype), 0 1ab

92 if size == 0: 1ab

93 return jnp.empty(0, x.dtype), 0 1ab

94 x = jnp.sort(x) 1ab

95 

96 def loop(carry, x): 1ab

97 i_out, last, out = carry 1ab

98 i_out = jnp.where(x == last, i_out, i_out + 1) 1ab

99 out = out.at[i_out].set(x) 1ab

100 return (i_out, x, out), None 1ab

101 

102 carry = 0, x[0], jnp.full(size, fill_value, x.dtype) 1ab

103 (actual_length, _, out), _ = scan(loop, carry, x[:size]) 1ab

104 return out, actual_length + 1 1ab

105 

106 

107class split: 1ab

108 """ 

109 Split a key into `num` keys. 

110 

111 Parameters 

112 ---------- 

113 key 

114 The key to split. 

115 num 

116 The number of keys to split into. 

117 """ 

118 

119 def __init__(self, key: Key[Array, ''], num: int = 2): 1ab

120 self._keys = random.split(key, num) 1ab

121 

122 def __len__(self): 1ab

123 return self._keys.size 1ab

124 

125 def pop(self, shape: int | tuple[int, ...] | None = None) -> Key[Array, '*']: 1ab

126 """ 

127 Pop one or more keys from the list. 

128 

129 Parameters 

130 ---------- 

131 shape 

132 The shape of the keys to pop. If `None`, a single key is popped. 

133 If an integer, that many keys are popped. If a tuple, the keys are 

134 reshaped to that shape. 

135 

136 Returns 

137 ------- 

138 The popped keys as a jax array with the requested shape. 

139 

140 Raises 

141 ------ 

142 IndexError 

143 If `shape` is larger than the number of keys left in the list. 

144 

145 Notes 

146 ----- 

147 The keys are popped from the beginning of the list, so for example 

148 ``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``. 

149 """ 

150 if shape is None: 1ab

151 shape = () 1ab

152 elif not isinstance(shape, tuple): 152 ↛ 154line 152 didn't jump to line 154 because the condition on line 152 was always true1ab

153 shape = (shape,) 1ab

154 size_to_pop = math.prod(shape) 1ab

155 if size_to_pop > self._keys.size: 1ab

156 msg = f'Cannot pop {size_to_pop} keys from {self._keys.size} keys' 1ab

157 raise IndexError(msg) 1ab

158 popped_keys = self._keys[:size_to_pop] 1ab

159 self._keys = self._keys[size_to_pop:] 1ab

160 return popped_keys.reshape(shape) 1ab

161 

162 

163def truncated_normal_onesided( 1ab

164 key: Key[Array, ''], 

165 shape: Sequence[int], 

166 upper: Bool[Array, '*'], 

167 bound: Float32[Array, '*'], 

168) -> Float32[Array, '*']: 

169 """ 

170 Sample from a one-sided truncated standard normal distribution. 

171 

172 Parameters 

173 ---------- 

174 key 

175 JAX random key. 

176 shape 

177 Shape of output array, broadcasted with other inputs. 

178 upper 

179 True for (-∞, bound], False for [bound, ∞). 

180 bound 

181 The truncation boundary. 

182 

183 Returns 

184 ------- 

185 Array of samples from the truncated normal distribution. 

186 """ 

187 # Pseudocode: 

188 # | if upper: 

189 # | if bound < 0: 

190 # | ndtri(uniform(0, ndtr(bound))) = 

191 # | ndtri(ndtr(bound) * u) 

192 # | if bound > 0: 

193 # | -ndtri(uniform(ndtr(-bound), 1)) = 

194 # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u)) 

195 # | if not upper: 

196 # | if bound < 0: 

197 # | ndtri(uniform(ndtr(bound), 1)) = 

198 # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u)) 

199 # | if bound > 0: 

200 # | -ndtri(uniform(0, ndtr(-bound))) = 

201 # | -ndtri(ndtr(-bound) * u) 

202 shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape) 1ab

203 bound_pos = bound > 0 1ab

204 ndtr_bound = ndtr(bound) 1ab

205 ndtr_neg_bound = ndtr(-bound) 1ab

206 scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound) 1ab

207 shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound) 1ab

208 u = random.uniform(key, shape) 1ab

209 left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)] 1ab

210 right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1) 1ab

211 truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u) 1ab

212 truncated_norm = ndtri(truncated_u) 1ab

213 return jnp.where(bound_pos, -truncated_norm, truncated_norm) 1ab