算法筆記_003:矩陣相乘問題【分治法】


目錄

1 問題描述 

1.1實驗題目 

1.2實驗目的 

1.3實驗要求 

2 解決方案 

2.1 分治法原理簡述 

2.2 分治法求解矩陣相乘原理 

2.3 具體實現源碼 

2.4 運算結果截圖 

 


1 問題描述

1.1實驗題目

    M1M2是兩個n×n的矩陣,設計算法計算M1×M2 的乘積。

1.2實驗目的

    (1)提高應用蠻力法設計算法的技能;

    (2)深刻理解並掌握分治法的設計思想;

    (3)理解這樣一個觀點:用蠻力法設計的算法,一般來說,經過適度的努力后,都可以對其進行改進,以提高算法的效率。

1.3實驗要求

    (1)設計並實現用BF(Brute-Force,即蠻力法)方法求解矩陣相乘問題的算法;

    (2)設計並實現用DACDivide-And-Conquer,即分治法)方法求解矩陣相乘問題的算法;

    (3)以上兩種算法的輸入既可以手動輸入,也可以自動生成;

    (4)對上述兩個算法進行時間復雜性分析,並設計實驗程序驗證分析結果;

    (5)設計可供用戶選擇算法的交互式菜單(放在相應的主菜單下)

 


2 解決方案

2.1 分治法原理簡述

    分治法的設計思想是將一個難以直接解決的大問題,分割成一些規模較小的相同問題,以便各個擊破,分而治之。

    分治策略是:對於一個規模為n的問題,若該問題可以容易地解決(比如說規模n較小)則直接解決,否則將其分解為k個規模較小的子問題,這些子問題互相獨立且與原問題形式相同,遞歸地解這些子問題,然后將各子問題的解合並得到原問題的解。這種算法設計策略叫做分治法。

    如果原問題可分割成k個子問題,1<k≤n ,且這些子問題都可解並可利用這些子問題的解求出原問題的解,那么這種分治法就是可行的。由分治法產生的子問題往往是原問題的較小模式,這就為使用遞歸技術提供了方便。在這種情況下,反復應用分治手段,可以使子問題與原問題類型一致而其規模卻不斷縮小,最終使子問題縮小到很容易直接求出其解。這自然導致遞歸過程的產生。分治與遞歸像一對孿生兄弟,經常同時應用在算法設計之中,並由此產生許多高效算法。

    分治法所能解決的問題一般具有以下幾個特征:

1) 該問題的規模縮小到一定的程度就可以容易地解決

2) 該問題可以分解為若干個規模較小的相同問題,即該問題具有最優子結構性質。

3) 利用該問題分解出的子問題的解可以合並為該問題的解;

4) 該問題所分解出的各個子問題是相互獨立的,即子問題之間不包含公共的子問題。

2.2 分治法求解矩陣相乘原理

首先了解一下傳統計算矩陣相乘的原理:

 

 

其次,看一下優化后的矩陣相乘法原理:

 

 

最后,看一下本文利用分治法求解矩陣相乘的原理(PS:本文求解其效率不是最高,主要是體驗一下分治法,重點在於分治法):

注意:使用分治法求解兩個nxn階矩陣相乘,其中n值為2的冪值,否則只能使用蠻力法計算。

本文具體源碼主要根據以上分塊矩陣方法,先分塊(即使用分治法),然后遞歸求解。

 

2.3 具體實現源碼

package com.liuzhen.dac;

public class Matrix {
    
    //初始化一個隨機nxn階矩陣
    public static int[][] initializationMatrix(int n){
        int[][] result = new int[n][n];
        for(int i = 0;i < n;i++){
            for(int j = 0;j < n;j++){
                result[i][j] = (int)(Math.random()*10); //采用隨機函數隨機生成1~10之間的數
            }
        }            
        return result;            
    }
    
    //蠻力法求解兩個nxn和nxn階矩陣相乘
    public static int[][] BruteForce(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++){
                result[i][j] = 0;
                for(int k=0;k<n;k++){
                    result[i][j] += p[i][k]*q[k][j];
                }
            }
        }                
        return result;
    }
    
    //分治法求解兩個nxn和nxn階矩陣相乘
    public static int[][] DivideAndConquer(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];
        //當n為2時,返回矩陣相乘結果
        if(n == 2){
            result = BruteForce(p,q,n);            
            return result;
        }
        
        //當n大於3時,采用采用分治法,遞歸求最終結果
        if(n > 2){
            int m = n/2;
            
            int[][] p1 = QuarterMatrix(p,n,1);
            int[][] p2 = QuarterMatrix(p,n,2);
            int[][] p3 = QuarterMatrix(p,n,3);
            int[][] p4 = QuarterMatrix(p,n,4);
//            System.out.println();
//            System.out.print("矩陣p1值為:");
//            PrintfMatrix(p1,m);
//            System.out.println();
//            System.out.print("矩陣p2值為:");
//            PrintfMatrix(p2,m);
//            System.out.println();
//            System.out.print("矩陣p3值為:");
//            PrintfMatrix(p3,m);
//            System.out.println();
//            System.out.print("矩陣p4值為:");
//            PrintfMatrix(p4,m);
            
            int[][] q1 = QuarterMatrix(q,n,1);
            int[][] q2 = QuarterMatrix(q,n,2);
            int[][] q3 = QuarterMatrix(q,n,3);
            int[][] q4 = QuarterMatrix(q,n,4);
            
            int[][] result1 = QuarterMatrix(result,n,1);
            int[][] result2 = QuarterMatrix(result,n,2);
            int[][] result3 = QuarterMatrix(result,n,3);
            int[][] result4 = QuarterMatrix(result,n,4);
            
            
            result1 = AddMatrix(DivideAndConquer(p1,q1,m),DivideAndConquer(p2,q3,m),m);
            result2 = AddMatrix(DivideAndConquer(p1,q2,m),DivideAndConquer(p2,q4,m),m);
            result3 = AddMatrix(DivideAndConquer(p3,q1,m),DivideAndConquer(p4,q3,m),m);
            result4 = AddMatrix(DivideAndConquer(p3,q2,m),DivideAndConquer(p4,q4,m),m);
            
            
            result = TogetherMatrix(result1,result2,result3,result4,m);
        }
        return result;
    }
    
    //獲取矩陣的四分之一,並決定返回哪一個四分之一
    public static int[][] QuarterMatrix(int[][] p,int n,int number){
        int rows = n/2;   //行數減半
        int cols = n/2;   //列數減半
        int[][] result = new int[rows][cols];
        switch(number){
           case 1 :
           {
              // result = new int[rows][cols];
               for(int i=0;i<rows;i++){
                   for(int j=0;j<cols;j++){
                       result[i][j] = p[i][j];
                   }
               }
               break;
           }
            
           case 2 :
           {
              // result = new int[rows][n-cols];
               for(int i=0;i<rows;i++){
                   for(int j=0;j<n-cols;j++){
                       result[i][j] = p[i][j+cols];
                   }
               }
               break;
           }
           
           case 3 :
           {
              // result = new int[n-rows][cols];
               for(int i=0;i<n-rows;i++){
                   for(int j=0;j<cols;j++){
                       result[i][j] = p[i+rows][j];
                   }
               }
               break;
           }
           
           case 4 :
           {
              // result = new int[n-rows][n-cols];
               for(int i=0;i<n-rows;i++){
                   for(int j=0;j<n-cols;j++){
                       result[i][j] = p[i+rows][j+cols];
                   }
               }
               break;
           }
           
           default:
               break;
        }
        
        return result;
     }
    
    //把均分為四分之一的矩陣,聚合成一個矩陣,其中矩陣a,b,c,d分別對應原完整矩陣的四分中1、2、3、4
    public static int[][] TogetherMatrix(int[][] a,int[][] b,int[][] c,int[][] d,int n){
        int[][] result = new int[2*n][2*n];
        for(int i=0;i<2*n;i++){
            for(int j=0;j<2*n;j++){
                if(i<n){
                    if(j<n){
                        result[i][j] = a[i][j];
                    }
                    else
                        result[i][j] = b[i][j-n];
                }
                else{
                    if(j<n){
                        result[i][j] = c[i-n][j];
                    }
                    else{
                        result[i][j] = d[i-n][j-n];
                    }
                }
            }
        }
        
        return result;
    }
    
    
    //求兩個矩陣相加結果
    public static int[][] AddMatrix(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++){
                result[i][j] = p[i][j]+q[i][j];
            }
        }
        return result;
    }
    
    //控制台輸出矩陣
    public static void PrintfMatrix(int[][] matrix,int n){
        for(int i=0;i<n;i++){
            System.out.println();
            for(int j=0;j<n;j++){
                System.out.print("\t");
                System.out.print(matrix[i][j]);
            }
        }
        
    }
    
    public static void main(String args[]){
        int[][] p = initializationMatrix(8);
        int[][] q = initializationMatrix(8);
        System.out.print("矩陣p初始化值為:");
        PrintfMatrix(p,8);
        System.out.println();
        System.out.print("矩陣q初始化值為:");
        PrintfMatrix(q,8);
        
        int[][] bf_result = BruteForce(p,q,8);
        System.out.println();
        System.out.print("蠻力法計算矩陣p*q結果為:");
        PrintfMatrix(bf_result,8);
    
        int[][] dac_result = DivideAndConquer(p,q,8);
        System.out.println();
        System.out.print("分治法計算矩陣p*q結果為:");
        PrintfMatrix(dac_result,8);
    }

}

 

2.4 運算結果截圖

 

 

參考資料:

    1、009-矩陣乘法-分治法-《算法設計技巧與分析》M.H.A學習筆記

     2、 Strassen矩陣乘法(分治法續)


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM