這個題直接暴力求解的話時間復雜度肯定是不行的,所以,我們要計算每個數值的貢獻,對每一個數求他當最小值當了多少次,當最大值當了多少次,最后當最大值的次數乘以這個數值減去當最小值的次數乘以數值就得到這個數的貢獻,依次把這n個數的貢獻加起來就是整個極差之和。
在計算一個數當了多少最值的時候,我們要理解問題,因為區間是連續的,所以,以最小值為例,如果一個數是當前這段區間的最小值,那么他一定是當前這段區間最小的(這不廢話),所以,我們就找到他往左做多能找到多少個連續的數都比他大,記錄這個位置,同理找他右邊有多少個大於它的,這樣就得到一個區間,這個區間是以這個數位最小值,如下圖示可以比較直觀的理解。
加入找以2為最小值的區間,那么他最多可以往左找到3,往右最多可以找到5,那么2作為最小值構成的區間數目為(2+1) * (1+1),如下:
[3, 9, 2], [9, 2], [2], [3, 9, 2, 5], [9, 2, 5], [2, 5]
同理如果2作為最大值也一樣求,最大值區間只有[2]這個區間
這個題目還有一個小技巧就是在預處理每個元素作為最值時,最左到什么位置和最右到什么位置,可以利用已知信息,就是前一個求出的位置來跳着加速,使得時間復雜度不是O(n^2)
代碼:

#include <bits/stdc++.h> using namespace std; const int maxn = 1e5 + 10; const int inf = 0x3f3f3f3f; int a[maxn]; int L[maxn], R[maxn]; void print(int L[], int n) { for (int i = 1; i <= n; i++) printf("%d ", L[i]); puts(""); } int main() { int n; scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%d", &a[i]); a[0] = -1, a[n + 1] = -1; int sum_min = 0; //求以當前元素作為最小值時,最左可以擴展到的元素位置. for (int i = 1; i <= n; i++) { if (a[i] >= a[i - 1]) L[i] = i; else { int tmp = i - 1; while (a[i] < a[tmp]) { if (tmp == L[tmp]) tmp--; else tmp = L[tmp]; } L[i] = tmp + 1; } } //print(L, n); //求以當前元素作為最小值時,最右可以擴展到的元素位置. for (int i = n; i >= 1; i--) { if (a[i] >= a[i + 1]) R[i] = i; else { int tmp = i + 1; while (a[i] < a[tmp]) { if (tmp == R[tmp]) tmp++; else tmp = R[tmp]; } R[i] = tmp - 1; } } //print(R, n); //求作為最小值時每個元素的貢獻,最后需要減去 for (int i = 1; i <= n; i++) { int tmp = (i - L[i] + 1) * (R[i] - i + 1); sum_min += tmp * a[i]; } a[0] = inf, a[n + 1] = inf; int sum_max = 0; //求以當前元素作為最大值時,最左可以擴展到的元素位置. for (int i = 1; i <= n; i++) { if (a[i] <= a[i - 1]) L[i] = i; else { int tmp = i - 1; while (a[i] > a[tmp]) { if (tmp == L[tmp]) tmp--; else tmp = L[tmp]; } L[i] = tmp + 1; } } //print(L, n); //求以當前元素作為最大值時,最右可以擴展到的元素位置. for (int i = n; i >= 1; i--) { if (a[i] <= a[i + 1]) R[i] = i; else { int tmp = i + 1; while (a[i] > a[tmp]) { if (tmp == R[tmp]) tmp++; else tmp = R[tmp]; } R[i] = tmp - 1; } } //print(R, n); //元素作為最大值時的貢獻 for (int i = 1; i <= n; i++) { int tmp = (i - L[i] + 1) * (R[i] - i + 1); sum_max += tmp * a[i]; } printf("%d\n", sum_max - sum_min); return 0; }