動態規划之矩陣鏈相乘問題(算法導論)


問題描述

給定n個矩陣序列,(A1,A2,A3,A4,...,An). 計算他們的乘積:A1A2A3...An.

由於矩陣的乘法運算符合結合律,因而可以通過調整計算順序,從而降低計算量。


樣例分析

比如有三個矩陣分別為:A1: 10*100,A2: 100*5,A3: 5*50

假如現在按照(A1A2)A3的順序計算需要的計算量為:10*100*5+10*5*50=7500次運算。

若按照A1(A2A3)的順序計算,需要的計算量為:100*5*50+10*100*50=75000次運算。

上面兩種不同的運算順序所有的計算量相差十倍。

因而,一種最優的計算順序將能很大程度的減少矩陣連乘的運算量。


問題解析

此問題的目的是尋找一種最優的括號化方案。下面用動態規划的思想來進行分析:

1、動態規划的第一步:尋找最優子結構。為方便起見,使用Ai..j表示AiAi+1...Aj的乘積結果矩陣。對於k(i<=k<j), 計算Ai..j所需要的計算量為:Ai..k 和 Ak+1..j 以及二者相乘的代價和。

2、設m[i][j]為Ai..j的最優計算順序所要花費的代價。則其求解公式為:

if i == j, m[i][j] = 0; //因為只有一個矩陣時計算代碼為0,即不需要計算。

m[i][j]=min{m[i][k] + m[k+1][j] + Pi-1PkPj} i<=k<j

3、為了能夠輸出求解順序,需要保存區間中的一些分割點。假如Ai..j中的最優分割點為k,則我們使用s[i][j]=k。即在Ai..j中,分別計算Ai..k 和 Ak+1..j 所用的計算開銷最小。

4、采用自底向上的表格法。依次求解矩陣長度為2,3,...,n的最優計算順序。


算法思想

1、對m[i][i]全部初始化為0.

2、在矩陣鏈A1..n中,依次計算長度len為2,3,...,n的m[i][j]大小。(j-i+1==長度len).

3、對於長度為len的m[i][j]初始化為+∞。然后根據以下公式計算m[i][j]的最小值。

m[i][j]=min{ m[i][k] + m[k+1][j] + Pi-1PkPj }

由於比長度len小的m[i][k],m[k+1][j]都已經提前計算了出來。所以就可以計算出最小的m[i][j],同時保存相應的最優點。如:s[i][j] = k; //k為i~j的最優計算分割點。

4、根據以上保存的結果,輸出。


具體代碼如下:(C代碼)

<span style="font-size:18px;">/**
	動態規划之矩陣鏈相乘,
	輸入:有N個矩陣連乘,用一行有n+1個數數組表示,表示是n個矩陣的行及第n個矩陣的列,它們之間用空格隔開. 
	輸出:每組測試數據的輸出占一行,它是計算出的矩陣最少連乘積次數,輸出最優全括號結構
	樣例輸入:10 100 5 50
    上面一組數據分別代表: A1:10*100, A2:100*5, A3:5*50
	樣例輸出:7500 ((A1A2)A3)

	30 35 15 5 10 20 25 --> 15125 ((A1(A2A3))((A4A5)A6))
**/
#include <stdio.h>
int m[1002][1002],s[1002][1002];
void matrix_chain(int a[], int n)
{
	int l, i, j, k, tmp;
	for(l=2; l<=n; l++)
	{
		for(i=1; i<=n-l+1; i++)		//長度為l的區間,其最小下標為1~n-l+1
		{
			j=i+l-1;
			m[i][j] = 0x7fffffff;
			for(k=i; k<j; k++)		//i~k, k+1~j, 所以k<j
			{
				tmp = m[i][k]+m[k+1][j]+a[i-1]*a[k]*a[j];
				if(tmp < m[i][j])
				{
					m[i][j] = tmp;
					s[i][j] = k;
				}
			}
		}
	}

}
void print(int i, int j)
{
	if(i == j)
		printf("A%d",i);
	else{
		printf("(");
		print(i, s[i][j]);
		print(s[i][j]+1, j);
		printf(")");
	}
}
int main()
{
	int n, a[1002];
	int i,j,l;
	while(scanf("%d",&n)==1)	//輸入有n個矩陣
	{
		for(i=0; i<n+1; i++)
			scanf("%d",&a[i]);
		
		//memset(m, 0x7fffffff,sizeof(m));
		for(i=0; i<n+1; i++)
			m[i][i] = 0;
		matrix_chain(a, n);
		printf("%d\n",m[1][n]);
		print(1, n);
		printf("\n");
	}

	return 0;
}
</span>






免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM