Coverage for src/bartz/jaxext.py: 91%

203 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-05-29 23:01 +0000

1# bartz/src/bartz/jaxext.py 

2# 

3# Copyright (c) 2024-2025, Giacomo Petrillo 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Additions to jax.""" 

26 

27import functools 1ab

28import math 1ab

29import warnings 1ab

30 

31import jax 1ab

32from jax import lax, random, tree_util 1ab

33from jax import numpy as jnp 1ab

34from scipy import special 1ab

35 

36 

37def float_type(*args): 1ab

38 """Determine the jax floating point result type given operands/types.""" 

39 t = jnp.result_type(*args) 1ab

40 return jnp.sin(jnp.empty(0, t)).dtype 1ab

41 

42 

43def _castto(func, type): 1ab

44 @functools.wraps(func) 1ab

45 def newfunc(*args, **kw): 1ab

46 return func(*args, **kw).astype(type) 1ab

47 

48 return newfunc 1ab

49 

50 

51class scipy: 1ab

52 """Mockup of the :external:py:mod:`scipy` module.""" 

53 

54 class special: 1ab

55 """Mockup of the :external:py:mod:`scipy.special` module.""" 

56 

57 @staticmethod 1ab

58 def gammainccinv(a, y): 1ab

59 """Survival function inverse of the Gamma(a, 1) distribution.""" 

60 a = jnp.asarray(a) 1ab

61 y = jnp.asarray(y) 1ab

62 shape = jnp.broadcast_shapes(a.shape, y.shape) 1ab

63 dtype = float_type(a.dtype, y.dtype) 1ab

64 dummy = jax.ShapeDtypeStruct(shape, dtype) 1ab

65 ufunc = _castto(special.gammainccinv, dtype) 1ab

66 return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims') 1ab

67 

68 class stats: 1ab

69 """Mockup of the :external:py:mod:`scipy.stats` module.""" 

70 

71 class invgamma: 1ab

72 """Class that represents the distribution InvGamma(a, 1).""" 

73 

74 @staticmethod 1ab

75 def ppf(q, a): 1ab

76 """Percentile point function.""" 

77 return 1 / scipy.special.gammainccinv(a, q) 1ab

78 

79 

80def vmap_nodoc(fun, *args, **kw): 1ab

81 """ 

82 Acts like `jax.vmap` but preserves the docstring of the function unchanged. 

83 

84 This is useful if the docstring already takes into account that the 

85 arguments have additional axes due to vmap. 

86 """ 

87 doc = fun.__doc__ 1ab

88 fun = jax.vmap(fun, *args, **kw) 1ab

89 fun.__doc__ = doc 1ab

90 return fun 1ab

91 

92 

93def huge_value(x): 1ab

94 """ 

95 Return the maximum value that can be stored in `x`. 

96 

97 Parameters 

98 ---------- 

99 x : array 

100 A numerical numpy or jax array. 

101 

102 Returns 

103 ------- 

104 maxval : scalar 

105 The maximum value allowed by `x`'s type (+inf for floats). 

106 """ 

107 if jnp.issubdtype(x.dtype, jnp.integer): 1ab

108 return jnp.iinfo(x.dtype).max 1ab

109 else: 

110 return jnp.inf 1ab

111 

112 

113def minimal_unsigned_dtype(value): 1ab

114 """Return the smallest unsigned integer dtype that can represent `value`.""" 

115 if value < 2**8: 1ab

116 return jnp.uint8 1ab

117 if value < 2**16: 117 ↛ 119line 117 didn't jump to line 119 because the condition on line 117 was always true1ab

118 return jnp.uint16 1ab

119 if value < 2**32: 

120 return jnp.uint32 

121 return jnp.uint64 

122 

123 

124def signed_to_unsigned(int_dtype): 1ab

125 """ 

126 Map a signed integer type to its unsigned counterpart. 

127 

128 Unsigned types are passed through. 

129 """ 

130 assert jnp.issubdtype(int_dtype, jnp.integer) 1ab

131 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true1ab

132 return int_dtype 

133 if int_dtype == jnp.int8: 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true1ab

134 return jnp.uint8 

135 if int_dtype == jnp.int16: 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true1ab

136 return jnp.uint16 

137 if int_dtype == jnp.int32: 137 ↛ 139line 137 didn't jump to line 139 because the condition on line 137 was always true1ab

138 return jnp.uint32 1ab

139 if int_dtype == jnp.int64: 

140 return jnp.uint64 

141 

142 

143def ensure_unsigned(x): 1ab

144 """If x has signed integer type, cast it to the unsigned dtype of the same size.""" 

145 return x.astype(signed_to_unsigned(x.dtype)) 1ab

146 

147 

148@functools.partial(jax.jit, static_argnums=(1,)) 1ab

149def unique(x, size, fill_value): 1ab

150 """ 

151 Restricted version of `jax.numpy.unique` that uses less memory. 

152 

153 Parameters 

154 ---------- 

155 x : 1d array 

156 The input array. 

157 size : int 

158 The length of the output. 

159 fill_value : scalar 

160 The value to fill the output with if `size` is greater than the number 

161 of unique values in `x`. 

162 

163 Returns 

164 ------- 

165 out : array (size,) 

166 The unique values in `x`, sorted, and right-padded with `fill_value`. 

167 actual_length : int 

168 The number of used values in `out`. 

169 """ 

170 if x.size == 0: 1ab

171 return jnp.full(size, fill_value, x.dtype), 0 1ab

172 if size == 0: 1ab

173 return jnp.empty(0, x.dtype), 0 1ab

174 x = jnp.sort(x) 1ab

175 

176 def loop(carry, x): 1ab

177 i_out, i_in, last, out = carry 1ab

178 i_out = jnp.where(x == last, i_out, i_out + 1) 1ab

179 out = out.at[i_out].set(x) 1ab

180 return (i_out, i_in + 1, x, out), None 1ab

181 

182 carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype) 1ab

183 (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size]) 1ab

184 return out, actual_length + 1 1ab

185 

186 

187def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False): 1ab

188 """ 

189 Batch a function such that each batch is smaller than a threshold. 

190 

191 Parameters 

192 ---------- 

193 func : callable 

194 A jittable function with positional arguments only, with inputs and 

195 outputs pytrees of arrays. 

196 max_io_nbytes : int 

197 The maximum number of input + output bytes in each batch (excluding 

198 unbatched arguments.) 

199 in_axes : pytree of int or None, default 0 

200 A tree matching the structure of the function input, indicating along 

201 which axes each array should be batched. If a single integer, it is 

202 used for all arrays. A `None` axis indicates to not batch an argument. 

203 out_axes : pytree of ints, default 0 

204 The same for outputs (but non-batching is not allowed). 

205 return_nbatches : bool, default False 

206 If True, the number of batches is returned as a second output. 

207 

208 Returns 

209 ------- 

210 batched_func : callable 

211 A function with the same signature as `func`, but that processes the 

212 input and output in batches in a loop. 

213 """ 

214 

215 def expand_axes(axes, tree): 1ab

216 if isinstance(axes, int): 1ab

217 return tree_util.tree_map(lambda _: axes, tree) 1ab

218 return tree_util.tree_map(lambda _, axis: axis, tree, axes) 1ab

219 

220 def check_no_nones(axes, tree): 1ab

221 def check_not_none(_, axis): 1ab

222 assert axis is not None 1ab

223 

224 tree_util.tree_map(check_not_none, tree, axes) 1ab

225 

226 def extract_size(axes, tree): 1ab

227 def get_size(x, axis): 1ab

228 if axis is None: 1ab

229 return None 1ab

230 else: 

231 return x.shape[axis] 1ab

232 

233 sizes = tree_util.tree_map(get_size, tree, axes) 1ab

234 sizes, _ = tree_util.tree_flatten(sizes) 1ab

235 assert all(s == sizes[0] for s in sizes) 1ab

236 return sizes[0] 1ab

237 

238 def sum_nbytes(tree): 1ab

239 def nbytes(x): 1ab

240 return math.prod(x.shape) * x.dtype.itemsize 1ab

241 

242 return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 1ab

243 

244 def next_divisor_small(dividend, min_divisor): 1ab

245 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 245 ↛ 248line 245 didn't jump to line 248 because the loop on line 245 didn't complete1ab

246 if dividend % divisor == 0: 246 ↛ 245line 246 didn't jump to line 245 because the condition on line 246 was always true1ab

247 return divisor 1ab

248 return dividend 

249 

250 def next_divisor_large(dividend, min_divisor): 1ab

251 max_inv_divisor = dividend // min_divisor 1ab

252 for inv_divisor in range(max_inv_divisor, 0, -1): 1ab

253 if dividend % inv_divisor == 0: 253 ↛ 252line 253 didn't jump to line 252 because the condition on line 253 was always true1ab

254 return dividend // inv_divisor 1ab

255 return dividend 1ab

256 

257 def next_divisor(dividend, min_divisor): 1ab

258 if dividend == 0: 1ab

259 return min_divisor 1ab

260 if min_divisor * min_divisor <= dividend: 1ab

261 return next_divisor_small(dividend, min_divisor) 1ab

262 return next_divisor_large(dividend, min_divisor) 1ab

263 

264 def pull_nonbatched(axes, tree): 1ab

265 def pull_nonbatched(x, axis): 1ab

266 if axis is None: 1ab

267 return None 1ab

268 else: 

269 return x 1ab

270 

271 return tree_util.tree_map(pull_nonbatched, tree, axes), tree 1ab

272 

273 def push_nonbatched(axes, tree, original_tree): 1ab

274 def push_nonbatched(original_x, x, axis): 1ab

275 if axis is None: 1ab

276 return original_x 1ab

277 else: 

278 return x 1ab

279 

280 return tree_util.tree_map(push_nonbatched, original_tree, tree, axes) 1ab

281 

282 def move_axes_out(axes, tree): 1ab

283 def move_axis_out(x, axis): 1ab

284 return jnp.moveaxis(x, axis, 0) 1ab

285 

286 return tree_util.tree_map(move_axis_out, tree, axes) 1ab

287 

288 def move_axes_in(axes, tree): 1ab

289 def move_axis_in(x, axis): 1ab

290 return jnp.moveaxis(x, 0, axis) 1ab

291 

292 return tree_util.tree_map(move_axis_in, tree, axes) 1ab

293 

294 def batch(tree, nbatches): 1ab

295 def batch(x): 1ab

296 return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:]) 1ab

297 

298 return tree_util.tree_map(batch, tree) 1ab

299 

300 def unbatch(tree): 1ab

301 def unbatch(x): 1ab

302 return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) 1ab

303 

304 return tree_util.tree_map(unbatch, tree) 1ab

305 

306 def check_same(tree1, tree2): 1ab

307 def check_same(x1, x2): 1ab

308 assert x1.shape == x2.shape 1ab

309 assert x1.dtype == x2.dtype 1ab

310 

311 tree_util.tree_map(check_same, tree1, tree2) 1ab

312 

313 initial_in_axes = in_axes 1ab

314 initial_out_axes = out_axes 1ab

315 

316 @jax.jit 1ab

317 @functools.wraps(func) 1ab

318 def batched_func(*args): 1ab

319 example_result = jax.eval_shape(func, *args) 1ab

320 

321 in_axes = expand_axes(initial_in_axes, args) 1ab

322 out_axes = expand_axes(initial_out_axes, example_result) 1ab

323 check_no_nones(out_axes, example_result) 1ab

324 

325 size = extract_size((in_axes, out_axes), (args, example_result)) 1ab

326 

327 args, nonbatched_args = pull_nonbatched(in_axes, args) 1ab

328 

329 total_nbytes = sum_nbytes((args, example_result)) 1ab

330 min_nbatches = total_nbytes // max_io_nbytes + bool( 1ab

331 total_nbytes % max_io_nbytes 

332 ) 

333 min_nbatches = max(1, min_nbatches) 1ab

334 nbatches = next_divisor(size, min_nbatches) 1ab

335 assert 1 <= nbatches <= max(1, size) 1ab

336 assert size % nbatches == 0 1ab

337 assert total_nbytes % nbatches == 0 1ab

338 

339 batch_nbytes = total_nbytes // nbatches 1ab

340 if batch_nbytes > max_io_nbytes: 1ab

341 assert size == nbatches 1ab

342 warnings.warn( 1ab

343 f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}' 

344 ) 

345 

346 def loop(_, args): 1ab

347 args = move_axes_in(in_axes, args) 1ab

348 args = push_nonbatched(in_axes, args, nonbatched_args) 1ab

349 result = func(*args) 1ab

350 result = move_axes_out(out_axes, result) 1ab

351 return None, result 1ab

352 

353 args = move_axes_out(in_axes, args) 1ab

354 args = batch(args, nbatches) 1ab

355 _, result = lax.scan(loop, None, args) 1ab

356 result = unbatch(result) 1ab

357 result = move_axes_in(out_axes, result) 1ab

358 

359 check_same(example_result, result) 1ab

360 

361 if return_nbatches: 1ab

362 return result, nbatches 1ab

363 return result 1ab

364 

365 return batched_func 1ab

366 

367 

368class split: 1ab

369 """ 

370 Split a key into `num` keys. 

371 

372 Parameters 

373 ---------- 

374 key : jax.dtypes.prng_key array 

375 The key to split. 

376 num : int 

377 The number of keys to split into. 

378 """ 

379 

380 def __init__(self, key, num=2): 1ab

381 self._keys = random.split(key, num) 1ab

382 

383 def __len__(self): 1ab

384 return self._keys.size 1ab

385 

386 def pop(self, shape=None): 1ab

387 """ 

388 Pop one or more keys from the list. 

389 

390 Parameters 

391 ---------- 

392 shape : int or tuple of int, optional 

393 The shape of the keys to pop. If `None`, a single key is popped. 

394 If an integer, that many keys are popped. If a tuple, the keys are 

395 reshaped to that shape. 

396 

397 Returns 

398 ------- 

399 keys : jax.dtypes.prng_key array 

400 The popped keys. 

401 

402 Raises 

403 ------ 

404 IndexError 

405 If `shape` is larger than the number of keys left in the list. 

406 

407 Notes 

408 ----- 

409 The keys are popped from the beginning of the list, so for example 

410 ``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``. 

411 """ 

412 if shape is None: 1ab

413 shape = () 1ab

414 elif not isinstance(shape, tuple): 414 ↛ 416line 414 didn't jump to line 416 because the condition on line 414 was always true1ab

415 shape = (shape,) 1ab

416 size_to_pop = math.prod(shape) 1ab

417 if size_to_pop > self._keys.size: 1ab

418 raise IndexError( 1ab

419 f'Cannot pop {size_to_pop} keys from {self._keys.size} keys' 

420 ) 

421 popped_keys = self._keys[:size_to_pop] 1ab

422 self._keys = self._keys[size_to_pop:] 1ab

423 return popped_keys.reshape(shape) 1ab