最長公共子序列(LCS)問題
你有兩個字符串 \(A,B\),字符集為 \(\Sigma\),求 \(A, B\) 的最長公共子序列。
簡單動態規划
首先有一個廣為人知的 dp:\(f_{i,j}\) 為 \(A\) 的長度為 \(j\) 的前綴與 \(B\) 長度為 \(i\) 的前綴的 LCS。(注意 \(i\) 和 \(j\) 分別對於那個串)
那么顯然有:
然而這是 \(O(n^2)\) 的,在略大的數據下就很容易 TLE。
還有一個 \(O(n\log n)\) 的算法,但只是針對排列的情況。
然后我們介紹一個基於 位運算 的優化方法。
這怎么就能位運算了呢?看着就不怎么 01。
但是有一個極其重要的性質:
即 \(f\) 的同一行內是 單調不減 並且 相鄰兩個相差不超過一。
矩陣 \(M\)
我們定義矩陣 \(M\) 為 \(f\) 數組每行分別 差分 的結果,即:
根據上述 \(f\) 的性質,不難發現 \(M\) 是個 01矩陣。那么可以直接 壓位(類似 std::bitset
)。
然后考慮直接轉移 \(M_i\) 整行,最后 \(\sum_{j}M_{|B|,j}\) 就是答案。這就是優化的基本思想。
字符比較串 \(p\)
我們定義 \(p(c)\) 為字符 \(c\) 在字符串 \(A\) 中出現的所有位置的集合,\(p(c)_i=1\) 表示 \(A_i=c\)。這是我們轉移的工具。
要預處理 \(p\) 我們需要 \(O(|\Sigma|\times |A|)\) 的空間。然而我們發現 \(p\) 中只有 \(0/1\),所以我們可以用類似於 \(M\) 進行壓位優化,那就只要 \(O\left(\frac{|\Sigma|\times |A|}{w}\right)\),一般來說還是一個可承受的量級。
\(M\) 的實際意義
上面只提到 \(M\) 是個差分數組,現在來考慮它的實際意義是什么,以便推出它的轉移方式。
考慮一個 \(M_{i,j}\) 什么時候會是 \(1\)。觀察原轉移方程,發現 \(f_{i,j-1}\) 方向必然不會使 \(f_{i,j}\) 加一,唯一兩個方向就是 \(f_{i-1,j-1}\) 或 \(f_{i-1,j}\)。
如果是從 \(f_{i-1,j-1}+1\) 而來,那么說明這個位置 \(A_j\) 發生了配對,從而答案 \(+1\);
如果是 \(f_{i-1,j}\),仔細思考一下還是一樣的,在下面總有一個位置會和上面一條相同。
總而言之就是 \(A_j\) 被計入答案 了,但注意這不意味着 \(M_i\) 中所有的 \(1\) 都對應一個被選中的 \(A_j\)。
正確的理解是 \(M_{i,j}\) 如果為 \(1\),設 \(k\) 為當前位到第一位之間 \(1\) 的個數,那就說明當前一個 LCS 長度為 \(k\) 的方案,最后的一位為 \(j\)。事實我們也是只需要考慮當前 LCS 的最后一位,添加時答案只要保證在當前方案的最后一位之后即可。
轉移方式
對於一整行 \(M_{i-1}\),我們對其分段,每段有前面一個極長 \(0\) 段,由一個單獨的 \(1\) 結尾,最后一整段 \(0\) 單獨成段。
然后用當前 \(B\) 的字符 \(p(B_i)\) 與之比對(注意這里是倒着的):
M[i - 1]: [1 0 0 0 0 0 0][1 0 0 0][1][1][1][1][1]
p[B[i]] : [0 1 0 1 1 0 0 0 1 0 0 0 1 1 0 0]
^ ^
| |
j = |A| j = 1
然后將兩者做 按位或 操作,再對於每個段按位或的結果取 段中的最后一個 \(1\),得:
M[i] : [0 0 0 0 1 0 0 0 1 0 0 1 1 1 1 1]
這個過程相當於 \(M_{i-1}\) 借助 \(p(B_i)\) 將這些 \(1\) 盡量向字符串的開頭移,以便為之后的匹配留足更大的機會。
至於其中的意義可以結合上面理解,大概就是對於每個長度的方案,都在不超過下一個長度的前提下前移。具體細節我也說不清楚
轉移實現
上面的轉移過於復雜,很難用我們熟知的位運算進行優化,於是嘗試將它翻譯成位運算。
我也不知道原論文作者怎么想到的,這里就說只一下做法吧。
我們記 \(X = M_{i-1}\ \texttt{OR}\ p(B_i)\),然后我們需要取其中最后一位:
X : [1 1 0 1 1 0 0 1 1 0 0 1 1 1 1 1]
然后將 \(M_{i-1}\) 右移一位,頭部補上 \(1\),並用 \(X\) 數值減 這個 01 串,得:
[1 1 0 1 1 0 0][1 1 0 0][1][1][1][1][1]
- [0 0 0 0 0 0 1 0 0 0 1 1 1 1 1 1]
--------------------------------------------------
[1 1 0 1 0 1 1][1 0 1 1][0][0][0][0][0]
這么做旨在將每段的末尾 \(0\) 段,然后將原來最右邊的 \(1\) 變成 \(0\)。
然后和 \(X\) 進行 異或 操作:
[0 0 0 0 1 1 1][0 1 1 1][1][1][1][1][1]
這樣就使最開始的最右邊的 \(1\) 到段尾變成 \(1\),其余變成 \(0\)。
最后只要保留第一個 \(1\),那么就剛好是 按位與 \(X\) 的結果。
於是得到:
那么在實現時,只要手寫一個 bitset
,支持按位與、或、異或、數值相減、位移即可。
復雜度
每次轉移需要 \(O\left(\frac {|A|} w\right)\),總時間復雜度為 \(O\left(\frac{|A|\times |B|}{w}\right)\)
空間瓶頸為 \(p\) 集合,為 \(O\left(\frac{|A|\times |\Sigma|}{w}\right)\),如果字符集 \(\Sigma\) 不確定可以離散化,空間為 \(O\left( \frac{|A|^2}{w} \right)\)。
參考代碼
下面的代碼實現 並不是倒着的(為了減法方便),於是位移什么的看着就有點詭異。
/*
* Author : _Wallace_
* Source : https://www.cnblogs.com/-Wallace-/
* Problem : LOJ #6564. 最長公共子序列
* Standard : GNU C++ 03
* Optimal : -Ofast
*/
#include <algorithm>
#include <cstddef>
#include <cstdio>
#include <cstring>
typedef unsigned long long ULL;
const int N = 7e4 + 5;
int n, m, u;
struct bitset {
ULL t[N / 64 + 5];
bitset() {
memset(t, 0, sizeof(t));
}
bitset(const bitset &rhs) {
memcpy(t, rhs.t, sizeof(t));
}
bitset& set(int p) {
t[p >> 6] |= 1llu << (p & 63);
return *this;
}
bitset& shift() {
ULL last = 0llu;
for (int i = 0; i < u; i++) {
ULL cur = t[i] >> 63;
(t[i] <<= 1) |= last, last = cur;
}
return *this;
}
int count() {
int ret = 0;
for (int i = 0; i < u; i++)
ret += __builtin_popcountll(t[i]);
return ret;
}
bitset& operator = (const bitset &rhs) {
memcpy(t, rhs.t, sizeof(t));
return *this;
}
bitset& operator &= (const bitset &rhs) {
for (int i = 0; i < u; i++) t[i] &= rhs.t[i];
return *this;
}
bitset& operator |= (const bitset &rhs) {
for (int i = 0; i < u; i++) t[i] |= rhs.t[i];
return *this;
}
bitset& operator ^= (const bitset &rhs) {
for (int i = 0; i < u; i++) t[i] ^= rhs.t[i];
return *this;
}
friend bitset operator - (const bitset &lhs, const bitset &rhs) {
ULL last = 0llu; bitset ret;
for (int i = 0; i < u; i++){
ULL cur = (lhs.t[i] < rhs.t[i] + last);
ret.t[i] = lhs.t[i] - rhs.t[i] - last;
last = cur;
}
return ret;
}
} p[N], f, g;
signed main() {
scanf("%d%d", &n, &m), u = n / 64 + 1;
for (int i = 1, c; i <= n; i++)
scanf("%d", &c), p[c].set(i);
for (int i = 1, c; i <= m; i++) {
scanf("%d", &c), (g = f) |= p[c];
f.shift(), f.set(0);
((f = g - f) ^= g) &= g;
}
printf("%d\n", f.count());
return 0;
}