Coverage for src/lsqfitgp/_Kernel/_ops.py: 99%

167 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_Kernel/_ops.py 

2# 

3# Copyright (c) 2020, 2022, 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20""" register linops on CrossKernel and AffineSpan """ 

21 

22import functools 1feabcd

23import numbers 1feabcd

24import sys 1feabcd

25 

26from jax import numpy as jnp 1feabcd

27import numpy 1feabcd

28 

29from .. import _jaxext 1feabcd

30from .. import _Deriv 1feabcd

31from .. import _array 1feabcd

32 

33from . import _util 1feabcd

34from ._crosskernel import CrossKernel, AffineSpan 1feabcd

35 

36def rescale_argparser(fun): 1feabcd

37 if not callable(fun): 1eabcd

38 raise ValueError("argument to 'rescale' must be a function") 1abcd

39 return fun 1eabcd

40 

41@functools.partial(CrossKernel.register_corelinop, argparser=rescale_argparser) 1feabcd

42def rescale(core, xfun, yfun): 1feabcd

43 r""" 

44  

45 Rescale the output of the function. 

46  

47 .. math:: 

48 T(f)(x) = \mathrm{fun}(x) f(x) 

49  

50 Parameters 

51 ---------- 

52 xfun, yfun : callable or None 

53 Functions from the type of the arguments of the kernel to scalar. 

54  

55 """ 

56 if not xfun: 1eabcd

57 return lambda x, y, **kw: core(x, y, **kw) * yfun(y) 1abcd

58 elif not yfun: 1eabcd

59 return lambda x, y, **kw: xfun(x) * core(x, y, **kw) 1eabcd

60 else: 

61 return lambda x, y, **kw: xfun(x) * core(x, y, **kw) * yfun(y) 1eabcd

62 

63@CrossKernel.register_xtransf 1feabcd

64def derivable(derivable): 1feabcd

65 """ 

66 Specify the degree of derivability of the function. 

67 

68 Parameters 

69 ---------- 

70 xderivable, yderivable: int or None 

71 Degree of derivability of the function. None means unknown. 

72 

73 Notes 

74 ----- 

75 The derivability check is hardcoded into the kernel core and it is not 

76 possible to remove it afterwards by applying ``'derivable'`` again with a 

77 higher limit. 

78 

79 """ 

80 if isinstance(derivable, bool): 1feabcd

81 derivable = sys.maxsize if derivable else 0 1feabcd

82 elif not isinstance(derivable, numbers.Integral) or derivable < 0: 1eabcd

83 raise ValueError(f'derivability degree {derivable!r} not valid') 1abcd

84 

85 def error_func(current, n): 1feabcd

86 raise ValueError(f'Took {current} derivatives > limit {n} on argument ' 1abcd

87 'of a kernel. This error may be spurious if there are ' 

88 'derivatives on values that define the input to the kernel, for ' 

89 'example if a hyperparameter enters the calculation of x. To ' 

90 'suppress the error, initialize the kernel with derivable=True.') 

91 

92 def xtransf(x): 1feabcd

93 if hasattr(x, 'dtype'): 93 ↛ 103line 93 didn't jump to line 103 because the condition on line 93 was always true1feabcd

94 # this branch handles limit_derivatives not accepting non-jax types 

95 # because of being based on jax.custom_jvp; this restriction on 

96 # custom_jvp appeared in jax 0.4.17 

97 if x.dtype.names is not None: 1feabcd

98 x = _array.StructuredArray(x) # structured arrays are not 1feabcd

99 # compatible with jax but common in lsqfitgp, so I wrap them 

100 elif not _jaxext.is_jax_type(x.dtype): 1feabcd

101 return x # since anyway there would be an error if a derivative 1abcd

102 # tried to pass through a non-jax type 

103 return _jaxext.limit_derivatives(x, n=derivable, error_func=error_func) 1feabcd

104 

105 return xtransf 1feabcd

106 

107 # TODO this system does not ignore additional derivatives that are not 

108 # taken by .transf('diff'). Plan: 

109 # - understand how to associate a jax transformation to a frame 

110 # - make a context manager with a global stack 

111 # - at initialization it takes the frames of the derivatives made by diff 

112 # - the context is around calling core in diff.newcore 

113 # - it adds the frames to the stack 

114 # - derivable asks the context manager class to find the frames in x 

115 # - raises if their number is above the limit 

116 

117def _asfloat(x): 1feabcd

118 return x.astype(_jaxext.float_type(x)) 1feabcd

119 

120def diff_argparser(deriv): 1feabcd

121 deriv = _Deriv.Deriv(deriv) 1feabcd

122 return deriv if deriv else None 1feabcd

123 

124@functools.partial(CrossKernel.register_corelinop, argparser=diff_argparser) 1feabcd

125def diff(core, xderiv, yderiv): 1feabcd

126 r""" 

127  

128 Derive the function. 

129  

130 .. math:: 

131 T(f)(x) = \frac{\partial^n f}{\partial x^n} (x) 

132  

133 Parameters 

134 ---------- 

135 xderiv, yderiv : Deriv_like 

136 A `Deriv` or something that can be converted to a `Deriv`. 

137  

138 Raises 

139 ------ 

140 RuntimeError 

141 The derivative orders are greater than the `derivative` attribute. 

142  

143 """ 

144 

145 # reparse derivatives because they could be None 

146 xderiv = _Deriv.Deriv(xderiv) 1feabcd

147 yderiv = _Deriv.Deriv(yderiv) 1feabcd

148 

149 # wrapper of kernel with derivable arguments unpacked 

150 def f(x, y, *args, **kw): 1feabcd

151 i = -1 1feabcd

152 if not xderiv.implicit: 1feabcd

153 for i, dim in enumerate(xderiv): 1feabcd

154 x = x.at[dim].set(args[i]) 1feabcd

155 if not yderiv.implicit: 1feabcd

156 for j, dim in enumerate(yderiv): 1feabcd

157 y = y.at[dim].set(args[1 + i + j]) 1feabcd

158 return core(x, y, **kw) 1feabcd

159 

160 # last x index in case iteration on x does not run but does on y 

161 i = -1 1feabcd

162 

163 # derive w.r.t. first argument 

164 if xderiv.implicit: 1feabcd

165 for _ in range(xderiv.order): 1feabcd

166 f = _jaxext.elementwise_grad(f, 0) 1feabcd

167 else: 

168 for i, dim in enumerate(xderiv): 1feabcd

169 for _ in range(xderiv[dim]): 1feabcd

170 f = _jaxext.elementwise_grad(f, 2 + i) 1feabcd

171 

172 # derive w.r.t. second argument 

173 if yderiv.implicit: 1feabcd

174 for _ in range(yderiv.order): 1feabcd

175 f = _jaxext.elementwise_grad(f, 1) 1feabcd

176 else: 

177 for j, dim in enumerate(yderiv): 1feabcd

178 for _ in range(yderiv[dim]): 1feabcd

179 f = _jaxext.elementwise_grad(f, 2 + 1 + i + j) 1feabcd

180 

181 # check derivatives are ok for actual input arrays, wrap structured arrays 

182 def process_arg(x, deriv, pos): 1feabcd

183 if x.dtype.names is not None: 1feabcd

184 for dim in deriv: 1feabcd

185 if dim not in x.dtype.names: 1feabcd

186 raise ValueError(f'derivative along missing field {dim!r} ' 1abcd

187 f'on {pos} argument') 

188 if not jnp.issubdtype(x.dtype[dim], jnp.number): 1feabcd

189 raise TypeError(f'derivative along non-numeric field ' 1abcd

190 f'{dim!r} on {pos} argument') 

191 return _array.StructuredArray(x) 1feabcd

192 elif not deriv.implicit: 1feabcd

193 raise ValueError('derivative on named fields with non-structured ' 1abcd

194 f'array on {pos} argument') 

195 elif not jnp.issubdtype(x.dtype, jnp.number): 1feabcd

196 raise TypeError(f'derivative along non-numeric array on ' 1abcd

197 f'{pos} argument') 

198 return x 1feabcd

199 

200 def newcore(x, y, **kw): 1feabcd

201 x = process_arg(x, xderiv, 'left') 1feabcd

202 y = process_arg(y, yderiv, 'right') 1feabcd

203 

204 args = [] 1feabcd

205 

206 if not xderiv.implicit: 1feabcd

207 for dim in xderiv: 1feabcd

208 args.append(_asfloat(x[dim])) 1feabcd

209 elif xderiv: 1feabcd

210 x = _asfloat(x) 1feabcd

211 

212 if not yderiv.implicit: 1feabcd

213 for dim in yderiv: 1feabcd

214 args.append(_asfloat(y[dim])) 1feabcd

215 elif yderiv: 1feabcd

216 y = _asfloat(y) 1feabcd

217 

218 return f(x, y, *args, **kw) 1feabcd

219 

220 return newcore 1feabcd

221 

222@CrossKernel.register_xtransf 1feabcd

223def xtransf(fun): 1feabcd

224 r""" 

225  

226 Transform the inputs of the function. 

227  

228 .. math:: 

229 T(f)(x) = f(\mathrm{fun}(x)) 

230  

231 Parameters 

232 ---------- 

233 xfun, yfun : callable or None 

234 Functions mapping a new kind of input to the kind of input accepted by 

235 the kernel. 

236  

237 """ 

238 if not callable(fun): 1abcd

239 raise ValueError("argument to 'xtransf' must be a function") 1abcd

240 return fun 1abcd

241 

242@CrossKernel.register_xtransf 1feabcd

243def dim(dim): 1feabcd

244 """ 

245 Restrict the function to a field of a structured input:: 

246 

247 T(f)(x) = f(x[dim]) 

248 

249 If the array is not structured, an exception is raised. If the field for 

250 name `dim` has a nontrivial shape, the array passed to the kernel is still 

251 structured but has only field `dim`. 

252 

253 Parameters 

254 ---------- 

255 xdim, ydim: None, str, list of str 

256 Field names or lists of field names. 

257 

258 """ 

259 if not isinstance(dim, (str, list)): 1feabcd

260 raise TypeError(f'dim must be a (list of) string, found {dim!r}') 1abcd

261 def fun(x): 1feabcd

262 if x.dtype.names is None: 1feabcd

263 raise ValueError(f'cannot get dim={dim!r} from non-structured input') 1abcd

264 elif x.dtype[dim].shape: 1feabcd

265 return x[[dim]] 1abcd

266 else: 

267 return x[dim] 1feabcd

268 return fun 1feabcd

269 

270@CrossKernel.register_xtransf 1feabcd

271def maxdim(maxdim): 1feabcd

272 """ 

273 

274 Restrict the process to a maximum input dimensionality. 

275 

276 Parameters 

277 ---------- 

278 xmaxdim, ymaxdim: None, int 

279 Maximum dimensionality of the input. 

280  

281 Notes 

282 ----- 

283 Once applied a restriction, the check is hardcoded into the kernel core and 

284 it is not possible to remove it by applying again `maxdim` with a larger 

285 limit. 

286 

287 """ 

288 if not isinstance(maxdim, numbers.Integral) or maxdim < 0: 1feabcd

289 raise ValueError(f'maximum dimensionality {maxdim!r} not valid') 1abcd

290 

291 def fun(x): 1feabcd

292 nd = _array._nd(x.dtype) 1feabcd

293 with _jaxext.skipifabstract(): 1feabcd

294 if nd > maxdim: 1feabcd

295 raise ValueError(f'kernel applied to input with {nd} ' 1abcd

296 f'fields > maxdim={maxdim}') 

297 return x 1feabcd

298 

299 return fun 1feabcd

300 

301@CrossKernel.register_xtransf 1feabcd

302def loc(loc): 1feabcd

303 r""" 

304 Translate the process inputs: 

305 

306 .. math:: 

307 T(f)(x) = f(x - \mathrm{loc}) 

308 

309 Parameters 

310 ---------- 

311 xloc, yloc: None, number 

312 Translations. 

313 

314 """ 

315 with _jaxext.skipifabstract(): 1feabcd

316 assert -jnp.inf < loc < jnp.inf, loc 1feabcd

317 return lambda x: _util.ufunc_recurse_dtype(lambda x: x - loc, x) 1feabcd

318 

319@CrossKernel.register_xtransf 1feabcd

320def scale(scale): 1feabcd

321 r""" 

322 Rescale the process inputs: 

323 

324 .. math:: 

325 T(f)(x) = f(x / \mathrm{scale}) 

326 

327 Parameters 

328 ---------- 

329 xscale, yscale: None, number 

330 Rescaling factors. 

331 

332 """ 

333 with _jaxext.skipifabstract(): 1feabcd

334 assert 0 < scale < jnp.inf, scale 1feabcd

335 return lambda x: _util.ufunc_recurse_dtype(lambda x: x / scale, x) 1feabcd

336 

337def normalize_argparser(do): 1feabcd

338 return do if do else None 1fabcd

339 

340@functools.partial(CrossKernel.register_corelinop, argparser=normalize_argparser) 1feabcd

341def normalize(core, dox, doy): 1feabcd

342 r""" 

343 Rescale the process to unit variance. 

344 

345 .. math:: 

346 T(f)(x) &= f(x) / \sqrt{\mathrm{Std}[f(x)]} \\ 

347 &= f(x) / \sqrt{\mathrm{kernel}(x, x)} 

348 

349 Parameters 

350 ---------- 

351 dox, doy : bool 

352 Whether to rescale. 

353 """ 

354 if dox and doy: 1fabcd

355 return lambda x, y, **kw: core(x, y, **kw) / jnp.sqrt(core(x, x, **kw) * core(y, y, **kw)) 1fabcd

356 elif dox: 1abcd

357 return lambda x, y, **kw: core(x, y, **kw) / jnp.sqrt(core(x, x, **kw)) 1abcd

358 else: 

359 return lambda x, y, **kw: core(x, y, **kw) / jnp.sqrt(core(y, y, **kw)) 1abcd

360 

361@CrossKernel.register_corelinop 1feabcd

362def cond(core, cond1, cond2, other): 1feabcd

363 r""" 

364 

365 Switch between two independent processes based on a condition. 

366 

367 .. math:: 

368 T(f, g)(x) = \begin{cases} 

369 f(x) & \text{if $\mathrm{cond}(x)$,} \\ 

370 g(x) & \text{otherwise.} 

371 \end{cases} 

372  

373 Parameters 

374 ---------- 

375 cond1, cond2 : callable 

376 Function that is applied on an array of points and must return 

377 a boolean array with the same shape. 

378 other : 

379 Kernel of the process used where the condition is false. 

380  

381 """ 

382 def newcore(x, y, **kw): 1eabcd

383 xcond = cond1(x) 1eabcd

384 ycond = cond2(y) 1eabcd

385 r = jnp.where(xcond & ycond, core(x, y, **kw), other(x, y, **kw)) 1eabcd

386 return jnp.where(xcond ^ ycond, 0, r) 1eabcd

387 

388 return newcore 1eabcd

389 

390 # TODO add a function `choose` to extend `cond`, 

391 # kernel0.linop('choose', kernel1, kernel2, ..., lambda x: x['index']) 

392 

393AffineSpan.inherit_transf('maxdim') 1feabcd

394AffineSpan.inherit_transf('derivable') 1feabcd

395 

396@functools.partial(AffineSpan.register_linop, transfname='loc') 1feabcd

397def affine_loc(tcls, self, xloc, yloc): 1feabcd

398 dynkw = dict(self.dynkw) 1abcd

399 newself = tcls.super_transf('loc', self, xloc, yloc) 1abcd

400 dynkw['lloc'] = dynkw['lloc'] + xloc * dynkw['lscale'] 1abcd

401 dynkw['rloc'] = dynkw['rloc'] + yloc * dynkw['rscale'] 1abcd

402 return newself._clone(self.__class__, dynkw=dynkw) 1abcd

403 

404@functools.partial(AffineSpan.register_linop, transfname='scale') 1feabcd

405def affine_scale(tcls, self, xscale, yscale): 1feabcd

406 dynkw = dict(self.dynkw) 1eabcd

407 newself = tcls.super_transf('scale', self, xscale, yscale) 1eabcd

408 dynkw['lscale'] = dynkw['lscale'] * xscale 1eabcd

409 dynkw['rscale'] = dynkw['rscale'] * yscale 1eabcd

410 return newself._clone(self.__class__, dynkw=dynkw) 1eabcd