Counting 1 bits (population count) on large data using AVX-512 or AVX-2

AVX-2

@HadiBreis’ comment links to an article on fast population-count with SSSE3, by Wojciech Muła; the article links to this GitHub repository; and the repository has the following AVX-2 implementation. It’s based on a vectorized lookup instruction, and using a 16-value lookup table for the bit counts of nibbles.

#   include <immintrin.h>
#   include <x86intrin.h>

std::uint64_t popcnt_AVX2_lookup(const uint8_t* data, const size_t n) {

    size_t i = 0;

    const __m256i lookup = _mm256_setr_epi8(
        /* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
        /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
        /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
        /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4,

        /* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
        /* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
        /* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
        /* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4
    );

    const __m256i low_mask = _mm256_set1_epi8(0x0f);

    __m256i acc = _mm256_setzero_si256();

#define ITER { \
        const __m256i vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data + i)); \
        const __m256i lo  = _mm256_and_si256(vec, low_mask); \
        const __m256i hi  = _mm256_and_si256(_mm256_srli_epi16(vec, 4), low_mask); \
        const __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo); \
        const __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi); \
        local = _mm256_add_epi8(local, popcnt1); \
        local = _mm256_add_epi8(local, popcnt2); \
        i += 32; \
    }

    while (i + 8*32 <= n) {
        __m256i local = _mm256_setzero_si256();
        ITER ITER ITER ITER
        ITER ITER ITER ITER
        acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
    }

    __m256i local = _mm256_setzero_si256();

    while (i + 32 <= n) {
        ITER;
    }

    acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));

#undef ITER

    uint64_t result = 0;

    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 0));
    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 1));
    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 2));
    result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 3));

    for (/**/; i < n; i++) {
        result += lookup8bit[data[i]];
    }

    return result;
}

AVX-512

The same repository also has a VPOPCNT-based AVX-512 implementation. Before listing the code for it, here’s the simplified and more readable pseudocode:

  • For every consecutive sequence of 64 bytes:

    • Load the sequence into a SIMD register with 64×8 = 512 bits
    • Perform 8 parallel population counts of 64 bits each on that register
    • Add the 8 population-count results in parallel, into an “accumulator” register holding 8 sums
  • Sum up the 8 values in the accumulator

  • If there’s a tail of less than 64 bytes, count the bits there in some simpler way

  • Return the main sum plus the tail sum

And now for the real deal:

#   include <immintrin.h>
#   include <x86intrin.h>

uint64_t avx512_vpopcnt(const uint8_t* data, const size_t size) {
    
    const size_t chunks = size / 64;

    uint8_t* ptr = const_cast<uint8_t*>(data);
    const uint8_t* end = ptr + size;

    // count using AVX512 registers
    __m512i accumulator = _mm512_setzero_si512();
    for (size_t i=0; i < chunks; i++, ptr += 64) {
        
        // Note: a short chain of dependencies, likely unrolling will be needed.
        const __m512i v = _mm512_loadu_si512((const __m512i*)ptr);
        const __m512i p = _mm512_popcnt_epi64(v);

        accumulator = _mm512_add_epi64(accumulator, p);
    }

    // horizontal sum of a register
    uint64_t tmp[8] __attribute__((aligned(64)));
    _mm512_store_si512((__m512i*)tmp, accumulator);

    uint64_t total = 0;
    for (size_t i=0; i < 8; i++) {
        total += tmp[i];
    }

    // popcount the tail
    while (ptr + 8 < end) {
        total += _mm_popcnt_u64(*reinterpret_cast<const uint64_t*>(ptr));
        ptr += 8;
    }

    while (ptr < end) {
        total += lookup8bit[*ptr++];
    }

    return total;
}

The lookup8bit is a popcnt lookup table for bytes rather than bits, and is defined here. edit: As commenters note, using an 8-bit lookup table at the end is not a very good idea and can be improved on.

Leave a Comment