前置技能:AC自動機
假設我們有了一個AC自動機,然后在上面進行字符串匹配。
上面是一個有四個字符串的AC自動機(abcde、aacdf、cdf、cde),虛線是fail指針,實線是轉移。
這是上一次講AC自動機的時候的匹配代碼:
int match(char* s) { int cur=rot,ans=0; for(int i=0;s[i];i++) { int c=s[i]-'a'; cur=ch[cur][c]; for(int f=cur;f!=rot;f=fail[f]) ans+=cnt[f], cnt[f]=0; } return ans; }
出題人嘿嘿一笑,給了你一個“aaaaaaaaaaaaaaaaaaa”。這樣的字符串fail鏈長度為O(n)的,這就很尷尬了。
我們發現,如果我們把每個x與fail[x]連邊,好像形成了一個樹結構(fail[x]是x的父節點)。
我們就是要查詢一個點到根路徑上的cnt之和!我們只要dfs一下預處理出來就行了。
我們來分析一下這個樹結構是什么東西。
首先,這棵樹是在一個trie的基礎上產生的,所以這棵樹上的每個點都是一個字符串的前綴,而且每個字符串的每個前綴在這棵樹上都對應着一個點。
其次,由於fail指針,每個點父節點的字符串都是這個點字符串的后綴,並且樹上沒有更長的它的后綴。
例1 bzoj3172 單詞
給出n個字符串,詢問每個字符串在所有字符串中的出現次數之和。
在建ac自動機的時候,我們把經過的那條鏈的cnt值全部+1,那么我們就是要查詢子樹和。
為什么?例如有一個串abcde,匹配了123abcde123,那么AC自動機上123abcde就會在abcde的子樹中。
#include <iostream> #include <stdio.h> #include <math.h> #include <string.h> #include <time.h> #include <stdlib.h> #include <string> #include <bitset> #include <vector> #include <set> #include <map> #include <queue> #include <algorithm> #include <sstream> #include <stack> #include <iomanip> using namespace std; #define pb push_back #define mp make_pair typedef pair<int,int> pii; typedef long long ll; typedef double ld; typedef vector<int> vi; #define fi first #define se second #define fe first #define FO(x) {freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);} #define Edg int M=0,fst[SZ],vb[SZ],nxt[SZ];void ad_de(int a,int b){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;}void adde(int a,int b){ad_de(a,b);ad_de(b,a);} #define Edgc int M=0,fst[SZ],vb[SZ],nxt[SZ],vc[SZ];void ad_de(int a,int b,int c){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;vc[M]=c;}void adde(int a,int b,int c){ad_de(a,b,c);ad_de(b,a,c);} #define es(x,e) (int e=fst[x];e;e=nxt[e]) #define esb(x,e,b) (int e=fst[x],b=vb[e];e;e=nxt[e],b=vb[e]) #define VIZ {printf("digraph G{\n"); for(int i=1;i<=n;i++) for es(i,e) printf("%d->%d;\n",i,vb[e]); puts("}");} #define VIZ2 {printf("graph G{\n"); for(int i=1;i<=n;i++) for es(i,e) if(vb[e]>=i)printf("%d--%d;\n",i,vb[e]); puts("}");} using namespace std; #define SZ 1000099 int rot=1,ch[SZ][27],fail[SZ],e=1; ll cnt[SZ]; int M=0,fst[SZ],vb[SZ+SZ],nxt[SZ+SZ]; void ad_de(int a,int b){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;} void adde(int a,int b){ad_de(a,b);ad_de(b,a);} int insert(char* s) { int cur=rot; for(int i=0;s[i];i++) { int c=s[i]-'a'; if(!ch[cur][c]) ch[cur][c]=++e; cur=ch[cur][c]; ++cnt[cur]; } return cur; } int qs[SZ],h=0,t=0; void bfail() { h=t=0; fail[rot]=rot; for(int i=0;i<26;i++) { if(!ch[rot][i]) { ch[rot][i]=rot; continue; } fail[ch[rot][i]]=rot; qs[t++]=ch[rot][i]; } while(h!=t) { int cur=qs[h++]; for(int c=0;c<26;c++) { if(!ch[cur][c]) ch[cur][c]=ch[fail[cur]][c]; else { fail[ch[cur][c]]=ch[fail[cur]][c]; qs[t++]=ch[cur][c]; } } } } int n,T,orz[SZ]; char str[SZ]; void dfs(int x,int f=0) { for esb(x,e,b) { if(b==f) continue; dfs(b,x); cnt[x]+=cnt[b]; } } int main() { scanf("%d",&n); for(int i=1;i<=n;i++) { scanf("%s",str); orz[i]=insert(str); } bfail(); for(int i=2;i<=e;i++) adde(i,fail[i]); dfs(1); for(int i=1;i<=n;i++) printf("%lld\n",cnt[orz[i]]); }
例2 bzoj2434 阿狸的打字機
有一個打字機,上面有一個可以容納字符串的凹槽。
打字機上有26個英文字母和'B'、'P'。輸入字母,打字機的一個凹槽中會加入這個字母。按下'B',打字機凹槽中會刪掉最后一個字母。按下'P',打字機會在紙上打印出凹槽中現有的所有字母並換行。按一下印有'B'的按鍵,打字機凹槽中最后一個字母會消失。
有m次詢問,每次詢問第x個打印的字符串在第y個打印的字符串中出現了多少次。
我們可以發現B就相當於回到trie上的父節點,P就相當於記錄一下這個節點編號。
我們建出AC自動機和fail樹。“第x個打印的字符串在第y個打印的字符串中出現了多少次”,如果樹中只有第y個字符串那就是要統計x字符串這個點的子樹和。
那么我們可以重新進行一次建樹一樣的操作,等到當前字符串為“第y個字符串”時再處理詢問。我們只要在fail樹上用dfs序+樹狀數組維護就行了。
#include <iostream> #include <stdio.h> #include <math.h> #include <string.h> #include <time.h> #include <stdlib.h> #include <string> #include <bitset> #include <vector> #include <set> #include <map> #include <queue> #include <algorithm> #include <sstream> #include <stack> #include <iomanip> using namespace std; #define pb push_back #define mp make_pair typedef pair<int,int> pii; typedef long long ll; typedef double ld; typedef vector<int> vi; #define fi first #define se second #define fe first #define FO(x) {freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);} #define Edg int M=0,fst[SZ],vb[SZ],nxt[SZ];void ad_de(int a,int b){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;}void adde(int a,int b){ad_de(a,b);ad_de(b,a);} #define Edgc int M=0,fst[SZ],vb[SZ],nxt[SZ],vc[SZ];void ad_de(int a,int b,int c){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;vc[M]=c;}void adde(int a,int b,int c){ad_de(a,b,c);ad_de(b,a,c);} #define es(x,e) (int e=fst[x];e;e=nxt[e]) #define esb(x,e,b) (int e=fst[x],b=vb[e];e;e=nxt[e],b=vb[e]) #define VIZ {printf("digraph G{\n"); for(int i=1;i<=n;i++) for es(i,e) printf("%d->%d;\n",i,vb[e]); puts("}");} #define VIZ2 {printf("graph G{\n"); for(int i=1;i<=n;i++) for es(i,e) if(vb[e]>=i)printf("%d--%d;\n",i,vb[e]); puts("}");} using namespace std; #define SZ 233333 int rot=1,ch[SZ][27],fail[SZ],fa[SZ],e=1; int M=0,fst[SZ],vb[SZ+SZ],nxt[SZ+SZ]; void ad_de(int a,int b){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;} void adde(int a,int b){ad_de(a,b);ad_de(b,a);} char str[SZ]; int al(int& s) {if(!s) s=++e; return s;} int n,rp[SZ]; void pre(char* s) { int cur=rot; for(int i=0;s[i];i++) { if(s[i]=='B') cur=fa[cur]; else if(s[i]=='P') rp[++n]=cur; else { char c=s[i]-'a'; int nx=al(ch[cur][c]); fa[nx]=cur; cur=nx; } } } int qs[SZ],h=0,t=0; void bfail() { h=t=0; fail[rot]=rot; for(int i=0;i<26;i++) { if(!ch[rot][i]) { ch[rot][i]=rot; continue; } fail[ch[rot][i]]=rot; qs[t++]=ch[rot][i]; } while(h!=t) { int cur=qs[h++]; for(int c=0;c<26;c++) { if(!ch[cur][c]) ch[cur][c]=ch[fail[cur]][c]; else { fail[ch[cur][c]]=ch[fail[cur]][c]; qs[t++]=ch[cur][c]; } } } for(int i=2;i<=e;i++) adde(i,fail[i]); } int dfn[SZ],D=0,ls[SZ]; void dfs(int x,int f=0) { dfn[x]=++D; for esb(x,e,b) { if(b==f) continue; dfs(b,x); } ls[x]=D; } int ss[SZ]; int sum(int x) { int ans=0; for(;x>=1;x-=x&-x) ans+=ss[x]; return ans; } void edt(int x,int y) { for(;x<=D;x+=x&-x) ss[x]+=y; } int m,qa[SZ],qb[SZ],nq[SZ],fq[SZ],anss[SZ]; int main() { scanf("%s",str); pre(str); bfail(); scanf("%d",&m); for(int i=1;i<=m;i++) { scanf("%d%d",qa+i,qb+i); nq[i]=fq[qb[i]]; fq[qb[i]]=i; } dfs(rot); int cur=rot,ci=0; for(int i=0;str[i];i++) { if(str[i]=='B') edt(dfn[cur],-1), cur=fa[cur]; else if(str[i]=='P') { ++ci; for(int q=fq[ci];q;q=nq[q]) anss[q]=sum(ls[rp[qa[q]]])-sum(dfn[rp[qa[q]]]-1); } else { char c=str[i]-'a'; int nx=ch[cur][c]; fa[nx]=cur; cur=nx; edt(dfn[cur],1); } } for(int i=1;i<=m;i++) printf("%d\n",anss[i]); }
例3 bzoj3881 Divljak
Alice有n個字符串s1...sn,Bob有一個字符串集合,一開始集合是空的。
若干個操作,每個操作是往集合里添加一個字符串或者給定x,查詢集合中有多少個字符串包含sx。
注意這里要求的是有多少個包含,而不是出現了幾次。
例如ababab和ab,如果算子樹和的話就會被多算一次。
我們考慮上一題我們實際上干的事情是對於每個前綴對應的點,把根到這個點的路徑全部+1。
那為了不重復統計,我們只要保證沒有點被多次+1就行了。
上一題我們轉化為了子樹和,那么這一題轉化為子樹和只要在每個重復的lca處-1就行了,具體實現就類似虛樹那樣用一個棧來維護。
好像把虛樹的板子抄下來就過了
#include <iostream> #include <stdio.h> #include <math.h> #include <string.h> #include <time.h> #include <stdlib.h> #include <string> #include <bitset> #include <vector> #include <set> #include <map> #include <queue> #include <algorithm> #include <sstream> #include <stack> #include <iomanip> using namespace std; #define pb push_back #define mp make_pair typedef pair<int,int> pii; typedef long long ll; typedef double ld; typedef vector<int> vi; #define fi first #define se second #define fe first #define FO(x) {freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);} #define Edg int M=0,fst[SZ],vb[SZ],nxt[SZ];void ad_de(int a,int b){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;}void adde(int a,int b){ad_de(a,b);ad_de(b,a);} #define Edgc int M=0,fst[SZ],vb[SZ],nxt[SZ],vc[SZ];void ad_de(int a,int b,int c){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;vc[M]=c;}void adde(int a,int b,int c){ad_de(a,b,c);ad_de(b,a,c);} #define es(x,e) (int e=fst[x];e;e=nxt[e]) #define esb(x,e,b) (int e=fst[x],b=vb[e];e;e=nxt[e],b=vb[e]) #define VIZ {printf("digraph G{\n"); for(int i=1;i<=n;i++) for es(i,e) printf("%d->%d;\n",i,vb[e]); puts("}");} #define VIZ2 {printf("graph G{\n"); for(int i=1;i<=n;i++) for es(i,e) if(vb[e]>=i)printf("%d--%d;\n",i,vb[e]); puts("}");} using namespace std; #define SZ 2000099 int rot=1,ch[SZ][27],fail[SZ],e=1; int M=0,fst[SZ],vb[SZ+SZ],nxt[SZ+SZ]; void ad_de(int a,int b){++M;nxt[M]=fst[a];fst[a]=M;vb[M]=b;} void adde(int a,int b){ad_de(a,b);ad_de(b,a);} int insert(char* s) { int cur=rot; for(int i=0;s[i];i++) { int c=s[i]-'a'; if(!ch[cur][c]) ch[cur][c]=++e; cur=ch[cur][c]; } return cur; } int qs[SZ],h=0,t=0; void bfail() { h=t=0; fail[rot]=rot; for(int i=0;i<26;i++) { if(!ch[rot][i]) { ch[rot][i]=rot; continue; } fail[ch[rot][i]]=rot; qs[t++]=ch[rot][i]; } while(h!=t) { int cur=qs[h++]; for(int c=0;c<26;c++) { if(!ch[cur][c]) ch[cur][c]=ch[fail[cur]][c]; else { fail[ch[cur][c]]=ch[fail[cur]][c]; qs[t++]=ch[cur][c]; } } } } #define S 22 int n,T,up[SZ][S],dep[SZ],orz[SZ]; int dfn[SZ],ls[SZ],D=0; char str[SZ]; void dfs(int x,int f=0) { dfn[x]=++D; up[x][0]=f; for(int i=1;i<=S-1;i++) up[x][i]=up[up[x][i-1]][i-1]; for esb(x,e,b) { if(b==f) continue; dep[b]=dep[x]+1; dfs(b,x); } ls[x]=D; } int jump(int x,int d) { for(int i=S-1;i>=0;i--) { if(up[x][i]&&dep[up[x][i]]>=d) x=up[x][i]; } return x; } int lca(int a,int b) { if(dep[a]>dep[b])swap(a,b); //dep[a]<=dep[b] b=jump(b,dep[a]); if(a==b) return a; for(int i=S-1;i>=0;i--) { if(up[a][i]==up[b][i]) continue; a=up[a][i]; b=up[b][i]; } return up[a][0]; } int bs[SZ]; int sum(int x) { int ans=0; for(;x>=1;x-=x&-x) ans+=bs[x]; return ans; } int sum(int l,int r) {return sum(r)-sum(l-1);} void edt(int x,int y) { for(;x<=D;x+=x&-x) bs[x]+=y; } int ss[SZ],sn=0,vfa[SZ]; bool cmpdfn(int a,int b) { if(a!=b) return dfn[a]<dfn[b]; return a<b; } int st[SZ],stn=0; int vs[SZ],vn=0; void inss(char* s) { int l=strlen(s),cur=1; sn=stn=vn=0; ss[++sn]=1; for(int i=0;i<l;i++) { cur=ch[cur][s[i]-'a']; ss[++sn]=cur; } sort(ss+1,ss+1+sn,cmpdfn); sn=unique(ss+1,ss+1+sn)-ss-1; for(int i=1;i<=sn;i++) vs[++vn]=ss[i]; for(int i=1;i<=sn;i++) { int x=ss[i]; if(!stn) {st[++stn]=x; vfa[x]=0; continue;} int lc=lca(x,st[stn]); while(stn&&dep[st[stn]]>dep[lc]) { if(dep[st[stn-1]]<=dep[lc]) vfa[st[stn]]=lc; --stn; } if(st[stn]!=lc) { vs[++vn]=lc; vfa[lc]=st[stn]; st[++stn]=lc; } vfa[x]=lc; st[++stn]=x; } for(int i=1;i<=vn;i++) { int v=vs[i]; if(!vfa[v]) continue; edt(dfn[v],1); edt(dfn[vfa[v]],-1); } } int main() { scanf("%d",&n); for(int i=1;i<=n;i++) { scanf("%s",str); orz[i]=insert(str); } bfail(); for(int i=2;i<=e;i++) adde(i,fail[i]); dfs(1); int q; scanf("%d",&q); while(q--) { int g,x; scanf("%d",&g); if(g==1) scanf("%s",str), inss(str); else scanf("%d",&x), printf("%d\n",sum(dfn[orz[x]],ls[orz[x]])); } }