定義
后綴平衡樹,就是動態的維護后綴數組,可以 \(O(\log n)\) 在末尾插入字符,\(O(\log n)\) 查詢 \(rank,SA\)。但是由於是維護的后綴信息,所以插入只能在末尾插入字符(然后轉化成在開頭加一個字符),相當於添加一個后綴。
在線構造
方法一:
我們需要一種能比較兩個后綴大小的方法,最簡單就是二分+Hash,\(O(\log n)\) 實現,再加上平衡樹插入復雜度,總復雜度 \(O(\log^2 n)\)。
方法二:
考慮另一種比較方法,因為每次只添加一個字符,也就是說如果把第一個字符刪掉,那剩下的字符串在之前已經插入過后綴平衡樹中了,我們只需要先比較一下兩個后綴的第一個字符,后面字符串的比較直接調用之前信息就好。
那怎么快速比較后綴平衡樹中兩個后綴大小呢?我們給每個點一個權值區間 \([l,r]\),定義這個點的權值 \(tag_i\) 為 \(mid=\frac{l+r}2\)。那它左子樹對應的區間就是 \([l,mid-1]\),右子樹對應的區間就是 \([mid+1,r]\)。發現如果按照中序遍歷的順序遍歷整顆平衡樹,那每個點的權值是單調遞增的。
可是這是平衡樹誒。如果是那種基於旋轉重構的平衡樹那豈不是每次旋轉都要重構一遍子樹內的權值? \(emmm\) 確實是這樣,所以要用到一種更高級的平衡樹---重量平衡樹。重量平衡樹就是要保證平衡不能是均攤平衡,然后要么沒有旋轉,要么旋轉影響的子樹大小是期望\(\log\)或者均攤\(\log\)。\(Treap\) 和替罪羊樹都滿足這個條件。所以我們直接拿 \(Treap\) 維護就好了。
代碼實現
嗯以上就是基本概念了,然后講一下怎么維護好我們需要的 \(rank,sa,height\) 數組。
再次強調這里把每次往末尾插入一個元素轉化成每次往開頭插入一個元素
回顧最開始說的比較方法,如果兩個字符串首字符不一樣那么直接比較,否則比較去掉首字符后兩個字符串的 \(tag\) 值。代碼長這樣:
bool cmp(int x,int y){ // 比較第x個插入的后綴和第y個插入的后綴哪個字典序小 x<y返回1
return s[x]<s[y] or s[x]==s[y] and tag[x-1]<tag[y-1];
}
如果需要往平衡樹里插入字符c,設當前要插入的元素是第 \(tot\) 個元素,即 \(s[tot]=c\)。先把 \(tot\) 扔進去,找到 \(tot\) 的前驅后繼,也就是它在 \(sa\) 數組上的前驅后繼 \(pre,nxt\)。因為要維護好 \(height\),之前的 \(height[nxt]=lcp(pre,nxt)\) ,如果要往中間插入一個 \(tot\) 的話,那就需要讓 \(height[tot]=lcp(pre,tot),height[nxt]=lcp(tot,nxt)\)。這樣就維護好了 \(height\) 數組。
\(sa\) 和 \(rank\) 比較簡單,一個是找第 \(k\) 大,一個是找排名為 \(k\),都是平衡樹的基本操作了。
還有就是如果要刪除怎么辦。同樣找到 $pre,nxt $ ,令 \(height[nxt]=lcp(pre,nxt)\) 即可。然后在平衡樹上刪除點 \(tot\) 的時候也要注意一下,如果當前已經找到了 \(tot\),那就不用像普通的 \(Treap\) 一樣旋轉到葉子結點再刪除,因為每次旋轉的時候都需要遍歷整個子樹然后重構權值,所以這里直接像 \(fhq\_Treap\) 一樣把 \(tot\) 的兩個孩子 \(merge\) 起來,最后遍歷一遍就好了。刪除代碼:
void remove(int &x,int l,int r){
if(x==tot){
x=merge(ch[x][0],ch[x][1]);
dfs(x,l,r);return;
} else{
sze[x]--;
int mid=l+r>>1;
if(cmp(x,tot)) remove(ch[x][1],mid+1,r);
else remove(ch[x][0],l,mid-1);
}
}
同時因為我們需要二分+哈希維護 \(height\),所以還需要動態維護哈希值。
別的就沒啥了。
也不知道有啥用。
Code
一道例題 要求往字符串末尾插入一個字符,撤銷一個插入,詢問當前字符串本質不同子串個數。
#pragma GCC optimize(2)
#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
#define int long long
const int N=1e5+5;
const int inf=1e18;
const int base=9973;
char s[N];
int lcp[N],ans,prio[N];
int root,tag[N],ch[N][2];
int n,tot,sze[N];;ull hsh[N],pw[N];
int getint(){
int X=0,w=0;char ch=getchar();
while(!isdigit(ch))w|=ch=='-',ch=getchar();
while( isdigit(ch))X=X*10+ch-48,ch=getchar();
if(w) return -X;return X;
}
bool cmp(int x,int y){
return s[x]<s[y] or s[x]==s[y] and tag[x-1]<tag[y-1];
}
void dfs(int x,int l,int r){
if(!x) return;
int mid=l+r>>1;
tag[x]=mid;
dfs(ch[x][0],l,mid-1),dfs(ch[x][1],mid+1,r);
sze[x]=sze[ch[x][0]]+sze[ch[x][1]]+1;
}
void rotate(int &x,int d,int l,int r){
int y=ch[x][d],z=ch[y][d^1];
ch[y][d^1]=x;ch[x][d]=z;
x=y;dfs(x,l,r);
}
void insert(int &x,int l,int r){
if(!x) {
x=tot;tag[x]=l+r>>1;
sze[x]=1;ch[x][0]=ch[x][1]=0;
prio[x]=rand();return;
} int d=cmp(x,tot),mid=l+r>>1;
sze[x]++;
if(d) insert(ch[x][d],mid+1,r);
else insert(ch[x][d],l,mid-1);
if(prio[ch[x][d]]<prio[x]) rotate(x,d,l,r);
}
int find(int x,int now){
if(x==now) return sze[ch[x][0]]+1;
int d=cmp(x,now);
if(d) return sze[ch[x][0]]+1+find(ch[x][1],now);
else return find(ch[x][0],now);
}
int kth(int x,int k){
if(!x) return 0;
if(sze[ch[x][0]]==k-1) return x;
if(sze[ch[x][0]]>=k) return kth(ch[x][0],k);
return kth(ch[x][1],k-sze[ch[x][0]]-1);
}
bool eq(int l1,int l2,int len){
ull a=hsh[l1+len-1]-hsh[l1-1]*pw[len],b=hsh[l2+len-1]-hsh[l2-1]*pw[len];
return a==b;
}
int getlcp(int a,int b){
int l=1,r=min(a,b),ans=0;
while(l<=r){
int mid=l+r>>1;
if(eq(a-mid+1,b-mid+1,mid)) ans=mid,l=mid+1;
else r=mid-1;
} return ans;
}
void ins(int x){
s[++tot]=s[x];hsh[tot]=hsh[tot-1]*base+s[x]-'a';
insert(root,1,inf);
int a=find(root,tot),b=kth(root,a-1),c=kth(root,a+1);
ans-=lcp[c];
lcp[tot]=getlcp(b,tot),lcp[c]=getlcp(tot,c);
ans+=lcp[tot]+lcp[c];
}
int merge(int x,int y){
if(!x or !y) return x+y;
if(prio[x]<prio[y]) {
ch[x][1]=merge(ch[x][1],y);
return x;
}
else {
ch[y][0]=merge(x,ch[y][0]);
return y;
}
}
void remove(int &x,int l,int r){
if(x==tot){
x=merge(ch[x][0],ch[x][1]);
dfs(x,l,r);return;
} else{
sze[x]--;
int mid=l+r>>1,d=cmp(x,tot);
if(d) remove(ch[x][d],mid+1,r);
else remove(ch[x][d],l,mid-1);
}
}
void del(){
int rk=find(root,tot);
int b=kth(root,rk-1),c=kth(root,rk+1);
ans-=lcp[tot]+lcp[c];
lcp[c]=getlcp(b,c);
ans+=lcp[c];
remove(root,1,inf);
tot--;
}
signed main(){
srand(20020619);
scanf("%s",s+1);n=strlen(s+1);
pw[0]=1;for(int i=1;i<=n;i++) pw[i]=pw[i-1]*base;
for(int i=1;i<=n;i++){
if(s[i]=='-') del();
else ins(i);
printf("%lld\n",tot*(tot+1)/2-ans);
} return 0;
}