稀疏矩陣乘法


給定兩個 稀疏矩陣 A 和 B,返回AB的結果。
您可以假設A的列數等於B的行數。

題目地址:https://www.jiuzhang.com/solution/sparse-matrix-multiplication/#tag-other

本參考程序來自九章算法,由 @Roger 提供。

題目解法:

時間復雜度分析:
假設矩陣A,B均為 n x n 的矩陣,
矩陣A的稀疏系數為a,矩陣B的稀疏系數為b,
a,b∈[0, 1],矩陣越稀疏,系數越小。

方法一:暴力,不考慮稀疏性
Time (n^2 * (1 + n)) = O(n^2 + n^3)
Space O(1)

方法二:改進,僅考慮A的稀疏性
Time O(n^2 * (1 + a * n) = O(n^2 + a * n^3)
Space O(1)

方法三(最優):進一步改進,考慮A與B的稀疏性
Time O(n^2 * (1 + a * b * n)) = O(n^2 + a * b * n^3)
Space O(b * n^2)

方法四:另外一種思路,將矩陣A, B非0元素的坐標抽出,對非0元素進行運算和結果累加
Time O(2 * n^2 + a * b * n^4) = O(n^2 + a * b * n^4)
Space O(a * n^2 + b * n^2)

解讀:矩陣乘法的兩種形式,假設 A(n, t) * B(t, m) = C(n, m)

// 形式一:外層兩個循環遍歷C (常規解法)
for (int i = 0; i < n; i++) {
    for (int j = 0; j < m; j++) {
        for (int k = 0; k < t; k++) {
            C[i][j] += A[i][k] * B[k][j];
        }
    }
}

// 或者寫成下面這樣子
for (int i = 0; i < n; i++) {
    for (int j = 0; j < m; j++) {
        int sum = 0;
        for (int k = 0; k < t; k++) {
            sum += A[i][k] * B[k][j];
        }
        C[i][j] = sum;
    }
}
// 形式二:外層兩個循環遍歷A
for (int i = 0; i < n; i++) {
    for (int k = 0; k < t; k++) {
        for (int j = 0; j < m; j++) {
            C[i][j] += A[i][k] * B[k][j];
        }
    }
}

兩種方法的區別

代碼上的區別(表象):
調換了第二三層循環的順序

核心區別(內在):
形式一以C為核心進行遍歷,每個C[i][j]只會被計算一次,就是最終答案。
形式二以A為核心進行遍歷,每個A[i][k] 乘上 B[k][j]之后,會被累加到 C[i][j],每個C[i][j]將被累加t次。

 

舉個例子,若A矩陣2x3,B矩陣3x2,C矩陣2x2
       A                 B              C
a00 , a01 , a02      b00 , b01      c00 , c01
a10 , a11 , a12      b10 , b11      c10 , c11
                            b20 , b21

形式一的計算過程:遍歷C,假設遍歷到c00,計算c00 = a00 * b00 + a01 * b10 + a02 * b20
形式二的計算過程:遍歷A,
假設遍歷到a00,a00 * b00 累加到 c00, a00 * b01 累加到c01;
假設遍歷到a01,a01 * b10 累加到 c00, a01 * b11 累加到c01;

 

 再回到本題目,可以發現是否為稀疏矩陣,對於上述形式一來說,並無法進行優化,因為是以C為核心
但是對於形式二來說,以A為核心,若A[i][k]為0,那么該元素就不必進行對應相乘並累加的操作了。
故方法二,就是基於此進行優化的。

// 方法一
public class Solution {
    /**
     * @param A: a sparse matrix
     * @param B: a sparse matrix
     * @return: the result of A * B
     */
    public int[][] multiply(int[][] A, int[][] B) {
        // write your code here
        // A(n, t) * B(t, m) = C(n, m)
        int n = A.length;
        int t = A[0].length;
        int m = B[0].length;
        int[][] C = new int[n][m];
        
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++) {
                int sum = 0;
                for (int k = 0; k < t; k++) {
                    sum += A[i][k] * B[k][j];
                }
                C[i][j] = sum;
            }
        }
        
        return C;
    }
}

// 方法二
public class Solution {
    /**
     * @param A: a sparse matrix
     * @param B: a sparse matrix
     * @return: the result of A * B
     */
    public int[][] multiply(int[][] A, int[][] B) {
        // write your code here
        // A(n, t) * B(t, m) = C(n, m)
        int n = A.length;
        int t = A[0].length;
        int m = B[0].length;
        int[][] C = new int[n][m];
        
        for (int i = 0; i < n; i++) {
            for (int k = 0; k < t; k++) {
                if (A[i][k] == 0) {
                    continue;
                }
                for (int j = 0; j < m; j++) {
                    C[i][j] += A[i][k] * B[k][j];
                }
            }
        }
        
        return C;
    }
}

// 方法三
public class Solution {
    /**
     * @param A: a sparse matrix
     * @param B: a sparse matrix
     * @return: the result of A * B
     */
    public int[][] multiply(int[][] A, int[][] B) {
        // write your code here
        // A(n, t) * B(t, m) = C(n, m)
        int n = A.length;
        int t = A[0].length;
        int m = B[0].length;
        int[][] C = new int[n][m];
        
        List<List<Integer>> B_nonZero_colIndices = new ArrayList<>();
        for (int k = 0; k < t; k++) {
            List<Integer> colIndices = new ArrayList<>();
            for (int j = 0; j < m; j++) {
                if (B[k][j] != 0) {
                    colIndices.add(j);
                } 
            }
            B_nonZero_colIndices.add(colIndices);
        }
        
        for (int i = 0; i < n; i++) {
            for (int k = 0; k < t; k++) {
                if (A[i][k] == 0) {
                    continue;
                }
                for (int colIndex : B_nonZero_colIndices.get(k)) {
                    C[i][colIndex] += A[i][k] * B[k][colIndex];
                }
            }
        }
        
        return C;
    }
}

// 方法四
public class Solution {
    /**
     * @param A: a sparse matrix
     * @param B: a sparse matrix
     * @return: the result of A * B
     */
    public int[][] multiply(int[][] A, int[][] B) {
        // write your code here
        // A(n, t) * B(t, m) = C(n, m)
        int n = A.length;
        int t = A[0].length;
        int m = B[0].length;
        int[][] C = new int[n][m];
        
        List<Point> A_Points = getNonZeroPoints(A);
        List<Point> B_Points = getNonZeroPoints(B);
        
        for (Point pA : A_Points) {
            for (Point pB : B_Points) {
                if (pA.j == pB.i) {
                    C[pA.i][pB.j] += A[pA.i][pA.j] * B[pB.i][pB.j];
                }
            }
        }
    
        return C;
    }
    
    
    private List<Point> getNonZeroPoints(int[][] matrix) {
        List<Point> nonZeroPoints = new ArrayList<>();
        for (int i = 0; i < matrix.length; i++) {
            for (int j = 0; j < matrix[0].length; j++) {
                if (matrix[i][j] != 0) {
                    nonZeroPoints.add(new Point(i, j));
                }
            }
        }
        return nonZeroPoints;
    }
    
    class Point {
        int i, j;
        Point(int i, int j) {
            this.i = i;
            this.j = j;
        }
    }
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM