Modular arithmetics and NTT (finite field DFT) optimizations

First off, thank you very much for posting and making it free to use. I really appreciate that.

I was able to use some bit tricks to eliminate some branching, rearranged the main loop, and modified the assembly, and was able to get a 1.35x speedup.

Also, I added a preprocessor condition for 64 bit, seeing as Visual Studio doesn’t allow inline assembly in 64 bit mode (thank you Microsoft; feel free to go screw yourself).

Something strange happened when I was optimizing the modsub() function. I rewrote it using bit hacks like I did modadd (which was faster). But for some reason, the bit wise version of modsub was slower. Not sure why. Might just be my computer.

//
// Mandalf The Beige
// Based on:
// Spektre
// http://stackoverflow.com/questions/18577076/modular-arithmetics-and-ntt-finite-field-dft-optimizations
//
// This code may be freely used however you choose, so long as it is accompanied by this notice.
//




#ifndef H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR
#define H__OPTIMIZED_NUMBER_THEORETIC_TRANSFORM__HDR

#include <string.h>

#ifndef uint32
#define uint32 unsigned long int
#endif

#ifndef uint64
#define uint64 unsigned long long int
#endif


class fast_ntt                                   // number theoretic transform
{
    public:
    fast_ntt()
    {
        r = 0; L = 0;
        W = 0; iW = 0; rN = 0;
    }
    // main interface
    void  NTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast  NTT(uint32 src[n])
    void INTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast INTT(uint32 src[n])
    // helper functions

    private:
    bool init(uint32 n);                                     // init r,L,p,W,iW,rN
    void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])

    void  NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])
    void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);
    // only for testing
    void  NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = slow  NTT(uint32 src[n])
    void INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = slow INTT(uint32 src[n])
    // uint32 arithmetics


    // modular arithmetics
    inline uint32 modadd(uint32 a, uint32 b);
    inline uint32 modsub(uint32 a, uint32 b);
    inline uint32 modmul(uint32 a, uint32 b);
    inline uint32 modpow(uint32 a, uint32 b);

    uint32 r, L, N;//, p;
    uint32 W, iW, rN;

    const uint32 p = 0xC0000001;
};

//---------------------------------------------------------------------------
void fast_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
{
    if (n > 0)
    {
        init(n);
    }
    NTT_fast(dst, src, N, W);
    //  NTT_slow(dst,src,N,W);
}

//---------------------------------------------------------------------------
void fast_ntt::INTT(uint32 *dst, uint32 *src, uint32 n)
{
    if (n > 0)
    {
        init(n);
    }
    NTT_fast(dst, src, N, iW);
    for (uint32 i = 0; i<N; i++)
    {
        dst[i] = modmul(dst[i], rN);
    }
    //  INTT_slow(dst,src,N,W);
}

//---------------------------------------------------------------------------
bool fast_ntt::init(uint32 n)
{
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r = 2;
    //p = 0xC0000001;
    if ((n < 2) || (n > 0x10000000))
    {
        r = 0; L = 0; W = 0; // p = 0;
        iW = 0; rN = 0; N = 0;
        return false;
    }
    L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit
    //  r=2; p=0x78000001; if ((n<2)||(n>0x04000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x3c000000/n; // 31:27 bit best for signed 32 bit
    //  r=2; p=0x00010001; if ((n<2)||(n>0x00000020)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x00000020/n; // 17:16 bit best for 16 bit
    //  r=2; p=0x0a000001; if ((n<2)||(n>0x01000000)) { r=0; L=0; p=0; W=0; iW=0; rN=0; N=0; return false; } L=0x01000000/n; // 28:25 bit
    N = n;               // size of vectors [uint32s]
    W = modpow(r, L); // Wn for NTT
    iW = modpow(r, p - 1 - L); // Wn for INTT
    rN = modpow(n, p - 2); // scale for INTT
    return true;
}

//---------------------------------------------------------------------------

void fast_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if(n > 1)
    {
        if(dst != src)
        {
            NTT_calc(dst, src, n, w);
        }
        else
        {
            uint32* temp = new uint32[n];
            NTT_calc(temp, src, n, w);
            memcpy(dst, temp, n * sizeof(uint32));
            delete [] temp;
        }
    }
    else if(n == 1)
    {
        dst[0] = src[0];
    }
}

void fast_ntt::NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w)
{
    if (n > 1)
    {
        uint32* temp = new uint32[n];
        memcpy(temp, src, n * sizeof(uint32));
        NTT_calc(dst, temp, n, w);
        delete[] temp;
    }
    else if (n == 1)
    {
        dst[0] = src[0];
    }
}



void fast_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if(n > 1)
    {
        uint32 i, j, a0, a1,
        n2 = n >> 1,
        w2 = modmul(w, w);

        // reorder even,odd
        for (i = 0, j = 0; i < n2; i++, j += 2)
        {
            dst[i] = src[j];
        }
        for (j = 1; i < n; i++, j += 2)
        {
            dst[i] = src[j];
        }
        // recursion
        if(n2 > 1)
        {
            NTT_calc(src, dst, n2, w2);  // even
            NTT_calc(src + n2, dst + n2, n2, w2);  // odd
        }
        else if(n2 == 1)
        {
            src[0] = dst[0];
            src[1] = dst[1];
        }

        // restore results

        w2 = 1, i = 0, j = n2;
        a0 = src[i];
        a1 = src[j];
        dst[i] = modadd(a0, a1);
        dst[j] = modsub(a0, a1);
        while (++i < n2)
        {
            w2 = modmul(w2, w);
            j++;
            a0 = src[i];
            a1 = modmul(src[j], w2);
            dst[i] = modadd(a0, a1);
            dst[j] = modsub(a0, a1);
        }
    }
}

//---------------------------------------------------------------------------
void fast_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    uint32 i, j, wj, wi, a,
        n2 = n >> 1;
    for (wj = 1, j = 0; j < n; j++)
    {
        a = 0;
        for (wi = 1, i = 0; i < n; i++)
        {
            a = modadd(a, modmul(wi, src[i]));
            wi = modmul(wi, wj);
        }
        dst[j] = a;
        wj = modmul(wj, w);
    }
}

//---------------------------------------------------------------------------
void fast_ntt::INTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    uint32 i, j, wi = 1, wj = 1, a, n2 = n >> 1;

    for (wj = 1, j = 0; j < n; j++)
    {
        a = 0;
        for (wi = 1, i = 0; i < n; i++)
        {
            a = modadd(a, modmul(wi, src[i]));
            wi = modmul(wi, wj);
        }
        dst[j] = modmul(a, rN);
        wj = modmul(wj, iW);
    }
}    


//---------------------------------------------------------------------------
uint32 fast_ntt::modadd(uint32 a, uint32 b)
{
    uint32 d;
    d = a + b;

    if(d < a)
    {
        d -= p;
    }
    if (d >= p)
    {
        d -= p;
    }
    return d;
}

//---------------------------------------------------------------------------
uint32 fast_ntt::modsub(uint32 a, uint32 b)
{
    uint32 d;
    d = a - b;
    if (d > a)
    {
        d += p;
    }
    return d;
}

//---------------------------------------------------------------------------
uint32 fast_ntt::modmul(uint32 a, uint32 b)
{
    uint32 _a = a;
    uint32 _b = b;

    // Original
    uint32 _p = p;
    __asm
    {
        mov eax, _a;
        mul _b;
        div _p;
        mov eax, edx;
    };
}


uint32 fast_ntt::modpow(uint32 a, uint32 b)
{
    //*
    uint64 D, M, A, P;

    P = p; A = a;
    M = 0llu - (b & 1);
    D = (M & A) | ((~M) & 1);

    while ((b >>= 1) != 0)
    {
        A = modmul(A, A);
        //A = (A * A) % P;

        if ((b & 1) == 1)
        {
            //D = (D * A) % P;
            D = modmul(D, A);
        }
    }
    return (uint32)D;
}

New modmul

uint32 fast_ntt::modmul(uint32 a, uint32 b)
{
    uint32 _a = a;
    uint32 _b = b;   

    __asm
    {
    mov eax, a;
    mul b;
    mov ebx, eax;
    mov eax, 2863311530;
    mov ecx, edx;
    mul edx;
    shld edx, eax, 1;
    mov eax, 3221225473;

    mul edx;
    sub ebx, eax;
    mov eax, 3221225473;
    sbb ecx, edx;
    jc addback;

            neg ecx;
            and ecx, eax;
            sub ebx, ecx;

    sub ebx, eax;
    sbb edx, edx;
    and eax, edx;
            addback:
    add eax, ebx;          
    };  
}

[EDIT]
Spektre, based on your feedback I changed the modadd & modsub back to their original. I also realized I made some changes to the recursive NTT function I shouldn’t have.

[EDIT2]
Removed unneeded if statements and bitwise functions.

[EDIT3]
Added new modmul inline assembly.

Leave a Comment