斜率DP個人理解


斜率DP

斜率DP的一版模式:給你一個序列,至多或分成m段,每段有花費和限制,問符合情況的最小花費是多少;

一版都用到sum[],所以符合單調,然后就可以用斜率優化了,很模板的東西;

如果看不懂可以先去看一下本博客----斜率DP題目,看一下第一道題目,然后在回來看push,pop是為什么這樣操作;

 

首先通過對方程的化簡得到如下遞推方程
DP[i] = min/max( -a[i]*x[j] + y[j] ) + w[i]; (1<=j<i)

一般情況下,x[j],y[j],a[i]都是單調遞增的,(求最小值,維護的是下右凸包)
當然也可以x[j]單調遞減,y[j]單調遞增,a[i]單調遞增;(求最小值,維護的是下左凸包)

對於DP[i],顯然只要找到一個j使a[i]*x[j]+y[j]最小就可以了,
注意對於DP[i]來說,a[i],w[i]都是常量;

一般對於DP[i] =min/max(-a[i]*x[j] + y[j] )+ w[i],最朴素的時間復雜度是O(n^2);
為什么可以優化呢


設G = -a[i]*x[j] + y[j],
移項: y[j] = a[i]*x[j] + G;
現在的問題就是:已知道a[i]也就是斜率,給你幾個點(x[j],y[j]),找一個點帶入使得G最小;
G是直線與Y軸的交點的縱坐標的值,顯然這個點一定在這些點形成的凸包上,

(圖是x[i],y[i],單調遞增,斜率為正的情況)

因為我們在從小到大遞推求解,求DP[i]的時候DP[j](0<=j<i)都是已知的
所以我們可以在求完DP[i]之后可以馬上把點(x[i],y[i])加入,來維護一個凸包;

這里還需要一個小知識點,就是凸包的維護,如果寫過凸包的話,我們都知道在維護前
都要先把點排序(不管是水平序,還是極角序)
這就是為什么要x[i],y[i]是單調的原因了,只有單調才可以按照遞推的順序直接維護凸包了;

 

但如果所有的點都在凸包上,那么這個優化也就不算優化了,

所以問題變成:
對於一條已知斜率的直線,如何從凸包上找一個點使它與Y軸的交點的縱坐標值最小;

對於一個下凸包,且斜率單調遞增:(求最小值的情況下)
我們現在假設直線和下凸包里斜率最小的直線重合,不斷的變大這條直線的斜率,
也就是沿着這個凸包旋轉,
我們發現,這條直線要么跟凸包的一條直線重合,要么經過凸包的一個點,
且一旦一個點被旋轉過去后,接下來斜率變大的直線都不可會再經過這個點重合,
也就是說一旦一個點被淘汰了,那么它在接下來的過程中也不會被用到,

 

 

這樣我們就有一個O(n)的算法,每次從凸包隊列里從頭比較相臨的倆個點,誰得到的G
比較小,如果后一個點得到的G小,說明前一個點在接下來的狀況下也不是最優的,所以
可以直接淘汰。

而所謂的單調隊列優化其實也是這樣,就是在隊列里維護可能提供最優值的那些狀態,
不斷的插入新的點,不斷的刪掉不符合或者不優的點;
然后在維護的隊列里快速的找到那個使當前狀態最優的那個狀態;

 

 1 #include<cstdio>
 2 #include<cstring>
 3 #include<cstdlib>
 4 #include<iostream>
 5 #include<algorithm>
 6 #include<cmath>
 7 #include<vector>
 8 #include<set>
 9 using namespace std;
10 const int N=50000+10;
11 typedef long long LL;
12 struct Point{
13     LL x,y;
15 Point (LL a=0,LL b=0):x(a),y(b){} 16 Point operator - (const Point &p) const{ 17 return Point(x-p.x,y-p.y); 18 } 19 }; 20 typedef Point Vector; 21 inline LL Cross(const Vector &u,const Vector &v){ 22 return u.x*v.y - u.y*v.x; 23 } 24 int n,M; 25 struct dequeue{ 26 Point q[N]; 27 int head,tail; 28 void init(){ 29 head = 1; tail = 0; 30 } 31 void push(const Point &u){ 32 while (head < tail && Cross(q[tail]-q[tail-1],u-q[tail-1]) <= 0 ) tail--; 33 q[++tail] = u; 34 } 35 Point pop(const LL &k){//斜率的大小 36 while (head < tail && k*q[head].x + q[head].y >= k*q[head+1].x + q[head+1].y ) head++; 37 return q[head]; 38 } 39 }H; 40 // dp[i] = -k*x[j] + y[j] + w; 41 // 寫成結構體常數比較大; 42 void solve(){ 43 44 H.init(); 45 //隊列里初始值得看情況,比如H.push(Point(0,0)); 46 for (int i=1;i<=n;i++){ 47 Point t = H.pop(k); 48 dp[i] = -k*t.x + t.y + W; 49 H.push(Point(x[i],y[i])); 50 } 51 }

 

 

還有就是不滿足單調的,首先是
斜率不滿足單調性,x[i],y[i]還是滿足單調;
這樣凸包還是可以直接維護的,但是找凸包上的點就不能在o(1)的時間找到;
但是我們可以用三分找,因為按照隊列里點的順序G值是先變小后變大的;

也可以二分斜率,因為在凸包上相鄰兩個點的斜率是單調遞增的;

 

 1     用find()代替pop();    
 2     int find(const LL &k){
 3         int l = head, r = tail;
 4         while (r - l >= 3){
 5             int m1 = l + (r-l)/3;
 6             int m2 = r - (r-l)/3;
 7             if (k*q[m1].x+q[m1].y >= k*q[m2].x+q[m2].y ) l = m1+1;
 8             else r = m2-1;
 9         }    
10         int ret = l;
11         for (int i = l+1; i <= r; i++) {
12             if (k*q[i].x+q[i].y <= k*q[ret].x+q[ret].y) ret = i;
13         }
14         return ret;
15     }

 

 

然后如果x[i],y[i]也不滿足單調,這樣就不能直接維護凸包了,需要動態維護凸包
簡單點的就是用set,但是set無法實現kth大,所以得自己寫平衡樹;


先找到插入點前驅,和后繼(水平序),然后分兩邊同時維護凸包,(如果還不太清楚可以看一下本博客的動態凸包的代碼)

再用三分找最小;

要用到的就是findPre(),findNext(),kth();當然也可以在插入的時候記錄下該點跟前驅的斜率,然后

直接查找第一個比讀入斜率大的點就可以,因為在平衡樹里斜率也是滿足二叉樹的性質的,這樣就不用kth()了,

代碼可以參看hust里;


因為一個點被刪除后就不會在進入凸包,時間O(logn),查找要logn;
所以總時間復雜度為O(logn*logn*n);

http://acm.hust.edu.cn/vjudge/problem/viewProblem.action?id=31649

貨幣兌換:splay  dp[i] = ai[i]*x[j]+bi[i]*y[j] ----->  dp[i]/bi[i] = ai[i]/bi[i] *x[j] +y[j];

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<iostream>
  4 #include<algorithm>
  5 #include<cmath>
  6 #include<vector>
  7 #include<cstdlib>
  8 using namespace std;
  9 const int N=100000+10;
 10 const  double eps=1e-8;
 11 inline int dcmp(double x){
 12     return x<-eps ? -1 : x>eps;
 13 }
 14 struct Point{
 15     double x,y;
 16     Point(double a=0,double b=0):x(a),y(b){}
 17     Point operator - (const Point &p)const{
 18         return Point(x-p.x,y-p.y);
 19     }
 20     double operator * (const Point &p)const{
 21         return x*p.y - y*p.x;
 22     }
 23     bool operator < (const Point &p)const{
 24         return dcmp(x-p.x)<0 || (dcmp(x-p.x)==0 && dcmp(y-p.y)<0);
 25     }
 26 };
 27 struct splay_tree{
 28     int sz,root,ch[N][2],pre[N],ss[N];
 29     Point val[N];
 30     void rotate(int x){
 31         int y = pre[x];
 32         int f = (ch[y][0]==x);
 33         ch[y][f^1] = ch[x][f];
 34         pre[ ch[x][f] ] = y;
 35         pre[ x ] = pre[ y ];
 36         ch[ pre[y] ][ ch[ pre[y] ][ 1 ] == y ] = x;
 37         ch[x][f] = y;
 38         pre[y] = x;
 39         pushup(y);
 40     }
 41     void splay(int x,int goal){
 42         while (pre[x] != goal ){
 43             int y = pre[x], z = pre[y];
 44             if (z==goal){
 45                 rotate(x);
 46             }else {
 47                 int f = (ch[z][0]==y);
 48                 if (ch[y][f] == x){
 49                     rotate(x); rotate(x);
 50                 }else {
 51                     rotate(y); rotate(x);
 52                 }
 53             }
 54         }
 55         pushup(x);
 56         if (goal == 0) root=x;
 57     }
 58     void init(){
 59         sz=0; ch[0][0]=ch[0][1]=pre[0]=0; val[0]=Point(0,0); ss[0]=0;
 60     }
 61     void pushup(int x){
 62         ss[x] = ss[ ch[x][0] ] + ss[ ch[x][1] ] + 1; 
 63     }
 64     void insert(Point x){
 65         val[++sz]=x; ss[sz]=1; 
 66         ch[sz][0]=ch[sz][1]=pre[sz]=0;
 67         if (sz==1){
 68             root=1; return;
 69         }
 70         int u,f;
 71         for (u=root; ch[u][f=val[u]<x]; u=ch[u][f]);
 72         ch[u][f] = sz;
 73         pre[sz] = u;
 74         splay(sz,0);
 75         if (sz<=2) return;
 76         ins(sz);    
 77     }
 78     void remove(int x){
 79         int u = findPre(x), v = findNext(x);
 80         splay(u,0); splay(v,u);
 81         ch[v][0]=0;
 82         splay(v,0);
 83     }
 84     int findPre(int x){
 85         splay(x,0);
 86         int u;
 87         if (ch[x][0]==0) return 0;
 88         for (u=ch[x][0]; ch[u][1]; u=ch[u][1]);
 89         return u;
 90     }
 91     int findNext(int x){
 92         splay(x,0);
 93         int u;
 94         if (ch[x][1]==0) return 0;
 95         for (u=ch[x][1]; ch[u][0]; u=ch[u][0]);
 96         return u;
 97     }
 98     void ins(int x){
 99         int u = findPre(x), v = findNext(x);
100         if (u!=0 && v!=0) {
101             double k= (val[u]-val[x])*(val[v]-val[x]);
102             if (dcmp(k)<=0) {
103                 remove(x); return;
104             }
105         }
106         while (1){
107             u=findNext(x);
108             if (u==0) break;
109             v=findNext(u);
110             if (v==0) break;
111             double k=(val[u]-val[x])*(val[v]-val[x]);
112             if (dcmp(k)>=0){
113                 remove(u);
114             }else break;
115         }
116         while (1){
117             u=findPre(x);
118             if (u==0) break;
119             v=findPre(u);
120             if (v==0) break;
121             double k=(val[u]-val[x])*(val[v]-val[x]);
122             if (dcmp(k)<=0){
123                 
124                 remove(u);
125             }else break;
126         }
127     }
128     int kth(int k){
129         int tmp=k;
130         if (k>ss[root]) return 0;
131         int x = root;
132         while (ss[ ch[x][0] ]+1!=k){
133             int c = ss[ ch[x][0] ];
134             if (k<=c) x = ch[x][0];
135             else {
136                 x = ch[x][1];
137                 k -= c+1;
138             }
139         }
140         splay(x,0);
141         return x;
142     }
143     double cal(double k,int x){
144         return k*val[x].x+val[x].y;
145     }
146     Point find(double k){
147         int l=1,r=ss[root];
148         while (r-l>3){
149             int m1= l+(r-l)/3;
150             int m2= r-(r-l)/3;
151             if (cal(k,kth(m1))>cal(k,kth(m2))) r=m2-1;
152             else l=m1+1;
153         }
154         int ret=kth(l);
155         double tmp=cal(k,ret);
156         for (int i=l+1;i<=r;i++){
157             int t=kth(i);
158             double t2=cal(k,t);
159             if (tmp<t2) {
160                 ret=t; tmp=t2;
161             }
162         }
163         return val[ret];
164     }
165     void debug(){
166         printf("root: %d\n",root);print_tree(root);
167     }
168     void print_tree(int x){
169         if (x){
170             print_tree(ch[x][0]);
171             printf("now: %d ,fa: %d ,son0: %d ,son1: %d ,size: %d\n",x,pre[x],ch[x][0],ch[x][1],ss[x]);
172             print_tree(ch[x][1]);
173         }
174     
175     }
176 }H;
177 int n,s;
178 double ak[N],bk[N],rk[N];
179 double dp[N];
180 void solve(){
181     H.init();
182     double x,y;
183     dp[1]=s;
184     y = (double)s/(rk[1]*ak[1]+bk[1]);
185     x = rk[1]*y;
186     H.insert(Point(x,y));
187     for (int i=2;i<=n;i++){
188         Point t = H.find(ak[i]/bk[i]);
189         dp[i] =max(dp[i-1], ak[i]*t.x+bk[i]*t.y);
190         y = dp[i]/(rk[i]*ak[i]+bk[i]);
191         x = rk[i]*y;
192         H.insert(Point(x,y));    
193     }
194     printf("%.3lf\n",dp[n]);
195 }
196 int main(){
197 //    freopen("in.txt","r",stdin);
198 //    freopen("1.out","w",stdout);
199     while (~scanf("%d%d",&n,&s)){
200         for (int i=1;i<=n;i++) scanf("%lf%lf%lf",&ak[i],&bk[i],&rk[i]);
201         solve();
202     }
203 
204     return 0;
205 }
View Code

 

 

這樣對於形如 DP[i] = min/max(-a[i]*x[j]+y[j])+w[i]; (1<=j<i)
的DP方程都可以解決了;


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM