Coverage for src/bartz/jaxext.py: 90%
203 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-05 18:54 +0000
1# bartz/src/bartz/jaxext.py
2#
3# Copyright (c) 2024, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25import functools 1a
26import math 1a
27import warnings 1a
29from scipy import special 1a
30import jax 1a
31from jax import numpy as jnp 1a
32from jax import tree_util 1a
33from jax import lax 1a
35def float_type(*args): 1a
36 """
37 Determine the jax floating point result type given operands/types.
38 """
39 t = jnp.result_type(*args) 1a
40 return jnp.sin(jnp.empty(0, t)).dtype 1a
42def castto(func, type): 1a
43 @functools.wraps(func) 1a
44 def newfunc(*args, **kw): 1a
45 return func(*args, **kw).astype(type) 1a
46 return newfunc 1a
48def pure_callback_ufunc(callback, dtype, *args, excluded=None, **kwargs): 1a
49 """ version of `jax.pure_callback` that deals correctly with ufuncs,
50 see `<https://github.com/google/jax/issues/17187>`_ """
51 if excluded is None: 51 ↛ 53line 51 didn't jump to line 53 because the condition on line 51 was always true1a
52 excluded = () 1a
53 shape = jnp.broadcast_shapes(*( 1a
54 a.shape
55 for i, a in enumerate(args)
56 if i not in excluded
57 ))
58 ndim = len(shape) 1a
59 padded_args = [ 1a
60 a if i in excluded
61 else jnp.expand_dims(a, tuple(range(ndim - a.ndim)))
62 for i, a in enumerate(args)
63 ]
64 result = jax.ShapeDtypeStruct(shape, dtype) 1a
65 return jax.pure_callback(callback, result, *padded_args, vectorized=True, **kwargs) 1a
67 # TODO when jax solves this, check version and piggyback on original if new
69class scipy: 1a
71 class special: 1a
73 @functools.wraps(special.gammainccinv) 1a
74 def gammainccinv(a, y): 1a
75 a = jnp.asarray(a) 1a
76 y = jnp.asarray(y) 1a
77 dtype = float_type(a.dtype, y.dtype) 1a
78 ufunc = castto(special.gammainccinv, dtype) 1a
79 return pure_callback_ufunc(ufunc, dtype, a, y) 1a
81 class stats: 1a
83 class invgamma: 1a
85 def ppf(q, a): 1a
86 return 1 / scipy.special.gammainccinv(a, q) 1a
88@functools.wraps(jax.vmap) 1a
89def vmap_nodoc(fun, *args, **kw): 1a
90 """
91 Version of `jax.vmap` that preserves the docstring of the input function.
92 """
93 doc = fun.__doc__ 1a
94 fun = jax.vmap(fun, *args, **kw) 1a
95 fun.__doc__ = doc 1a
96 return fun 1a
98def huge_value(x): 1a
99 """
100 Return the maximum value that can be stored in `x`.
102 Parameters
103 ----------
104 x : array
105 A numerical numpy or jax array.
107 Returns
108 -------
109 maxval : scalar
110 The maximum value allowed by `x`'s type (+inf for floats).
111 """
112 if jnp.issubdtype(x.dtype, jnp.integer): 1a
113 return jnp.iinfo(x.dtype).max 1a
114 else:
115 return jnp.inf 1a
117def minimal_unsigned_dtype(max_value): 1a
118 """
119 Return the smallest unsigned integer dtype that can represent a given
120 maximum value (inclusive).
121 """
122 if max_value < 2 ** 8: 122 ↛ 124line 122 didn't jump to line 124 because the condition on line 122 was always true1a
123 return jnp.uint8 1a
124 if max_value < 2 ** 16:
125 return jnp.uint16
126 if max_value < 2 ** 32:
127 return jnp.uint32
128 return jnp.uint64
130def signed_to_unsigned(int_dtype): 1a
131 """
132 Map a signed integer type to its unsigned counterpart. Unsigned types are
133 passed through.
134 """
135 assert jnp.issubdtype(int_dtype, jnp.integer) 1a
136 if jnp.issubdtype(int_dtype, jnp.unsignedinteger): 136 ↛ 137line 136 didn't jump to line 137 because the condition on line 136 was never true1a
137 return int_dtype
138 if int_dtype == jnp.int8: 138 ↛ 139line 138 didn't jump to line 139 because the condition on line 138 was never true1a
139 return jnp.uint8
140 if int_dtype == jnp.int16: 140 ↛ 141line 140 didn't jump to line 141 because the condition on line 140 was never true1a
141 return jnp.uint16
142 if int_dtype == jnp.int32: 142 ↛ 144line 142 didn't jump to line 144 because the condition on line 142 was always true1a
143 return jnp.uint32 1a
144 if int_dtype == jnp.int64:
145 return jnp.uint64
147def ensure_unsigned(x): 1a
148 """
149 If x has signed integer type, cast it to the unsigned dtype of the same size.
150 """
151 return x.astype(signed_to_unsigned(x.dtype)) 1a
153@functools.partial(jax.jit, static_argnums=(1,)) 1a
154def unique(x, size, fill_value): 1a
155 """
156 Restricted version of `jax.numpy.unique` that uses less memory.
158 Parameters
159 ----------
160 x : 1d array
161 The input array.
162 size : int
163 The length of the output.
164 fill_value : scalar
165 The value to fill the output with if `size` is greater than the number
166 of unique values in `x`.
168 Returns
169 -------
170 out : array (size,)
171 The unique values in `x`, sorted, and right-padded with `fill_value`.
172 actual_length : int
173 The number of used values in `out`.
174 """
175 if x.size == 0: 1a
176 return jnp.full(size, fill_value, x.dtype), 0 1a
177 if size == 0: 1a
178 return jnp.empty(0, x.dtype), 0 1a
179 x = jnp.sort(x) 1a
180 def loop(carry, x): 1a
181 i_out, i_in, last, out = carry 1a
182 i_out = jnp.where(x == last, i_out, i_out + 1) 1a
183 out = out.at[i_out].set(x) 1a
184 return (i_out, i_in + 1, x, out), None 1a
185 carry = 0, 0, x[0], jnp.full(size, fill_value, x.dtype) 1a
186 (actual_length, _, _, out), _ = jax.lax.scan(loop, carry, x[:size]) 1a
187 return out, actual_length + 1 1a
189def autobatch(func, max_io_nbytes, in_axes=0, out_axes=0, return_nbatches=False): 1a
190 """
191 Batch a function such that each batch is smaller than a threshold.
193 Parameters
194 ----------
195 func : callable
196 A jittable function with positional arguments only, with inputs and
197 outputs pytrees of arrays.
198 max_io_nbytes : int
199 The maximum number of input + output bytes in each batch (excluding
200 unbatched arguments.)
201 in_axes : pytree of int or None, default 0
202 A tree matching the structure of the function input, indicating along
203 which axes each array should be batched. If a single integer, it is
204 used for all arrays. A `None` axis indicates to not batch an argument.
205 out_axes : pytree of ints, default 0
206 The same for outputs (but non-batching is not allowed).
207 return_nbatches : bool, default False
208 If True, the number of batches is returned as a second output.
210 Returns
211 -------
212 batched_func : callable
213 A function with the same signature as `func`, but that processes the
214 input and output in batches in a loop.
215 """
217 def expand_axes(axes, tree): 1a
218 if isinstance(axes, int): 1a
219 return tree_util.tree_map(lambda _: axes, tree) 1a
220 return tree_util.tree_map(lambda _, axis: axis, tree, axes) 1a
222 def check_no_nones(axes, tree): 1a
223 def check_not_none(_, axis): 1a
224 assert axis is not None 1a
225 tree_util.tree_map(check_not_none, tree, axes) 1a
227 def extract_size(axes, tree): 1a
228 def get_size(x, axis): 1a
229 if axis is None: 1a
230 return None 1a
231 else:
232 return x.shape[axis] 1a
233 sizes = tree_util.tree_map(get_size, tree, axes) 1a
234 sizes, _ = tree_util.tree_flatten(sizes) 1a
235 assert all(s == sizes[0] for s in sizes) 1a
236 return sizes[0] 1a
238 def sum_nbytes(tree): 1a
239 def nbytes(x): 1a
240 return math.prod(x.shape) * x.dtype.itemsize 1a
241 return tree_util.tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 1a
243 def next_divisor_small(dividend, min_divisor): 1a
244 for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 244 ↛ 247line 244 didn't jump to line 247 because the loop on line 244 didn't complete1a
245 if dividend % divisor == 0: 245 ↛ 244line 245 didn't jump to line 244 because the condition on line 245 was always true1a
246 return divisor 1a
247 return dividend
249 def next_divisor_large(dividend, min_divisor): 1a
250 max_inv_divisor = dividend // min_divisor 1a
251 for inv_divisor in range(max_inv_divisor, 0, -1): 1a
252 if dividend % inv_divisor == 0: 252 ↛ 251line 252 didn't jump to line 251 because the condition on line 252 was always true1a
253 return dividend // inv_divisor 1a
254 return dividend 1a
256 def next_divisor(dividend, min_divisor): 1a
257 if dividend == 0: 1a
258 return min_divisor 1a
259 if min_divisor * min_divisor <= dividend: 1a
260 return next_divisor_small(dividend, min_divisor) 1a
261 return next_divisor_large(dividend, min_divisor) 1a
263 def pull_nonbatched(axes, tree): 1a
264 def pull_nonbatched(x, axis): 1a
265 if axis is None: 1a
266 return None 1a
267 else:
268 return x 1a
269 return tree_util.tree_map(pull_nonbatched, tree, axes), tree 1a
271 def push_nonbatched(axes, tree, original_tree): 1a
272 def push_nonbatched(original_x, x, axis): 1a
273 if axis is None: 1a
274 return original_x 1a
275 else:
276 return x 1a
277 return tree_util.tree_map(push_nonbatched, original_tree, tree, axes) 1a
279 def move_axes_out(axes, tree): 1a
280 def move_axis_out(x, axis): 1a
281 return jnp.moveaxis(x, axis, 0) 1a
282 return tree_util.tree_map(move_axis_out, tree, axes) 1a
284 def move_axes_in(axes, tree): 1a
285 def move_axis_in(x, axis): 1a
286 return jnp.moveaxis(x, 0, axis) 1a
287 return tree_util.tree_map(move_axis_in, tree, axes) 1a
289 def batch(tree, nbatches): 1a
290 def batch(x): 1a
291 return x.reshape((nbatches, x.shape[0] // nbatches) + x.shape[1:]) 1a
292 return tree_util.tree_map(batch, tree) 1a
294 def unbatch(tree): 1a
295 def unbatch(x): 1a
296 return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:]) 1a
297 return tree_util.tree_map(unbatch, tree) 1a
299 def check_same(tree1, tree2): 1a
300 def check_same(x1, x2): 1a
301 assert x1.shape == x2.shape 1a
302 assert x1.dtype == x2.dtype 1a
303 tree_util.tree_map(check_same, tree1, tree2) 1a
305 initial_in_axes = in_axes 1a
306 initial_out_axes = out_axes 1a
308 @jax.jit 1a
309 @functools.wraps(func) 1a
310 def batched_func(*args): 1a
311 example_result = jax.eval_shape(func, *args) 1a
313 in_axes = expand_axes(initial_in_axes, args) 1a
314 out_axes = expand_axes(initial_out_axes, example_result) 1a
315 check_no_nones(out_axes, example_result) 1a
317 size = extract_size((in_axes, out_axes), (args, example_result)) 1a
319 args, nonbatched_args = pull_nonbatched(in_axes, args) 1a
321 total_nbytes = sum_nbytes((args, example_result)) 1a
322 min_nbatches = total_nbytes // max_io_nbytes + bool(total_nbytes % max_io_nbytes) 1a
323 min_nbatches = max(1, min_nbatches) 1a
324 nbatches = next_divisor(size, min_nbatches) 1a
325 assert 1 <= nbatches <= max(1, size) 1a
326 assert size % nbatches == 0 1a
327 assert total_nbytes % nbatches == 0 1a
329 batch_nbytes = total_nbytes // nbatches 1a
330 if batch_nbytes > max_io_nbytes: 1a
331 assert size == nbatches 1a
332 warnings.warn(f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}') 1a
334 def loop(_, args): 1a
335 args = move_axes_in(in_axes, args) 1a
336 args = push_nonbatched(in_axes, args, nonbatched_args) 1a
337 result = func(*args) 1a
338 result = move_axes_out(out_axes, result) 1a
339 return None, result 1a
341 args = move_axes_out(in_axes, args) 1a
342 args = batch(args, nbatches) 1a
343 _, result = lax.scan(loop, None, args) 1a
344 result = unbatch(result) 1a
345 result = move_axes_in(out_axes, result) 1a
347 check_same(example_result, result) 1a
349 if return_nbatches: 1a
350 return result, nbatches 1a
351 return result 1a
353 return batched_func 1a
355@tree_util.register_pytree_node_class 1a
356class LeafDict(dict): 1a
357 """ dictionary that acts as a leaf in jax pytrees, to store compile-time
358 values """
360 def tree_flatten(self): 1a
361 return (), self 1a
363 @classmethod 1a
364 def tree_unflatten(cls, aux_data, children): 1a
365 return aux_data 1a
367 def __repr__(self): 1a
368 return f'{__class__.__name__}({super().__repr__()})' 1a