Coverage for src/bartz/jaxext.py: 91%
203 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-05-29 23:01 +0000
1# bartz/src/bartz/jaxext.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Additions to jax."""
27import functools 1ab
28import math 1ab
29import warnings 1ab
31import jax 1ab
32from jax import lax, random, tree_util 1ab
33from jax import numpy as jnp 1ab
34from scipy import special 1ab
37def float_type(*args): 1ab
38 """Determine the jax floating point result type given operands/types."""
39 t = jnp.result_type(*args) 1ab
40 return jnp.sin(jnp.empty(0, t)).dtype 1ab
43def _castto(func, type): 1ab
44 @functools.wraps(func) 1ab
45 def newfunc(*args, **kw): 1ab
46 return func(*args, **kw).astype(type) 1ab
48 return newfunc 1ab
51class scipy: 1ab
52 """Mockup of the :external:py:mod:`scipy` module."""
54 class special: 1ab
55 """Mockup of the :external:py:mod:`scipy.special` module."""
57 @staticmethod 1ab
58 def gammainccinv(a, y): 1ab
59 """Survival function inverse of the Gamma(a, 1) distribution."""
60 a = jnp.asarray(a) 1ab
61 y = jnp.asarray(y) 1ab
62 shape = jnp.broadcast_shapes(a.shape, y.shape) 1ab
63 dtype = float_type(a.dtype, y.dtype) 1ab
64 dummy = jax.ShapeDtypeStruct(shape, dtype) 1ab
65 ufunc = _castto(special.gammainccinv, dtype) 1ab
66 return jax.pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims') 1ab
68 class stats: 1ab
69 """Mockup of the :external:py:mod:`scipy.stats` module."""
71 class invgamma: 1ab
72 """Class that represents the distribution InvGamma(a, 1)."""
74 @staticmethod 1ab
75 def ppf(q, a): 1ab
76 """Percentile point function."""
77 return 1 / scipy.special.gammainccinv(a, q) 1ab
80def vmap_nodoc(fun, *args, **kw): 1ab
81 """
82 Acts like `jax.vmap` but preserves the docstring of the function unchanged.
84 This is useful if the docstring already takes into account that the
85 arguments have additional axes due to vmap.
86 """
87 doc = fun.__doc__ 1ab
88 fun = jax.vmap(fun, *args, **kw) 1ab
89 fun.__doc__ = doc 1ab
90 return fun 1ab
93def huge_value(x): 1ab
94 """
95 Return the maximum value that can be stored in `x`.
97 Parameters
98 ----------
99 x : array
100 A numerical numpy or jax array.
102 Returns
103 -------
104 maxval : scalar
105 The maximum value allowed by `x`'s type (+inf for floats).
106 """
107 if jnp.issubdtype(x.dtype, jnp.integer): 1ab
108 return jnp.iinfo(x.dtype).max 1ab
109 else:
110 return jnp.inf 1ab
113def minimal_unsigned_dtype(value): 1ab
114 """Return the smallest unsigned integer dtype that can represent `value`."""
115 if value < 2**8: 1ab
116 return jnp.uint8 1ab
117 if value < 2**16: 117 ↛ 119line 117 didn't jump to line 119 because the condition on line 117 was always true1ab
118 return jnp.uint16 1ab
119 if value < 2**32:
120 return jnp.uint32
121 return jnp.uint64
124def signed_to_unsigned(int_dtype): 1ab
125 """
126 Map a signed integer type to its unsigned counterpart.
128 Unsigned types are passed through.
129 """
130 assert jnp.issubdtype(int_dtype, jnp.integer) 1ab
131 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 131 ↛ 132line 131 didn't jump to line 132 because the condition on line 131 was never true1ab
132 return int_dtype
133 if int_dtype == jnp.int8: 133 ↛ 134line 133 didn't jump to line 134 because the condition on line 133 was never true1ab
134 return jnp.uint8
135 if int_dtype == jnp.int16: 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true1ab
136 return jnp.uint16
137 if int_dtype == jnp.int32: 137 ↛ 139line 137 didn't jump to line 139 because the condition on line 137 was always true1ab
138 return jnp.uint32 1ab
139 if int_dtype == jnp.int64:
140 return jnp.uint64
143def ensure_unsigned(x): 1ab
144 """If x has signed integer type, cast it to the unsigned dtype of the same size."""
145 return x.astype(signed_to_unsigned(x.dtype)) 1ab
148@functools.partial(jax.jit, static_argnums=(1,)) 1ab
149def unique(x, size, fill_value): 1ab
150 """
151 Restricted version of `jax.numpy.unique` that uses less memory.
153 Parameters
154 ----------
155 x : 1d array
156 The input array.
157 size : int
158 The length of the output.
159 fill_value : scalar
160 The value to fill the output with if `size` is greater than the number
161 of unique values in `x`.
163 Returns
164 -------
165 out : array (size,)
166 The unique values in `x`, sorted, and right-padded with `fill_value`.
167 actual_length : int
168 The number of used values in `out`.
169 """
170 if x.size == 0: 1ab
171 return jnp.full(size, fill_value, x.dtype), 0 1ab
172 if size == 0: 1ab
173 return jnp.empty(0, x.dtype), 0 1ab
174 x = jnp.sort(x) 1ab
176 def loop(carry, x): 1ab
177 i_out, i_in, last, out = carry 1ab
178 i_out = jnp.where(x == last, i_out, i_out + 1) 1ab
179 out = out.at[i_out].set(x) 1ab
180 return (i_out, i_in + 1, x, out), None 1ab
182 carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype) 1ab
183 (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size]) 1ab
184 return out, actual_length + 1 1ab
187def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False): 1ab
188 """
189 Batch a function such that each batch is smaller than a threshold.
191 Parameters
192 ----------
193 func : callable
194 A jittable function with positional arguments only, with inputs and
195 outputs pytrees of arrays.
196 max_io_nbytes : int
197 The maximum number of input + output bytes in each batch (excluding
198 unbatched arguments.)
199 in_axes : pytree of int or None, default 0
200 A tree matching the structure of the function input, indicating along
201 which axes each array should be batched. If a single integer, it is
202 used for all arrays. A `None` axis indicates to not batch an argument.
203 out_axes : pytree of ints, default 0
204 The same for outputs (but non-batching is not allowed).
205 return_nbatches : bool, default False
206 If True, the number of batches is returned as a second output.
208 Returns
209 -------
210 batched_func : callable
211 A function with the same signature as `func`, but that processes the
212 input and output in batches in a loop.
213 """
215 def expand_axes(axes, tree): 1ab
216 if isinstance(axes, int): 1ab
217 return tree_util.tree_map(lambda _: axes, tree) 1ab
218 return tree_util.tree_map(lambda _, axis: axis, tree, axes) 1ab
220 def check_no_nones(axes, tree): 1ab
221 def check_not_none(_, axis): 1ab
222 assert axis is not None 1ab
224 tree_util.tree_map(check_not_none, tree, axes) 1ab
226 def extract_size(axes, tree): 1ab
227 def get_size(x, axis): 1ab
228 if axis is None: 1ab
229 return None 1ab
230 else:
231 return x.shape[axis] 1ab
233 sizes = tree_util.tree_map(get_size, tree, axes) 1ab
234 sizes, _ = tree_util.tree_flatten(sizes) 1ab
235 assert all(s == sizes[0] for s in sizes) 1ab
236 return sizes[0] 1ab
238 def sum_nbytes(tree): 1ab
239 def nbytes(x): 1ab
240 return math.prod(x.shape) * x.dtype.itemsize 1ab
242 return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 1ab
244 def next_divisor_small(dividend, min_divisor): 1ab
245 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 245 ↛ 248line 245 didn't jump to line 248 because the loop on line 245 didn't complete1ab
246 if dividend % divisor == 0: 246 ↛ 245line 246 didn't jump to line 245 because the condition on line 246 was always true1ab
247 return divisor 1ab
248 return dividend
250 def next_divisor_large(dividend, min_divisor): 1ab
251 max_inv_divisor = dividend // min_divisor 1ab
252 for inv_divisor in range(max_inv_divisor, 0, -1): 1ab
253 if dividend % inv_divisor == 0: 253 ↛ 252line 253 didn't jump to line 252 because the condition on line 253 was always true1ab
254 return dividend // inv_divisor 1ab
255 return dividend 1ab
257 def next_divisor(dividend, min_divisor): 1ab
258 if dividend == 0: 1ab
259 return min_divisor 1ab
260 if min_divisor * min_divisor <= dividend: 1ab
261 return next_divisor_small(dividend, min_divisor) 1ab
262 return next_divisor_large(dividend, min_divisor) 1ab
264 def pull_nonbatched(axes, tree): 1ab
265 def pull_nonbatched(x, axis): 1ab
266 if axis is None: 1ab
267 return None 1ab
268 else:
269 return x 1ab
271 return tree_util.tree_map(pull_nonbatched, tree, axes), tree 1ab
273 def push_nonbatched(axes, tree, original_tree): 1ab
274 def push_nonbatched(original_x, x, axis): 1ab
275 if axis is None: 1ab
276 return original_x 1ab
277 else:
278 return x 1ab
280 return tree_util.tree_map(push_nonbatched, original_tree, tree, axes) 1ab
282 def move_axes_out(axes, tree): 1ab
283 def move_axis_out(x, axis): 1ab
284 return jnp.moveaxis(x, axis, 0) 1ab
286 return tree_util.tree_map(move_axis_out, tree, axes) 1ab
288 def move_axes_in(axes, tree): 1ab
289 def move_axis_in(x, axis): 1ab
290 return jnp.moveaxis(x, 0, axis) 1ab
292 return tree_util.tree_map(move_axis_in, tree, axes) 1ab
294 def batch(tree, nbatches): 1ab
295 def batch(x): 1ab
296 return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:]) 1ab
298 return tree_util.tree_map(batch, tree) 1ab
300 def unbatch(tree): 1ab
301 def unbatch(x): 1ab
302 return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) 1ab
304 return tree_util.tree_map(unbatch, tree) 1ab
306 def check_same(tree1, tree2): 1ab
307 def check_same(x1, x2): 1ab
308 assert x1.shape == x2.shape 1ab
309 assert x1.dtype == x2.dtype 1ab
311 tree_util.tree_map(check_same, tree1, tree2) 1ab
313 initial_in_axes = in_axes 1ab
314 initial_out_axes = out_axes 1ab
316 @jax.jit 1ab
317 @functools.wraps(func) 1ab
318 def batched_func(*args): 1ab
319 example_result = jax.eval_shape(func, *args) 1ab
321 in_axes = expand_axes(initial_in_axes, args) 1ab
322 out_axes = expand_axes(initial_out_axes, example_result) 1ab
323 check_no_nones(out_axes, example_result) 1ab
325 size = extract_size((in_axes, out_axes), (args, example_result)) 1ab
327 args, nonbatched_args = pull_nonbatched(in_axes, args) 1ab
329 total_nbytes = sum_nbytes((args, example_result)) 1ab
330 min_nbatches = total_nbytes // max_io_nbytes + bool( 1ab
331 total_nbytes % max_io_nbytes
332 )
333 min_nbatches = max(1, min_nbatches) 1ab
334 nbatches = next_divisor(size, min_nbatches) 1ab
335 assert 1 <= nbatches <= max(1, size) 1ab
336 assert size % nbatches == 0 1ab
337 assert total_nbytes % nbatches == 0 1ab
339 batch_nbytes = total_nbytes // nbatches 1ab
340 if batch_nbytes > max_io_nbytes: 1ab
341 assert size == nbatches 1ab
342 warnings.warn( 1ab
343 f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
344 )
346 def loop(_, args): 1ab
347 args = move_axes_in(in_axes, args) 1ab
348 args = push_nonbatched(in_axes, args, nonbatched_args) 1ab
349 result = func(*args) 1ab
350 result = move_axes_out(out_axes, result) 1ab
351 return None, result 1ab
353 args = move_axes_out(in_axes, args) 1ab
354 args = batch(args, nbatches) 1ab
355 _, result = lax.scan(loop, None, args) 1ab
356 result = unbatch(result) 1ab
357 result = move_axes_in(out_axes, result) 1ab
359 check_same(example_result, result) 1ab
361 if return_nbatches: 1ab
362 return result, nbatches 1ab
363 return result 1ab
365 return batched_func 1ab
368class split: 1ab
369 """
370 Split a key into `num` keys.
372 Parameters
373 ----------
374 key : jax.dtypes.prng_key array
375 The key to split.
376 num : int
377 The number of keys to split into.
378 """
380 def __init__(self, key, num=2): 1ab
381 self._keys = random.split(key, num) 1ab
383 def __len__(self): 1ab
384 return self._keys.size 1ab
386 def pop(self, shape=None): 1ab
387 """
388 Pop one or more keys from the list.
390 Parameters
391 ----------
392 shape : int or tuple of int, optional
393 The shape of the keys to pop. If `None`, a single key is popped.
394 If an integer, that many keys are popped. If a tuple, the keys are
395 reshaped to that shape.
397 Returns
398 -------
399 keys : jax.dtypes.prng_key array
400 The popped keys.
402 Raises
403 ------
404 IndexError
405 If `shape` is larger than the number of keys left in the list.
407 Notes
408 -----
409 The keys are popped from the beginning of the list, so for example
410 ``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
411 """
412 if shape is None: 1ab
413 shape = () 1ab
414 elif not isinstance(shape, tuple): 414 ↛ 416line 414 didn't jump to line 416 because the condition on line 414 was always true1ab
415 shape = (shape,) 1ab
416 size_to_pop = math.prod(shape) 1ab
417 if size_to_pop > self._keys.size: 1ab
418 raise IndexError( 1ab
419 f'Cannot pop {size_to_pop} keys from {self._keys.size} keys'
420 )
421 popped_keys = self._keys[:size_to_pop] 1ab
422 self._keys = self._keys[size_to_pop:] 1ab
423 return popped_keys.reshape(shape) 1ab