題目
在一維坐標軸上,有\(n\)個點和\(m\)線段。每次可以花費1移動任意點向左或向右移動一個單位距離。問讓每個線段均被至少一個點訪問的最小代價。只要有點和線段有交集,該線段就被訪問過。
題解
有兩個比較簡單的處理:
- 如果某些點在線段內,該線段就可以被刪除。
- 如果線段內含了更小的線段,那么較大的線段可以被刪除。
這個處理可以用數組數組解決。講線段按照左區間遞減排序,右區間遞增排序。如果(L, R)包含(l, r),那么有\(L\le l \le r \le R\)。那么(l, r)必然排在(L,R)之前。
以下線段用區間代替。這樣剩余的區間就和點沒有交集,區間之間也不會互相包含。剩余的區間最優方案就是只被1個點訪問,並且是距離這個區間最近的兩個點中的一個。在兩個相鄰點之間夾着的區間,最優方案一定是前若干個區間被前一個點訪問,剩下的區間被后一個點訪問。(感性理解一下)
故最優方案中,每個點訪問的區間都是它周圍連續的若干個區間。假設點訪問左右側區間的所需的最遠距離分別為\(a,b\),那么最少代價為\(2\min(a,b)+\max(a,b)\)。這樣處理后,就可以用dp解決了。
設\(dp[i][j]\)代表第個\(i\)點前面的區間都訪問好了,並且第\(i\)點之后的連續\(j\)個區間會被\(i\)訪問的最小代價。轉移方程為
其中:
- \(x\)代表點\(i-1\)到點\(i\)之間的區間數
- \(a_{k+1,i}\)代表點\(i\)到它右側的,點\(i-1\)之后的第\(k+1\)個的區間的距離。
- \(b_{i,j}\)代表點\(i\)到它左側的,點\(i\)之后的第\(j\)個的區間的距離。
這個dp是\(O(n+m^2)\)的,需要優化。
\(a_{k+1,i}\)隨着\(k\)增加而減少,\(b_{i,j}\)隨着\(j\)的增加而增加,均是單調的。故可以從小到大枚舉\(j\),從大到小枚舉\(k\),找到臨界位置\(k'\),使得\(a_{k'+1} > b_{i,j}\),這樣就可以把\(\max\)和\(\min\)去掉了,維護一個前綴最小值和后綴最小值即可\(O(1)\)計算出當前的dp值。使用雙指針找出臨界位置。
最終時間復雜度O(n+m)
#include <bits/stdc++.h>
#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N)
typedef long long ll;
using namespace std;
/*-----------------------------------------------------------------*/
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f
const int N = 1e6 + 10;
const double eps = 1e-5;
typedef pair<ll, ll> PII;
vector<int> num;
vector<PII> ret;
int pos[N];
PII segs[N];
ll arr[N];
int tarr[N];
bool del[N];
int mx;
int lowbit(int x) {
return x&-x;
}
void add(int p, int val) {
while(p < mx) {
tarr[p] += val;
p += lowbit(p);
}
}
int get(int p) {
int res = 0;
while(p) {
res += tarr[p];
p -= lowbit(p);
}
return res;
}
int sum(int l, int r) {
return get(r) - get(l - 1);
}
bool cmp(int a, int b) {
if(segs[a].first == segs[b].first) {
return segs[a].second < segs[b].second;
}
return segs[a].first > segs[b].first;
}
int id(int x) {
return lower_bound(num.begin(), num.end(), x) - num.begin() + 1;
}
ll dp[N], tmpdp[N], pre[N], las[N];
int main() {
IOS;
int t;
cin >> t;
while(t--) {
ret.clear();
num.clear();
int n, m;
cin >> n >> m;
for(int i = 1; i <= n; i++) {
cin >> arr[i];
num.push_back(arr[i]);
}
for(int i = 1; i <= m; i++) {
dp[i] = pre[i] = las[i] = 0;
del[i] = 0;
pos[i] = i;
int l, r;
cin >> l >> r;
num.push_back(l);
num.push_back(r);
segs[i] = {l, r};
}
sort(pos + 1, pos + 1 + m, cmp);
sort(num.begin(), num.end());
num.erase(unique(num.begin(), num.end()), num.end());
mx = num.size() + 1;
for(int i = 0; i <= mx; i++) tarr[i] = 0;
for(int i = 1; i <= n; i++) {
add(id(arr[i]), 1);
}
for(int i = 1; i <= m; i++) {
int p = pos[i];
int l = id(segs[p].first), r = id(segs[p].second);
if(sum(l, r)) del[p] = 1;
else add(r, 1);
}
for(int i = 1; i <= m; i++) {
int p = pos[i];
if(del[p]) continue;
ret.push_back(segs[p]);
}
sort(arr + 1, arr + 1 + n);
sort(ret.begin(), ret.end());
int p1 = 0, p2 = 0, prelen, len;
for(int i = 1; i <= n; i++) {
while(p1 < ret.size() && ret[p1].first <= arr[i]) p1++;
while(p2 < ret.size() && (i == n || ret[p2].second <= arr[i + 1])) p2++;
len = p2 - p1;
int sp = prelen;
for(int j = 0; j <= len; j++) {
ll d = j ? ret[p1 + j - 1].first - arr[i] : 0;
if(i == 1) {
ll a = p1 > 0 ? arr[i] - ret[0].second : 0;
tmpdp[j] = 2 * min(a, d) + max(a, d);
} else {
while(sp >= 0) {
ll a = sp < prelen ? arr[i] - ret[p1 - prelen + sp].second : 0;
if(a > d) break;
sp--;
}
tmpdp[j] = las[sp + 1] + d;
if(sp >= 0) tmpdp[j] = min(tmpdp[j], pre[sp] + 2 * d);
}
}
for(int j = 0; j <= len; j++) dp[j] = tmpdp[j];
if(i < n) {
for(int j = 0; j <= len; j++) {
ll d = j < len ? arr[i + 1] - ret[p1 + j].second : 0;
pre[j] = dp[j] + d;
las[j] = dp[j] + 2 * d;
}
for(int j = 1; j <= len; j++) pre[j] = min(pre[j], pre[j - 1]);
for(int j = len - 1; j >= 0; j--) las[j] = min(las[j], las[j + 1]);
}
prelen = len;
}
cout << dp[len] << endl;
}
}