Coverage for src / bartz / _profiler.py: 97%
63 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-18 15:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-18 15:24 +0000
1# bartz/src/bartz/_profiler.py
2#
3# Copyright (c) 2025, The Bartz Contributors
4#
5# This file is part of bartz.
6#
7# Permission is hereby granted, free of charge, to any person obtaining a copy
8# of this software and associated documentation files (the "Software"), to deal
9# in the Software without restriction, including without limitation the rights
10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11# copies of the Software, and to permit persons to whom the Software is
12# furnished to do so, subject to the following conditions:
13#
14# The above copyright notice and this permission notice shall be included in all
15# copies or substantial portions of the Software.
16#
17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23# SOFTWARE.
25"""Module with utilities related to profiling bartz."""
27from collections.abc import Callable, Iterator
28from contextlib import contextmanager
29from functools import wraps
30from typing import Any, TypeVar
32from jax import block_until_ready, debug, jit
33from jax.lax import cond, scan
34from jax.profiler import TraceAnnotation
35from jaxtyping import Array, Bool
37PROFILE_MODE: bool = False
39T = TypeVar('T')
40Carry = TypeVar('Carry')
43def get_profile_mode() -> bool:
44 """Return the current profile mode status.
46 Returns
47 -------
48 True if profile mode is enabled, False otherwise.
49 """
50 return PROFILE_MODE 2b f c d g e s t N # $ % ' ( ) * u O v P Q R S T U w + x y , z A V B C W D - . / : ; = ? 8 @ [ ] ^ X Y Z _ ` { E 0 F G | } H 1 I J K ~ abbbL 2 9 M 6 3 7 4 ! cbdbh i j m k l n p q o 5 r
53def set_profile_mode(value: bool, /) -> None:
54 """Set the profile mode status.
56 Parameters
57 ----------
58 value
59 If True, enable profile mode. If False, disable it.
60 """
61 global PROFILE_MODE # noqa: PLW0603
62 PROFILE_MODE = value 2b f c d g e M 6 3 7 4 ! cbh i j m k l n p q o 5 r
65@contextmanager
66def profile_mode(value: bool, /) -> Iterator[None]:
67 """Context manager to temporarily set profile mode.
69 Parameters
70 ----------
71 value
72 Profile mode value to set within the context.
74 Examples
75 --------
76 >>> with profile_mode(True):
77 ... # Code runs with profile mode enabled
78 ... pass
80 Notes
81 -----
82 In profiling mode, the MCMC loop is not compiled into a single function, but
83 instead compiled in smaller pieces that are instrumented to show up in the
84 jax tracer and Python profiling statistics. Search for function names
85 starting with 'jab' (see `jit_and_block_if_profiling`).
87 Jax tracing is not enabled by this context manager and if used must be
88 handled separately by the user; this context manager only makes sure that
89 the execution flow will be more interpretable in the traces if the tracer is
90 used.
91 """
92 old_value = get_profile_mode() 1bfcdgeM6374!hijmklnpqo5r
93 set_profile_mode(value) 1bfcdgeM6374!hijmklnpqo5r
94 try: 1bfcdgeM6374!hijmklnpqo5r
95 yield 1bfcdgeM6374!hijmklnpqo5r
96 finally:
97 set_profile_mode(old_value) 1bfcdgeM6374!hijmklnpqo5r
100def jit_and_block_if_profiling(
101 func: Callable[..., T], block_before: bool = False, **kwargs
102) -> Callable[..., T]:
103 """Apply JIT compilation and block if profiling is enabled.
105 When profile mode is off, the function runs without JIT. When profile mode
106 is on, the function is JIT compiled and blocks outputs to ensure proper
107 timing.
109 Parameters
110 ----------
111 func
112 Function to wrap.
113 block_before
114 If True block inputs before passing them to the JIT-compiled function.
115 This ensures that any pending computations are completed before entering
116 the JIT-compiled function. This phase is not included in the trace
117 event.
118 **kwargs
119 Additional arguments to pass to `jax.jit`.
121 Returns
122 -------
123 Wrapped function.
125 Notes
126 -----
127 Under profiling mode, the function invocation is handled such that a custom
128 jax trace event with name `jab[<func_name>]` is created. The statistics on
129 the actual Python function will be off, while the function
130 `jab_inner_wrapper` represents the actual execution time.
131 """
132 jitted_func = jit(func, **kwargs) 1ahijmkl
134 event_name = f'jab[{func.__name__}]' 1ahijmkl
136 # this wrapper is meant to measure the time spent executing the function
137 def jab_inner_wrapper(*args, **kwargs) -> T: 1ahijmkl
138 with TraceAnnotation(event_name): 1bfcdgehijkl
139 result = jitted_func(*args, **kwargs) 1bfcdgehijkl
140 return block_until_ready(result) 1bfcdgehjkl
142 @wraps(func) 1ahijmkl
143 def jab_outer_wrapper(*args: Any, **kwargs: Any) -> T: 1ahijmkl
144 if get_profile_mode(): 1bfcdgestNuOvPQRSTUwxyzAVBCWDXYZE0FGH1IJKL2hijmkl
145 if block_before: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true1bfcdgehijkl
146 args, kwargs = block_until_ready((args, kwargs))
147 return jab_inner_wrapper(*args, **kwargs) 1bfcdgehijkl
148 else:
149 return func(*args, **kwargs) 1stNuOvPQRSTUwxyzAVBCWDXYZE0FGH1IJKL2him
151 return jab_outer_wrapper 1ahijmkl
154def jit_if_not_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]:
155 """Apply JIT compilation only when not profiling.
157 When profile mode is off, the function is JIT compiled. When profile mode is
158 on, the function runs as-is.
160 Parameters
161 ----------
162 func
163 Function to wrap.
164 *args
165 **kwargs
166 Additional arguments to pass to `jax.jit`.
168 Returns
169 -------
170 Wrapped function.
171 """
172 jitted_func = jit(func, *args, **kwargs) 1anpq
174 @wraps(func) 1anpq
175 def wrapper(*args: Any, **kwargs: Any) -> T: 1anpq
176 if get_profile_mode(): 2b f c d g e s t N # $ % ' ( ) * u O v P Q R S T U w + x y , z A V B C W D - . / : ; = ? 8 @ [ ] ^ X Y Z _ ` { E 0 F G | } H 1 I J K ~ abbbL 2 9 n p q
177 return func(*args, **kwargs) 1bfcdgenq
178 else:
179 return jitted_func(*args, **kwargs) 2b f c d g e s t N # $ % ' ( ) * u O v P Q R S T U w + x y , z A V B C W D - . / : ; = ? 8 @ [ ] ^ X Y Z _ ` { E 0 F G | } H 1 I J K ~ abbbL 2 9 n p
181 return wrapper 1anpq
184def scan_if_not_profiling(
185 f: Callable[[Carry, None], tuple[Carry, None]],
186 init: Carry,
187 xs: None,
188 length: int,
189 /,
190) -> tuple[Carry, None]:
191 """Restricted replacement for `jax.lax.scan` that uses a Python loop when profiling.
193 Parameters
194 ----------
195 f
196 Scan body function with signature (carry, None) -> (carry, None).
197 init
198 Initial carry value.
199 xs
200 Input values to scan over (not supported).
201 length
202 Integer specifying the number of loop iterations.
204 Returns
205 -------
206 Tuple of (final_carry, None) (stacked outputs not supported).
207 """
208 assert xs is None 1bfcdgestNuOvPQRSTUwxyzAVBCWD8XYZE0FGH1IJKL29o5r
209 if get_profile_mode(): 1bfcdgestNuOvPQRSTUwxyzAVBCWD8XYZE0FGH1IJKL29o5r
210 carry = init 1bfcdgeor
211 for _i in range(length): 1bfcdgeor
212 carry, _ = f(carry, None) 1bfcdgeor
213 return carry, None 1bfcdgeor
215 else:
216 return scan(f, init, None, length) 1stNuOvPQRSTUwxyzAVBCWD8XYZE0FGH1IJKL29o5
219def cond_if_not_profiling(
220 pred: bool | Bool[Array, ''],
221 true_fun: Callable[..., T],
222 false_fun: Callable[..., T],
223 /,
224 *operands,
225) -> T:
226 """Restricted replacement for `jax.lax.cond` that uses a Python if when profiling.
228 Parameters
229 ----------
230 pred
231 Boolean predicate to choose which function to execute.
232 true_fun
233 Function to execute if `pred` is True.
234 false_fun
235 Function to execute if `pred` is False.
236 *operands
237 Arguments passed to `true_fun` and `false_fun`.
239 Returns
240 -------
241 Result of either `true_fun()` or `false_fun()`.
242 """
243 if get_profile_mode(): 1bfcdgestNuOvPQRSTUwxyzAVBCWD8XYZE0FGH1IJKL29M6374
244 if pred: 1bfcdgeM34
245 return true_fun(*operands) 1bfcdgeM4
246 else:
247 return false_fun(*operands) 1bcdeM3
248 else:
249 return cond(pred, true_fun, false_fun, *operands) 1stNuOvPQRSTUwxyzAVBCWD8XYZE0FGH1IJKL29M67
252def callback_if_not_profiling(
253 callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any
254):
255 """Restricted replacement for `jax.debug.callback` that calls the callback directly in profiling mode."""
256 if get_profile_mode(): 1bcdestuvwxyzABCDEFGHIJKL
257 callback(*args, **kwargs) 1bcde
258 else:
259 debug.callback(callback, *args, ordered=ordered, **kwargs) 1stuvwxyzABCDEFGHIJKL