JAX extensions

Additions to jax.

bartz.jaxext.float_type(*args)[source]

Determine the jax floating point result type given operands/types.

class bartz.jaxext.scipy[source]

Mockup of the scipy module.

class special[source]

Mockup of the scipy.special module.

static gammainccinv(a, y)[source]

Survival function inverse of the Gamma(a, 1) distribution.

class stats[source]

Mockup of the scipy.stats module.

class invgamma[source]

Class that represents the distribution InvGamma(a, 1).

static ppf(q, a)[source]

Percentile point function.

bartz.jaxext.vmap_nodoc(fun, *args, **kw)[source]

Acts like jax.vmap but preserves the docstring of the function unchanged.

This is useful if the docstring already takes into account that the arguments have additional axes due to vmap.

bartz.jaxext.huge_value(x)[source]

Return the maximum value that can be stored in x.

Parameters:

x (array) – A numerical numpy or jax array.

Returns:

maxval (scalar) – The maximum value allowed by x’s type (+inf for floats).

bartz.jaxext.minimal_unsigned_dtype(value)[source]

Return the smallest unsigned integer dtype that can represent value.

bartz.jaxext.signed_to_unsigned(int_dtype)[source]

Map a signed integer type to its unsigned counterpart.

Unsigned types are passed through.

bartz.jaxext.ensure_unsigned(x)[source]

If x has signed integer type, cast it to the unsigned dtype of the same size.

bartz.jaxext.unique(x, size, fill_value)[source]

Restricted version of jax.numpy.unique that uses less memory.

Parameters:
  • x (1d array) – The input array.

  • size (int) – The length of the output.

  • fill_value (scalar) – The value to fill the output with if size is greater than the number of unique values in x.

Returns:

  • out (array (size,)) – The unique values in x, sorted, and right-padded with fill_value.

  • actual_length (int) – The number of used values in out.

bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)[source]

Batch a function such that each batch is smaller than a threshold.

Parameters:
  • func (callable) – A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.

  • max_io_nbytes (int) – The maximum number of input + output bytes in each batch (excluding unbatched arguments.)

  • in_axes (pytree of int or None, default 0) – A tree matching the structure of the function input, indicating along which axes each array should be batched. If a single integer, it is used for all arrays. A None axis indicates to not batch an argument.

  • out_axes (pytree of ints, default 0) – The same for outputs (but non-batching is not allowed).

  • return_nbatches (bool, default False) – If True, the number of batches is returned as a second output.

Returns:

batched_func (callable) – A function with the same signature as func, but that processes the input and output in batches in a loop.

class bartz.jaxext.split(key, num=2)[source]

Split a key into num keys.

Parameters:
  • key (jax.dtypes.prng_key array) – The key to split.

  • num (int) – The number of keys to split into.

pop(shape=None)[source]

Pop one or more keys from the list.

Parameters:

shape (int or tuple of int, optional) – The shape of the keys to pop. If None, a single key is popped. If an integer, that many keys are popped. If a tuple, the keys are reshaped to that shape.

Returns:

keys (jax.dtypes.prng_key array) – The popped keys.

Raises:

IndexError – If shape is larger than the number of keys left in the list.

Notes

The keys are popped from the beginning of the list, so for example list(keys.pop(2)) is equivalent to [keys.pop(), keys.pop()].