3473: 字符串
Time Limit: 20 Sec Memory Limit: 256 MBSubmit: 354 Solved: 160
[Submit][Status][Discuss]
Description
給定n個字符串,詢問每個字符串有多少子串(不包括空串)是所有n個字符串中至少k個字符串的子串?
Input
第一行兩個整數n,k。
接下來n行每行一個字符串。
Output
一行n個整數,第i個整數表示第i個字符串的答案。
字符串總長度L
n,k,L<=1e5
研究了兩節多課廣義后綴自動機是什么,還看了2015國家隊論文,然后發現,廣義后綴自動機不就是把很多串的SAM建到了一個SAM上,建每個串的時候都從root開始(last=root)就行了........
廣義后綴自動機是Trie樹的后綴自動機,可以解決多主串問題
這樣的在線構造算法復雜度為O(G(T)),G(T)為Trie樹上所有葉子節點深度和,發現G(T)<=所有主串總長度
還有一種離線算法,復雜度O(|T||A|) ,不學了吧
對於本題,建出廣義SAM后,只要得到每個狀態出現在不同串中的次數就好做了
我們跑每個子串,然后更新狀態
狀態維護cou和cur分別為出現次數及上一次出現是哪個串,然后就可以不重復的統計啦
出現次數向父親傳遞,所以要沿着Parent向上跑更新,遇到cur=當前串的就不用繼續跑了,這樣最壞情況下復雜度為O(L^3/2),發生在n=L的時候(均值不等式啊)
剩下的只要DP出f[i]為i及其Parent祖先出現次數>=k有多少字符串(注意一個狀態貢獻的字符串為t[par].val-t[u].val),然后在跑一遍每個字符串得到答案就行了
注意sz也要=1啊啊啊啊啊啊啊啊啊再讓你作死寫新模板
#include <iostream> #include <cstdio> #include <algorithm> #include <cstring> #include <cmath> #include <string> using namespace std; const int N=2e5+5; typedef long long ll; int n,k; string s[N>>1]; char ss[N>>1]; struct node{ int ch[26],par,val; int cou,cur; }t[N]; int sz=1,root=1,last=1; void extend(int c){ int p=last,np=++sz; t[np].val=t[p].val+1; for(;p&&!t[p].ch[c];p=t[p].par) t[p].ch[c]=np; if(!p) t[np].par=root; else{ int q=t[p].ch[c]; if(t[q].val==t[p].val+1) t[np].par=q; else{ int nq=++sz; t[nq]=t[q];t[nq].val=t[p].val+1; t[q].par=t[np].par=nq; for(;p&&t[p].ch[c]==q;p=t[p].par) t[p].ch[c]=nq; } } last=np; } int c[N],a[N]; ll f[N]; void RadixSort(){ for(int i=1;i<=sz;i++) c[t[i].val]++; for(int i=1;i<=sz;i++) c[i]+=c[i-1]; for(int i=sz;i>=1;i--) a[c[t[i].val]--]=i; } void solve(){ int u;ll ans; for(int i=1;i<=n;i++){//printf("i %d\n",i); u=root; for(int j=0;j<s[i].size();j++){ u=t[u].ch[s[i][j]-'a'];//printf("u %d %d %d\n",u,t[u].cou,j); int p=u; for(;p&&t[p].cur!=i;p=t[p].par) t[p].cou++,t[p].cur=i; } } RadixSort(); for(int i=1;i<=sz;i++) u=a[i]; t[1].cou=0; for(int i=1;i<=sz;i++) u=a[i],f[u]=f[t[u].par]+(t[u].cou>=k?t[u].val-t[t[u].par].val:0); for(int i=1;i<=n;i++){ u=root;ans=0; for(int j=0;j<s[i].size();j++){ u=t[u].ch[s[i][j]-'a']; ans+=f[u]; } printf("%lld ",ans); } } int main(){ freopen("in","r",stdin); scanf("%d%d",&n,&k); for(int i=1;i<=n;i++){ scanf("%s",ss),s[i]=string(ss); last=root; for(int j=0;j<s[i].size();j++) extend(s[i][j]-'a'); } solve(); }