c++的矩陣乘法加速trick


最近讀RNNLM的源代碼,發現其實現矩陣乘法時使用了一個trick,這里描述一下這個trick。

首先是正常版的矩陣乘法(其實是矩陣乘向量)

void matrixXvector(float* destvect, float* srcmatrix, int srcmatrix_rownum, int srcmatrix_colnum, float* srcvect, int srcvect_size){
	for(int row=0;row<srcmatrix_rownum;++row){
		destvect[row]=0;
		for(int col=0;col<srcmatrix_colnum;++col){
			destvect[row]+=srcmatrix[row*srcmatrix_colnum+col]*srcvect[col];
		}
	}
}

就是最簡單的for循環,逐行逐列遍歷。

接下來是RNNLM中實現的trick版本

void matrixXvector2(float* destvect, float* srcmatrix, int srcmatrix_rownum, int srcmatrix_colnum, float* srcvect, int srcvect_size){
	int row, col;
	float val1, val2, val3, val4;
	float val5, val6, val7, val8;
	
	for(row=0;row<srcmatrix_rownum/8;++row){
		val1 = 0;
		val2 = 0;
		val3 = 0;
		val4 = 0;
		val5 = 0;
		val6 = 0;
		val7 = 0;
		val8 = 0;
		
		for(col=0;col<srcmatrix_colnum;++col){
			val1+=srcmatrix[(row*8+0)*srcmatrix_colnum+col]*srcvect[col];
			val2+=srcmatrix[(row*8+1)*srcmatrix_colnum+col]*srcvect[col];
			val3+=srcmatrix[(row*8+2)*srcmatrix_colnum+col]*srcvect[col];
			val4+=srcmatrix[(row*8+3)*srcmatrix_colnum+col]*srcvect[col];
			val5+=srcmatrix[(row*8+4)*srcmatrix_colnum+col]*srcvect[col];
			val6+=srcmatrix[(row*8+5)*srcmatrix_colnum+col]*srcvect[col];
			val7+=srcmatrix[(row*8+6)*srcmatrix_colnum+col]*srcvect[col];
			val8+=srcmatrix[(row*8+7)*srcmatrix_colnum+col]*srcvect[col];
		}
		
		destvect[row*8+0]+=val1;
		destvect[row*8+1]+=val2;
		destvect[row*8+2]+=val3;
		destvect[row*8+3]+=val4;
		destvect[row*8+4]+=val5;
		destvect[row*8+5]+=val6;
		destvect[row*8+6]+=val7;
		destvect[row*8+7]+=val8;
		
	}
	
	for(row=row*8;row<srcmatrix_rownum;++row){
		for(col=0;col<srcmatrix_colnum;++col){
			destvect[row]+=srcmatrix[row*srcmatrix_colnum+col]*srcvect[col];	
		}
	}
}

對比普通版,trick版把遍歷行的for循環分成了8份,同時進行列遍歷。

實際測試中,這個trick版比普通版快了接近2倍~這是編譯器優化造成的么……?


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM