Coverage for src / bartz / _profiler.py: 97%

63 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-18 15:24 +0000

1# bartz/src/bartz/_profiler.py 

2# 

3# Copyright (c) 2025, The Bartz Contributors 

4# 

5# This file is part of bartz. 

6# 

7# Permission is hereby granted, free of charge, to any person obtaining a copy 

8# of this software and associated documentation files (the "Software"), to deal 

9# in the Software without restriction, including without limitation the rights 

10# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

11# copies of the Software, and to permit persons to whom the Software is 

12# furnished to do so, subject to the following conditions: 

13# 

14# The above copyright notice and this permission notice shall be included in all 

15# copies or substantial portions of the Software. 

16# 

17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

20# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

22# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

23# SOFTWARE. 

24 

25"""Module with utilities related to profiling bartz.""" 

26 

27from collections.abc import Callable, Iterator 

28from contextlib import contextmanager 

29from functools import wraps 

30from typing import Any, TypeVar 

31 

32from jax import block_until_ready, debug, jit 

33from jax.lax import cond, scan 

34from jax.profiler import TraceAnnotation 

35from jaxtyping import Array, Bool 

36 

37PROFILE_MODE: bool = False 

38 

39T = TypeVar('T') 

40Carry = TypeVar('Carry') 

41 

42 

43def get_profile_mode() -> bool: 

44 """Return the current profile mode status. 

45 

46 Returns 

47 ------- 

48 True if profile mode is enabled, False otherwise. 

49 """ 

50 return PROFILE_MODE 2b f c d g e s t N # $ % ' ( ) * u O v P Q R S T U w + x y , z A V B C W D - . / : ; = ? 8 @ [ ] ^ X Y Z _ ` { E 0 F G | } H 1 I J K ~ abbbL 2 9 M 6 3 7 4 ! cbdbh i j m k l n p q o 5 r

51 

52 

53def set_profile_mode(value: bool, /) -> None: 

54 """Set the profile mode status. 

55 

56 Parameters 

57 ---------- 

58 value 

59 If True, enable profile mode. If False, disable it. 

60 """ 

61 global PROFILE_MODE # noqa: PLW0603 

62 PROFILE_MODE = value 2b f c d g e M 6 3 7 4 ! cbh i j m k l n p q o 5 r

63 

64 

65@contextmanager 

66def profile_mode(value: bool, /) -> Iterator[None]: 

67 """Context manager to temporarily set profile mode. 

68 

69 Parameters 

70 ---------- 

71 value 

72 Profile mode value to set within the context. 

73 

74 Examples 

75 -------- 

76 >>> with profile_mode(True): 

77 ... # Code runs with profile mode enabled 

78 ... pass 

79 

80 Notes 

81 ----- 

82 In profiling mode, the MCMC loop is not compiled into a single function, but 

83 instead compiled in smaller pieces that are instrumented to show up in the 

84 jax tracer and Python profiling statistics. Search for function names 

85 starting with 'jab' (see `jit_and_block_if_profiling`). 

86 

87 Jax tracing is not enabled by this context manager and if used must be 

88 handled separately by the user; this context manager only makes sure that 

89 the execution flow will be more interpretable in the traces if the tracer is 

90 used. 

91 """ 

92 old_value = get_profile_mode() 1bfcdgeM6374!hijmklnpqo5r

93 set_profile_mode(value) 1bfcdgeM6374!hijmklnpqo5r

94 try: 1bfcdgeM6374!hijmklnpqo5r

95 yield 1bfcdgeM6374!hijmklnpqo5r

96 finally: 

97 set_profile_mode(old_value) 1bfcdgeM6374!hijmklnpqo5r

98 

99 

100def jit_and_block_if_profiling( 

101 func: Callable[..., T], block_before: bool = False, **kwargs 

102) -> Callable[..., T]: 

103 """Apply JIT compilation and block if profiling is enabled. 

104 

105 When profile mode is off, the function runs without JIT. When profile mode 

106 is on, the function is JIT compiled and blocks outputs to ensure proper 

107 timing. 

108 

109 Parameters 

110 ---------- 

111 func 

112 Function to wrap. 

113 block_before 

114 If True block inputs before passing them to the JIT-compiled function. 

115 This ensures that any pending computations are completed before entering 

116 the JIT-compiled function. This phase is not included in the trace 

117 event. 

118 **kwargs 

119 Additional arguments to pass to `jax.jit`. 

120 

121 Returns 

122 ------- 

123 Wrapped function. 

124 

125 Notes 

126 ----- 

127 Under profiling mode, the function invocation is handled such that a custom 

128 jax trace event with name `jab[<func_name>]` is created. The statistics on 

129 the actual Python function will be off, while the function 

130 `jab_inner_wrapper` represents the actual execution time. 

131 """ 

132 jitted_func = jit(func, **kwargs) 1ahijmkl

133 

134 event_name = f'jab[{func.__name__}]' 1ahijmkl

135 

136 # this wrapper is meant to measure the time spent executing the function 

137 def jab_inner_wrapper(*args, **kwargs) -> T: 1ahijmkl

138 with TraceAnnotation(event_name): 1bfcdgehijkl

139 result = jitted_func(*args, **kwargs) 1bfcdgehijkl

140 return block_until_ready(result) 1bfcdgehjkl

141 

142 @wraps(func) 1ahijmkl

143 def jab_outer_wrapper(*args: Any, **kwargs: Any) -> T: 1ahijmkl

144 if get_profile_mode(): 1bfcdgestNuOvPQRSTUwxyzAVBCWDXYZE0FGH1IJKL2hijmkl

145 if block_before: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true1bfcdgehijkl

146 args, kwargs = block_until_ready((args, kwargs)) 

147 return jab_inner_wrapper(*args, **kwargs) 1bfcdgehijkl

148 else: 

149 return func(*args, **kwargs) 1stNuOvPQRSTUwxyzAVBCWDXYZE0FGH1IJKL2him

150 

151 return jab_outer_wrapper 1ahijmkl

152 

153 

154def jit_if_not_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]: 

155 """Apply JIT compilation only when not profiling. 

156 

157 When profile mode is off, the function is JIT compiled. When profile mode is 

158 on, the function runs as-is. 

159 

160 Parameters 

161 ---------- 

162 func 

163 Function to wrap. 

164 *args 

165 **kwargs 

166 Additional arguments to pass to `jax.jit`. 

167 

168 Returns 

169 ------- 

170 Wrapped function. 

171 """ 

172 jitted_func = jit(func, *args, **kwargs) 1anpq

173 

174 @wraps(func) 1anpq

175 def wrapper(*args: Any, **kwargs: Any) -> T: 1anpq

176 if get_profile_mode(): 2b f c d g e s t N # $ % ' ( ) * u O v P Q R S T U w + x y , z A V B C W D - . / : ; = ? 8 @ [ ] ^ X Y Z _ ` { E 0 F G | } H 1 I J K ~ abbbL 2 9 n p q

177 return func(*args, **kwargs) 1bfcdgenq

178 else: 

179 return jitted_func(*args, **kwargs) 2b f c d g e s t N # $ % ' ( ) * u O v P Q R S T U w + x y , z A V B C W D - . / : ; = ? 8 @ [ ] ^ X Y Z _ ` { E 0 F G | } H 1 I J K ~ abbbL 2 9 n p

180 

181 return wrapper 1anpq

182 

183 

184def scan_if_not_profiling( 

185 f: Callable[[Carry, None], tuple[Carry, None]], 

186 init: Carry, 

187 xs: None, 

188 length: int, 

189 /, 

190) -> tuple[Carry, None]: 

191 """Restricted replacement for `jax.lax.scan` that uses a Python loop when profiling. 

192 

193 Parameters 

194 ---------- 

195 f 

196 Scan body function with signature (carry, None) -> (carry, None). 

197 init 

198 Initial carry value. 

199 xs 

200 Input values to scan over (not supported). 

201 length 

202 Integer specifying the number of loop iterations. 

203 

204 Returns 

205 ------- 

206 Tuple of (final_carry, None) (stacked outputs not supported). 

207 """ 

208 assert xs is None 1bfcdgestNuOvPQRSTUwxyzAVBCWD8XYZE0FGH1IJKL29o5r

209 if get_profile_mode(): 1bfcdgestNuOvPQRSTUwxyzAVBCWD8XYZE0FGH1IJKL29o5r

210 carry = init 1bfcdgeor

211 for _i in range(length): 1bfcdgeor

212 carry, _ = f(carry, None) 1bfcdgeor

213 return carry, None 1bfcdgeor

214 

215 else: 

216 return scan(f, init, None, length) 1stNuOvPQRSTUwxyzAVBCWD8XYZE0FGH1IJKL29o5

217 

218 

219def cond_if_not_profiling( 

220 pred: bool | Bool[Array, ''], 

221 true_fun: Callable[..., T], 

222 false_fun: Callable[..., T], 

223 /, 

224 *operands, 

225) -> T: 

226 """Restricted replacement for `jax.lax.cond` that uses a Python if when profiling. 

227 

228 Parameters 

229 ---------- 

230 pred 

231 Boolean predicate to choose which function to execute. 

232 true_fun 

233 Function to execute if `pred` is True. 

234 false_fun 

235 Function to execute if `pred` is False. 

236 *operands 

237 Arguments passed to `true_fun` and `false_fun`. 

238 

239 Returns 

240 ------- 

241 Result of either `true_fun()` or `false_fun()`. 

242 """ 

243 if get_profile_mode(): 1bfcdgestNuOvPQRSTUwxyzAVBCWD8XYZE0FGH1IJKL29M6374

244 if pred: 1bfcdgeM34

245 return true_fun(*operands) 1bfcdgeM4

246 else: 

247 return false_fun(*operands) 1bcdeM3

248 else: 

249 return cond(pred, true_fun, false_fun, *operands) 1stNuOvPQRSTUwxyzAVBCWD8XYZE0FGH1IJKL29M67

250 

251 

252def callback_if_not_profiling( 

253 callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any 

254): 

255 """Restricted replacement for `jax.debug.callback` that calls the callback directly in profiling mode.""" 

256 if get_profile_mode(): 1bcdestuvwxyzABCDEFGHIJKL

257 callback(*args, **kwargs) 1bcde

258 else: 

259 debug.callback(callback, *args, ordered=ordered, **kwargs) 1stuvwxyzABCDEFGHIJKL