在Seal庫和HElib庫中都用到了NTT技術,用於加快多項式計算,而NTT又是FFT的優化,FFT又來自於DFT,現在具體學習一下這三個技術!
基礎概念
名詞區分
1、DFT:離散傅立葉變換
2、FFT:快速傅立葉變換
3、NTT:快速數論變換
4、MTT:NTT的擴展
5、多項式卷積:多項式乘法
6、根據多項式的系數表示法求點值表示法的過程叫做“求值”;根據點值表示法求系數表示法的過程稱為“插值”
7、求一個多項式的乘法,即求卷積,先通過傅立葉變換對系數表示法的多項式進行求值運算,其復雜度\(O(nlog^n)\),然后在\(O(n)\)的時間內點值相乘,在進行插值運算。
8、如果選取單位復根作為求值點,則可以對系數向量進行離散傅立葉變換(DFT),得到相應的點值表示;同樣可以通過對點值進行逆DFT運算,獲得相應的系數向量。DFT和逆DFT時間復雜度均為\(O(nlog^n)\)。
復數
定義
我們知道,一個復數可以這樣表示:\(a+bi\),a和b是實數,其中\(i\)叫做虛數單位,復數域是目前已知最大的域。
在復平面中,x軸代表實數,y軸(除原點外的點)代表虛數,從原點(0,0)到(a,b)的向量表示復數\(a+bi\)
模長:從原點(0,0)到(a,b)的距離,即\(\sqrt{a^2+b^2}\)
幅角:假設以逆時針為正方向,從x軸正半軸到已知向量的轉角的有向角叫做幅角
運算
1、加法
在復數平面,復數可以表示為向量,因為復數的加法和向量的加法相同。
2、乘法
幾何定義:復數相乘密,模長相乘,幅角相加
代數定義:
單位根
在復數平面上,以原點為圓心,1為半徑做圓,所得的圓叫做單位圓,以圓點為起點,圓的n等分為終點,做n個向量,設幅角為正且最小的向量對應的復數為w_n$,稱為n次單位根。
根據復數乘法的運算法則,其余n-1個復數為\(w_n^2,...,w_n^n\),注意\(w_n^0=w_n^n=1\)(對應復平面上以x軸為正方向的向量)
如何計算呢?
由歐拉公式解決\(w_n^k=cos(k*2\pi/n)+isin(k*2\pi/n)\)
例如:向量AB表示的是復數為4次單位根
n次單位根的幅角為周角的\(1/n\)。
在代數中,若\(z_n=1\),我們把z稱為n次單位根。具體請參考:n次單位根(n-th unit root)
單位根的性質
1、\(w_n^k=cos(k*2\pi/n)+isin(k*2\pi/n)\)
2、【相消引理】\(w_{dn}^{dk}=w_n^k\)
證明:以d=2為例
3、【折半引理】\(w_n^{k+n/2}=-w_n^k\)
證明:
4、\(w_n^0=w_n^k=1\)
5、\(w_n^{n-i} = 共軛(w_n^i)\)
6、\(w_n^{n+i}=w_n^i\)
多項式系數表示法
設\(A(x)\)表示一個d次多項式,則\(A(x)=a_1+a_2*x+...,+a_{d}*x^{d}\)
利用這種方法計算多項式卷積復雜度為\(O(d^2)\),其實就是直接對應相乘(暴力)。
例如:\(A(x)=1+2x+x^2\), \(B(x)=1-2x+x^2\)
多項式點值表示法
將n個值x帶入多項式,會得到d各不同的值y,則該多項式被這n個點值\((x_1,y_1),...,(x_d,y_x)\)唯一確定,其中\(\sum_{j=1}^{d}a_j*x^j_i\)
而利用點值法計算多項式卷積復雜度也為\(O(d^2)\)。(選點\(O(d)\),每次計算\(O(d)\))
例如上面的多項式用點值法表示:\(A(x)=[(-2,1),(-1,0),(0,1),(1,4),(2,9)],B(x)=[(-2,9),(-1,4),(0,1),(1,0),(2,1)]\),則
即有這個5個點就可以唯一確定一個4次多項式,而兩兩相乘的復雜度為\(O(d)\)
引理1:\(( d + 1 )\)個點值可以唯一確定一個d 階多項式
因此,我們可以將一個系數多項式轉換為一個點值多項式,然而進行復雜度為\(O(d)\)的乘法,再將結果的點值多項式恢復回系數多項式。
但是:如果我們采用下面這種矩陣形式計算點值的話【選點】,那么由系數轉為點值的復雜度也為\(O(d^2)\)。
接下來考慮對其優化:
1、對於系數表示法,每個點的系數都固定,優化困難
2、對於點值表示法,可以用FFT來解決!
DFT
已知\(A(x)\)的系數為\((a_0,a_1,...,a_{n-1})\),對於\(k=0,1,...,n-1\),定義:
其中向量\(y=(y_0,y_1,...,y_{n-1})\)是系數向量\(a=(a_0,a_1,...,a_{n-1})\)的離散傅立葉變換,記\(y=DFT_n(a)\),復雜度為\(O(n^2)\)
而使用下面的FFT方法,可以在\(O(nlog^n)\)時間內求出\(DFT_n(a)\)
FFT
用於加速系數多項式到點值多項式的運算!
首先觀察下面多項式:
例如:\(F(x)=x^2\),有對稱性\(F(-x)=F(x)\),相當於確定了一個點相當於確定兩個點。
同理又如\(F(x^3)\),有性質\(F(-x)=-F(x)\),也是確定了一個點相當於確定了兩個點。
所以對於有奇偶行的多項式,只需要找到原本一半的點就可以得到這個多現實了。
基於以上想法,假如有下面多項式:
把\(P_e\)和\(P_o\)分別看作兩個多項式,也就是對於一個點\(x_i\),我們只要計算出\(P_e(x_i^2)\)和\(P_o(x_i^2)\),就可以得到\(P(x_i)\)和\(P(-x_i)\),而且\(P_e\)和\(P_o\)還可以進一步拆分為奇偶兩部分!
假設原本我們需要n個點\(【\pm x_1,\pm x_12,...,\pm x_{n/2}】\)就能確定一個\(n-1\)階的多項式。現在變成了求\(P_e(x)\)和\(P_o(x)\)在\(x_1^2,x_2^2,...,x_{n/2}^2\)上面的點值【n/2個點】。
那如果這n/2個點兩兩之間滿足\(x_i^2=-x_j^2\),則就可以進一步拆分為一半了,就可以將原本的復雜度\(O(d^2)\)降為\(O(dlog^d)\)。這里可以看出FFT用到了分治思想。
問題是,$x_12,x_22,...,x_{n/2}^2並不滿足兩兩互為相反數。由此使用n次單位根,選用n個n次單文根
\([w^0,w^1,...,w^{n-1}]\)。
這樣,兩個點平方后依舊互為相反數!
可以看出,將以一個n個點的求值問題轉換為求n/2個點,在轉換為求n/4個點,以此迭代,從而達到\(O(dlog^d)\)。將上述思想轉換為為代碼如下:
FFT的逆
如何從點值多項式變為系數多項式呢?
對於點值計算,
實際上就是一個矩陣的乘法:
將點換為n個n次單位根,則矩陣變為:
其中中間的范德蒙德矩陣就成了一個DFT矩陣。
有了正向(系數到點值)的矩陣變換,求逆向(點值到系數)就是對上面矩陣求逆即可:
即:
從上面可以看出,FFT是將\(w\)作為點值傳入,IFFT就是將\({1/n}*w^{-1}\)作為點值傳入:
程序
下面程序用FFT計算兩個大數乘
題目:http://acm.hdu.edu.cn/showproblem.php?pid=1402
#include <iostream>
#include <string.h>
#include <stdio.h>
#include <math.h>
using namespace std;
const int N = 500005;
const double PI = acos(-1.0);
struct Virt
{
double r, i;
Virt(double r = 0.0,double i = 0.0)
{
this->r = r;
this->i = i;
}
Virt operator + (const Virt &x)
{
return Virt(r + x.r, i + x.i);
}
Virt operator - (const Virt &x)
{
return Virt(r - x.r, i - x.i);
}
Virt operator * (const Virt &x)
{
return Virt(r * x.r - i * x.i, i * x.r + r * x.i);
}
};
//雷德算法--倒位序
void Rader(Virt F[], int len)
{
int j = len >> 1;
for(int i=1; i<len-1; i++)
{
if(i < j) swap(F[i], F[j]);
int k = len >> 1;
while(j >= k)
{
j -= k;
k >>= 1;
}
if(j < k) j += k;
}
}
//FFT實現
void FFT(Virt F[], int len, int on)
{
Rader(F, len);
for(int h=2; h<=len; h<<=1) //分治后計算長度為h的DFT
{
Virt wn(cos(-on*2*PI/h), sin(-on*2*PI/h)); //單位復根e^(2*PI/m)用歐拉公式展開
for(int j=0; j<len; j+=h)
{
Virt w(1,0); //旋轉因子
for(int k=j; k<j+h/2; k++)
{
Virt u = F[k];
Virt t = w * F[k + h / 2];
F[k] = u + t; //蝴蝶合並操作
F[k + h / 2] = u - t;
w = w * wn; //更新旋轉因子
}
}
}
if(on == -1)
for(int i=0; i<len; i++)
F[i].r /= len;
}
//求卷積
void Conv(Virt a[],Virt b[],int len)
{
FFT(a,len,1);
FFT(b,len,1);
for(int i=0; i<len; i++)
a[i] = a[i]*b[i];
FFT(a,len,-1);
}
char str1[N],str2[N];
Virt va[N],vb[N];
int result[N];
int len;
void Init(char str1[],char str2[])
{
int len1 = strlen(str1);
int len2 = strlen(str2);
len = 1;
while(len < 2*len1 || len < 2*len2) len <<= 1;
int i;
for(i=0; i<len1; i++)
{
va[i].r = str1[len1-i-1] - '0';
va[i].i = 0.0;
}
while(i < len)
{
va[i].r = va[i].i = 0.0;
i++;
}
for(i=0; i<len2; i++)
{
vb[i].r = str2[len2-i-1] - '0';
vb[i].i = 0.0;
}
while(i < len)
{
vb[i].r = vb[i].i = 0.0;
i++;
}
}
void Work()
{
Conv(va,vb,len);
for(int i=0; i<len; i++)
result[i] = va[i].r+0.5;
}
void Export()
{
for(int i=0; i<len; i++)
{
result[i+1] += result[i]/10;
result[i] %= 10;
}
int high = 0;
for(int i=len-1; i>=0; i--)
{
if(result[i])
{
high = i;
break;
}
}
for(int i=high; i>=0; i--)
printf("%d",result[i]);
puts("");
}
int main()
{
while(~scanf("%s%s",str1,str2))
{
Init(str1,str2);
Work();
Export();
}
return 0;
}
NTT
在FFT中,我們需要用到復數,復數雖然很神奇,但是它也有自己的局限性——需要用double類型計算,精度太低,那有沒有什么東西能夠代替復數且解決精度問題呢?
這個東西,叫原根
階
若a,p互素,且p>1,對於\(a^n mod p =1\)滿足最小的n,叫做a模p的階,記\(\delta _p(a)\).
例如:
其中:
\(2^1 mod 7 =2\)
\(2^2 mod 7 =4\)
\(2^3 mod 7 =1\)
原根
設p是正整數,a是整數,若\(\delta _p(a)\)等於\(\phi(p)\),則a為模p的一個原根。
例如:
\(\delta _7(3)=6=\phi (7)\),所以3是模7的一個原根。
原根的個數不唯一
1、若模數p有原根,那么它一定有\(\phi(\phi(p))個原根\)
2、若p為素數,原根一定存在,假設g是P的一個原根,那么\(g^i mod p (1<g<p,0<i<p)\)的結果兩兩不同
簡單的說,就是
3、那如何求一個質數的原根呢?
對於指數p,\(p_i\)是p-1的因子,若\(g^{{p-1}/p_i} (mod p)\)恆成立,則g是p的原根。
下面就是為什么原根可以代替單位根計算?
因為原根具有和單位根相同的性質,FFT中,用到了單位根的四條性質,原根也滿足這四條性質:
最終可以得到:
然后只需將FFT中的\(w_n\)替換掉,就是NTT。即:
綜上,NTT的變換為:
這里P是素數且N必須是P-1的因子;由於N是2的方冪,所以可構造\(P=c.2^k+1\)的素數。
通常p取998244353,它的原根為3。
程序
使用NTT,計算兩個大數乘
#include <iostream>
#include <string.h>
#include <stdio.h>
#include <ctime>
using namespace std;
typedef long long LL;
const int N = 1 << 18;
const int P = (479 << 21) + 1;
const int G = 3;
const int NUM = 20;
LL wn[NUM];
LL a[N], b[N];
char A[N], B[N];
LL quick_mod(LL a, LL b, LL m)
{
LL ans = 1;
a %= m;
while(b)
{
if(b & 1)
{
ans = ans * a % m;
b--;
}
b >>= 1;
a = a * a % m;
}
return ans;
}
void GetWn()
{
for(int i = 0; i < NUM; i++)
{
int t = 1 << i;
wn[i] = quick_mod(G, (P - 1) / t, P);
}
}
void Prepare(char A[], char B[], LL a[], LL b[], int &len)
{
len = 1;
int L1 = strlen(A);
int L2 = strlen(B);
while(len <= 2 * L1 || len <= 2 * L2) len <<= 1;
for(int i = 0; i < len; i++)
{
if(i < L1) a[i] = A[L1 - i - 1] - '0';
else a[i] = 0;
if(i < L2) b[i] = B[L2 - i - 1] - '0';
else b[i] = 0;
}
}
void Rader(LL a[], int len)
{
int j = len >> 1;
for(int i = 1; i < len - 1; i++)
{
if(i < j) swap(a[i], a[j]);
int k = len >> 1;
while(j >= k)
{
j -= k;
k >>= 1;
}
if(j < k) j += k;
}
}
void NTT(LL a[], int len, int on)
{
Rader(a, len);
int id = 0;
for(int h = 2; h <= len; h <<= 1)
{
id++;
for(int j = 0; j < len; j += h)
{
LL w = 1;
for(int k = j; k < j + h / 2; k++)
{
LL u = a[k] % P;
LL t = w * a[k + h / 2] % P;
a[k] = (u + t) % P;
a[k + h / 2] = (u - t + P) % P;
w = w * wn[id] % P;
}
}
}
if(on == -1)
{
for(int i = 1; i < len / 2; i++)
swap(a[i], a[len - i]);
LL inv = quick_mod(len, P - 2, P);
for(int i = 0; i < len; i++)
a[i] = a[i] * inv % P;
}
}
void Conv(LL a[], LL b[], int n)
{
NTT(a, n, 1);
NTT(b, n, 1);
for(int i = 0; i < n; i++)
a[i] = a[i] * b[i] % P;
NTT(a, n, -1);
}
void Transfer(LL a[], int n)
{
int t = 0;
for(int i = 0; i < n; i++)
{
a[i] += t;
if(a[i] > 9)
{
t = a[i] / 10;
a[i] %= 10;
}
else t = 0;
}
}
void Print(LL a[], int n)
{
bool flag = 1;
for(int i = n - 1; i >= 0; i--)
{
if(a[i] != 0 && flag)
{
//使用putchar()速度快很多
putchar(a[i] + '0');
flag = 0;
}
else if(!flag)
putchar(a[i] + '0');
}
puts("");
}
int main()
{
GetWn();
//71992652622957199265262295622895175513333762898450963102326778447440803476281886006800109449932463374706102636321647845385139107625719926526229571992652622956228951755133337628984509631023267784474408034762818860068001094499324633747061026363216478453851391076257199265262295719926526229562289517551333376289845096310232677844744080347628188600680010944993246337470610263632164784538513910762571992652622957199265262295622895175513333762898450963102326778447440803476281886006800109449932463374706102636321647845385139107625719926526229571992652622956228951755133337628984509631023267784474408034762818860068001094499324633747061026363216478453851391076257199265262295719926526229562289517551333376289845096310232677844744080347628188600680010944993246337470610263632164784538513910762571992652622957199265262295622895175513333762898450963102326778447440803476281886006800109449932463374706102636321647845385139107625
while(scanf("%s %s", A, B) != EOF)
{
int len;
clock_t start_time = clock();//計時開始
Prepare(A, B, a, b, len);
Conv(a, b, len);
Transfer(a, len);
cout << "elapsed time:" << 1000*double(clock() - start_time) / CLOCKS_PER_SEC
<< 'ms' << endl;
Print(a, len);
}
return 0;
}
輸出:elapsed time:3.9328019
MTT
待學習!
參考
1、快速傅里葉變換(FFT)詳解
2、快速數論變換(NTT)小結
3、CKKS的Encoding(CKKS方案的編碼部分的筆記)
4、多項式乘法運算終極版
5、多項式乘法運算初級版