manacher模板
今天考了個回文的題,於是在520巨佬的指導下學習了一波manacher.先推薦一波520大佬的博客
題目描述
給出一個只由小寫英文字符a,b,c...y,z組成的字符串S,求S中最長回文串的長度.
字符串長度為n
輸入輸出格式
輸入格式:
一行小寫英文字符a,b,c...y,z組成的字符串S
輸出格式:
一個整數表示答案
輸入輸出樣例
輸入樣例#1:
aaa
輸出樣例#1:
3
說明
字符串長度len <= 11000000
題意就是要求一段字符串中的最長回文串.
首先想一下朴素算法是如何實現的:我們可以枚舉一個字串的起點終點,然后用\(O(len)\)的時間驗證,這樣的總時間復雜度是\(O(n^3)\)的.
根據回文的性質,我們可以不用又枚舉起點又枚舉終點,因為回文是關於它的對稱軸對稱的,所以我們可以考慮直接枚舉它的對稱軸,然后向兩邊擴展.這樣的總時間復雜度是\(O(n^2)\)的.
因為一個字符串最多只有\(len\)個對稱軸,所以對稱軸相同的那些回文串都可以由對稱軸這一點拓展而來. 但是因為對稱軸有可能是在兩個字符的中間(也就是這個回文串的長度是偶數),這樣會使得接下來的操作很不方便,所以我們將原字符串中每兩個字符的中間都插入一個同一個特殊符號方便判斷,比如'#'什么的.打個栗子:
Brave_Cattle -> #B#r#a#v#e#_#C#a#t#t#l#e#
比如這個兩個t,顯然它們的對稱軸是在這兩個字符中間的,那么這個插入的特殊符號就給我們判斷省了很多事.
為了求出長度最長的回文串,顯然我們要求出回文串半徑最大的那一個.
這時候我們需要一個數組\(p[i]\)來記錄下從下標\(i\)開始最多能拓展的回文半徑. 因為在每兩個字符中間都插入了一個特殊字符所以記錄下所有\(p[i]\)之后最長的字串的長度就是最大的\(p[i]-1\).
那么問題的重點就來了:如何快速求出\(p[i]\)數組. 我們可以發現,如果有一個回文串是一個長回文串的子串的話,那么這個回文串的長度可以直接由之前記錄的對稱軸另一邊的那個對稱的字串推出(這個一定要看圖理解).
我們用\(id\)表示目前作為中間回文來推其他回文串的半徑的那個回文的對稱軸的下標(也就是表示的從下標 \(mx\)關於\(id\)的對稱點 到下標\(mx\) 的那個回文串,\(mx\)表示\(p[id]\),也就是該回文串的最右邊的下標).在圖中就是最下面這根直線所代表的區間,我們叫它\(A\)區間.
此時有一個以\(i\)為對稱軸的回文字串,也就是圖中\(i\)下面那條線所代表的區間范圍,我們叫它\(I\)區間.根據數學知識,我們可以得到\(i\)關於\(id\)的對稱點\(j=id*2-i\)(如果不知道就自己畫根數軸模擬一下吧),以這個對稱軸得到的回文串我們叫它\(J\)區間.因為\(I,J\)區間都是屬於\(A\)區間的回文串,而且他們關於\(id\)對稱,所以這兩個區間的半徑的長度是一樣的.
if(i < mx) p[i] = p[id*2-i];
但是我們還需要考慮一種情況: 如果\(I\)區間的右端點超過了\(A\)區間,那么此時\(I\)區間的半徑是取不到和\(J\)區間一樣大的半徑的,所以我們需要判斷是否會發生這種情況.如下圖:
if(i < mx) p[i] = min(p[pos*2-i],mx-i);
當然這樣得到了\(p[i]\)還需要再往后擴展一下,因為有可能后面還存在可以擴展的情況.就判斷一下是否處理后的字符串的第\(i-p[i]\)位和第\(i+p[i]\)位是否相同.
最后我們還需要在循環的時候更新一下作為長串來處理的那個回文串.我們要選最遠的那個,因為這樣就可以讓循環的次數減少.
下面看一下終點的代碼注釋吧:
void manacher(){
int pos = 0, mx = 0, ans = 0;
for(int i=1;i<=cnt;i++){
if(i < mx) p[i] = min(p[pos*2-i],mx-i);//處理
else p[i] = 1;//否則以i為對稱軸的這個串就屬於長串的范圍外,無法直接得到p[i]值.
while(ss[i+p[i]] == ss[i-p[i]]) p[i]++;//還要再擴展一次
if(mx < i+p[i]) mx = i+p[i], pos = i;//更新作為長串的回文串
ans = max(ans,p[i]-1);
}
}
大概內容就講的差不多了...如果不懂的話可能只能再多自己出一點數據模擬一下這個過程了.
下面貼一下完整代碼吧
#include<bits/stdc++.h>
using namespace std;
const int N=11000000+5;
const int inf=2147483647;
int cnt, len, ans = 0;
char s[N], ss[N*2];
int p[N*2];
void init(){//將每兩個字符中插入一個字符
len = strlen(s), cnt = 1;
ss[0] = '!'; ss[cnt] = '#';
for(int i=0;i<len;i++)
ss[++cnt] = s[i], ss[++cnt] = '#';
}
void manacher(){
int pos = 0, mx = 0;
for(int i=1;i<=cnt;i++){
if(i < mx) p[i] = min(p[pos*2-i],mx-i);
else p[i] = 1;
while(ss[i+p[i]] == ss[i-p[i]]) p[i]++;
if(mx < i+p[i]) mx = i+p[i], pos = i;
ans = max(ans,p[i]-1);
}
}
int main(){
scanf("%s",s);
init(); manacher();
printf("%d\n",ans);
return 0;
}