DFT/FFT/NTT


在Seal庫和HElib庫中都用到了NTT技術,用於加快多項式計算,而NTT又是FFT的優化,FFT又來自於DFT,現在具體學習一下這三個技術!

基礎概念

名詞區分

1、DFT:離散傅立葉變換
2、FFT:快速傅立葉變換
3、NTT:快速數論變換
4、MTT:NTT的擴展
5、多項式卷積:多項式乘法
6、根據多項式的系數表示法求點值表示法的過程叫做“求值”;根據點值表示法求系數表示法的過程稱為“插值”
7、求一個多項式的乘法,即求卷積,先通過傅立葉變換對系數表示法的多項式進行求值運算,其復雜度\(O(nlog^n)\),然后在\(O(n)\)的時間內點值相乘,在進行插值運算。

8、如果選取單位復根作為求值點,則可以對系數向量進行離散傅立葉變換(DFT),得到相應的點值表示;同樣可以通過對點值進行逆DFT運算,獲得相應的系數向量。DFT和逆DFT時間復雜度均為\(O(nlog^n)\)

復數

定義

我們知道,一個復數可以這樣表示:\(a+bi\),a和b是實數,其中\(i\)叫做虛數單位,復數域是目前已知最大的域。
在復平面中,x軸代表實數,y軸(除原點外的點)代表虛數,從原點(0,0)到(a,b)的向量表示復數\(a+bi\)

模長:從原點(0,0)到(a,b)的距離,即\(\sqrt{a^2+b^2}\)
幅角:假設以逆時針為正方向,從x軸正半軸到已知向量的轉角的有向角叫做幅角

運算

1、加法
在復數平面,復數可以表示為向量,因為復數的加法和向量的加法相同。
2、乘法
幾何定義:復數相乘密,模長相乘,幅角相加
代數定義:

\[(a+bi)*(c+di)=ac+adi+bci+bdi^2+ac+adi+bci-bd=(ac-bd)+(bc+ad)i \]

單位根

在復數平面上,以原點為圓心,1為半徑做圓,所得的圓叫做單位圓,以圓點為起點,圓的n等分為終點,做n個向量,設幅角為正且最小的向量對應的復數為w_n$,稱為n次單位根。

根據復數乘法的運算法則,其余n-1個復數為\(w_n^2,...,w_n^n\),注意\(w_n^0=w_n^n=1\)(對應復平面上以x軸為正方向的向量)

如何計算呢?
由歐拉公式解決\(w_n^k=cos(k*2\pi/n)+isin(k*2\pi/n)\)
例如:向量AB表示的是復數為4次單位根

n次單位根的幅角為周角的\(1/n\)

在代數中,若\(z_n=1\),我們把z稱為n次單位根。具體請參考:n次單位根(n-th unit root)

單位根的性質

1、\(w_n^k=cos(k*2\pi/n)+isin(k*2\pi/n)\)

2、【相消引理】\(w_{dn}^{dk}=w_n^k\)
證明:以d=2為例

3、【折半引理】\(w_n^{k+n/2}=-w_n^k\)
證明:

4、\(w_n^0=w_n^k=1\)

5、\(w_n^{n-i} = 共軛(w_n^i)\)

6、\(w_n^{n+i}=w_n^i\)

多項式系數表示法

\(A(x)\)表示一個d次多項式,則\(A(x)=a_1+a_2*x+...,+a_{d}*x^{d}\)
利用這種方法計算多項式卷積復雜度為\(O(d^2)\),其實就是直接對應相乘(暴力)。
例如:\(A(x)=1+2x+x^2\), \(B(x)=1-2x+x^2\)

\[A(x)B(x)=(+2x+x^2)(1-2x+x^2)=1-2x^2+x^4 \]

多項式點值表示法

將n個值x帶入多項式,會得到d各不同的值y,則該多項式被這n個點值\((x_1,y_1),...,(x_d,y_x)\)唯一確定,其中\(\sum_{j=1}^{d}a_j*x^j_i\)
而利用點值法計算多項式卷積復雜度也為\(O(d^2)\)。(選點\(O(d)\),每次計算\(O(d)\)
例如上面的多項式用點值法表示:\(A(x)=[(-2,1),(-1,0),(0,1),(1,4),(2,9)],B(x)=[(-2,9),(-1,4),(0,1),(1,0),(2,1)]\),則

\[C(x)=[(-2,9),(-1,0),(0,1),(1,0),(2,9)] \]

即有這個5個點就可以唯一確定一個4次多項式,而兩兩相乘的復雜度為\(O(d)\)

引理1:\(( d + 1 )\)個點值可以唯一確定一個d 階多項式

因此,我們可以將一個系數多項式轉換為一個點值多項式,然而進行復雜度為\(O(d)\)的乘法,再將結果的點值多項式恢復回系數多項式。

但是:如果我們采用下面這種矩陣形式計算點值的話【選點】,那么由系數轉為點值的復雜度也為\(O(d^2)\)

接下來考慮對其優化:
1、對於系數表示法,每個點的系數都固定,優化困難
2、對於點值表示法,可以用FFT來解決!

DFT

已知\(A(x)\)的系數為\((a_0,a_1,...,a_{n-1})\),對於\(k=0,1,...,n-1\),定義:

\[y_k=A(w_n^k)=\sum_{i=0}^{n-1}a_iw_n^{ki} \]

其中向量\(y=(y_0,y_1,...,y_{n-1})\)是系數向量\(a=(a_0,a_1,...,a_{n-1})\)的離散傅立葉變換,記\(y=DFT_n(a)\),復雜度為\(O(n^2)\)
而使用下面的FFT方法,可以在\(O(nlog^n)\)時間內求出\(DFT_n(a)\)

FFT

用於加速系數多項式到點值多項式的運算!
首先觀察下面多項式:

例如:\(F(x)=x^2\),有對稱性\(F(-x)=F(x)\),相當於確定了一個點相當於確定兩個點。
同理又如\(F(x^3)\),有性質\(F(-x)=-F(x)\),也是確定了一個點相當於確定了兩個點。

所以對於有奇偶行的多項式,只需要找到原本一半的點就可以得到這個多現實了。
基於以上想法,假如有下面多項式:

\(P_e\)\(P_o\)分別看作兩個多項式,也就是對於一個點\(x_i\),我們只要計算出\(P_e(x_i^2)\)\(P_o(x_i^2)\),就可以得到\(P(x_i)\)\(P(-x_i)\),而且\(P_e\)\(P_o\)還可以進一步拆分為奇偶兩部分!

假設原本我們需要n個點\(【\pm x_1,\pm x_12,...,\pm x_{n/2}】\)就能確定一個\(n-1\)階的多項式。現在變成了求\(P_e(x)\)\(P_o(x)\)\(x_1^2,x_2^2,...,x_{n/2}^2\)上面的點值【n/2個點】。
那如果這n/2個點兩兩之間滿足\(x_i^2=-x_j^2\),則就可以進一步拆分為一半了,就可以將原本的復雜度\(O(d^2)\)降為\(O(dlog^d)\)。這里可以看出FFT用到了分治思想。

問題是,$x_12,x_22,...,x_{n/2}^2並不滿足兩兩互為相反數。由此使用n次單位根,選用n個n次單文根

\([w^0,w^1,...,w^{n-1}]\)

這樣,兩個點平方后依舊互為相反數!

可以看出,將以一個n個點的求值問題轉換為求n/2個點,在轉換為求n/4個點,以此迭代,從而達到\(O(dlog^d)\)。將上述思想轉換為為代碼如下:

FFT的逆

如何從點值多項式變為系數多項式呢?

對於點值計算,

實際上就是一個矩陣的乘法:

將點換為n個n次單位根,則矩陣變為:

其中中間的范德蒙德矩陣就成了一個DFT矩陣。
有了正向(系數到點值)的矩陣變換,求逆向(點值到系數)就是對上面矩陣求逆即可:

即:

從上面可以看出,FFT是將\(w\)作為點值傳入,IFFT就是將\({1/n}*w^{-1}\)作為點值傳入:

程序

下面程序用FFT計算兩個大數乘
題目:http://acm.hdu.edu.cn/showproblem.php?pid=1402

#include <iostream>
#include <string.h>
#include <stdio.h>
#include <math.h>

using namespace std;
const int N = 500005;
const double PI = acos(-1.0);

struct Virt
{
    double r, i;

    Virt(double r = 0.0,double i = 0.0)
    {
        this->r = r;
        this->i = i;
    }

    Virt operator + (const Virt &x)
    {
        return Virt(r + x.r, i + x.i);
    }

    Virt operator - (const Virt &x)
    {
        return Virt(r - x.r, i - x.i);
    }

    Virt operator * (const Virt &x)
    {
        return Virt(r * x.r - i * x.i, i * x.r + r * x.i);
    }
};

//雷德算法--倒位序
void Rader(Virt F[], int len)
{
    int j = len >> 1;
    for(int i=1; i<len-1; i++)
    {
        if(i < j) swap(F[i], F[j]);
        int k = len >> 1;
        while(j >= k)
        {
            j -= k;
            k >>= 1;
        }
        if(j < k) j += k;
    }
}

//FFT實現
void FFT(Virt F[], int len, int on)
{
    Rader(F, len);
    for(int h=2; h<=len; h<<=1) //分治后計算長度為h的DFT
    {
        Virt wn(cos(-on*2*PI/h), sin(-on*2*PI/h));  //單位復根e^(2*PI/m)用歐拉公式展開
        for(int j=0; j<len; j+=h)
        {
            Virt w(1,0);            //旋轉因子
            for(int k=j; k<j+h/2; k++)
            {
                Virt u = F[k];
                Virt t = w * F[k + h / 2];
                F[k] = u + t;     //蝴蝶合並操作
                F[k + h / 2] = u - t;
                w = w * wn;      //更新旋轉因子
            }
        }
    }
    if(on == -1)
        for(int i=0; i<len; i++)
            F[i].r /= len;
}

//求卷積
void Conv(Virt a[],Virt b[],int len)
{
    FFT(a,len,1);
    FFT(b,len,1);
    for(int i=0; i<len; i++)
        a[i] = a[i]*b[i];
    FFT(a,len,-1);
}

char str1[N],str2[N];
Virt va[N],vb[N];
int result[N];
int len;

void Init(char str1[],char str2[])
{
    int len1 = strlen(str1);
    int len2 = strlen(str2);
    len = 1;
    while(len < 2*len1 || len < 2*len2) len <<= 1;

    int i;
    for(i=0; i<len1; i++)
    {
        va[i].r = str1[len1-i-1] - '0';
        va[i].i = 0.0;
    }
    while(i < len)
    {
        va[i].r = va[i].i = 0.0;
        i++;
    }
    for(i=0; i<len2; i++)
    {
        vb[i].r = str2[len2-i-1] - '0';
        vb[i].i = 0.0;
    }
    while(i < len)
    {
        vb[i].r = vb[i].i = 0.0;
        i++;
    }
}

void Work()
{
    Conv(va,vb,len);
    for(int i=0; i<len; i++)
        result[i] = va[i].r+0.5;
}

void Export()
{
    for(int i=0; i<len; i++)
    {
        result[i+1] += result[i]/10;
        result[i] %= 10;
    }
    int high = 0;
    for(int i=len-1; i>=0; i--)
    {
        if(result[i])
        {
            high = i;
            break;
        }
    }
    for(int i=high; i>=0; i--)
        printf("%d",result[i]);
    puts("");
}

int main()
{
    while(~scanf("%s%s",str1,str2))
    {
        Init(str1,str2);
        Work();
        Export();
    }
    return 0;
}

NTT

在FFT中,我們需要用到復數,復數雖然很神奇,但是它也有自己的局限性——需要用double類型計算,精度太低,那有沒有什么東西能夠代替復數且解決精度問題呢?
這個東西,叫原根

若a,p互素,且p>1,對於\(a^n mod p =1\)滿足最小的n,叫做a模p的階,記\(\delta _p(a)\).
例如:

\[\delta _7(2)=3 \]

其中:
\(2^1 mod 7 =2\)
\(2^2 mod 7 =4\)
\(2^3 mod 7 =1\)

原根

設p是正整數,a是整數,若\(\delta _p(a)\)等於\(\phi(p)\),則a為模p的一個原根。

例如:
\(\delta _7(3)=6=\phi (7)\),所以3是模7的一個原根。

原根的個數不唯一

1、若模數p有原根,那么它一定有\(\phi(\phi(p))個原根\)
2、若p為素數,原根一定存在,假設g是P的一個原根,那么\(g^i mod p (1<g<p,0<i<p)\)的結果兩兩不同
簡單的說,就是

\[g^i mod p \neq g^j mod p,(1< i \neq j <p-1) \]

3、那如何求一個質數的原根呢?
對於指數p,\(p_i\)是p-1的因子,若\(g^{{p-1}/p_i} (mod p)\)恆成立,則g是p的原根。

下面就是為什么原根可以代替單位根計算?
因為原根具有和單位根相同的性質,FFT中,用到了單位根的四條性質,原根也滿足這四條性質:

最終可以得到:

\[w_n=g^{{p-1}/n} mod p \]

然后只需將FFT中的\(w_n\)替換掉,就是NTT。即:

綜上,NTT的變換為:

這里P是素數且N必須是P-1的因子;由於N是2的方冪,所以可構造\(P=c.2^k+1\)的素數。
通常p取998244353,它的原根為3。

程序

使用NTT,計算兩個大數乘

#include <iostream>
#include <string.h>
#include <stdio.h>
#include <ctime>
using namespace std;
typedef long long LL;

const int N = 1 << 18;
const int P = (479 << 21) + 1;
const int G = 3;
const int NUM = 20;

LL  wn[NUM];
LL  a[N], b[N];
char A[N], B[N];

LL quick_mod(LL a, LL b, LL m)
{
    LL ans = 1;
    a %= m;
    while(b)
    {
        if(b & 1)
        {
            ans = ans * a % m;
            b--;
        }
        b >>= 1;
        a = a * a % m;
    }
    return ans;
}

void GetWn()
{
    for(int i = 0; i < NUM; i++)
    {
        int t = 1 << i;
        wn[i] = quick_mod(G, (P - 1) / t, P);
    }
}

void Prepare(char A[], char B[], LL a[], LL b[], int &len)
{
    len = 1;
    int L1 = strlen(A);
    int L2 = strlen(B);
    while(len <= 2 * L1 || len <= 2 * L2) len <<= 1;
    for(int i = 0; i < len; i++)
    {
        if(i < L1) a[i] = A[L1 - i - 1] - '0';
        else a[i] = 0;
        if(i < L2) b[i] = B[L2 - i - 1] - '0';
        else b[i] = 0;
    }
}

void Rader(LL a[], int len)
{
    int j = len >> 1;
    for(int i = 1; i < len - 1; i++)
    {
        if(i < j) swap(a[i], a[j]);
        int k = len >> 1;
        while(j >= k)
        {
            j -= k;
            k >>= 1;
        }
        if(j < k) j += k;
    }
}

void NTT(LL a[], int len, int on)
{
    Rader(a, len);
    int id = 0;
    for(int h = 2; h <= len; h <<= 1)
    {
        id++;
        for(int j = 0; j < len; j += h)
        {
            LL w = 1;
            for(int k = j; k < j + h / 2; k++)
            {
                LL u = a[k] % P;
                LL t = w * a[k + h / 2] % P;
                a[k] = (u + t) % P;
                a[k + h / 2] = (u - t + P) % P;
                w = w * wn[id] % P;
            }
        }
    }
    if(on == -1)
    {
        for(int i = 1; i < len / 2; i++)
            swap(a[i], a[len - i]);
        LL inv = quick_mod(len, P - 2, P);
        for(int i = 0; i < len; i++)
            a[i] = a[i] * inv % P;
    }
}

void Conv(LL a[], LL b[], int n)
{
    NTT(a, n, 1);
    NTT(b, n, 1);
    for(int i = 0; i < n; i++)
        a[i] = a[i] * b[i] % P;
    NTT(a, n, -1);
}

void Transfer(LL a[], int n)
{
    int t = 0;
    for(int i = 0; i < n; i++)
    {
        a[i] += t;
        if(a[i] > 9)
        {
            t = a[i] / 10;
            a[i] %= 10;
        }
        else t = 0;
    }
}

void Print(LL a[], int n)
{
    bool flag = 1;
    for(int i = n - 1; i >= 0; i--)
    {
        if(a[i] != 0 && flag)
        {
            //使用putchar()速度快很多
            putchar(a[i] + '0');
            flag = 0;
        }
        else if(!flag)
            putchar(a[i] + '0');
    }
    puts("");
}


int main()
{
    GetWn();
//71992652622957199265262295622895175513333762898450963102326778447440803476281886006800109449932463374706102636321647845385139107625719926526229571992652622956228951755133337628984509631023267784474408034762818860068001094499324633747061026363216478453851391076257199265262295719926526229562289517551333376289845096310232677844744080347628188600680010944993246337470610263632164784538513910762571992652622957199265262295622895175513333762898450963102326778447440803476281886006800109449932463374706102636321647845385139107625719926526229571992652622956228951755133337628984509631023267784474408034762818860068001094499324633747061026363216478453851391076257199265262295719926526229562289517551333376289845096310232677844744080347628188600680010944993246337470610263632164784538513910762571992652622957199265262295622895175513333762898450963102326778447440803476281886006800109449932463374706102636321647845385139107625
    while(scanf("%s %s", A, B) != EOF)
    {
        int len;
        clock_t start_time = clock();//計時開始
        Prepare(A, B, a, b, len);
        Conv(a, b, len);
        Transfer(a, len);
        cout << "elapsed time:" << 1000*double(clock() - start_time) / CLOCKS_PER_SEC
             << 'ms' << endl;
        Print(a, len);
    }
    return 0;
}

輸出:elapsed time:3.9328019

MTT

待學習!

參考

1、快速傅里葉變換(FFT)詳解
2、快速數論變換(NTT)小結
3、CKKS的Encoding(CKKS方案的編碼部分的筆記)

4、多項式乘法運算終極版
5、多項式乘法運算初級版


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM