Coverage for src/lsqfitgp/_gvarext/_tabulate.py: 68%
48 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_gvarext/_tabulate.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import textwrap 1fabcde
22import gvar 1fabcde
23import numpy 1fabcde
25def tabulate_together(*gs, headers=True, offset='', ndecimal=None, keys=None): 1fabcde
26 """
28 Format a table comparing side by side various collections of gvars.
30 Parameters
31 ----------
32 *gs : sequence of arrays or dictionaries of gvars
33 The variables to be tabulated. The structures of arrays and dictionaries
34 must match.
35 headers : bool or sequence of strings
36 If True (default), add automatically an header. If False, don't add an
37 header. If a sequence with length len(gs) + 1, it contains the column
38 names for the keys/indices and for each set of variables.
39 offset : str
40 Prefix to each line, default empty.
41 ndecimal : int, optional
42 Number of decimal places. If not specified (default), keep two error
43 digits.
44 keys : sequence, optional
45 If ``gs`` are dictionaries, a subset of keys to be extracted from each
46 dictionary. Ignored if they are arrays.
48 Examples
49 --------
50 >>> print(tabulate_together(gvar.gvar(dict(a=1)), gvar.gvar(dict(a=2))))
51 key/index value1 value2
52 ---------------------------
53 a 1 (0) 2 (0)
55 See also
56 --------
57 gvar.tabulate
59 """
60 if not gs: 60 ↛ 61line 60 didn't jump to line 61 because the condition on line 60 was never true1abcde
61 return ''
62 gs = [g if hasattr(g, 'keys') else numpy.asarray(g) for g in gs] 1abcde
63 assert all(hasattr(g, 'keys') for g in gs) or all(not hasattr(g, 'keys') for g in gs) 1abcde
64 if keys is not None and hasattr(gs[0], 'keys'): 64 ↛ 65line 64 didn't jump to line 65 because the condition on line 64 was never true1abcde
65 gs = [{k: g[k] for k in keys} for g in gs]
66 g0 = gs[0] 1abcde
67 if hasattr(g0, 'keys'): 67 ↛ 71line 67 didn't jump to line 71 because the condition on line 67 was always true1abcde
68 assert all(set(g.keys()) == set(g0.keys()) for g in gs[1:]) 1abcde
69 gs = [{k: g[k] for k in g0} for g in gs] 1abcde
70 else:
71 assert all(g.shape == g0.shape for g in gs[1:])
72 if g0.shape == ():
73 gs = [{'--': g} for g in gs]
74 tables = [ 1abcde
75 _splittable(gvar.tabulate(g, headers=['@', ''], ndecimal=ndecimal))
76 for g in gs
77 ]
78 columns = list(tables[0]) + [t[1] for t in tables[1:]] 1abcde
79 if not hasattr(headers, '__len__'): 79 ↛ 80line 79 didn't jump to line 80 because the condition on line 79 was never true1abcde
80 if headers:
81 headers = ['key/index'] + [f'value{i+1}' for i in range(len(gs))]
82 else:
83 headers = None
84 else:
85 assert len(headers) == len(columns) 1abcde
86 if headers is not None: 86 ↛ 88line 86 didn't jump to line 88 because the condition on line 86 was always true1abcde
87 columns = (_head(col, head) for col, head in zip(columns, headers)) 1abcde
88 return textwrap.indent(_join(columns), offset) 1abcde
90def _splittable(table): 1fabcde
91 lines = table.split('\n') 1abcde
92 header = lines[0] 1abcde
93 col = header.find('@') + 1 1abcde
94 contentlines = lines[2:] 1abcde
95 col1 = '\n'.join(line[:col] for line in contentlines) 1abcde
96 col2 = '\n'.join(line[col:] for line in contentlines) 1abcde
97 return col1, col2 1abcde
99def _head(col, head): 1fabcde
100 head = str(head) 1abcde
101 width = col.find('\n') 1abcde
102 if width < 0: 102 ↛ 103line 102 didn't jump to line 103 because the condition on line 102 was never true1abcde
103 width = len(col)
104 hwidth = len(head) 1abcde
105 if hwidth > width: 105 ↛ 106line 105 didn't jump to line 106 because the condition on line 105 was never true1abcde
106 col = textwrap.indent(col, (hwidth - width) * ' ')
107 else:
108 head = (width - hwidth) * ' ' + head 1abcde
109 return head + '\n' + len(head) * '-' + '\n' + col 1abcde
111def _join(cols): 1fabcde
112 split = (col.split('\n') for col in cols) 1abcde
113 return '\n'.join(''.join(lines) for lines in zip(*split)) 1abcde