Coverage for src/lsqfitgp/_kernels/_basic.py: 100%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_kernels/_basic.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20import sys 1feabcd

21import re 1feabcd

22import collections 1feabcd

23 

24import numpy 1feabcd

25import jax 1feabcd

26from jax import numpy as jnp 1feabcd

27from jax.scipy import special as jspecial 1feabcd

28 

29from .. import _special 1feabcd

30from .. import _jaxext 1feabcd

31from .. import _Kernel 1feabcd

32from .._Kernel import kernel, stationarykernel, isotropickernel 1feabcd

33 

34@isotropickernel(derivable=True, input='raw') 1feabcd

35def Constant(x, y): 1feabcd

36 """ 

37 Constant kernel. 

38  

39 .. math:: 

40 k(x, y) = 1 

41  

42 This means that all points are completely correlated, thus it is equivalent 

43 to fitting with a horizontal line. This can be seen also by observing that 

44 1 = 1 x 1. 

45 """ 

46 return jnp.ones(jnp.broadcast_shapes(x.shape, y.shape)) 1eabcd

47 

48@isotropickernel(derivable=False, input='raw') 1feabcd

49def White(x, y): 1feabcd

50 """ 

51 White noise kernel. 

52  

53 .. math:: 

54 k(x, y) = \\begin{cases} 

55 1 & x = y \\\\ 

56 0 & x \\neq y 

57 \\end{cases} 

58 """ 

59 return _Kernel.prod_recurse_dtype(lambda x, y: x == y, x, y).astype(int) 1feabcd

60 # TODO maybe StructuredArray should support equality and other operations 

61 

62@isotropickernel(derivable=True) 1feabcd

63def ExpQuad(r2): 1feabcd

64 """ 

65 Exponential quadratic kernel. 

66  

67 .. math:: 

68 k(r) = \\exp \\left( -\\frac 12 r^2 \\right) 

69  

70 It is smooth and has a strict typical lengthscale, i.e., oscillations are 

71 strongly suppressed under a certain wavelength, and correlations are 

72 strongly suppressed over a certain distance. 

73  

74 Reference: Rasmussen and Williams (2006, p. 83). 

75 """ 

76 return jnp.exp(-1/2 * r2) 1feabcd

77 

78def _dot(x, y): 1feabcd

79 return _Kernel.sum_recurse_dtype(lambda x, y: x * y, x, y) 1eabcd

80 

81@kernel(derivable=True) 1feabcd

82def Linear(x, y): 1feabcd

83 """ 

84 Dot product kernel. 

85  

86 .. math:: 

87 k(x, y) = x \\cdot y = \\sum_i x_i y_i 

88  

89 In 1D it is equivalent to fitting with a line passing by the origin. 

90 

91 Reference: Rasmussen and Williams (2006, p. 89). 

92 """ 

93 return _dot(x, y) 1abcd

94 

95@isotropickernel(derivable=lambda gamma=1: gamma == 2) 1feabcd

96def GammaExp(r2, gamma=1): 1feabcd

97 """ 

98 Gamma exponential kernel. 

99  

100 .. math:: 

101 k(r) = \\exp(-r^\\gamma), \\quad 

102 \\gamma \\in (0, 2] 

103  

104 For :math:`\\gamma = 2` it is the squared exponential kernel, for 

105 :math:`\\gamma = 1` (default) it is the Matérn 1/2 kernel, for 

106 :math:`\\gamma \\to 0` it tends to white noise plus a constant. The process 

107 is differentiable only for :math:`\\gamma = 2`, however as :math:`\\gamma` 

108 gets closer to 2 the variance of the non-derivable component goes to zero. 

109 

110 Reference: Rasmussen and Williams (2006, p. 86). 

111 """ 

112 with _jaxext.skipifabstract(): 1fabcd

113 assert 0 < gamma <= 2, gamma 1fabcd

114 nondiff = jnp.exp(-(r2 ** (gamma / 2))) 1fabcd

115 diff = jnp.exp(-r2) 1fabcd

116 return jnp.where(gamma == 2, diff, nondiff) 1fabcd

117 # I need to keep separate the case where derivatives w.r.t. r2 could be 

118 # computed because the second derivative at x=0 of x^p with floating 

119 # point p is nan. 

120 

121 # TODO extend to gamma=0, the correct limit is 

122 # e^-1 constant + (1 - e^-1) white noise. Use lax.switch. 

123 

124 # TODO derivatives w.r.t. gamma at gamma==2 are probably broken, although 

125 # I guess they are not needed since it's on the boundary of the domain 

126 

127@kernel(derivable=True) 1feabcd

128def NNKernel(x, y, sigma0=1): 1feabcd

129 """ 

130 Neural network kernel. 

131  

132 .. math:: 

133 k(x, y) = \\frac 2 \\pi 

134 \\arcsin \\left( \\frac 

135 { 

136 2 (q + x \\cdot y) 

137 }{ 

138 (1 + 2 (q + x \\cdot x)) 

139 (1 + 2 (q + y \\cdot y)) 

140 } 

141 \\right), 

142 \\quad q = \\texttt{sigma0}^2 

143  

144 Kernel which is equivalent to a neural network with one infinite hidden 

145 layer with Gaussian priors on the weights and error function response. In 

146 other words, you can think of the process as a superposition of sigmoids 

147 where ``sigma0`` sets the dispersion of the centers of the sigmoids. 

148  

149 Reference: Rasmussen and Williams (2006, p. 90). 

150 """ 

151 

152 # TODO the `2`s in the formula are a bit arbitrary. Remove them or give 

153 # motivation relative to the precise formulation of the neural network. 

154 with _jaxext.skipifabstract(): 1eabcd

155 assert 0 < sigma0 < jnp.inf 1eabcd

156 q = sigma0 ** 2 1eabcd

157 denom = (1 + 2 * (q + _dot(x, x))) * (1 + 2 * (q + _dot(y, y))) 1eabcd

158 return 2/jnp.pi * jnp.arcsin(2 * (q + _dot(x, y)) / denom) 1eabcd

159 

160 # TODO this is not fully equivalent to an arbitrary transformation on the 

161 # augmented vector even if x and y are transformed, unless I support q 

162 # being a vector or an additional parameter. 

163 

164 # TODO if arcsin has positive taylor coefficients, this can be obtained as 

165 # arcsin(1 + linear) * rescaling. 

166 

167@kernel 1feabcd

168def Gibbs(x, y, scalefun=lambda x: 1): 1feabcd

169 """ 

170 Gibbs kernel. 

171  

172 .. math:: 

173 k(x, y) = \\sqrt{ \\frac {2 s(x) s(y)} {s(x)^2 + s(y)^2} } 

174 \\exp \\left( -\\frac {(x - y)^2} {s(x)^2 + s(y)^2} \\right), 

175 \\quad s = \\texttt{scalefun}. 

176  

177 Kernel which in some sense is like a Gaussian kernel where the scale 

178 changes at every point. The scale is computed by the parameter `scalefun` 

179 which must be a callable taking the x array and returning a scale for each 

180 point. By default ``scalefun`` returns 1 so it is a Gaussian kernel. 

181  

182 Consider that the default parameter ``scale`` acts before ``scalefun``, so 

183 for example if ``scalefun(x) = x`` then ``scale`` has no effect. You should 

184 include all rescalings in ``scalefun`` to avoid surprises. 

185  

186 Reference: Rasmussen and Williams (2006, p. 93). 

187 """ 

188 sx = scalefun(x) 1eabcd

189 sy = scalefun(y) 1eabcd

190 with _jaxext.skipifabstract(): 1eabcd

191 assert jnp.all(sx > 0) 1eabcd

192 assert jnp.all(sy > 0) 1eabcd

193 denom = sx ** 2 + sy ** 2 1eabcd

194 factor = jnp.sqrt(2 * sx * sy / denom) 1eabcd

195 distsq = _Kernel.sum_recurse_dtype(lambda x, y: (x - y) ** 2, x, y) 1eabcd

196 return factor * jnp.exp(-distsq / denom) 1eabcd

197 

198@stationarykernel(derivable=True, maxdim=1) 1feabcd

199def Periodic(delta, outerscale=1): 1feabcd

200 r""" 

201 Periodic Gaussian kernel. 

202  

203 .. math:: 

204 k(\Delta) = \exp \left( 

205 -2 \left( 

206 \frac {\sin(\Delta / 2)} {\texttt{outerscale}} 

207 \right)^2 

208 \right) 

209  

210 A Gaussian kernel over a transformed periodic space. It represents a 

211 periodic process. The usual `scale` parameter sets the period, with the 

212 default ``scale=1`` giving a period of 2π, while `outerscale` sets the 

213 length scale of the correlations. 

214  

215 Reference: Rasmussen and Williams (2006, p. 92). 

216 """ 

217 with _jaxext.skipifabstract(): 1fabcd

218 assert 0 < outerscale < jnp.inf 1fabcd

219 return jnp.exp(-2 * (jnp.sin(delta / 2) / outerscale) ** 2) 1fabcd

220 

221@kernel(derivable=False, maxdim=1) 1feabcd

222def Categorical(x, y, cov=None): 1feabcd

223 r""" 

224 Categorical kernel. 

225  

226 .. math:: 

227 k(x, y) = \texttt{cov}[x, y] 

228  

229 A kernel over integers from 0 to N-1. The parameter `cov` is the covariance 

230 matrix of the values. 

231 """ 

232 

233 # TODO support sparse matrix for cov (replace jnp.asarray and numpy 

234 # check) 

235 

236 assert jnp.issubdtype(x.dtype, jnp.integer) 1fabcd

237 cov = jnp.asarray(cov) 1fabcd

238 assert cov.ndim == 2 1fabcd

239 assert cov.shape[0] == cov.shape[1] 1fabcd

240 with _jaxext.skipifabstract(): 1fabcd

241 assert jnp.allclose(cov, cov.T) 1fabcd

242 return cov[x, y] 1fabcd

243 

244@kernel 1feabcd

245def Rescaling(x, y, stdfun=None): 1feabcd

246 r""" 

247 Outer product kernel. 

248  

249 .. math:: 

250 k(x, y) = \texttt{stdfun}(x) \texttt{stdfun}(y) 

251  

252 A totally correlated kernel with arbitrary variance. Parameter `stdfun` 

253 must be a function that takes ``x`` or ``y`` and computes the standard 

254 deviation at the point. It can yield negative values; points with the same 

255 sign of `stdfun` will be totally correlated, points with different sign will 

256 be totally anticorrelated. Use this kernel to modulate the variance of 

257 other kernels. By default `stdfun` returns a constant, so it is equivalent 

258 to `Constant`. 

259  

260 """ 

261 if stdfun is None: 1fabcd

262 stdfun = lambda x: jnp.ones(x.shape) 1abcd

263 # do not use np.ones_like because it does not recognize StructuredArray 

264 # do not use x.dtype because it could be structured 

265 return stdfun(x) * stdfun(y) 1fabcd

266 

267@stationarykernel(derivable=False, input='abs', maxdim=1) 1feabcd

268def Expon(delta): 1feabcd

269 """ 

270 Exponential kernel. 

271  

272 .. math:: 

273 k(\\Delta) = \\exp(-|\\Delta|) 

274  

275 In 1D it is equivalent to the Matérn 1/2 kernel, however in more dimensions 

276 it acts separately while the Matérn kernel is isotropic. 

277 

278 Reference: Rasmussen and Williams (2006, p. 85). 

279 """ 

280 return jnp.exp(-delta) 1abcd

281 

282 # TODO rename Laplace, write it in terms of the 1-norm directly, then do 

283 # TruncLaplace that reaches zero over a prescribed box. (Or truncate with 

284 # an option truncbox=[(l0, r0), (l1, r1), ...]). 

285 

286_bow_regexp = re.compile(r'\s|[!«»"“”‘’/()\'?¡¿„‚<>,;.:-–—]') 1feabcd

287 

288@kernel(derivable=False, maxdim=1) 1feabcd

289@numpy.vectorize 1feabcd

290def BagOfWords(x, y): 1feabcd

291 """ 

292 Bag of words kernel. 

293  

294 .. math:: 

295 k(x, y) &= \\sum_{w \\in \\text{words}} c_w(x) c_w(y), \\\\ 

296 c_w(x) &= \\text{number of times word $w$ appears in $x$} 

297  

298 The words are defined as non-empty substrings delimited by spaces or one of 

299 the following punctuation characters: ! « » " “ ” ‘ ’ / ( ) ' ? ¡ ¿ „ ‚ < > 

300 , ; . : - – —. 

301 

302 Reference: Rasmussen and Williams (2006, p. 100). 

303 """ 

304 

305 # TODO precompute the bags for x and y, then call a vectorized private 

306 # function. 

307 

308 # TODO iterate on the shorter bag and use get on the other instead of 

309 # computing set intersection? Or: convert words to integers and then do 

310 # set intersection with sorted arrays? 

311 

312 xbag = collections.Counter(_bow_regexp.split(x)) 1abcd

313 ybag = collections.Counter(_bow_regexp.split(y)) 1abcd

314 xbag[''] = 0 # why this? I can't recall 1abcd

315 ybag[''] = 0 1abcd

316 common = set(xbag) & set(ybag) 1abcd

317 return sum(xbag[k] * ybag[k] for k in common) 1abcd

318 

319# TODO add bag of characters and maybe other text kernels 

320 

321@stationarykernel(derivable=False, input='abs', maxdim=1) 1feabcd

322def HoleEffect(delta): 1feabcd

323 """ 

324  

325 Hole effect kernel. 

326  

327 .. math:: k(\\Delta) = (1 - \\Delta) \\exp(-\\Delta) 

328  

329 Reference: Dietrich and Newsam (1997, p. 1096). 

330  

331 """ 

332 return (1 - delta) * jnp.exp(-delta) 1abcd

333 

334def _cauchy_derivable(alpha=2, **_): 1feabcd

335 return alpha == 2 1eabcd

336 

337@isotropickernel(derivable=_cauchy_derivable) 1feabcd

338def Cauchy(r2, alpha=2, beta=2): 1feabcd

339 r""" 

340 Generalized Cauchy kernel. 

341  

342 .. math:: 

343 k(r) = \left(1 + \frac{r^\alpha}{\beta} \right)^{-\beta/\alpha}, 

344 \quad \alpha \in (0, 2], \beta > 0. 

345  

346 In the geostatistics literature, the case :math:`\alpha=2` and 

347 :math:`\beta=2` (default) is known as the Cauchy kernel. In the machine 

348 learning literature, the case :math:`\alpha=2` (for any :math:`\beta`) is 

349 known as the rational quadratic kernel. For :math:`\beta\to\infty` it is 

350 equivalent to ``GammaExp(gamma=alpha, scale=alpha ** (1/alpha))``, while 

351 for :math:`\beta\to 0` to ``Constant``. It is smooth only for 

352 :math:`\alpha=2`. 

353  

354 References: Gneiting and Schlather (2004, p. 273), Rasmussen and Williams 

355 (2006, p. 86). 

356  

357 """ 

358 with _jaxext.skipifabstract(): 1eabcd

359 assert 0 < alpha <= 2, alpha 1eabcd

360 assert 0 < beta, beta 1eabcd

361 power = jnp.where(alpha == 2, r2, r2 ** (alpha / 2)) 1eabcd

362 # I need to keep separate the case where derivatives w.r.t. r2 could be 

363 # computed because the second derivative at x=0 of x^p with floating 

364 # point p is nan. 

365 return (1 + power / beta) ** (-beta / alpha) 1eabcd

366 

367 # TODO derivatives w.r.t. alpha at alpha==2 are probably broken, although 

368 # I guess they are not needed since it's on the boundary of the domain 

369 

370@isotropickernel(derivable=lambda alpha=1: alpha == 0, input='posabs') 1feabcd

371def CausalExpQuad(r, alpha=1): 1feabcd

372 r""" 

373 Causal exponential quadratic kernel. 

374  

375 .. math:: 

376 k(r) = \big(1 - \operatorname{erf}(\alpha r/4)\big) 

377 \exp\left(-\frac12 r^2 \right) 

378  

379 From https://github.com/wesselb/mlkernels. 

380 """ 

381 with _jaxext.skipifabstract(): 1abcd

382 assert alpha >= 0, alpha 1abcd

383 return jspecial.erfc(alpha / 4 * r) * jnp.exp(-1/2 * jnp.square(r)) 1abcd

384 # TODO taylor-expand erfc near 0 and use r2 

385 

386 # TODO is the erfc part a standalone valid kernel? If so, separate it, 

387 # since this can be obtained as the product 

388 

389@kernel(derivable=True, maxdim=1) 1feabcd

390def Decaying(x, y, alpha=1): 1feabcd

391 r""" 

392 Decaying kernel. 

393  

394 .. math:: 

395 k(x, y) = 

396 \frac{1}{(1 + x + y)^\alpha}, 

397 \quad x, y, \alpha \ge 0 

398  

399 Reference: Swersky, Snoek and Adams (2014). 

400 """ 

401 # TODO high dimensional version of this, see mlkernels issue #3 

402 with _jaxext.skipifabstract(): 1abcd

403 assert jnp.all(x >= 0) 1abcd

404 assert jnp.all(y >= 0) 1abcd

405 return 1 / (x + y + 1) ** alpha 1abcd

406 # use x + y + 1 instead of 1 + x + y because the latter is less numerically 

407 # accurate and symmetric for small x and y 

408 

409@isotropickernel(derivable=False, input='posabs') 1feabcd

410def Log(r): 1feabcd

411 """ 

412 Log kernel. 

413  

414 .. math:: 

415 k(r) = \\log(1 + r) / r 

416  

417 From https://github.com/wesselb/mlkernels. 

418 """ 

419 return jnp.log1p(r) / r 1abcd

420 

421@kernel(derivable=True, maxdim=1) 1feabcd

422def Taylor(x, y): 1feabcd

423 """ 

424 Exponential-like power series kernel. 

425  

426 .. math:: 

427 k(x, y) = \\sum_{k=0}^\\infty \\frac {x^k}{k!} \\frac {y^k}{k!} 

428 = I_0(2 \\sqrt{xy}) 

429  

430 It is equivalent to fitting with a Taylor series expansion in zero with 

431 independent priors on the coefficients k with mean zero and standard 

432 deviation 1/k!. 

433 """ 

434 

435 mul = x * y 1abcd

436 val = 2 * jnp.sqrt(jnp.abs(mul)) 1abcd

437 return jnp.where(mul >= 0, jspecial.i0(val), _special.j0(val)) 1abcd

438 

439 # TODO reference? Maybe it's called bessel kernel in the literature? 

440 # => nope, bessel kernel is the J_v one 

441 

442 # TODO what is the "natural" extension of this to multidim? 

443 

444 # TODO probably the rescaled version of this (e^-x) makes more sense