JAX extensions¶
Additions to jax.
- bartz.jaxext.float_type(*args)[source]¶
Determine the jax floating point result type given operands/types.
- class bartz.jaxext.scipy[source]¶
Mockup of the
scipy
module.- class special[source]¶
Mockup of the
scipy.special
module.
- bartz.jaxext.vmap_nodoc(fun, *args, **kw)[source]¶
Acts like
jax.vmap
but preserves the docstring of the function unchanged.This is useful if the docstring already takes into account that the arguments have additional axes due to vmap.
- bartz.jaxext.huge_value(x)[source]¶
Return the maximum value that can be stored in
x
.- Parameters:
x (array) – A numerical numpy or jax array.
- Returns:
maxval (scalar) – The maximum value allowed by
x
’s type (+inf for floats).
- bartz.jaxext.minimal_unsigned_dtype(value)[source]¶
Return the smallest unsigned integer dtype that can represent
value
.
- bartz.jaxext.signed_to_unsigned(int_dtype)[source]¶
Map a signed integer type to its unsigned counterpart.
Unsigned types are passed through.
- bartz.jaxext.ensure_unsigned(x)[source]¶
If x has signed integer type, cast it to the unsigned dtype of the same size.
- bartz.jaxext.unique(x, size, fill_value)[source]¶
Restricted version of
jax.numpy.unique
that uses less memory.- Parameters:
x (1d array) – The input array.
size (int) – The length of the output.
fill_value (scalar) – The value to fill the output with if
size
is greater than the number of unique values inx
.
- Returns:
out (array (size,)) – The unique values in
x
, sorted, and right-padded withfill_value
.actual_length (int) – The number of used values in
out
.
- bartz.jaxext.autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False)[source]¶
Batch a function such that each batch is smaller than a threshold.
- Parameters:
func (callable) – A jittable function with positional arguments only, with inputs and outputs pytrees of arrays.
max_io_nbytes (int) – The maximum number of input + output bytes in each batch (excluding unbatched arguments.)
in_axes (pytree of int or None, default 0) – A tree matching the structure of the function input, indicating along which axes each array should be batched. If a single integer, it is used for all arrays. A
None
axis indicates to not batch an argument.out_axes (pytree of ints, default 0) – The same for outputs (but non-batching is not allowed).
return_nbatches (bool, default False) – If True, the number of batches is returned as a second output.
- Returns:
batched_func (callable) – A function with the same signature as
func
, but that processes the input and output in batches in a loop.
- class bartz.jaxext.split(key, num=2)[source]¶
Split a key into
num
keys.- Parameters:
key (jax.dtypes.prng_key array) – The key to split.
num (int) – The number of keys to split into.
- pop(shape=None)[source]¶
Pop one or more keys from the list.
- Parameters:
shape (int or tuple of int, optional) – The shape of the keys to pop. If
None
, a single key is popped. If an integer, that many keys are popped. If a tuple, the keys are reshaped to that shape.- Returns:
keys (jax.dtypes.prng_key array) – The popped keys.
- Raises:
IndexError – If
shape
is larger than the number of keys left in the list.
Notes
The keys are popped from the beginning of the list, so for example
list(keys.pop(2))
is equivalent to[keys.pop(), keys.pop()]
.