JAX extensions

bartz.jaxext

Additions to jax.

bartz.jaxext.vmap_nodoc(fun, *args, **kw)[source]

Acts like jax.vmap but preserves the docstring of the function unchanged.

This is useful if the docstring already takes into account that the arguments have additional axes due to vmap.

bartz.jaxext.minimal_unsigned_dtype(value)[source]

Return the smallest unsigned integer dtype that can represent value.

bartz.jaxext.unique(x, size, fill_value)[source]

Restricted version of jax.numpy.unique that uses less memory.

Parameters:
  • x (Shaped[Array, '_']) – The input array.

  • size (int) – The length of the output.

  • fill_value (Shaped[Array, '']) – The value to fill the output with if size is greater than the number of unique values in x.

Returns:

  • out (Shaped[Array, ‘{size}’]) – The unique values in x, sorted, and right-padded with fill_value.

  • actual_length (int) – The number of used values in out.

class bartz.jaxext.split(key, num=2)[source]

Split a key into num keys.

Parameters:
  • key (Key[Array, '']) – The key to split.

  • num (int, default: 2) – The number of keys to split into.

pop(shape=None)[source]

Pop one or more keys from the list.

Parameters:

shape (int | tuple[int, ...] | None, default: None) – The shape of the keys to pop. If None, a single key is popped. If an integer, that many keys are popped. If a tuple, the keys are reshaped to that shape.

Returns:

Key[Array, '*']The popped keys as a jax array with the requested shape.

Raises:

IndexError – If shape is larger than the number of keys left in the list.

Notes

The keys are popped from the beginning of the list, so for example list(keys.pop(2)) is equivalent to [keys.pop(), keys.pop()].

bartz.jaxext.truncated_normal_onesided(key, shape, upper, bound)[source]

Sample from a one-sided truncated standard normal distribution.

Parameters:
  • key (Key[Array, '']) – JAX random key.

  • shape (Sequence[int]) – Shape of output array, broadcasted with other inputs.

  • upper (Bool[Array, '*']) – True for (-∞, bound], False for [bound, ∞).

  • bound (Float32[Array, '*']) – The truncation boundary.

Returns:

Float32[Array, '*']Array of samples from the truncated normal distribution.

bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)[source]

Batch a function such that each batch is smaller than a threshold.

Parameters:
  • func (Callable) – A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.

  • max_io_nbytes (int) – The maximum number of input + output bytes in each batch (excluding unbatched arguments.)

  • in_axes (PyTree[None | int], default: 0) – A tree matching (a prefix of) the structure of the function input, indicating along which axes each array should be batched. A None axis indicates to not batch an argument.

  • out_axes (PyTree[int], default: 0) – The same for outputs (but non-batching is not allowed).

  • return_nbatches (bool, default: False) – If True, the number of batches is returned as a second output.

Returns:

Callable – A function with the same signature as func, save for the return value if return_nbatches.

bartz.jaxext.scipy.special

Mockup of the scipy.special module.

bartz.jaxext.scipy.special.gammainccinv(a, y)[source]

Survival function inverse of the Gamma(a, 1) distribution.

bartz.jaxext.scipy.special.ndtri(p)[source]

Compute the inverse of the CDF of the Normal distribution function.

This is a patch of jax.scipy.special.ndtri.

bartz.jaxext.scipy.stats

Mockup of the scipy.stats module.

class bartz.jaxext.scipy.stats.invgamma[source]

Class that represents the distribution InvGamma(a, 1).

static ppf(q, a)[source]

Percentile point function.