Coverage for src/lsqfitgp/_array.py: 94%
404 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-12 12:42 +0000
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-12 12:42 +0000
1# lsqfitgp/_array.py
2#
3# Copyright (c) 2020, 2022, 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import textwrap 1feabcd
21import math 1feabcd
23import numpy 1feabcd
24from numpy.lib import recfunctions 1feabcd
25import jax 1feabcd
26from jax import numpy as jnp 1feabcd
27from jax import tree_util 1feabcd
29# TODO use register_pytree_with_keys
30@tree_util.register_pytree_node_class 1feabcd
31class StructuredArray: 1feabcd
32 """
33 JAX-friendly imitation of a numpy structured array.
35 It behaves like a read-only numpy structured array, and you can create
36 a copy with a modified field with a jax-like syntax.
38 Examples
39 --------
40 >>> a = numpy.empty(3, dtype=[('f', float), ('g', float)])
41 >>> a = StructuredArray(a)
42 >>> a = a.at['f'].set(numpy.arange(3))
43 ... # is equivalent to a['f'] = numpy.arange(3)
45 Parameters
46 ----------
47 array : numpy array, StructuredArray
48 A structured array. An array qualifies as structured if
49 ``array.dtype.names is not None``.
51 Notes
52 -----
53 The StructuredArray is a readonly view on the input array. When you
54 change the content of a field of the StructuredArray, however, the
55 reference to the original array for that field is lost.
57 """
59 @classmethod 1feabcd
60 def _readonlyview_wrapifstructured(cls, x): 1feabcd
61 if x.dtype.names is not None: 1feabcd
62 x = cls(x) 1eabcd
63 if isinstance(x, numpy.ndarray): 1feabcd
64 x = x.view() 1feabcd
65 x.flags.writeable = False 1feabcd
66 # jax arrays and StructuredArrays are already readonly
67 return x 1feabcd
69 @classmethod 1feabcd
70 def _array(cls, s, t, d, *, check=True): 1feabcd
71 """
72 Create a new StructuredArray.
74 All methods and functions that create a new StructuredArray object
75 should use this method.
77 Parameters
78 ----------
79 s : tuple or None
80 The shape. If None, it is determined automatically from the arrays.
81 t : dtype or None
82 The dtype of the array. If None, it is determined automatically
83 (before the shape).
84 d : dict str -> array
85 The _dict of the array. The arrays, if structured, must be already
86 StructuredArrays. The order of the keys must match the order of the
87 fields.
88 check : bool
89 If True (default), check the passed values are consistent.
91 Return
92 ------
93 out : StructuredArray
94 A new StructuredArray object.
95 """
97 if t is None: 1feabcd
98 # infer the data type from the arrays in the dictionary
99 ndim = min((x.ndim for x in d.values()), default=None) 1eabcd
100 t = numpy.dtype([ 1eabcd
101 (name, x.dtype, x.shape[ndim:])
102 for name, x in d.items()
103 ])
104 # TODO infer the least common head shape instead of counting dims
106 # remove offset info since this is actually a columnar format
107 t = recfunctions.repack_fields(t, align=False, recurse=True) 1feabcd
109 if s is None: 1feabcd
110 # infer the shape from the arrays in the dictionary
111 assert d, 'can not infer array shape with no fields' 1feabcd
112 f = t.names[0] 1feabcd
113 a = d[f] 1feabcd
114 s = a.shape[:a.ndim - t[0].ndim] 1feabcd
116 if check: 116 ↛ 129line 116 didn't jump to line 129 because the condition on line 116 was always true1feabcd
117 assert len(t) == len(d) 1feabcd
118 assert t.names == tuple(d.keys()) 1feabcd
119 assert all( 1feabcd
120 x.dtype == t[f].base and x.ndim >= t[f].ndim
121 for f, x in d.items()
122 )
123 shapes = [ 1feabcd
124 x.shape[:x.ndim - t[f].ndim]
125 for f, x in d.items()
126 ]
127 assert all(s == s1 for s1 in shapes) 1feabcd
129 out = super().__new__(cls) 1feabcd
130 out.shape = s 1feabcd
131 out.dtype = t 1feabcd
132 out._dict = d 1feabcd
134 return out 1feabcd
136 def __new__(cls, array): 1feabcd
137 if isinstance(array, cls): 1feabcd
138 return array 1feabcd
139 d = { 1feabcd
140 name: cls._readonlyview_wrapifstructured(array[name])
141 for name in array.dtype.names
142 }
143 return cls._array(array.shape, array.dtype, d) 1feabcd
145 @classmethod 1feabcd
146 def from_dataframe(cls, df): 1feabcd
147 """
148 Make a StructuredArray from a DataFrame. Data is not copied if not
149 necessary.
150 """
151 d = { 1eabcd
152 col: cls._readonlyview_wrapifstructured(df[col].to_numpy())
153 for col in df.columns
154 }
155 return cls._array(None, None, d) 1eabcd
156 # TODO support polars structured dtypes
157 # TODO polars has a parameter Series.to_numpy(zero_copy_only: bool),
158 # default False. Maybe make it accessible through kw or options.
160 @classmethod 1feabcd
161 def from_dict(cls, mapping): 1feabcd
162 """
163 Make a StructuredArray from a dictionary of arrays. Data is not copied.
164 """
165 d = { 1eabcd
166 name: cls._readonlyview_wrapifstructured(value)
167 for name, value in mapping.items()
168 }
169 return cls._array(None, None, d) 1eabcd
171 @property 1feabcd
172 def size(self): 1feabcd
173 return math.prod(self.shape) 1eabcd
175 @property 1feabcd
176 def ndim(self): 1feabcd
177 return len(self.shape) 1eabcd
179 @property 1feabcd
180 def nbytes(self): 1feabcd
181 return sum(x.nbytes for x in self._dict.values())
183 @property 1feabcd
184 def T(self): 1feabcd
185 if self.ndim < 2: 1abcd
186 return self 1abcd
187 return self.swapaxes(self.ndim - 2, self.ndim - 1) 1abcd
189 # TODO this is mT, not T! Make a unit test and correct it.
191 def swapaxes(self, i, j): 1feabcd
192 shape = jax.eval_shape(lambda: jnp.empty(self.shape).swapaxes(i, j)).shape 1abcd
193 d = {k: v.swapaxes(i, j) for k, v in self._dict.items()} 1abcd
194 return self._array(shape, self.dtype, d) 1abcd
196 # TODO: doesn't this break when the indices are negative and there is
197 # an array field? Test it.
199 def __len__(self): 1feabcd
200 if self.shape: 1eabcd
201 return self.shape[0] 1eabcd
202 else:
203 raise TypeError('len() of unsized object') 1abcd
205 def __getitem__(self, key): 1feabcd
206 if isinstance(key, str): 1feabcd
207 return self._dict[key] 1feabcd
208 elif isinstance(key, list) and key and all(isinstance(k, str) for k in key): 1feabcd
209 d = { 1abcd
210 name: self._dict[name]
211 for name in key
212 }
213 return self._array(self.shape, self.dtype[key], d) 1abcd
214 else:
215 d = { 1feabcd
216 name: x[
217 (key if isinstance(key, tuple) else (key,))
218 + (slice(None),) * self.dtype[name].ndim
219 ]
220 for name, x in self._dict.items()
221 }
222 shape = jax.eval_shape(lambda: jnp.empty(self.shape)[key]).shape 1feabcd
223 return self._array(shape, self.dtype, d) 1feabcd
225 @property 1feabcd
226 def at(self): 1feabcd
227 return self._Getter(self) 1feabcd
229 class _Getter: 1feabcd
231 def __init__(self, array): 1feabcd
232 self.array = array 1feabcd
234 def __getitem__(self, key): 1feabcd
235 if key not in self.array.dtype.names: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true1feabcd
236 raise KeyError(key)
237 return self.Setter(self.array, key) 1feabcd
239 class Setter: 1feabcd
241 def __init__(self, array, key, parent=None): 1feabcd
242 self.array = array 1feabcd
243 self.key = key 1feabcd
244 self.parent = parent 1feabcd
246 def __getitem__(self, subkey): 1feabcd
247 if subkey not in self.array.dtype[self.key].names: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true1abcd
248 raise KeyError(subkey)
249 return self.__class__(self.array[self.key], subkey, self) 1abcd
251 def set(self, val): 1feabcd
252 assert isinstance(val, (numpy.ndarray, jnp.ndarray, StructuredArray)) 1feabcd
253 prev = self.array._dict[self.key] 1feabcd
254 # TODO support casting and broadcasting
255 assert prev.dtype == val.dtype 1feabcd
256 assert prev.shape == val.shape 1feabcd
257 d = dict(self.array._dict) 1feabcd
258 d[self.key] = self.array._readonlyview_wrapifstructured(val) 1feabcd
259 out = self.array._array(self.array.shape, self.array.dtype, d) 1feabcd
260 if self.parent: 1feabcd
261 return self.parent.set(out) 1abcd
262 else:
263 return out 1feabcd
265 def reshape(self, *shape): 1feabcd
266 """
267 Reshape the array without changing its contents. See
268 numpy.ndarray.reshape.
269 """
270 if len(shape) == 1 and hasattr(shape[0], '__len__'): 1feabcd
271 shape = shape[0] 1eabcd
272 shape = tuple(shape) 1feabcd
273 d = { 1feabcd
274 name: x.reshape(shape + self.dtype[name].shape)
275 for name, x in self._dict.items()
276 }
277 shape = numpy.empty(self.shape, []).reshape(shape).shape 1feabcd
278 return self._array(shape, self.dtype, d) 1feabcd
280 def squeeze(self, axis=None): 1feabcd
281 """
282 Remove axes of length 1. See numpy.ndarray.squeeze.
283 """
284 if axis is None: 1eabcd
285 axis = tuple(i for i, size in enumerate(self.shape) if size == 1) 1abcd
286 if not hasattr(axis, '__len__'): 1eabcd
287 axis = (axis,) 1eabcd
288 assert all(self.shape[i] == 1 for i in axis) 1eabcd
289 newshape = [size for i, size in enumerate(self.shape) if i not in axis] 1eabcd
290 return self.reshape(newshape) 1eabcd
292 def astype(self, dtype): 1feabcd
293 if dtype != self.dtype:
294 raise NotImplementedError
295 return self
297 def broadcast_to(self, shape, **kw): 1feabcd
298 """
299 Return a view of the array broadcasted to another shape. See
300 numpy.broadcast_to.
301 """
302 # raises if not broadcastable
303 numpy.broadcast_to(numpy.empty(self.shape, []), shape, **kw) 1abcd
304 d = { 1abcd
305 name: broadcast_to(x, shape + self.dtype[name].shape, **kw)
306 for name, x in self._dict.items()
307 }
308 return self._array(shape, self.dtype, d) 1abcd
310 # TODO implement flatten_with_keys
311 def tree_flatten(self): 1feabcd
312 """ JAX PyTree encoder. See `jax.tree_util.tree_flatten`. """
313 children = tuple(self._dict[key] for key in self.dtype.names) 1feabcd
314 aux = dict(shape=self.shape, dtype=self.dtype) 1feabcd
315 return children, aux 1feabcd
317 @classmethod 1feabcd
318 def tree_unflatten(cls, aux, children): 1feabcd
319 """ JAX PyTree decoder. See `jax.tree_util.tree_unflatten`. """
321 # if there are no fields, keep original shape
322 if not children: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true1feabcd
323 return cls._array(aux['shape'], aux['dtype'], {})
325 # convert children to arrays because tree_util.tree_flatten unpacks 0d
326 # arrays
327 children = list(map(asarray, children)) 1feabcd
329 # if possible, keep original dtype shapes
330 oldtype = aux['dtype'] 1feabcd
331 compatible_tail_shapes = all( 1feabcd
332 x.shape[max(0, x.ndim - oldtype[i].ndim):] == oldtype[i].shape
333 for i, x in enumerate(children)
334 )
335 head_shapes = [ 1feabcd
336 x.shape[:max(0, x.ndim - oldtype[i].ndim)]
337 for i, x in enumerate(children)
338 ]
339 compatible_head_shapes = all(head_shapes[0] == s for s in head_shapes) 1feabcd
340 if compatible_tail_shapes and compatible_head_shapes: 1feabcd
341 dtype = numpy.dtype([ 1feabcd
342 (oldtype.names[i], x.dtype, oldtype[i].shape)
343 for i, x in enumerate(children)
344 ])
345 else:
346 dtype = None 1abcd
348 d = dict(zip(oldtype.names, children)) 1feabcd
350 return cls._array(None, dtype, d) 1feabcd
352 # TODO this breaks jax.jit(...).lower(...).compile()(...) because
353 # apparently `lower` saves the pytree def after a step of dummyfication,
354 # so the shape and dtype bases of the StructuredArray are () and object.
355 # JAX expects pytrees to have a structure which does not depend on what
356 # they store. => Quick hack: preserve the shape and dtype
357 # unconditionally, i.e., tree_unflatten can produce malformed
358 # StructuredArrays. The dictionary will contain whatever JAX puts into
359 # it. => Quicker hack: it seems to me that jax always uses None as
360 # dummy, so I could detect if all childrens are None or StructuredArray.
362 def __repr__(self): 1feabcd
363 # code from gvar https://github.com/gplepage/gvar
364 # bufferdict.pyx:BufferDict:__str__
365 out = 'StructuredArray({' 1abcd
367 listrepr = [(repr(k), repr(v)) for k, v in self._dict.items()] 1abcd
368 newlinemode = any('\n' in rv for _, rv in listrepr) 1abcd
370 for rk, rv in listrepr: 1abcd
371 if not newlinemode: 1abcd
372 out += '{}: {}, '.format(rk, rv) 1abcd
373 elif '\n' in rv: 1abcd
374 rv = rv.replace('\n', '\n ') 1abcd
375 out += '\n {}:\n {},'.format(rk, rv) 1abcd
376 else:
377 out += '\n {}: {},'.format(rk, rv) 1abcd
379 if out.endswith(', '): 1abcd
380 out = out[:-2] 1abcd
381 elif newlinemode: 1abcd
382 out += '\n' 1abcd
383 out += '})' 1abcd
385 return out 1abcd
387 # TODO try simply using the __repr__ of self._dict
389 def __array__(self, copy=None, dtype=None): 1feabcd
390 if copy is False: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true1abcd
391 raise ValueError('StructuredArray has to be copied when converted to a numpy array')
392 if dtype is not None: 392 ↛ 393line 392 didn't jump to line 393 because the condition on line 392 was never true1abcd
393 dtype = numpy.dtype(dtype)
394 if dtype != self.dtype:
395 raise ValueError('StructuredArray can not be converted to a numpy array with a different dtype')
396 array = numpy.empty(self.shape, self.dtype) 1abcd
397 self._copy_into_array(array) 1abcd
398 return array 1abcd
400 def _copy_into_array(self, dest): 1feabcd
401 assert self.dtype == dest.dtype 1abcd
402 assert self.shape == dest.shape 1abcd
403 for name, src in self._dict.items(): 1abcd
404 if isinstance(src, StructuredArray): 1abcd
405 src._copy_into_array(dest[name]) 1abcd
406 else:
407 dest[name][...] = src 1abcd
409 def __array_function__(self, func, types, args, kwargs): 1feabcd
410 if func not in self._handled_functions: 410 ↛ 411line 410 didn't jump to line 411 because the condition on line 410 was never true1eabcd
411 return NotImplemented
412 return self._handled_functions[func](*args, **kwargs) 1eabcd
414 _handled_functions = {} 1feabcd
416 @classmethod 1feabcd
417 def _implements(cls, np_function): 1feabcd
418 """ Register an __array_function__ implementation """
419 def decorator(func): 1feabcd
420 cls._handled_functions[np_function] = func 1feabcd
421 newdoc = f"""\ 1feabcd
422Implementation of `{np_function.__module__}.{np_function.__name__}` for `StructuredArray`.
424"""
425 if func.__doc__: 1feabcd
426 newdoc += textwrap.dedent(func.__doc__) + '\n' 1feabcd
427 newdoc += 'Original docstring below:\n\n' 1feabcd
428 newdoc += textwrap.dedent(np_function.__doc__) 1feabcd
429 func.__doc__ = newdoc 1feabcd
430 return func 1feabcd
431 return decorator 1feabcd
433@StructuredArray._implements(numpy.broadcast_to) 1feabcd
434def broadcast_to(x, shape, **kw): 1feabcd
435 """
436 Version of numpy.broadcast_to that works with StructuredArray and JAX
437 arrays.
438 """
439 if isinstance(x, StructuredArray): 1abcd
440 return x.broadcast_to(shape, **kw) 1abcd
441 elif isinstance(x, jnp.ndarray): 441 ↛ 442line 441 didn't jump to line 442 because the condition on line 441 was never true1abcd
442 return jnp.broadcast_to(x, shape, **kw)
443 else:
444 return numpy.broadcast_to(x, shape, **kw) 1abcd
446@StructuredArray._implements(numpy.broadcast_arrays) 1feabcd
447def broadcast_arrays(*arrays, **kw): 1feabcd
448 """
449 Version of numpy.broadcast_arrays that works with StructuredArray and JAX
450 arrays.
451 """
452 shapes = [a.shape for a in arrays] 1abcd
453 shape = numpy.broadcast_shapes(*shapes) 1abcd
454 return [broadcast_to(a, shape, **kw) for a in arrays] 1abcd
455 # numpy.broadcast_arrays returns a list, not a tuple
457class broadcast: 1feabcd
458 """
459 Version of numpy.broadcast that works with StructuredArray.
460 """
462 # not handled by __array_function__
464 def __init__(self, *arrays): 1feabcd
465 self.shape = numpy.broadcast_shapes(*(a.shape for a in arrays)) 1feabcd
467def asarray(x, dtype=None): 1feabcd
468 """
469 Version of `numpy.asarray` that works with `StructuredArray` and JAX arrays.
470 If `x` is not an array already, returns a JAX array if possible.
471 """
472 if isinstance(x, (StructuredArray, jnp.ndarray, numpy.ndarray)): 1feabcd
473 return x if dtype is None else x.astype(dtype) 1feabcd
474 if x is None: 1feabcd
475 return numpy.asarray(x, dtype) 1abcd
476 # partial workaround for jax issue #14506, None would be interpreted as
477 # nan by jax
478 try: 1feabcd
479 return jnp.asarray(x, dtype) 1feabcd
480 except (TypeError, ValueError): 1fabcd
481 return numpy.asarray(x, dtype) 1fabcd
483def _asarray_jaxifpossible(x): 1feabcd
484 x = asarray(x) 1feabcd
485 if x.dtype.names: 1feabcd
486 return tree_util.tree_map(_asarray_jaxifpossible, StructuredArray(x)) 1feabcd
487 if isinstance(x, numpy.ndarray): 1feabcd
488 try: 1feabcd
489 return jnp.asarray(x) 1feabcd
490 except (TypeError, ValueError): 1feabcd
491 pass 1feabcd
492 return x 1feabcd
494@StructuredArray._implements(numpy.squeeze) 1feabcd
495def _squeeze(a, axis=None): 1feabcd
496 return a.squeeze(axis) 1abcd
498@StructuredArray._implements(numpy.ix_) 1feabcd
499def _ix(*args): 1feabcd
500 args = tuple(map(asarray, args)) 1abcd
501 assert all(x.ndim == 1 for x in args) 1abcd
502 n = len(args) 1abcd
503 return tuple( 1abcd
504 x.reshape((1,) * i + (-1,) + (1,) * (n - i - 1))
505 for i, x in enumerate(args)
506 )
508def unstructured_to_structured(arr, 1feabcd
509 dtype=None,
510 names=None,
511 align=False, # TODO maybe align is totally inapplicable even with numpy arrays? What does it mean?
512 copy=False,
513 casting='unsafe'):
514 """ Like `numpy.lib.recfunctions.unstructured_to_structured`, but outputs a
515 `StructuredArray`. """
516 arr = asarray(arr) 1eabcd
517 if not arr.ndim: 1eabcd
518 raise ValueError('arr must have at least one dimension') 1abcd
519 mockup = numpy.empty((0,) + arr.shape[-1:], arr.dtype) 1eabcd
520 dummy = recfunctions.unstructured_to_structured(mockup, 1eabcd
521 dtype=dtype, names=names, align=align, copy=copy, casting=casting)
522 out, length = _unstructured_to_structured_recursive(0, (), arr, dummy.dtype, copy, casting) 1eabcd
523 assert length == arr.shape[-1] 1eabcd
524 return out 1eabcd
526def _unstructured_to_structured_recursive(idx, shape, arr, dtype, copy, casting, *strides): 1feabcd
527 arrays = {} 1eabcd
528 for i, name in enumerate(dtype.names): 1eabcd
529 base = dtype[i].base 1eabcd
530 subshape = shape + dtype[i].shape 1eabcd
531 size = math.prod(dtype[i].shape) 1eabcd
532 stride = _nd(base) 1eabcd
533 substrides = strides + ((size, stride),) 1eabcd
534 if base.names is not None: 1eabcd
535 y, newidx = _unstructured_to_structured_recursive(idx, subshape, arr, base, copy, casting, *substrides) 1eabcd
536 shift = newidx - idx 1eabcd
537 assert shift == stride 1eabcd
538 idx += size * stride 1eabcd
539 else:
540 assert stride == 1 1eabcd
541 if all(size == 1 for size, _ in strides): 1eabcd
542 indices = numpy.s_[idx:idx + size] 1eabcd
543 srcsize = size 1eabcd
544 else:
545 indices = sum(( 1abcd
546 stride * numpy.arange(size)[numpy.s_[:,] + (None,) * i]
547 for i, (size, stride) in enumerate(reversed(substrides))
548 ), start=idx)
549 indices = indices.reshape(-1) 1abcd
550 srcsize = indices.size 1abcd
551 key = numpy.s_[..., indices] 1eabcd
552 x = arr[key] 1eabcd
553 x = x.reshape(arr.shape[:-1] + subshape) 1eabcd
554 if isinstance(x, jnp.ndarray): 1eabcd
555 y = x.astype(base) 1eabcd
556 else:
557 y = x.astype(base, copy=copy, casting=casting) 1abcd
558 idx += size 1eabcd
559 arrays[name] = y 1eabcd
560 return StructuredArray._array(arr.shape[:-1] + shape, dtype, arrays), idx 1eabcd
562@StructuredArray._implements(recfunctions.structured_to_unstructured) 1feabcd
563def _structured_to_unstructured(arr, dtype=None, casting='unsafe'): 1feabcd
564 mockup = numpy.empty(0, arr.dtype) 1eabcd
565 dummy = recfunctions.structured_to_unstructured(mockup, dtype=dtype, casting=casting) 1eabcd
566 args = (arr.shape + dummy.shape[-1:], dummy.dtype) 1eabcd
567 try: 1eabcd
568 out = jnp.empty(*args) 1eabcd
569 except TypeError: 1abcd
570 out = numpy.empty(*args) 1abcd
571 # TODO can I make out column-major w.r.t. only the last column?
572 out, length = _structured_to_unstructured_recursive(0, arr, out) 1eabcd
573 assert length == dummy.shape[-1] 1eabcd
574 return out 1eabcd
576def _nd(dtype): 1feabcd
577 """ Count the number of scalars in a dtype """
578 base = dtype.base 1feabcd
579 shape = dtype.shape 1feabcd
580 size = math.prod(shape) 1feabcd
581 if base.names is None: 1feabcd
582 return size 1feabcd
583 else:
584 return size * sum(_nd(base[name]) for name in base.names) 1feabcd
586 # I use this function in many parts of the package so it should not have an
587 # underscore, even if I don't export it in the main namespace. And move it
588 # to utils, it's not specific to StructuredArray.
590def _structured_to_unstructured_recursive(idx, arr, out, *strides): 1feabcd
591 dtype = arr.dtype 1eabcd
592 for i, name in enumerate(dtype.names): 1eabcd
593 subarr = arr[name] 1eabcd
594 base = dtype[i].base 1eabcd
595 size = math.prod(dtype[i].shape) 1eabcd
596 stride = _nd(base) 1eabcd
597 substrides = strides + ((size, stride),) 1eabcd
598 if base.names is not None: 1eabcd
599 out, newidx = _structured_to_unstructured_recursive(idx, subarr, out, *substrides) 1eabcd
600 shift = newidx - idx 1eabcd
601 assert shift == stride 1eabcd
602 idx += size * stride 1eabcd
603 else:
604 assert stride == 1 1eabcd
605 if all(size == 1 for size, _ in strides): 1eabcd
606 indices = numpy.s_[idx:idx + size] 1eabcd
607 srcsize = size 1eabcd
608 else:
609 indices = sum(( 1abcd
610 stride * numpy.arange(size)[numpy.s_[:,] + (None,) * i]
611 for i, (size, stride) in enumerate(reversed(substrides))
612 ), start=idx)
613 indices = indices.reshape(-1) 1abcd
614 srcsize = indices.size 1abcd
615 key = numpy.s_[..., indices] 1eabcd
616 src = subarr.reshape(out.shape[:-1] + (srcsize,)) 1eabcd
617 if hasattr(out, 'at'): 1eabcd
618 out = out.at[key].set(src) 1eabcd
619 else:
620 out[key] = src 1abcd
621 idx += size 1eabcd
622 return out, idx 1eabcd
624@StructuredArray._implements(numpy.empty_like) 1feabcd
625def _empty_like(prototype, dtype=None, *, shape=None): 1feabcd
626 shape = prototype.shape if shape is None else shape 1abcd
627 dtype = prototype.dtype if dtype is None else dtype 1abcd
628 return _empty(shape, dtype) 1abcd
630@StructuredArray._implements(numpy.empty) 1feabcd
631def _empty(shape, dtype=float): 1feabcd
632 if hasattr(shape, '__len__'): 1abcd
633 shape = tuple(shape) 1abcd
634 else:
635 shape = (int(shape),) 1abcd
636 dtype = numpy.dtype(dtype) 1abcd
637 arrays = {} 1abcd
638 for i, name in enumerate(dtype.names): 1abcd
639 dtbase = dtype[i].base 1abcd
640 dtshape = shape + dtype[i].shape 1abcd
641 if dtbase.names is not None: 1abcd
642 y = _empty(dtshape, dtbase) 1abcd
643 else:
644 try: 1abcd
645 y = jnp.empty(dtshape, dtbase) 1abcd
646 except TypeError: 1abcd
647 y = numpy.empty(dtshape, dtbase) 1abcd
648 arrays[name] = y 1abcd
649 return StructuredArray._array(shape, dtype, arrays) 1abcd
651@StructuredArray._implements(numpy.concatenate) 1feabcd
652def _concatenate(arrays, axis=0, dtype=None, casting='same_kind'): 1feabcd
654 # checks arrays is a non-empty sequence
655 arrays = list(arrays) 1abcd
656 if not arrays: 656 ↛ 657line 656 didn't jump to line 657 because the condition on line 656 was never true1abcd
657 raise ValueError('need at least one array to concatenate')
659 # parse axis argument
660 if axis is None: 660 ↛ 661line 660 didn't jump to line 661 because the condition on line 660 was never true1abcd
661 axis = 0
662 arrays = [a.reshape(-1) for a in arrays]
663 else:
664 ndim = arrays[0].ndim 1abcd
665 assert all(a.ndim == ndim for a in arrays) 1abcd
666 assert -ndim <= axis < ndim 1abcd
667 axis %= ndim 1abcd
668 shape = arrays[0].shape 1abcd
669 assert all(a.shape[:axis] == shape[:axis] for a in arrays) 1abcd
670 assert all(a.shape[axis + 1:] == shape[axis + 1:] for a in arrays) 1abcd
672 dtype = numpy.result_type(*(a.dtype for a in arrays)) 1abcd
673 assert all(numpy.can_cast(a.dtype, dtype, casting) for a in arrays) 1abcd
674 shape = ( 1abcd
675 *arrays[0].shape[:axis],
676 sum(a.shape[axis] for a in arrays),
677 *arrays[0].shape[axis + 1:],
678 )
680 out = _concatenate_recursive(arrays, axis, dtype, shape, casting) 1abcd
681 assert out.shape == shape and out.dtype == dtype 1abcd
682 return out 1abcd
684def _concatenate_recursive(arrays, axis, dtype, shape, casting): 1feabcd
685 cat = {} 1abcd
686 for name in dtype.names: 1abcd
687 subarrays = [a[name] for a in arrays] 1abcd
688 base = dtype[name].base 1abcd
689 if base.names is not None: 1abcd
690 subshape = shape + dtype[name].shape 1abcd
691 y = _concatenate_recursive(subarrays, axis, base, subshape, casting) 1abcd
692 else:
693 try: 1abcd
694 y = jnp.concatenate(subarrays, axis=axis, dtype=base) 1abcd
695 except TypeError: 1abcd
696 y = numpy.concatenate(subarrays, axis=axis, dtype=base, casting=casting) 1abcd
697 cat[name] = y 1abcd
698 return StructuredArray._array(shape, dtype, cat) 1abcd
700@StructuredArray._implements(recfunctions.append_fields) 1feabcd
701def _append_fields(base, names, data, usemask=True): 1feabcd
702 assert not usemask, 'masked arrays not supported, set usemask=False' 1abcd
703 if isinstance(names, str): 1abcd
704 names = [names] 1abcd
705 data = [data] 1abcd
706 assert len(names) == len(data) 1abcd
707 arrays = base._dict.copy() 1abcd
708 arrays.update(zip(names, data)) 1abcd
709 dtype = numpy.dtype(base.dtype.descr + [ 1abcd
710 (name, array.dtype) for name, array in zip(names, data)
711 ])
712 return StructuredArray._array(base.shape, dtype, arrays) 1abcd
714@StructuredArray._implements(numpy.swapaxes) 1feabcd
715def _swapaxes(x, i, j): 1feabcd
716 return x.swapaxes(i, j)