凸優化小結


本文參考自 Wearry 在集訓的講解《DP及其優化》。

簡介

凸優化解決的是一類選擇恰好 \(K\) 個某種物品的最優化問題 , 一般來說這樣的題目在不考慮物品數量限制的條件下會有一個隱性的圖像 , 表示選擇的物品數量與問題最優解之間的關系 .

每個點就是選了 \(K\) 個物品的最優Dp值。(答案)也就是 \((K, f(K))\)

問題能夠用凸優化解決還需要滿足圖像是凸的 , 直觀地理解就是選的物品越多的情況下多選一個物品 , 最優解的增長速度會變慢 .

解法

解決凸優化類型的題目可以采用二分的方法 , 即二分隱性凸殼上最優值所在點的斜率 , 然后忽略恰好 \(K\) 個的限制做一次原問題 .

這樣每次選擇一個物品的時候要多付出斜率大小的代價 , 就能夠根據最優情況下選擇的物品數量來判斷二分的斜率與實際最優值的斜率的大小關系 .

理論上這個斜率一定是整數 , 由於題目性質可能會出現二分不出這個數的情況 , 這時就需要一些實現上的技巧保證能夠找到這個最優解 .

因為相鄰兩個點橫下標差 \(1\) (多選一個),縱坐標都是整數。(對於大部分的題目最優解都是整數)。

這個也就是 CTSC 上講的 帶權二分 啦。

例題

UOJ #104. 【APIO2014】Split the sequence

題意

將一個長為 \(n\) 的序列分成 \(k+1\) 個塊,每次分割得到分割處 左邊的和 與 右邊的和 乘積的分數。

保證序列中每個數非負。最后需要最大化分數,需要求出任意一組方案。

\(2 \le n \le 10^5, 1 \le k \le \min \{n - 1, 200\}\)

題解

直接做斜率優化是 \(O(nk)\) 的,那個十分 簡單 ,注意細節就行了。可以參考 我的代碼

雖然已經過了這題了,但是有更好的做法。也就是對於 \(k \le n - 1\) 也就是 \(k,n\) 同級的時候有更好的做法。

考慮前面講的凸優化,我們考慮二分那個斜率,也就是分數的增長率。

假設二分的值為 \(mid\) ,相當於轉化成沒有分段次數的限制,但是每次分段都要額外付出 \(mid\) 的代價 , 求最大化收益的前提下分段數是多少 .

具體化來說,就例如上圖,那個上凸殼就是答案的圖像,我們當前二分的那個斜率的直線就是那條紅線。

我們當前是最大化 \(f(x) - x\times mid\)

那么我們考慮把紅線向上不斷平移,那么最后接觸到的點就是這條直線與上凸殼的切點。此時答案最大。

那么我們算出的分段數就是 \(x\) ,也就是切點的下標。然后比較一下 \(x\)\(k\) 的關系,判斷應該向哪邊移動。

然后最后得到斜率算出的方案就是最優方案了。

我沒有寫 但聽說細節特別多,輸出方案很惡心。如果想寫的話,可以看下 UOJ 最快的代碼,來自同屆大佬 yww 的。

這個復雜度就是 \(O(n \log w)\) 的,十分優秀。

CF739E Gosha is hunting

題意

你要抓神奇寶貝! 現在一共有 \(n\) 只神奇寶貝。 你有 \(a\) 個『寶貝球』和 \(b\) 個『超級球』。 『寶貝球』抓到第 \(i\) 只神奇寶貝的概率是 \(p_i\) ,『超級球』抓到的概率則是 \(u_i\) 。 不能往同一只神奇寶貝上使用超過一個同種的『球』,但是可以往同一只上既使用『寶貝球』又使用『超級球』(都抓到算一個)。 請合理分配每個球抓誰,使得你抓到神奇寶貝的總個數期望最大,並輸出這個值。

\(n \le 2000\)

題解

不難發現用的球越多,期望增長率越低。這是很好理解的,一開始肯定選更優的神奇寶貝球,然后再選較劣的神奇寶貝球。

這就意味着這個隱性的圖像是上凸的,我們可以類似於上題的套路,我們二分那個斜率。

然后我們就可以忽略個數的限制了。但此處這里有兩個變量,那么我們二分套二分就行了。

假設當前二分的是 \(mid\) ,那么我們每次選擇一個神奇寶貝球就要付出 \(mid\) 的代價。

然后求出最大化收益時需要選多少個神奇寶貝球就行了,這個可以用一個很容易的 dp 求出。

但注意兩個同時選的時候,概率應該是 \(p_a + p_b - p_a \times p_b\)

但此時有一個重要的細節,就是二分到最后斜率求出的答案不一定是正確的。

但是在其中如果我們二分到 最優解要選的球和我最后用的球一樣的話,那么這樣就是一個最優的可行解。

至於原因?無可奉告!

似乎是可能有三點共線的情況,此時選的個數有問題。並且最后需要用給你的個數,不能用求出的個數。

代碼

具體看看代碼。。。反正我也不知道為什么這么多特殊情況。

#include <bits/stdc++.h>

#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)

using namespace std;

inline bool chkmax(double &a, double b) {return b > a ? a = b, 1 : 0;}

inline int read() {
    int x = 0, fh = 1; char ch = getchar();
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') fh = -1;
    for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
    return x * fh;
}

void File() {
#ifdef zjp_shadow
	freopen ("E.in", "r", stdin);
	freopen ("E.out", "w", stdout);
#endif
}

const double eps = 1e-10;

const int N = 2010;

int n, a, b;

double pa[N], pb[N]; int usea, useb; double f;

void Calc(double costa, double costb) {
	f = 0; usea = useb = 0;
	For (i, 1, n) {
		int cura = 0, curb = 0; double res = 0;
		if (chkmax(res, pa[i] - costa)) cura = 1, curb = 0;
		if (chkmax(res, pb[i] - costb)) cura = 0, curb = 1;
		if (chkmax(res, pa[i] + pb[i] - pa[i] * pb[i] - (costa + costb))) cura = curb = 1;
		usea += cura; useb += curb; f += res;
	}
}

int main () {

	File();

	n = read(); a = read(); b = read();
	For (i, 1, n) scanf("%lf", &pa[i]);
	For (i, 1, n) scanf("%lf", &pb[i]);

	double la = 0, ra = 1, lb, rb;
	while (la + eps < ra) {
		double mida = (la + ra) / 2.0; lb = 0, rb = 1;
		while (lb + eps < rb) {
			double midb = (lb + rb) / 2.0;
			Calc(mida, midb);
			if (useb == b) {lb = midb; break; }
			if (useb < b) rb = midb; else lb = midb;
		}
		if (usea == a) { la = mida; break; }
		if (usea < a) ra = mida; else la = mida;
	}
	Calc(la, lb);
	printf ("%.10lf\n", f + la * a + lb * b);

	return 0;
}

LOJ #2478. 「九省聯考 2018」林克卡特樹

題意

LOJ #2478. 「九省聯考 2018」林克卡特樹

請點上面鏈接qwq 題意很好理解的。(但要認真看題)

題解

題意等價於,恰好選 \(k\) 條鏈, 使得他們的長度和最大。

我們同樣可以使用凸優化對於這個來進行優化。

二分那個斜率 \(mid\) ,每次選擇多一條鏈就要減去 \(mid\) ,最后求使得答案最優的時候,需要分成幾段。

但這些都不是重點,重點是如何求出答案最優的時候有多少段。

我們令 dp[u][0/1/2]\(u\) 這個點,向子樹中延伸出 \(0,1,2\) 條鏈。

轉移的話,枚舉一下它從和哪個兒子的鏈相連,計算一下分的段數即可。

為了方便計算段數,在鏈的底部統計上段數,所以合並兩條鏈的時候需要減去一段,並且把權值加回來 \(mid\)

記得要統計上別的子樹的答案!!先掛下 \(dp\) 的代碼吧。

利用 std :: pair<ll, int> 寫的更加方便,第一維表示答案,第二維表示段數。

typedef pair<ll, int> PLI;
#define res first
#define num second
#define mp make_pair

inline PLI operator + (const PLI &lhs, const PLI &rhs) {
	return mp(lhs.res + rhs.res, lhs.num + rhs.num);
}

PLI f[N][3]; ll del;
void Dp(int u = 1, int fa = 0) {
	f[u][0] = mp(0, 0);
	f[u][1] = mp(- del, 1);
	f[u][2] = mp(- inf, 0);

	for (register int i = Head[u]; i; i = Next[i]) {
		register int v = to[i]; if (v == fa) continue ; Dp(v, u);
		PLI tmp = max(f[v][0], max(f[v][1], f[v][2]));

		chkmax(f[u][2], f[u][2] + tmp);
		chkmax(f[u][2], f[u][1] + f[v][1] + mp(val[i] + del, -1));

		chkmax(f[u][1], f[u][1] + tmp);
		chkmax(f[u][1], f[u][0] + f[v][1] + mp(val[i], 0));
		chkmax(f[u][1], f[u][0] + f[v][0] + mp(- del, 1));

		chkmax(f[u][0], f[u][0] + tmp);
	}
}

然后又會有三點共線的情況,也就是對於選擇連續幾個答案都是相同的。

我們發現,利用 std :: pair<ll, int> 的運算符 < ,會在第一維答案相同時優先第二維段數小的在前。

所以我們更新答案的時候就需要在 \(use > k\) 也就是需求大於供給 通貨膨脹 的時候進行更新,不然答案可能更新不到。

如果 \(use = k\) 那么就可以直接退出輸出答案就行啦。

代碼

#include <bits/stdc++.h>

#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)

using namespace std;

typedef long long ll;
template<typename T> inline bool chkmax(T &a, T b) {return b > a ? a = b, 1 : 0;}

namespace pb_ds
{   
	namespace io
	{
		const int MaxBuff = 1 << 15;
		const int Output = 1 << 23;
		char B[MaxBuff], *S = B, *T = B;
#define getc() ((S == T) && (T = (S = B) + fread(B, 1, MaxBuff, stdin), S == T) ? 0 : *S++)
		char Out[Output], *iter = Out;
		inline void flush()
		{
			fwrite(Out, 1, iter - Out, stdout);
			iter = Out;
		}
	}

	inline int read()
	{
		using namespace io;
		register char ch; register int ans = 0; register bool neg = 0;
		while(ch = getc(), (ch < '0' || ch > '9') && ch != '-')     ;
		ch == '-' ? neg = 1 : ans = ch - '0';
		while(ch = getc(), '0' <= ch && ch <= '9') ans = ans * 10 + ch - '0';
		return neg ? -ans : ans;
	}
};

using namespace pb_ds;

void File () {
#ifdef zjp_shadow
	freopen ("2478.in", "r", stdin);
	freopen ("2478.out", "w", stdout);
#endif
}

const int N = 3e5 + 1e3, M = N << 1;

int Head[N], Next[M], to[M], val[M], e = 0;
inline void add_edge(int u, int v, int w) {
	to[++ e] = v; Next[e] = Head[u]; Head[u] = e; val[e] = w;
}

inline void Add(int u, int v, int w) {
	add_edge(u, v, w); add_edge(v, u, w);
}

typedef long long ll;
const ll inf = 1e18;

typedef pair<ll, int> PLI;
#define res first
#define num second
#define mp make_pair

inline PLI operator + (const PLI &lhs, const PLI &rhs) {
	return mp(lhs.res + rhs.res, lhs.num + rhs.num);
}

PLI f[N][3]; ll del;
void Dp(int u = 1, int fa = 0) {
	f[u][0] = mp(0, 0);
	f[u][1] = mp(- del, 1);
	f[u][2] = mp(- inf, 0);

	for (register int i = Head[u]; i; i = Next[i]) {
		register int v = to[i]; if (v == fa) continue ; Dp(v, u);
		PLI tmp = max(f[v][0], max(f[v][1], f[v][2]));

		chkmax(f[u][2], f[u][2] + tmp);
		chkmax(f[u][2], f[u][1] + f[v][1] + mp(val[i] + del, -1));

		chkmax(f[u][1], f[u][1] + tmp);
		chkmax(f[u][1], f[u][0] + f[v][1] + mp(val[i], 0));
		chkmax(f[u][1], f[u][0] + f[v][0] + mp(- del, 1));

		chkmax(f[u][0], f[u][0] + tmp);
	}
}

int n, k, use; PLI ans;

void Calc(ll cur) {
	ans = mp(-inf, 0); del = cur; Dp(); 
	For (i, 0, 2) chkmax(ans, f[1][i]); use = ans.num;
}

ll Ans;
int main () {

	File();

	n = read(), k = read() + 1;
	For (i, 1, n - 1) {
		register int u = read(), v = read(), w = read(); Add(u, v, w);
	}

	ll l = -1e6, r = 8e7;
	while (l <= r) {
		ll mid = (l + r) >> 1;
		Calc(mid);
		if (use == k) return printf ("%lld\n", ans.res + mid * k), 0;
		if (use < k) r = mid - 1;
		else l = mid + 1, Ans = ans.res + mid * k;
	}
	printf ("%lld\n", Ans);

    return 0;

}

LOJ #566. 「LibreOJ Round #10」yanQval 的生成樹

題意

戳進去 >> #566. 「LibreOJ Round #10」yanQval 的生成樹

題意簡單明了 qwq

題解

首先,顯然有 \(\mu\) 是這些數的中位數。

然后我們就很容易想到考慮枚舉中位數 \(mid\) ,然后在 \(w_i < mid\) (白邊)與 \(w_i \ge mid\) (黑邊)分別選 \(\displaystyle \lfloor \frac{n - 1}{2} \rfloor\) 條邊,組成最大生成樹。

這個就顯然可以進行凸優化了,二分斜率 \(k\) ,把白邊權值 \(+k\) ,然后做最大生成樹,看選出白邊的數量與需求的關系就行了。

這樣就得到了一個很好的 \(O(nm \log w ~\alpha (n))\) 的做法啦。(注意此處需要預處理排序,才能達到這個復雜度)

然后這樣顯然不夠,我們繼續考慮之前的權值是什么。白邊的權值為 \(mid + k - w_i\) ,黑邊的為 \(w_i - mid\) 。同時加上一個 \(mid\) 不會改變,那么就是 \(2\times mid + k - w_i\)\(w_i\) 。我們令 \(C=2\times mid + k\) ,那么白邊為 \(C - w_i\) ,黑邊為 \(w_i\)

嘗試一下二分 \(C\) ,然后直接判斷呢?這樣看起來很不真實,但卻是對的。

這樣可以保證在最大生成樹上 \(< mid\)\(\ge mid\) 都各有一半。為什么呢?因為你考慮不存在,那么多的一邊存在換到另外一邊會更優的情況。

具體看官方解釋:

首先對於 \(M\) 如果最大生成樹 \(T(M)\) 含有黑邊 \(w_1-M\) 和白邊 \(M-w_2\) 且 \(w_1<w_2\) ,顯然交換兩條邊為 \(w_2-M,M-w_1\) 更優(因為黑白邊對應重合,交換總是可行的)。故所有黑邊對應的 \(w\) 必然大於所有白邊。那么如果最大生成樹含有 \(w< M\) 的黑邊或 \(w\ge M\) 的白邊,必然只含一種,不妨設為黑邊。那么設最小黑邊原本的權值為 \(w'\) ,取 \(M'=w'\) ,可以發現其余邊的權值之和不變,而這條黑邊的權值從 \(w'-M<0\) 變成了 \(0\) ,增加了,故得到了一棵更大的生成樹,所以這一定不是全局最大生成樹。又由於方案數有限全局最大生成樹(或者 \(n-2\) 條邊生成森林)一定存在,其必然僅含有 \(w\ge M\) 的黑邊和 \(w<M\) 的白邊。

那么我們就除掉一個 \(O(n)\) 的復雜度啦。具體看代碼實現qwq

\(n\) 為偶數其實也是沒問題的,因為你總會選到中位數,不影響答案。

代碼

#include <bits/stdc++.h>

#define For(i, l, r) for(register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for(register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define DEBUG(...) fprintf(stderr, __VA_ARGS__)

using namespace std;

typedef long long ll;

inline bool chkmin(int &a, int b) {return b < a ? a = b, 1 : 0;}
inline bool chkmax(int &a, int b) {return b > a ? a = b, 1 : 0;}

inline int read() {
    int x = 0, fh = 1; char ch = getchar();
    for (; !isdigit(ch); ch = getchar()) if (ch == '-') fh = -1;
    for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
    return x * fh;
}

void File() {
#ifdef zjp_shadow
	freopen ("566.in", "r", stdin);
	freopen ("566.out", "w", stdout);
#endif
}

const int N = 2e5 + 1e3, M = 5e5 + 1e3;

int n, m;

namespace Union_Set {

	int fa[N], Size[N];

	void Init(int maxn) { For (i, 1, maxn) fa[i] = i, Size[i] = 0; }

	int find(int x) { return x == fa[x] ? x : fa[x] = find(fa[x]); }

	inline bool Union(int x, int y) {
		int rtx = find(x), rty = find(y);
		if (rtx == rty) return false;
		if (Size[rtx] < Size[rty]) swap(rtx, rty);
		Size[rtx] += Size[rty]; fa[rty] = rtx; return true;
	}

}

struct Edge {

	int u, v, w;

	inline bool operator < (const Edge &rhs) const { return w > rhs.w; }

} lt[M];

ll ans, res; int use, need;
void Work(int lim) {
	Union_Set :: Init(n); res = use = 0;
	for (register int L = 1, R = m, cur = 0; L <= R; ) {
		Edge add; register bool choose = false;
		if (lt[L].w >= lim - lt[R].w) add = lt[L ++];
		else add = lt[R --], choose = true, add.w = lim - add.w;

		if (Union_Set :: Union(add.u, add.v)) {
			res += add.w; if (choose) ++ use;
			if (++ cur == need << 1) break;
		}
	}
	res -= 1ll * lim * need;
}

int main () {

	File();

	n = read(); m = read(); need = (n - 1) >> 1; if (!need) return puts("0"), 0;
	For (i, 1, m)
		lt[i] = (Edge) {read(), read(), read()};
	sort(lt + 1, lt + m + 1);

	int l = 0, r = min(lt[1].w * 2 + 1, (int) 1e9);
	while (l <= r) {
		int mid = (l + r) >> 1; Work(mid);
		if (use == need) return printf ("%lld\n", res), 0;
		if (use < need) l = mid + 1, ans = res; else r = mid - 1;
	}
	printf ("%lld\n", ans);

    return 0;
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM