題目描述
小A 被選為了\(ION2018\) 的出題人,他精心准備了一道質量十分高的題目,且已經 把除了題目命名以外的工作都做好了。
由於\(ION\) 已經舉辦了很多屆,所以在題目命名上也是有規定的,\(ION\) 命題手冊規 定:每年由命題委員會規定一個小寫字母字符串,我們稱之為那一年的命名串,要求每道題的名字必須是那一年的命名串的一個非空連續子串,且不能和前一年的任何一道題目的名字相同。
由於一些特殊的原因,小A 不知道\(ION2017\) 每道題的名字,但是他通過一些特殊 手段得到了\(ION2017\) 的命名串,現在小A 有\(Q\) 次詢問:每次給定\(ION2017\) 的命名串和\(ION2018\) 的命名串,求有幾種題目的命名,使得這個名字一定滿足命題委員會的規定,即是\(ION2018\) 的命名串的一個非空連續子串且一定不會和\(ION2017\) 的任何一道題目的名字相同。
由於一些特殊原因,所有詢問給出的\(ION2017\) 的命名串都是某個串的連續子串, 詳細可見輸入格式。
輸入格式:
第一行一個字符串\(S\) ,之后詢問給出的\(ION2017\) 的命名串都是\(S\) 的連續子串。 第二行一個正整數\(Q\),表示詢問次數。 接下來\(Q\) 行,每行有一個字符串\(T\) 和兩個正整數\(l,r\),表示詢問如果\(ION2017\) 的 命名串是\(S[l..r]\),\(ION2018\) 的命名串是\(T\) 的話,有幾種命名方式一定滿足規定。
輸出格式:
輸出\(Q\)行,第\(i\) 行一個非負整數表示第\(i\) 個詢問的答案。
先放一個亂搞一個做法(后面補了正解) :
首先考慮 \(l = 1, r = |S|\) 的部分分,我在同步賽上的亂搞做法是對 \(S\) 建 \(Sam\) ,每次詢問往 \(Sam\) 里面插入詢問串
取出新增的后綴節點,暴力在 \(parent\) 樹上跳父親,計算在 \(S\) 中出現的不同子串個數,以及總共的不同子串個數,相減就是答案
因為相同的子串只會被算一次,所以每一個節點的貢獻只會被算一次,單次復雜度是向上跳的期望節點數,根據 \(Sam\) 的一些奇奇怪怪的性質,復雜度上限是\(O(n\sqrt{n})\) ,但是在實際情況下根本卡不滿,這 \(68pt\) 中最慢的點是 \(0.8s\)
討論有了 \(l, r\) 的限制的情況,在原先的算法基礎上還需要求出每個節點在限制下能表示的最長的公共子串長度為 \(maxlen\)
設 \(r\) 在該節點 \(right\) 集合中的前驅是 \(r'\) ,那么 \(mxlen = r' - l + 1\) ,通過這個重新計算公共子串個數即可
考慮求出這個東西只需要在 \(Sam\) 上大力線段樹合並即可,但是無論多么不滿乘上 \(log\) 都會 \(Tle\)
考慮進行剪枝,對每一個節點維護 \(mx_u\) 和 \(mn_u\) 表示其 \(right\) 集合中在 \(S\) 串中出現的最靠前和最靠后的位置
如果有 \(r < mn_u\) 或者 \(l > mx_u\) 這個節點就不會有貢獻,可以剪掉
觀察發現,大部分情況下 \(mxlen > dep_u\) ,此時求前驅沒有任何用處,本質上是因為 \(mx_u > r\) 的緣故
所以在此可以再加上一個剪枝,這樣線段樹的查詢只會在很深的幾個節點被調用了.
測一下最大的數據驚奇的發現只需要\(2.5s\) ,交上去發現用了一個 \(O(n\sqrt{n}logn)\)的亂搞水過了此題,(震驚!)
/*pragram by mangoyang*/
#include<bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int f = 0, ch = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 3500005, MAXN = 1500005;
char s[N]; int rt[N], buf[N], n;
struct SegmentTree{
int lc[MAXN*25], rc[MAXN*25], sz[MAXN*25], size;
inline SegmentTree(){ size = 1; }
inline void ins(int &u, int l, int r, int pos){
if(!u) u = ++size;
if(l == r) return (void) (sz[u]++);
int mid = l + r >> 1;
if(pos <= mid) ins(lc[u], l, mid, pos);
else ins(rc[u], mid + 1, r, pos);
sz[u] = sz[lc[u]] + sz[rc[u]];
}
inline int merge(int x, int y, int l, int r){
if(!x || !y) return x + y;
int mid = l + r >> 1, o = ++size;
if(l == r) sz[o] = sz[x] + sz[y];
else{
lc[o] = merge(lc[x], lc[y], l, mid);
rc[o] = merge(rc[x], rc[y], mid + 1, r);
sz[o] = sz[lc[o]] + sz[rc[o]];
}
return o;
}
inline int query(int u, int l, int r, int pos){
if(!sz[u]) return 0;
if(l == r) return l;
int mid = l + r >> 1;
if(pos <= mid) return query(lc[u], l, mid, pos);
int rans = query(rc[u], mid + 1, r, pos);
return rans ? rans : query(lc[u], l, mid, pos);
}
}Seg;
struct SuffixAutomaton{
vector<int> g[N], v; ll dep[N];
int ch[N][26], fa[N], vis[N], mx[N], mn[N], tail, size;
inline SuffixAutomaton(){ size = tail = 1, rt[1] = 1; }
inline int newnode(int x){ return dep[++size] = x, size; }
inline void ins(int c, int ff, int pos){
int p = tail, np = newnode(dep[p] + 1);
if(ff) v.push_back(np); else{
Seg.ins(rt[np], 1, n, pos);
mx[np] = mn[np] = pos;
}
for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = np;
if(!p) return (void) (fa[np] = 1, tail = np);
int q = ch[p][c];
if(dep[q] == dep[p] + 1) fa[np] = q;
else{
int nq = newnode(dep[p] + 1);
fa[nq] = fa[q], fa[np] = fa[q] = nq;
if(ff) rt[nq] = rt[q], mx[nq] = mx[q], mn[nq] = mn[q];
for(int i = 0; i < 26; i++) ch[nq][i] = ch[q][i];
for(; p && ch[p][c] == q; p = fa[p]) ch[p][c] = nq;
}tail = np;
}
inline void addedge(){
for(int i = 2; i <= size; i++) g[fa[i]].push_back(i);
}
inline void dfs(int u){
for(int i = 0; i < g[u].size(); i++){
int v = g[u][i];
dfs(v), rt[u] = Seg.merge(rt[u], rt[v], 1, n);
mx[u] = Max(mx[u], mx[v]), mn[u] = Min(mn[u], mn[v]);
}
}
inline void prepare(char *s){
for(int i = 0; i < n; i++) ins(s[i] - 'a', 0, i + 1);
addedge(), dfs(1);
}
inline ll calc(char *s, int l, int r){
tail = 1; v.clear();
ll len = strlen(s), ans = 0, all = 0;
for(int i = 0; i < len; i++) ins(s[i] - 'a', 1, 0);
for(int i = 0; i < v.size(); i++){
int u = v[i];
for(int p = u; p > 1; p = fa[p]) {
if(vis[p]) break; int OK = 0;
all += dep[p] - dep[fa[p]], vis[p] = 1;
if(rt[p]){
if((l == 1 && r == n) || OK)
{ ans += dep[p] - dep[fa[p]]; continue; }
if(mx[p] < l || mn[p] > r) continue;
int mxlen = mx[p] <= r ? mx[p] - l + 1 : Seg.query(rt[p], 1, n, r) - l + 1;
if(mxlen > dep[fa[p]])
ans += Min(dep[p], mxlen) - dep[fa[p]];
if(mxlen >= dep[p]) OK = 1;
}
}
}
for(int i = 0; i < v.size(); i++){
int u = v[i];
for(int p = u; p > 1; p = fa[p]){
if(!vis[p]) break; vis[p] = 0;
}
}
return all - ans;
}
}van;
int main(){
scanf("%s", s); n = strlen(s);
int Q; read(Q), van.prepare(s);
while(Q--){
int l, r;
scanf("%s", s), read(l), read(r);
printf("%lld\n", van.calc(s, l, r));
}
}
正解
補集轉換一步,問題變成求 \(T\) 與 \(S[l_i:r_i]\) 的本質不同的公共子串數,考慮讓 \(T\) 在 \(S\) 的 \(sam\) 上匹配,雙指針找出每一個前綴能在 \(S[l_i:r_i]\) 中能匹配上的后綴長度 \(len[i]\),然后在 \(T\) 的 \(sam\) 上統計答案,對於每一個節點隨便找一個出現的前綴,拿這個前綴的 \([0,len[i]]\) 和其所能表示的字符串長度區間取交集即可。
求 \(len[i]\) 可以先找到第一個能接收當前字符 \(c\) 的節點,然后不斷刪去首字母,直到能在 \([l_i:r_i]\) 放下,也就是找到當前匹配節點的 \(right\) 集合,判斷一段區間內是否有元素,這個用隨便維護一下就好了,線段樹合並蠻好寫的。
/*program by mangoyang*/
#include <bits/stdc++.h>
#define inf (0x7f7f7f7f)
#define Max(a, b) ((a) > (b) ? (a) : (b))
#define Min(a, b) ((a) < (b) ? (a) : (b))
typedef long long ll;
using namespace std;
template <class T>
inline void read(T &x){
int ch = 0, f = 0; x = 0;
for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1;
for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
if(f) x = -x;
}
const int N = 2000005;
char s[N];
int res[N], n, q;
namespace Seg{
#define mid ((l + r) >> 1)
int lc[N*25], rc[N*25], sz[N*25], size;
inline void ins(int &u, int l, int r, int pos){
if(!u) u = ++size;
if(l == r) return (void) (sz[u]++);
if(pos <= mid) ins(lc[u], l, mid, pos);
else ins(rc[u], mid + 1, r, pos);
sz[u] = sz[lc[u]] + sz[rc[u]];
}
inline int merge(int x, int y, int l, int r){
if(!x || !y) return x + y;
int o = ++size;
if(l == r) sz[o] = sz[x] + sz[y];
else{
lc[o] = merge(lc[x], lc[y], l, mid);
rc[o] = merge(rc[x], rc[y], mid + 1, r);
sz[o] = sz[lc[o]] + sz[rc[o]];
}
return o;
}
inline int query(int u, int l, int r, int L, int R){
if(l >= L && r <= R) return sz[u];
int res = 0;
if(L <= mid) res += query(lc[u], l, mid, L, R);
if(mid < R) res += query(rc[u], mid + 1, r, L, R);
return res;
}
#undef mid
}
vector<int> vec[N];
namespace SAM1{
vector<int> g[N];
int ch[N][26], rt[N], fa[N], len[N], size = 1, tail = 1;
inline int newnode(int x){ return len[++size] = x, size; }
inline void ins(int c, int x){
int p = tail, np = newnode(len[p] + 1);
Seg::ins(rt[np], 1, n, x);
for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = np;
if(!p) return (void) (fa[np] = 1, tail = np);
int q = ch[p][c];
if(len[q] == len[p] + 1) fa[np] = q;
else{
int nq = newnode(len[p] + 1);
fa[nq] = fa[q], fa[q] = fa[np] = nq;
for(int i = 0; i < 26; i++) ch[nq][i] = ch[q][i];
for(; p && ch[p][c] == q; p = fa[p]) ch[p][c] = nq;
}tail = np;
}
inline void addedge(){
for(int i = 2; i <= size; i++) g[fa[i]].push_back(i);
}
inline void dfs(int u){
for(int i = 0; i < (int) g[u].size(); i++)
dfs(g[u][i]), rt[u] = Seg::merge(rt[u], rt[g[u][i]], 1, n);
}
inline void solve(char *s, int L, int R){
int lenth = strlen(s + 1);
for(int i = 1, p = 1, now = 0; i <= lenth; i++){
int c = s[i] - 'a';
while(!ch[p][c] && p) p = fa[p], now = len[p];
if(!p){ p = 1, now = 0; continue; };
p = ch[p][c], now++;
while(p > 1){
if(Seg::query(rt[p], 1, n, L + now - 1, R)) break;
if(--now == len[fa[p]]) p = fa[p];
}
if(p == 1) continue;
for(int j = 0; j < (int) vec[i].size(); j++)
res[vec[i][j]] = max(res[vec[i][j]], now);
}
}
}
namespace SAM2{
int fa[N], len[N], ch[N][26], size, tail;
inline void Clear(){
for(int i = 1; i <= size; i++){
fa[i] = len[i] = res[i] = 0;
memset(ch[i], 0, sizeof(ch[i]));
}
size = tail = 1;
}
inline int newnode(int x){ return len[++size] = x, size; }
inline void ins(int c, int x){
int p = tail, np = newnode(len[p] + 1);
vec[x].push_back(np);
for(; p && !ch[p][c]; p = fa[p]) ch[p][c] = np;
if(!p) return (void) (fa[np] = 1, tail = np);
int q = ch[p][c];
if(len[q] == len[p] + 1) fa[np] = q;
else{
int nq = newnode(len[p] + 1);
vec[x].push_back(nq);
fa[nq] = fa[q], fa[q] = fa[np] = nq;
for(int i = 0; i < 26; i++) ch[nq][i] = ch[q][i];
for(; p && ch[p][c] == q; p = fa[p]) ch[p][c] = nq;
}tail = np;
}
inline ll solve(){
ll ans1 = 0, ans2 = 0;
for(int i = 1; i <= size; i++){
if(res[i] > len[fa[i]])
ans2 += 1ll * min(res[i], len[i]) - len[fa[i]];
ans1 += 1ll * len[i] - len[fa[i]];
}
return ans1 - ans2;
}
}
int main(){
scanf("%s", s + 1), n = strlen(s + 1);
for(int i = 1; i <= n; i++) SAM1::ins(s[i] - 'a', i);
SAM1::addedge(), SAM1::dfs(1);
read(q); int L, R;
while(q--){
scanf("%s", s + 1); int m = strlen(s + 1);
read(L), read(R);
for(int i = 1; i <= m; i++) vec[i].clear();
SAM2::Clear();
for(int i = 1; i <= m; i++) SAM2::ins(s[i] - 'a', i);
SAM1::solve(s, L, R);
printf("%lld\n", SAM2::solve());
}
return 0;
}