Coverage for src/lsqfitgp/_array.py: 95%

398 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_array.py 

2# 

3# Copyright (c) 2020, 2022, 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import textwrap 1feabcd

21import math 1feabcd

22 

23import numpy 1feabcd

24from numpy.lib import recfunctions 1feabcd

25import jax 1feabcd

26from jax import numpy as jnp 1feabcd

27from jax import tree_util 1feabcd

28 

29# TODO use register_pytree_with_keys 

30@tree_util.register_pytree_node_class 1feabcd

31class StructuredArray: 1feabcd

32 """ 

33 JAX-friendly imitation of a numpy structured array. 

34  

35 It behaves like a read-only numpy structured array, and you can create 

36 a copy with a modified field with a jax-like syntax. 

37  

38 Examples 

39 -------- 

40 >>> a = numpy.empty(3, dtype=[('f', float), ('g', float)]) 

41 >>> a = StructuredArray(a) 

42 >>> a = a.at['f'].set(numpy.arange(3)) 

43 ... # is equivalent to a['f'] = numpy.arange(3) 

44  

45 Parameters 

46 ---------- 

47 array : numpy array, StructuredArray 

48 A structured array. An array qualifies as structured if 

49 ``array.dtype.names is not None``. 

50  

51 Notes 

52 ----- 

53 The StructuredArray is a readonly view on the input array. When you 

54 change the content of a field of the StructuredArray, however, the 

55 reference to the original array for that field is lost. 

56  

57 """ 

58 

59 @classmethod 1feabcd

60 def _readonlyview_wrapifstructured(cls, x): 1feabcd

61 if x.dtype.names is not None: 1feabcd

62 x = cls(x) 1eabcd

63 if isinstance(x, numpy.ndarray): 1feabcd

64 x = x.view() 1feabcd

65 x.flags.writeable = False 1feabcd

66 # jax arrays and StructuredArrays are already readonly 

67 return x 1feabcd

68 

69 @classmethod 1feabcd

70 def _array(cls, s, t, d, *, check=True): 1feabcd

71 """ 

72 Create a new StructuredArray. 

73 

74 All methods and functions that create a new StructuredArray object 

75 should use this method. 

76 

77 Parameters 

78 ---------- 

79 s : tuple or None 

80 The shape. If None, it is determined automatically from the arrays. 

81 t : dtype or None 

82 The dtype of the array. If None, it is determined automatically 

83 (before the shape). 

84 d : dict str -> array 

85 The _dict of the array. The arrays, if structured, must be already 

86 StructuredArrays. The order of the keys must match the order of the 

87 fields. 

88 check : bool 

89 If True (default), check the passed values are consistent. 

90 

91 Return 

92 ------ 

93 out : StructuredArray 

94 A new StructuredArray object. 

95 """ 

96 

97 if t is None: 1feabcd

98 # infer the data type from the arrays in the dictionary 

99 ndim = min((x.ndim for x in d.values()), default=None) 1eabcd

100 t = numpy.dtype([ 1eabcd

101 (name, x.dtype, x.shape[ndim:]) 

102 for name, x in d.items() 

103 ]) 

104 # TODO infer the least common head shape instead of counting dims 

105 

106 # remove offset info since this is actually a columnar format 

107 t = recfunctions.repack_fields(t, align=False, recurse=True) 1feabcd

108 

109 if s is None: 1feabcd

110 # infer the shape from the arrays in the dictionary 

111 assert d, 'can not infer array shape with no fields' 1feabcd

112 f = t.names[0] 1feabcd

113 a = d[f] 1feabcd

114 s = a.shape[:a.ndim - t[0].ndim] 1feabcd

115 

116 if check: 116 ↛ 129line 116 didn't jump to line 129 because the condition on line 116 was always true1feabcd

117 assert len(t) == len(d) 1feabcd

118 assert t.names == tuple(d.keys()) 1feabcd

119 assert all( 1feabcd

120 x.dtype == t[f].base and x.ndim >= t[f].ndim 

121 for f, x in d.items() 

122 ) 

123 shapes = [ 1feabcd

124 x.shape[:x.ndim - t[f].ndim] 

125 for f, x in d.items() 

126 ] 

127 assert all(s == s1 for s1 in shapes) 1feabcd

128 

129 out = super().__new__(cls) 1feabcd

130 out.shape = s 1feabcd

131 out.dtype = t 1feabcd

132 out._dict = d 1feabcd

133 

134 return out 1feabcd

135 

136 def __new__(cls, array): 1feabcd

137 if isinstance(array, cls): 1feabcd

138 return array 1feabcd

139 d = { 1feabcd

140 name: cls._readonlyview_wrapifstructured(array[name]) 

141 for name in array.dtype.names 

142 } 

143 return cls._array(array.shape, array.dtype, d) 1feabcd

144 

145 @classmethod 1feabcd

146 def from_dataframe(cls, df): 1feabcd

147 """ 

148 Make a StructuredArray from a DataFrame. Data is not copied if not 

149 necessary. 

150 """ 

151 d = { 1eabcd

152 col: cls._readonlyview_wrapifstructured(df[col].to_numpy()) 

153 for col in df.columns 

154 } 

155 return cls._array(None, None, d) 1eabcd

156 # TODO support polars structured dtypes 

157 # TODO polars has a parameter Series.to_numpy(zero_copy_only: bool), 

158 # default False. Maybe make it accessible through kw or options. 

159 

160 @classmethod 1feabcd

161 def from_dict(cls, mapping): 1feabcd

162 """ 

163 Make a StructuredArray from a dictionary of arrays. Data is not copied. 

164 """ 

165 d = { 1eabcd

166 name: cls._readonlyview_wrapifstructured(value) 

167 for name, value in mapping.items() 

168 } 

169 return cls._array(None, None, d) 1eabcd

170 

171 @property 1feabcd

172 def size(self): 1feabcd

173 return math.prod(self.shape) 1eabcd

174 

175 @property 1feabcd

176 def ndim(self): 1feabcd

177 return len(self.shape) 1eabcd

178 

179 @property 1feabcd

180 def nbytes(self): 1feabcd

181 return sum(x.nbytes for x in self._dict.values()) 

182 

183 @property 1feabcd

184 def T(self): 1feabcd

185 if self.ndim < 2: 1abcd

186 return self 1abcd

187 return self.swapaxes(self.ndim - 2, self.ndim - 1) 1abcd

188 

189 # TODO this is mT, not T! Make a unit test and correct it. 

190 

191 def swapaxes(self, i, j): 1feabcd

192 shape = jax.eval_shape(lambda: jnp.empty(self.shape).swapaxes(i, j)).shape 1abcd

193 d = {k: v.swapaxes(i, j) for k, v in self._dict.items()} 1abcd

194 return self._array(shape, self.dtype, d) 1abcd

195 

196 # TODO: doesn't this break when the indices are negative and there is 

197 # an array field? Test it. 

198 

199 def __len__(self): 1feabcd

200 if self.shape: 1eabcd

201 return self.shape[0] 1eabcd

202 else: 

203 raise TypeError('len() of unsized object') 1abcd

204 

205 def __getitem__(self, key): 1feabcd

206 if isinstance(key, str): 1feabcd

207 return self._dict[key] 1feabcd

208 elif isinstance(key, list) and key and all(isinstance(k, str) for k in key): 1feabcd

209 d = { 1abcd

210 name: self._dict[name] 

211 for name in key 

212 } 

213 return self._array(self.shape, self.dtype[key], d) 1abcd

214 else: 

215 d = { 1feabcd

216 name: x[ 

217 (key if isinstance(key, tuple) else (key,)) 

218 + (slice(None),) * self.dtype[name].ndim 

219 ] 

220 for name, x in self._dict.items() 

221 } 

222 shape = jax.eval_shape(lambda: jnp.empty(self.shape)[key]).shape 1feabcd

223 return self._array(shape, self.dtype, d) 1feabcd

224 

225 @property 1feabcd

226 def at(self): 1feabcd

227 return self._Getter(self) 1feabcd

228 

229 class _Getter: 1feabcd

230 

231 def __init__(self, array): 1feabcd

232 self.array = array 1feabcd

233 

234 def __getitem__(self, key): 1feabcd

235 if key not in self.array.dtype.names: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true1feabcd

236 raise KeyError(key) 

237 return self.Setter(self.array, key) 1feabcd

238 

239 class Setter: 1feabcd

240 

241 def __init__(self, array, key, parent=None): 1feabcd

242 self.array = array 1feabcd

243 self.key = key 1feabcd

244 self.parent = parent 1feabcd

245 

246 def __getitem__(self, subkey): 1feabcd

247 if subkey not in self.array.dtype[self.key].names: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true1abcd

248 raise KeyError(subkey) 

249 return self.__class__(self.array[self.key], subkey, self) 1abcd

250 

251 def set(self, val): 1feabcd

252 assert isinstance(val, (numpy.ndarray, jnp.ndarray, StructuredArray)) 1feabcd

253 prev = self.array._dict[self.key] 1feabcd

254 # TODO support casting and broadcasting 

255 assert prev.dtype == val.dtype 1feabcd

256 assert prev.shape == val.shape 1feabcd

257 d = dict(self.array._dict) 1feabcd

258 d[self.key] = self.array._readonlyview_wrapifstructured(val) 1feabcd

259 out = self.array._array(self.array.shape, self.array.dtype, d) 1feabcd

260 if self.parent: 1feabcd

261 return self.parent.set(out) 1abcd

262 else: 

263 return out 1feabcd

264 

265 def reshape(self, *shape): 1feabcd

266 """ 

267 Reshape the array without changing its contents. See 

268 numpy.ndarray.reshape. 

269 """ 

270 if len(shape) == 1 and hasattr(shape[0], '__len__'): 1feabcd

271 shape = shape[0] 1eabcd

272 shape = tuple(shape) 1feabcd

273 d = { 1feabcd

274 name: x.reshape(shape + self.dtype[name].shape) 

275 for name, x in self._dict.items() 

276 } 

277 shape = numpy.empty(self.shape, []).reshape(shape).shape 1feabcd

278 return self._array(shape, self.dtype, d) 1feabcd

279 

280 def squeeze(self, axis=None): 1feabcd

281 """ 

282 Remove axes of length 1. See numpy.ndarray.squeeze. 

283 """ 

284 if axis is None: 1eabcd

285 axis = tuple(i for i, size in enumerate(self.shape) if size == 1) 1abcd

286 if not hasattr(axis, '__len__'): 1eabcd

287 axis = (axis,) 1eabcd

288 assert all(self.shape[i] == 1 for i in axis) 1eabcd

289 newshape = [size for i, size in enumerate(self.shape) if i not in axis] 1eabcd

290 return self.reshape(newshape) 1eabcd

291 

292 def astype(self, dtype): 1feabcd

293 if dtype != self.dtype: 

294 raise NotImplementedError 

295 return self 

296 

297 def broadcast_to(self, shape, **kw): 1feabcd

298 """ 

299 Return a view of the array broadcasted to another shape. See 

300 numpy.broadcast_to. 

301 """ 

302 # raises if not broadcastable 

303 numpy.broadcast_to(numpy.empty(self.shape, []), shape, **kw) 1abcd

304 d = { 1abcd

305 name: broadcast_to(x, shape + self.dtype[name].shape, **kw) 

306 for name, x in self._dict.items() 

307 } 

308 return self._array(shape, self.dtype, d) 1abcd

309 

310 # TODO implement flatten_with_keys 

311 def tree_flatten(self): 1feabcd

312 """ JAX PyTree encoder. See `jax.tree_util.tree_flatten`. """ 

313 children = tuple(self._dict[key] for key in self.dtype.names) 1feabcd

314 aux = dict(shape=self.shape, dtype=self.dtype) 1feabcd

315 return children, aux 1feabcd

316 

317 @classmethod 1feabcd

318 def tree_unflatten(cls, aux, children): 1feabcd

319 """ JAX PyTree decoder. See `jax.tree_util.tree_unflatten`. """ 

320 

321 # if there are no fields, keep original shape 

322 if not children: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true1feabcd

323 return cls._array(aux['shape'], aux['dtype'], {}) 

324 

325 # convert children to arrays because tree_util.tree_flatten unpacks 0d 

326 # arrays 

327 children = list(map(asarray, children)) 1feabcd

328 

329 # if possible, keep original dtype shapes 

330 oldtype = aux['dtype'] 1feabcd

331 compatible_tail_shapes = all( 1feabcd

332 x.shape[max(0, x.ndim - oldtype[i].ndim):] == oldtype[i].shape 

333 for i, x in enumerate(children) 

334 ) 

335 head_shapes = [ 1feabcd

336 x.shape[:max(0, x.ndim - oldtype[i].ndim)] 

337 for i, x in enumerate(children) 

338 ] 

339 compatible_head_shapes = all(head_shapes[0] == s for s in head_shapes) 1feabcd

340 if compatible_tail_shapes and compatible_head_shapes: 1feabcd

341 dtype = numpy.dtype([ 1feabcd

342 (oldtype.names[i], x.dtype, oldtype[i].shape) 

343 for i, x in enumerate(children) 

344 ]) 

345 else: 

346 dtype = None 1abcd

347 

348 d = dict(zip(oldtype.names, children)) 1feabcd

349 

350 return cls._array(None, dtype, d) 1feabcd

351 

352 # TODO this breaks jax.jit(...).lower(...).compile()(...) because 

353 # apparently `lower` saves the pytree def after a step of dummyfication, 

354 # so the shape and dtype bases of the StructuredArray are () and object. 

355 # JAX expects pytrees to have a structure which does not depend on what 

356 # they store. => Quick hack: preserve the shape and dtype 

357 # unconditionally, i.e., tree_unflatten can produce malformed 

358 # StructuredArrays. The dictionary will contain whatever JAX puts into 

359 # it. => Quicker hack: it seems to me that jax always uses None as 

360 # dummy, so I could detect if all childrens are None or StructuredArray. 

361 

362 def __repr__(self): 1feabcd

363 # code from gvar https://github.com/gplepage/gvar 

364 # bufferdict.pyx:BufferDict:__str__ 

365 out = 'StructuredArray({' 1abcd

366 

367 listrepr = [(repr(k), repr(v)) for k, v in self._dict.items()] 1abcd

368 newlinemode = any('\n' in rv for _, rv in listrepr) 1abcd

369 

370 for rk, rv in listrepr: 1abcd

371 if not newlinemode: 1abcd

372 out += '{}: {}, '.format(rk, rv) 1abcd

373 elif '\n' in rv: 1abcd

374 rv = rv.replace('\n', '\n ') 1abcd

375 out += '\n {}:\n {},'.format(rk, rv) 1abcd

376 else: 

377 out += '\n {}: {},'.format(rk, rv) 1abcd

378 

379 if out.endswith(', '): 1abcd

380 out = out[:-2] 1abcd

381 elif newlinemode: 1abcd

382 out += '\n' 1abcd

383 out += '})' 1abcd

384 

385 return out 1abcd

386 

387 # TODO try simply using the __repr__ of self._dict 

388 

389 def __array__(self): 1feabcd

390 array = numpy.empty(self.shape, self.dtype) 1abcd

391 self._copy_into_array(array) 1abcd

392 return array 1abcd

393 

394 def _copy_into_array(self, dest): 1feabcd

395 assert self.dtype == dest.dtype 1abcd

396 assert self.shape == dest.shape 1abcd

397 for name, src in self._dict.items(): 1abcd

398 if isinstance(src, StructuredArray): 1abcd

399 src._copy_into_array(dest[name]) 1abcd

400 else: 

401 dest[name][...] = src 1abcd

402 

403 def __array_function__(self, func, types, args, kwargs): 1feabcd

404 if func not in self._handled_functions: 404 ↛ 405line 404 didn't jump to line 405 because the condition on line 404 was never true1eabcd

405 return NotImplemented 

406 return self._handled_functions[func](*args, **kwargs) 1eabcd

407 

408 _handled_functions = {} 1feabcd

409 

410 @classmethod 1feabcd

411 def _implements(cls, np_function): 1feabcd

412 """ Register an __array_function__ implementation """ 

413 def decorator(func): 1feabcd

414 cls._handled_functions[np_function] = func 1feabcd

415 newdoc = f"""\ 1feabcd

416Implementation of `{np_function.__module__}.{np_function.__name__}` for `StructuredArray`. 

417 

418""" 

419 if func.__doc__: 1feabcd

420 newdoc += textwrap.dedent(func.__doc__) + '\n' 1feabcd

421 newdoc += 'Original docstring below:\n\n' 1feabcd

422 newdoc += textwrap.dedent(np_function.__doc__) 1feabcd

423 func.__doc__ = newdoc 1feabcd

424 return func 1feabcd

425 return decorator 1feabcd

426 

427@StructuredArray._implements(numpy.broadcast_to) 1feabcd

428def broadcast_to(x, shape, **kw): 1feabcd

429 """ 

430 Version of numpy.broadcast_to that works with StructuredArray and JAX 

431 arrays. 

432 """ 

433 if isinstance(x, StructuredArray): 1abcd

434 return x.broadcast_to(shape, **kw) 1abcd

435 elif isinstance(x, jnp.ndarray): 435 ↛ 436line 435 didn't jump to line 436 because the condition on line 435 was never true1abcd

436 return jnp.broadcast_to(x, shape, **kw) 

437 else: 

438 return numpy.broadcast_to(x, shape, **kw) 1abcd

439 

440@StructuredArray._implements(numpy.broadcast_arrays) 1feabcd

441def broadcast_arrays(*arrays, **kw): 1feabcd

442 """ 

443 Version of numpy.broadcast_arrays that works with StructuredArray and JAX 

444 arrays. 

445 """ 

446 shapes = [a.shape for a in arrays] 1abcd

447 shape = numpy.broadcast_shapes(*shapes) 1abcd

448 return [broadcast_to(a, shape, **kw) for a in arrays] 1abcd

449 # numpy.broadcast_arrays returns a list, not a tuple 

450 

451class broadcast: 1feabcd

452 """ 

453 Version of numpy.broadcast that works with StructuredArray. 

454 """ 

455 

456 # not handled by __array_function__ 

457 

458 def __init__(self, *arrays): 1feabcd

459 self.shape = numpy.broadcast_shapes(*(a.shape for a in arrays)) 1feabcd

460 

461def asarray(x, dtype=None): 1feabcd

462 """ 

463 Version of `numpy.asarray` that works with `StructuredArray` and JAX arrays. 

464 If `x` is not an array already, returns a JAX array if possible. 

465 """ 

466 if isinstance(x, (StructuredArray, jnp.ndarray, numpy.ndarray)): 1feabcd

467 return x if dtype is None else x.astype(dtype) 1feabcd

468 if x is None: 1feabcd

469 return numpy.asarray(x, dtype) 1abcd

470 # partial workaround for jax issue #14506, None would be interpreted as 

471 # nan by jax 

472 try: 1feabcd

473 return jnp.asarray(x, dtype) 1feabcd

474 except (TypeError, ValueError): 1fabcd

475 return numpy.asarray(x, dtype) 1fabcd

476 

477def _asarray_jaxifpossible(x): 1feabcd

478 x = asarray(x) 1feabcd

479 if x.dtype.names: 1feabcd

480 return tree_util.tree_map(_asarray_jaxifpossible, StructuredArray(x)) 1feabcd

481 if isinstance(x, numpy.ndarray): 1feabcd

482 try: 1feabcd

483 return jnp.asarray(x) 1feabcd

484 except (TypeError, ValueError): 1feabcd

485 pass 1feabcd

486 return x 1feabcd

487 

488@StructuredArray._implements(numpy.squeeze) 1feabcd

489def _squeeze(a, axis=None): 1feabcd

490 return a.squeeze(axis) 1abcd

491 

492@StructuredArray._implements(numpy.ix_) 1feabcd

493def _ix(*args): 1feabcd

494 args = tuple(map(asarray, args)) 1abcd

495 assert all(x.ndim == 1 for x in args) 1abcd

496 n = len(args) 1abcd

497 return tuple( 1abcd

498 x.reshape((1,) * i + (-1,) + (1,) * (n - i - 1)) 

499 for i, x in enumerate(args) 

500 ) 

501 

502def unstructured_to_structured(arr, 1feabcd

503 dtype=None, 

504 names=None, 

505 align=False, # TODO maybe align is totally inapplicable even with numpy arrays? What does it mean? 

506 copy=False, 

507 casting='unsafe'): 

508 """ Like `numpy.lib.recfunctions.unstructured_to_structured`, but outputs a 

509 `StructuredArray`. """ 

510 arr = asarray(arr) 1eabcd

511 if not arr.ndim: 1eabcd

512 raise ValueError('arr must have at least one dimension') 1abcd

513 mockup = numpy.empty((0,) + arr.shape[-1:], arr.dtype) 1eabcd

514 dummy = recfunctions.unstructured_to_structured(mockup, 1eabcd

515 dtype=dtype, names=names, align=align, copy=copy, casting=casting) 

516 out, length = _unstructured_to_structured_recursive(0, (), arr, dummy.dtype, copy, casting) 1eabcd

517 assert length == arr.shape[-1] 1eabcd

518 return out 1eabcd

519 

520def _unstructured_to_structured_recursive(idx, shape, arr, dtype, copy, casting, *strides): 1feabcd

521 arrays = {} 1eabcd

522 for i, name in enumerate(dtype.names): 1eabcd

523 base = dtype[i].base 1eabcd

524 subshape = shape + dtype[i].shape 1eabcd

525 size = math.prod(dtype[i].shape) 1eabcd

526 stride = _nd(base) 1eabcd

527 substrides = strides + ((size, stride),) 1eabcd

528 if base.names is not None: 1eabcd

529 y, newidx = _unstructured_to_structured_recursive(idx, subshape, arr, base, copy, casting, *substrides) 1eabcd

530 shift = newidx - idx 1eabcd

531 assert shift == stride 1eabcd

532 idx += size * stride 1eabcd

533 else: 

534 assert stride == 1 1eabcd

535 if all(size == 1 for size, _ in strides): 1eabcd

536 indices = numpy.s_[idx:idx + size] 1eabcd

537 srcsize = size 1eabcd

538 else: 

539 indices = sum(( 1abcd

540 stride * numpy.arange(size)[numpy.s_[:,] + (None,) * i] 

541 for i, (size, stride) in enumerate(reversed(substrides)) 

542 ), start=idx) 

543 indices = indices.reshape(-1) 1abcd

544 srcsize = indices.size 1abcd

545 key = numpy.s_[..., indices] 1eabcd

546 x = arr[key] 1eabcd

547 x = x.reshape(arr.shape[:-1] + subshape) 1eabcd

548 if isinstance(x, jnp.ndarray): 1eabcd

549 y = x.astype(base) 1eabcd

550 else: 

551 y = x.astype(base, copy=copy, casting=casting) 1abcd

552 idx += size 1eabcd

553 arrays[name] = y 1eabcd

554 return StructuredArray._array(arr.shape[:-1] + shape, dtype, arrays), idx 1eabcd

555 

556@StructuredArray._implements(recfunctions.structured_to_unstructured) 1feabcd

557def _structured_to_unstructured(arr, dtype=None, casting='unsafe'): 1feabcd

558 mockup = numpy.empty(0, arr.dtype) 1eabcd

559 dummy = recfunctions.structured_to_unstructured(mockup, dtype=dtype, casting=casting) 1eabcd

560 args = (arr.shape + dummy.shape[-1:], dummy.dtype) 1eabcd

561 try: 1eabcd

562 out = jnp.empty(*args) 1eabcd

563 except TypeError: 1abcd

564 out = numpy.empty(*args) 1abcd

565 # TODO can I make out column-major w.r.t. only the last column? 

566 out, length = _structured_to_unstructured_recursive(0, arr, out) 1eabcd

567 assert length == dummy.shape[-1] 1eabcd

568 return out 1eabcd

569 

570def _nd(dtype): 1feabcd

571 """ Count the number of scalars in a dtype """ 

572 base = dtype.base 1feabcd

573 shape = dtype.shape 1feabcd

574 size = math.prod(shape) 1feabcd

575 if base.names is None: 1feabcd

576 return size 1feabcd

577 else: 

578 return size * sum(_nd(base[name]) for name in base.names) 1feabcd

579 

580 # I use this function in many parts of the package so it should not have an 

581 # underscore, even if I don't export it in the main namespace. And move it 

582 # to utils, it's not specific to StructuredArray. 

583 

584def _structured_to_unstructured_recursive(idx, arr, out, *strides): 1feabcd

585 dtype = arr.dtype 1eabcd

586 for i, name in enumerate(dtype.names): 1eabcd

587 subarr = arr[name] 1eabcd

588 base = dtype[i].base 1eabcd

589 size = math.prod(dtype[i].shape) 1eabcd

590 stride = _nd(base) 1eabcd

591 substrides = strides + ((size, stride),) 1eabcd

592 if base.names is not None: 1eabcd

593 out, newidx = _structured_to_unstructured_recursive(idx, subarr, out, *substrides) 1eabcd

594 shift = newidx - idx 1eabcd

595 assert shift == stride 1eabcd

596 idx += size * stride 1eabcd

597 else: 

598 assert stride == 1 1eabcd

599 if all(size == 1 for size, _ in strides): 1eabcd

600 indices = numpy.s_[idx:idx + size] 1eabcd

601 srcsize = size 1eabcd

602 else: 

603 indices = sum(( 1abcd

604 stride * numpy.arange(size)[numpy.s_[:,] + (None,) * i] 

605 for i, (size, stride) in enumerate(reversed(substrides)) 

606 ), start=idx) 

607 indices = indices.reshape(-1) 1abcd

608 srcsize = indices.size 1abcd

609 key = numpy.s_[..., indices] 1eabcd

610 src = subarr.reshape(out.shape[:-1] + (srcsize,)) 1eabcd

611 if hasattr(out, 'at'): 1eabcd

612 out = out.at[key].set(src) 1eabcd

613 else: 

614 out[key] = src 1abcd

615 idx += size 1eabcd

616 return out, idx 1eabcd

617 

618@StructuredArray._implements(numpy.empty_like) 1feabcd

619def _empty_like(prototype, dtype=None, *, shape=None): 1feabcd

620 shape = prototype.shape if shape is None else shape 1abcd

621 dtype = prototype.dtype if dtype is None else dtype 1abcd

622 return _empty(shape, dtype) 1abcd

623 

624@StructuredArray._implements(numpy.empty) 1feabcd

625def _empty(shape, dtype=float): 1feabcd

626 if hasattr(shape, '__len__'): 1abcd

627 shape = tuple(shape) 1abcd

628 else: 

629 shape = (int(shape),) 1abcd

630 dtype = numpy.dtype(dtype) 1abcd

631 arrays = {} 1abcd

632 for i, name in enumerate(dtype.names): 1abcd

633 dtbase = dtype[i].base 1abcd

634 dtshape = shape + dtype[i].shape 1abcd

635 if dtbase.names is not None: 1abcd

636 y = _empty(dtshape, dtbase) 1abcd

637 else: 

638 try: 1abcd

639 y = jnp.empty(dtshape, dtbase) 1abcd

640 except TypeError: 1abcd

641 y = numpy.empty(dtshape, dtbase) 1abcd

642 arrays[name] = y 1abcd

643 return StructuredArray._array(shape, dtype, arrays) 1abcd

644 

645@StructuredArray._implements(numpy.concatenate) 1feabcd

646def _concatenate(arrays, axis=0, dtype=None, casting='same_kind'): 1feabcd

647 

648 # checks arrays is a non-empty sequence 

649 arrays = list(arrays) 1abcd

650 if not arrays: 650 ↛ 651line 650 didn't jump to line 651 because the condition on line 650 was never true1abcd

651 raise ValueError('need at least one array to concatenate') 

652 

653 # parse axis argument 

654 if axis is None: 654 ↛ 655line 654 didn't jump to line 655 because the condition on line 654 was never true1abcd

655 axis = 0 

656 arrays = [a.reshape(-1) for a in arrays] 

657 else: 

658 ndim = arrays[0].ndim 1abcd

659 assert all(a.ndim == ndim for a in arrays) 1abcd

660 assert -ndim <= axis < ndim 1abcd

661 axis %= ndim 1abcd

662 shape = arrays[0].shape 1abcd

663 assert all(a.shape[:axis] == shape[:axis] for a in arrays) 1abcd

664 assert all(a.shape[axis + 1:] == shape[axis + 1:] for a in arrays) 1abcd

665 

666 dtype = numpy.result_type(*(a.dtype for a in arrays)) 1abcd

667 assert all(numpy.can_cast(a.dtype, dtype, casting) for a in arrays) 1abcd

668 shape = ( 1abcd

669 *arrays[0].shape[:axis], 

670 sum(a.shape[axis] for a in arrays), 

671 *arrays[0].shape[axis + 1:], 

672 ) 

673 

674 out = _concatenate_recursive(arrays, axis, dtype, shape, casting) 1abcd

675 assert out.shape == shape and out.dtype == dtype 1abcd

676 return out 1abcd

677 

678def _concatenate_recursive(arrays, axis, dtype, shape, casting): 1feabcd

679 cat = {} 1abcd

680 for name in dtype.names: 1abcd

681 subarrays = [a[name] for a in arrays] 1abcd

682 base = dtype[name].base 1abcd

683 if base.names is not None: 1abcd

684 subshape = shape + dtype[name].shape 1abcd

685 y = _concatenate_recursive(subarrays, axis, base, subshape, casting) 1abcd

686 else: 

687 try: 1abcd

688 y = jnp.concatenate(subarrays, axis=axis, dtype=base) 1abcd

689 except TypeError: 1abcd

690 y = numpy.concatenate(subarrays, axis=axis, dtype=base, casting=casting) 1abcd

691 cat[name] = y 1abcd

692 return StructuredArray._array(shape, dtype, cat) 1abcd

693 

694@StructuredArray._implements(recfunctions.append_fields) 1feabcd

695def _append_fields(base, names, data, usemask=True): 1feabcd

696 assert not usemask, 'masked arrays not supported, set usemask=False' 1abcd

697 if isinstance(names, str): 1abcd

698 names = [names] 1abcd

699 data = [data] 1abcd

700 assert len(names) == len(data) 1abcd

701 arrays = base._dict.copy() 1abcd

702 arrays.update(zip(names, data)) 1abcd

703 dtype = numpy.dtype(base.dtype.descr + [ 1abcd

704 (name, array.dtype) for name, array in zip(names, data) 

705 ]) 

706 return StructuredArray._array(base.shape, dtype, arrays) 1abcd

707 

708@StructuredArray._implements(numpy.swapaxes) 1feabcd

709def _swapaxes(x, i, j): 1feabcd

710 return x.swapaxes(i, j)