多項式全家桶學習筆記(How EI's poly works)


這里都是一些論文級別的玩意,基本不是給正常人類看的

注意:這里僅對模數是 \(998244353\) 的部分進行介紹。

零.讓我們開始

這里是一些基礎的東西,不怎么需要想,這里就略過了。

Code
#include<bits/stdc++.h>
#define endl '\n' 
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define Rep(i,a,b) for(int i=(a);i>=(b);--i)
using namespace std;
const int P=998244353,G=3,LIMIT=50;
typedef vector<int> vec;
struct IO_Tp {
    const static int _I_Buffer_Size = 2 << 22;
    char _I_Buffer[_I_Buffer_Size], *_I_pos = _I_Buffer;

    const static int _O_Buffer_Size = 2 << 22;
    char _O_Buffer[_O_Buffer_Size], *_O_pos = _O_Buffer;

    IO_Tp() { fread(_I_Buffer, 1, _I_Buffer_Size, stdin); }
    ~IO_Tp() { fwrite(_O_Buffer, 1, _O_pos - _O_Buffer, stdout); }

    IO_Tp &operator>>(int &res) {
    	int f=1;
        while (!isdigit(*_I_pos)&&(*_I_pos)!='-') ++_I_pos;
        if(*_I_pos=='-')f=-1,++_I_pos;
        res = *_I_pos++ - '0';
        while (isdigit(*_I_pos)) res = res * 10 + (*_I_pos++ - '0');
        res*=f;
        return *this;
    }

    IO_Tp &operator<<(int n) {
    	if(n<0)*_O_pos++='-',n=-n;
        static char _buf[10];
        char *_pos = _buf;
        do
            *_pos++ = '0' + n % 10;
        while (n /= 10);
        while (_pos != _buf) *_O_pos++ = *--_pos;
        return *this;
    }

    IO_Tp &operator<<(char ch) {
        *_O_pos++ = ch;
        return *this;
    }
} IO;//快讀
void chkmax(int &x,int y){if(x<y)x=y;}
void chkmin(int &x,int y){if(x>y)x=y;}
int qpow(int a,int k,int p=P){//快速冪
	int ret=1;
	while(k){
		if(k&1)ret=1ll*ret*a%p;
		a=1ll*a*a%p;
		k>>=1;
	}
	return ret;
}
int norm(int x){return x>=P?x-P:x;}
int reduce(int x){return x<0?x+P:x;}
void add(int&x,int y){if((x+=y)>=P)x-=P;}//取模
struct Maths{
	int n;
	vec fac,invfac,inv;
	void build(int n){
		this->n=n;
		fac.resize(n+1);
		invfac.resize(n+1);
		inv.resize(n+1);
		fac[0]=1;
		rep(k,1,n)fac[k]=1ll*fac[k-1]*k%P;
		inv[1]=inv[0]=1;
		rep(k,2,n)inv[k]=P-1ll*(P/k)*inv[P%k]%P;
		invfac[0]=1;
		rep(k,1,n)invfac[k]=1ll*invfac[k-1]*inv[k]%P;
	}
	Maths(){build(1);}
	void chk(int k){
		int lmt=n;
		if(k>lmt){while(k>lmt)lmt<<=1;build(lmt);}
	}
	int cfac(int k){return chk(k),fac[k];}
	int cifac(int k){return chk(k),invfac[k];}
	int cinv(int k){return chk(k),inv[k];}
	int binom(int n,int m){
		if(m<0||m>n)return 0;
		return 1ll*cfac(n)*cifac(m)%P*cifac(n-m)%P;
	}
}math;//普通數論部分
struct poly{
	vec a;
	poly(int v=0):a(1){
		if((v%=P)<0)v+=P;
		a[0]=v;
	}
	poly(const vec&a):a(a){}
	poly(initializer_list<int>init):a(init){}
	int operator[](int k)const{return k<a.size()?a[k]:0;}
	int&operator[](int k){
		if(k>=a.size())a.resize(k+1);
		return a[k];
	}
	int deg()const{return a.size()-1;}
	void redeg(int d){a.resize(d+1);}
	poly slice(int d)const{
		if(d<a.size())return vec(a.begin(),a.begin()+d+1);
		vec res(a);
		res.resize(d+1);
		return res;
	}
	int*base(){return a.data();}
	const int*base()const{return a.data();}
	poly println(FILE* fp)const{
		fprintf(fp,"%d",a[0]);
		rep(i,1,a.size()-1)fprintf(fp," %d",a[i]);
		fputc('\n',fp);
		return *this;
	}
	poly operator+(const poly&rhs)const{
		vec res(max(a.size(),rhs.a.size()));
		rep(i,0,res.size()-1)if((res[i]=operator[](i)+rhs[i])>=P)res[i]-=P;
		return res;
	}
	poly operator-()const{
		poly ret(a);
		rep(i,0,a.size()-1)if(ret[i])ret[i]=P-ret[i];
		return ret;
	}
	poly operator-(const poly&rhs)const{return operator+(-rhs);}
        /*
        這里應該有一堆屎山聲明,可是我懶得羅列了所以就沒寫
        */
    poly shift(int k)const;
};//聲明+部分簡單函數
poly zeroes(int deg){return vec(deg+1);}//0函數
poly operator "" _z(unsigned long long a){return {0,(int)a};}
poly operator+(int v,const poly&rhs){return poly(v)+rhs;}//多項式加整數
poly operator*(int v,const poly&rhs){//多項式乘整數
	poly ret=zeroes(rhs.deg());
	rep(i,0,rhs.deg())ret[i]=1ll*rhs[i]*v%P;
	return ret;
}
poly operator*(const poly&lhs,int v){return v*lhs;}
poly poly::shift(int k)const{//多項式乘 x^k
	poly g=zeroes(deg()+k);
	rep(i,0,k-1)g[i]=0;
	rep(i,min(0,-k),deg()-1)g[i+k]=a[i];
	return g;
}
template<class T>
IO_Tp& operator>>(IO_Tp& IO,vector<T>&v){//輸入 vector
	for(T&x:v)IO>>x;
	return IO;
}
template<class T>
IO_Tp& operator<<(IO_Tp& IO,vector<T>&v){//輸出 vector
	for(T&x:v)IO<<x;
	return IO;
}

一.多項式乘法

原理和普通的多項式乘法一致,沒啥好說的,我們看看哪些地方可以優化。

我們注意到可以在初始化的時候做一些預處理,這樣大概可以減少一定的常數,在 P3803 這道題上面總時間 \(1.6s\to 1s\),快了 0.6s。

補充:EI 認為 “經過測試,某些形式特殊的數組的 NTT 改良版本,看似省略了部分計算,實則緩存不友好,還不如直接做”。(雖然經過測試,某一版本的 NTT 比這一版塊 0.2s,但是碼量差不多翻了一倍(在 NTT 部分),因此不予使用。)

Code
struct NTT{
	int L,brev[1<<11];
	vec root;
	NTT():L(-1){
		rep(i,1,(1<<11)-1)brev[i]=brev[i>>1]>>1|((i&1)<<10);
	}
	void preproot(int l){
		L=l;
		root.resize(2<<L);
		rep(i,0,L){
			int *w=root.data()+(1<<i);
			w[0]=1;
			int omega=qpow(G,(P-1)>>i);
			rep(j,1,(1<<i)-1)w[j]=1ll*w[j-1]*omega%P;
		}
	}
	void dft(int*a,int lgn,int d=1){
		if(L<lgn)preproot(lgn);
		int n=1<<lgn;
		rep(i,0,n-1){
			int rev=(brev[i>>11]|(brev[i&((1<<11)-1)]<<11))>>((11<<1)-lgn);
			if(i<rev)swap(a[i],a[rev]);
		}
		for(int i=1;i<n;i<<=1){
			int *w=root.data()+(i<<1);
			for(int j=0;j<n;j+=i<<1)rep(k,0,i-1){
				int aa=1ll*w[k]*a[i+j+k]%P;
				a[i+j+k]=norm(a[j+k]+P-aa);
				add(a[j+k],aa);
			}
		}
		if(d==-1){
			reverse(a+1,a+n);
			int inv=nt.inv(n);
			rep(i,0,n-1)a[i]=1ll*a[i]*inv%P;
		}
	}
}ntt;

二.多項式乘法逆

閱讀參考資料

首先列一下牛迭的式子:假設 \(f\in \mathbb R[[x]],A\in \mathbb R[[x,y]]\),滿足 \(A(x,f)=0\),令 \(f_0=f\bmod x^n\),則

\[f\bmod x^{2n}=f_0-\frac{A(x,f_0)}{\frac{\delta A}{\delta y}(x,f_0)}\bmod x^{2n} \]

如果定義 \(\operatorname{ord}(f)=\min\{n|[x^n]f\neq 0\}\),那么我們可以觀察到 \(\operatorname{ord}(A(f_0))\ge n\),因此計算 \(A'(f_0)\) 的精度只需達到 \(n\) 即可。

下面討論操作的優化。在以下討論中,內容分為三部分:

  1. 直接按式子計算的時間。下面定義 \(E(n)\) 為一次長度為 \(n\) 的 DFT 所需的時間,\(M(n)\) 為一次兩個精度為 \(n\) 的形式冪級數的乘法所需要的時間,因此 \(M(n)=(3+o(1))E(2n)=(6+o(1))E(n)\)。在下文中,一切 \(o(1)\) 會被省略。
  2. 利用循環卷積優化。注意到在大部分情況下,我們已經得到了結果中的一部分系數,而長度為 \(n\) 的 DFT 解決了循環卷積問題 \(fg\bmod (x^n-1)\),僅用於計算卷積很浪費,所以可以考慮先計算循環卷積,必要時進行一些處理,最后得到所需的系數。
  3. 減少 DFT 次數。注意到很多時候計算了相同的 DFT,或者可以用線性變換的性質合並幾次 IDFT,考慮減少這些額外的開銷。

在這一部分,我們研究的問題是倒數。

1

\(f\in \mathbb R[[x]]\),令 \(g=1/f\),求 \(g\)

\(A(g)=fg-1\),代入牛迭式子可得:

\[g\bmod x^{2n}=2g_0-fg_0^2\bmod x^{2n} \]

這就是我們一般使用的方法。

這里一共使用了 \(1\) 次長度 \(2n\) 的乘法,\(1\) 次長度 \(4n\) 的乘法,用時 \(M(2n)+M(4n)=18E(n)\)。(其實我覺得做三次長度為 \(2n\) 的 DFT 就行了,這樣是 \(12E(n)\) 的,不知道為什么沒有這么寫)

UPD:破案了,這樣的優化方法被算在了 3 里面。

2

\(f\in \mathbb R[[x]]\),令 \(g=1/f\),求 \(g\)

考慮 \(g\bmod x^{2n}=g_0-(fg_0-1)g_0\bmod x^{2n}\),顯然 \(\deg((f\bmod x^2n)g_0-1)<3n,\operatorname{ord}((f\bmod x^2n)g_0-1)\ge n\),因此只需計算 \((f\bmod x^{2n})g_0\bmod (x^{2n}-1)\) 即可,同理 \((fg_0-1)g_0\bmod x^{2n}\) 也只需要長為 \(2n\) 的循環卷積,用時 \(12E(n)\)。所以計算 \(g\bmod x^n\) 的時間是 \(12E(n)\)

3

觀察上述過程,有兩次和 \(g_0\) 有關的長為 \(2n\) 的循環卷積,可以記錄下來而不是重新算,用時 \(10E(n)\)

EI 的代碼應該是按照這個實現的,不過把普通的遞歸換成了迭代,因此總時間 \(0.6s\to 0.2s\),優化了 \(0.4s\)

Code
struct Newton{
	void inv(const poly&f,const poly&nttf,poly&g,const poly&nttg,int t){//given f,g,nttf,nttg
		int n=1<<t;
		poly prod=nttf;
		rep(i,0,(n<<1)-1)prod[i]=1ll*prod[i]*nttg[i]%P;
		ntt.dft(prod.base(),t+1,-1);//calculate fg-1
		rep(i,0,n-1)prod[i]=0;//prod=
		ntt.dft(prod.base(),t+1,1);
		rep(i,0,(n<<1)-1)prod[i]=1ll*prod[i]*nttg[i]%P;
		ntt.dft(prod.base(),t+1,-1);//calculate (fg-1)g
		rep(i,0,n-1)prod[i]=0;
		g=g-prod;//calculate g-(fg-1)g
	}
	void inv(const poly&f,const poly&nttf,poly&g,int t){//given f,nttf,g
		poly nttg=g;
		nttg.redeg((2<<t)-1);
		ntt.dft(nttg.base(),t+1,1);//calc nttg
		inv(f,nttf,g,nttg,t);
	}
	void inv(const poly&f,poly&g,int t){//given f,g
		poly nttg=g;
		nttg.redeg((2<<t)-1);
		ntt.dft(nttg.base(),t+1,1);//calc nttg
		poly nttf=f;
		nttf.redeg((2<<t)-1);
		ntt.dft(nttf.base(),t+1,1);//calc nttf
		inv(f,nttf,g,nttg,t);
	}
}nit;
poly poly::inv()const{
	poly g=nt.inv(a[0]);
	for(int t=0;(1<<t)<=deg();++t)nit.inv(slice((2<<t)-1),g,t);
	g.redeg(deg());
	return g;
}

還可以繼續優化嗎?

如果允許長度為 \(3n\) 的 DFT,那么考慮 \(g\bmod x^{2n}=g_0-(fg_0^2-g_0)\bmod x^{2n}\),用長度為 \(3n\) 的循環卷積計算 \(fg_0^2\) 即可達到 \(9E(n)\) 的用時,可惜大部分時候(比如 \(998244353\))都不能做。

注意到我們不是必須要循環卷積才能解決問題,對 \(a\in \mathbb R,a^{2n}\neq 1\),考慮在 \(\mathbb R[x]/(x^{2n}-1)(x^n-a^n)\) 中計算卷積,即在 \(1,\zeta_{2n},\zeta_{2n}^2,\dots,\zeta_{2n}^{2n-1},a,a\zeta_n,a\zeta_n^2,\dots,a\zeta_n^{n-1}\) 上多點求值和插值。對 \(f\in \mathbb R[x]/(x^{2n}-1)(x^n-a^n)\) 進行多點求值只需用 FFT 計算 \(\mathcal F_{2n}(f)\)\(\mathcal F_{2n}(f(ax))\),而插值只需分別還原 CRT 合並。

容易發現,如果在 \(\mathbb R[x]/(x^{2n}-1)(x^n-a^n)\) 中進行卷積,仍然可以處理超出長度部分的影響,且不需要長度為 \(3n\) 的 DFT,同時也計算了 \(\mathcal F_{2n}(f\bmod x^{2n})\)\(\mathcal F_{2n}(g_0)\),所需時間仍是改進前的 \(9E(n)=\frac{3}{2}M(n)\),所以可以幾乎完全代替前一種做法。

簡單描述一下思路:為了算出所需結果 \(f\),先算出 \(f\bmod (x^{2n}−1)(x^n−a^n)\),考慮超出部分對前 \(n\) 項(本應全是 \(0\))的貢獻,利用這些信息還原出這一部分,然后即可把這一部分對所需部分的影響消除掉。

算出結果需在 \(1, \zeta_{2n}, \zeta_{2n}^2, \dots, \zeta_{2n}^{2n-1}, a, a\zeta_n, a\zeta_n^2, \dots, a\zeta_n^{n-1}\) 上多點求值和插值,多點求值即計算 \(\mathcal F_{2n}(f)\)\(\mathcal F_n(f(ax))\),插值即分別還原並 CRT 合並。

實際實現並不需要考慮這些,最后的推薦實現也可以這樣描述。考慮將所需結果 \(f\) 表示為 \(ax^n+bx^{2n}+cx^{3n}\),其中 \(a,b,c\in \mathbb R[x]\)\(\deg(a),\deg(b),\deg(c)<n\),那么可以用循環卷積計算出 \(f\bmod (x^{2n}-1)\)\(f\bmod (x^n-i)=f(\zeta_{4n}x)\bmod (x^n-1)\),也就相當於算出了 \(b,a+c,ia-b-ic\),還原出 \(a\) 即可。

三.多項式 ln

哈哈這個我熟,直接兩邊求導再積分,就可以得到 \(\ln f=\int f'/f\),然后利用多項式求逆就可以做到 \(9E(n)+6E(n)=15E(n)\) 了!

一翻 EI 的代碼:wtf 怎么這么長,怎么還有個 quo 函數???

EI:You are too naive.

翻了翻博客,發現求商數居然還有科技,學廢了。

1

商數:對於 \(f,h\in \mathbb R[[x]]\),令 \(g=1/f,q=hg=h/f\),求 \(q\)

顯然先求 \(g=1/f\),再求 \(q=hg\) 即可,總時間是 \(18E(n)+6E(n)=24E(n)\)

對數:略。

2

對於 \(f,h\in \mathbb R[[x]]\),令 \(g=1/f,q=hg=h/f\),求 \(q\)

直接求 \(g\) 然后卷積的總時間是 \(18E(n)\)

如果不求 \(g\),注意到 \(A(q)=fq-h=0\),令 \(g_0=g\bmod x^n,h_0=h\bmod x^n,q_0=q\bmod x^n=g_0h_0\bmod x^n\),有

\[q\bmod x^{2n}=q_0-(fq_0-h)g_0\bmod x^{2n} \]

計算 \(g_0\)\(12E(n)\),計算 \(q_0\)\(6E(n)\),計算 \((fq_0-h)\) 和倒數類似需要 \(12E(n)\) 的時間,因此計算 \(q\bmod x^{2n}\)\(30E(n)\),計算 \(q\bmod x^n\) 就是 \(15E(n)\)

3

對於 \(f,h\in \mathbb R[[x]]\),令 \(g=1/f,q=hg=h/f\),求 \(q\)

\(g_0=g\bmod x^n,g_1=(g\bmod x^{2n}-g_0)/x^n,h_0=h\bmod x^n,h_1=(h\bmod x^{2n}-h_0)/x^n\)

如果需要求 \(g\),考慮計算

\[q\bmod x^{2n}=(g\bmod x^{2n})(h\bmod x^{2n})\bmod x^{2n}=g_0h_0+(g_0h_1+g_1h_0)x^n\bmod x^{2n} \]

  • 計算 \(\mathcal F_{2n}(g_0),g_0,g_1\) 需要 \(18E(n)\) 時間。
  • 計算 \(\mathcal F_{2n}(g_1),\mathcal F_{2n}(h_0),\mathcal F_{2n}(h_1)\) 需要 \(6E(n)\) 時間。
  • 計算 \(g_0h_0,g_0h_1+g_1h_0\) 需要 \(4E(n)\) 時間。

總共需要 \(28E(n)\) 時間計算 \(q\bmod x^{2n}\),因此計算 \(q\bmod x^n\) 就需要 \(14E(n)\) 時間。

這里用到的技巧可以表述為:對於 \(f,g\in \mathbb R[[x]]\),已知 \(f\bmod x^n,g\bmod x^n,\mathcal F_n(f\bmod x^{n/2})\),則需要 \(5E(n)\) 時間計算出 \(fg\mod x^n\)

如果不求 \(g\),仍然考慮

\[q\bmod x^{2n}=q_0-(fq_0-h)g_0\bmod x^{2n} \]

與第二部分的做法相比,可以用更快計算倒數的方法,計算 \(q_0\) 時用的 \(\mathcal F_{2n}(g_0)\) 可以用於計算 \((fq_0-1)g_0\),所以計算 \(q\bmod x^{2n}\) 所用時間為 \(24E(n)\),計算 \(q\bmod x^n\) 總時間為 \(12E(n)=2M(n)\)

  • 計算 \(g_0\) \(\to\) \(9E(n)\)
  • 計算 \(\mathcal F_{2n}(g_0),q_0\) \(\to\) \(6E(n)\)
  • 計算 \(fq_0\) \(\to\) \(M(n)=6E(n)\)
  • 計算 \((fq_0-h)g_0\),已知 \(\mathcal F_{2n}(g_0)\) \(\to\) \(4E(n)\)

因此計算 \(q\bmod x^{2n}\) 總時間是 \(25E(n)\),計算 \(q\bmod x^n\) 總時間是 \(12.5E(n)\)

觀察 \(q=g_0h_0\bmod x^n,(fq_0-h)g_0\),可以發現符合上文描述的技巧的使用條件,其中 \((fq_0-h)g_0\) 可以視為計算 \(((fq_0-h)/x^n)g_0\),再考慮到相同的 DFT 只用計算一次,需要 \(9E(n)\) 時間計算,總時間就是 \(12E(n)=2M(n)\)

Code
poly poly::quo(const poly&rhs)const{
	if(rhs.deg()==0)return 1ll*a[0]*nt.inv(rhs[0])%P;
	poly g=nt.inv(rhs[0]);
	int t=0,n;
	for(n=1;(n<<1)<=rhs.deg();++t,n<<=1)nit.inv(rhs.slice((n<<1)-1),g,t);
	poly nttg=g;
	nttg.redeg((n<<1)-1);
	ntt.dft(nttg.base(),t+1,1);
	poly eps1=rhs.slice((n<<1)-1);
	ntt.dft(eps1.base(),t+1,1);
	rep(i,0,(n<<1)-1)eps1[i]=1ll*eps1[i]*nttg[i]%P;
	ntt.dft(eps1.base(),t+1,-1);
	memcpy(eps1.base(),eps1.base()+n,sizeof(int)<<t);
	memset(eps1.base()+n,0,sizeof(int)<<t);
	ntt.dft(eps1.base(),t+1,1);
	poly h0=slice(n-1);
	h0.redeg((n<<1)-1);
	ntt.dft(h0.base(),t+1,1);
	poly h0g0=zeroes((n<<1)-1);
	rep(i,0,(n<<1)-1)h0g0[i]=1ll*h0[i]*nttg[i]%P;
	ntt.dft(h0g0.base(),t+1,-1);
	poly h0eps1=zeroes((n<<1)-1);
	rep(i,0,(n<<1)-1)h0eps1[i]=1ll*h0[i]*eps1[i]%P;
	ntt.dft(h0eps1.base(),t+1,-1);
	rep(i,0,n-1)h0eps1[i]=reduce(operator[](i+n)-h0eps1[i]);
	memset(h0eps1.base()+n,0,sizeof(int)<<t);
	ntt.dft(h0eps1.base(),t+1,1);
	rep(i,0,(n<<1)-1)h0eps1[i]=1ll*h0eps1[i]*nttg[i]%P;
	ntt.dft(h0eps1.base(),t+1,-1);
	memcpy(h0eps1.base()+n,h0eps1.base(),sizeof(int)<<t);
	memset(h0eps1.base(),0,sizeof(int)<<t);
	return (h0g0+h0eps1).slice(rhs.deg());
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM