「NOI2016」優秀的拆分
題目描述
如果一個字符串可以被拆分為 \(\text{AABB}\) 的形式,其中 \(\text{A}\) 和 \(\text{B}\) 是任意非空字符串,則我們稱該字符串的這種拆分是優秀的。
例如,對於字符串 \(\text {aabaabaa}\) ,如果令 \(\text{A}=\texttt{aab}\),\(\text{B}=\texttt{a}\),我們就找到了這個字符串拆分成 \(\text{AABB}\) 的一種方式。
一個字符串可能沒有優秀的拆分,也可能存在不止一種優秀的拆分。
比如我們令 \(\text{A}=\texttt{a}\),\(\text{B}=\texttt{baa}\),也可以用 \(\text{AABB}\) 表示出上述字符串;但是,字符串 \(\texttt{abaabaa}\) 就沒有優秀的拆分。
現在給出一個長度為 \(n\) 的字符串 \(S\),我們需要求出,在它所有子串的所有拆分方式中,優秀拆分的總個數。這里的子串是指字符串中連續的一段。
以下事項需要注意:
- 出現在不同位置的相同子串,我們認為是不同的子串,它們的優秀拆分均會被記入答案。
- 在一個拆分中,允許出現 \(\text{A}=\text{B}\)。例如 \(\texttt{cccc}\) 存在拆分 \(\text{A}=\text{B}=\texttt{c}\)。
- 字符串本身也是它的一個子串。
輸入格式
每個輸入文件包含多組數據。
輸入文件的第一行只有一個整數 \(T\),表示數據的組數。
接下來 \(T\) 行,每行包含一個僅由英文小寫字母構成的字符串 \(S\),意義如題所述。
輸出格式
輸出 \(T\) 行,每行包含一個整數,表示字符串 \(S\) 所有子串的所有拆分中,總共有多少個是優秀的拆分。
樣例
樣例輸入
4
aabbbb
cccccc
aabaabaabaa
bbaabaababaaba
樣例輸出
3
5
4
7
樣例解釋
我們用 \(S[i, j]\) 表示字符串 \(S\) 第 \(i\) 個字符到第 \(j\) 個字符的子串(從 \(1\) 開始計數)。
第一組數據中,共有三個子串存在優秀的拆分:
\(S[1,4]=\text{aabb}\),優秀的拆分為 \(\text{A}=\texttt{a}\),\(\text{B}=\texttt{b}\);
\(S[3,6]=\text{bbbb}\),優秀的拆分為 \(\text{A}=\texttt{b}\),\(\text{B}=\texttt{b}\);
\(S[1,6]=\text{aabbbb}\),優秀的拆分為 \(\text{A}=\texttt{a}\),\(\text{B}=\texttt{bb}\)。
而剩下的子串不存在優秀的拆分,所以第一組數據的答案是 \(3\)。
第二組數據中,有兩類,總共四個子串存在優秀的拆分:
對於子串 \(S[1,4]=S[2,5]=S[3,6]=\text{cccc}\),它們優秀的拆分相同,均為 \(\text{A}=\texttt{c}\),\(\text{B}=\texttt{c}\),但由於這些子串位置不同,因此要計算三次;
對於子串 \(S[1,6]=\text{cccccc}\),它優秀的拆分有兩種:\(\text{A}=\texttt{c}\),\(\text{B}=\texttt{cc}\) 和 \(\text{A}=\texttt{cc}\),\(\text{B}=\texttt{c}\),它們是相同子串的不同拆分,也都要計入答案。
所以第二組數據的答案是 \(3+2=5\)。
第三組數據中,\(S[1,8]\) 和 \(S[4,11]\) 各有兩種優秀的拆分,其中 \(S[1,8]\) 是問題描述中的例子,所以答案是 \(2+2=4\)。
第四組數據中,\(S[1,4]\),\(S[6,11]\),\(S[7,12]\),\(S[2,11]\),\(S[1,8]\) 各有一種優秀的拆分,\(S[3,14]\) 有兩種優秀的拆分,所以答案是 \(5+2=7\)。
數據范圍與提示
對於全部的測試點,\(1 \leq T \leq 10, \ n \leq 30000\)。
題解
\(95\)分hash暴力真的就是隨便寫...
我們處理出\(a[i]\)和\(b[i]\)表示以\(i\)為終點和起點的\(AA\)串的個數。那么答案即為\(\sum_{i=1}^{n-1}a[i]\times b[i + 1]\)。hash優化一下判定過程就是\(O(n^2)\)的。
\(100\)分不看題解真的沒有什么思路(即使知道了這是一道后綴數組題...)
我們可以思考一下如何優化處理\(AA\)串的過程。
枚舉\(A\)串的長度\(len\),然后對於相鄰的兩個長度間隔為\(len\)的點,如果他們的\(lcp(x,y)+lcs(x,y)\geq len\),那么中間則有一段長度為\(lcp+lcs-len+1\)的合法的\(AA\)串終點的區間。
為什么呢?可以通過把這句話畫出來,比如這樣:
那么中間那段紅色的區域就是合法的終點區間。
\(lcp(x,y)\)和\(lcs(x,y)\)可以直接用后綴數組來求。總復雜度為\(O(n \log n)\)。
當然也可以用hash實現這個過程,復雜度就是\(O(n \log^2 n)\)的。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 50010;
int n, a[N], b[N];
char s[N];
struct SA {
int sa[N], height[N], tong[N], rnk[N], tp[N], f[N][16], LG[N];
int m;
void radix_sort() {
for(int i = 1; i <= m; ++i) tong[i] = 0;
for(int i = 1; i <= n; ++i) tong[rnk[i]]++;
for(int i = 1; i <= m; ++i) tong[i] += tong[i - 1];
for(int i = n; i; --i) sa[tong[rnk[tp[i]]]--] = tp[i];
}
int query(int l, int r) {
l = rnk[l], r = rnk[r];
if(l > r) swap(l, r); ++l;
int k = LG[r - l + 1];
return min(f[l][k], f[r - (1 << k) + 1][k]);
}
void init() {
memset(sa, 0, sizeof(sa));
memset(height, 0, sizeof(height));
memset(tong, 0, sizeof(tong));
memset(rnk, 0, sizeof(rnk));
memset(tp, 0, sizeof(tp));
memset(f, 0, sizeof(f));
memset(LG, 0, sizeof(LG));
}
void build(char *A) {
init();
for(int i = 1; i <= n; ++i) rnk[i] = A[i], tp[i] = i;
m = 200; radix_sort();
for(int w = 1, p = 0; w <= n && p < n; m = p, w <<= 1) {
p = 0;
for(int i = 1; i <= w; ++i) tp[++p] = n - w + i;
for(int i = 1; i <= n; ++i) if(sa[i] > w) tp[++p] = sa[i] - w;
radix_sort(); swap(tp, rnk); rnk[sa[1]] = p = 1;
for(int i = 2; i <= n; ++i)
rnk[sa[i]] = (tp[sa[i]] == tp[sa[i - 1]] && tp[sa[i] + w] == tp[sa[i - 1] + w]) ? p : ++p;
}
for(int i = 1, k = 0; i <= n; ++i) {
if(k) --k; int j = sa[rnk[i] - 1];
while(A[i + k] == A[j + k] && i + k <= n && j + k <= n) ++k;
height[rnk[i]] = k;
}
for(int i = 2; i <= n; ++i) LG[i] = LG[i >> 1] + 1;
for(int i = 1; i <= n; ++i) f[i][0] = height[i];
for(int j = 1; j <= 15; ++j)
for(int i = 1; i + (1 << j) - 1 <= n; ++i) {
f[i][j] = min(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
}
}A, B;
int main() {
int T = 0; scanf("%d", &T); while(T--) {
memset(a, 0, sizeof(a));
memset(b, 0, sizeof(b));
scanf("%s", s + 1); n = strlen(s + 1);
A.build(s); reverse(s + 1, s + n + 1); B.build(s);
for(int len = 1; len <= (n >> 1); ++len) {
for(int i = len, j = i + len; j <= n; i += len, j += len) {
int LCS = min(len - 1, B.query(n - i + 2, n - j + 2)), LCP = min(len, A.query(i, j));
if(LCS + LCP >= len) {
int t = LCP + LCS - len + 1;
a[i - LCS]++; a[i - LCS + t]--;
b[j + LCP - t]++; b[j + LCP]--;
}
}
}
for(int i = 1; i <= n; ++i) a[i] += a[i - 1], b[i] += b[i - 1];
ll ans = 0;
for(int i = 1; i < n; ++i) ans += 1LL * b[i] * a[i + 1];
printf("%lld\n", ans);
}
return 0;
}