Coverage for src/lsqfitgp/_Kernel/_crosskernel.py: 99%

401 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_Kernel/_crosskernel.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import enum 1feabcd

21import functools 1feabcd

22import sys 1feabcd

23import collections 1feabcd

24import types 1feabcd

25import abc 1feabcd

26import warnings 1feabcd

27 

28import numpy 1feabcd

29from jax import numpy as jnp 1feabcd

30 

31from .. import _array 1feabcd

32from .. import _jaxext 1feabcd

33from .. import _utils 1feabcd

34 

35from . import _util 1feabcd

36 

37@functools.lru_cache(maxsize=None) 1feabcd

38def least_common_superclass(*classes): 1feabcd

39 """ 

40 Find a "least" common superclass. The class is searched in all the MROs, 

41 but the comparison is done with `issubclass` to support virtual inheritance. 

42 """ 

43 mros = [c.__mro__ for c in classes] 1feabcd

44 indices = [0] * len(mros) 1feabcd

45 for i, mroi in enumerate(mros): 1feabcd

46 for j, mroj in enumerate(mros): 1feabcd

47 if i == j: 1feabcd

48 continue 1feabcd

49 while not issubclass(mroi[0], mroj[indices[j]]): 1feabcd

50 indices[j] += 1 1feabcd

51 idx = numpy.argmin(indices) 1feabcd

52 return mros[idx][indices[idx]] 1feabcd

53 

54class CrossKernel: 1feabcd

55 r""" 

56  

57 Base class to represent kernels, i.e., covariance functions. 

58 

59 A kernel is a two-argument function that computes the covariance between 

60 two functions at some points according to some probability distribution: 

61 

62 .. math:: 

63 \mathrm{kernel}(x, y) = \mathrm{Cov}[f(x), g(y)]. 

64  

65 `CrossKernel` objects are callable, the signature is ``obj(x, y)``, and 

66 they can be summed and multiplied between them and with scalars. They 

67 are immutable; all operations return new objects. 

68 

69 Parameters 

70 ---------- 

71 core : callable 

72 A function with signature ``core(x, y)``, where ``x`` and ``y`` 

73 are two broadcastable numpy arrays, which computes the value of the 

74 kernel. 

75 scale, loc, derivable, maxdim, dim : 

76 If specified, these arguments are passed as arguments to the 

77 correspondingly named operators, in the order listed here. See `linop`. 

78 Briefly: the kernel selects only fields `dim` in the input, checks the 

79 dimensionality against `maxdim`, checks there are not too many 

80 derivatives taken on the arguments, then transforms as ``(x - loc) / 

81 scale``. If any argument is callable, it is passed `**kw` and must 

82 return the actual argument. If an argument is a tuple, it is interpreted 

83 as a pair of arguments. 

84 forcekron : bool, default False 

85 If True, apply ``.transf('forcekron')`` to the kernel, before the 

86 operations above. Available only for `Kernel`. 

87 batchbytes : number, optional 

88 If specified, apply ``.batch(batchbytes)`` to the kernel. 

89 dynkw : dict, optional 

90 Additional keyword arguments passed to `core` that can be modified 

91 by transformations. Deleted by transformations by default. 

92 **initkw : 

93 Additional keyword arguments passed to `core` that can be read but not 

94 changed by transformations. 

95 

96 Attributes 

97 ---------- 

98 initkw : dict 

99 The `initkw` argument. 

100 dynkw : dict 

101 The `dynkw` argument, or a modification of it if the object has been 

102 transformed. 

103 core : callable 

104 The `core` argument partially evaluated on `initkw`, or another 

105 function wrapping it if the object has been transformed. 

106  

107 Methods 

108 ------- 

109 batch 

110 linop 

111 algop 

112 transf 

113 register_transf 

114 register_linop 

115 register_corelinop 

116 register_xtransf 

117 register_algop 

118 register_ufuncalgop 

119 transf_help 

120 has_transf 

121 list_transf 

122 super_transf 

123 make_linop_family 

124 

125 See also 

126 -------- 

127 Kernel 

128 

129 Notes 

130 ----- 

131 The predefined class hierarchy and the class logic of the transformations 

132 assume that each kernel class corresponds to a subalgebra, i.e., addition 

133 and multiplication preserve the class. 

134  

135 """ 

136 

137 __slots__ = '_initkw', '_dynkw', '_core' 1feabcd

138 # only __new__ and _clone shall access these attributes 

139 

140 @property 1feabcd

141 def initkw(self): 1feabcd

142 return types.MappingProxyType(self._initkw) 1feabcd

143 

144 @property 1feabcd

145 def dynkw(self): 1feabcd

146 return types.MappingProxyType(self._dynkw) 1feabcd

147 

148 @property 1feabcd

149 def core(self): 1feabcd

150 return self._core 1feabcd

151 

152 def __new__(cls, core, *, 1feabcd

153 scale=None, 

154 loc=None, 

155 derivable=None, 

156 maxdim=None, 

157 dim=None, 

158 forcekron=False, 

159 batchbytes=None, 

160 dynkw={}, 

161 **initkw, 

162 ): 

163 self = super().__new__(cls) 1feabcd

164 

165 self._initkw = initkw 1feabcd

166 self._dynkw = dict(dynkw) 1feabcd

167 self._core = lambda x, y, **dynkw: core(x, y, **initkw, **dynkw) 1feabcd

168 

169 if forcekron: 1feabcd

170 self = self.transf('forcekron') 1abcd

171 

172 linop_args = { 1feabcd

173 'scale': scale, 

174 'loc': loc, 

175 'derivable': derivable, 

176 'maxdim': maxdim, 

177 'dim': dim, 

178 } 

179 for transfname, arg in linop_args.items(): 1feabcd

180 if callable(arg): 1feabcd

181 arg = arg(**initkw) 1feabcd

182 if isinstance(arg, tuple): 1feabcd

183 self = self.linop(transfname, *arg) 1eabcd

184 else: 

185 self = self.linop(transfname, arg) 1feabcd

186 

187 if batchbytes is not None: 1feabcd

188 self = self.batch(batchbytes) 1eabcd

189 

190 return self 1feabcd

191 

192 def __call__(self, x, y): 1feabcd

193 x = _array.asarray(x) 1feabcd

194 y = _array.asarray(y) 1feabcd

195 shape = _array.broadcast(x, y).shape 1feabcd

196 result = self.core(x, y, **self.dynkw) 1feabcd

197 assert isinstance(result, (numpy.ndarray, jnp.number, jnp.ndarray)) 1feabcd

198 assert jnp.issubdtype(result.dtype, jnp.number), result.dtype 1feabcd

199 assert result.shape == shape, (result.shape, shape) 1feabcd

200 return result 1feabcd

201 

202 def _clone(self, cls=None, *, initkw=None, dynkw=None, core=None): 1feabcd

203 newself = object.__new__(self.__class__ if cls is None else cls) 1feabcd

204 newself._initkw = self._initkw if initkw is None else dict(initkw) 1feabcd

205 newself._dynkw = {} if dynkw is None else dict(dynkw) 1feabcd

206 newself._core = self._core if core is None else core 1feabcd

207 return newself 1feabcd

208 

209 class _side(enum.Enum): 1feabcd

210 LEFT = 0 1feabcd

211 RIGHT = 1 1feabcd

212 

213 @classmethod 1feabcd

214 def _nary(cls, op, kernels, side): 1feabcd

215 

216 if side is cls._side.LEFT: 1feabcd

217 wrapper = lambda c, _, y, **kw: lambda x: c(x, y, **kw) 1feabcd

218 arg = lambda x, _: x 1feabcd

219 elif side is cls._side.RIGHT: 1abcd

220 wrapper = lambda c, x, _, **kw: lambda y: c(x, y, **kw) 1abcd

221 arg = lambda _, y: y 1abcd

222 else: # pragma: no cover 

223 raise KeyError(side) 

224 

225 cores = [k.core for k in kernels] 1feabcd

226 def core(x, y, **kw): 1feabcd

227 wrapped = [wrapper(c, x, y, **kw) for c in cores] 1feabcd

228 transformed = op(*wrapped) 1feabcd

229 return transformed(arg(x, y)) 1feabcd

230 

231 return __class__(core) 1feabcd

232 

233 # TODO instead of wrapping the cores just by closing over an argument, 

234 # define a `PartialKernel` object that has linop defined, with one arg 

235 # fixed to None, and __call__ that closes over an argument. This way I 

236 # could apply ops directly with deflintransf. 

237 

238 # TODO move in the binary op methods the logic that decides wether to 

239 # return NotImplemented, while the algops will raise an exception if the 

240 # argument is not to their liking 

241 

242 def __add__(self, other): 1feabcd

243 return self.algop('add', other) 1feabcd

244 

245 __radd__ = __add__ 1feabcd

246 

247 def __mul__(self, other): 1feabcd

248 return self.algop('mul', other) 1feabcd

249 

250 __rmul__ = __mul__ 1feabcd

251 

252 def __pow__(self, other): 1feabcd

253 return self.algop('pow', exponent=other) 1abcd

254 

255 def __rpow__(self, other): 1feabcd

256 return self.algop('rpow', base=other) 1abcd

257 

258 def _swap(self): 1feabcd

259 """ permute the arguments """ 

260 core = self.core 1feabcd

261 return self._clone( 1feabcd

262 __class__, 

263 core=lambda x, y, **kw: core(y, x, **kw), 

264 ) 

265 

266 # TODO make _swap a transf inherited by CrossIsotropicKernel => messes 

267 # up with Kernel subclasses. I want to make Kernel subclasses identity 

268 # on swap, so it can't be an inherited transf, because Kernel appears 

269 # as the second base. 

270 

271 def batch(self, maxnbytes): 1feabcd

272 """ 

273 Return a batched version of the kernel. 

274 

275 The batched kernel processes its inputs in chunks to try to limit memory 

276 usage. 

277 

278 Parameters 

279 ---------- 

280 maxnbytes : number 

281 The maximum number of input bytes per chunk, counted after 

282 broadcasting the input shapes. Actual broadcasting may not occur if 

283 not induced by the operations in the kernel. 

284 

285 Returns 

286 ------- 

287 batched_kernel : CrossKernel 

288 The same kernel but with batched computations. 

289 """ 

290 core = _jaxext.batchufunc(self.core, maxnbytes=maxnbytes) 1eabcd

291 return self._clone(core=core) 1eabcd

292 

293 @classmethod 1feabcd

294 def _crossmro(cls): 1feabcd

295 """ MRO iterator excluding subclasses of Kernel """ 

296 for c in cls.mro(): # pragma: no branch 1feabcd

297 if not issubclass(c, Kernel): 1feabcd

298 yield c 1feabcd

299 if c is __class__: 1feabcd

300 break 1abcd

301 

302 _transf = {} 1feabcd

303 

304 _Transf = collections.namedtuple('_Transf', ['func', 'doc', 'kind']) 1feabcd

305 

306 def __init_subclass__(cls, **kw): 1feabcd

307 super().__init_subclass__(**kw) 1feabcd

308 cls._transf = {} 1feabcd

309 cls.__slots__ = () 1feabcd

310 

311 @classmethod 1feabcd

312 def _transfmro(cls): 1feabcd

313 """ Iterator of superclasses with a _transf attribute """ 

314 for c in cls.mro(): # pragma: no branch 1feabcd

315 yield c 1feabcd

316 if c is __class__: 1feabcd

317 break 1feabcd

318 

319 @classmethod 1feabcd

320 def _settransf(cls, transfname, transf): 1feabcd

321 if transfname in cls._transf: 1feabcd

322 raise KeyError(f'transformation {transfname!r} already registered ' 1abcd

323 f'for {cls.__name__}') 

324 cls._transf[transfname] = cls._Transf(*transf) 1feabcd

325 

326 @classmethod 1feabcd

327 def _alltransf(cls): 1feabcd

328 """ list all accessible transfs as dict name -> (tcls, transf) """ 

329 transfs = {} 1feabcd

330 for tcls in cls._transfmro(): 1feabcd

331 for name, transf in tcls._transf.items(): 1feabcd

332 transfs.setdefault(name, (tcls, transf)) 1feabcd

333 return transfs 1feabcd

334 

335 @classmethod 1feabcd

336 def _gettransf(cls, transfname, transfmro=None): 1feabcd

337 """ 

338 Find a transformation. 

339 

340 The transformation is searched following the MRO up to `CrossKernel`. 

341 

342 Parameters 

343 ---------- 

344 transfname : hashable 

345 The transformation name. 

346 

347 Returns 

348 ------- 

349 cls : type 

350 The class where the transformation was found. 

351 transf, doc, kind : tuple 

352 The objects set by `register_transf`. 

353 

354 Raises 

355 ------ 

356 KeyError : 

357 The transformation was not found. 

358 """ 

359 if transfmro is None: 1feabcd

360 transfmro = cls._transfmro() 1feabcd

361 for c in transfmro: 1feabcd

362 try: 1feabcd

363 return c, c._transf[transfname] 1feabcd

364 except KeyError: 1feabcd

365 pass 1feabcd

366 raise KeyError(transfname) 1abcd

367 

368 @classmethod 1feabcd

369 def inherit_transf(cls, transfname, *, intermediates=False): 1feabcd

370 """ 

371  

372 Inherit a transformation from a superclass. 

373 

374 Parameters 

375 ---------- 

376 transfname : hashable 

377 The name of the transformation. 

378 intermediates : bool, default False 

379 If True, make all superclasses up to the one definining the 

380 transformation inherit it too. 

381 

382 Raises 

383 ------ 

384 KeyError : 

385 The transformation was not found in any superclass, or the 

386 transformation is already registered on any of the target classes. 

387 

388 See also 

389 -------- 

390 transf 

391 

392 """ 

393 tcls, transf = cls._gettransf(transfname) 1feabcd

394 cls._settransf(transfname, transf) 1feabcd

395 if intermediates: 1feabcd

396 for c in cls.mro()[1:]: # pragma: no branch 1feabcd

397 if c is tcls: 1feabcd

398 break 1feabcd

399 c._settransf(transfname, transf) 1feabcd

400 

401 @classmethod 1feabcd

402 def inherit_all_algops(cls, intermediates=False): 1feabcd

403 """ 

404 

405 Inherit all algebraic operations from superclasses. 

406 

407 This makes sense if the class represents a subalgebra, i.e., it 

408 should be preserved by addition and multiplication. 

409 

410 Parameters 

411 ---------- 

412 intermediates : bool, default False 

413 If True, make all superclasses up to the one definining the 

414 transformation inherit it too. 

415 

416 Raises 

417 ------ 

418 KeyError : 

419 An algebraic operation is already registered for one of the target 

420 classes. 

421 

422 See also 

423 -------- 

424 transf 

425 

426 """ 

427 mro = cls._transfmro() 1feabcd

428 next(mro) 1feabcd

429 for name, (_, transf) in next(mro)._alltransf().items(): 1feabcd

430 if transf.kind is cls._algopmarker: 1feabcd

431 cls.inherit_transf(name, intermediates=intermediates) 1feabcd

432 

433 @classmethod 1feabcd

434 def list_transf(cls, superclasses=True): 1feabcd

435 """ 

436 List all the available transformations. 

437 

438 Parameters 

439 ---------- 

440 superclasses : bool, default True 

441 Include transformations defined in superclasses. 

442 

443 Returns 

444 ------- 

445 transfs: dict of Transf 

446 The dictionary keys are the transformation names, the values are 

447 named tuples ``(tcls, kind, impl, doc)`` where ``tcls`` is the class 

448 defining the transformation, ``kind`` is the kind of transformation, 

449 ``impl`` is the implementation with signature ``impl(tcls, self, 

450 *args, **kw)``, ``doc`` is the docstring. 

451 """ 

452 if superclasses: 1abcd

453 source = cls._alltransf().items 1abcd

454 else: 

455 def source(): 1abcd

456 for name, transf in cls._transf.items(): 1abcd

457 yield name, (cls, transf) 1abcd

458 return { 1abcd

459 name: cls.Transf(tcls, transf.kind, transf.func, transf.doc) 

460 for name, (tcls, transf) in source() 

461 } 

462 

463 Transf = collections.namedtuple('Transf', ['tcls', 'kind', 'func', 'doc']) 1feabcd

464 

465 @classmethod 1feabcd

466 def has_transf(cls, transfname): 1feabcd

467 """ 

468 Check if a transformation is registered. 

469 

470 Parameters 

471 ---------- 

472 transfname : hashable 

473 The transformation name. 

474 

475 Returns 

476 ------- 

477 has_transf : bool 

478 Whether the transformation is registered. 

479 

480 See also 

481 -------- 

482 transf 

483 """ 

484 try: 1abcd

485 cls._gettransf(transfname) 1abcd

486 except KeyError as exc: 1abcd

487 if exc.args == (transfname,): 1abcd

488 return False 1abcd

489 else: # pragma: no cover 

490 raise 

491 else: 

492 return True 1abcd

493 

494 @classmethod 1feabcd

495 def transf_help(cls, transfname): 1feabcd

496 """ 

497  

498 Return the documentation of a transformation. 

499 

500 Parameters 

501 ---------- 

502 transfname : hashable 

503 The name of the transformation. 

504 

505 Returns 

506 ------- 

507 doc : str 

508 The documentation of the transformation. 

509 

510 See also 

511 -------- 

512 transf 

513 

514 """ 

515 _, transf = cls._gettransf(transfname) 1abcd

516 return transf.doc 1abcd

517 

518 def transf(self, transfname, *args, **kw): 1feabcd

519 """ 

520 

521 Return a transformed kernel. 

522 

523 Parameters 

524 ---------- 

525 transfname : hashable 

526 A name identifying the transformation. 

527 *args, **kw : 

528 Arguments to the transformation. 

529 

530 Returns 

531 ------- 

532 newkernel : object 

533 The output of the transformation. 

534 

535 Raises 

536 ------ 

537 KeyError 

538 The transformation is not defined in this class or any superclass. 

539 

540 See also 

541 -------- 

542 linop, algop, transf_help, has_transf, list_transf, super_transf, register_transf, register_linop, register_corelinop, register_xtransf, register_algop, register_ufuncalgop 

543 

544 """ 

545 tcls, transf = self._gettransf(transfname) 1abcd

546 return transf.func(tcls, self, *args, **kw) 1abcd

547 

548 @classmethod 1feabcd

549 def super_transf(cls, transfname, self, *args, **kw): 1feabcd

550 """ 

551  

552 Transform the kernel using a superclass transformation. 

553 

554 This is equivalent to `transf` but is invoked on a class and the 

555 object is passed explicitly. The definition of the transformation is 

556 searched starting from the next class in the MRO of `self`. 

557 

558 Parameters 

559 ---------- 

560 transfname, *args, **kw : 

561 See `transf`. 

562 self : CrossKernel 

563 The object to transform. 

564 

565 Returns 

566 ------- 

567 newkernel : object 

568 The output of the transformation. 

569 

570 """ 

571 mro = list(self._transfmro()) 1eabcd

572 idx = mro.index(cls) 1eabcd

573 tcls, transf = self._gettransf(transfname, mro[idx + 1:]) 1eabcd

574 return transf.func(tcls, self, *args, **kw) 1eabcd

575 

576 def linop(self, transfname, *args, **kw): 1feabcd

577 r""" 

578 

579 Transform kernels to represent the application of a linear operator. 

580 

581 .. math:: 

582 \text{kernel}_1(x, y) &= \mathrm{Cov}[f_1(x), g_1(y)], \\ 

583 \text{kernel}_2(x, y) &= \mathrm{Cov}[f_2(x), g_2(y)], \\ 

584 &\ldots \\ 

585 \text{newkernel}(x, y) &= 

586 \mathrm{Cov}[T_f(f_1, f_2, \ldots)(x), T_g(g_1, g_2, \ldots)(y)] 

587 

588 Parameters 

589 ---------- 

590 transfname : hashable 

591 The name of the transformation. 

592 *args : 

593 A sequence of `CrossKernel` instances, representing the operands, 

594 followed by one or two non-kernel arguments, indicating how to act 

595 on each side of the kernels. If both arguments represent the 

596 identity, this is a no-op. If there is only one argument, it is 

597 intended that the two arguments are equal. `None` always represents 

598 the identity. 

599 

600 Returns 

601 ------- 

602 newkernel : CrossKernel 

603 The transformed kernel. 

604 

605 Raises 

606 ------ 

607 ValueError : 

608 The transformation exists but was not defined by `register_linop`. 

609 

610 See also 

611 -------- 

612 transf 

613 

614 Notes 

615 ----- 

616 The linear operator is defined on the function the kernel represents the 

617 distribution of, not in full generality on the kernel itself. When 

618 multiple kernels are involved, their distributions may be considered 

619 independent or not, depending on the specific operation. 

620 

621 If the result is a subclass of the class defining the transformation, 

622 the result is casted to the latter. Then, if the result and all the 

623 operands are instances of `Kernel`, but the two operator arguments 

624 differ, the result is casted to its first non-`Kernel` superclass. 

625  

626 """ 

627 tcls, transf = self._gettransf(transfname) 1feabcd

628 if transf.kind is not self._linopmarker: 1feabcd

629 raise ValueError(f'the transformation {transfname!r} was not ' 1abcd

630 f'defined with register_linop and so can not be invoked ' 

631 f'by linop') 

632 return transf.func(tcls, self, *args) 1feabcd

633 

634 def algop(self, transfname, *operands, **kw): 1feabcd

635 r""" 

636 

637 Return a nonnegative algebraic transformation of the input kernels. 

638 

639 .. math:: 

640 \mathrm{newkernel}(x, y) &= 

641 f(\mathrm{kernel}_1(x, y), \mathrm{kernel}_2(x, y), \ldots), \\ 

642 f(z_1, z_2, \ldots) &= \sum_{k_1,k_2,\ldots=0}^\infty 

643 a_{k_1 k_2 \ldots} z_1^{k_1} z_2^{k_2} \ldots, 

644 \quad a_* \ge 0. 

645 

646 Parameters 

647 ---------- 

648 transfname : hashable 

649 A name identifying the transformation. 

650 *operands : CrossKernel, scalar 

651 Arguments to the transformation in addition to self. 

652 **kw : 

653 Additional arguments to the transformation, not considered as 

654 operands. 

655 

656 Returns 

657 ------- 

658 newkernel : CrossKernel or NotImplemented 

659 The transformed kernel, or NotImplemented if the operation is 

660 not supported. 

661 

662 See also 

663 -------- 

664 transf 

665 

666 Notes 

667 ----- 

668 The class of `newkernel` is the least common superclass of: the 

669 "natural" output of the operation, the class defining the 

670 transformation, and the classes of the operands. 

671 

672 For class determination, scalars in the input count as `Constant` 

673 if nonnegative or traced by jax, else `CrossConstant`. 

674 

675 """ 

676 tcls, transf = self._gettransf(transfname) 1feabcd

677 if transf.kind is not self._algopmarker: 1feabcd

678 raise ValueError(f'the transformation {transfname!r} was not ' 1abcd

679 f'defined with register_algop and so can not be invoked ' 

680 f'by algop') 

681 return transf.func(tcls, self, *operands, **kw) 1feabcd

682 

683 @classmethod 1feabcd

684 def register_transf(cls, func, transfname=None, doc=None, kind=None): 1feabcd

685 """ 

686  

687 Register a transformation for use with `transf`. 

688 

689 The transformation will be accessible to subclasses. 

690 

691 Parameters 

692 ---------- 

693 func : callable 

694 A function ``func(tcls, self, *args, **kw) -> object`` implementing 

695 the transformation, where ``tcls`` is the class that defines the 

696 transformation. 

697 transfname : hashable, optional. 

698 The `transfname` parameter to `transf` this transformation will 

699 be accessible under. If not specified, use the name of `func`. 

700 doc : str, optional 

701 The documentation of the transformation for `transf_help`. If not 

702 specified, use the docstring of `func`. 

703 kind : object, optional 

704 An arbitrary marker. 

705 

706 Returns 

707 ------- 

708 func : callable 

709 The `func` argument as is. 

710 

711 Raises 

712 ------ 

713 KeyError : 

714 The name is already in use for another transformation in the same 

715 class. 

716 

717 See also 

718 -------- 

719 transf 

720 

721 """ 

722 if transfname is None: 1feabcd

723 transfname = func.__name__ 1feabcd

724 if doc is None: 1feabcd

725 doc = func.__doc__ 1feabcd

726 cls._settransf(transfname, (func, doc, kind)) 1feabcd

727 return func 1feabcd

728 

729 # TODO forbid to override with a different kind? 

730 

731 @classmethod 1feabcd

732 def register_linop(cls, op, transfname=None, doc=None, argparser=None): 1feabcd

733 """ 

734  

735 Register a transformation for use with `linop`. 

736 

737 Parameters 

738 ---------- 

739 op : callable 

740 A function ``op(tcls, self, arg1, arg2, *operands) -> CrossKernel`` 

741 that returns the new kernel, where ``arg1`` and ``arg2`` represent 

742 the operators acting on each side of the kernels, and ``operands`` 

743 are the other kernels beyond ``self``. 

744 transfname, doc : optional 

745 See `register_transf`. 

746 argparser : callable, optional 

747 A function applied to ``arg1`` and ``arg2``. Not called if the 

748 argument is `None`. It should map the identity to `None`. 

749 

750 Returns 

751 ------- 

752 op : callable 

753 The `op` argument as is. 

754 

755 Notes 

756 ----- 

757 The function `op` is called only if ``arg1`` or ``arg2`` is not `None` 

758 after potential conversion with `argparser`. 

759 

760 See also 

761 -------- 

762 transf 

763 

764 """ 

765 

766 if transfname is None: 1feabcd

767 transfname = op.__name__ # for result type error message 1feabcd

768 

769 @functools.wraps(op) 1feabcd

770 def func(tcls, self, *allargs): 1feabcd

771 

772 # split the arguments in kernels and non-kernels 

773 for pos, arg in enumerate(allargs): 1feabcd

774 if not isinstance(arg, __class__): 1feabcd

775 break 1feabcd

776 else: 

777 pos = len(allargs) 1abcd

778 operands = allargs[:pos] 1feabcd

779 args = allargs[pos:] 1feabcd

780 

781 # check the arguments from the first non-kernel onwards are 1 or 2 

782 if len(args) not in (1, 2): 1feabcd

783 raise ValueError(f'incorrect number of non-kernel tail ' 1abcd

784 f'arguments {len(args)}, expected 1 or 2') 

785 

786 # wrap argument parser to enforce preserving None 

787 if argparser: 1feabcd

788 conv = lambda x: None if x is None else argparser(x) 1feabcd

789 else: 

790 conv = lambda x: x 1eabcd

791 

792 # determine if the two arguments count as "identical" or not 

793 if len(args) == 1: 1feabcd

794 arg = conv(*args) 1feabcd

795 different = False 1feabcd

796 arg1 = arg2 = arg 1feabcd

797 else: 

798 arg1, arg2 = args 1feabcd

799 different = arg1 is not arg2 1feabcd

800 arg1 = conv(arg1) 1feabcd

801 arg2 = conv(arg2) 1feabcd

802 different &= arg1 is not arg2 1feabcd

803 # they must be not identical both before and after to handle 

804 # these cases: 

805 # - if the user passes identical arguments, but argparser 

806 # makes copies, it must still count as identical 

807 # - if the user passes arguments which are not identical, 

808 # but argparser sends them to the same object, then they 

809 # surely represent the same transf 

810 

811 # handle no-op case 

812 if arg1 is None and arg2 is None: 1feabcd

813 return self 1feabcd

814 

815 # invoke implementation 

816 result = op(tcls, self, arg1, arg2, *operands) 1feabcd

817 

818 # check result is a kernel 

819 if not isinstance(result, __class__): 1feabcd

820 raise TypeError(f'linop {transfname!r} returned ' 1abcd

821 f'object of type {result.__class__.__name__}, expected ' 

822 f'subclass of {__class__.__name__}') 

823 

824 # modify class of the result 

825 rcls = result.__class__ 1feabcd

826 if issubclass(rcls, tcls): 1feabcd

827 rcls = tcls 1feabcd

828 all_operands_kernel = all(isinstance(o, Kernel) for o in operands) 1feabcd

829 if isinstance(self, Kernel) and all_operands_kernel and different: 1feabcd

830 rcls = next(rcls._crossmro()) 1feabcd

831 if rcls is not result.__class__: 1feabcd

832 result = result._clone(rcls) 1feabcd

833 

834 return result 1feabcd

835 

836 cls.register_transf(func, transfname, doc, cls._linopmarker) 1feabcd

837 return op 1feabcd

838 

839 class _LinOpMarker(str): pass 1feabcd

840 _linopmarker = _LinOpMarker('linop') 1feabcd

841 

842 @classmethod 1feabcd

843 def register_corelinop(cls, corefunc, transfname=None, doc=None, argparser=None): 1feabcd

844 """ 

845 

846 Register a linear operator with a function that acts only on the core. 

847 

848 Parameters 

849 ---------- 

850 corefunc : callable 

851 A function ``corefunc(core, arg1, arg2, *cores) -> newcore``, where 

852 ``core`` is the function that implements the kernel passed at 

853 initialization, and ``cores`` for other operands. 

854 transfname, doc, argparser : 

855 See `register_linop`. 

856 

857 Returns 

858 ------- 

859 corefunc : callable 

860 The `corefunc` argument as is. 

861 

862 See also 

863 -------- 

864 transf 

865 

866 """ 

867 @functools.wraps(corefunc) 1feabcd

868 def op(_, self, arg1, arg2, *operands): 1feabcd

869 cores = (o.core for o in operands) 1feabcd

870 core = corefunc(self.core, arg1, arg2, *cores) 1feabcd

871 return self._clone(core=core) 1feabcd

872 cls.register_linop(op, transfname, doc, argparser) 1feabcd

873 return corefunc 1feabcd

874 

875 @classmethod 1feabcd

876 def register_xtransf(cls, xfunc, transfname=None, doc=None): 1feabcd

877 """ 

878 

879 Register a linear operator that acts only on the input. 

880 

881 Parameters 

882 ---------- 

883 xfunc : callable 

884 A function ``xfunc(arg) -> (lambda x: newx)`` that takes in a 

885 `linop` argument and produces a function to transform the input. Not 

886 called if ``arg`` is `None`. To indicate the identity, return 

887 `None`. 

888 transfname, doc : 

889 See `register_linop`. `argparser` is not provided because its 

890 functionality can be included in `xfunc`. 

891 

892 Returns 

893 ------- 

894 xfunc : callable 

895 The `xfunc` argument as is. 

896 

897 See also 

898 -------- 

899 transf 

900 

901 """ 

902 

903 @functools.wraps(xfunc) 1feabcd

904 def corefunc(core, xfun, yfun): 1feabcd

905 if not xfun: 1feabcd

906 return lambda x, y, **kw: core(x, yfun(y), **kw) 1abcd

907 elif not yfun: 1feabcd

908 return lambda x, y, **kw: core(xfun(x), y, **kw) 1abcd

909 else: 

910 return lambda x, y, **kw: core(xfun(x), yfun(y), **kw) 1feabcd

911 

912 cls.register_corelinop(corefunc, transfname, doc, xfunc) 1feabcd

913 return xfunc 1feabcd

914 

915 @classmethod 1feabcd

916 def register_algop(cls, op, transfname=None, doc=None): 1feabcd

917 """ 

918 

919 Register a transformation for use with `algop`. 

920 

921 Parameters 

922 ---------- 

923 op : callable 

924 A function ``op(tcls, *kernels, **kw) -> CrossKernel | 

925 NotImplemented`` that returns the new kernel. ``kernels`` may be 

926 scalars but for the first argument. 

927 transfname, doc : 

928 See `register_transf`. 

929 

930 Returns 

931 ------- 

932 op : callable 

933 The `op` argument as is. 

934 

935 See also 

936 -------- 

937 transf 

938 

939 """ 

940 

941 if transfname is None: 1feabcd

942 transfname = op.__name__ # for error message 1feabcd

943 

944 @functools.wraps(op) 1feabcd

945 def func(tcls, *operands, **kw): 1feabcd

946 result = op(tcls, *operands, **kw) 1feabcd

947 

948 if result is NotImplemented: 1feabcd

949 return result 1abcd

950 elif not isinstance(result, __class__): 1feabcd

951 raise TypeError(f'algop {transfname!r} returned ' 1abcd

952 f'object of type {result.__class__.__name__}, expected ' 

953 f'subclass of {__class__.__name__}') 

954 

955 def classes(): 1feabcd

956 yield tcls 1feabcd

957 for o in operands: 1feabcd

958 if isinstance(o, __class__): 1feabcd

959 yield o.__class__ 1feabcd

960 elif _util.is_nonnegative_scalar_trueontracer(o): 1feabcd

961 yield Constant 1feabcd

962 elif _util.is_numerical_scalar(o): 1abcd

963 yield CrossConstant 1abcd

964 else: 

965 raise TypeError(f'operands to algop {transfname!r} ' 1abcd

966 f'must be CrossKernel or numbers, found {o!r}') 

967 # this type check comes after letting the implementation 

968 # return NotImplemented, to support overloading 

969 yield result.__class__ 1feabcd

970 

971 lcs = least_common_superclass(*classes()) 1feabcd

972 return result._clone(lcs) 1feabcd

973 

974 cls.register_transf(func, transfname, doc, cls._algopmarker) 1feabcd

975 return op 1feabcd

976 

977 # TODO delete initkw (also in linop) if there's more than one kernel 

978 # operand or if the class changed? 

979 

980 # TODO consider adding an option domains=callable, returns list of 

981 # domains from the operands, or list of tuples right away, and the 

982 # impl checks at runtime (if not traced) that the output values are in 

983 # the domains, with an informative error message 

984 

985 # TODO consider escalating to ancestor definitions of the transf 

986 # until an impl does not return NotImplemented, and then raise an 

987 # an error instead of letting the NotImplemented escape. This would be 

988 # useful to partial behavior ovverride in subclasses, e.g, handle 

989 # multiplication by scalar to rewire transformations. 

990 

991 class _AlgOpMarker(str): pass 1feabcd

992 _algopmarker = _AlgOpMarker('algop') 1feabcd

993 

994 @classmethod 1feabcd

995 def register_ufuncalgop(cls, ufunc, transfname=None, doc=None): 1feabcd

996 """ 

997 

998 Register an algebraic operation with a function that acts only on the 

999 kernel value. 

1000 

1001 Parameters 

1002 ---------- 

1003 corefunc : callable 

1004 A function ``ufunc(*values, **kw) -> value``, where ``values`` are 

1005 the values yielded by the operands. 

1006 transfname, doc : 

1007 See `register_transf`. 

1008 

1009 Returns 

1010 ------- 

1011 ufunc : callable 

1012 The `ufunc` argument as is. 

1013 

1014 See also 

1015 -------- 

1016 transf 

1017 

1018 """ 

1019 @functools.wraps(ufunc) 1feabcd

1020 def op(_, self, *operands, **kw): 1feabcd

1021 cores = tuple( 1fabcd

1022 o.core if isinstance(o, __class__) 

1023 else lambda x, y: o 

1024 for o in (self, *operands) 

1025 ) 

1026 def core(x, y, **kw): 1fabcd

1027 values = (core(x, y, **kw) for core in cores) 1abcd

1028 return ufunc(*values, **kw) 1abcd

1029 return self._clone(core=core) 1fabcd

1030 cls.register_algop(op, transfname, doc) 1feabcd

1031 return ufunc 1feabcd

1032 

1033 @classmethod 1feabcd

1034 def make_linop_family(cls, transfname, bothker, leftker, rightker=None, *, 1feabcd

1035 doc=None, argparser=None, argnames=None, translkw=None): 

1036 """ 

1037  

1038 Form a family of kernels classes related by linear operators. 

1039 

1040 The class this method is called on is the seed class. A new 

1041 trasformation is registered to obtain other classes from the seed class 

1042 by applying a newly defined linop transformation. 

1043 

1044 Parameters 

1045 ---------- 

1046 transfname : str 

1047 The name of the new transformation. 

1048 bothker, leftker, rightker : CrossKernel 

1049 The kernel classes to be obtained by applying the operator to a seed 

1050 class object respectively on both sides, only left, or only right. 

1051 All classes are assumed to require no positional arguments at 

1052 construction, and recognize the same set of keyword arguments. If 

1053 `rightker` is not specified, it is defined by subclassing `leftker` 

1054 and transposing the kernel on object construction. 

1055 doc, argparser : callable, optional 

1056 See `register_linop`. 

1057 argnames : pair of str, optional 

1058 If specified, `leftker` is passed an additional keyword argument 

1059 with name ``argnames[0]`` that specifies the argument of the 

1060 operator, `rightker` is passed ``argnames[1]``, and `bothker` 

1061 similarly is passed both. 

1062 

1063 If `leftker` is used to implement `rightker`, it must accept both 

1064 arguments, and may use them to determine if it is being invoked as 

1065 `leftker` or `rightker`. 

1066 translkw : callable, optional 

1067 A function with signature ``translkw(dynkw, <argnames>, **initkw) -> 

1068 dict`` that determines the constructor arguments for a new object 

1069 starting from the ones of the object to be transformed. By default, 

1070 ``initkw`` is passed over, and an error is raised if ``dynkw`` is 

1071 not empty. 

1072 

1073 Examples 

1074 -------- 

1075 

1076 >>> @lgp.kernel 

1077 ... def A(x, y, *, gatto): 

1078 ... ''' The reknown A kernel of order gatto ''' 

1079 ... return ... 

1080 ... 

1081 >>> @lgp.kernel 

1082 ... def T(n, m, *, gatto, op): 

1083 ... ''' The kernel of the Topo series of an A process of order 

1084 ... gatto ''' 

1085 ... op1, op2 = op 

1086 ... return ... 

1087 ... 

1088 >>> @lgp.crosskernel 

1089 ... def CrossTA(n, x, *, gatto, op): 

1090 ... ''' The cross covariance between the Topo series of an A 

1091 ... process of order gatto and the process itself ''' 

1092 ... return ... 

1093 ... 

1094 >>> A.make_linop_family('topo', T, CrossTA) 

1095 >>> a = A(gatto=7) 

1096 >>> ta = A.linop('topo', True, None) 

1097 >>> at = A.linop('topo', None, True) 

1098 >>> t = A.linop('topo', True) 

1099 

1100 See also 

1101 -------- 

1102 transf 

1103 

1104 """ 

1105 

1106 if rightker is None: 1feabcd

1107 

1108 # invent a name for rightker 

1109 rightname = f'Cross{cls.__name__}{bothker.__name__}' 1feabcd

1110 

1111 # define how to set up rightker 

1112 def exec_body(ns): 1feabcd

1113 

1114 if leftker.__doc__: 1feabcd

1115 header = 'Automatically generated transposed version of:\n\n' 1abcd

1116 ns['__doc__'] = _utils.append_to_docstring(leftker.__doc__, header, front=True) 1abcd

1117 

1118 def __new__(cls, *args, **kw): 1feabcd

1119 self = super(rightker, cls).__new__(cls, *args, **kw) 1abcd

1120 

1121 if self.__class__ is cls: 1abcd

1122 self = self._swap() 1abcd

1123 if not isinstance(self, leftker): 1123 ↛ 1124line 1123 didn't jump to line 1124 because the condition on line 1123 was never true1abcd

1124 raise TypeError(f'newly created instance of ' 

1125 f'automatically defined {rightker.__name__} is not an ' 

1126 f'instance of {leftker.__name__} after ' 

1127 f'transposition. Either define transposition ' 

1128 f'for {leftker.__name__}, or define ' 

1129 f'{rightker.__name__} manually') 

1130 return self._clone(cls) 1abcd

1131 

1132 else: 

1133 return self._swap() 1abcd

1134 

1135 ns['__new__'] = __new__ 1feabcd

1136 

1137 # create rightker, evil twin of leftker separated at birth 

1138 rightker = types.new_class(rightname, (leftker,), exec_body=exec_body) 1feabcd

1139 

1140 # check which classes are symmetric 

1141 classes = cls, bothker, leftker, rightker 1feabcd

1142 sym = tuple(issubclass(c, Kernel) for c in classes) 1feabcd

1143 exp = True, True, False, False 1feabcd

1144 if sym != exp: 1feabcd

1145 desc = lambda t: 'Kernel' if t else 'non-Kernel' 1abcd

1146 warnings.warn(f'Expected classes pattern {", ".join(map(desc, exp))}, ' 1abcd

1147 f'found {", ".join(map(desc, sym))}') 

1148 

1149 # set translkw if not specified 

1150 if translkw is None: 1feabcd

1151 def translkw(*, dynkw, **initkw): 1abcd

1152 if dynkw: 1152 ↛ 1153line 1152 didn't jump to line 1153 because the condition on line 1152 was never true1abcd

1153 raise ValueError('found non-empty `dynkw`, the default ' 

1154 'implementation of `translkw` does not support it') 

1155 return initkw 1abcd

1156 

1157 # function to produce the arguments to the transformed objects 

1158 def makekw(self, arg1, arg2): 1feabcd

1159 kw = dict(dynkw=self.dynkw, **self.initkw) 1eabcd

1160 if argnames is not None: 1eabcd

1161 if arg1 is not None: 1abcd

1162 kw = dict(**kw, **{argnames[0]: arg1}) 1abcd

1163 if arg2 is not None: 1abcd

1164 kw = dict(**kw, **{argnames[1]: arg2}) 1abcd

1165 return translkw(**kw) 1eabcd

1166 

1167 # register linop mapping cls to either leftker, rightker or bothker 

1168 regkw = dict(transfname=transfname, doc=doc, argparser=argparser) 1feabcd

1169 @functools.partial(cls.register_linop, **regkw) 1feabcd

1170 def op_seed_to_siblings(_, self, arg1, arg2): 1feabcd

1171 kw = makekw(self, arg1, arg2) 1eabcd

1172 if arg2 is None: 1eabcd

1173 return leftker(**kw) 1eabcd

1174 elif arg1 is None: 1eabcd

1175 return rightker(**kw) 1abcd

1176 else: 

1177 return bothker(**kw) 1eabcd

1178 

1179 # register linop mapping leftker to bothker 

1180 @functools.partial(leftker.register_linop, **regkw) 1feabcd

1181 def op_left_to_both(_, self, arg1, arg2): 1feabcd

1182 if arg1 is None: 1abcd

1183 return bothker(**makekw(self, arg1, arg2)) 1abcd

1184 else: 

1185 raise ValueError(f'cannot further transform ' 1abcd

1186 f'`{leftker.__name__}` on left side with linop ' 

1187 f'{transfname!r}') 

1188 

1189 # register linop mapping rightker to bothker 

1190 @functools.partial(rightker.register_linop, **regkw) 1feabcd

1191 def op_right_to_both(_, self, arg1, arg2): 1feabcd

1192 if arg2 is None: 1abcd

1193 return bothker(**makekw(self, arg1, arg2)) 1abcd

1194 else: 

1195 raise ValueError(f'cannot further transform ' 1abcd

1196 f'`{rightker.__name__}` on right side with linop ' 

1197 f'{transfname!r}') 

1198 

1199class AffineSpan(CrossKernel, abc.ABC): 1feabcd

1200 """ 

1201 

1202 Kernel that tracks affine transformations. 

1203 

1204 An `AffineSpan` instance accumulates the overall affine transformation 

1205 applied to its inputs and output. 

1206 

1207 `AffineSpan` and it subclasses are preserved by the transformations 

1208 'scale', 'loc', 'add' (with scalar) and 'mul' (with scalar). 

1209 

1210 `AffineSpan` can not be instantiated directly or used as standalone 

1211 superclass. It must be the first base before concrete superclasses. 

1212 

1213 """ 

1214 

1215 _affine_dynkw = dict(lloc=0, rloc=0, lscale=1, rscale=1, offset=0, ampl=1) 1feabcd

1216 

1217 def __new__(cls, *args, dynkw={}, **kw): 1feabcd

1218 if cls is __class__: 1eabcd

1219 raise TypeError(f'cannot instantiate {__class__.__name__} directly') 1abcd

1220 new_dynkw = dict(cls._affine_dynkw) 1eabcd

1221 new_dynkw.update(dynkw) 1eabcd

1222 return super().__new__(cls, *args, dynkw=new_dynkw, **kw) 1eabcd

1223 

1224 def __init_subclass__(cls, **kw): 1feabcd

1225 super().__init_subclass__(**kw) 1feabcd

1226 for name in __class__._transf: 1feabcd

1227 cls.inherit_transf(name) 1feabcd

1228 

1229 # TODO I would like to inherit conditional on there not being another 

1230 # definition. However, since registrations are done after class 

1231 # creation, this method is invoked too early to do that. 

1232 # 

1233 # Is it possible to reimplement the transformation system to have the 

1234 # transformations defined in the class body? 

1235 # 

1236 # Option 1: use a decorator that marks the transf methods and makes 

1237 # them staticmethods. Then an __init_subclass__ above CrossKernel goes 

1238 # through the methods and registers the marked ones. 

1239 #  

1240 # Option 2: try to use descriptors (classes like property). Is it 

1241 # possible to reimplement from scratch something like classmethod? If 

1242 # so, I could make a transfmethod that bounds two arguments: tcls 

1243 # for the class that defines it, and self for the object that invokes 

1244 # it. Then inheriting would be just copying the thing over. If I wanted 

1245 # to keep the invocation through method interface, I could define the 

1246 # methods with underscores and register them somewhere. 

1247 # 

1248 # However, since I define kernels with decorators, I can't actually 

1249 # count on the transformations being defined into them at class 

1250 # creation. But that would be quite a corner case. 

1251 # 

1252 # To make subclassing convenient without decorators, allow to define a 

1253 # core(x, y) method, which __init_subclass__ converts to staticmethod, 

1254 # and have __new__ use it to set _core (I can't use the same name with 

1255 # slots). 

1256 # 

1257 # to use without post-class registering make_linopfamily, make it a 

1258 # decorator with attributed subdecorators for other members: 

1259 # 

1260 # @linop_family('T') 

1261 # @kernel 

1262 # def A(...) 

1263 # @A.left('arg') 

1264 # @crosskernel 

1265 # def CrossTA(...) 

1266 # @A.right # optional, generated from left if not specified 

1267 # ... 

1268 # @A.both('arg1', 'arg2') 

1269 # @kernel 

1270 # def T(...) 

1271 # 

1272 # Can also be stacked on top of left/right/both to chain families. Works 

1273 # both with decorators and explicit class definitions. 

1274 

1275 def _clone(self, *args, **kw): 1feabcd

1276 newself = super()._clone(*args, **kw) 1eabcd

1277 if isinstance(newself, __class__): 1eabcd

1278 for name in self._affine_dynkw: 1eabcd

1279 newself._dynkw[name] = self._dynkw[name] 1eabcd

1280 return newself 1eabcd

1281 

1282 @classmethod 1feabcd

1283 def __subclasshook__(cls, sub): 1feabcd

1284 if cls is __class__: 1eabcd

1285 return NotImplemented 1eabcd

1286 # to avoid algops promoting to unqualified AffineSpan 

1287 if issubclass(cls, Kernel): 1eabcd

1288 if issubclass(sub, Constant): 1eabcd

1289 return True 1abcd

1290 else: 

1291 return NotImplemented 1eabcd

1292 elif issubclass(sub, CrossConstant): 1abcd

1293 return True 1abcd

1294 else: 

1295 return NotImplemented 1abcd

1296 

1297 # TODO I could do separately AffineLeft, AffineRight, AffineOut, and make 

1298 # this a subclass of those three. AffineOut would also allow keeping the 

1299 # class when adding two objects without keyword arguments in initkw and 

1300 # dynkw beyond those managed by Affine. => Make only AffineSide and have it 

1301 # switch side on _swap. 

1302 

1303 # TODO when I reimplement transformations as methods, make AffineSpan not 

1304 # a subclass of CrossKernel. Right now I have to to avoid routing around 

1305 # the assumption that all classes in the MRO implement the transformation 

1306 # management logic. 

1307 

1308class PreservedBySwap(CrossKernel): 1feabcd

1309 

1310 def __new__(cls, *args, **kw): 1feabcd

1311 if cls is __class__: 1311 ↛ 1312line 1311 didn't jump to line 1312 because the condition on line 1311 was never true1eabcd

1312 raise TypeError(f'cannot instantiate {__class__.__name__} directly') 

1313 return super().__new__(cls, *args, **kw) 1eabcd

1314 

1315 def _swap(self): 1feabcd

1316 return super()._swap()._clone(self.__class__) 1eabcd

1317 

1318 # TODO when I implement transformations with methods, make this not an 

1319 # instance of CrossKernel.