昨晚的比賽題。(像我這種蒟蒻只能打打div2)
題意
給你$n$個物品,每一個物品$i$,有一個權值$w_i$和一個位置$a_i$,定義移動一個物品$i$到位置$t$的代價為$w_i * \left |a_i - t \right |$,要求你寫一個數據結構支持以下兩種操作:
1、修改一個物品的權值
2、查詢把一個區間內全部移到相鄰的位置的最小值。
舉個栗子:如果要把$[l, r]$移到相鄰的位置,就是對於$\forall i \in [l, r]$,要有$pos_i = x + i - l\ (1 \leq x \leq n - (r - l))$,然后要確定這個$x$使移動的總代價最小,最后要求這個最小的代價對$1e9 + 7$取模的結果,每次詢問獨立。
注意:要先使總代價最小然后再取模,而不是取模后最小。
保證給出的$a_i$遞增。
兩個原題:
一個簡單題:
我們有很經典的貨倉選址的模型,就是在直線上有$n$個點,每一個點$i$有一個位置$pos_i$,每一個點有一個貨物。定義運輸貨物的代價是移動的距離。現在要在直線上選擇一個點建立貨倉,要把所有的貨物都運到這個點,要求使使代價最小,求這個最小代價。
很簡單吧,中位數。
抄一段lyd書上的證明:先把所有的點按照$pos_i$排序,假設貨倉建在$X$,左側的點有$P$個,右側的點有$Q$個。如果$P < Q$,那么把$X$往右移動會使答案變優,同理當$P > Q$使把$X$向左移動會使答案變優,所有最優解會在$P == Q$的地方產生。
再抄一句:當$n$是偶數的時候,這時$pos_{\frac{n}{2}}$和$pos_{\frac{n + 1}{2}}$中的點都可以是最優解。
稍微強化板:
現在每一個貨倉$i$里有$w_i$個貨物。
排序后,找到第一個$X$,使$\sum_{i = 1}^{X}a_i \geq \sum_{i = X + 1}^{n}a_i$,$X$就是最優解。
這個東西叫做帶權中位數。
丟一個百度百科的鏈接,里面有證明。 傳送門
我自己把不嚴謹的證明在這里再寫一遍:
假設最優答案在$T$取到,那么有(唔,這里$a_i$代表權值):
$\sum_{i = 1}^{n}a_i *dis(i, T) \leq \sum_{i = 1}^{n}a_i * dis(i, T + 1)$
變形一下:
$\sum_{i = 1}^{T - 1}a_i *dis(i, T) + \sum_{i = T + 1}^{n}a_i *dis(i, T) + a_{T + 1} * dis(T, T + 1)\leq \sum_{i = 1}^{T}a_i *dis(i, T + 1) + \sum_{i = T + 2}^{n}a_i *dis(i, T + 1) + a_T * dis(T, T + 1)$
發現$T$左邊的點走到$T + 1$與走到$T$比,多走了$dis(T, T + 1)$,而右邊的點則少走了$dis(T, T + 1)$。
消掉一模一樣的東西就得到了: $\sum_{i = 1}^{T}a_i \geq \sum_{i = T + 1}^{n}a_i$。
把$T$和$T - 1$代進去也是一樣的結果。
回到這題
那么這題要求移到相鄰的位置,可以理解為先移到同一個位置然后移回來,相對移動不變,我們只要找到這個帶權中位數的位置,就能得到最優解了。
帶上修改,我們可以用兩個樹狀數組來維護,一個維護$\sum_{i = 1}^{n}w_i$,另一個維護$\sum_{i = 1}^{n}w_i*(a_i - i)$,詢問的時候先二分一下找到帶權中位數的位置$pos$,然后對於$pos$左邊的點向右移,對於$pos$右邊的點向左移,就可以計算出答案了。
時間復雜度$O(nlog^2{n})$。
Code:

#include <cstdio> #include <cstring> using namespace std; typedef long long ll; const int N = 2e5 + 5; const ll P = 1e9 + 7; int n, qn; ll a[N], w[N]; template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for(; ch > '9'|| ch < '0'; ch = getchar()) if(ch == '-') op = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } namespace BitSum { ll s[N]; #define lowbit(p) (p & (-p)) inline void modify(int p, ll v) { for(; p <= n; p += lowbit(p)) s[p] += v; } inline ll query(int p) { ll res = 0LL; for(; p > 0; p -= lowbit(p)) res += s[p]; return res; } inline ll getSum(int l, int r) { if(r < l) return 0LL; return query(r) - query(l - 1); } } namespace BitMul { ll s[N]; #define lowbit(p) (p & (-p)) inline void modify(int p, ll v) { v %= P; for(; p <= n; p += lowbit(p)) (s[p] = s[p] + v + P) %= P; } inline ll query(int p) { ll res = 0LL; for(; p > 0; p -= lowbit(p)) (res += s[p]) %= P; return res; } inline ll getSum(int l, int r) { return (query(r) - query(l - 1) + P) % P; } } inline int getPos(int x, int y) { int ln = x, rn = y, mid, res; for(; ln <= rn; ) { mid = (ln + rn) / 2; if(BitSum :: getSum(x, mid) >= BitSum :: getSum(mid + 1, y)) res = mid, rn = mid - 1; else ln = mid + 1; } return res; } inline ll abs(ll x) { return x > 0 ? x : -x; } inline ll max(ll x, ll y) { return x > y ? x : y; } inline ll min(ll x, ll y) { return x > y ? y : x; } inline void solve(int x, int y) { if(x == y) { puts("0"); return; } int pos = getPos(x, y); /* ll res = BitMul :: getSum(x, y); //d1 = 0LL, d2 = 0LL; d1 = (d1 - (BitSum :: getSum(pos, y) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; d1 = (d1 + (BitSum :: getSum(x, pos - 1) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; d2 = (d2 - (BitSum :: getSum(pos + 1, y) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; d2 = (d2 + (BitSum :: getSum(x, pos) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; ll d = 0LL; d = (d - (BitSum :: getSum(pos, y) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; d = (d + (BitSum :: getSum(x, pos - 1) % P) * 1LL * abs(a[pos] - pos) % P + P) % P; */ ll res = 0LL; res = (-BitMul :: getSum(x, pos) + (BitSum :: getSum(x, pos) % P) * abs(a[pos] - pos) % P + P) % P; res = (res - (BitSum :: getSum(pos, y)) % P * abs(a[pos] - pos) % P + BitMul :: getSum(pos, y) + P) % P; // printf("%lld\n", (res + d + P) % P); printf("%lld\n", res); } int main() { read(n), read(qn); for(int i = 1; i <= n; i++) read(a[i]); for(int i = 1; i <= n; i++) { read(w[i]); BitMul :: modify(i, w[i] * (a[i] - i)); BitSum :: modify(i, w[i]); } for(int x, y; qn--; ) { read(x), read(y); if(x < 0) { x = -x; BitSum :: modify(x, -w[x]); BitMul :: modify(x, -1LL * w[x] * (a[x] - x)); w[x] = 1LL * y; BitSum :: modify(x, w[x]); BitMul :: modify(x, 1LL * w[x] * (a[x] - x)); } else solve(x, y); } /* for(int i = 1; i <= n; i++) printf("%lld ", BitSum :: getSum(i, i)); */ return 0; }