JAX extensions

bartz.jaxext.float_type(*args)[source]

Determine the jax floating point result type given operands/types.

bartz.jaxext.pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs)[source]

version of jax.pure_callback that deals correctly with ufuncs, see https://github.com/google/jax/issues/17187

bartz.jaxext.huge_value(x)[source]

Return the maximum value that can be stored in x.

Parameters:
xarray

A numerical numpy or jax array.

Returns:
maxvalscalar

The maximum value allowed by x’s type (+inf for floats).

bartz.jaxext.minimal_unsigned_dtype(max_value)[source]

Return the smallest unsigned integer dtype that can represent a given maximum value (inclusive).

bartz.jaxext.signed_to_unsigned(int_dtype)[source]

Map a signed integer type to its unsigned counterpart. Unsigned types are passed through.

bartz.jaxext.ensure_unsigned(x)[source]

If x has signed integer type, cast it to the unsigned dtype of the same size.

bartz.jaxext.unique(x, size, fill_value)[source]

Restricted version of jax.numpy.unique that uses less memory.

Parameters:
x1d array

The input array.

sizeint

The length of the output.

fill_valuescalar

The value to fill the output with if size is greater than the number of unique values in x.

Returns:
outarray (size,)

The unique values in x, sorted, and right-padded with fill_value.

actual_lengthint

The number of used values in out.

bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)[source]

Batch a function such that each batch is smaller than a threshold.

Parameters:
funccallable

A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.

max_io_nbytesint

The maximum number of input + output bytes in each batch (excluding unbatched arguments.)

in_axespytree of int or None, default 0

A tree matching the structure of the function input, indicating along which axes each array should be batched. If a single integer, it is used for all arrays. A None axis indicates to not batch an argument.

out_axespytree of ints, default 0

The same for outputs (but non-batching is not allowed).

return_nbatchesbool, default False

If True, the number of batches is returned as a second output.

Returns:
batched_funccallable

A function with the same signature as func, but that processes the input and output in batches in a loop.

class bartz.jaxext.LeafDict[source]

dictionary that acts as a leaf in jax pytrees, to store compile-time values