FFT什么的


目錄

  這里只有公式&做法,沒有復雜的證明(其實是因為弱雞yww不會)

  參考自國家集訓隊論文&各個博客

多項式

​  一個以\(x\)為變量的多項式定義在一個代數域\(F\)上,將函數\(A(x)\)表示為形式和:

\[A(x)=\sum_{j=0}^{n-1}a_jx^j \]

我們稱\(a_0,a_1,\ldots,a_{n-1}\)為多項式的系數,所有系數都屬於數域\(F\),典型的情形是負數集合\(C\)

  如果一個多項式的最高次的非零系數是\(a_k\),則稱\(A(x)\)的次數是\(k\)。任何嚴格大於一個多項式次數的整數都是該多項式的次數界。因此,對於次數界為\(n\)的多項式\(C(x)\),其次數可以是\(0\)~\(n-1\)之間的任何整數,包括\(0\)\(n-1\)

​  我們在多項式上可以定義很多不同的運算。

多項式加法

​  如果\(A(x)\)\(B(x)\)是次數界為\(n\)的多項式,那么他們的和也是一個次數界為\(n\)的多項式\(C(x)\)。對於所有屬於定義域的\(x\),都有\(C(x)=A(x)+B(x)\)。也就是說,若

\[A(x)=\sum_{j=0}^{n-1}a_jx^j\\ B(x)=\sum_{j=0}^{n-1}b_jx^j \]

\[C(x)=\sum_{j=0}^{n-1}c_jx^j\\ \]

其中

\[c_j=a_j+b_j \]

​  例如,如果

\[A(x)=6x^3+7x^2-10x+9,B(x)=-2x^3+4x-5 \]

\[C(x)=4x^3+7x^2-6x+4 \]

多項式乘法

​  如果\(A(x)\)是次數界為\(n\)的多項式,\(B(x)\)是次數界為\(m\)的多項式,那么他們的乘積是一個次數界為\(n+m\)的多項式\(C(x)\)。其中

\[c_j=\sum_{k=0}^ja_kb_{j-k} \]

​  例如,如果

\[A(x)=6x^3+7x^2-10x+9,B(x)=-2x^3+4x-5 \]

​  則

\[C(x)=-12x^6-14x^5+44x^4-20x^3-75x^2+86x-45 \]

多項式的表示

系數表達

​  對一個次數界為\(n\)的多項式\(A(x)=\sum_{j=0}^{n-1}a_jx^j\)而言,其系數表達式一個由系數組成得到向量\(a=(a_0,a_1,\cdots,a_{n-1})\)

​  我們可以用秦久韶算法在\(O(n)\)的時間內求出多項式在給定點\(x_0\)的值,即求值運算:

\[A(x_0)=a_0+x_0(a_1+a_0(a_2+\cdots+x_0(a_{n-1}+x_0(a_{n-1})\cdots)) \]

​  類似的,對於兩個分別用系數向量\(a=(a_0,a_1,\cdots,a_{n-1}),b=(b_0,b_1,\cdots,b_{n-1})\)表示的多項式進行相加時,所需的時間是\(O(n)\)。我們只用輸出系數向量\(c=(c_0,c_1,\cdots,c_{n-1})\),其中\(c_i=a_i+b_i\)

​  現在來考慮兩個用系數形式表達的次數界為\(n\)的多項式\(A(x),B(x)\)的乘法運算,所需要的時間是\(O(n^2)\)。系數向量\(c\)也稱為輸入向量\(a,b\)的卷積。\(c=a\otimes b\)

點值表達

​  一個次數界為\(n\)的多項式的點值表達就是一個有\(n\)個點值對所組成的集合。

\[\{(x_0,y_0),(x_1,y_1),\cdots,(x_{n-1},y_{n-1})\} \]

使得對\(k=0,1,\cdots,n-1\),所有\(x_k\)各不相同且\(y_k=A(x_k)\)

​  一個多項式可以有很多不同的點值表達,因為可以采用\(n\)個不同的點構成的集合作為這種表示方法的基。

​  朴素的求值是\(O(n^2)\)的。

​  求值的逆稱為插值。當插值多項式的次數界等於已知的點值對的數目時,插值才是明確的。

​  我們可以在用高斯消元在\(O(n^3)\)內插值,也可以用拉格朗日插值\(O(n^2)\)內插值。

​  以上求值和插值可以將多項式的系數表達和點值表達進行相互轉化,上面給出的算法的時間復雜度是\(O(n^2)\),但我們可以巧妙地選取\(x_k\)來加速這一過程,使其運行時間變為\(O(nlogn)\)

​  對於許多多項式相關的操作,點值表達式很便利的。

​  對於加法,如果\(C(x)=A(x)+B(x)\)。給定\(A\)的點值表達

\[\{(x_0,y_0),(x_1,y_1),\cdots,(x_{n-1},y_{n-1})\} \]

\(B\)的點值表達

\[\{(x_0,y'_0),(x_1,y'_1),\cdots,(x_{n-1},y'_{n-1})\} \]

(注意,\(A\)\(B\)在相同的\(n\)個位置求值),則\(C\)的點值表達是

\[\{(x_0,y_0+y'_0),(x_1,y_1+y'_1),\cdots,(x_{n-1},y_{n-1}+y'_{n-1})\} \]

因此,對兩個點值形式表示的次數界為\(n\)的多項式相加,時間復雜度是\(O(n)\)

​  類似的,如果\(C(x)=A(x)B(x)\),我們需要\(2n\)個點值對才能插出\(C\)。給定\(A\)的點值表達

\[\{(x_0,y_0),(x_1,y_1),\cdots,(x_{2n-1},y_{2n-1})\} \]

\(B\)的點值表達

\[\{(x_0,y'_0),(x_1,y'_1),\cdots,(x_{2n-1},y'_{2n-1})\} \]

(注意,\(A\)\(B\)在相同的\(2n\)個位置求值),則\(C\)的點值表達是

\[\{(x_0,y_0y'_0),(x_1,y_1y'_1),\cdots,(x_{2n-1},y_{2n-1}y'_{2n-1})\} \]

因此,對兩個點值形式表示的次數界為\(n\)的多項式相乘,時間復雜度是\(O(n)\)

​  最后,我們考慮一個采用點值表達的多項式,如何求其在某個新點上的值。最簡單的方法是把該多項式轉成系數形式表達,然后在新點處求值。

系數形式表示的多項式的快速乘法

​  如果我們選\(n\)次單位復數根作為求值點,我們可以在\(O(nlogn)\)內求值和插值。我們先在對這兩個多項式\(A,B\)求值之前添加\(n\)\(0\),使其次數界加倍為\(2n\)。現在我們采用“\(2n\)次單位復數根”作為求值點。

DFT&FFT&IDFT

單位復數根

​  \(n\)次單位復數根是滿足\(w^n=1\)的復數\(w\)\(n\)次單位復數根恰好有\(n\)個,對於\(k=0,1,\cdots,n-1\),這些根是\(e^{\frac{2\pi ik}{n}}\)\(w_n=e^\frac{2\pi i}{n}\)稱為主\(n\)次單位根,所有其他\(n\)次單位復數根都是\(w_n\)的冪次。這\(n\)\(n\)次單位復數根在乘法意義下形成了一個群,即\(w_n^jw_n^k=w_n^{(j+k)mod~n}\),而且這\(n\)\(n\)次單位復數根均勻分布在以復平面的原點為圓心的單位半徑的圓周上。(圖片from zjt)

  

​  消去引理:對任何整數\(n\geq 0,k\geq 0,d>0\)

\[w_{dn}^{dk}=w_n^k \]

DFT

​  回顧一下,我們希望計算次數界為\(n\)的多項式\(A(x)\)\(w_n^0,w_n^1,\cdots,w_n^{n-1}\)處的值(即在\(n\)\(n\)次單位復數根處)。對於\(k=0,1,\cdots,n-1\),定義結果\(y_k\)

\[y_k=A(w_n^k)=\sum_{j=0}^{n-1}a_jw_n^{kj} \]

向量\(y=(y_0,y_1,\cdots,y_{n-1})\)就是系數向量\(a\)的離散傅里葉變換(DFT),我們也記為\(y=DFT_n(a)\)

FFT

​  利用單位復數根的特殊性質,我們可以在\(O(nlogn)\)內計算出\(DFT_n(a)\)。這里假設\(n\)\(2\)的冪。

  FFT利用了分治策略。

  我們令\(a=(a_0,a_1,\cdots,a_{n-1}),a_1=(a_0,a_2,\cdots,a_{n-2}),a_2=(a_1,a_3,\cdots,a_{n-1})\)

  對於\(k<\frac n2\)有:

\[\begin{align} y_k&=A(w_n^k)\\ &=\sum_{j=0}^{n-1}a_jw_n^{kj}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj+k}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+w_n^k\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj}\\ &=\sum_{j=0}^{\frac n2-1}{a_1}_{j}w_{\frac n2}^{kj}+w_n^k\sum_{j=0}^{\frac n2-1}{a_2}_{j}w_{\frac n2}^{kj}\\ &={y_1}_k+w_n^k{y_2}_k \end{align} \]

  對於\(k\geq \frac n2\)有:

\[\begin{align} y_k&=A(w_n^k)\\ &=\sum_{j=0}^{n-1}a_jw_n^{kj}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj+k}\\ &=\sum_{j=0}^{\frac n2-1}a_{2j}w_n^{2kj}+w_n^k\sum_{j=0}^{\frac n2-1}a_{2j+1}w_n^{2kj}\\ &=\sum_{j=0}^{\frac n2-1}{a_1}_{j}w_{\frac n2}^{kj}+w_n^k\sum_{j=0}^{\frac n2-1}{a_2}_{j}w_{\frac n2}^{kj}\\ &=\sum_{j=0}^{\frac n2-1}{a_1}_{j}w_{\frac n2}^{(k-\frac n2)j}+w_n^k\sum_{j=0}^{\frac n2-1}{a_2}_{j}w_{\frac n2}^{(k-\frac n2)j}\\ &={y_1}_{k-\frac n2}+w_n^k{y_2}_{k-\frac n2}\\ &={y_1}_{k-\frac n2}-w_n^{k-\frac n2}{y_2}_{k-\frac n2} \end{align} \]

  這樣我們把\(y_1,y_2\)合並為\(y\)的時間復雜度是\(O(n)\)。所以總的時間復雜度是

\[T(n)=2T(\frac n2)+O(n)=O(n\log n) \]

IDFT

​  通過推導公式,我們得到:

\[a_k=\frac1n\sum_{j=0}^{n-1}y_jw_n^{-kj} \]

​  所以我們可以用類似FFT的方法在\(O(n\log n)\)內求出\(IDFT_n(y)\)

多項式乘法

​  我們可以在\(O(n)\)內補\(0\)\(O(n\log n)\)內求值,\(O(n)\)內點值乘法,\(O(n\log n)\)內插值。所以我們可以在\(O(n\log n)\)內求出\(a\otimes b\)

\[a\otimes b=IDFT_{2n}(DFT_{2n}(a)\cdot DFT_{2n}(b)) \]

蝶形運算

  我們把由\({y_1}_k,{y_2}_k,w_n^k\)得到\(y_k,y_{k+\frac n2}\)的過程稱為蝴蝶操作。

​  我們發現,遞歸時\(a\)是長這樣的:

\[0~~~1~~~2~~~3~~~4~~~5~~~6~~~7\\ 0~~~2~~~4~~~6~|~1~~~3~~~5~~~7\\ 0~~~4~|~2~~~6~|~1~~~5~|~3~~~7\\ 0~|~4~|~2~|~6~|~1~|~5~|~3~|~7 \]

  總的蝶形運算是長這樣的:
  
  

​  可以發現,最后\(a_i\)是原來的\(a_{rev(i)}\)。所以我們可以交換\(a_i,a_{rev(i)}\),然后一層層來做。這樣可以減小常數。

NTT

​  在某些時候,我們需要求模\(p\)意義下的卷積。

​  先求出\(p\)的原根\(g\),可以發現,\(g^{\frac{p-1}{n}}\)\(w_n\)的性質類似。所以我們可以用\(g^{\frac{p-1}{n}}\)來代替\(w_n\)

時間上的優化

  當我們要算兩個多項式 \(A(x), B(x)\) 的乘積的時候,普通的做法是先把 \(a,b\) 兩個序列 DFT,再點乘,再 IDFT 回去。

  但是我們還有一種方法:

​  令\(t_j=(a_j+b_j)+(a_j-b_j)i,S=T\times T\)

​  \(s_j\)的實部為

\[\begin{align} \sum_{k=0}^j(a_k+b_k)(a_{j-k}+b_{j-k})-(a_k-b_k)(a_{j-k}-b_{j-k})&=\sum_{k=0}^j4a_kb_{j-k}=4\sum_{k=0}^ja_kb_{j-k} \end{align} \]

  這樣我們就可以求出\(S=T\times T\),然后把\(s_j\)除以\(4\)

  這個方法可以把\(3\)次DFT改成\(2\)次DFT。

多項式求導

  給定\(A(x)=\sum_{i\geq 0}a_ix^i\),定義\(A(x)\)的形式導數為

\[A'(x)=\sum_{i\geq 1}ia_ix^{i-1} \]

多項式積分

  給定\(A(x)=\sum_{i\geq 0}a_ix^i\),則

\[\int A(x)=\sum_{i\geq 1}\frac{a_{i-1}}{i}x^i \]

多項式求逆

​  多項式\(A(x)\)存在乘法逆元的充要條件是\(A(x)\)的常數項存在乘法逆元。

​  下面介紹一個\(O(n~log~n)\)計算乘法逆元的算法,它的本質是牛頓迭代法

​  首先求出\(A(x)\)常數項的逆元\(b\),令\(B(x)\)的初始值為\(b\)

​  假設已求出滿足

\[A(x)B(x)\equiv1~(mod~x^n) \]

\(B(x)\),則

\[\begin{align} A(x)B(x)-1&\equiv0~(mod~x^n)\\ {(A(x)B(x)-1)}^2&\equiv 0~(mod~x^{2n})\\ A(x)(2B(x)-B(x)^2A(x))&\equiv 1~(mod~x^{2n}) \end{align} \]

​  我們可以用\(O(n~log~n)\)的時間計算出\(2B(x)-B(x)^2A(x)\),並將它賦值給\(B(x)\)進行下一次迭代。每迭代一次,\(B(x)\)的有效項數\(n\)都會增加一倍。於是該算法的時間復雜度為

\[T(n)=T(n/2)+O(n\log n)=O(n\log n) \]

多項式開根

  已知\(A(x)\),求\(B(x)\)使得

\[B(x)^2\equiv A(x)~(mod~x^n) \]

  先求出\(A(x)\)常數項的平方根\(b\)(可以用二次剩余的東西來算,但我只會暴力算),令\(B(x)\)的初始值為\(b\)

  假設已求出滿足

\[B(x)^2\equiv A(x)~(mod~x^n) \]

\(B(x)\),則

\[\begin{align} B(x)^2-A(x)&\equiv 0~(mod~x^n)\\ {(B(x)^2-A(x))}^2&\equiv 0~(mod~x^{2n})\\ B(x)^4-2B(x)^2A(x)+A(x)^2&\equiv 0~(mod~x^{2n})\\ B(x)^4+2B(x)^2A(x)+A(x)^2&\equiv 4B(x)^2A(x)~(mod~x^{2n})\\ {(B(x)^2+A(x))}^2&\equiv {(2B(x))}^2A(x)~(mod~x^{2n})\\ {(\frac{B(x)^2+A(x)}{2B(x)})}^2&\equiv A(x)~(mod~x^{2n}) \end{align} \]

  我們可以在\(O(n\log n)\)內算出\(\frac{B(x)^2+A(x)}{2B(x)}=\frac{B(x)}{2}+\frac{A(x)}{2B(x)}\),並把它賦值給\(B(x)\)

  時間復雜度:\(O(n\log n)\)

多項式ln

  給定形式冪級數\(A(x)=\sum_{i\geq 1}a_ix^i\),定義

\[\ln(1-A(x))=-\sum_{i\geq 1}\frac{{A(x)}^i}{i} \]

  給定多項式\(A(x)=1+\sum_{i\geq 1}a_ix^i\),令

\[B(x)=\ln(A(x)) \]

\[B'(x)=\frac{A'(x)}{A(x)} \]

  只需要求出\(A(x)\)的乘法逆元,就可以求出\(\ln(A(x))\)

多項式exp

  給定形式冪級數\(A(x)=\sum_{i\geq 1}a_ix^i\),定義

\[\exp(A(x))=\sum_{i\geq 0}\frac{{A(x)}^i}{i!} \]

  令\(f(x)=e^{A(x)}\),可得到一個關於\(f(x)\)的方程

\[g(f(x))=\ln(f(x))-A(x)=0 \]

  考慮用牛頓迭代解這一方程。首先\(f(x)\)的常數項是容易確定的(就是\(1\))。

  設以求得\(f(x)\)的前\(n\)\(f_0(x)\),即

\[f(x)\equiv f_0(x)~~~(mod~~~x^n) \]

  作泰勒展開得

\[\begin{align} 0&=g(f(x))\\ &=g(f_0(x))+g'(f_0(x))(f(x)-f_0(x))~~~~~(mod~~~x^{2n}) \end{align} \]

\[f(x)\equiv f_0(x)-\frac{g(f_0(x))}{g'(f_0(x))}~~~~(mod~~~x^{2n}) \]

  把上面那個式子帶入得

\[\begin{align} f(x)&=f_0(x)-\frac{\ln(f_0(x))-A(x)}{\frac{1}{f_0(x)}}\\ &=f_0(x)(1-\ln(f_0(x))+A(x)) \end{align} \]

  時間復雜度:\(O(n\log n)\)
  

多項式求冪

  給你\(A(x),k\),求\(A^k(x)\)

  設\(A(x)\)中最低次數項是\(cx^d\),那么先把整個多項式除以\(cx^d\),再求\(\ln\),把整個多項式乘以\(k\),再求\(\exp\),再乘上\(c^kx^{kd}\)

\[A^k(x)=\exp(k\ln\frac{A(x)}{cx^d}))c^kx^{kd} \]

  時間復雜度:\(O(n\log n)\)

多項式除法

​  給你\(A(x),B(x)\),求兩個多項式\(D(x),R(x)\)滿足

\[A(x)=D(x)B(x)+R(x) \]

​  若\(A(x)\)是一個\(n\)階多項式,則

\[A^R(x)=x^nA(\frac1x) \]

  舉個例子:比如說

\[A(x)=x^3+2x^2+3x+4\\ A^R(x)=1+2x+3x^2+4x^3 \]

​  相當於把\(A(x)\)的系數反轉。

  我們設\(A(x)\)\(n\)階多項式,\(B(x)\)\(m\)階多項式,\(D(x)\)\(n-m\)階多項式,\(R(x)\)\(m-1\)階多項式。我們把上個式子的\(x\)\(\frac1x\),然后全部乘上\(x^n\)

\[x^nA(\frac1x)=x^{n-m}D(\frac1x)x^mB(\frac1x)+x^{n-m+1}x^{m-1}R(\frac1x)\\ A^R(x)=D^R(x)B^R(x)+x^{n-m+1}R^R(x) \]

  然后我們把這個式子放在模\(x^{n-m+1}\)意義下,得到

\[A^R(x)=D^R(x)B^R(x)~(mod~x^{n-m+1})\\ D^R(x)=A^R(x){(B^R(x))}^{-1}~(mod~x^{n-m+1}) \]

  因為\(D(x)\)的次數是\(n-m\),所以不會受模意義的影響。

  然后把\(D(x)\)帶入到原來的式子中,就可以算出\(R(x)\)了。

  時間復雜度:\(O(n\log n)\)

多點求值

  給你一個多項式\(A(x)\)\(n\)個點\(x_0,x_1,\cdots,x_{n-1}\),求這個多項式在這\(n\)個點處的值,即求\(A(x_0),A(x_1),\cdots,A(x_{n-1})\)

  考慮一個簡單的做法:構造\(B_i(x)=x-x_i,C_i(x)=A(x)~mod~B_i(x)\),那么\(B_i(x_i)=0\)。所以\(A(x_i)=C_i(x_i)\)。但是計算\(B_i(x)\)\(C_i(x)\)\(O(n)\)的,必須加速這個過程。

  設當前求值的點為\(X=\{x_0,x_1,\cdots,x_{n-1}\}\),我們可以把這\(n\)個點分為兩半:

\[X_0=\{x_0,x_1,\cdots,x_{\frac n2-1}\}\\ X_1=\{x_{\frac n2},x_{\frac n2+1},\cdots,x_{n-1}\} \]

  構造多項式

\[B_0=\prod_{i=0}^{\frac n2-1}(x-x_i)\\ B_1=\prod_{i=\frac n2}^{n-1}(x-x_i)\\ A_0=A~mod~B_0\\ A_1=A~mod~B_1 \]

  那么當\(x\in X_0\)\(A(x)=A_0(x)\),可以遞歸計算。當\(x\in X_1\)時同理。

  每一層計算\(B_0,B_1,A_0,A_1\)的時間復雜度都是\(O(n\log n)\)

  總的時間復雜度就是

\[T(n)=2T(\frac n2)+O(n\log n)=O(n\log^2n) \]

快速插值

  考慮怎么求\(g_i=\prod_{j=0,j\neq i}^n (x_i-x_j)\),也就是分母。

\[\begin{align} g_i&=\prod_{j=0,j\neq i}^n (x_i-x_j)\\ &=\lim_{x \to x_i}\frac{\prod_{j=0}^n (x-x_j)}{x-x_i}\\ &=(\prod_{j=0}^n (x-x_j))'|_{x=x_i} \end{align} \]

  可以分治求出\(\prod_{j=0}^n (x-x_j)\)再求導后在所有\(x_i\)處多點求值。

  分子直接分治求出。

  時間復雜度:\(O(n\log^2n)\)

小技巧1

  比如我們要計算兩個實數序列的卷積\(A\times B=C\),記\(D_i=(a_i+b_i)+(a_i-b_i)i\),那么\(C_i=\frac{1}{4}real({D^2}_i)\)
  
  這樣就可以把三次DFT減少到兩次DFT。
  
  當然,如果\(A=B\)那么這個優化是沒有效果的。

任意模數FFT

模板

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<ctime>
#include<utility>
#include<cmath>
#include<functional>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
void sort(int &a,int &b)
{
	if(a>b)
		swap(a,b);
}
void open(const char *s)
{
#ifndef ONLINE_JUDGE
	char str[100];
	sprintf(str,"%s.in",s);
	freopen(str,"r",stdin);
	sprintf(str,"%s.out",s);
	freopen(str,"w",stdout);
#endif
}
int rd()
{
	int s=0,c;
	while((c=getchar())<'0'||c>'9');
	do
	{
		s=s*10+c-'0';
	}
	while((c=getchar())>='0'&&c<='9');
	return s;
}
int upmin(int &a,int b)
{
	if(b<a)
	{
		a=b;
		return 1;
	}
	return 0;
}
int upmax(int &a,int b)
{
	if(b>a)
	{
		a=b;
		return 1;
	}
	return 0;
}
const ll p=998244353;
const ll g=3;
ll fp(ll a,ll b)
{
    ll s=1;
    while(b)
    {
        if(b&1)
            s=s*a%p;
        a=a*a%p;
        b>>=1;
    }
    return s;
}
const int maxn=600000;
ll inv[maxn];
namespace ntt
{
    ll w1[maxn];
    ll w2[maxn];
    int rev[maxn];
    int n;
    void init(int m)
    {
        n=1;
        while(n<m)
            n<<=1;
        int i;
        for(i=2;i<=n;i<<=1)
        {
            w1[i]=fp(g,(p-1)/i);
            w2[i]=fp(w1[i],p-2);
        }
        rev[0]=0;
        for(i=1;i<n;i++)
            rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
    }
    void ntt(ll *a,int t)
    {
        int i,j,k;
        ll u,v,w,wn;
        for(i=0;i<n;i++)
            if(rev[i]<i)
                swap(a[i],a[rev[i]]);
        for(i=2;i<=n;i<<=1)
        {
            wn=(t==1?w1[i]:w2[i]);
            for(j=0;j<n;j+=i)
            {
                w=1;
                for(k=j;k<j+i/2;k++)
                {
                    u=a[k];
                    v=a[k+i/2]*w%p;
					a[k]=(u+v)%p;
					a[k+i/2]=(u-v)%p;
                    w=w*wn%p;
                }
            }
        }
        if(t==-1)
        {
            u=fp(n,p-2);    
            for(i=0;i<n;i++)
                a[i]=a[i]*u%p;
        }
    }
    ll x[maxn];
    ll y[maxn];
    ll z[maxn];
    void copy_clear(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
        for(i=m;i<n;i++)
            a[i]=0;
    }
    void copy(ll *a,ll *b,int m)
    {
        int i;
        for(i=0;i<m;i++)
            a[i]=b[i];
    }
    void mul(ll *a,ll *b,ll *c,int m)
    {
    	init(m<<1);
    	copy_clear(x,a,m);
    	copy_clear(y,b,m);
    	ntt(x,1);
    	ntt(y,1);
    	int i;
    	for(i=0;i<n;i++)
    		x[i]=x[i]*y[i]%p;
    	ntt(x,-1);
    	copy(c,x,m);
    }
    void inverse(ll *a,ll *b,int m)
    {
        if(m==1)
        {
            b[0]=fp(a[0],p-2);
            return;
        }
        inverse(a,b,m>>1);
        init(m<<1);
        copy_clear(x,a,m);
        copy_clear(y,b,m>>1);
        ntt(x,1);
        ntt(y,1);
        int i;
        for(i=0;i<n;i++)
            x[i]=y[i]*(2-x[i]*y[i]%p)%p;
    	ntt(x,-1);
    	copy(b,x,m);
    }
    ll c[maxn],d[maxn],e[maxn],f[maxn];
    void sqrt(ll *a,ll *b,int m)
    {
    	if(m==1)
    	{
    		if(a[0]==1)
    			b[0]=1;
    		else if(a[0]==0)
    			b[0]=0;
    		else
    			//我也不會
				;
			return;
		}
		sqrt(a,b,m>>1);
//		copy_clear(c,b,m>>1);
		int i;
		for(i=m;i<m<<1;i++)
			b[i]=0;
		inverse(b,d,m);
		init(m<<1);
		for(i=m;i<m<<1;i++)
			b[i]=d[i]=0;
		ll inv2=fp(2,p-2);
		copy_clear(x,a,m);
		ntt(x,1);
		ntt(d,1);
		for(i=0;i<n;i++)
			x[i]=x[i]*d[i]%p;
		ntt(x,-1);
		for(i=0;i<m;i++)
			b[i]=((b[i]+x[i])%p*inv2)%p;
	}
    void derivative(ll *a,ll *b,int m)
	{
		int i;
		for(i=0;i<m-1;i++)
			b[i]=(i+1)*a[i+1]%p;
		b[m-1]=0;
	}
    void differential(ll *a,ll *b,int m)
    {
    	int i;
    	for(i=m-1;i>=1;i--)
    		b[i]=a[i-1]*inv[i]%p;
    	b[0]=0;
    }
    void ln(ll *a,ll *b,int m)
    {
    	static ll c[maxn],d[maxn];
    	derivative(a,c,m);
    	inverse(a,d,m);
    	init(m<<1);
    	int i;
    	for(i=m;i<n;i++)
    		c[i]=d[i]=0;
    	ntt(c,1);
    	ntt(d,1);
    	for(i=0;i<n;i++)
    		c[i]=c[i]*d[i]%p;
    	ntt(c,-1);
    	differential(c,b,m);
    }
    void exp(ll *a,ll *b,int m)
    {
    	if(m==1)
    	{
    		b[0]=1;
    		return;
    	}
    	exp(a,b,m>>1);
    	int i;
    	for(i=m>>1;i<m;i++)
    		b[i]=0;
    	ln(b,y,m);
    	init(m<<1);
    	copy_clear(x,a,m);
    	x[0]++;
    	for(i=0;i<m;i++)
    		x[i]=(x[i]-y[i])%p;
    	copy_clear(y,b,m);
    	ntt(x,1);
    	ntt(y,1);
    	for(i=0;i<n;i++)
    		x[i]=x[i]*y[i]%p;
    	ntt(x,-1);
    	copy(b,x,m);
    }
    void module(ll *a,ll *b,ll *c,int n1,int n2)
    {
    	int k=1;
    	while(k<=n1-n2+1)
    		k<<=1;
    	int i;
    	for(i=0;i<=n1;i++)
    		d[i]=a[i];
    	for(i=0;i<=n2;i++)
    		e[i]=b[i];
    	reverse(d,d+n1+1);
    	reverse(e,e+n2+1);
    	for(i=n1-n2+1;i<k<<1;i++)
    		d[i]=e[i]=0;
    	inverse(e,f,k);
    	for(i=n1-n2+1;i<k<<1;i++)
    		f[i]=0;
    	init(k<<1);
    	ntt::ntt(d,1);
    	ntt::ntt(f,1);
    	for(i=0;i<n;i++)
    		e[i]=d[i]*f[i]%p;
    	ntt::ntt(e,-1);
    	for(i=0;i<=n1-n2;i++)
    		c[i]=e[i];
    	reverse(c,c+n1-n2+1);
    }
};
ll b[maxn];
ll a[maxn];
ll c[maxn];
void get(ll *a,int n)
{
	int i;
	for(i=0;i<n;i++)
		a[i]=rand();
}
int main()
{
//	freopen("fft.txt","w",stdout);
//	srand(time(0));
//	int n=262144;
//	int bg,ed;
//	int i;
//	int times=100,j;
//	double s,s1;
//	inv[0]=inv[1]=1;
//	for(i=2;i<=n;i++)
//		inv[i]=-(p/i)*inv[p%i]%p;
//	s=0;
//	for(j=1;j<=times;j++)
//	{
//		get(a,n);
//		bg=clock();
//		ntt::init(n);
//		ntt::ntt(a,1);
//		ed=clock();
//		s+=double(ed-bg)/CLOCKS_PER_SEC;
//	}
//	printf("ntt :%.10lf\n",s/times);
//	s1=s;
//	s=0;
//	for(j=1;j<=times;j++)
//	{
//		get(a,n);
//		get(b,n);
//		bg=clock();
//		ntt::mul(a,b,c,n);
//		ed=clock();
//		s+=double(ed-bg)/CLOCKS_PER_SEC;
//	}
//	printf("mul :%.10lf %.10lf\n",s/times,s/s1);
//	s=0;
//	for(j=1;j<=times;j++)
//	{
//		get(a,n);
//		bg=clock();
//		ntt::inverse(a,b,n);
//		ed=clock();
//		s+=double(ed-bg)/CLOCKS_PER_SEC;
//	}
//	printf("inv :%.10lf %.10lf\n",s/times,s/s1);
//	s=0;
//	for(j=1;j<=times;j++)
//	{
//		get(a,n);
//		a[0]=1;
//		bg=clock();
//		ntt::sqrt(a,b,n);
//		ed=clock();
//		s+=double(ed-bg)/CLOCKS_PER_SEC;
//	}
//	printf("sqrt:%.10lf %.10lf\n",s/times,s/s1);
//	s=0;
//	for(j=1;j<=times;j++)
//	{
//		get(a,n);
//		a[0]=1;
//		bg=clock();
//		ntt::ln(a,b,n);
//		ed=clock();
//		s+=double(ed-bg)/CLOCKS_PER_SEC;
//	}
//	printf("ln  :%.10lf %.10lf\n",s/times,s/s1);
//	s=0;
//	for(j=1;j<=times;j++)
//	{
//		get(a,n);
//		bg=clock();
//		ntt::exp(a,b,n);
//		ed=clock();
//		s+=double(ed-bg)/CLOCKS_PER_SEC;
//	}
//	printf("exp :%.10lf %.10lf\n",s/times,s/s1);
//	return 0;
}

多點求值+快速插值

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const ll p=998244353;
const ll g=3;
const int maxw=131072;
const int maxn=150000;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
int rt,cnt,ls[1000010],rs[1000010];
ll vx[100010],vy[100010],va[100010];
ll inv[maxn],w1[maxn],w2[maxn];
int rev[maxn];
void init()
{
	inv[0]=inv[1]=1;
	for(int i=2;i<=maxw;i++)
		inv[i]=-p/i*inv[p%i]%p;
	for(int i=2;i<=maxw;i<<=1)
	{
		w1[i]=fp(g,(p-1)/i);
		w2[i]=fp(w1[i],p-2);
	}
}
ll *f[1000010];
int len[maxn];
void clear(ll *a,int n)
{
	memset(a,0,(sizeof a[0])*n);
}
void ntt(ll *a,int n,int t)
{
	for(int i=1;i<n;i++)
	{
		rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
		if(i>rev[i])
			swap(a[i],a[rev[i]]);
	}
	for(int i=2;i<=n;i<<=1)
	{
		ll wn=(t==1?w1[i]:w2[i]);
		for(int j=0;j<n;j+=i)
		{
			ll w=1;
			for(int k=j;k<j+i/2;k++)
			{
				ll u=a[k];
				ll v=a[k+i/2]*w%p;
				a[k]=(u+v)%p;
				a[k+i/2]=(u-v)%p;
				w=w*wn%p;
			}
		}
	}
	if(t==-1)
	{
		ll inv=fp(n,p-2);
		for(int i=0;i<n;i++)
			a[i]=a[i]*inv%p;
	}
}
void mul(ll *a,ll *b,ll *c,int n,int m)
{
	int k=1;
	while(k<=n+m)
		k<<=1;
	static ll a1[maxn],a2[maxn];
	clear(a1,k);
	clear(a2,k);
	for(int i=0;i<=n;i++)
		a1[i]=a[i];
	for(int i=0;i<=m;i++)
		a2[i]=b[i];
	ntt(a1,k,1);
	ntt(a2,k,1);
	for(int i=0;i<k;i++)
		a1[i]=a1[i]*a2[i]%p;
	ntt(a1,k,-1);
	for(int i=0;i<=n+m;i++)
		c[i]=a1[i];
}
void getinv(ll *a,ll *b,int n)
{
	if(n==1)
	{
		b[0]=fp(a[0],p-2);
		return;
	}
	getinv(a,b,n>>1);
	static ll a1[maxn],a2[maxn];
	clear(a1,n<<1);
	clear(a2,n<<1);
	for(int i=0;i<n;i++)
		a1[i]=a[i];
	for(int i=0;i<n>>1;i++)
		a2[i]=b[i];
	ntt(a1,n<<1,1);
	ntt(a2,n<<1,1);
	for(int i=0;i<n<<1;i++)
		a1[i]=a2[i]*(2-a2[i]*a1[i]%p)%p;
	ntt(a1,n<<1,-1);
	for(int i=0;i<n;i++)
		b[i]=a1[i];
}
void div(ll *a,ll *b,ll *c,int n,int m)
{
	static ll a1[maxn],a2[maxn],a3[maxn];
	int k=1;
	while(k<=2*(n-m))
		k<<=1;
	for(int i=0;i<=n;i++)
		a1[i]=a[i];
	for(int i=0;i<=m;i++)
		a2[i]=b[i];
	reverse(a1,a1+n+1);
	reverse(a2,a2+m+1);
	clear(a1+n-m+1,k-(n-m+1));
	clear(a2+n-m+1,k-(n-m+1));
	getinv(a2,a3,k);
	clear(a3+n-m+1,k-(n-m+1));
	ntt(a1,k,1);
	ntt(a3,k,1);
	for(int i=0;i<k;i++)
		a1[i]=a1[i]*a3[i]%p;
	ntt(a1,k,-1);
	for(int i=0;i<=n-m;i++)
		c[i]=a1[i];
	reverse(c,c+n-m+1);
}
void getmod(ll *a,ll *b,ll *c,int n,int m)
{
	static ll a1[maxn],a2[maxn];
	int k=1;
	while(k<=n)
		k<<=1;
	clear(a1,k);
	clear(a2,k);
	for(int i=0;i<=m;i++)
		a1[i]=b[i];
	div(a,b,a2,n,m);
	ntt(a1,k,1);
	ntt(a2,k,1);
	for(int i=0;i<k;i++)
		a1[i]=a1[i]*a2[i]%p;
	ntt(a1,k,-1);
	for(int i=0;i<m;i++)
		c[i]=(a[i]-a1[i])%p;
}
void divide(int l,int r,int &now)
{
	now=++cnt;
	len[now]=r-l+1;
	f[now]=new ll[len[now]+1];
	if(l==r)
	{
		f[now][1]=1;
		f[now][0]=-vx[l];
		return;
	}
	int mid=(l+r)>>1;
	divide(l,mid,ls[now]);
	divide(mid+1,r,rs[now]);
	mul(f[ls[now]],f[rs[now]],f[now],len[ls[now]],len[rs[now]]);
}
void getv(ll *a,int n,int l,int r,int now)
{
	ll *a1=new ll[len[now]];
	getmod(a,f[now],a1,n,len[now]);
	if(l==r)
	{
		va[l]=a1[0];
		return;
	}
	int mid=(l+r)>>1;
	getv(a1,len[now]-1,l,mid,ls[now]);
	getv(a1,len[now]-1,mid+1,r,rs[now]);
}
ll *s[1000010];
void getpoly(int l,int r,int now)
{
	s[now]=new ll[len[now]];
	if(l==r)
	{
		s[now][0]=va[l];
		return;
	}
	int mid=(l+r)>>1;
	getpoly(l,mid,ls[now]);
	getpoly(mid+1,r,rs[now]);
	int k=1;
	while(k<=len[now])
		k<<=1;
	static ll a1[maxn],a2[maxn],a3[maxn],a4[maxn];
	clear(a1,k);
	clear(a2,k);
	clear(a3,k);
	clear(a4,k);
	for(int i=0;i<len[ls[now]];i++)
		a1[i]=s[ls[now]][i];
	for(int i=0;i<=len[rs[now]];i++)
		a2[i]=f[rs[now]][i];
	for(int i=0;i<len[rs[now]];i++)
		a3[i]=s[rs[now]][i];
	for(int i=0;i<=len[ls[now]];i++)
		a4[i]=f[ls[now]][i];
	ntt(a1,k,1);
	ntt(a2,k,1);
	ntt(a3,k,1);
	ntt(a4,k,1);
	for(int i=0;i<k;i++)
		a1[i]=(a1[i]*a2[i]+a3[i]*a4[i])%p;
	ntt(a1,k,-1);
	for(int i=0;i<len[now];i++)
		s[now][i]=a1[i];
}
int n;
ll a[maxn],b[maxn],c[maxn];
int main()
{
	init();
	scanf("%d",&n);
	for(int i=0;i<=n;i++)
		scanf("%lld%lld",&vx[i],&vy[i]);
	divide(0,n,rt);
	for(int i=0;i<=n;i++)
		a[i]=f[rt][i+1]*(i+1)%p;
	getv(a,n,0,n,rt);
//	for(int i=0;i<=n;i++)
//		printf("%lld ",(va[i]+p)%p);
//	printf("\n");
	for(int i=0;i<=n;i++)
		va[i]=fp(va[i],p-2)*vy[i]%p;
	getpoly(0,n,rt);
	for(int i=0;i<=n;i++)
		printf("%lld ",(s[rt][i]+p)%p);
	printf("\n");
	return 0;
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM