[FFT/NTT/MTT]總結


最近重新學了下卷積,簡單總結一下,不涉及細節內容:

 

1、FFT

朴素求法:$Coefficient-O(n^2)-CoefficientResult$

FFT:$Coefficient-O(nlogn)-Dot-O(n)-DotResult-O(nlogn)-CoefficientResult$

其中系數到點值的轉化稱為$DFT(離散傅里葉變換)$,而點值到系數的轉為稱為$IDFT(傅里葉逆變換)$

 

原本朴素的直接帶入$n$個值的$DFT$和直接使用拉格朗日插值公式的$IDFT$的復雜度仍為$O(n^2)$

但$FFT$通過帶入特定的值:單位根,使得兩者都能迭代/分治得解決,將復雜度降到了$O(nlogn)$

優化的技巧和注意事項:

1、預處理$w[i]$

2、求出最終數組從后往前迭代省去遞歸常數

3、數組長度要先擴成2的倍數用於分治

 

模板:

#include <bits/stdc++.h>

using namespace std;
#define X first
#define Y second
#define pb push_back
typedef double db;
typedef long long ll;
typedef pair<int,int> P;
const int MAXN=3e6+10;
struct Complex
{
    db x,y;
    Complex(db a=0,db b=0){x=a;y=b;}
    Complex operator + (const Complex& rhs)
    {return Complex(x+rhs.x,y+rhs.y);}
    Complex operator - (const Complex& rhs)
    {return Complex(x-rhs.x,y-rhs.y);}
    Complex operator * (const Complex& rhs)
    {return Complex(x*rhs.x-y*rhs.y,x*rhs.y+y*rhs.x);}
}a[MAXN],b[MAXN];
int n,m,lmt=1,dgt,par[MAXN];

void FFT(Complex *a,int flag)
{
    for(int i=0;i<lmt;i++)
        if(i<par[i]) swap(a[i],a[par[i]]);
    
    for(int len=1;len<lmt;len<<=1)
    {
        Complex unit(cos(M_PI/len),flag*sin(M_PI/len));
        for(int st=0;st<lmt;st+=(len<<1))
        {
            Complex w(1,0);
            for(int k=st;k<st+len;k++,w=w*unit)
            {
                Complex A=a[k],B=w*a[k+len];
                a[k]=A+B;a[k+len]=A-B;
            }
        }
    }
    if(flag==-1)
        for(int i=0;i<=n+m;i++)
            a[i].x=floor(a[i].x/lmt+0.5);
}

int main()
{    
    scanf("%d%d",&n,&m);
    for(int i=0;i<=n;i++) scanf("%lf",&a[i].x);
    for(int i=0;i<=m;i++) scanf("%lf",&b[i].x);
    while(lmt<=n+m) lmt<<=1,dgt++;
    for(int i=0;i<lmt;i++)
        par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1));    
    
    FFT(a,1);FFT(b,1);
    for(int i=0;i<lmt;i++) 
        a[i]=a[i]*b[i];
    FFT(a,-1);
    for(int i=0;i<=n+m;i++) 
        printf("%d ",(int)a[i].x);
    return 0;
}
FFT

 

2、NTT

單位根由於涉及了復數的運算,導致對精度要求高時會出錯

而$NTT$就能使得整個$FFT$都能在模意義下計算,從而滿足精度要求

 

考慮$FFT$引入單位根$w_n^k$是為了其什么性質來分治計算:

1、$w_n^k$互不相同,保證點值表示的合法

2、$w_{t*n}^{t*k}=w_n^k$且$w_n^{k+2/n}=-w_n^k$,使得計算可分治

3、$\sum_{i=0}^{n-1} {w_n^k}^i=n*[k==0]$,保證逆矩陣構造的正確性

在模意義下引入質數$p=kn+1$,其原根$g$滿足$g_t(t\in [0,p-1])$互不相同

這樣令$p$的$k$次單位根為$g^{\frac{p-1}{k}}$,易證上述$w_n^k$的性質其在模意義下均滿足

 

接下來考慮該怎樣選擇質數$p$

為了能夠分治時允許$k$每次乘2,$p-1$的質因數分解中要有很多的2

令$p=r*2^k+1$,其能處理的數據規模為$[0,2^k]$,常用質數有:傳送門

 

這樣,我們就在模意義下利用原根的性質找到了可做$FFT$的“單位根”

由於沒有了復數運算,$NTT$比$FFT$的常數也小了很多,一般是更好的選擇

 

模板:

#include <bits/stdc++.h>

using namespace std;
#define X first
#define Y second
#define pb push_back
typedef double db;
typedef long long ll;
typedef pair<int,int> P;
const int MAXN=4e6+10,MOD=998244353;
ll n,m,a[MAXN],b[MAXN],dgt,lmt=1,par[MAXN];

ll quick_pow(ll a,ll b)
{
    ll ret=1;
    for(;b;b>>=1,a=a*a%MOD)
        if(b&1) ret=ret*a%MOD;
    return ret;
}
void FFT(ll *a,int flag)
{
    for(int i=0;i<lmt;i++)
        if(i<par[i]) swap(a[i],a[par[i]]);
    for(int len=1;len<lmt;len<<=1)
    {
        ll unit=quick_pow(3,(MOD-1)/(len<<1));
        if(flag==-1) unit=quick_pow(unit,MOD-2);
        for(int st=0;st<lmt;st+=(len<<1))
        {
            ll w=1;
            for(int k=st;k<st+len;k++,w=w*unit%MOD)
            {
                ll A=a[k],B=w*a[k+len]%MOD;
                a[k]=(A+B)%MOD;a[k+len]=(A-B+MOD)%MOD;
            }
        }
    }
}

int main()
{
    scanf("%lld%lld",&n,&m);
    for(int i=0;i<=n;i++) scanf("%lld",&a[i]);
    for(int i=0;i<=m;i++) scanf("%lld",&b[i]);
    while(lmt<=n+m) lmt<<=1,dgt++;
    for(int i=0;i<lmt;i++)
        par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1));
    
    FFT(a,1);FFT(b,1);
    for(int i=0;i<lmt;i++)
        (a[i]*=b[i])%=MOD;
    FFT(a,-1);
    ll inv=quick_pow(lmt,MOD-2);
    for(int i=0;i<=n+m;i++)
        printf("%lld ",a[i]*inv%MOD);
    return 0;
}
NTT

 

3、MTT

如果答案需要取模且模數非質數該如何處理呢?

常見背景為:多項式長度$1e5$,模數$1e9$非質數,此時$FFT$爆$longlong$,沒法用$NTT$

 

(1)三模數$NTT$

根據上方的數據限制,可發現最終答案最多為$1e23$

這樣就能用多個乘積大於$1e23$的模數分別做$NTT$最后再用$CRT$合並答案即可

一般常用:469762049,998244353,1004535809

 

可如果直接用$CRT$合並會發現模數爆$longlong$還是不好處理

此時就可以先合並前兩個式子,得到

$res=k(mod(p_1*p_2)),res=a_3(mod(p_3))$

這樣設$res=p_1*p_2*c+k$再帶入二式就能得到$c=(a_3-k)*(p_1*p_2)^{-1}(mod(p_3))$

這樣類似$exCRT$的分步處理就避開了對$p_1*p_2*p_3$的取模

但這樣要進行9次$DFT/IDFT$,常數巨大無比

 

模板:

#include <bits/stdc++.h>

using namespace std;
#define X first
#define Y second
#define pb push_back
typedef double db;
typedef long long ll;
typedef pair<int,int> P;
const int MAXN=4e5+10;
ll p[]={469762049,998244353,1004535809};
int n,m,MOD,F[MAXN],G[MAXN],dgt,lmt=1;
ll a[3][MAXN],b[MAXN],res[MAXN],par[MAXN];

ll quickpow(ll a,ll b,ll MOD)
{
    a%=MOD;ll ret=1;
    for(;b;b>>=1,a=a*a%MOD)
        if(b&1) ret=ret*a%MOD;
    return ret;
}
ll mul(ll a,ll b,ll MOD)
{
    a=(a%MOD+MOD)%MOD;
    b=(b%MOD+MOD)%MOD;ll ret=0;
    for(;b;b>>=1,a=(a+a)%MOD)
        if(b&1) (ret+=a)%=MOD;
    return ret;
}
ll inv(ll a,ll MOD)
{return quickpow(a,MOD-2,MOD);}
void FFT(ll *a,int flag,ll MOD)
{
    for(int i=0;i<lmt;i++)
        if(i<par[i]) swap(a[i],a[par[i]]);
    for(int len=1;len<lmt;len<<=1)
    {
        ll unit=quickpow(3,(MOD-1)/(len<<1),MOD);
        if(flag==-1) unit=inv(unit,MOD);
        for(int st=0;st<lmt;st+=(len<<1))
        {
            ll w=1;
            for(int k=st;k<st+len;k++,w=w*unit%MOD)
            {
                ll A=a[k],B=w*a[k+len]%MOD;
                a[k]=(A+B)%MOD;a[k+len]=(A-B+MOD)%MOD;
            }
        }
    }
    if(flag==-1)
    {
        ll INV=inv(lmt,MOD);
        for(int i=0;i<lmt;i++)
            a[i]=a[i]*INV%MOD;
    }
}
void solve(ll *a,ll *b,ll MOD)
{
    for(int i=0;i<=n;i++) a[i]=F[i];
    for(int i=0;i<=m;i++) b[i]=G[i];
    for(int i=m+1;i<lmt;i++) b[i]=0;
    FFT(a,1,MOD);FFT(b,1,MOD);
    for(int i=0;i<lmt;i++) a[i]=a[i]*b[i]%MOD;
    FFT(a,-1,MOD);
}

int main()
{
    scanf("%d%d%d",&n,&m,&MOD);
    for(int i=0;i<=n;i++) scanf("%d",&F[i]);
    for(int i=0;i<=m;i++) scanf("%d",&G[i]);
    while(lmt<=n+m) lmt<<=1,dgt++;
    for(int i=0;i<lmt;i++) par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1));
    
    for(int i=0;i<3;i++) solve(a[i],b,p[i]);
    for(int i=0;i<=n+m;i++)
    {
        ll M=p[0]*p[1];
        ll A=(mul(a[0][i]*p[1],inv(p[1],p[0]),M)+
             mul(a[1][i]*p[0],inv(p[0],p[1]),M))%M;
        ll K=mul(a[2][i]-A,inv(M,p[2]),p[2]);
        res[i]=(mul(K,M,MOD)+A%MOD)%MOD;
    }
    for(int i=0;i<=n+m;i++)
        printf("%lld ",res[i]);
    return 0;
}
三模數NTT

 

(2)拆系數$FFT$

 不能用$FFT$僅僅因為最后答案會爆$longlong$,那么可以將原數拆分后分別計算

$A_i=a_i* \sqrt{P}+b_i,B_i=c_i* \sqrt{P}+d_i$

此時$A*B=P*(a*c)+\sqrt{P}*(a*d+b*c)+(b*d)$,每部分最大值為$1e14$,分別$DFT/IDFT$

這樣要做7次$DFT/IDFT$,效率未顯著提升

 

$myy$在論文里提到過對FFT的優化:

設$P_j=A_j+i*B_j,Q_j=A_j-i*B_j$,使得$DFT$前虛部不再為空

可推出$DFT$后的$DP,DQ$數組如下結論:

$DP_k=\sum_{j=0}^{lmt-1} (A_j+i*B_j)*w_{lmt}^{j*k},DQ_k=conj(DP_{lmt-k})$

 

這樣就能用1次對$P$的$DFT$算出$P,Q,A,B$的$DFT$,從而將上面的4次$DFT$化為2次

由於$IDFT$就能看成$DFT$的逆過程

因此可以合並算出$IDFT(DFT[a]*DFT[c]+i*DFT[b]*DFT[d])$,從而將$IDFT$也化為2次

這樣的常數經測試是三模數$NTT$的1/7左右

 

模板:

#include <bits/stdc++.h>

using namespace std;
#define X first
#define Y second
#define pb push_back
typedef double db;
typedef long long ll;
typedef pair<int,int> P;
const int MAXN=1e6+10;
struct Complex
{
    db x,y;
    Complex(db a=0,db b=0){x=a;y=b;}
    Complex operator +(const Complex& rhs)
    {return Complex(x+rhs.x,y+rhs.y);}
    Complex operator -(const Complex& rhs)
    {return Complex(x-rhs.x,y-rhs.y);}
    Complex operator *(const Complex& rhs)
    {return Complex(x*rhs.x-y*rhs.y,x*rhs.y+y*rhs.x);}
}a[MAXN],b[MAXN],w[MAXN],t1[MAXN],t2[MAXN],t3[MAXN];
int n,m,MOD,lmt=1,dgt,par[MAXN];ll x,res[MAXN];

void FFT(Complex *a,int flag)
{
    for(int i=0;i<lmt;i++)
        if(i<par[i]) swap(a[i],a[par[i]]);
    for(int len=1;len<lmt;len<<=1)
        for(int st=0;st<lmt;st+=(len<<1))
        {
            int cur=0;
            for(int k=st;k<st+len;k++)
            {
                Complex A=a[k],B=w[cur]*a[k+len];
                a[k]=A+B;a[k+len]=A-B;
                //預處理的寫法 
                cur=(cur+flag*lmt/(len<<1)+lmt)&(lmt-1);
            }
        }
    if(flag==-1)
        for(int i=0;i<lmt;i++)
            a[i].x=floor(a[i].x/lmt+0.5);
}
void solve()
{
    FFT(a,1);FFT(b,1);
    for(int i=0;i<lmt;i++)
    {
        Complex d1,d2,d3,d4;
        int j=(lmt-i)&(lmt-1);
        d1=(a[i]+Complex(a[j].x,-a[j].y))*Complex(0.5,0);
        d2=(a[i]-Complex(a[j].x,-a[j].y))*Complex(0,-0.5);
        d3=(b[i]+Complex(b[j].x,-b[j].y))*Complex(0.5,0);
        d4=(b[i]-Complex(b[j].x,-b[j].y))*Complex(0,-0.5);
        //必須先用臨時變量存,因為后面還要用 
        t1[i]=d1*d3;t2[i]=d1*d4+d2*d3;t3[i]=d2*d4;
    }
    for(int i=0;i<lmt;i++)
        //充分利用虛部空間(可看成逆過程) 
        b[i]=t2[i],a[i]=t1[i]+t3[i]*Complex(0,1);
    FFT(a,-1);FFT(b,-1);
    for(int i=0;i<lmt;i++)
    {
        ll k1=(ll)a[i].x%MOD,k2=(ll)b[i].x%MOD;
        ll k3=(ll)floor(a[i].y/lmt+0.5)%MOD;
        res[i]=((k3<<30)%MOD+(k2<<15)%MOD+k1)%MOD;
    }
}

int main()
{
    scanf("%d%d%d",&n,&m,&MOD);
    for(int i=0;i<=n;i++)
        scanf("%lld",&x),a[i]=Complex(x&32767,x>>15);
    for(int i=0;i<=m;i++)
        scanf("%lld",&x),b[i]=Complex(x&32767,x>>15);
    while(lmt<=n+m) lmt<<=1,dgt++;
    for(int i=0;i<lmt;i++)
        par[i]=(par[i>>1]>>1)|((i&1)<<(dgt-1));
    for(int i=0;i<lmt;i++)
        w[i]=Complex(cos(2*M_PI*i/lmt),sin(2*M_PI*i/lmt));
    
    solve();
    for(int i=0;i<=n+m;i++)
        printf("%lld ",res[i]);
    return 0;
}
拆系數FFT

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM