Coverage for src/lsqfitgp/_jaxext/_fasthash.py: 100%
44 statements
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
« prev ^ index » next coverage.py v7.6.3, created at 2024-10-15 19:54 +0000
1# lsqfitgp/_jaxext/_fasthash.py
2#
3# Copyright (c) 2023, Giacomo Petrillo
4#
5# This file is part of lsqfitgp.
6#
7# lsqfitgp is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# lsqfitgp is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with lsqfitgp. If not, see <http://www.gnu.org/licenses/>.
20# JAX port of fasthash https://github.com/ztanml/fast-hash, original license:
22# The MIT License
23#
24# Copyright (C) 2012 Zilong Tan (eric.zltan@gmail.com)
25#
26# Permission is hereby granted, free of charge, to any person
27# obtaining a copy of this software and associated documentation
28# files (the "Software"), to deal in the Software without
29# restriction, including without limitation the rights to use, copy,
30# modify, merge, publish, distribute, sublicense, and/or sell copies
31# of the Software, and to permit persons to whom the Software is
32# furnished to do so, subject to the following conditions:
33#
34# The above copyright notice and this permission notice shall be
35# included in all copies or substantial portions of the Software.
36#
37# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
38# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
39# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
40# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
41# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
42# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
43# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
44# SOFTWARE.
46import functools 1feabcd
48from jax import numpy as jnp 1feabcd
49from jax import lax 1feabcd
50import jax 1feabcd
52# TODO this breaks down if jax_enable_x64=False. I have to write the operations
53# on 64 bit integers in terms of operations on 32 bit integers, and always
54# manipulate arrays with 2x uint32. I bet I can find such translations on the
55# internet. I can make an interface for something that switches automatically
56# between the two modes.
58# TODO move to _special
60# Compression function for Merkle-Damgard construction.
61# This function is generated using the framework provided.
62def mix(h): 1feabcd
63 h ^= h >> 23 1eabcd
64 h *= jnp.array(0x2127599bf4325c37, jnp.uint64) 1eabcd
65 h ^= h >> 47 1eabcd
66 return h 1eabcd
68@functools.partial(jax.jit, static_argnames=('unroll',)) 1feabcd
69def fasthash64(buf, seed, *, unroll=4): 1feabcd
70 # buf = jnp.asarray(buf) # needed without jit
71 seed = jnp.array(seed, jnp.uint64) 1eabcd
72 assert seed.dtype == jnp.uint64 # check that jax_enable_x64=True 1eabcd
73 assert buf.ndim == 1 1eabcd
74 buf = buf.view(jnp.uint8) 1eabcd
75 m = jnp.array(0x880355f21e6d1965, jnp.uint64) 1eabcd
76 pos = buf[:buf.size - buf.size % 8].view(jnp.uint64) 1eabcd
77 h = seed ^ (buf.size * m) 1eabcd
79 def loop(carry, v): 1eabcd
80 h, m = carry 1eabcd
81 h ^= mix(v) 1eabcd
82 h *= m 1eabcd
83 return (h, m), None 1eabcd
84 (h, _), _ = lax.scan(loop, (h, m), pos, unroll=unroll) 1eabcd
86 pos2 = buf[pos.nbytes:] 1eabcd
87 assert pos2.size < 8 1eabcd
88 assert pos.nbytes + pos2.size == buf.size 1eabcd
89 v = jnp.array(0, jnp.uint64) 1eabcd
91 if pos2.size >= 7: v ^= pos2[6].astype(jnp.uint64) << 48 1eabcd
92 if pos2.size >= 6: v ^= pos2[5].astype(jnp.uint64) << 40 1eabcd
93 if pos2.size >= 5: v ^= pos2[4].astype(jnp.uint64) << 32 1eabcd
94 if pos2.size >= 4: v ^= pos2[3].astype(jnp.uint64) << 24 1eabcd
95 if pos2.size >= 3: v ^= pos2[2].astype(jnp.uint64) << 16 1eabcd
96 if pos2.size >= 2: v ^= pos2[1].astype(jnp.uint64) << 8 1eabcd
97 if pos2.size >= 1: v ^= pos2[0].astype(jnp.uint64) 1eabcd
98 if pos2.size: 1eabcd
99 h ^= mix(v) 1abcd
100 h *= m 1abcd
102 assert h.dtype == jnp.uint64 1eabcd
103 return mix(h) 1eabcd
105def fasthash32(buf, seed): 1feabcd
106 seed = jnp.array(seed, jnp.uint32) 1abcd
107 h = fasthash64(buf, seed) 1abcd
108 return (h - (h >> 32)).astype(jnp.uint32) 1abcd