經典的1D1D動態規划題目,標准做法是平衡樹維護凸殼,但實際上還有更簡潔的分治法。
首先分析一下題目,對於任意一天,一定是貪心地買入所有貨幣或者賣出所有貨幣是最優的,因為有便宜我們就要盡量去占,有虧損就一點也不去碰。於是我們得到方程:
f[i]=max{f[j]/(a[j]*rate[j]+b[j])*rate[j]*a[i]+f[j]/(a[j]*rate[j]+b[j])*b[i]}
其中,x[j]=f[j]/(a[j]*rate[j]+b[j])*rate[j]表示第j天最多可以擁有的A貨幣的數量
y[j]=f[j]/(a[j]*rate[j]+b[j])表示第j天最多可以擁有的B貨幣的數量
那么方程可化簡為f[i]=max{x[j]*a[i]+y[j]*b[i]},那么我們就是要選擇一個最優的決策點(x[j],y[j])來更新f[i]得到最優解。
變形:y[j]=f[i]/b[i]-x[j]*a[i]/b[i],這是一個直線的斜截式方程,由於我們是用j去更新i,那么就相當於每次用一條斜率為-a[i]/b[i]的直線去切由若干(x[j],y[j])點組成的集合,能得到的最大截距的點,就是最優決策點,進一步,就是要維護一個由若干(x[j],y[j])點組成凸殼,因為最優決策點一定在凸殼上。
但是對於斜率-a[i]/b[i]和點(x[j],y[j])都是無序的,於是我們只能用一棵平衡樹來維護凸殼,每次找到斜率能卡到的點(此點左側的斜率和右側的斜率恰好夾住-a[i]/b[i]斜率)。
具體splay實現:我們維護x坐標遞增的點集,每次把新點插入到相應位置,更新凸殼的時候,分別找到新點左右能與它組成新的凸殼的點,把中間的點刪掉;如果這個點完全在舊的凸殼內,那么把這個點刪掉。每次找最優決策點的時候就拿-a[i]/b[i]去切凸殼就行了。
但這樣搞實在是麻煩了許多,而且許多人得splay代碼非常的長,在考場上就非常不容易寫出來,於是出現了神一般的陳丹琪分治!
這個神級分治的精髓在於:變在線為離線,化無序為有序。
上面我們分析了,因為點和斜率都不是單調的,所以我們只能用一棵平衡樹去維護。我們考慮導致無序的原因,是我們按照順序依次回答了1..n的關於f值的詢問。但是事實上我們並沒有必要這么做,因為每個1..n的f[i]值,可能成為最優決策點一定在1..i范圍內,而對於每個在1..i范圍內的決策點,一定都有機會成為i+1..n的f值得最優決策點。這樣1..i的f值一定不會受1..i的決策點的影響,i+1..n的點一定不會i+1..n的f值。於是可以分治!
對於一個分治過程solve(l,r),我們用l..mid的決策點去更新mid+1..r這部分的f值,這樣遞歸地更新的話,我們一定可以保證在遞歸到i點的時候,1..i-1的點都已經更新過i點的f值了。我們看到,分治的過程中,左半區(l..mid)和右半區(mid+1..r)這兩個區間的作用是不同的,我們要用左半區已經更新好的f值去求出點(x,y),然后用右半區的斜率去切左半區的點集更新f值。對於左半邊我們需要的只是點(x,y),右半區我們需要的只是斜率-a[i]/b[i],兩部分的順序互不影響。於是,我們在處理好左半邊的東西的時候保證點集按坐標排好序,在處理右半區之前保證詢問按照斜率排好序,這樣相當於用一系列連續變化的直線去切一些連續點組成的凸殼,那么我們就可以簡單地用一個棧來維護連續點組成的凸殼,用掃描的方法更新f值。我們一開始就排好詢問的順序,然后保證在solve之前還原左半區詢問集合的順序,這樣就保證了按照原順序得到f值;在solve之后把兩部分點集歸並,這樣就保證了每個過程中的點集是有序的。
雖然我敘述的比較煩,但是我們看到這個分治的過程是非常優美的,對於詢問我們是先排序,后還原;而對於點集我們是不斷地歸並,恰好是對稱的過程。為什么呢?上面已經說了,因為我們對左右兩個半區的需求是不一樣的,於是這樣就得到了兩個不同的有序序列,把無序化為有序。
這樣看來,分治算法取代一些復雜的數據結構是一種強有力的趨勢。
splay代碼:

1 #include<iostream> 2 #include<cstdio> 3 #include<algorithm> 4 #include<cmath> 5 #include<cstring> 6 #define maxn 120000 7 #define eps 1e-9 8 #define inf 1e9 9 using namespace std; 10 int fa[maxn],c[maxn][2]; 11 double f[maxn],x[maxn],y[maxn],lk[maxn],rk[maxn],a[maxn],b[maxn],rate[maxn]; 12 int n,m,rot,num; 13 14 inline double fabs(double x) 15 { 16 return (x>0)?x:-x; 17 } 18 19 inline void zigzag(int x,int &rot) 20 { 21 int y=fa[x],z=fa[y]; 22 int p=(c[y][1]==x),q=p^1; 23 if (y==rot) rot=x; 24 else if (c[z][0]==y) c[z][0]=x; else c[z][1]=x; 25 fa[x]=z; fa[y]=x; fa[c[x][q]]=y; 26 c[y][p]=c[x][q]; c[x][q]=y; 27 } 28 29 inline void splay(int x,int &rot) 30 { 31 while (x!=rot) 32 { 33 int y=fa[x],z=fa[y]; 34 if (y!=rot) 35 if ((c[y][0]==x)xor(c[z][0]==y)) zigzag(x,rot); else zigzag(y,rot); 36 zigzag(x,rot); 37 } 38 } 39 40 inline void insert(int &t,int anc,int now)//加入平衡樹 41 { 42 if (t==0) 43 { 44 t=now; 45 fa[t]=anc; 46 return ; 47 } 48 if (x[now]<=x[t]+eps) insert(c[t][0],t,now); 49 else insert(c[t][1],t,now); 50 } 51 52 inline double getk(int i,int j)//求斜率 53 { 54 if (fabs(x[i]-x[j])<eps) return -inf; 55 else return (y[j]-y[i])/(x[j]-x[i]); 56 } 57 58 inline int prev(int rot)//求可以和當前點組成凸包的右邊第一個點 59 { 60 int t=c[rot][0],tmp=t; 61 while (t) 62 { 63 if (getk(t,rot)<=lk[t]+eps) tmp=t,t=c[t][1]; 64 else t=c[t][0]; 65 } 66 return tmp; 67 } 68 inline int succ(int rot)//求可以和當前點組成凸包的左邊第一個點 69 { 70 int t=c[rot][1],tmp=t; 71 while (t) 72 { 73 if (getk(rot,t)+eps>=rk[t]) tmp=t,t=c[t][0]; 74 else t=c[t][1]; 75 } 76 return tmp; 77 } 78 79 inline void update(int t)//加入t點 80 { 81 splay(t,rot); 82 if (c[t][0])//向左求凸包 83 { 84 int left=prev(rot); 85 splay(left,c[rot][0]); c[left][1]=0; 86 lk[t]=rk[left]=getk(left,t); 87 } 88 else lk[t]=inf; 89 if (c[t][1])//向右求凸包 90 { 91 int right=succ(rot); 92 splay(right,c[rot][1]); c[right][0]=0; 93 rk[t]=lk[right]=getk(t,right); 94 } 95 else rk[t]=-inf; 96 if (lk[t]<=rk[t]+eps)//在原凸包內部的情況,直接刪掉該點 97 { 98 rot=c[t][0]; c[rot][1]=c[t][1]; fa[c[t][1]]=rot; fa[rot]=0; 99 lk[rot]=rk[c[t][1]]=getk(rot,c[t][1]); 100 } 101 } 102 103 inline int find(int t,double k)//找到當前斜率的位置,即找到最優值 104 { 105 if (t==0) return 0; 106 if (lk[t]+eps>=k&&k+eps>=rk[t]) return t; 107 if (k+eps>lk[t]) return find(c[t][0],k); 108 else return find(c[t][1],k); 109 } 110 111 int main() 112 { 113 //freopen("cash.in","r",stdin); 114 //freopen("cash.out","w",stdout); 115 scanf("%d%lf",&n,&f[0]); 116 for (int i=1;i<=n;i++) scanf("%lf%lf%lf",&a[i],&b[i],&rate[i]); 117 for (int i=1;i<=n;i++) 118 { 119 int j=find(rot,-a[i]/b[i]); 120 f[i]=max(f[i-1],x[j]*a[i]+y[j]*b[i]); 121 y[i]=f[i]/(a[i]*rate[i]+b[i]); 122 x[i]=y[i]*rate[i]; 123 insert(rot,0,i); 124 update(i); 125 } 126 printf("%.3lf\n",f[n]); 127 return 0; 128 }
分治代碼:

1 #include<iostream> 2 #include<cstdio> 3 #include<algorithm> 4 #include<cmath> 5 #include<cstring> 6 #define maxn 120000 7 #define eps 1e-9 8 #define inf 1e9 9 using namespace std; 10 struct query 11 { 12 double q,a,b,rate,k; 13 int pos; 14 }q[maxn],nq[maxn]; 15 double fabs(double x) 16 { 17 return (x>0)?x:-x; 18 } 19 struct point 20 { 21 double x,y; 22 friend bool operator <(const point &a,const point &b) 23 { 24 return (a.x<b.x+eps)||(fabs(a.x-b.x)<=eps&&a.y<b.y+eps); 25 } 26 }p[maxn],np[maxn]; 27 int st[maxn]; 28 double f[maxn]; 29 int n,m; 30 31 double getk(int i,int j) 32 { 33 if (i==0) return -inf; 34 if (j==0) return inf; 35 if (fabs(p[i].x-p[j].x)<=eps) return -inf; 36 return (p[i].y-p[j].y)/(p[i].x-p[j].x); 37 } 38 39 void solve(int l,int r) 40 { 41 if (l==r)//此時l之前包括l的f值已經達到最優,計算出對應的點即可 42 { 43 f[l]=max(f[l-1],f[l]); 44 p[l].y=f[l]/(q[l].a*q[l].rate+q[l].b); 45 p[l].x=p[l].y*q[l].rate; 46 return ; 47 } 48 int mid=(l+r)>>1,l1=l,l2=mid+1; 49 //對詢問集合排序,1位置2斜率 50 for (int i=l;i<=r;i++) 51 if (q[i].pos<=mid) nq[l1++]=q[i]; 52 else nq[l2++]=q[i]; 53 for (int i=l;i<=r;i++) q[i]=nq[i]; 54 //遞歸左區間 55 solve(l,mid); 56 //左半區所有點都以計算好,把它們入棧,維護凸殼 57 int top=0; 58 for (int i=l;i<=mid;i++) 59 { 60 while (top>=2&&getk(i,st[top])+eps>getk(st[top],st[top-1])) top--; 61 st[++top]=i; 62 } 63 //拿左半區更新右半區 64 int j=1; 65 for (int i=r;i>=mid+1;i--)//保證詢問斜率遞減 66 { 67 while (j<top&&q[i].k<getk(st[j],st[j+1])+eps) j++; 68 f[q[i].pos]=max(f[q[i].pos],p[st[j]].x*q[i].a+p[st[j]].y*q[i].b); 69 } 70 //遞歸右區間 71 solve(mid+1,r); 72 //合並左右區間的點,按照x,y排序 73 l1=l,l2=mid+1; 74 for (int i=l;i<=r;i++) 75 if ((p[l1]<p[l2]||l2>r)&&l1<=mid) np[i]=p[l1++]; 76 else np[i]=p[l2++]; 77 for (int i=l;i<=r;i++) p[i]=np[i]; 78 } 79 80 bool cmp(query a,query b) 81 { 82 return a.k<b.k; 83 } 84 85 int main() 86 { 87 //freopen("cash.in","r",stdin); 88 //freopen("cash.out","w",stdout); 89 scanf("%d%lf",&n,&f[0]); 90 for (int i=1;i<=n;i++) 91 { 92 scanf("%lf%lf%lf",&q[i].a,&q[i].b,&q[i].rate); 93 q[i].k=-q[i].a/q[i].b; 94 q[i].pos=i; 95 } 96 sort(q+1,q+n+1,cmp); 97 solve(1,n); 98 printf("%.3lf\n",f[n]); 99 return 0; 100 }