JAX extensions

bartz.jaxext

Additions to jax.

bartz.jaxext.vmap_nodoc(fun, *args, **kw)[source]

Acts like jax.vmap but preserves the docstring of the function unchanged.

This is useful if the docstring already takes into account that the arguments have additional axes due to vmap.

bartz.jaxext.minimal_unsigned_dtype(value)[source]

Return the smallest unsigned integer dtype that can represent value.

bartz.jaxext.unique(x, size, fill_value)[source]

Restricted version of jax.numpy.unique that uses less memory.

Parameters:
  • x (Shaped[Array, '_']) – The input array.

  • size (int) – The length of the output.

  • fill_value (Shaped[Array, '']) – The value to fill the output with if size is greater than the number of unique values in x.

Returns:

  • out (Shaped[Array, ‘{size}’]) – The unique values in x, sorted, and right-padded with fill_value.

  • actual_length (int) – The number of used values in out.

class bartz.jaxext.split(key, num=2)[source]

Split a key into num keys.

Parameters:
  • key (Key[Array, '']) – The key to split.

  • num (int, default: 2) – The number of keys to split into.

pop(shape=())[source]

Pop one or more keys from the list.

Parameters:

shape (int | tuple[int, ...], default: ()) – The shape of the keys to pop. If empty (default), a single key is popped and returned. If not empty, the popped key is split and reshaped to the target shape.

Returns:

Key[Array, '*']The popped keys as a jax array with the requested shape.

Raises:

IndexError – If the list is empty.

bartz.jaxext.truncated_normal_onesided(key, shape, upper, bound, *, clip=True)[source]

Sample from a one-sided truncated standard normal distribution.

Parameters:
  • key (Key[Array, '']) – JAX random key.

  • shape (Sequence[int]) – Shape of output array, broadcasted with other inputs.

  • upper (Bool[Array, '*']) – True for (-∞, bound], False for [bound, ∞).

  • bound (Float32[Array, '*']) – The truncation boundary.

  • clip (bool, default: True) – Whether to clip the truncated uniform samples to (0, 1) before transforming them to truncated normal. Intended for debugging purposes.

Returns:

Float32[Array, '*']Array of samples from the truncated normal distribution.

bartz.jaxext.get_default_device()[source]

Get the current default JAX device.

Return type:

Device

bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)[source]

Batch a function such that each batch is smaller than a threshold.

Parameters:
  • func (Callable) – A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.

  • max_io_nbytes (int) – The maximum number of input + output bytes in each batch (excluding unbatched arguments.)

  • in_axes (PyTree[int | None], default: 0) – A tree matching (a prefix of) the structure of the function input, indicating along which axes each array should be batched. A None axis indicates to not batch an argument.

  • out_axes (PyTree[int], default: 0) – The same for outputs (but non-batching is not allowed).

  • return_nbatches (bool, default: False) – If True, the number of batches is returned as a second output.

Returns:

Callable – A function with the same signature as func, save for the return value if return_nbatches.

bartz.jaxext.scipy.special

Mockup of the scipy.special module.

bartz.jaxext.scipy.special.gammainccinv(a, y)[source]

Survival function inverse of the Gamma(a, 1) distribution.

bartz.jaxext.scipy.special.ndtri(p)[source]

Compute the inverse of the CDF of the Normal distribution function.

This is a patch of jax.scipy.special.ndtri.

bartz.jaxext.scipy.stats

Mockup of the scipy.stats module.

class bartz.jaxext.scipy.stats.invgamma[source]

Class that represents the distribution InvGamma(a, 1).

static ppf(q, a)[source]

Percentile point function.