FFT/NTT/MTT學習筆記


FFT/NTT/MTT

Tags:數學

作業部落

評論地址


前言

這是網上的優秀博客
並不建議初學者看我的博客,因為我也不是很了解FFT的具體原理

一、概述

兩個多項式相乘,不用\(N^2\),通過\(FFT\)可以把復雜度優化到\(O(NlogN)\)\(NTT\)能夠取模,\(MTT\)可以對非\(NTT\)模數取模,相對來說\(FFT\)常數小些因為不要取模

二、我們來背板子(FFT)

先放一個板子(洛谷P3803 【模板】多項式乘法(FFT)

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
using namespace std;
const int MAXN=3000005;
const double pi=acos(-1); 
int N,M,r[MAXN],l;
struct Complex
{
	double rl,im;//real part / imaginary part
	Complex(){rl=im=0;}//以下是初始化的板子,雖然不懂為什么可以這樣寫
	Complex(double a,double b){rl=a,im=b;}
	Complex operator + (Complex B)
		{return Complex(rl+B.rl,im+B.im);}
	Complex operator - (Complex B)
		{return Complex(rl-B.rl,im-B.im);}
	Complex operator * (Complex B)
		{return Complex(rl*B.rl-im*B.im,rl*B.im+im*B.rl);}
}A[MAXN],B[MAXN];//對A,B兩個多項式進行乘法
void FFT(Complex *P,int op)
{
	for(int i=1;i<N;i++)//這個叫Rader排序
		/*
		  假設原來P[1...n].id=1..n
		  現在需要的序列是從1到n所對應的id分別為id[1..n],滿足r[id[i]]是升序
		  r[i]表示把i二進制上第1到l位的數反過來后的十進制數
		 */			 
		if(i<r[i]) swap(P[i],P[r[i]]);
	//接下來的這個叫做蝴蝶操作,算法導論上有一張圖較為清晰
	for(int i=1;i<N;i<<=1)//表示操作區間集的每個區間的長度
	{
		Complex W=(Complex){cos(pi/i),op*sin(pi/i)};
		for(int p=i<<1,j=0;j<N;j+=p)//表示每個區間集的最上端位置
		{
			Complex w=(Complex){1,0};//第0個單位復數根
			/*
			  轉角公式:將一個點(x,y)繞原點逆時針旋轉t后的點是(x*cost-y*sint,x*sint+y*cost)
			  用三角函數和差化積公式容易得證
			  單位復數根是把單位元分為若干等份,於是每次就要轉一定角度
			  用w=w*W實現轉角
			 */
			for(int k=0;k<i;k++,w=w*W)//每個區間的最上端位置
			{
				Complex X=P[j+k],Y=w*P[j+k+i];//j+k+i便是每個區間下端位置
				P[j+k]=X+Y;P[j+k+i]=X-Y;//所謂蝴蝶操作
			}
		}
	}
}
int main()
{
    cin>>N>>M;
	for(int i=0;i<=N;i++) cin>>A[i].rl;
	for(int i=0;i<=M;i++) cin>>B[i].rl;
	//讀入實部,便是系數
	M+=N;//最終位數
	for(N=1;N<=M;N<<=1) l++;l--;//FFT必須是2^k項才能做,這里把他補全
	for(int i=0;i<N;i++) r[i]=(r[i>>1]>>1)|((i&1)<<l);//r是rader排序,將每個i的二進制位反過來
	FFT(A,1);FFT(B,1);//將AB化成點集形式
	//形如(w0,y0),(w1,y1)...(wn,yn)的這些點確定一條線
	for(int i=0;i<N;i++) A[i]=A[i]*B[i];//點集O(n)相乘
	FFT(A,-1);//再將點集轉化為系數表示的形式
	for(int i=0;i<=M;i++) printf("%d ",(int)(A[i].rl/N+0.5));//這時虛部都是0了
	return 0;
}

以下是預處理單位復數根的代碼
代碼長度會小些,精度也要高,建議使用這種寫法
三角函數比乘法慢

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<complex>
using namespace std;
const int MAXN=3e6+10;
const double pi=acos(-1);
int r[MAXN],N,M,l;
complex<double>A[MAXN],B[MAXN],w[MAXN];
void FFT(complex<double> *P,int op)
{
	for(int i=1;i<N;i++) if(i>r[i]) swap(P[i],P[r[i]]);
	for(int i=1;i<N;i<<=1)
		for(int p=i<<1,j=0;j<N;j+=p)
			for(int k=0;k<i;k++)
			{
				complex<double> W=w[N/i*k];W.imag()*=op;//實際要得到的是cos(pi/i*k)
				complex<double> X=P[j+k],Y=W*P[j+k+i];//QAQ這里總是忘記乘W
				P[j+k]=X+Y;P[j+k+i]=X-Y;
			}
}
int main()
{
	cin>>N>>M;
	for(int i=0;i<=N;i++) cin>>A[i].real();
	for(int i=0;i<=M;i++) cin>>B[i].real();
	M+=N;
	for(N=1;N<=M;N<<=1) l++;l--;
	for(int i=0;i<N;i++) r[i]=(r[i>>1]>>1)|((i&1)<<l);
	for(int i=0;i<N;i++) w[i].real()=cos(pi/N*i),w[i]=imag()=sin(pi/N*i);
	FFT(A,1);FFT(B,1);
	for(int i=0;i<N;i++) A[i]=A[i]*B[i];
	FFT(A,-1);
	for(int i=0;i<=M;i++) printf("%d ",(int)(A[i].real()/N+0.5));
	puts("");return 0;
}

記憶方式:
循環的\(i\)枚舉當前處理的長度
\(j\)枚舉第幾組(兩組兩組進行)
\(k\)枚舉位置
於是\(j+k\)表示某組的第一小組的一個位置,\(i+j+k\)是某組第二小組與第一小組對應的位置
然后先加再減,記得乘上\(W\)

注意點:
1.最后要(int)(real()/N+0.5)
2.由於N要放大所以空間開兩倍!!

三、我們再來背板子(NTT)

還是那道題

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cmath>
using namespace std;
const int N=3000005;
const int mod=998244353;
int r[N],l,n,m,A[N],B[N],w[N];
int ksm(int a,int k)
{
    int s=1,b=a;
    for(;k;k>>=1,b=1ll*b*b%mod)
        if(k&1) s=1ll*s*b%mod;
    return s;
}
void NTT(int *P,int op)
{
    for(int i=0;i<n;i++) if(i<r[i]) swap(P[i],P[r[i]]);
    for(int i=1;i<n;i<<=1)
    {
        int W=ksm(3,(mod-1)/(i<<1));//3是998244353的一個原根
        if(op<0) W=ksm(W,mod-2);w[0]=1;
        for(int j=1;j<i;j++) w[j]=1ll*w[j-1]*W%mod;
        for(int j=0,p=i<<1;j<n;j+=p)
            for(int k=0;k<i;k++)
            {
                int X=P[j+k],Y=1ll*w[k]*P[i+j+k]%mod;
                P[j+k]=(X+Y)%mod;P[i+j+k]=((X-Y)%mod+mod)%mod;
            }
    }
}
int main()
{
    cin>>n>>m;
    for(int i=0;i<=n;i++) cin>>A[i];
    for(int i=0;i<=m;i++) cin>>B[i];
    m=n+m;for(n=1;n<=m;n<<=1) l++;l--;
    for(int i=0;i<n;i++) r[i]=(r[i>>1]>>1)|((i&1)<<l);
    NTT(A,1);NTT(B,1);
    for(int i=0;i<n;i++) A[i]=1ll*A[i]*B[i]%mod; NTT(A,-1);
    for(int i=0,inv=ksm(n,mod-2);i<n;i++) A[i]=1ll*A[i]*inv%mod;
    for(int i=0;i<=m;i++) printf("%d ",A[i]);
    return 0;
}

一個數\(k\)的原根\(x\)滿足\(x^1,x^2,x^3...x^{\phi(k)}\)各不相同且\(x^{\phi(k)}=1\)
對於且僅對於\(2,4,p,2p,p^r(p為奇質數)\)有原根存在
NTT的原根就代替了FFT中的單位復數根,要求形式是\(p=r*2^p+1\)
常用的\(NTT\)模數有\(998244353(3)\)\(1004535809(3)\)

找質數的原根

最暴力的方法是枚舉原根,然后判斷\(x^1...x^{p-1}\)是否相同
優化的話是檢查\(p-1\)的所有質因數中,是否存在一個質因子\(k\)使得\(x^{\frac{p-1}{k}}=1\),若存在,則該數不是原根,否則是原根

證明(Thanks GXY)

首先可以明確的是,若對於\(m\)屬於\([1,p-2]\),沒有\(g^m\equiv 1(mod\ p)\),則g是一個原根
因為\(g^{m1}\equiv g^{m2} \equiv k\),且\(m1>m2\),則一定有\(g^{m2-m1}\equiv 1\)
利用反證法,假設存在一個\(m\)使得\(g^m\equiv 1(mod\ p)\)

分兩種情況討論:
1.\(gcd(p-1,m)!=1\)
\(k=(p-1)/m=p_1^{a_1}p_2^{a_2}...p_i^{a_i}\)\(p_i\)為質數
\(g^{\frac{p-1}{p_i}}=g^{k*p_1^{a_1}*..*P_i^{a_i-1}}=(g^k)^{p_1^{a_1}..p_i^{a_i}}=1\),能夠通過上述方法判定出來

2.\(gcd(p-1,m)==1\)
\(g^m\equiv g^{2m}\equiv...\equiv g^{km}\equiv g^{p-1}\equiv 1(mod\ p)\)
\(km\equiv x(mod\ p-1)\),由於\(gcd(m,p-1)==1\),根據同余方程的EXGCD判斷,\(x\)可以在\([0,p-2]\)任意取值,都有符合條件的\(k\)使得式子成立
根據歐拉定理/費馬小定理得\(g^{km}\equiv g^{km\%(p-1)}\equiv g^x\equiv 1(mod\ p)\),使得所有的\(x\)屬於\([0,p-2]\)都模p余1,也會在之前的方法中判斷出來

四、有個可以講清的了(MTT)

處理任意模數\(NTT\)問題
\(M=\sqrt{mod}\)(這樣子好像復雜度最優)
然后多項式的每一項拆成\(AM+B\),於是\(A\)\(B\)都在\(int\)之內就不會爆\(double\)
所以兩個數相乘就成為了$$(A_1M+B_1)*(A_2M+B_2)=A_1A_2M^2+(A_1B_2+A_2B_1)M+B_1B_2$$分別進行\(4\)\(DFT\)\(4\)\(IDFT\)即可(一共8次,有些博客是7次,但是代碼比我長)

Code

洛谷P4245 【模板】任意模數NTT

#include<iostream>
#include<cstdio>
#include<cstring>
#include<complex>
#include<cmath>
using namespace std;
const double Pi=acos(-1);
const int N=400100;
const int M=30000;
int n,m,p,F[N],G[N];
int r[N],Ans[N],l,tt;
complex<double> A1[N],B1[N],A2[N],B2[N],A[N],w[N];
void FFT(complex<double> *P,int op)
{
    for(int i=0;i<l;i++) if(r[i]<i) swap(P[i],P[r[i]]);
    for(int i=1;i<l;i<<=1)
        for(int p=i<<1,j=0;j<l;j+=p)
            for(int k=0;k<i;k++)
            {
                complex<double> W=w[l/i*k];W.imag()*=op;
                complex<double> X=P[j+k],Y=W*P[j+k+i];
                P[j+k]=X+Y;P[j+k+i]=X-Y;
            }
}
void Work(complex<double> *P1,complex<double> *P2,int base)
{
    for(int i=0;i<l;i++) A[i]=P1[i]*P2[i];FFT(A,-1);
    for(int i=0;i<=m+n;i++) (Ans[i]+=(long long)(A[i].real()/l+0.5)%p*base%p)%=p;
}
int main()
{
    scanf("%d%d%d",&n,&m,&p);
    for(int i=0,x;i<=n;i++) scanf("%d",&x),A1[i].real()=x/M,B1[i].real()=x%M;
    for(int i=0,x;i<=m;i++) scanf("%d",&x),A2[i].real()=x/M,B2[i].real()=x%M;
    for(l=1;l<=n+m;l<<=1) tt++;tt--;
    for(int i=0;i<l;i++) r[i]=(r[i>>1]>>1)|((i&1)<<tt);
    for(int i=0;i<l;i++) w[i].real()=cos(Pi/l*i),w[i].imag()=sin(Pi/l*i);
    FFT(A1,1);FFT(A2,1);FFT(B1,1);FFT(B2,1);
    Work(A1,A2,M*M%p); Work(A1,B2,M%p);
    Work(A2,B1,M%p); Work(B1,B2,1);
    for(int i=0;i<=m+n;i++) printf("%d ",Ans[i]);
}

五、一些要點

這一部分還沒玩成,待博主把這些算法完全弄懂后再來填坑~
NTT時(X-Y+mod)%mod時,Y為負數就可能爆int,可以不加mod然后最后輸出的時候加
當乘起來不會超過mod(注意是乘后累加),那么NTT可以代替FFT,否則不行,例子見MTT
乘法通過原根變成加法再NTT
字符串匹配問題的兩種做法
組合數公式給拆成可以NTT的形式


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM