斜率優化DP


我們知道,有些DP方程可以轉化成DP[i]=f[j]+x[i]的形式,其中f[j]中保存了只與j相關的量。這樣的DP方程我們可以用單調隊列進行優化,從而使得O(n^2)的復雜度降到O(n)。

 

可是並不是所有的方程都可以轉化成上面的形式,舉個例子:dp[i]=dp[j]+(x[i]-x[j])*(x[i]-x[j])。如果把右邊的乘法化開的話,會得到x[i]*x[j]的項。這就沒辦法使得f[j]里只存在於j相關的量了。於是上面的單調隊列優化方法就不好使了。

 

這里學習一種新的優化方法,叫做斜率優化,其實和凸包差不多,下面會解釋。

 

舉例子說明是最好的!HDU 3507,很適合的一個入門題。http://acm.hdu.edu.cn/showproblem.php?pid=3507

大概題意就是要輸出N個數字a[N],輸出的時候可以連續連續的輸出,每連續輸出一串,它的費用是 “這串數字和的平方加上一個常數M”。

我們設dp[i]表示輸出到i的時候最少的花費,sum[i]表示從a[1]到a[i]的數字和。於是方程就是:

dp[i]=dp[j]+M+(sum[i]-sum[j])^2;

很顯然這個是一個二維的。題目的數字有500000個,不用試了,二維鐵定超時了。那我們就來試試斜率優化吧,看看是如何做到從O(n^2)復雜度降到O(n)的。

 

分析:

我們假設k<j<i。如果在j的時候決策要比在k的時候決策好,那么也是就是dp[j]+M+(sum[i]-sum[j])^2<dp[k]+M+(sum[i]-sum[k])^2。(因為是最小花費嘛,所以優就是小於)

兩邊移項一下,得到:(dp[j]+num[j]^2-(dp[k]+num[k]^2))/(2*(num[j]-num[k]))<sum[i]。我們把dp[j]-num[j]^2看做是yj,把2*num[j]看成是xj。

那么不就是yj-yk/xj-xk<sum[i]么?   左邊是不是斜率的表示? 

那么yj-yk/xj-xk<sum[i]說明了什么呢?  我們前面是不是假設j的決策比k的決策要好才得到這個表示的? 如果是的話,那么就說明g[j,k]=yj-jk/xj-xk<sum[i]代表這j的決策比k的決策要更優。

 

關鍵的來了:現在從左到右,還是設k<j<i,如果g[i,j]<g[j,k],那么j點便永遠不可能成為最優解,可以直接將它踢出我們的最優解集。為什么呢?

我們假設g[i,j]<sum[i],那么就是說i點要比j點優,排除j點。

如果g[i,j]>=sum[i],那么j點此時是比i點要更優,但是同時g[j,k]>g[i,j]>sum[i]。這說明還有k點會比j點更優,同樣排除j點。

排除多余的點,這便是一種優化!

 

接下來看看如何找最優解。

設k<j<i。

由於我們排除了g[i,j]<g[j,k]的情況,所以整個有效點集呈現一種上凸性質,即k j的斜率要大於j i的斜率。

這樣,從左到右,斜率之間就是單調遞減的了。當我們的最優解取得在j點的時候,那么k點不可能再取得比j點更優的解了,於是k點也可以排除。換句話說,j點之前的點全部不可能再比j點更優了,可以全部從解集中排除。

 

於是對於這題我們對於斜率優化做法可以總結如下:

1,用一個單調隊列來維護解集。

2,假設隊列中從頭到尾已經有元素a b c。那么當d要入隊的時候,我們維護隊列的上凸性質,即如果g[d,c]<g[c,b],那么就將c點刪除。直到找到g[d,x]>=g[x,y]為止,並將d點加入在該位置中。

3,求解時候,從隊頭開始,如果已有元素a b c,當i點要求解時,如果g[b,a]<sum[i],那么說明b點比a點更優,a點可以排除,於是a出隊。最后dp[i]=getDp(q[head])。

View Code
 1 #include<iostream>
 2 #include<string>
 3 using namespace std;
 4 
 5 int dp[500005];
 6 int q[500005];
 7 int sum[500005];
 8 int head,tail,n,m;
 9 
10 int getDP(int i,int j)
11 {
12     return dp[j]+m+(sum[i]-sum[j])*(sum[i]-sum[j]);
13 }
14 
15 int getUP(int j,int k)  //yj-yk的部分
16 {
17     return dp[j]+sum[j]*sum[j]-(dp[k]+sum[k]*sum[k]);
18 }
19 
20 int getDOWN(int j,int k) //xj-xk的部分
21 {
22     return 2*(sum[j]-sum[k]);
23 }
24 
25 int main()
26 {
27     int i;
28     freopen("D:\\in.txt","r",stdin);
29     while(scanf("%d%d",&n,&m)==2)
30     {
31         for(i=1;i<=n;i++)
32             scanf("%d",&sum[i]);
33         sum[0]=dp[0]=0;
34         for(i=1;i<=n;i++)
35             sum[i]+=sum[i-1];
36         head=tail=0;
37         q[tail++]=0;
38         for(i=1;i<=n;i++)
39         {
40             while(head+1<tail && getUP(q[head+1],q[head])<=sum[i]*getDOWN(q[head+1],q[head]))
41                 head++;
42             dp[i]=getDP(i,q[head]);
43             while(head+1<tail && getUP(i,q[tail-1])*getDOWN(q[tail-1],q[tail-2])<=getUP(q[tail-1],q[tail-2])*getDOWN(i,q[tail-1]))
44                 tail--;
45             q[tail++]=i;
46         }
47         printf("%d\n",dp[n]);
48     }
49     return 0;
50 }

 

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM