JAX extensions¶
bartz.jaxext¶
Additions to jax.
- bartz.jaxext.vmap_nodoc(fun, *args, **kw)[source]¶
Acts like
jax.vmap
but preserves the docstring of the function unchanged.This is useful if the docstring already takes into account that the arguments have additional axes due to vmap.
- bartz.jaxext.minimal_unsigned_dtype(value)[source]¶
Return the smallest unsigned integer dtype that can represent
value
.
- bartz.jaxext.unique(x, size, fill_value)[source]¶
Restricted version of
jax.numpy.unique
that uses less memory.- Parameters:
x (
Shaped[Array, '_']
) – The input array.size (
int
) – The length of the output.fill_value (
Shaped[Array, '']
) – The value to fill the output with ifsize
is greater than the number of unique values inx
.
- Returns:
out (Shaped[Array, ‘{size}’]) – The unique values in
x
, sorted, and right-padded withfill_value
.actual_length (int) – The number of used values in
out
.
- class bartz.jaxext.split(key, num=2)[source]¶
Split a key into
num
keys.- Parameters:
key (
Key[Array, '']
) – The key to split.num (
int
, default:2
) – The number of keys to split into.
- pop(shape=None)[source]¶
Pop one or more keys from the list.
- Parameters:
shape (
int
|tuple
[int
,...
] |None
, default:None
) – The shape of the keys to pop. IfNone
, a single key is popped. If an integer, that many keys are popped. If a tuple, the keys are reshaped to that shape.- Returns:
Key[Array, '*']
– The popped keys as a jax array with the requested shape.- Raises:
IndexError – If
shape
is larger than the number of keys left in the list.
Notes
The keys are popped from the beginning of the list, so for example
list(keys.pop(2))
is equivalent to[keys.pop(), keys.pop()]
.
- bartz.jaxext.truncated_normal_onesided(key, shape, upper, bound)[source]¶
Sample from a one-sided truncated standard normal distribution.
- Parameters:
key (
Key[Array, '']
) – JAX random key.shape (
Sequence
[int
]) – Shape of output array, broadcasted with other inputs.upper (
Bool[Array, '*']
) – True for (-∞, bound], False for [bound, ∞).bound (
Float32[Array, '*']
) – The truncation boundary.
- Returns:
Float32[Array, '*']
– Array of samples from the truncated normal distribution.
- bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)[source]¶
Batch a function such that each batch is smaller than a threshold.
- Parameters:
func (
Callable
) – A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.max_io_nbytes (
int
) – The maximum number of input + output bytes in each batch (excluding unbatched arguments.)in_axes (
PyTree[None | int]
, default:0
) – A tree matching (a prefix of) the structure of the function input, indicating along which axes each array should be batched. ANone
axis indicates to not batch an argument.out_axes (
PyTree[int]
, default:0
) – The same for outputs (but non-batching is not allowed).return_nbatches (
bool
, default:False
) – If True, the number of batches is returned as a second output.
- Returns:
Callable
– A function with the same signature asfunc
, save for the return value ifreturn_nbatches
.
bartz.jaxext.scipy.special¶
Mockup of the scipy.special
module.
- bartz.jaxext.scipy.special.gammainccinv(a, y)[source]¶
Survival function inverse of the Gamma(a, 1) distribution.
- bartz.jaxext.scipy.special.ndtri(p)[source]¶
Compute the inverse of the CDF of the Normal distribution function.
This is a patch of
jax.scipy.special.ndtri
.
bartz.jaxext.scipy.stats¶
Mockup of the scipy.stats
module.