JAX extensions¶
- bartz.jaxext.float_type(*args)[source]¶
Determine the jax floating point result type given operands/types.
- bartz.jaxext.pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs)[source]¶
version of
jax.pure_callback
that deals correctly with ufuncs, see https://github.com/google/jax/issues/17187
- bartz.jaxext.huge_value(x)[source]¶
Return the maximum value that can be stored in
x
.- Parameters:
- xarray
A numerical numpy or jax array.
- Returns:
- maxvalscalar
The maximum value allowed by
x
’s type (+inf for floats).
- bartz.jaxext.minimal_unsigned_dtype(max_value)[source]¶
Return the smallest unsigned integer dtype that can represent a given maximum value (inclusive).
- bartz.jaxext.signed_to_unsigned(int_dtype)[source]¶
Map a signed integer type to its unsigned counterpart. Unsigned types are passed through.
- bartz.jaxext.ensure_unsigned(x)[source]¶
If x has signed integer type, cast it to the unsigned dtype of the same size.
- bartz.jaxext.unique(x, size, fill_value)[source]¶
Restricted version of
jax.numpy.unique
that uses less memory.- Parameters:
- x1d array
The input array.
- sizeint
The length of the output.
- fill_valuescalar
The value to fill the output with if
size
is greater than the number of unique values inx
.
- Returns:
- outarray (size,)
The unique values in
x
, sorted, and right-padded withfill_value
.- actual_lengthint
The number of used values in
out
.
- bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)[source]¶
Batch a function such that each batch is smaller than a threshold.
- Parameters:
- funccallable
A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.
- max_io_nbytesint
The maximum number of input + output bytes in each batch (excluding unbatched arguments.)
- in_axespytree of int or None, default 0
A tree matching the structure of the function input, indicating along which axes each array should be batched. If a single integer, it is used for all arrays. A
None
axis indicates to not batch an argument.- out_axespytree of ints, default 0
The same for outputs (but non-batching is not allowed).
- return_nbatchesbool, default False
If True, the number of batches is returned as a second output.
- Returns:
- batched_funccallable
A function with the same signature as
func
, but that processes the input and output in batches in a loop.