我們知道,有些DP方程可以轉化成DP[i]=f[j]+x[i]的形式,其中f[j]中保存了只與j相關的量。這樣的DP方程我們可以用單調隊列進行優化,從而使得O(n^2)的復雜度降到O(n)。
可是並不是所有的方程都可以轉化成上面的形式,舉個例子:dp[i]=dp[j]+(x[i]-x[j])*(x[i]-x[j])。如果把右邊的乘法化開的話,會得到x[i]*x[j]的項。這就沒辦法使得f[j]里只存在於j相關的量了。於是上面的單調隊列優化方法就不好使了。
這里學習一種新的優化方法,叫做斜率優化,其實和凸包差不多,下面會解釋。
舉例子說明是最好的!HDU 3507,很適合的一個入門題。http://acm.hdu.edu.cn/showproblem.php?pid=3507
大概題意就是要輸出N個數字a[N],輸出的時候可以連續連續的輸出,每連續輸出一串,它的費用是 “這串數字和的平方加上一個常數M”。
我們設dp[i]表示輸出到i的時候最少的花費,sum[i]表示從a[1]到a[i]的數字和。於是方程就是:
dp[i]=dp[j]+M+(sum[i]-sum[j])^2;
很顯然這個是一個二維的。題目的數字有500000個,不用試了,二維鐵定超時了。那我們就來試試斜率優化吧,看看是如何做到從O(n^2)復雜度降到O(n)的。
分析:
我們假設k<j<i。如果在j的時候決策要比在k的時候決策好,那么也是就是dp[j]+M+(sum[i]-sum[j])^2<dp[k]+M+(sum[i]-sum[k])^2。(因為是最小花費嘛,所以優就是小於)
兩邊移項一下,得到:(dp[j]+num[j]^2-(dp[k]+num[k]^2))/(2*(num[j]-num[k]))<sum[i]。我們把dp[j]-num[j]^2看做是yj,把2*num[j]看成是xj。
那么不就是yj-yk/xj-xk<sum[i]么? 左邊是不是斜率的表示?
那么yj-yk/xj-xk<sum[i]說明了什么呢? 我們前面是不是假設j的決策比k的決策要好才得到這個表示的? 如果是的話,那么就說明g[j,k]=yj-jk/xj-xk<sum[i]代表這j的決策比k的決策要更優。
關鍵的來了:現在從左到右,還是設k<j<i,如果g[i,j]<g[j,k],那么j點便永遠不可能成為最優解,可以直接將它踢出我們的最優解集。為什么呢?
我們假設g[i,j]<sum[i],那么就是說i點要比j點優,排除j點。
如果g[i,j]>=sum[i],那么j點此時是比i點要更優,但是同時g[j,k]>g[i,j]>sum[i]。這說明還有k點會比j點更優,同樣排除j點。
排除多余的點,這便是一種優化!
接下來看看如何找最優解。
設k<j<i。
由於我們排除了g[i,j]<g[j,k]的情況,所以整個有效點集呈現一種上凸性質,即k j的斜率要大於j i的斜率。

這樣,從左到右,斜率之間就是單調遞減的了。當我們的最優解取得在j點的時候,那么k點不可能再取得比j點更優的解了,於是k點也可以排除。換句話說,j點之前的點全部不可能再比j點更優了,可以全部從解集中排除。
於是對於這題我們對於斜率優化做法可以總結如下:
1,用一個單調隊列來維護解集。
2,假設隊列中從頭到尾已經有元素a b c。那么當d要入隊的時候,我們維護隊列的上凸性質,即如果g[d,c]<g[c,b],那么就將c點刪除。直到找到g[d,x]>=g[x,y]為止,並將d點加入在該位置中。
3,求解時候,從隊頭開始,如果已有元素a b c,當i點要求解時,如果g[b,a]<sum[i],那么說明b點比a點更優,a點可以排除,於是a出隊。最后dp[i]=getDp(q[head])。
View Code
1 #include<iostream> 2 #include<string> 3 using namespace std; 4 5 int dp[500005]; 6 int q[500005]; 7 int sum[500005]; 8 int head,tail,n,m; 9 10 int getDP(int i,int j) 11 { 12 return dp[j]+m+(sum[i]-sum[j])*(sum[i]-sum[j]); 13 } 14 15 int getUP(int j,int k) //yj-yk的部分 16 { 17 return dp[j]+sum[j]*sum[j]-(dp[k]+sum[k]*sum[k]); 18 } 19 20 int getDOWN(int j,int k) //xj-xk的部分 21 { 22 return 2*(sum[j]-sum[k]); 23 } 24 25 int main() 26 { 27 int i; 28 freopen("D:\\in.txt","r",stdin); 29 while(scanf("%d%d",&n,&m)==2) 30 { 31 for(i=1;i<=n;i++) 32 scanf("%d",&sum[i]); 33 sum[0]=dp[0]=0; 34 for(i=1;i<=n;i++) 35 sum[i]+=sum[i-1]; 36 head=tail=0; 37 q[tail++]=0; 38 for(i=1;i<=n;i++) 39 { 40 while(head+1<tail && getUP(q[head+1],q[head])<=sum[i]*getDOWN(q[head+1],q[head])) 41 head++; 42 dp[i]=getDP(i,q[head]); 43 while(head+1<tail && getUP(i,q[tail-1])*getDOWN(q[tail-1],q[tail-2])<=getUP(q[tail-1],q[tail-2])*getDOWN(i,q[tail-1])) 44 tail--; 45 q[tail++]=i; 46 } 47 printf("%d\n",dp[n]); 48 } 49 return 0; 50 }
