Coverage for src/lsqfitgp/_array.py: 94%

404 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-12 12:42 +0000

1# lsqfitgp/_array.py 

2# 

3# Copyright (c) 2020, 2022, 2023, 2024, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import textwrap 1feabcd

21import math 1feabcd

22 

23import numpy 1feabcd

24from numpy.lib import recfunctions 1feabcd

25import jax 1feabcd

26from jax import numpy as jnp 1feabcd

27from jax import tree_util 1feabcd

28 

29# TODO use register_pytree_with_keys 

30@tree_util.register_pytree_node_class 1feabcd

31class StructuredArray: 1feabcd

32 """ 

33 JAX-friendly imitation of a numpy structured array. 

34  

35 It behaves like a read-only numpy structured array, and you can create 

36 a copy with a modified field with a jax-like syntax. 

37  

38 Examples 

39 -------- 

40 >>> a = numpy.empty(3, dtype=[('f', float), ('g', float)]) 

41 >>> a = StructuredArray(a) 

42 >>> a = a.at['f'].set(numpy.arange(3)) 

43 ... # is equivalent to a['f'] = numpy.arange(3) 

44  

45 Parameters 

46 ---------- 

47 array : numpy array, StructuredArray 

48 A structured array. An array qualifies as structured if 

49 ``array.dtype.names is not None``. 

50  

51 Notes 

52 ----- 

53 The StructuredArray is a readonly view on the input array. When you 

54 change the content of a field of the StructuredArray, however, the 

55 reference to the original array for that field is lost. 

56  

57 """ 

58 

59 @classmethod 1feabcd

60 def _readonlyview_wrapifstructured(cls, x): 1feabcd

61 if x.dtype.names is not None: 1feabcd

62 x = cls(x) 1eabcd

63 if isinstance(x, numpy.ndarray): 1feabcd

64 x = x.view() 1feabcd

65 x.flags.writeable = False 1feabcd

66 # jax arrays and StructuredArrays are already readonly 

67 return x 1feabcd

68 

69 @classmethod 1feabcd

70 def _array(cls, s, t, d, *, check=True): 1feabcd

71 """ 

72 Create a new StructuredArray. 

73 

74 All methods and functions that create a new StructuredArray object 

75 should use this method. 

76 

77 Parameters 

78 ---------- 

79 s : tuple or None 

80 The shape. If None, it is determined automatically from the arrays. 

81 t : dtype or None 

82 The dtype of the array. If None, it is determined automatically 

83 (before the shape). 

84 d : dict str -> array 

85 The _dict of the array. The arrays, if structured, must be already 

86 StructuredArrays. The order of the keys must match the order of the 

87 fields. 

88 check : bool 

89 If True (default), check the passed values are consistent. 

90 

91 Return 

92 ------ 

93 out : StructuredArray 

94 A new StructuredArray object. 

95 """ 

96 

97 if t is None: 1feabcd

98 # infer the data type from the arrays in the dictionary 

99 ndim = min((x.ndim for x in d.values()), default=None) 1eabcd

100 t = numpy.dtype([ 1eabcd

101 (name, x.dtype, x.shape[ndim:]) 

102 for name, x in d.items() 

103 ]) 

104 # TODO infer the least common head shape instead of counting dims 

105 

106 # remove offset info since this is actually a columnar format 

107 t = recfunctions.repack_fields(t, align=False, recurse=True) 1feabcd

108 

109 if s is None: 1feabcd

110 # infer the shape from the arrays in the dictionary 

111 assert d, 'can not infer array shape with no fields' 1feabcd

112 f = t.names[0] 1feabcd

113 a = d[f] 1feabcd

114 s = a.shape[:a.ndim - t[0].ndim] 1feabcd

115 

116 if check: 116 ↛ 129line 116 didn't jump to line 129 because the condition on line 116 was always true1feabcd

117 assert len(t) == len(d) 1feabcd

118 assert t.names == tuple(d.keys()) 1feabcd

119 assert all( 1feabcd

120 x.dtype == t[f].base and x.ndim >= t[f].ndim 

121 for f, x in d.items() 

122 ) 

123 shapes = [ 1feabcd

124 x.shape[:x.ndim - t[f].ndim] 

125 for f, x in d.items() 

126 ] 

127 assert all(s == s1 for s1 in shapes) 1feabcd

128 

129 out = super().__new__(cls) 1feabcd

130 out.shape = s 1feabcd

131 out.dtype = t 1feabcd

132 out._dict = d 1feabcd

133 

134 return out 1feabcd

135 

136 def __new__(cls, array): 1feabcd

137 if isinstance(array, cls): 1feabcd

138 return array 1feabcd

139 d = { 1feabcd

140 name: cls._readonlyview_wrapifstructured(array[name]) 

141 for name in array.dtype.names 

142 } 

143 return cls._array(array.shape, array.dtype, d) 1feabcd

144 

145 @classmethod 1feabcd

146 def from_dataframe(cls, df): 1feabcd

147 """ 

148 Make a StructuredArray from a DataFrame. Data is not copied if not 

149 necessary. 

150 """ 

151 d = { 1eabcd

152 col: cls._readonlyview_wrapifstructured(df[col].to_numpy()) 

153 for col in df.columns 

154 } 

155 return cls._array(None, None, d) 1eabcd

156 # TODO support polars structured dtypes 

157 # TODO polars has a parameter Series.to_numpy(zero_copy_only: bool), 

158 # default False. Maybe make it accessible through kw or options. 

159 

160 @classmethod 1feabcd

161 def from_dict(cls, mapping): 1feabcd

162 """ 

163 Make a StructuredArray from a dictionary of arrays. Data is not copied. 

164 """ 

165 d = { 1eabcd

166 name: cls._readonlyview_wrapifstructured(value) 

167 for name, value in mapping.items() 

168 } 

169 return cls._array(None, None, d) 1eabcd

170 

171 @property 1feabcd

172 def size(self): 1feabcd

173 return math.prod(self.shape) 1eabcd

174 

175 @property 1feabcd

176 def ndim(self): 1feabcd

177 return len(self.shape) 1eabcd

178 

179 @property 1feabcd

180 def nbytes(self): 1feabcd

181 return sum(x.nbytes for x in self._dict.values()) 

182 

183 @property 1feabcd

184 def T(self): 1feabcd

185 if self.ndim < 2: 1abcd

186 return self 1abcd

187 return self.swapaxes(self.ndim - 2, self.ndim - 1) 1abcd

188 

189 # TODO this is mT, not T! Make a unit test and correct it. 

190 

191 def swapaxes(self, i, j): 1feabcd

192 shape = jax.eval_shape(lambda: jnp.empty(self.shape).swapaxes(i, j)).shape 1abcd

193 d = {k: v.swapaxes(i, j) for k, v in self._dict.items()} 1abcd

194 return self._array(shape, self.dtype, d) 1abcd

195 

196 # TODO: doesn't this break when the indices are negative and there is 

197 # an array field? Test it. 

198 

199 def __len__(self): 1feabcd

200 if self.shape: 1eabcd

201 return self.shape[0] 1eabcd

202 else: 

203 raise TypeError('len() of unsized object') 1abcd

204 

205 def __getitem__(self, key): 1feabcd

206 if isinstance(key, str): 1feabcd

207 return self._dict[key] 1feabcd

208 elif isinstance(key, list) and key and all(isinstance(k, str) for k in key): 1feabcd

209 d = { 1abcd

210 name: self._dict[name] 

211 for name in key 

212 } 

213 return self._array(self.shape, self.dtype[key], d) 1abcd

214 else: 

215 d = { 1feabcd

216 name: x[ 

217 (key if isinstance(key, tuple) else (key,)) 

218 + (slice(None),) * self.dtype[name].ndim 

219 ] 

220 for name, x in self._dict.items() 

221 } 

222 shape = jax.eval_shape(lambda: jnp.empty(self.shape)[key]).shape 1feabcd

223 return self._array(shape, self.dtype, d) 1feabcd

224 

225 @property 1feabcd

226 def at(self): 1feabcd

227 return self._Getter(self) 1feabcd

228 

229 class _Getter: 1feabcd

230 

231 def __init__(self, array): 1feabcd

232 self.array = array 1feabcd

233 

234 def __getitem__(self, key): 1feabcd

235 if key not in self.array.dtype.names: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true1feabcd

236 raise KeyError(key) 

237 return self.Setter(self.array, key) 1feabcd

238 

239 class Setter: 1feabcd

240 

241 def __init__(self, array, key, parent=None): 1feabcd

242 self.array = array 1feabcd

243 self.key = key 1feabcd

244 self.parent = parent 1feabcd

245 

246 def __getitem__(self, subkey): 1feabcd

247 if subkey not in self.array.dtype[self.key].names: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true1abcd

248 raise KeyError(subkey) 

249 return self.__class__(self.array[self.key], subkey, self) 1abcd

250 

251 def set(self, val): 1feabcd

252 assert isinstance(val, (numpy.ndarray, jnp.ndarray, StructuredArray)) 1feabcd

253 prev = self.array._dict[self.key] 1feabcd

254 # TODO support casting and broadcasting 

255 assert prev.dtype == val.dtype 1feabcd

256 assert prev.shape == val.shape 1feabcd

257 d = dict(self.array._dict) 1feabcd

258 d[self.key] = self.array._readonlyview_wrapifstructured(val) 1feabcd

259 out = self.array._array(self.array.shape, self.array.dtype, d) 1feabcd

260 if self.parent: 1feabcd

261 return self.parent.set(out) 1abcd

262 else: 

263 return out 1feabcd

264 

265 def reshape(self, *shape): 1feabcd

266 """ 

267 Reshape the array without changing its contents. See 

268 numpy.ndarray.reshape. 

269 """ 

270 if len(shape) == 1 and hasattr(shape[0], '__len__'): 1feabcd

271 shape = shape[0] 1eabcd

272 shape = tuple(shape) 1feabcd

273 d = { 1feabcd

274 name: x.reshape(shape + self.dtype[name].shape) 

275 for name, x in self._dict.items() 

276 } 

277 shape = numpy.empty(self.shape, []).reshape(shape).shape 1feabcd

278 return self._array(shape, self.dtype, d) 1feabcd

279 

280 def squeeze(self, axis=None): 1feabcd

281 """ 

282 Remove axes of length 1. See numpy.ndarray.squeeze. 

283 """ 

284 if axis is None: 1eabcd

285 axis = tuple(i for i, size in enumerate(self.shape) if size == 1) 1abcd

286 if not hasattr(axis, '__len__'): 1eabcd

287 axis = (axis,) 1eabcd

288 assert all(self.shape[i] == 1 for i in axis) 1eabcd

289 newshape = [size for i, size in enumerate(self.shape) if i not in axis] 1eabcd

290 return self.reshape(newshape) 1eabcd

291 

292 def astype(self, dtype): 1feabcd

293 if dtype != self.dtype: 

294 raise NotImplementedError 

295 return self 

296 

297 def broadcast_to(self, shape, **kw): 1feabcd

298 """ 

299 Return a view of the array broadcasted to another shape. See 

300 numpy.broadcast_to. 

301 """ 

302 # raises if not broadcastable 

303 numpy.broadcast_to(numpy.empty(self.shape, []), shape, **kw) 1abcd

304 d = { 1abcd

305 name: broadcast_to(x, shape + self.dtype[name].shape, **kw) 

306 for name, x in self._dict.items() 

307 } 

308 return self._array(shape, self.dtype, d) 1abcd

309 

310 # TODO implement flatten_with_keys 

311 def tree_flatten(self): 1feabcd

312 """ JAX PyTree encoder. See `jax.tree_util.tree_flatten`. """ 

313 children = tuple(self._dict[key] for key in self.dtype.names) 1feabcd

314 aux = dict(shape=self.shape, dtype=self.dtype) 1feabcd

315 return children, aux 1feabcd

316 

317 @classmethod 1feabcd

318 def tree_unflatten(cls, aux, children): 1feabcd

319 """ JAX PyTree decoder. See `jax.tree_util.tree_unflatten`. """ 

320 

321 # if there are no fields, keep original shape 

322 if not children: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true1feabcd

323 return cls._array(aux['shape'], aux['dtype'], {}) 

324 

325 # convert children to arrays because tree_util.tree_flatten unpacks 0d 

326 # arrays 

327 children = list(map(asarray, children)) 1feabcd

328 

329 # if possible, keep original dtype shapes 

330 oldtype = aux['dtype'] 1feabcd

331 compatible_tail_shapes = all( 1feabcd

332 x.shape[max(0, x.ndim - oldtype[i].ndim):] == oldtype[i].shape 

333 for i, x in enumerate(children) 

334 ) 

335 head_shapes = [ 1feabcd

336 x.shape[:max(0, x.ndim - oldtype[i].ndim)] 

337 for i, x in enumerate(children) 

338 ] 

339 compatible_head_shapes = all(head_shapes[0] == s for s in head_shapes) 1feabcd

340 if compatible_tail_shapes and compatible_head_shapes: 1feabcd

341 dtype = numpy.dtype([ 1feabcd

342 (oldtype.names[i], x.dtype, oldtype[i].shape) 

343 for i, x in enumerate(children) 

344 ]) 

345 else: 

346 dtype = None 1abcd

347 

348 d = dict(zip(oldtype.names, children)) 1feabcd

349 

350 return cls._array(None, dtype, d) 1feabcd

351 

352 # TODO this breaks jax.jit(...).lower(...).compile()(...) because 

353 # apparently `lower` saves the pytree def after a step of dummyfication, 

354 # so the shape and dtype bases of the StructuredArray are () and object. 

355 # JAX expects pytrees to have a structure which does not depend on what 

356 # they store. => Quick hack: preserve the shape and dtype 

357 # unconditionally, i.e., tree_unflatten can produce malformed 

358 # StructuredArrays. The dictionary will contain whatever JAX puts into 

359 # it. => Quicker hack: it seems to me that jax always uses None as 

360 # dummy, so I could detect if all childrens are None or StructuredArray. 

361 

362 def __repr__(self): 1feabcd

363 # code from gvar https://github.com/gplepage/gvar 

364 # bufferdict.pyx:BufferDict:__str__ 

365 out = 'StructuredArray({' 1abcd

366 

367 listrepr = [(repr(k), repr(v)) for k, v in self._dict.items()] 1abcd

368 newlinemode = any('\n' in rv for _, rv in listrepr) 1abcd

369 

370 for rk, rv in listrepr: 1abcd

371 if not newlinemode: 1abcd

372 out += '{}: {}, '.format(rk, rv) 1abcd

373 elif '\n' in rv: 1abcd

374 rv = rv.replace('\n', '\n ') 1abcd

375 out += '\n {}:\n {},'.format(rk, rv) 1abcd

376 else: 

377 out += '\n {}: {},'.format(rk, rv) 1abcd

378 

379 if out.endswith(', '): 1abcd

380 out = out[:-2] 1abcd

381 elif newlinemode: 1abcd

382 out += '\n' 1abcd

383 out += '})' 1abcd

384 

385 return out 1abcd

386 

387 # TODO try simply using the __repr__ of self._dict 

388 

389 def __array__(self, copy=None, dtype=None): 1feabcd

390 if copy is False: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true1abcd

391 raise ValueError('StructuredArray has to be copied when converted to a numpy array') 

392 if dtype is not None: 392 ↛ 393line 392 didn't jump to line 393 because the condition on line 392 was never true1abcd

393 dtype = numpy.dtype(dtype) 

394 if dtype != self.dtype: 

395 raise ValueError('StructuredArray can not be converted to a numpy array with a different dtype') 

396 array = numpy.empty(self.shape, self.dtype) 1abcd

397 self._copy_into_array(array) 1abcd

398 return array 1abcd

399 

400 def _copy_into_array(self, dest): 1feabcd

401 assert self.dtype == dest.dtype 1abcd

402 assert self.shape == dest.shape 1abcd

403 for name, src in self._dict.items(): 1abcd

404 if isinstance(src, StructuredArray): 1abcd

405 src._copy_into_array(dest[name]) 1abcd

406 else: 

407 dest[name][...] = src 1abcd

408 

409 def __array_function__(self, func, types, args, kwargs): 1feabcd

410 if func not in self._handled_functions: 410 ↛ 411line 410 didn't jump to line 411 because the condition on line 410 was never true1eabcd

411 return NotImplemented 

412 return self._handled_functions[func](*args, **kwargs) 1eabcd

413 

414 _handled_functions = {} 1feabcd

415 

416 @classmethod 1feabcd

417 def _implements(cls, np_function): 1feabcd

418 """ Register an __array_function__ implementation """ 

419 def decorator(func): 1feabcd

420 cls._handled_functions[np_function] = func 1feabcd

421 newdoc = f"""\ 1feabcd

422Implementation of `{np_function.__module__}.{np_function.__name__}` for `StructuredArray`. 

423 

424""" 

425 if func.__doc__: 1feabcd

426 newdoc += textwrap.dedent(func.__doc__) + '\n' 1feabcd

427 newdoc += 'Original docstring below:\n\n' 1feabcd

428 newdoc += textwrap.dedent(np_function.__doc__) 1feabcd

429 func.__doc__ = newdoc 1feabcd

430 return func 1feabcd

431 return decorator 1feabcd

432 

433@StructuredArray._implements(numpy.broadcast_to) 1feabcd

434def broadcast_to(x, shape, **kw): 1feabcd

435 """ 

436 Version of numpy.broadcast_to that works with StructuredArray and JAX 

437 arrays. 

438 """ 

439 if isinstance(x, StructuredArray): 1abcd

440 return x.broadcast_to(shape, **kw) 1abcd

441 elif isinstance(x, jnp.ndarray): 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true1abcd

442 return jnp.broadcast_to(x, shape, **kw) 

443 else: 

444 return numpy.broadcast_to(x, shape, **kw) 1abcd

445 

446@StructuredArray._implements(numpy.broadcast_arrays) 1feabcd

447def broadcast_arrays(*arrays, **kw): 1feabcd

448 """ 

449 Version of numpy.broadcast_arrays that works with StructuredArray and JAX 

450 arrays. 

451 """ 

452 shapes = [a.shape for a in arrays] 1abcd

453 shape = numpy.broadcast_shapes(*shapes) 1abcd

454 return [broadcast_to(a, shape, **kw) for a in arrays] 1abcd

455 # numpy.broadcast_arrays returns a list, not a tuple 

456 

457class broadcast: 1feabcd

458 """ 

459 Version of numpy.broadcast that works with StructuredArray. 

460 """ 

461 

462 # not handled by __array_function__ 

463 

464 def __init__(self, *arrays): 1feabcd

465 self.shape = numpy.broadcast_shapes(*(a.shape for a in arrays)) 1feabcd

466 

467def asarray(x, dtype=None): 1feabcd

468 """ 

469 Version of `numpy.asarray` that works with `StructuredArray` and JAX arrays. 

470 If `x` is not an array already, returns a JAX array if possible. 

471 """ 

472 if isinstance(x, (StructuredArray, jnp.ndarray, numpy.ndarray)): 1feabcd

473 return x if dtype is None else x.astype(dtype) 1feabcd

474 if x is None: 1feabcd

475 return numpy.asarray(x, dtype) 1abcd

476 # partial workaround for jax issue #14506, None would be interpreted as 

477 # nan by jax 

478 try: 1feabcd

479 return jnp.asarray(x, dtype) 1feabcd

480 except (TypeError, ValueError): 1fabcd

481 return numpy.asarray(x, dtype) 1fabcd

482 

483def _asarray_jaxifpossible(x): 1feabcd

484 x = asarray(x) 1feabcd

485 if x.dtype.names: 1feabcd

486 return tree_util.tree_map(_asarray_jaxifpossible, StructuredArray(x)) 1feabcd

487 if isinstance(x, numpy.ndarray): 1feabcd

488 try: 1feabcd

489 return jnp.asarray(x) 1feabcd

490 except (TypeError, ValueError): 1feabcd

491 pass 1feabcd

492 return x 1feabcd

493 

494@StructuredArray._implements(numpy.squeeze) 1feabcd

495def _squeeze(a, axis=None): 1feabcd

496 return a.squeeze(axis) 1abcd

497 

498@StructuredArray._implements(numpy.ix_) 1feabcd

499def _ix(*args): 1feabcd

500 args = tuple(map(asarray, args)) 1abcd

501 assert all(x.ndim == 1 for x in args) 1abcd

502 n = len(args) 1abcd

503 return tuple( 1abcd

504 x.reshape((1,) * i + (-1,) + (1,) * (n - i - 1)) 

505 for i, x in enumerate(args) 

506 ) 

507 

508def unstructured_to_structured(arr, 1feabcd

509 dtype=None, 

510 names=None, 

511 align=False, # TODO maybe align is totally inapplicable even with numpy arrays? What does it mean? 

512 copy=False, 

513 casting='unsafe'): 

514 """ Like `numpy.lib.recfunctions.unstructured_to_structured`, but outputs a 

515 `StructuredArray`. """ 

516 arr = asarray(arr) 1eabcd

517 if not arr.ndim: 1eabcd

518 raise ValueError('arr must have at least one dimension') 1abcd

519 mockup = numpy.empty((0,) + arr.shape[-1:], arr.dtype) 1eabcd

520 dummy = recfunctions.unstructured_to_structured(mockup, 1eabcd

521 dtype=dtype, names=names, align=align, copy=copy, casting=casting) 

522 out, length = _unstructured_to_structured_recursive(0, (), arr, dummy.dtype, copy, casting) 1eabcd

523 assert length == arr.shape[-1] 1eabcd

524 return out 1eabcd

525 

526def _unstructured_to_structured_recursive(idx, shape, arr, dtype, copy, casting, *strides): 1feabcd

527 arrays = {} 1eabcd

528 for i, name in enumerate(dtype.names): 1eabcd

529 base = dtype[i].base 1eabcd

530 subshape = shape + dtype[i].shape 1eabcd

531 size = math.prod(dtype[i].shape) 1eabcd

532 stride = _nd(base) 1eabcd

533 substrides = strides + ((size, stride),) 1eabcd

534 if base.names is not None: 1eabcd

535 y, newidx = _unstructured_to_structured_recursive(idx, subshape, arr, base, copy, casting, *substrides) 1eabcd

536 shift = newidx - idx 1eabcd

537 assert shift == stride 1eabcd

538 idx += size * stride 1eabcd

539 else: 

540 assert stride == 1 1eabcd

541 if all(size == 1 for size, _ in strides): 1eabcd

542 indices = numpy.s_[idx:idx + size] 1eabcd

543 srcsize = size 1eabcd

544 else: 

545 indices = sum(( 1abcd

546 stride * numpy.arange(size)[numpy.s_[:,] + (None,) * i] 

547 for i, (size, stride) in enumerate(reversed(substrides)) 

548 ), start=idx) 

549 indices = indices.reshape(-1) 1abcd

550 srcsize = indices.size 1abcd

551 key = numpy.s_[..., indices] 1eabcd

552 x = arr[key] 1eabcd

553 x = x.reshape(arr.shape[:-1] + subshape) 1eabcd

554 if isinstance(x, jnp.ndarray): 1eabcd

555 y = x.astype(base) 1eabcd

556 else: 

557 y = x.astype(base, copy=copy, casting=casting) 1abcd

558 idx += size 1eabcd

559 arrays[name] = y 1eabcd

560 return StructuredArray._array(arr.shape[:-1] + shape, dtype, arrays), idx 1eabcd

561 

562@StructuredArray._implements(recfunctions.structured_to_unstructured) 1feabcd

563def _structured_to_unstructured(arr, dtype=None, casting='unsafe'): 1feabcd

564 mockup = numpy.empty(0, arr.dtype) 1eabcd

565 dummy = recfunctions.structured_to_unstructured(mockup, dtype=dtype, casting=casting) 1eabcd

566 args = (arr.shape + dummy.shape[-1:], dummy.dtype) 1eabcd

567 try: 1eabcd

568 out = jnp.empty(*args) 1eabcd

569 except TypeError: 1abcd

570 out = numpy.empty(*args) 1abcd

571 # TODO can I make out column-major w.r.t. only the last column? 

572 out, length = _structured_to_unstructured_recursive(0, arr, out) 1eabcd

573 assert length == dummy.shape[-1] 1eabcd

574 return out 1eabcd

575 

576def _nd(dtype): 1feabcd

577 """ Count the number of scalars in a dtype """ 

578 base = dtype.base 1feabcd

579 shape = dtype.shape 1feabcd

580 size = math.prod(shape) 1feabcd

581 if base.names is None: 1feabcd

582 return size 1feabcd

583 else: 

584 return size * sum(_nd(base[name]) for name in base.names) 1feabcd

585 

586 # I use this function in many parts of the package so it should not have an 

587 # underscore, even if I don't export it in the main namespace. And move it 

588 # to utils, it's not specific to StructuredArray. 

589 

590def _structured_to_unstructured_recursive(idx, arr, out, *strides): 1feabcd

591 dtype = arr.dtype 1eabcd

592 for i, name in enumerate(dtype.names): 1eabcd

593 subarr = arr[name] 1eabcd

594 base = dtype[i].base 1eabcd

595 size = math.prod(dtype[i].shape) 1eabcd

596 stride = _nd(base) 1eabcd

597 substrides = strides + ((size, stride),) 1eabcd

598 if base.names is not None: 1eabcd

599 out, newidx = _structured_to_unstructured_recursive(idx, subarr, out, *substrides) 1eabcd

600 shift = newidx - idx 1eabcd

601 assert shift == stride 1eabcd

602 idx += size * stride 1eabcd

603 else: 

604 assert stride == 1 1eabcd

605 if all(size == 1 for size, _ in strides): 1eabcd

606 indices = numpy.s_[idx:idx + size] 1eabcd

607 srcsize = size 1eabcd

608 else: 

609 indices = sum(( 1abcd

610 stride * numpy.arange(size)[numpy.s_[:,] + (None,) * i] 

611 for i, (size, stride) in enumerate(reversed(substrides)) 

612 ), start=idx) 

613 indices = indices.reshape(-1) 1abcd

614 srcsize = indices.size 1abcd

615 key = numpy.s_[..., indices] 1eabcd

616 src = subarr.reshape(out.shape[:-1] + (srcsize,)) 1eabcd

617 if hasattr(out, 'at'): 1eabcd

618 out = out.at[key].set(src) 1eabcd

619 else: 

620 out[key] = src 1abcd

621 idx += size 1eabcd

622 return out, idx 1eabcd

623 

624@StructuredArray._implements(numpy.empty_like) 1feabcd

625def _empty_like(prototype, dtype=None, *, shape=None): 1feabcd

626 shape = prototype.shape if shape is None else shape 1abcd

627 dtype = prototype.dtype if dtype is None else dtype 1abcd

628 return _empty(shape, dtype) 1abcd

629 

630@StructuredArray._implements(numpy.empty) 1feabcd

631def _empty(shape, dtype=float): 1feabcd

632 if hasattr(shape, '__len__'): 1abcd

633 shape = tuple(shape) 1abcd

634 else: 

635 shape = (int(shape),) 1abcd

636 dtype = numpy.dtype(dtype) 1abcd

637 arrays = {} 1abcd

638 for i, name in enumerate(dtype.names): 1abcd

639 dtbase = dtype[i].base 1abcd

640 dtshape = shape + dtype[i].shape 1abcd

641 if dtbase.names is not None: 1abcd

642 y = _empty(dtshape, dtbase) 1abcd

643 else: 

644 try: 1abcd

645 y = jnp.empty(dtshape, dtbase) 1abcd

646 except TypeError: 1abcd

647 y = numpy.empty(dtshape, dtbase) 1abcd

648 arrays[name] = y 1abcd

649 return StructuredArray._array(shape, dtype, arrays) 1abcd

650 

651@StructuredArray._implements(numpy.concatenate) 1feabcd

652def _concatenate(arrays, axis=0, dtype=None, casting='same_kind'): 1feabcd

653 

654 # checks arrays is a non-empty sequence 

655 arrays = list(arrays) 1abcd

656 if not arrays: 656 ↛ 657line 656 didn't jump to line 657 because the condition on line 656 was never true1abcd

657 raise ValueError('need at least one array to concatenate') 

658 

659 # parse axis argument 

660 if axis is None: 660 ↛ 661line 660 didn't jump to line 661 because the condition on line 660 was never true1abcd

661 axis = 0 

662 arrays = [a.reshape(-1) for a in arrays] 

663 else: 

664 ndim = arrays[0].ndim 1abcd

665 assert all(a.ndim == ndim for a in arrays) 1abcd

666 assert -ndim <= axis < ndim 1abcd

667 axis %= ndim 1abcd

668 shape = arrays[0].shape 1abcd

669 assert all(a.shape[:axis] == shape[:axis] for a in arrays) 1abcd

670 assert all(a.shape[axis + 1:] == shape[axis + 1:] for a in arrays) 1abcd

671 

672 dtype = numpy.result_type(*(a.dtype for a in arrays)) 1abcd

673 assert all(numpy.can_cast(a.dtype, dtype, casting) for a in arrays) 1abcd

674 shape = ( 1abcd

675 *arrays[0].shape[:axis], 

676 sum(a.shape[axis] for a in arrays), 

677 *arrays[0].shape[axis + 1:], 

678 ) 

679 

680 out = _concatenate_recursive(arrays, axis, dtype, shape, casting) 1abcd

681 assert out.shape == shape and out.dtype == dtype 1abcd

682 return out 1abcd

683 

684def _concatenate_recursive(arrays, axis, dtype, shape, casting): 1feabcd

685 cat = {} 1abcd

686 for name in dtype.names: 1abcd

687 subarrays = [a[name] for a in arrays] 1abcd

688 base = dtype[name].base 1abcd

689 if base.names is not None: 1abcd

690 subshape = shape + dtype[name].shape 1abcd

691 y = _concatenate_recursive(subarrays, axis, base, subshape, casting) 1abcd

692 else: 

693 try: 1abcd

694 y = jnp.concatenate(subarrays, axis=axis, dtype=base) 1abcd

695 except TypeError: 1abcd

696 y = numpy.concatenate(subarrays, axis=axis, dtype=base, casting=casting) 1abcd

697 cat[name] = y 1abcd

698 return StructuredArray._array(shape, dtype, cat) 1abcd

699 

700@StructuredArray._implements(recfunctions.append_fields) 1feabcd

701def _append_fields(base, names, data, usemask=True): 1feabcd

702 assert not usemask, 'masked arrays not supported, set usemask=False' 1abcd

703 if isinstance(names, str): 1abcd

704 names = [names] 1abcd

705 data = [data] 1abcd

706 assert len(names) == len(data) 1abcd

707 arrays = base._dict.copy() 1abcd

708 arrays.update(zip(names, data)) 1abcd

709 dtype = numpy.dtype(base.dtype.descr + [ 1abcd

710 (name, array.dtype) for name, array in zip(names, data) 

711 ]) 

712 return StructuredArray._array(base.shape, dtype, arrays) 1abcd

713 

714@StructuredArray._implements(numpy.swapaxes) 1feabcd

715def _swapaxes(x, i, j): 1feabcd

716 return x.swapaxes(i, j)