JAX extensions¶
bartz.jaxext¶
Additions to jax.
- bartz.jaxext.vmap_nodoc(fun, *args, **kw)[source]¶
Acts like
jax.vmapbut preserves the docstring of the function unchanged.This is useful if the docstring already takes into account that the arguments have additional axes due to vmap.
- bartz.jaxext.minimal_unsigned_dtype(value)[source]¶
Return the smallest unsigned integer dtype that can represent
value.
- bartz.jaxext.unique(x, size, fill_value)[source]¶
Restricted version of
jax.numpy.uniquethat uses less memory.- Parameters:
x (
Shaped[Array, '_']) – The input array.size (
int) – The length of the output.fill_value (
Shaped[Array, '']) – The value to fill the output with ifsizeis greater than the number of unique values inx.
- Returns:
out (Shaped[Array, ‘{size}’]) – The unique values in
x, sorted, and right-padded withfill_value.actual_length (int) – The number of used values in
out.
- class bartz.jaxext.split(key, num=2)[source]¶
Split a key into
numkeys.- Parameters:
key (
Key[Array, '']) – The key to split.num (
int, default:2) – The number of keys to split into.
- pop(shape=None)[source]¶
Pop one or more keys from the list.
- Parameters:
shape (
int|tuple[int,...] |None, default:None) – The shape of the keys to pop. IfNone, a single key is popped. If an integer, that many keys are popped. If a tuple, the keys are reshaped to that shape.- Returns:
Key[Array, '*']– The popped keys as a jax array with the requested shape.- Raises:
IndexError – If
shapeis larger than the number of keys left in the list.
Notes
The keys are popped from the beginning of the list, so for example
list(keys.pop(2))is equivalent to[keys.pop(), keys.pop()].
- bartz.jaxext.truncated_normal_onesided(key, shape, upper, bound)[source]¶
Sample from a one-sided truncated standard normal distribution.
- Parameters:
key (
Key[Array, '']) – JAX random key.shape (
Sequence[int]) – Shape of output array, broadcasted with other inputs.upper (
Bool[Array, '*']) – True for (-∞, bound], False for [bound, ∞).bound (
Float32[Array, '*']) – The truncation boundary.
- Returns:
Float32[Array, '*']– Array of samples from the truncated normal distribution.
- bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)[source]¶
Batch a function such that each batch is smaller than a threshold.
- Parameters:
func (
Callable) – A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.max_io_nbytes (
int) – The maximum number of input + output bytes in each batch (excluding unbatched arguments.)in_axes (
PyTree[int | None], default:0) – A tree matching (a prefix of) the structure of the function input, indicating along which axes each array should be batched. ANoneaxis indicates to not batch an argument.out_axes (
PyTree[int], default:0) – The same for outputs (but non-batching is not allowed).return_nbatches (
bool, default:False) – If True, the number of batches is returned as a second output.
- Returns:
Callable– A function with the same signature asfunc, save for the return value ifreturn_nbatches.
bartz.jaxext.scipy.special¶
Mockup of the scipy.special module.
- bartz.jaxext.scipy.special.gammainccinv(a, y)[source]¶
Survival function inverse of the Gamma(a, 1) distribution.
- bartz.jaxext.scipy.special.ndtri(p)[source]¶
Compute the inverse of the CDF of the Normal distribution function.
This is a patch of
jax.scipy.special.ndtri.
bartz.jaxext.scipy.stats¶
Mockup of the scipy.stats module.