寫在前面
最大子段和和 GSS1 的題解區還沒有下面這種做法,你們快上啊(
廣義矩陣乘法
對於一 \(p\times m\) 的矩陣 \(A\),與 \(m\times q\) 的矩陣 \(B\),定義廣義矩陣乘法 \(A\times B = C\) 的結果是一個 \(p\times q\) 的矩陣 \(C\),滿足:
其中 \(\oplus\) 與 \(\otimes\) 是兩種二元運算。
考察這種廣義矩陣乘法是否滿足結合律:
顯然,若 \(\otimes\) 運算滿足交換律、結合律,且 \(\otimes\) 對 \(\oplus\) 存在分配律,即存在 \(\left(\bigoplus a\right)\otimes b = \bigoplus \left( a\otimes b \right)\) 時,廣義矩陣乘法滿足結合律。根據上述運算規律,對二式進行 \(\oplus\) 的交換后有:
維護 DP
以 P1115 最大子段和 為例。
給定一個長度為 \(n\) 的數列 \(a\),選出其中連續且非空的一段使得這段和最大。
\(1\le n\le 2\times 10^5\),\(-10^4\le a_i\le 10^4\)。
1S,128MB。
記 \(f_i\) 表示以 \(a_i\) 結尾的最大子段和,初始化 \(f_0 = -\infin\)。轉移時考察是否要加上前面一段的貢獻。前面一段的最大貢獻為 \(f_{i-1}\)。則顯然有:
定義 \(g\) 為 \(f\) 的前綴最大值,答案即為 \(g_n\)。算法總時間復雜度 \(O(n)\) 級別。
考慮加法運算運算與取 \(\max\) 運算的性質:發現取 \(\max\) 滿足交換律與結合律,且加法對取 \(\max\) 滿足分配率,即有:
考慮定義一種廣義矩陣乘法 \(A\times B = C\),滿足:
考慮將上述狀態轉移方程寫成廣義矩陣乘法形式。當從 \(i-1\) 轉移到 \(i\) 時,顯然有:
根據上述分析,顯然該運算滿足結合律,則有:
其中 \(\prod\) 表示連續廣義矩陣乘法。預處理整個序列的廣義矩陣乘積后,根據上式即得 答案 \(g_{n}\)。總復雜度 \(O\left(3^3\times n\right)\) 級別。
靜態區間查詢
SP1043 GSS1 - Can you answer these queries I
給定一個長度為 \(n\) 的數列 \(a\),給定 \(m\) 次詢問。
每次詢問給定區間 \([l,r]\),要求選出區間 \([l,r]\) 中連續且非空的一段使得這段和最大,輸出最大子段和。
\(1\le n\le 5\times 10^4\),\(-15007\le a_i\le 15007\)。
原題面中並沒有給出 \(m\) 的范圍,此處根據實際測試情況推斷 \(m\) 與 \(n\) 同階。
230ms,1.46G。
發現上述題目中廣義矩陣乘法做法 復雜度比直接做還劣 有着很好的擴展性。對於任意區間,預處理區間對應的廣義矩陣乘積后即得該區間的最大子段和。
問題變為如何快速求得區間廣義矩陣乘積。廣義矩陣乘法滿足結合律,且本題中沒有修改操作,考慮對於每個位置 \(i\) 預處理以 \(i\) 為左端點的長度為 \(2\) 的冪的區間的廣義矩陣乘積。回答詢問時倍增拼湊區間即可。總時間復雜度 \(O\left(3^3 (n+m)\log n\right)\) 級別。
不用維護一堆亂七八糟的玩意,個人認為比隔壁直接上線段樹好寫(
//知識點:矩陣乘法,倍增
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 5e4 + 10;
const int kL = 3;
const LL kInf = 1e9 + 2077;
//=============================================================
int n, m;
struct Matrix {
LL a[kL][kL];
Matrix() {
memset(a, 0, sizeof (a));
}
void build() {
for (int i = 1; i <= kL; ++ i) a[i][i] = 1;
}
Matrix operator * (const Matrix &b_) const {
Matrix ret;
memset(ret.a, 128, sizeof (ret.a));
for (int k = 0; k < kL; ++ k) {
for (int i = 0; i < kL; ++ i) {
for (int j = 0; j < kL; ++ j) {
ret.a[i][j] = std::max(ret.a[i][j], a[i][k] + b_.a[k][j]);
}
}
}
return ret;
}
} f[kN][21];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir_, int sec_) {
if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
if (sec_ < fir_) fir_ = sec_;
}
LL Query(int l, int r) {
Matrix ans;
ans.a[0][0] = ans.a[0][1] = -kInf;
for (int i = 20; i >= 0; -- i) {
if (l + (1 << i) - 1 <= r) {
ans = ans * f[l][i];
l += (1 << i);
}
}
return ans.a[0][1];
}
//=============================================================
int main() {
n = read();
for (int i = 1; i <= n; ++ i) {
f[i][0].a[0][0] = f[i][0].a[2][0] = f[i][0].a[0][1] = f[i][0].a[2][1]
= read();
f[i][0].a[1][0] = f[i][0].a[0][2] = f[i][0].a[1][2] = -kInf;
}
for (int i = 1; i <= 20; ++ i) {
for (int j = 1; j + (1 << i) - 1 <= n; ++ j) {
f[j][i] = f[j][i - 1] * f[j + (1 << (i - 1))][i - 1];
}
}
m = read();
for (int i = 1; i <= m; ++ i) {
int l = read(), r = read();
printf("%lld\n", Query(l, r));
}
return 0;
}
動態區間查詢
SP1716 GSS3 - Can you answer these queries III
給定一個長度為 \(n\) 的數列 \(a\),給定 \(m\) 次操作:
- 單點修改。
- 給定區間 \([l,r]\),要求選出區間 \([l,r]\) 中連續且非空的一段使得這段和最大,輸出最大子段和。
\(1\le n,m\le 5\times 10^4\),\(-10^4\le a_i\le 10^4\)。
330ms,1.46G。
在上題的基礎上加入了單點修改操作。發現每次修改僅會影響對應位置的矩陣,以及包含該位置的區間的廣義矩陣乘積,考慮線段樹維護廣義矩陣乘積,每次修改僅需更新自葉到根的 \(\log n\) 個位置的對應區間。總時間復雜度 \(O\left(3^3 (n+m)\log n\right)\)。
//知識點:矩陣乘法,線段樹
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 5e4 + 10;
const int kL = 3;
const LL kInf = 1e9 + 2077;
//=============================================================
int n, m, a[kN];
struct Matrix {
LL a[kL][kL];
Matrix() {
memset(a, 0, sizeof (a));
}
void build() {
for (int i = 1; i <= kL; ++ i) a[i][i] = 1;
}
Matrix operator * (const Matrix &b_) const {
Matrix ret;
memset(ret.a, 128, sizeof (ret.a));
for (int k = 0; k < kL; ++ k) {
for (int i = 0; i < kL; ++ i) {
for (int j = 0; j < kL; ++ j) {
ret.a[i][j] = std::max(ret.a[i][j], a[i][k] + b_.a[k][j]);
}
}
}
return ret;
}
};
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir_, int sec_) {
if (sec_ > fir_) fir_ = sec_;
}
void Chkmin(int &fir_, int sec_) {
if (sec_ < fir_) fir_ = sec_;
}
namespace Seg {
#define ls (now_<<1)
#define rs (now_<<1|1)
#define mid ((L_+R_)>>1)
Matrix sum[kN << 2];
void Pushup(int now_) {
sum[now_] = sum[ls] * sum[rs];
}
void Build(int now_, int L_, int R_) {
if (L_ == R_) {
sum[now_].a[0][0] = sum[now_].a[2][0] = sum[now_].a[0][1]
= sum[now_].a[2][1] = a[L_];
sum[now_].a[1][0] = sum[now_].a[0][2] = sum[now_].a[1][2] = -kInf;
return ;
}
Build(ls, L_, mid), Build(rs, mid + 1, R_);
Pushup(now_);
}
void Modify(int now_, int L_, int R_, int pos_, LL val_) {
if (L_ == R_) {
sum[now_].a[0][0] = sum[now_].a[2][0] = sum[now_].a[0][1]
= sum[now_].a[2][1] = val_;
return ;
}
if (pos_ <= mid) Modify(ls, L_, mid, pos_, val_);
else Modify(rs, mid + 1, R_, pos_, val_);
Pushup(now_);
}
Matrix Query(int now_, int L_, int R_, int l_, int r_) {
if (l_ == L_ && R_ == r_) return sum[now_];
if (r_ <= mid) return Query(ls, L_, mid, l_, r_);
if (l_ > mid) return Query(rs, mid + 1, R_, l_, r_);
return Query(ls, L_, mid, l_, mid) * Query(rs, mid + 1, R_, mid + 1, r_);
}
#undef ls
#undef rs
#undef mid
}
int Query(int l_, int r_) {
Matrix ans;
ans.a[0][0] = ans.a[0][1] = -kInf;
return (ans * Seg::Query(1, 1, n, l_, r_)).a[0][1];
}
//=============================================================
int main() {
n = read();
for (int i = 1; i <= n; ++ i) a[i] = read();
Seg::Build(1, 1, n);
m = read();
for (int i = 1; i <= m; ++ i) {
int opt = read(), x = read(), y = read();
if (opt == 0) Seg::Modify(1, 1, n, x, y);
if (opt == 1) printf("%d\n", Query(x, y));
}
return 0;
}
動態樹形 DP
給定一棵 \(n\) 個點的樹,點有點權。給定 \(m\) 次點權修改操作,求每次操作后整棵樹的 最大點權獨立集 的權值。
一棵樹的獨立集定義為滿足任意一條邊的兩端點都不同時存在於集合中的樹的一個點集,一個獨立集的價值定義為集合中所有點的點權之和。
\(1\le n,m\le 10^5\),\(-100\le\) 點權 \(\le 100\)。
1S,256MB。
先考慮朴素 DP。欽定 1 為根,設 \(f_{u,0/1}\) 表示欽定點 \(u\) 不在/在 獨立集時以 \(u\) 為根的子樹的最大點權獨立集的權值,顯然有:
答案即為 \(\max (f_{1,0}, f_{1, 1})\)。
要求支持修改,又樹的形態不變,考慮用樹鏈剖分維護。但發現每個節點的 DP 值與其所有兒子有關,而樹剖只能支持修改重鏈/子樹信息。於是考慮對於每個節點,先將其輕兒子的貢獻求和,再考慮其重兒子的貢獻,使得可以通過對重鏈的修改/查詢來維護上述信息。這種思想在 LCT 維護子樹信息時也有所應用。
記 \(g_{u,0/1}\) 表示欽定 \(u\) 的重兒子不在獨立集,點 \(u\) 不在/在 獨立集時以 \(u\) 為根的子樹的最大點權獨立集的權值。記 \(\operatorname{s}_u\) 表示 \(u\) 的重兒子,顯然有:
則對 \(f\) 的轉移可以改寫成下列形式:
出現了一個熟悉的形式,套路地定義廣義矩陣乘法 \(A\times B = C\),滿足:
根據上述轉移方程,有下列關系成立。
於是可以考慮先預處理出 \(g\) 數組初始化轉移矩陣,再使用線段樹維護區間矩陣乘積。轉移矩陣寫在前面是因為 dfs 序列中深度較淺的點在前,轉移矩陣寫在前面可以直接按 dfs 序求得區間矩陣乘積並轉移。若轉移矩陣寫在后面,需要先將區間內的元素順序反轉。經過預處理后,求得以 1 為根的重鏈對應區間的矩陣乘積,即得 \(f_{u,0}\) 與 \(f_{u,1}\)。正確性顯然,重鏈一定以某葉節點為鏈底,以 1 為根的重鏈上所有輕兒子子樹信息的並即為整棵樹的信息。
考慮修改操作對哪些位置的 \(g\) 會產生影響。考慮其實際含義,\(g\) 維護的是輕兒子子樹信息。被影響的節點顯然為指定的修改位置 \(x\),以及子樹中包含被修改位置,且為輕兒子的節點的父親,后者可以通過從被修改位置不斷跳重鏈來進行遍歷。每次跳到的重鏈的頂的父親,即為對應節點。
每次更新上述節點時先求得修改前以該節點的對應輕兒子的子樹信息,修改子樹中的節點后再求得該節點的對應輕兒子子樹信息。根據兩次求得的子樹信息的差更新該節點的 \(g\),並將即將被修改的節點調整為當前節點。建議結合代碼理解。
總復雜度 \(O(8n\log n + 8m\log^2 n)\) 級別。
//知識點:樹形 DP,矩陣乘法,重鏈剖分,線段樹
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 1e5 + 10;
const int kL = 2;
const int kInf = 1e9 + 2077;
//=============================================================
int n, m, e_num, head[kN], val[kN], v[kN << 1], ne[kN << 1];
int dfn_num, dfn[kN], id[kN], f[kN][2], g[kN][2];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
void Chkmax(int &fir, int sec) {
if (sec > fir) fir = sec;
}
void Chkmin(int &fir, int sec) {
if (sec < fir) fir = sec;
}
struct Matrix {
int a[kL][kL];
Matrix() {
memset(a, 0, sizeof (a));
}
void build() {
for (int i = 1; i <= kL; ++ i) a[i][i] = 1;
}
Matrix operator * (const Matrix &b_) const {
Matrix ret;
memset(ret.a, 128, sizeof (ret.a));
for (int k = 0; k < kL; ++ k) {
for (int i = 0; i < kL; ++ i) {
for (int j = 0; j < kL; ++ j) {
ret.a[i][j] = std::max(ret.a[i][j], a[i][k] + b_.a[k][j]);
}
}
}
return ret;
}
} matrix[kN];
void Add(int u_, int v_) {
v[++ e_num] = v_;
ne[e_num] = head[u_];
head[u_] = e_num;
}
namespace Seg { //維護區間矩陣乘積
#define ls (now_<<1)
#define rs (now_<<1|1)
#define mid ((L_+R_)>>1)
Matrix sum[kN << 2];
void Pushup(int now_) {
sum[now_] = sum[ls] * sum[rs];
}
void Build(int now_, int L_, int R_) {
if (L_ == R_) {
sum[now_] = matrix[L_];
return ;
}
Build(ls, L_, mid), Build(rs, mid + 1, R_);
Pushup(now_);
}
void Modify(int now_, int L_, int R_, int pos_) {
if (L_ == R_) {
sum[now_] = matrix[pos_];
return ;
}
if (pos_ <= mid) Modify(ls, L_, mid, pos_);
else Modify(rs, mid + 1, R_, pos_);
Pushup(now_);
}
Matrix Query(int now_, int L_, int R_, int l_, int r_) {
if (l_ == L_ && R_ == r_) return sum[now_];
if (r_ <= mid) return Query(ls, L_, mid, l_, r_);
if (l_ > mid) return Query(rs, mid + 1, R_, l_, r_);
return Query(ls, L_, mid, l_, mid) * Query(rs, mid + 1, R_, mid + 1, r_);
}
#undef ls
#undef rs
#undef mid
}
namespace HLD {
int fa[kN], sz[kN], son[kN], dep[kN], top[kN], end[kN];
void Dfs1(int u_, int fa_) {
sz[u_] = 1;
fa[u_] = fa_;
f[u_][1] = val[u_];
dep[u_] = dep[fa_] + 1;
for (int i = head[u_]; i; i = ne[i]) { //預處理 f
int v_ = v[i];
if (v_ == fa_) continue ;
Dfs1(v_, u_);
sz[u_] += sz[v_];
if (sz[v_] > sz[son[u_]]) son[u_] = v_;
f[u_][0] += std::max(f[v_][0], f[v_][1]);
f[u_][1] += f[v_][0];
}
}
void Dfs2(int u_, int top_) {
dfn[u_] = ++ dfn_num;
id[dfn_num] = u_;
top[u_] = top_;
Chkmax(end[top_], dfn_num);
if (son[u_]) Dfs2(son[u_], top_);
g[u_][1] = val[u_];
for (int i = head[u_]; i; i = ne[i]) { //預處理 g
int v_ = v[i];
if (v_ == fa[u_] || v_ == son[u_]) continue ;
Dfs2(v_, v_);
g[u_][0] += std::max(f[v_][0], f[v_][1]);
g[u_][1] += f[v_][0];
}
}
void Modify(int u_, int val_) {
matrix[dfn[u_]].a[1][0] += val_ - val[u_]; //修改 u_ 的 g[u_][1]
val[u_] = val_; //更新點權
while (u_) { //u_ 不斷上跳
Matrix old = Seg::Query(1, 1, n, dfn[top[u_]], end[top[u_]]); //以 top[u_] 為根的子樹的信息
Seg::Modify(1, 1, n, dfn[u_]); //修改節點 u_ 的信息(單點修改矩陣)
Matrix newone = Seg::Query(1, 1, n, dfn[top[u_]], end[top[u_]]); //更新后以 top[u_] 為根的子樹的信息
u_ = fa[top[u_]]; //更新輕兒子 u_ 的父親的 g
//注意下文的賦值還未更新到線段樹上,上面需要求得未修改之前的信息,更新線段樹信息要在之后進行
matrix[dfn[u_]].a[0][0] += std::max(newone.a[0][0], newone.a[1][0]) -
std::max(old.a[0][0], old.a[1][0]);
matrix[dfn[u_]].a[0][1] = matrix[dfn[u_]].a[0][0];
matrix[dfn[u_]].a[1][0] += newone.a[0][0] - old.a[0][0];
}
}
int Query() { //求得以 1 為根的重鏈對應區間的矩陣乘積,即得答案
//重鏈一定以某葉節點為鏈底,以 1 為根的重鏈上所有輕兒子子樹信息的並即為整棵樹的信息。
Matrix ans = Seg::Query(1, 1, n, 1, end[1]);
return std::max(ans.a[0][0], ans.a[1][0]);
}
}
//=============================================================
int main() {
n = read(), m = read();
for (int i = 1; i <= n; ++ i) val[i] = read();
for (int i = 1; i < n; ++ i) {
int u_ = read(), v_ = read();
Add(u_, v_), Add(v_, u_);
}
HLD::Dfs1(1, 0), HLD::Dfs2(1, 1);
for (int i = 1; i <= n; ++ i) { //構造轉移矩陣
matrix[dfn[i]].a[0][0] = matrix[dfn[i]].a[0][1] = g[i][0];
matrix[dfn[i]].a[1][0] = g[i][1], matrix[dfn[i]].a[1][1] = -kInf;
}
Seg::Build(1, 1, n);
while (m --) {
int x_ = read(), y_ = read();
HLD::Modify(x_, y_);
printf("%d\n", HLD::Query());
}
return 0;
}
例題
先咕着。
寫在最后
鳴謝:
矩陣 - OI Wiki
矩陣乘法的結合律_jiongjiong 的專欄-CSDN博客_矩陣乘法結合律
動態 DP - OI Wiki
洛谷日報#130[GKxx]動態DP入門 - GKxx