引子
你某天在洛谷里刷題,夢想着有一天AK IOI(@DXR),這時,你看到了一個橙題,但是AC率僅僅只有 \(\frac{1}{3}\) ,你尋思着一道橙題會有多難,於是決定寫這道題
題目
-
對於一個遞歸函數\(w(a,b,c)\)
-
如果\(a \le 0\) or \(b \le 0\) or \(c \le 0\)就返回值\(1\).
-
如果\(a > 20\) or \(b > 20\) or \(c > 20\)就返回\(w(20,20,20)\)
-
如果\(a < b\)並且\(b < c\) 就返回\(w(a,b,c-1)+w(a,b-1,c-1)-w(a,b-1,c)\)
-
其它的情況就返回\(w(a-1,b,c)+w(a-1,b-1,c)+w(a-1,b,c-1)-w(a-1,b-1,c-1)\)
你不屑的把所有的公式都打上去
#include<cstdio>
#include<iostream>
using namespace std;
inline long long w(long long a,long long b,long long c)
{
if(a<=0||b<=0||c<=0) return 1;
else if(a>20||b>20||c>20) return w(20,20,20);
else if(a<b&&b<c) return w(a,b,c-1)+w(a,b-1,c-1)-w(a,b-1,c);
else return w(a-1,b,c)+w(a-1,b-1,c)+w(a-1,b,c-1)-w(a-1,b-1,c-1);
return w(a-1,b,c)+w(a-1,b-1,c)+w(a-1,b,c-1)-w(a-1,b-1,c-1);
}
int main()
{
long long a,b,c;
for(register int i=1; ;i++)
{
scanf("%lld%lld%lld",&a,&b,&c);
if(a==-1&&b==-1&&c==-1) break;
printf("w(%lld, %lld, %lld) = ",a,b,c);
printf("%lld\n",w(a,b,c));
}
return 0;
}
然后看到的會是
然后你陷入了沉思...
正題
很顯然,我們打的這個程序十分的垃圾,當輸入50,50,50的數據時,就會瞬間爆炸,全程TLE,比賽時是絕對不能出現這種情況的
那么我們如何解決這種問題呢
這個題目全都是遞歸公式,也許我們可以在這上面下手......
你用你聰明的大腦想到,有時候遞歸的a,b,c會是和之前某個時候相等的,我們可以開一個數組存儲一下......
然后......
#include<cstdio>
#include<iostream>
using namespace std;
long long m[25][25][25];
inline long long w(long long a,long long b,long long c)
{
if(a<=0||b<=0||c<=0) return 1;
else if(m[a][b][c]!=0) return m[a][b][c];
else if(a>20||b>20||c>20) m[a][b][c]=w(20,20,20);
else if(a<b&&b<c) m[a][b][c]=w(a,b,c-1)+w(a,b-1,c-1)-w(a,b-1,c);
else m[a][b][c]=w(a-1,b,c)+w(a-1,b-1,c)+w(a-1,b,c-1)-w(a-1,b-1,c-1);
return m[a][b][c];
}
int main()
{
long long a,b,c;
for(register int i=1; ;i++)
{
scanf("%lld%lld%lld",&a,&b,&c);
if(a==-1&&b==-1&&c==-1) break;
printf("w(%lld, %lld, %lld) = ",a,b,c);
printf("%lld\n",w(a,b,c));
}
return 0;
}
光榮\({RE}\)
但是這個時候,我們的時間復雜度會下降不少,起碼當數據為 \(50,50,50\) 時,我們不會爆炸,RE的原因只是因為題目范圍,經過一番 玄學 分析之后,我們發現當 \(a,b,c\) 中任意一個值大於 \(20\) 時,返回值都是一樣的,我們只需要加個判斷即可
#include<cstdio>
#include<iostream>
using namespace std;
long long m[25][25][25];
inline long long w(long long a,long long b,long long c)
{
if(a<=0||b<=0||c<=0) return 1;
else if(m[a][b][c]!=0) return m[a][b][c];
else if(a>20||b>20||c>20) m[a][b][c]=w(20,20,20);
else if(a<b&&b<c) m[a][b][c]=w(a,b,c-1)+w(a,b-1,c-1)-w(a,b-1,c);
else m[a][b][c]=w(a-1,b,c)+w(a-1,b-1,c)+w(a-1,b,c-1)-w(a-1,b-1,c-1);
return m[a][b][c];
}
int main()
{
long long a,b,c;
for(register int i=1; ;i++)
{
scanf("%lld%lld%lld",&a,&b,&c);
if(a==-1&&b==-1&&c==-1) break;
printf("w(%lld, %lld, %lld) = ",a,b,c);
if(a>20) a=21;
if(b>20) b=21;
if(c>20) c=21;
printf("%lld\n",w(a,b,c));
}
return 0;
}
講解
看到這里,我相信大家對記憶化搜索已經有一個基本了解了,它就是一個換裝玩角色扮演的DFS,偽裝成高級的樣子和DP混在一起(這點我后面會講解)
我們結合另一道題理解
看完題目,你會想到一個個枚舉每一個點,進行大爆搜,如果你是這么想的,請重新回到文章篇頭再看一遍記憶化搜索的思路是什么,因為這題實際是記憶化搜索
對於每一個點,我們進行一次搜索,找出它滑雪距離的最大值,存儲下來,方便下一次使用
#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;
int dx[4]={0,0,1,-1};
int dy[4]={1,-1,0,0};
int n,m,a[201][201],s[201][201],ans;
int dfs(int x,int y){
if(s[x][y])return s[x][y];//記憶化搜索
s[x][y]=1;//題目中答案是有包含這個點的
for(int i=0;i<4;i++)
{ int xx=dx[i]+x;
int yy=dy[i]+y;//四個方向
if(xx>0&&yy>0&&xx<=n&&yy<=m&&a[x][y]>a[xx][yy]){
dfs(xx,yy);
s[x][y]=max(s[x][y],s[xx][yy]+1);
}
}
return s[x][y];
}
int main()
{
scanf("%d%d",&n,&m);//同題目的R,C
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
scanf("%d",&a[i][j]);
for(int i=1;i<=n;i++)//找從每個出發的最長距離
for(int j=1;j<=m;j++)
ans=max(ans,dfs(i,j));//取最大值
printf("%d",ans);
return 0;
}
然后,我們發現一個神奇的地方......
s[x][y]=max(s[x][y],s[xx][yy]+1)
這行代碼,是不是特別熟悉?
沒錯,這就是一個狀態轉移方程,所以我說記憶化搜索就是玩角色扮演的DFS,裝作DP的亞子,所以實際上這題也可以用DP來解,和上面的代碼也差不太多
總結
再仔細觀察一下咱的代碼,發現記憶化搜索中的DFS函數幾乎不需要外部變量( 自力更生),這也是記憶化搜索的特點之一
所以我們得出記憶化搜索的總結
- 不依賴任何 外部變量
- 答案以返回值的形式存在,而不能以參數的形式存在(就是不能將 dfs 定義成 \(dfs( pos , tleft , nowans )\),這里面的 \(nowans\) 不符合要求)。
- 對於相同一組參數,dfs 返回值總是相同的
例題
既然說記憶化搜索就是玩角色扮演的DFS,裝作DP,那么幾乎所有的DP,都可以用記憶化求解(好耶)
DP Code
#include<stdio.h>
int max(int a,int b)
{
if (a>b) return a;
else return b;
}
int main()
{
int f[1000]={0},c[1000],w[1000];
int n,v,i,j;
scanf("%d%d",&v,&n);
for(i=1;i<=n;i++)scanf("%d%d",&c[i],&w[i]);
for(i=1;i<=n;i++)
for(j=v;j>=c[i];j--)
{
f[j]=max(f[j],f[j-c[i]]+w[i]);
}
printf("%d ",f[v]);
return 0;
}
記憶化 Code
int n,t;
int tcost[103],mget[103];
int mem[103][1003];
int dfs(int pos,int tleft){
if( mem[pos][tleft] != -1 ) return mem[pos][tleft];
if(pos == n+1)
return mem[pos][tleft] = 0;
int dfs1,dfs2 = -INF;
dfs1 = dfs(pos+1,tleft);
if( tleft >= tcost[pos] )
dfs2 = dfs(pos+1,tleft-tcost[pos]) + mget[pos];
return mem[pos][tleft] = max(dfs1,dfs2);
}
int main(){
memset(mem,-1,sizeof(mem));
cin >> t >> n;
for(int i = 1;i <= n;i++)
cin >> tcost[i] >> mget[i];
cout << dfs(1,t) << endl;
return 0;
}