Coverage for src/lsqfitgp/_jaxext/_fasthash.py: 100%

44 statements  

« prev     ^ index     » next       coverage.py v7.6.3, created at 2024-10-15 19:54 +0000

1# lsqfitgp/_jaxext/_fasthash.py 

2# 

3# Copyright (c) 2023, Giacomo Petrillo 

4# 

5# This file is part of lsqfitgp. 

6# 

7# lsqfitgp is free software: you can redistribute it and/or modify 

8# it under the terms of the GNU General Public License as published by 

9# the Free Software Foundation, either version 3 of the License, or 

10# (at your option) any later version. 

11# 

12# lsqfitgp is distributed in the hope that it will be useful, 

13# but WITHOUT ANY WARRANTY; without even the implied warranty of 

14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

15# GNU General Public License for more details. 

16# 

17# You should have received a copy of the GNU General Public License 

18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>. 

19 

20# JAX port of fasthash https://github.com/ztanml/fast-hash, original license: 

21 

22# The MIT License 

23# 

24# Copyright (C) 2012 Zilong Tan (eric.zltan@gmail.com) 

25# 

26# Permission is hereby granted, free of charge, to any person 

27# obtaining a copy of this software and associated documentation 

28# files (the "Software"), to deal in the Software without 

29# restriction, including without limitation the rights to use, copy, 

30# modify, merge, publish, distribute, sublicense, and/or sell copies 

31# of the Software, and to permit persons to whom the Software is 

32# furnished to do so, subject to the following conditions: 

33# 

34# The above copyright notice and this permission notice shall be 

35# included in all copies or substantial portions of the Software. 

36# 

37# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 

38# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 

39# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 

40# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 

41# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 

42# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 

43# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 

44# SOFTWARE. 

45 

46import functools 1feabcd

47 

48from jax import numpy as jnp 1feabcd

49from jax import lax 1feabcd

50import jax 1feabcd

51 

52# TODO this breaks down if jax_enable_x64=False. I have to write the operations 

53# on 64 bit integers in terms of operations on 32 bit integers, and always 

54# manipulate arrays with 2x uint32. I bet I can find such translations on the 

55# internet. I can make an interface for something that switches automatically 

56# between the two modes. 

57 

58# TODO move to _special 

59 

60# Compression function for Merkle-Damgard construction. 

61# This function is generated using the framework provided. 

62def mix(h): 1feabcd

63 h ^= h >> 23 1eabcd

64 h *= jnp.array(0x2127599bf4325c37, jnp.uint64) 1eabcd

65 h ^= h >> 47 1eabcd

66 return h 1eabcd

67 

68@functools.partial(jax.jit, static_argnames=('unroll',)) 1feabcd

69def fasthash64(buf, seed, *, unroll=4): 1feabcd

70 # buf = jnp.asarray(buf) # needed without jit 

71 seed = jnp.array(seed, jnp.uint64) 1eabcd

72 assert seed.dtype == jnp.uint64 # check that jax_enable_x64=True 1eabcd

73 assert buf.ndim == 1 1eabcd

74 buf = buf.view(jnp.uint8) 1eabcd

75 m = jnp.array(0x880355f21e6d1965, jnp.uint64) 1eabcd

76 pos = buf[:buf.size - buf.size % 8].view(jnp.uint64) 1eabcd

77 h = seed ^ (buf.size * m) 1eabcd

78 

79 def loop(carry, v): 1eabcd

80 h, m = carry 1eabcd

81 h ^= mix(v) 1eabcd

82 h *= m 1eabcd

83 return (h, m), None 1eabcd

84 (h, _), _ = lax.scan(loop, (h, m), pos, unroll=unroll) 1eabcd

85 

86 pos2 = buf[pos.nbytes:] 1eabcd

87 assert pos2.size < 8 1eabcd

88 assert pos.nbytes + pos2.size == buf.size 1eabcd

89 v = jnp.array(0, jnp.uint64) 1eabcd

90 

91 if pos2.size >= 7: v ^= pos2[6].astype(jnp.uint64) << 48 1eabcd

92 if pos2.size >= 6: v ^= pos2[5].astype(jnp.uint64) << 40 1eabcd

93 if pos2.size >= 5: v ^= pos2[4].astype(jnp.uint64) << 32 1eabcd

94 if pos2.size >= 4: v ^= pos2[3].astype(jnp.uint64) << 24 1eabcd

95 if pos2.size >= 3: v ^= pos2[2].astype(jnp.uint64) << 16 1eabcd

96 if pos2.size >= 2: v ^= pos2[1].astype(jnp.uint64) << 8 1eabcd

97 if pos2.size >= 1: v ^= pos2[0].astype(jnp.uint64) 1eabcd

98 if pos2.size: 1eabcd

99 h ^= mix(v) 1abcd

100 h *= m 1abcd

101 

102 assert h.dtype == jnp.uint64 1eabcd

103 return mix(h) 1eabcd

104 

105def fasthash32(buf, seed): 1feabcd

106 seed = jnp.array(seed, jnp.uint32) 1abcd

107 h = fasthash64(buf, seed) 1abcd

108 return (h - (h >> 32)).astype(jnp.uint32) 1abcd