Coverage for src/bartz/jaxext.py: 90%

203 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-05 18:54 +0000

1# bartz/src/bartz/jaxext.py 

2# 

3# Copyright (c) 2024, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25import functools 1a

26import math 1a

27import warnings 1a

28 

29from scipy import special 1a

30import jax 1a

31from jax import numpy as jnp 1a

32from jax import tree_util 1a

33from jax import lax 1a

34 

35def float_type(*args): 1a

36 """ 

37 Determine the jax floating point result type given operands/types. 

38 """ 

39 t = jnp.result_type(*args) 1a

40 return jnp.sin(jnp.empty(0, t)).dtype 1a

41 

42def castto(func, type): 1a

43 @functools.wraps(func) 1a

44 def newfunc(*args, **kw): 1a

45 return func(*args, **kw).astype(type) 1a

46 return newfunc 1a

47 

48def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs): 1a

49 """ version of `jax.pure_callback` that deals correctly with ufuncs, 

50 see `<https://github.com/google/jax/issues/17187>`_ """ 

51 if excluded is None: 51 ↛ 53line 51 didn't jump to line 53 because the condition on line 51 was always true1a

52 excluded = () 1a

53 shape = jnp.broadcast_shapes(*( 1a

54 a.shape 

55 for i, a in enumerate(args) 

56 if i not in excluded 

57 )) 

58 ndim = len(shape) 1a

59 padded_args = [ 1a

60 a if i in excluded 

61 else jnp.expand_dims(a, tuple(range(ndim - a.ndim))) 

62 for i, a in enumerate(args) 

63 ] 

64 result = jax.ShapeDtypeStruct(shape, dtype) 1a

65 return jax.pure_callback(callback, result, *padded_args, vectorized=True, **kwargs) 1a

66 

67 # TODO when jax solves this, check version and piggyback on original if new 

68 

69class scipy: 1a

70 

71 class special: 1a

72 

73 @functools.wraps(special.gammainccinv) 1a

74 def gammainccinv(a, y): 1a

75 a = jnp.asarray(a) 1a

76 y = jnp.asarray(y) 1a

77 dtype = float_type(a.dtype, y.dtype) 1a

78 ufunc = castto(special.gammainccinv, dtype) 1a

79 return pure_callback_ufunc(ufunc, dtype, a, y) 1a

80 

81 class stats: 1a

82 

83 class invgamma: 1a

84 

85 def ppf(q, a): 1a

86 return 1 / scipy.special.gammainccinv(a, q) 1a

87 

88@functools.wraps(jax.vmap) 1a

89def vmap_nodoc(fun, *args, **kw): 1a

90 """ 

91 Version of `jax.vmap` that preserves the docstring of the input function. 

92 """ 

93 doc = fun.__doc__ 1a

94 fun = jax.vmap(fun, *args, **kw) 1a

95 fun.__doc__ = doc 1a

96 return fun 1a

97 

98def huge_value(x): 1a

99 """ 

100 Return the maximum value that can be stored in `x`. 

101 

102 Parameters 

103 ---------- 

104 x : array 

105 A numerical numpy or jax array. 

106 

107 Returns 

108 ------- 

109 maxval : scalar 

110 The maximum value allowed by `x`'s type (+inf for floats). 

111 """ 

112 if jnp.issubdtype(x.dtype, jnp.integer): 1a

113 return jnp.iinfo(x.dtype).max 1a

114 else: 

115 return jnp.inf 1a

116 

117def minimal_unsigned_dtype(max_value): 1a

118 """ 

119 Return the smallest unsigned integer dtype that can represent a given 

120 maximum value (inclusive). 

121 """ 

122 if max_value < 2 ** 8: 122 ↛ 124line 122 didn't jump to line 124 because the condition on line 122 was always true1a

123 return jnp.uint8 1a

124 if max_value < 2 ** 16: 

125 return jnp.uint16 

126 if max_value < 2 ** 32: 

127 return jnp.uint32 

128 return jnp.uint64 

129 

130def signed_to_unsigned(int_dtype): 1a

131 """ 

132 Map a signed integer type to its unsigned counterpart. Unsigned types are 

133 passed through. 

134 """ 

135 assert jnp.issubdtype(int_dtype, jnp.integer) 1a

136 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true1a

137 return int_dtype 

138 if int_dtype == jnp.int8: 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true1a

139 return jnp.uint8 

140 if int_dtype == jnp.int16: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true1a

141 return jnp.uint16 

142 if int_dtype == jnp.int32: 142 ↛ 144line 142 didn't jump to line 144 because the condition on line 142 was always true1a

143 return jnp.uint32 1a

144 if int_dtype == jnp.int64: 

145 return jnp.uint64 

146 

147def ensure_unsigned(x): 1a

148 """ 

149 If x has signed integer type, cast it to the unsigned dtype of the same size. 

150 """ 

151 return x.astype(signed_to_unsigned(x.dtype)) 1a

152 

153@functools.partial(jax.jit, static_argnums=(1,)) 1a

154def unique(x, size, fill_value): 1a

155 """ 

156 Restricted version of `jax.numpy.unique` that uses less memory. 

157 

158 Parameters 

159 ---------- 

160 x : 1d array 

161 The input array. 

162 size : int 

163 The length of the output. 

164 fill_value : scalar 

165 The value to fill the output with if `size` is greater than the number 

166 of unique values in `x`. 

167 

168 Returns 

169 ------- 

170 out : array (size,) 

171 The unique values in `x`, sorted, and right-padded with `fill_value`. 

172 actual_length : int 

173 The number of used values in `out`. 

174 """ 

175 if x.size == 0: 1a

176 return jnp.full(size, fill_value, x.dtype), 0 1a

177 if size == 0: 1a

178 return jnp.empty(0, x.dtype), 0 1a

179 x = jnp.sort(x) 1a

180 def loop(carry, x): 1a

181 i_out, i_in, last, out = carry 1a

182 i_out = jnp.where(x == last, i_out, i_out + 1) 1a

183 out = out.at[i_out].set(x) 1a

184 return (i_out, i_in + 1, x, out), None 1a

185 carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype) 1a

186 (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size]) 1a

187 return out, actual_length + 1 1a

188 

189def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False): 1a

190 """ 

191 Batch a function such that each batch is smaller than a threshold. 

192 

193 Parameters 

194 ---------- 

195 func : callable 

196 A jittable function with positional arguments only, with inputs and 

197 outputs pytrees of arrays. 

198 max_io_nbytes : int 

199 The maximum number of input + output bytes in each batch (excluding 

200 unbatched arguments.) 

201 in_axes : pytree of int or None, default 0 

202 A tree matching the structure of the function input, indicating along 

203 which axes each array should be batched. If a single integer, it is 

204 used for all arrays. A `None` axis indicates to not batch an argument. 

205 out_axes : pytree of ints, default 0 

206 The same for outputs (but non-batching is not allowed). 

207 return_nbatches : bool, default False 

208 If True, the number of batches is returned as a second output. 

209 

210 Returns 

211 ------- 

212 batched_func : callable 

213 A function with the same signature as `func`, but that processes the 

214 input and output in batches in a loop. 

215 """ 

216 

217 def expand_axes(axes, tree): 1a

218 if isinstance(axes, int): 1a

219 return tree_util.tree_map(lambda _: axes, tree) 1a

220 return tree_util.tree_map(lambda _, axis: axis, tree, axes) 1a

221 

222 def check_no_nones(axes, tree): 1a

223 def check_not_none(_, axis): 1a

224 assert axis is not None 1a

225 tree_util.tree_map(check_not_none, tree, axes) 1a

226 

227 def extract_size(axes, tree): 1a

228 def get_size(x, axis): 1a

229 if axis is None: 1a

230 return None 1a

231 else: 

232 return x.shape[axis] 1a

233 sizes = tree_util.tree_map(get_size, tree, axes) 1a

234 sizes, _ = tree_util.tree_flatten(sizes) 1a

235 assert all(s == sizes[0] for s in sizes) 1a

236 return sizes[0] 1a

237 

238 def sum_nbytes(tree): 1a

239 def nbytes(x): 1a

240 return math.prod(x.shape) * x.dtype.itemsize 1a

241 return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 1a

242 

243 def next_divisor_small(dividend, min_divisor): 1a

244 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 244 ↛ 247line 244 didn't jump to line 247 because the loop on line 244 didn't complete1a

245 if dividend % divisor == 0: 245 ↛ 244line 245 didn't jump to line 244 because the condition on line 245 was always true1a

246 return divisor 1a

247 return dividend 

248 

249 def next_divisor_large(dividend, min_divisor): 1a

250 max_inv_divisor = dividend // min_divisor 1a

251 for inv_divisor in range(max_inv_divisor, 0, -1): 1a

252 if dividend % inv_divisor == 0: 252 ↛ 251line 252 didn't jump to line 251 because the condition on line 252 was always true1a

253 return dividend // inv_divisor 1a

254 return dividend 1a

255 

256 def next_divisor(dividend, min_divisor): 1a

257 if dividend == 0: 1a

258 return min_divisor 1a

259 if min_divisor * min_divisor <= dividend: 1a

260 return next_divisor_small(dividend, min_divisor) 1a

261 return next_divisor_large(dividend, min_divisor) 1a

262 

263 def pull_nonbatched(axes, tree): 1a

264 def pull_nonbatched(x, axis): 1a

265 if axis is None: 1a

266 return None 1a

267 else: 

268 return x 1a

269 return tree_util.tree_map(pull_nonbatched, tree, axes), tree 1a

270 

271 def push_nonbatched(axes, tree, original_tree): 1a

272 def push_nonbatched(original_x, x, axis): 1a

273 if axis is None: 1a

274 return original_x 1a

275 else: 

276 return x 1a

277 return tree_util.tree_map(push_nonbatched, original_tree, tree, axes) 1a

278 

279 def move_axes_out(axes, tree): 1a

280 def move_axis_out(x, axis): 1a

281 return jnp.moveaxis(x, axis, 0) 1a

282 return tree_util.tree_map(move_axis_out, tree, axes) 1a

283 

284 def move_axes_in(axes, tree): 1a

285 def move_axis_in(x, axis): 1a

286 return jnp.moveaxis(x, 0, axis) 1a

287 return tree_util.tree_map(move_axis_in, tree, axes) 1a

288 

289 def batch(tree, nbatches): 1a

290 def batch(x): 1a

291 return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:]) 1a

292 return tree_util.tree_map(batch, tree) 1a

293 

294 def unbatch(tree): 1a

295 def unbatch(x): 1a

296 return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) 1a

297 return tree_util.tree_map(unbatch, tree) 1a

298 

299 def check_same(tree1, tree2): 1a

300 def check_same(x1, x2): 1a

301 assert x1.shape == x2.shape 1a

302 assert x1.dtype == x2.dtype 1a

303 tree_util.tree_map(check_same, tree1, tree2) 1a

304 

305 initial_in_axes = in_axes 1a

306 initial_out_axes = out_axes 1a

307 

308 @jax.jit 1a

309 @functools.wraps(func) 1a

310 def batched_func(*args): 1a

311 example_result = jax.eval_shape(func, *args) 1a

312 

313 in_axes = expand_axes(initial_in_axes, args) 1a

314 out_axes = expand_axes(initial_out_axes, example_result) 1a

315 check_no_nones(out_axes, example_result) 1a

316 

317 size = extract_size((in_axes, out_axes), (args, example_result)) 1a

318 

319 args, nonbatched_args = pull_nonbatched(in_axes, args) 1a

320 

321 total_nbytes = sum_nbytes((args, example_result)) 1a

322 min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes) 1a

323 min_nbatches = max(1, min_nbatches) 1a

324 nbatches = next_divisor(size, min_nbatches) 1a

325 assert 1 <= nbatches <= max(1, size) 1a

326 assert size % nbatches == 0 1a

327 assert total_nbytes % nbatches == 0 1a

328 

329 batch_nbytes = total_nbytes // nbatches 1a

330 if batch_nbytes > max_io_nbytes: 1a

331 assert size == nbatches 1a

332 warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}') 1a

333 

334 def loop(_, args): 1a

335 args = move_axes_in(in_axes, args) 1a

336 args = push_nonbatched(in_axes, args, nonbatched_args) 1a

337 result = func(*args) 1a

338 result = move_axes_out(out_axes, result) 1a

339 return None, result 1a

340 

341 args = move_axes_out(in_axes, args) 1a

342 args = batch(args, nbatches) 1a

343 _, result = lax.scan(loop, None, args) 1a

344 result = unbatch(result) 1a

345 result = move_axes_in(out_axes, result) 1a

346 

347 check_same(example_result, result) 1a

348 

349 if return_nbatches: 1a

350 return result, nbatches 1a

351 return result 1a

352 

353 return batched_func 1a

354 

355@tree_util.register_pytree_node_class 1a

356class LeafDict(dict): 1a

357 """ dictionary that acts as a leaf in jax pytrees, to store compile-time 

358 values """ 

359 

360 def tree_flatten(self): 1a

361 return (), self 1a

362 

363 @classmethod 1a

364 def tree_unflatten(cls, aux_data, children): 1a

365 return aux_data 1a

366 

367 def __repr__(self): 1a

368 return f'{__class__.__name__}({super().__repr__()})' 1a