快速數論變換(NTT)


轉自ACdreamers (http://blog.csdn.net/acdreamers/article/details/39026505

 

 

在上一篇文章中 http://blog.csdn.net/acdreamers/article/details/39005227 介紹了用快速傅里葉變

換來求多項式的乘法。可以發現它是利用了單位復根的特殊性質,大大減少了運算,但是這種做法是對復數系數的矩陣

加以處理,每個復數系數的實部和虛部是一個正弦及余弦函數,因此大部分系數都是浮點數,我們必須做復數及浮點數

的計算,計算量會比較大,而且浮點數的計算可能會導致誤差增大。

 

今天,我將來介紹另一種計算多項式乘法的算法,叫做快速數論變換(NTT),在離散正交變換的理論中,已經證明在

復數域內,具有循環卷積特性的唯一變換是DFT,所以在復數域中不存在具有循環卷積性質的更簡單的離散正交變換。

因此提出了以數論為基礎的具有循環卷積性質的快速數論變換

 

回憶復數向量,其離散傅里葉變換公式如下

 

   

 

離散傅里葉逆變換公式為

 

   

 

今天的快速數論變換(NTT)是在上進行的,在快速傅里葉變換(FFT)中,通過次單位復根來運算的,即滿

,而對於快速數論變換來說,則是可以將看成是的等價,這里是模素數

的原根(由於是素數,那么原根一定存在)。即

 

        

 

所以綜上,我們得到數論變換的公式如下

 

    

 

數論變換的逆變換公式為

 

    

 

這樣就把復數對應到一個整數,之后一切都是在系統內考慮。

 

上述數論變換(NTT)公式中,要求是素數且必須是的因子。由於經常是2的方冪,所以可以構造形

的素數。通常來說可以選擇費馬素數,這樣的變換叫做費馬數數論變換

 

這里我們選擇,這樣得到模的原根值為

 

 

題目:http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1028

 

分析:題目意思就是大數相乘,此處用快速數論變換(NTT)實現。

 

  1 #include <iostream>  
  2 #include <string.h>  
  3 #include <stdio.h>  
  4   
  5 using namespace std;  
  6 typedef long long LL;  
  7   
  8 const int N = 1 << 18;  
  9 const int P = (479 << 21) + 1;  
 10 const int G = 3;  
 11 const int NUM = 20;  
 12   
 13 LL  wn[NUM];  
 14 LL  a[N], b[N];  
 15 char A[N], B[N];  
 16   
 17 LL quick_mod(LL a, LL b, LL m)  
 18 {  
 19     LL ans = 1;  
 20     a %= m;  
 21     while(b)  
 22     {  
 23         if(b & 1)  
 24         {  
 25             ans = ans * a % m;  
 26             b--;  
 27         }  
 28         b >>= 1;  
 29         a = a * a % m;  
 30     }  
 31     return ans;  
 32 }  
 33   
 34 void GetWn()  
 35 {  
 36     for(int i=0; i<NUM; i++)  
 37     {  
 38         int t = 1 << i;  
 39         wn[i] = quick_mod(G, (P - 1) / t, P);  
 40     }  
 41 }  
 42   
 43 void Prepare(char A[], char B[], LL a[], LL b[], int &len)  
 44 {  
 45     len = 1;  
 46     int len_A = strlen(A);  
 47     int len_B = strlen(B);  
 48     while(len <= 2 * len_A || len <= 2 * len_B) len <<= 1;  
 49     for(int i=0; i<len_A; i++)  
 50         A[len - 1 - i] = A[len_A - 1 - i];  
 51     for(int i=0; i<len - len_A; i++)  
 52         A[i] = '0';  
 53     for(int i=0; i<len_B; i++)  
 54         B[len - 1 - i] = B[len_B - 1 - i];  
 55     for(int i=0; i<len - len_B; i++)  
 56         B[i] = '0';  
 57     for(int i=0; i<len; i++)  
 58         a[len - 1 - i] = A[i] - '0';  
 59     for(int i=0; i<len; i++)  
 60         b[len - 1 - i] = B[i] - '0';  
 61 }  
 62   
 63 void Rader(LL a[], int len)  
 64 {  
 65     int j = len >> 1;  
 66     for(int i=1; i<len-1; i++)  
 67     {  
 68         if(i < j) swap(a[i], a[j]);  
 69         int k = len >> 1;  
 70         while(j >= k)  
 71         {  
 72             j -= k;  
 73             k >>= 1;  
 74         }  
 75         if(j < k) j += k;  
 76     }  
 77 }  
 78   
 79 void NTT(LL a[], int len, int on)  
 80 {  
 81     Rader(a, len);  
 82     int id = 0;  
 83     for(int h = 2; h <= len; h <<= 1)  
 84     {  
 85         id++;  
 86         for(int j = 0; j < len; j += h)  
 87         {  
 88             LL w = 1;  
 89             for(int k = j; k < j + h / 2; k++)  
 90             {  
 91                 LL u = a[k] % P;  
 92                 LL t = w * (a[k + h / 2] % P) % P;  
 93                 a[k] = (u + t) % P;  
 94                 a[k + h / 2] = ((u - t) % P + P) % P;  
 95                 w = w * wn[id] % P;  
 96             }  
 97         }  
 98     }  
 99     if(on == -1)  
100     {  
101         for(int i = 1; i < len / 2; i++)  
102             swap(a[i], a[len - i]);  
103         LL Inv = quick_mod(len, P - 2, P);  
104         for(int i = 0; i < len; i++)  
105             a[i] = a[i] % P * Inv % P;  
106     }  
107 }  
108   
109 void Conv(LL a[], LL b[], int n)  
110 {  
111     NTT(a, n, 1);  
112     NTT(b, n, 1);  
113     for(int i = 0; i < n; i++)  
114         a[i] = a[i] * b[i] % P;  
115     NTT(a, n, -1);  
116 }  
117   
118 void Transfer(LL a[], int n)  
119 {  
120     int t = 0;  
121     for(int i = 0; i < n; i++)  
122     {  
123         a[i] += t;  
124         if(a[i] > 9)  
125         {  
126             t = a[i] / 10;  
127             a[i] %= 10;  
128         }  
129         else t = 0;  
130     }  
131 }  
132   
133 void Print(LL a[], int n)  
134 {  
135     bool flag = 1;  
136     for(int i = n - 1; i >= 0; i--)  
137     {  
138         if(a[i] != 0 && flag)  
139         {  
140             printf("%d", a[i]);  
141             flag = 0;  
142         }  
143         else if(!flag)  
144             printf("%d", a[i]);  
145     }  
146     puts("");  
147 }  
148   
149 int main()  
150 {  
151     GetWn();  
152     while(scanf("%s%s", A, B)!=EOF)  
153     {  
154         int len;  
155         Prepare(A, B, a, b, len);  
156         Conv(a, b, len);  
157         Transfer(a, len);  
158         Print(a, len);  
159     }  
160     return 0;  
161 }  

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM