Coverage for src/lsqfitgp/_linalg/_stdcplx.py: 44%
25 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_linalg/_stdcplx.py
2#
3# Copyright (c) 2022, 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20"""
21Module to estimate the time taken by standard linear algebra operations.
22"""
24# TODO maybe I can replace this with jax compilation introspection, which
25# provides flops estimates. See https://jax.readthedocs.io/en/latest/aot.html
27import timeit 1abcdef
28import inspect 1abcdef
30from jax import random 1abcdef
31from jax import numpy as jnp 1abcdef
32from jax.scipy import linalg as jlinalg 1abcdef
33from scipy import sparse 1abcdef
35def benchmark(func, *args, **kwargs): 1abcdef
36 timer = timeit.Timer('func(*args, **kwargs)', globals=locals())
37 n, _ = timer.autorange()
38 times = timer.repeat(5, n)
39 time = min(times) / n
40 return time
42ops = { 1abcdef
43 'chol': [
44 lambda x: jnp.linalg.cholesky(x), # function performing the operation
45 lambda s: s[0] ** 3, # complexity in terms of arguments' shapes
46 ],
47 'eigh': [
48 lambda x: jnp.linalg.eigh(x),
49 lambda s: s[0] ** 3,
50 ],
51 'qr-red': [
52 lambda x: jnp.linalg.qr(x, mode='reduced'),
53 lambda s: min(s) ** 2 * max(s),
54 ],
55 'qr-full': [
56 lambda x: jnp.linalg.qr(x, mode='full'),
57 lambda s: max(s) ** 3,
58 ],
59 'svd-red': [
60 lambda x: jnp.linalg.svd(x, full_matrices=False),
61 lambda s: min(s) ** 2 * max(s),
62 ],
63 'svd-full': [
64 lambda x: jnp.linalg.svd(x, full_matrices=True),
65 lambda s: max(s) ** 3,
66 ],
67 'solve_triangular': [
68 lambda x, y: jlinalg.solve_triangular(x, y),
69 lambda s, t: s[0] ** 2 * t[1],
70 ],
71 'matmul': [
72 lambda x, y: jnp.matmul(x, y),
73 lambda s, t: s[0] * s[1] * t[1],
74 ]
75}
77def gen_ops_factors(n): # pragma: no cover 1abcdef
78 key = random.PRNGKey(202208101236)
79 factors = {}
80 for op, (job, est) in ops.items():
81 print(f'{op}({n})... ', end='', flush=True)
82 nparams = len(inspect.signature(job).parameters)
83 key, subkey = random.split(key)
84 m = random.normal(subkey, (nparams, n, n), jnp.float32)
85 args = m @ jnp.swapaxes(m, -2, -1)
86 time = benchmark(job, *args)
87 print(f'{time:.2g} s')
88 factors[op] = time / est(*(a.shape for a in args))
89 return factors
91ops_factors = {'chol': 6.03470915928483e-12, 1abcdef
92 'eigh': 1.824986875290051e-10,
93 'qr-red': 1.1241237493231893e-10,
94 'qr-full': 1.2058762495871633e-10,
95 'svd-red': 4.468000000342727e-10,
96 'svd-full': 4.2561762500554324e-10,
97 'solve_triangular': 4.1634716559201486e-12,
98 'matmul': 5.6301691802218555e-12} # = gen_ops_factors(1000)
100ops_consts = {'chol': 1.810961455339566e-06, 1abcdef
101 'eigh': 2.390482500195503e-06,
102 'qr-red': 2.6676162518560884e-06,
103 'qr-full': 2.6932845800183714e-06,
104 'svd-red': 3.7152979196980598e-06,
105 'svd-full': 3.663789590355009e-06,
106 'solve_triangular': 2.170706249307841e-06,
107 'matmul': 1.718031665077433e-06} # = gen_ops_factors(1)
109def predtime(op, shapes, types): 1abcdef
110 """
111 Estimate the time taken by a linear algebra operation.
113 Parameters
114 ----------
115 op : str
116 The identifier of the operation, see `listops`.
117 shapes : sequence of tuples of integers
118 The shapes of the arguments.
119 types : sequence of numpy data types
120 The types of the arguments. They are promoted according to JAX rules.
122 Returns
123 -------
124 time : float
125 The estimated time. Not accurate. The unit of measure is seconds on a
126 particular laptop cpu used to calibrate the estimate with 1000x1000
127 matrices.
128 """
129 _, est = ops[op]
130 factor = ops_factors[op]
131 const = ops_consts[op]
132 dt = jnp.sin(jnp.empty(0, jnp.result_type(*types))).dtype
133 if dt == jnp.float64:
134 factor *= 2
135 return const + factor * est(*shapes)
137def listops(): 1abcdef
138 """
139 List available linear algebra operations.
141 Returns
142 -------
143 ops : dict
144 A dictionary operation identifier -> number of arguments.
145 """
146 return {
147 op: len(inspect.signature(job).parameters)
148 for op, (job, _) in ops.items()
149 }
151# TODO I should estimate the cost under jit, and subtract the overheaded
152# estimated with a jitted no-op.