C++基於armadillo im2col的實現


col2im的實現,這是im2col的逆過程
最近學習CNN,需要用到im2col這個函數,無奈網上沒有多少使用armadillo的例子,而且armadillo庫中似乎也沒有這個函數,因此自己寫了。
im2col的原理網上一大把,我懶得寫了。

1. field<某類>

field<class oT> 是armadillo庫中的類,類似於矩陣, 不過這個“矩陣”的每一個元素都是向量或者矩陣。因此用field可以作為四維輸入數據使用。

2. 矩陣展開

這個其實還挺簡單,使用reshape函數將矩陣變形。不過,armadillo中變形是按照豎向變形的。比如:

1 2 3
4 5 6
7 8 9

這樣的矩陣變形成1×9的向量的話:

1 4 7 2 5 8 3 6 9

會成這樣👆。。。
但是也不影響,濾波器也是這么變得,相對位置沒變唄。。

3. 排列組合

鄙人才疏學淺,只會用一堆for循環來排列組合。。。貌似沒找到更好的辦法。

4. 其他細節

像是步數、填充什么的,多注意一下就行了。

5. 實現代碼

mat im2col(field<mat> input_data, int filter_h, int filter_w, int stride, int pad)
{
	int N, C, H, W;
	N = input_data.n_rows;
	C = input_data.n_cols;
	H = input_data(0, 0).n_rows;
	W = input_data(0, 0).n_cols;
	int out_h = (H + 2 * pad - filter_h) / stride + 1;
	int out_w = (W + 2 * pad - filter_w) / stride + 1;
	field<mat> img = input_data;
	img.for_each([H, W, pad](mat& X) {X.insert_rows(0, pad); X.insert_rows(H + pad, pad); X.insert_cols(0, pad); X.insert_cols(W + pad, pad); });
	mat col(out_h * out_w * N, C * filter_h * filter_w, fill::zeros);
	for (int n = 0, z = 0; n < N; n++)
	{
		for (int i = 0; i < out_h; i++)
		{
			for (int j = 0; j < out_w; j++, z++)
			{
				for (int k = 0; k < C; k++)
				{
					mat filter(filter_h, filter_w, fill::zeros);
					filter = img(n, k)(span(i * stride, i * stride + filter_h - 1), span(j * stride, j * stride + filter_w - 1));
					filter.reshape(1, filter_h * filter_w);
					int x = z;
					int y0 = filter_h * filter_w * k;
					int y1 = filter_h * filter_w * k + filter_h * filter_w - 1;
					col(span(x, x), span(y0, y1)) = filter;
				}
			}
		}
	}
	return col;
}

頭文件就是聲明和引用。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM