Coverage for src/lsqfitgp/_linalg/_stdcplx.py: 44%

25 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_linalg/_stdcplx.py 

2# 

3# Copyright (c) 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" 

21Module to estimate the time taken by standard linear algebra operations. 

22""" 

23 

24# TODO maybe I can replace this with jax compilation introspection, which 

25# provides flops estimates. See https://jax.readthedocs.io/en/latest/aot.html 

26 

27import timeit 1abcdef

28import inspect 1abcdef

29 

30from jax import random 1abcdef

31from jax import numpy as jnp 1abcdef

32from jax.scipy import linalg as jlinalg 1abcdef

33from scipy import sparse 1abcdef

34 

35def benchmark(func, *args, **kwargs): 1abcdef

36 timer = timeit.Timer('func(*args, **kwargs)', globals=locals()) 

37 n, _ = timer.autorange() 

38 times = timer.repeat(5, n) 

39 time = min(times) / n 

40 return time 

41 

42ops = { 1abcdef

43 'chol': [ 

44 lambda x: jnp.linalg.cholesky(x), # function performing the operation 

45 lambda s: s[0] ** 3, # complexity in terms of arguments' shapes 

46 ], 

47 'eigh': [ 

48 lambda x: jnp.linalg.eigh(x), 

49 lambda s: s[0] ** 3, 

50 ], 

51 'qr-red': [ 

52 lambda x: jnp.linalg.qr(x, mode='reduced'), 

53 lambda s: min(s) ** 2 * max(s), 

54 ], 

55 'qr-full': [ 

56 lambda x: jnp.linalg.qr(x, mode='full'), 

57 lambda s: max(s) ** 3, 

58 ], 

59 'svd-red': [ 

60 lambda x: jnp.linalg.svd(x, full_matrices=False), 

61 lambda s: min(s) ** 2 * max(s), 

62 ], 

63 'svd-full': [ 

64 lambda x: jnp.linalg.svd(x, full_matrices=True), 

65 lambda s: max(s) ** 3, 

66 ], 

67 'solve_triangular': [ 

68 lambda x, y: jlinalg.solve_triangular(x, y), 

69 lambda s, t: s[0] ** 2 * t[1], 

70 ], 

71 'matmul': [ 

72 lambda x, y: jnp.matmul(x, y), 

73 lambda s, t: s[0] * s[1] * t[1], 

74 ] 

75} 

76 

77def gen_ops_factors(n): # pragma: no cover 1abcdef

78 key = random.PRNGKey(202208101236) 

79 factors = {} 

80 for op, (job, est) in ops.items(): 

81 print(f'{op}({n})... ', end='', flush=True) 

82 nparams = len(inspect.signature(job).parameters) 

83 key, subkey = random.split(key) 

84 m = random.normal(subkey, (nparams, n, n), jnp.float32) 

85 args = m @ jnp.swapaxes(m, -2, -1) 

86 time = benchmark(job, *args) 

87 print(f'{time:.2g} s') 

88 factors[op] = time / est(*(a.shape for a in args)) 

89 return factors 

90 

91ops_factors = {'chol': 6.03470915928483e-12, 1abcdef

92 'eigh': 1.824986875290051e-10, 

93 'qr-red': 1.1241237493231893e-10, 

94 'qr-full': 1.2058762495871633e-10, 

95 'svd-red': 4.468000000342727e-10, 

96 'svd-full': 4.2561762500554324e-10, 

97 'solve_triangular': 4.1634716559201486e-12, 

98 'matmul': 5.6301691802218555e-12} # = gen_ops_factors(1000) 

99 

100ops_consts = {'chol': 1.810961455339566e-06, 1abcdef

101 'eigh': 2.390482500195503e-06, 

102 'qr-red': 2.6676162518560884e-06, 

103 'qr-full': 2.6932845800183714e-06, 

104 'svd-red': 3.7152979196980598e-06, 

105 'svd-full': 3.663789590355009e-06, 

106 'solve_triangular': 2.170706249307841e-06, 

107 'matmul': 1.718031665077433e-06} # = gen_ops_factors(1) 

108 

109def predtime(op, shapes, types): 1abcdef

110 """ 

111 Estimate the time taken by a linear algebra operation. 

112  

113 Parameters 

114 ---------- 

115 op : str 

116 The identifier of the operation, see `listops`. 

117 shapes : sequence of tuples of integers 

118 The shapes of the arguments. 

119 types : sequence of numpy data types 

120 The types of the arguments. They are promoted according to JAX rules. 

121  

122 Returns 

123 ------- 

124 time : float 

125 The estimated time. Not accurate. The unit of measure is seconds on a 

126 particular laptop cpu used to calibrate the estimate with 1000x1000 

127 matrices. 

128 """ 

129 _, est = ops[op] 

130 factor = ops_factors[op] 

131 const = ops_consts[op] 

132 dt = jnp.sin(jnp.empty(0, jnp.result_type(*types))).dtype 

133 if dt == jnp.float64: 

134 factor *= 2 

135 return const + factor * est(*shapes) 

136 

137def listops(): 1abcdef

138 """ 

139 List available linear algebra operations. 

140  

141 Returns 

142 ------- 

143 ops : dict 

144 A dictionary operation identifier -> number of arguments. 

145 """ 

146 return { 

147 op: len(inspect.signature(job).parameters) 

148 for op, (job, _) in ops.items() 

149 } 

150 

151# TODO I should estimate the cost under jit, and subtract the overheaded 

152# estimated with a jitted no-op.