Coverage for src/bartz/jaxext/__init__.py: 92%
69 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-06-27 14:46 +0000
1# bartz/src/bartz/jaxext/__init__.py
2#
3# Copyright (c) 2024-2025, Giacomo Petrillo
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Additions to jax."""
27import functools 1ab
28import math 1ab
29from collections.abc import Sequence 1ab
31import jax 1ab
32from jax import numpy as jnp 1ab
33from jax import random 1ab
34from jax.lax import scan 1ab
35from jax.scipy.special import ndtr 1ab
36from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped 1ab
38from bartz.jaxext._autobatch import autobatch # noqa: F401 1ab
39from bartz.jaxext.scipy.special import ndtri 1ab
42def vmap_nodoc(fun, *args, **kw): 1ab
43 """
44 Acts like `jax.vmap` but preserves the docstring of the function unchanged.
46 This is useful if the docstring already takes into account that the
47 arguments have additional axes due to vmap.
48 """
49 doc = fun.__doc__ 1ab
50 fun = jax.vmap(fun, *args, **kw) 1ab
51 fun.__doc__ = doc 1ab
52 return fun 1ab
55def minimal_unsigned_dtype(value): 1ab
56 """Return the smallest unsigned integer dtype that can represent `value`."""
57 if value < 2**8: 1ab
58 return jnp.uint8 1ab
59 if value < 2**16: 59 ↛ 61line 59 didn't jump to line 61 because the condition on line 59 was always true1ab
60 return jnp.uint16 1ab
61 if value < 2**32:
62 return jnp.uint32
63 return jnp.uint64
66@functools.partial(jax.jit, static_argnums=(1,)) 1ab
67def unique( 1ab
68 x: Shaped[Array, ' _'], size: int, fill_value: Scalar
69) -> tuple[Shaped[Array, ' {size}'], int]:
70 """
71 Restricted version of `jax.numpy.unique` that uses less memory.
73 Parameters
74 ----------
75 x
76 The input array.
77 size
78 The length of the output.
79 fill_value
80 The value to fill the output with if `size` is greater than the number
81 of unique values in `x`.
83 Returns
84 -------
85 out : Shaped[Array, '{size}']
86 The unique values in `x`, sorted, and right-padded with `fill_value`.
87 actual_length : int
88 The number of used values in `out`.
89 """
90 if x.size == 0: 1ab
91 return jnp.full(size, fill_value, x.dtype), 0 1ab
92 if size == 0: 1ab
93 return jnp.empty(0, x.dtype), 0 1ab
94 x = jnp.sort(x) 1ab
96 def loop(carry, x): 1ab
97 i_out, last, out = carry 1ab
98 i_out = jnp.where(x == last, i_out, i_out + 1) 1ab
99 out = out.at[i_out].set(x) 1ab
100 return (i_out, x, out), None 1ab
102 carry = 0, x[0], jnp.full(size, fill_value, x.dtype) 1ab
103 (actual_length, _, out), _ = scan(loop, carry, x[:size]) 1ab
104 return out, actual_length + 1 1ab
107class split: 1ab
108 """
109 Split a key into `num` keys.
111 Parameters
112 ----------
113 key
114 The key to split.
115 num
116 The number of keys to split into.
117 """
119 def __init__(self, key: Key[Array, ''], num: int = 2): 1ab
120 self._keys = random.split(key, num) 1ab
122 def __len__(self): 1ab
123 return self._keys.size 1ab
125 def pop(self, shape: int | tuple[int, ...] | None = None) -> Key[Array, '*']: 1ab
126 """
127 Pop one or more keys from the list.
129 Parameters
130 ----------
131 shape
132 The shape of the keys to pop. If `None`, a single key is popped.
133 If an integer, that many keys are popped. If a tuple, the keys are
134 reshaped to that shape.
136 Returns
137 -------
138 The popped keys as a jax array with the requested shape.
140 Raises
141 ------
142 IndexError
143 If `shape` is larger than the number of keys left in the list.
145 Notes
146 -----
147 The keys are popped from the beginning of the list, so for example
148 ``list(keys.pop(2))`` is equivalent to ``[keys.pop(), keys.pop()]``.
149 """
150 if shape is None: 1ab
151 shape = () 1ab
152 elif not isinstance(shape, tuple): 152 ↛ 154line 152 didn't jump to line 154 because the condition on line 152 was always true1ab
153 shape = (shape,) 1ab
154 size_to_pop = math.prod(shape) 1ab
155 if size_to_pop > self._keys.size: 1ab
156 msg = f'Cannot pop {size_to_pop} keys from {self._keys.size} keys' 1ab
157 raise IndexError(msg) 1ab
158 popped_keys = self._keys[:size_to_pop] 1ab
159 self._keys = self._keys[size_to_pop:] 1ab
160 return popped_keys.reshape(shape) 1ab
163def truncated_normal_onesided( 1ab
164 key: Key[Array, ''],
165 shape: Sequence[int],
166 upper: Bool[Array, '*'],
167 bound: Float32[Array, '*'],
168) -> Float32[Array, '*']:
169 """
170 Sample from a one-sided truncated standard normal distribution.
172 Parameters
173 ----------
174 key
175 JAX random key.
176 shape
177 Shape of output array, broadcasted with other inputs.
178 upper
179 True for (-∞, bound], False for [bound, ∞).
180 bound
181 The truncation boundary.
183 Returns
184 -------
185 Array of samples from the truncated normal distribution.
186 """
187 # Pseudocode:
188 # | if upper:
189 # | if bound < 0:
190 # | ndtri(uniform(0, ndtr(bound))) =
191 # | ndtri(ndtr(bound) * u)
192 # | if bound > 0:
193 # | -ndtri(uniform(ndtr(-bound), 1)) =
194 # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u))
195 # | if not upper:
196 # | if bound < 0:
197 # | ndtri(uniform(ndtr(bound), 1)) =
198 # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u))
199 # | if bound > 0:
200 # | -ndtri(uniform(0, ndtr(-bound))) =
201 # | -ndtri(ndtr(-bound) * u)
202 shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape) 1ab
203 bound_pos = bound > 0 1ab
204 ndtr_bound = ndtr(bound) 1ab
205 ndtr_neg_bound = ndtr(-bound) 1ab
206 scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound) 1ab
207 shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound) 1ab
208 u = random.uniform(key, shape) 1ab
209 left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)] 1ab
210 right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1) 1ab
211 truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u) 1ab
212 truncated_norm = ndtri(truncated_u) 1ab
213 return jnp.where(bound_pos, -truncated_norm, truncated_norm) 1ab