题意
给出一个n和一个k,求1~n的每个区间的第k大的总和是多少,区间长度小于k的话,贡献为0.
思路
首先有一个关系:当一个数是第k大的时候,前面有x个比它大的数,那么后面就有k-x-1个比它大的数。
比赛的时候队友想出了用set来维护。一开始是一个空的set,先插入大的数,那么当之后插入数的时候,他们之间的pos距离就代表它有多少个小于它的,然后根据上面的关系,对于每个数最多使得迭代器跳k次,就可以快速维护了。其实想法和正解差不多,但是因为其迭代器使用不熟练,而且我还死磕自己错误的想法。
题解的思路其实差不多,一开始先维护一个满的链表,然后从小到大删除,每次算完一个数,就在链表里面删除,算x的时候,保证删除的数都比x小,都可以用来算贡献。i和pre[i]和nxt[i]的距离就是小于当前的数的数目+1。
链表
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 10;
typedef long long LL;
int pre[N], nxt[N], v[N], pos[N], n, k;
LL a[N], b[N];
LL solve(int x) {
int c1 = 0, c2 = 0;
for(int i = x; i && c1 <= k; i = pre[i])
a[++c1] = i - pre[i];
for(int i = x; i <= n && c2 <= k; i = nxt[i])
b[++c2] = nxt[i] - i;
LL ans = 0;
for(int i = 1; i <= c1; i++)
if(k - i + 1 <= c2 && k - i + 1 >= 1)
ans += a[i] * b[k-i+1];
return ans;
}
void del(int x) {
pre[nxt[x]] = pre[x];
nxt[pre[x]] = nxt[x];
}
int main() {
int t; scanf("%d", &t);
while(t--) {
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; i++) scanf("%d", &v[i]), pos[v[i]] = i;
for(int i = 0; i <= n + 1; i++) pre[i] = i - 1, nxt[i] = i + 1;
pre[0] = 0; nxt[n+1] = n + 1;
LL ans = 0;
for(int i = 1; i <= n; i++) {
ans += solve(pos[i]) * i;
del(pos[i]);
} printf("%lld\n", ans);
} return 0;
}
set
超时了。
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 10;
const int INF = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
typedef long long LL;
set<int> se;
int pos[N];
int x[N], y[N];
int main() {
int t; scanf("%d", &t);
while(t--) {
int n, k, z; scanf("%d%d", &n, &k);
for(int i = 1; i <= n; i++) scanf("%d", &z), pos[z] = i;
LL ans = 0;
se.clear();
se.insert(0); se.insert(n+1);
for(int i = n; i; i--) {
se.insert(pos[i]);
int c1 = 0, c2 = 0;
set<int>::iterator it = lower_bound(se.begin(), se.end(), pos[i]), it1, it2;
it1 = it, it2 = it;
while(c1 <= k && *it2 != 0) { // 向前面找k个
it2--;
x[++c1] = *it1 - *it2;
it1--;
}
it1 = it, it2 = it;
while(c2 <= k && *it2 != n + 1) { // 向后面找k个
it2++;
y[++c2] = *it2 - *it1;
it1++;
}
for(int j = 1; j <= c1; j++)
if(k - j + 1 <= c2) ans += 1LL * i * x[j] * y[k-j+1];
// printf("%d : %I64d\n", i, ans);
} printf("%I64d\n", ans);
} return 0;
}