upc-9541 矩陣乘法 (矩陣分塊)


題目描述

深度學習算法很大程度上基於矩陣運算。例如神經網絡中的全連接本質上是一個矩陣乘法,而卷積運算也通常是用矩陣乘法來實現的。有一些科研工作者為了讓神經網絡的計算更快捷,提出了二值化網絡的方法,就是將網絡權重壓縮成只用兩種值表示的形式,這樣就可以用一些 trick 加速計算了。例如兩個二進制向量點乘,可以用計算機中的與運算代替,然后統計結果中 1 的個數即可。
然而有時候為了降低壓縮帶來的誤差,只允許其中一個矩陣被壓縮成二進制。這樣的情況下矩陣乘法運算還能否做進一步優化呢?給定一個整數矩陣A 和一個二值矩陣B,計算矩陣乘法 C=A×B。為了減少輸出,你只需要計算 C 中所有元素的的異或和即可。
 

 

輸入

第一行有三個整數 N,P,M, 表示矩陣 A,B 的大小分別是 N×P,P×M 。
接下來 N 行是矩陣 A 的值,每一行有 P 個數字。第 i+1 行第 j 列的數字為 Ai,j, Ai,j 用大寫的16進制表示(即只包含 0~9, A~F),每個數字后面都有一個空格。
接下來 M 行是矩陣 B 的值,每一行是一個長為 P 的 01字符串。第 i+N+1 行第 j 個字符表示 Bj,i 的值。
 

 

輸出

一個整數,矩陣 C 中所有元素的異或和。

 

樣例輸入

4 2 3
3 4
8 A
F 5
6 7
01
11
10

 

樣例輸出

2

 

提示

2≤N,M≤4096,1≤P≤64,0≤Ai,j<65536,0≤Bi,j≤1.

 
看起來是矩陣分塊,但是數據比較水,for for for暴力循環就能過題。
 
由於矩陣b是個01矩陣,所以如果按8位分塊,一塊最多有256種情況,預處理分塊后極限數據時間復雜度為5e8.
#include "bits/stdc++.h"

using namespace std;
const int maxn = 4100;
int a[maxn][70], b[maxn][70];
int ap[maxn][10][260], bp[maxn][10];

int main() {
    //freopen("input.txt", "r", stdin);
    int n, p, m;
    scanf("%d %d %d", &n, &p, &m);
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < p; j++) {
            scanf("%x", &a[i][j]);
        }
    }
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < p; j++) {
            scanf("%1d", &b[i][j]);
        }
    }
    p = (p + 7) / 8;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < p; j++) {
            int base = j * 8;
            for (int k = 0; k < 256; k++) {
                for (int l = 0; l < 8; l++) {
                    if (k & (1 << l)) {
                        ap[i][j][k] += a[i][base + l];
                    }
                }
            }
        }
    }
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < p; j++) {
            int base = j * 8;
            for (int k = 0; k < 8; k++) {
                bp[i][j] += (b[i][base + k] << k);
            }
        }
    }
    int ans = 0, temp;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            temp = 0;
            for (int k = 0; k < p; k++) {
                temp += ap[i][k][bp[j][k]];
            }
            ans ^= temp;
        }
    }
    printf("%d\n", ans);
    return 0;
}

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM