MKL庫矩陣乘法(cblas_?gemm)


MKL庫中基本線性代數子程序,BLAS(Basic Linear Algebra Subprograms)庫,是一個API標淮,用以規范發布基礎線性代數操作的數值庫(如向量或矩陣乘法)。其中CBLASBLASC語言接口。

庫中前綴用來區分所支持處理的數據類型。

前綴 描述 函數名系列 描述
s- 實數、單精度 ge... 一般矩陣
c- 復數、單精度 sy... 對稱矩陣
d- 實數、雙精度 he... Hermitian矩陣
z- 復數、雙精度 tr... 三角矩陣

基本矩陣、向量操作

函數(采用常規的前綴為d的接口) 描述
cblas_dasum 向量元素值模的總和
cblas_daxpy 縮放向量
cblas_dcopy 復制向量
cblas_ddot 向量點積
cblas_dswap 交換兩向量
cblas_dgemv 常規矩陣×向量

重點介紹矩陣的乘法運算。

此示例是利用Intel 的MKL庫函數完成計算矩陣(乘法)運算,計算式為:

\[C=\alpha*A*B+\beta*C \]

其通過調整\(A、B、C\)矩陣及其系數,同樣可以完成矩陣的加減;如若只需矩陣\(A\)\(B\)的乘積,設置\(\alpha=1,\beta=0\)即可。

其中\(A\)\(m\times k\)維矩陣,\(B\)\(k\times n\)維矩陣,\(C\)\(m\times n\)維矩陣。

使用的函數為cblas_?gemm(gemm表示GEneric Matrix Multiplication),完成一般的矩陣乘法。

根據輸入/輸出數據的類型可以分為cblas_dgemm,cblas_sgemm,cblas_cgemm,cblas_zgemm,具體類型參見上文,不再贅述,以下以cblas_dgemm為例介紹其用法。

1 cblas_dgemm參數詳解

fun cblas_dgemm(Layout,		//指定行優先(CblasRowMajor,C)或列優先(CblasColMajor,Fortran)數據排序
               TransA,		//指定是否轉置矩陣A,可以為CblasNoTrans或CblasTrans
               TransB,		//指定是否轉置矩陣B,可以為CblasNoTrans或CblasTrans
               M,		//矩陣A和C的行數
               N,		//矩陣B和C的列數
               K,		//矩陣A的列,B的行
               alpha,		//矩陣A和B乘積的比例因子
               A,		//A矩陣
               lda,		//矩陣A的第一維的大小
               B,		//B矩陣
               ldb,		//矩陣B的第一維的大小
               beta,		//矩陣C的比例因子
               C,		//(輸入/輸出) 矩陣C的地址
               ldc		//矩陣C的第一維的大小
               )		

2 定義待處理矩陣

#include <stdio.h>
#include <stdlib.h>
#include "mkl.h"		// 調用mkl頭文件

#define min(x,y) (((x) < (y)) ? (x) : (y))	
double* A, * B, * C;		//聲明三個矩陣變量,並分配內存
int m, n, k, i, j;			//聲明矩陣的維度,其中
double alpha, beta;

m = 2000, k = 200, n = 1000;
alpha = 1.0; beta = 0.0;

A = (double*)mkl_malloc(m * k * sizeof(double), 64);	//按照矩陣維度分配內存
B = (double*)mkl_malloc(k * n * sizeof(double), 64);	//mkl_malloc用法與malloc相似,64表示64位
C = (double*)mkl_malloc(m * n * sizeof(double), 64);
if (A == NULL || B == NULL || C == NULL) {		//判空

    mkl_free(A);				
    mkl_free(B);
    mkl_free(C);
    return 1;
}

for (i = 0; i < (m * k); i++) {		//賦值
    A[i] = (double)(i + 1);
}

for (i = 0; i < (k * n); i++) {
    B[i] = (double)(-i - 1);
}

for (i = 0; i < (m * n); i++) {
    C[i] = 0.0;
}

其中\(A\)\(B\)矩陣設置為:

\[\begin{array}{l} A = \left[ {\begin{array}{*{20}{c}} {1.0}&{2.0}& \cdots &{1000.0}\\ {1001.0}&{1002.0}& \cdots &{2000.0}\\ \vdots & \vdots & \ddots & \cdots \\ {999001.0}&{999002.0}& \cdots &{1000000.0} \end{array}} \right] \space B = \left[ {\begin{array}{*{20}{c}} {-1.0}&{-2.0}& \cdots &{-1000.0}\\ {-1001.0}&{-1002.0}& \cdots &{-2000.0}\\ \vdots & \vdots & \ddots & \cdots \\ {-999001.0}&{-999002.0}& \cdots &{-1000000.0} \end{array}} \right] \end{array} \]

\(C\)矩陣為全0。

3 執行矩陣乘法

回到例子中,對照上面的參數,將C矩陣用A與B的矩陣乘法表示:

cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha, A, k, B, n, beta, C, n);

執行后的得到結果如下:

完整代碼

#include <stdio.h>
#include <stdlib.h>
#include "mkl.h"

#define min(x,y) (((x) < (y)) ? (x) : (y))

int main()
{
    double* A, * B, * C;
    int m, n, k, i, j;
    double alpha, beta;


    m = 2000, k = 200, n = 1000;

    alpha = 1.0; beta = 0.0;

    A = (double*)mkl_malloc(m * k * sizeof(double), 64);
    B = (double*)mkl_malloc(k * n * sizeof(double), 64);
    C = (double*)mkl_malloc(m * n * sizeof(double), 64);
    if (A == NULL || B == NULL || C == NULL) {

        mkl_free(A);
        mkl_free(B);
        mkl_free(C);
        return 1;
    }


    for (i = 0; i < (m * k); i++) {
        A[i] = (double)(i + 1);
    }

    for (i = 0; i < (k * n); i++) {
        B[i] = (double)(-i - 1);
    }

    for (i = 0; i < (m * n); i++) {
        C[i] = 0.0;
    }

    cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
        m, n, k, alpha, A, k, B, n, beta, C, n);


    for (i = 0; i < min(m, 6); i++) {
        for (j = 0; j < min(k, 6); j++) {
            printf("%12.0f", A[j + i * k]);
        }
        printf("\n");
    }

    for (i = 0; i < min(k, 6); i++) {
        for (j = 0; j < min(n, 6); j++) {
            printf("%12.0f", B[j + i * n]);
        }
        printf("\n");
    }

    for (i = 0; i < min(m, 6); i++) {
        for (j = 0; j < min(n, 6); j++) {
            printf("%12.5G", C[j + i * n]);
        }
        printf("\n");
    }

    mkl_free(A);
    mkl_free(B);
    mkl_free(C);

    return 0;
}


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM