Coverage for src/lsqfitgp/_array.py: 95%
398 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_array.py
2#
3# Copyright (c) 2020, 2022, 2023, 2024, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20import textwrap 1feabcd
21import math 1feabcd
23import numpy 1feabcd
24from numpy.lib import recfunctions 1feabcd
25import jax 1feabcd
26from jax import numpy as jnp 1feabcd
27from jax import tree_util 1feabcd
29# TODO use register_pytree_with_keys
30@tree_util.register_pytree_node_class 1feabcd
31class StructuredArray: 1feabcd
32 """
33 JAX-friendly imitation of a numpy structured array.
35 It behaves like a read-only numpy structured array, and you can create
36 a copy with a modified field with a jax-like syntax.
38 Examples
39 --------
40 >>> a = numpy.empty(3, dtype=[('f', float), ('g', float)])
41 >>> a = StructuredArray(a)
42 >>> a = a.at['f'].set(numpy.arange(3))
43 ... # is equivalent to a['f'] = numpy.arange(3)
45 Parameters
46 ----------
47 array : numpy array, StructuredArray
48 A structured array. An array qualifies as structured if
49 ``array.dtype.names is not None``.
51 Notes
52 -----
53 The StructuredArray is a readonly view on the input array. When you
54 change the content of a field of the StructuredArray, however, the
55 reference to the original array for that field is lost.
57 """
59 @classmethod 1feabcd
60 def _readonlyview_wrapifstructured(cls, x): 1feabcd
61 if x.dtype.names is not None: 1feabcd
62 x = cls(x) 1eabcd
63 if isinstance(x, numpy.ndarray): 1feabcd
64 x = x.view() 1feabcd
65 x.flags.writeable = False 1feabcd
66 # jax arrays and StructuredArrays are already readonly
67 return x 1feabcd
69 @classmethod 1feabcd
70 def _array(cls, s, t, d, *, check=True): 1feabcd
71 """
72 Create a new StructuredArray.
74 All methods and functions that create a new StructuredArray object
75 should use this method.
77 Parameters
78 ----------
79 s : tuple or None
80 The shape. If None, it is determined automatically from the arrays.
81 t : dtype or None
82 The dtype of the array. If None, it is determined automatically
83 (before the shape).
84 d : dict str -> array
85 The _dict of the array. The arrays, if structured, must be already
86 StructuredArrays. The order of the keys must match the order of the
87 fields.
88 check : bool
89 If True (default), check the passed values are consistent.
91 Return
92 ------
93 out : StructuredArray
94 A new StructuredArray object.
95 """
97 if t is None: 1feabcd
98 # infer the data type from the arrays in the dictionary
99 ndim = min((x.ndim for x in d.values()), default=None) 1eabcd
100 t = numpy.dtype([ 1eabcd
101 (name, x.dtype, x.shape[ndim:])
102 for name, x in d.items()
103 ])
104 # TODO infer the least common head shape instead of counting dims
106 # remove offset info since this is actually a columnar format
107 t = recfunctions.repack_fields(t, align=False, recurse=True) 1feabcd
109 if s is None: 1feabcd
110 # infer the shape from the arrays in the dictionary
111 assert d, 'can not infer array shape with no fields' 1feabcd
112 f = t.names[0] 1feabcd
113 a = d[f] 1feabcd
114 s = a.shape[:a.ndim - t[0].ndim] 1feabcd
116 if check: 116 ↛ 129line 116 didn't jump to line 129 because the condition on line 116 was always true1feabcd
117 assert len(t) == len(d) 1feabcd
118 assert t.names == tuple(d.keys()) 1feabcd
119 assert all( 1feabcd
120 x.dtype == t[f].base and x.ndim >= t[f].ndim
121 for f, x in d.items()
122 )
123 shapes = [ 1feabcd
124 x.shape[:x.ndim - t[f].ndim]
125 for f, x in d.items()
126 ]
127 assert all(s == s1 for s1 in shapes) 1feabcd
129 out = super().__new__(cls) 1feabcd
130 out.shape = s 1feabcd
131 out.dtype = t 1feabcd
132 out._dict = d 1feabcd
134 return out 1feabcd
136 def __new__(cls, array): 1feabcd
137 if isinstance(array, cls): 1feabcd
138 return array 1feabcd
139 d = { 1feabcd
140 name: cls._readonlyview_wrapifstructured(array[name])
141 for name in array.dtype.names
142 }
143 return cls._array(array.shape, array.dtype, d) 1feabcd
145 @classmethod 1feabcd
146 def from_dataframe(cls, df): 1feabcd
147 """
148 Make a StructuredArray from a DataFrame. Data is not copied if not
149 necessary.
150 """
151 d = { 1eabcd
152 col: cls._readonlyview_wrapifstructured(df[col].to_numpy())
153 for col in df.columns
154 }
155 return cls._array(None, None, d) 1eabcd
156 # TODO support polars structured dtypes
157 # TODO polars has a parameter Series.to_numpy(zero_copy_only: bool),
158 # default False. Maybe make it accessible through kw or options.
160 @classmethod 1feabcd
161 def from_dict(cls, mapping): 1feabcd
162 """
163 Make a StructuredArray from a dictionary of arrays. Data is not copied.
164 """
165 d = { 1eabcd
166 name: cls._readonlyview_wrapifstructured(value)
167 for name, value in mapping.items()
168 }
169 return cls._array(None, None, d) 1eabcd
171 @property 1feabcd
172 def size(self): 1feabcd
173 return math.prod(self.shape) 1eabcd
175 @property 1feabcd
176 def ndim(self): 1feabcd
177 return len(self.shape) 1eabcd
179 @property 1feabcd
180 def nbytes(self): 1feabcd
181 return sum(x.nbytes for x in self._dict.values())
183 @property 1feabcd
184 def T(self): 1feabcd
185 if self.ndim < 2: 1abcd
186 return self 1abcd
187 return self.swapaxes(self.ndim - 2, self.ndim - 1) 1abcd
189 # TODO this is mT, not T! Make a unit test and correct it.
191 def swapaxes(self, i, j): 1feabcd
192 shape = jax.eval_shape(lambda: jnp.empty(self.shape).swapaxes(i, j)).shape 1abcd
193 d = {k: v.swapaxes(i, j) for k, v in self._dict.items()} 1abcd
194 return self._array(shape, self.dtype, d) 1abcd
196 # TODO: doesn't this break when the indices are negative and there is
197 # an array field? Test it.
199 def __len__(self): 1feabcd
200 if self.shape: 1eabcd
201 return self.shape[0] 1eabcd
202 else:
203 raise TypeError('len() of unsized object') 1abcd
205 def __getitem__(self, key): 1feabcd
206 if isinstance(key, str): 1feabcd
207 return self._dict[key] 1feabcd
208 elif isinstance(key, list) and key and all(isinstance(k, str) for k in key): 1feabcd
209 d = { 1abcd
210 name: self._dict[name]
211 for name in key
212 }
213 return self._array(self.shape, self.dtype[key], d) 1abcd
214 else:
215 d = { 1feabcd
216 name: x[
217 (key if isinstance(key, tuple) else (key,))
218 + (slice(None),) * self.dtype[name].ndim
219 ]
220 for name, x in self._dict.items()
221 }
222 shape = jax.eval_shape(lambda: jnp.empty(self.shape)[key]).shape 1feabcd
223 return self._array(shape, self.dtype, d) 1feabcd
225 @property 1feabcd
226 def at(self): 1feabcd
227 return self._Getter(self) 1feabcd
229 class _Getter: 1feabcd
231 def __init__(self, array): 1feabcd
232 self.array = array 1feabcd
234 def __getitem__(self, key): 1feabcd
235 if key not in self.array.dtype.names: 235 ↛ 236line 235 didn't jump to line 236 because the condition on line 235 was never true1feabcd
236 raise KeyError(key)
237 return self.Setter(self.array, key) 1feabcd
239 class Setter: 1feabcd
241 def __init__(self, array, key, parent=None): 1feabcd
242 self.array = array 1feabcd
243 self.key = key 1feabcd
244 self.parent = parent 1feabcd
246 def __getitem__(self, subkey): 1feabcd
247 if subkey not in self.array.dtype[self.key].names: 247 ↛ 248line 247 didn't jump to line 248 because the condition on line 247 was never true1abcd
248 raise KeyError(subkey)
249 return self.__class__(self.array[self.key], subkey, self) 1abcd
251 def set(self, val): 1feabcd
252 assert isinstance(val, (numpy.ndarray, jnp.ndarray, StructuredArray)) 1feabcd
253 prev = self.array._dict[self.key] 1feabcd
254 # TODO support casting and broadcasting
255 assert prev.dtype == val.dtype 1feabcd
256 assert prev.shape == val.shape 1feabcd
257 d = dict(self.array._dict) 1feabcd
258 d[self.key] = self.array._readonlyview_wrapifstructured(val) 1feabcd
259 out = self.array._array(self.array.shape, self.array.dtype, d) 1feabcd
260 if self.parent: 1feabcd
261 return self.parent.set(out) 1abcd
262 else:
263 return out 1feabcd
265 def reshape(self, *shape): 1feabcd
266 """
267 Reshape the array without changing its contents. See
268 numpy.ndarray.reshape.
269 """
270 if len(shape) == 1 and hasattr(shape[0], '__len__'): 1feabcd
271 shape = shape[0] 1eabcd
272 shape = tuple(shape) 1feabcd
273 d = { 1feabcd
274 name: x.reshape(shape + self.dtype[name].shape)
275 for name, x in self._dict.items()
276 }
277 shape = numpy.empty(self.shape, []).reshape(shape).shape 1feabcd
278 return self._array(shape, self.dtype, d) 1feabcd
280 def squeeze(self, axis=None): 1feabcd
281 """
282 Remove axes of length 1. See numpy.ndarray.squeeze.
283 """
284 if axis is None: 1eabcd
285 axis = tuple(i for i, size in enumerate(self.shape) if size == 1) 1abcd
286 if not hasattr(axis, '__len__'): 1eabcd
287 axis = (axis,) 1eabcd
288 assert all(self.shape[i] == 1 for i in axis) 1eabcd
289 newshape = [size for i, size in enumerate(self.shape) if i not in axis] 1eabcd
290 return self.reshape(newshape) 1eabcd
292 def astype(self, dtype): 1feabcd
293 if dtype != self.dtype:
294 raise NotImplementedError
295 return self
297 def broadcast_to(self, shape, **kw): 1feabcd
298 """
299 Return a view of the array broadcasted to another shape. See
300 numpy.broadcast_to.
301 """
302 # raises if not broadcastable
303 numpy.broadcast_to(numpy.empty(self.shape, []), shape, **kw) 1abcd
304 d = { 1abcd
305 name: broadcast_to(x, shape + self.dtype[name].shape, **kw)
306 for name, x in self._dict.items()
307 }
308 return self._array(shape, self.dtype, d) 1abcd
310 # TODO implement flatten_with_keys
311 def tree_flatten(self): 1feabcd
312 """ JAX PyTree encoder. See `jax.tree_util.tree_flatten`. """
313 children = tuple(self._dict[key] for key in self.dtype.names) 1feabcd
314 aux = dict(shape=self.shape, dtype=self.dtype) 1feabcd
315 return children, aux 1feabcd
317 @classmethod 1feabcd
318 def tree_unflatten(cls, aux, children): 1feabcd
319 """ JAX PyTree decoder. See `jax.tree_util.tree_unflatten`. """
321 # if there are no fields, keep original shape
322 if not children: 322 ↛ 323line 322 didn't jump to line 323 because the condition on line 322 was never true1feabcd
323 return cls._array(aux['shape'], aux['dtype'], {})
325 # convert children to arrays because tree_util.tree_flatten unpacks 0d
326 # arrays
327 children = list(map(asarray, children)) 1feabcd
329 # if possible, keep original dtype shapes
330 oldtype = aux['dtype'] 1feabcd
331 compatible_tail_shapes = all( 1feabcd
332 x.shape[max(0, x.ndim - oldtype[i].ndim):] == oldtype[i].shape
333 for i, x in enumerate(children)
334 )
335 head_shapes = [ 1feabcd
336 x.shape[:max(0, x.ndim - oldtype[i].ndim)]
337 for i, x in enumerate(children)
338 ]
339 compatible_head_shapes = all(head_shapes[0] == s for s in head_shapes) 1feabcd
340 if compatible_tail_shapes and compatible_head_shapes: 1feabcd
341 dtype = numpy.dtype([ 1feabcd
342 (oldtype.names[i], x.dtype, oldtype[i].shape)
343 for i, x in enumerate(children)
344 ])
345 else:
346 dtype = None 1abcd
348 d = dict(zip(oldtype.names, children)) 1feabcd
350 return cls._array(None, dtype, d) 1feabcd
352 # TODO this breaks jax.jit(...).lower(...).compile()(...) because
353 # apparently `lower` saves the pytree def after a step of dummyfication,
354 # so the shape and dtype bases of the StructuredArray are () and object.
355 # JAX expects pytrees to have a structure which does not depend on what
356 # they store. => Quick hack: preserve the shape and dtype
357 # unconditionally, i.e., tree_unflatten can produce malformed
358 # StructuredArrays. The dictionary will contain whatever JAX puts into
359 # it. => Quicker hack: it seems to me that jax always uses None as
360 # dummy, so I could detect if all childrens are None or StructuredArray.
362 def __repr__(self): 1feabcd
363 # code from gvar https://github.com/gplepage/gvar
364 # bufferdict.pyx:BufferDict:__str__
365 out = 'StructuredArray({' 1abcd
367 listrepr = [(repr(k), repr(v)) for k, v in self._dict.items()] 1abcd
368 newlinemode = any('\n' in rv for _, rv in listrepr) 1abcd
370 for rk, rv in listrepr: 1abcd
371 if not newlinemode: 1abcd
372 out += '{}: {}, '.format(rk, rv) 1abcd
373 elif '\n' in rv: 1abcd
374 rv = rv.replace('\n', '\n ') 1abcd
375 out += '\n {}:\n {},'.format(rk, rv) 1abcd
376 else:
377 out += '\n {}: {},'.format(rk, rv) 1abcd
379 if out.endswith(', '): 1abcd
380 out = out[:-2] 1abcd
381 elif newlinemode: 1abcd
382 out += '\n' 1abcd
383 out += '})' 1abcd
385 return out 1abcd
387 # TODO try simply using the __repr__ of self._dict
389 def __array__(self): 1feabcd
390 array = numpy.empty(self.shape, self.dtype) 1abcd
391 self._copy_into_array(array) 1abcd
392 return array 1abcd
394 def _copy_into_array(self, dest): 1feabcd
395 assert self.dtype == dest.dtype 1abcd
396 assert self.shape == dest.shape 1abcd
397 for name, src in self._dict.items(): 1abcd
398 if isinstance(src, StructuredArray): 1abcd
399 src._copy_into_array(dest[name]) 1abcd
400 else:
401 dest[name][...] = src 1abcd
403 def __array_function__(self, func, types, args, kwargs): 1feabcd
404 if func not in self._handled_functions: 404 ↛ 405line 404 didn't jump to line 405 because the condition on line 404 was never true1eabcd
405 return NotImplemented
406 return self._handled_functions[func](*args, **kwargs) 1eabcd
408 _handled_functions = {} 1feabcd
410 @classmethod 1feabcd
411 def _implements(cls, np_function): 1feabcd
412 """ Register an __array_function__ implementation """
413 def decorator(func): 1feabcd
414 cls._handled_functions[np_function] = func 1feabcd
415 newdoc = f"""\ 1feabcd
416Implementation of `{np_function.__module__}.{np_function.__name__}` for `StructuredArray`.
418"""
419 if func.__doc__: 1feabcd
420 newdoc += textwrap.dedent(func.__doc__) + '\n' 1feabcd
421 newdoc += 'Original docstring below:\n\n' 1feabcd
422 newdoc += textwrap.dedent(np_function.__doc__) 1feabcd
423 func.__doc__ = newdoc 1feabcd
424 return func 1feabcd
425 return decorator 1feabcd
427@StructuredArray._implements(numpy.broadcast_to) 1feabcd
428def broadcast_to(x, shape, **kw): 1feabcd
429 """
430 Version of numpy.broadcast_to that works with StructuredArray and JAX
431 arrays.
432 """
433 if isinstance(x, StructuredArray): 1abcd
434 return x.broadcast_to(shape, **kw) 1abcd
435 elif isinstance(x, jnp.ndarray): 435 ↛ 436line 435 didn't jump to line 436 because the condition on line 435 was never true1abcd
436 return jnp.broadcast_to(x, shape, **kw)
437 else:
438 return numpy.broadcast_to(x, shape, **kw) 1abcd
440@StructuredArray._implements(numpy.broadcast_arrays) 1feabcd
441def broadcast_arrays(*arrays, **kw): 1feabcd
442 """
443 Version of numpy.broadcast_arrays that works with StructuredArray and JAX
444 arrays.
445 """
446 shapes = [a.shape for a in arrays] 1abcd
447 shape = numpy.broadcast_shapes(*shapes) 1abcd
448 return [broadcast_to(a, shape, **kw) for a in arrays] 1abcd
449 # numpy.broadcast_arrays returns a list, not a tuple
451class broadcast: 1feabcd
452 """
453 Version of numpy.broadcast that works with StructuredArray.
454 """
456 # not handled by __array_function__
458 def __init__(self, *arrays): 1feabcd
459 self.shape = numpy.broadcast_shapes(*(a.shape for a in arrays)) 1feabcd
461def asarray(x, dtype=None): 1feabcd
462 """
463 Version of `numpy.asarray` that works with `StructuredArray` and JAX arrays.
464 If `x` is not an array already, returns a JAX array if possible.
465 """
466 if isinstance(x, (StructuredArray, jnp.ndarray, numpy.ndarray)): 1feabcd
467 return x if dtype is None else x.astype(dtype) 1feabcd
468 if x is None: 1feabcd
469 return numpy.asarray(x, dtype) 1abcd
470 # partial workaround for jax issue #14506, None would be interpreted as
471 # nan by jax
472 try: 1feabcd
473 return jnp.asarray(x, dtype) 1feabcd
474 except (TypeError, ValueError): 1fabcd
475 return numpy.asarray(x, dtype) 1fabcd
477def _asarray_jaxifpossible(x): 1feabcd
478 x = asarray(x) 1feabcd
479 if x.dtype.names: 1feabcd
480 return tree_util.tree_map(_asarray_jaxifpossible, StructuredArray(x)) 1feabcd
481 if isinstance(x, numpy.ndarray): 1feabcd
482 try: 1feabcd
483 return jnp.asarray(x) 1feabcd
484 except (TypeError, ValueError): 1feabcd
485 pass 1feabcd
486 return x 1feabcd
488@StructuredArray._implements(numpy.squeeze) 1feabcd
489def _squeeze(a, axis=None): 1feabcd
490 return a.squeeze(axis) 1abcd
492@StructuredArray._implements(numpy.ix_) 1feabcd
493def _ix(*args): 1feabcd
494 args = tuple(map(asarray, args)) 1abcd
495 assert all(x.ndim == 1 for x in args) 1abcd
496 n = len(args) 1abcd
497 return tuple( 1abcd
498 x.reshape((1,) * i + (-1,) + (1,) * (n - i - 1))
499 for i, x in enumerate(args)
500 )
502def unstructured_to_structured(arr, 1feabcd
503 dtype=None,
504 names=None,
505 align=False, # TODO maybe align is totally inapplicable even with numpy arrays? What does it mean?
506 copy=False,
507 casting='unsafe'):
508 """ Like `numpy.lib.recfunctions.unstructured_to_structured`, but outputs a
509 `StructuredArray`. """
510 arr = asarray(arr) 1eabcd
511 if not arr.ndim: 1eabcd
512 raise ValueError('arr must have at least one dimension') 1abcd
513 mockup = numpy.empty((0,) + arr.shape[-1:], arr.dtype) 1eabcd
514 dummy = recfunctions.unstructured_to_structured(mockup, 1eabcd
515 dtype=dtype, names=names, align=align, copy=copy, casting=casting)
516 out, length = _unstructured_to_structured_recursive(0, (), arr, dummy.dtype, copy, casting) 1eabcd
517 assert length == arr.shape[-1] 1eabcd
518 return out 1eabcd
520def _unstructured_to_structured_recursive(idx, shape, arr, dtype, copy, casting, *strides): 1feabcd
521 arrays = {} 1eabcd
522 for i, name in enumerate(dtype.names): 1eabcd
523 base = dtype[i].base 1eabcd
524 subshape = shape + dtype[i].shape 1eabcd
525 size = math.prod(dtype[i].shape) 1eabcd
526 stride = _nd(base) 1eabcd
527 substrides = strides + ((size, stride),) 1eabcd
528 if base.names is not None: 1eabcd
529 y, newidx = _unstructured_to_structured_recursive(idx, subshape, arr, base, copy, casting, *substrides) 1eabcd
530 shift = newidx - idx 1eabcd
531 assert shift == stride 1eabcd
532 idx += size * stride 1eabcd
533 else:
534 assert stride == 1 1eabcd
535 if all(size == 1 for size, _ in strides): 1eabcd
536 indices = numpy.s_[idx:idx + size] 1eabcd
537 srcsize = size 1eabcd
538 else:
539 indices = sum(( 1abcd
540 stride * numpy.arange(size)[numpy.s_[:,] + (None,) * i]
541 for i, (size, stride) in enumerate(reversed(substrides))
542 ), start=idx)
543 indices = indices.reshape(-1) 1abcd
544 srcsize = indices.size 1abcd
545 key = numpy.s_[..., indices] 1eabcd
546 x = arr[key] 1eabcd
547 x = x.reshape(arr.shape[:-1] + subshape) 1eabcd
548 if isinstance(x, jnp.ndarray): 1eabcd
549 y = x.astype(base) 1eabcd
550 else:
551 y = x.astype(base, copy=copy, casting=casting) 1abcd
552 idx += size 1eabcd
553 arrays[name] = y 1eabcd
554 return StructuredArray._array(arr.shape[:-1] + shape, dtype, arrays), idx 1eabcd
556@StructuredArray._implements(recfunctions.structured_to_unstructured) 1feabcd
557def _structured_to_unstructured(arr, dtype=None, casting='unsafe'): 1feabcd
558 mockup = numpy.empty(0, arr.dtype) 1eabcd
559 dummy = recfunctions.structured_to_unstructured(mockup, dtype=dtype, casting=casting) 1eabcd
560 args = (arr.shape + dummy.shape[-1:], dummy.dtype) 1eabcd
561 try: 1eabcd
562 out = jnp.empty(*args) 1eabcd
563 except TypeError: 1abcd
564 out = numpy.empty(*args) 1abcd
565 # TODO can I make out column-major w.r.t. only the last column?
566 out, length = _structured_to_unstructured_recursive(0, arr, out) 1eabcd
567 assert length == dummy.shape[-1] 1eabcd
568 return out 1eabcd
570def _nd(dtype): 1feabcd
571 """ Count the number of scalars in a dtype """
572 base = dtype.base 1feabcd
573 shape = dtype.shape 1feabcd
574 size = math.prod(shape) 1feabcd
575 if base.names is None: 1feabcd
576 return size 1feabcd
577 else:
578 return size * sum(_nd(base[name]) for name in base.names) 1feabcd
580 # I use this function in many parts of the package so it should not have an
581 # underscore, even if I don't export it in the main namespace. And move it
582 # to utils, it's not specific to StructuredArray.
584def _structured_to_unstructured_recursive(idx, arr, out, *strides): 1feabcd
585 dtype = arr.dtype 1eabcd
586 for i, name in enumerate(dtype.names): 1eabcd
587 subarr = arr[name] 1eabcd
588 base = dtype[i].base 1eabcd
589 size = math.prod(dtype[i].shape) 1eabcd
590 stride = _nd(base) 1eabcd
591 substrides = strides + ((size, stride),) 1eabcd
592 if base.names is not None: 1eabcd
593 out, newidx = _structured_to_unstructured_recursive(idx, subarr, out, *substrides) 1eabcd
594 shift = newidx - idx 1eabcd
595 assert shift == stride 1eabcd
596 idx += size * stride 1eabcd
597 else:
598 assert stride == 1 1eabcd
599 if all(size == 1 for size, _ in strides): 1eabcd
600 indices = numpy.s_[idx:idx + size] 1eabcd
601 srcsize = size 1eabcd
602 else:
603 indices = sum(( 1abcd
604 stride * numpy.arange(size)[numpy.s_[:,] + (None,) * i]
605 for i, (size, stride) in enumerate(reversed(substrides))
606 ), start=idx)
607 indices = indices.reshape(-1) 1abcd
608 srcsize = indices.size 1abcd
609 key = numpy.s_[..., indices] 1eabcd
610 src = subarr.reshape(out.shape[:-1] + (srcsize,)) 1eabcd
611 if hasattr(out, 'at'): 1eabcd
612 out = out.at[key].set(src) 1eabcd
613 else:
614 out[key] = src 1abcd
615 idx += size 1eabcd
616 return out, idx 1eabcd
618@StructuredArray._implements(numpy.empty_like) 1feabcd
619def _empty_like(prototype, dtype=None, *, shape=None): 1feabcd
620 shape = prototype.shape if shape is None else shape 1abcd
621 dtype = prototype.dtype if dtype is None else dtype 1abcd
622 return _empty(shape, dtype) 1abcd
624@StructuredArray._implements(numpy.empty) 1feabcd
625def _empty(shape, dtype=float): 1feabcd
626 if hasattr(shape, '__len__'): 1abcd
627 shape = tuple(shape) 1abcd
628 else:
629 shape = (int(shape),) 1abcd
630 dtype = numpy.dtype(dtype) 1abcd
631 arrays = {} 1abcd
632 for i, name in enumerate(dtype.names): 1abcd
633 dtbase = dtype[i].base 1abcd
634 dtshape = shape + dtype[i].shape 1abcd
635 if dtbase.names is not None: 1abcd
636 y = _empty(dtshape, dtbase) 1abcd
637 else:
638 try: 1abcd
639 y = jnp.empty(dtshape, dtbase) 1abcd
640 except TypeError: 1abcd
641 y = numpy.empty(dtshape, dtbase) 1abcd
642 arrays[name] = y 1abcd
643 return StructuredArray._array(shape, dtype, arrays) 1abcd
645@StructuredArray._implements(numpy.concatenate) 1feabcd
646def _concatenate(arrays, axis=0, dtype=None, casting='same_kind'): 1feabcd
648 # checks arrays is a non-empty sequence
649 arrays = list(arrays) 1abcd
650 if not arrays: 650 ↛ 651line 650 didn't jump to line 651 because the condition on line 650 was never true1abcd
651 raise ValueError('need at least one array to concatenate')
653 # parse axis argument
654 if axis is None: 654 ↛ 655line 654 didn't jump to line 655 because the condition on line 654 was never true1abcd
655 axis = 0
656 arrays = [a.reshape(-1) for a in arrays]
657 else:
658 ndim = arrays[0].ndim 1abcd
659 assert all(a.ndim == ndim for a in arrays) 1abcd
660 assert -ndim <= axis < ndim 1abcd
661 axis %= ndim 1abcd
662 shape = arrays[0].shape 1abcd
663 assert all(a.shape[:axis] == shape[:axis] for a in arrays) 1abcd
664 assert all(a.shape[axis + 1:] == shape[axis + 1:] for a in arrays) 1abcd
666 dtype = numpy.result_type(*(a.dtype for a in arrays)) 1abcd
667 assert all(numpy.can_cast(a.dtype, dtype, casting) for a in arrays) 1abcd
668 shape = ( 1abcd
669 *arrays[0].shape[:axis],
670 sum(a.shape[axis] for a in arrays),
671 *arrays[0].shape[axis + 1:],
672 )
674 out = _concatenate_recursive(arrays, axis, dtype, shape, casting) 1abcd
675 assert out.shape == shape and out.dtype == dtype 1abcd
676 return out 1abcd
678def _concatenate_recursive(arrays, axis, dtype, shape, casting): 1feabcd
679 cat = {} 1abcd
680 for name in dtype.names: 1abcd
681 subarrays = [a[name] for a in arrays] 1abcd
682 base = dtype[name].base 1abcd
683 if base.names is not None: 1abcd
684 subshape = shape + dtype[name].shape 1abcd
685 y = _concatenate_recursive(subarrays, axis, base, subshape, casting) 1abcd
686 else:
687 try: 1abcd
688 y = jnp.concatenate(subarrays, axis=axis, dtype=base) 1abcd
689 except TypeError: 1abcd
690 y = numpy.concatenate(subarrays, axis=axis, dtype=base, casting=casting) 1abcd
691 cat[name] = y 1abcd
692 return StructuredArray._array(shape, dtype, cat) 1abcd
694@StructuredArray._implements(recfunctions.append_fields) 1feabcd
695def _append_fields(base, names, data, usemask=True): 1feabcd
696 assert not usemask, 'masked arrays not supported, set usemask=False' 1abcd
697 if isinstance(names, str): 1abcd
698 names = [names] 1abcd
699 data = [data] 1abcd
700 assert len(names) == len(data) 1abcd
701 arrays = base._dict.copy() 1abcd
702 arrays.update(zip(names, data)) 1abcd
703 dtype = numpy.dtype(base.dtype.descr + [ 1abcd
704 (name, array.dtype) for name, array in zip(names, data)
705 ])
706 return StructuredArray._array(base.shape, dtype, arrays) 1abcd
708@StructuredArray._implements(numpy.swapaxes) 1feabcd
709def _swapaxes(x, i, j): 1feabcd
710 return x.swapaxes(i, j)