diff --git a/src/bitops.c b/src/bitops.c index 26a08acfa..8fb201168 100644 --- a/src/bitops.c +++ b/src/bitops.c @@ -382,6 +382,53 @@ static inline long long redisPopcountAuto(const unsigned char *p, long count) { #endif } +/* --------------------------------------------------------------------------- + * SIMD helpers for redisBitpos() — scan for the first word that does not + * match 'skipval' (0 when looking for bit=1, ULONG_MAX for bit=0). + * Each function returns the number of bytes scanned that all matched skipval. + * The caller advances past that many bytes and falls into the scalar tail. + * ----------------------------------------------------------------------- */ +#ifdef HAVE_AVX512 +ATTRIBUTE_TARGET_AVX512 +static unsigned long redisBitposScanAVX512(unsigned char *p, + unsigned long count, int bit) { + unsigned long scanned = 0; + __m512i skip = bit ? _mm512_setzero_si512() + : _mm512_set1_epi64(-1LL); + + while (count >= 64) { + __m512i data = _mm512_loadu_si512(p); + __mmask8 eq = _mm512_cmpeq_epi64_mask(data, skip); + if (eq != 0xFF) break; + p += 64; + count -= 64; + scanned += 64; + } + return scanned; +} +#endif + +#ifdef HAVE_AVX2 +ATTRIBUTE_TARGET_AVX2 +static unsigned long redisBitposScanAVX2(unsigned char *p, + unsigned long count, int bit) { + unsigned long scanned = 0; + __m256i skip = bit ? _mm256_setzero_si256() + : _mm256_set1_epi64x(-1LL); + + while (count >= 32) { + __m256i data = _mm256_loadu_si256((__m256i *)p); + int eq = _mm256_movemask_pd(_mm256_castsi256_pd( + _mm256_cmpeq_epi64(data, skip))); + if (eq != 0xF) break; + p += 32; + count -= 32; + scanned += 32; + } + return scanned; +} +#endif + /* Return the position of the first bit set to one (if 'bit' is 1) or * zero (if 'bit' is 0) in the bitmap starting at 's' and long 'count' bytes. * @@ -392,7 +439,7 @@ static inline long long redisPopcountAuto(const unsigned char *p, long count) { long long redisBitpos(void *s, unsigned long count, int bit) { unsigned long *l; unsigned char *c; - unsigned long skipval, word = 0, one; + unsigned long skipval, word = 0; long long pos = 0; /* Position of bit, to return to the caller. */ unsigned long j; int found; @@ -420,16 +467,44 @@ long long redisBitpos(void *s, unsigned long count, int bit) { pos += 8; } - /* Skip bits with full word step. */ - l = (unsigned long*) c; + /* Skip bits with full word step. Use SIMD when available for the + * bulk of the scan, then fall through to scalar for the tail. */ if (!found) { skipval = bit ? 0 : ULONG_MAX; + +#if defined(HAVE_AVX512) || defined(HAVE_AVX2) + int useAVX = 0; +#endif + +#if defined(HAVE_AVX512) + if (BITOP_USE_AVX512 && count >= 64) { + unsigned long advanced = redisBitposScanAVX512(c, count, bit); + c += advanced; + count -= advanced; + pos += advanced * 8; + useAVX = 1; + } +#endif + +#if defined(HAVE_AVX2) + if (!useAVX && BITOP_USE_AVX2 && count >= 32) { + unsigned long advanced = redisBitposScanAVX2(c, count, bit); + c += advanced; + count -= advanced; + pos += advanced * 8; + } +#endif + + /* Scalar word-at-a-time scan handles the tail after SIMD and + * serves as the sole scan path when SIMD is unavailable. */ + l = (unsigned long *)c; while (count >= sizeof(*l)) { if (*l != skipval) break; l++; count -= sizeof(*l); pos += sizeof(*l)*8; } + c = (unsigned char *)l; } /* Load bytes into "word" considering the first byte as the most significant @@ -439,7 +514,6 @@ long long redisBitpos(void *s, unsigned long count, int bit) { * * Note that the loading is designed to work even when the bytes left * (count) are less than a full word. We pad it with zero on the right. */ - c = (unsigned char*)l; for (j = 0; j < sizeof(*l); j++) { word <<= 8; if (count) { @@ -456,24 +530,19 @@ long long redisBitpos(void *s, unsigned long count, int bit) { * that the right of the string is zero padded. */ if (bit == 1 && word == 0) return -1; - /* Last word left, scan bit by bit. The first thing we need is to - * have a single "1" set in the most significant position in an - * unsigned long. We don't know the size of the long so we use a - * simple trick. */ - one = ULONG_MAX; /* All bits set to 1.*/ - one >>= 1; /* All bits set to 1 but the MSB. */ - one = ~one; /* All bits set to 0 but the MSB. */ - - while(one) { - if (((one & word) != 0) == bit) return pos; - pos++; - one >>= 1; - } - - /* If we reached this point, there is a bug in the algorithm, since - * the case of no match is handled as a special case before. */ - serverPanic("End of redisBitpos() reached."); - return 0; /* Just to avoid warnings. */ + /* Last word left, find the position of the first matching bit. + * __builtin_clzl gives the count of leading zeros in an unsigned long, + * which is exactly the bit offset from MSB to the first set bit. + * For bit=0 we invert the word first to find the first zero bit. + * + * Safety: __builtin_clzl is undefined for a zero argument, but that + * cannot happen here: + * - bit==1: the 'if (bit == 1 && word == 0) return -1' above guards it. + * - bit==0: the skip-word loop consumes all words equal to ULONG_MAX + * (skipval), so the word loaded here satisfies word != ULONG_MAX, + * meaning ~word != 0. */ + pos += bit ? __builtin_clzl(word) : __builtin_clzl(~word); + return pos; } /* The following set.*Bitfield and get.*Bitfield functions implement setting