本文將介紹如何采用卷積神經網絡(CNN)來處理Fashion-MNIST數據集。
程序流程如下:
1、准備樣本數據
2、構建卷積神經網絡模型
3、網絡學習(訓練)
4、消費、測試
除了網絡模型的構建,其它步驟都和前面介紹的普通神經網絡的處理完全一致,本文就不重復介紹了,重點講一下模型的構建。
先看代碼:
/// <summary> /// 構建網絡模型 /// </summary> private Model BuildModel() { // 網絡參數 float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer> { keras.layers.Rescaling(scale, input_shape: (img_rows, img_cols, channel)), keras.layers.Conv2D(32, 5, padding: "same", activation: keras.activations.Relu), keras.layers.MaxPooling2D(), keras.layers.Conv2D(64, 3, padding: "same", activation: keras.activations.Relu), keras.layers.MaxPooling2D(), keras.layers.Flatten(), keras.layers.Dense(128, activation: keras.activations.Relu), keras.layers.Dense(num_classes,activation:keras.activations.Softmax) }); return model; }
keras.layers.Conv2D方法創建一個卷積層
keras.layers.MaxPooling2D方法創建一個池化層
卷積層的含義:
如上圖所示,原始數據尺寸為5*5,卷積核大小為3*3,當卷積核滑過原始圖片時,卷積核和圖片對應的數據進行運算(先乘后加),並形成新的數據。
示例的卷積核為[[1,0,1],[0,1,0],[1,0,1]],和左上角數據卷積后結果為4,填寫到對應位置。對整改圖片全部滑動一遍,即形成最終結果。
采用卷積神經網絡,相對於前面介紹的普通神經網絡有什么優勢呢?
1、首先,圖像本身是一個二維數據,普通網絡首先要把數據拉平,這一點就不合理,而卷積網絡通過卷積核處理數據,保留了原始數據的基本特征;
2、其次,采用卷積網絡大大減小了參數的數量。假設原始圖片分辨率為100*100,拉平后長度為10000,后面跟一個全連接層,輸出為128,此時參數量為(10000+1)*128,超過128萬。這才一個全連接層。如果采用CNN,參數數量取決於卷積核的大小和數量。假設卷積核大小為5*5,數量為32,此時參數數量為:(5*5+1)*32=832。【計算方法下面會詳細介紹】
池化層的含義:
池化就是壓縮,就是圖片數據太大了,通過池化把分辨率減小一些。
池化有均值池化和最大值池化方法,這個很好理解,就是一推數據中取平均值或最大值。MaxPooling2D明顯是最大池化法。
我們再看一下這個代碼:
keras.layers.Conv2D(32, 5, padding: "same", activation: keras.activations.Relu),
32表示卷積核數量為32,卷積核大小為5*5,padding: "same"表示對圖像進行邊緣補零,不然卷積后的圖像尺寸會變小,補零后圖像尺寸不變。
整體模型摘要信息如下:
下面逐行解釋一下:
1、首先輸入層的數據Shape為:(28,28,1),28表示圖片像素,1表示灰度圖片,如果是彩色圖片,應該為(28,28,3)
2、Rescaling對數據進行處理,統一乘以一個系數,這里沒有需要訓練的參數
3、引入一個卷積層,卷積核數量為32,卷積核大小為5*5(圖上看不出來),此時參數數量為:(5*5+1)*32=832,這里卷積核尺寸為5*5,所以有25個參數,這很好理解,+1是因為作為卷積計算后還要加一個偏置b,所以每個卷積核共26個參數。由於有32個卷積核,要對同一個圖像采用不同的卷積核做32次計算,所以這一層輸出數據為(28,28,32)
4、池化層將數據從(28,28,32)壓縮到(14,14,32)
5、再引入一個卷積層,卷積核數量為64,卷積核大小為3*3(圖上看不出來),這次計算和第一次不太一樣:由於上一層數據共有32片,對每一片數據采用的卷積核是不一樣的,所以這里實際一共有32*9=288個卷積核。首先用32個卷積核和上述32片數據分別進行卷積形成32片數據,然后將32片數據疊加求和,最后再加一個偏置形成一片新數據,重復進行64次,形成64片新數據。此時參數數量為:(288+1)*64=18496
【注意:這里的算法其實是和第一層卷積算法完全一樣的,只是第一層輸入為灰度圖片,數據只有一片,如果輸入為彩色圖片,就一致了。】
6、池化層將數據從(14,14,64)壓縮到(7,7,64)
7、將數據拉平,拉平后的數據長度為:7*7*64=3136
8、引入全連接層,輸出神經元數量為128,此時參數數量為:(3136+1)*128=401536
9、最后為全連接層輸出,輸出神經元數量為10,參數數量為:(128+1)*10=1290
現在,由於參數數量已經很多了,訓練需要的時間也比較長了,所以需要把訓練完成后的參數保存下來,下次可以重新加載保存的參數接着訓練,不用從頭再來。
保存的模型也可以發布到生產系統用於實際的消費。
全部代碼如下:

/// <summary> /// 采用卷積神經網絡處理Fashion-MNIST數據集 /// </summary> public class CNN_Fashion_MNIST { private readonly string TrainImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train"; private readonly string TestImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\test"; private readonly string train_date_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\cnn_train_data.bin"; private readonly string train_label_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\cnn_train_label.bin"; private readonly string ModelFile = @"D:\Study\Blogs\TF_Net\Model\cnn_fashion_mnist.h5"; private readonly int img_rows = 28; private readonly int img_cols = 28; private readonly int channel = 1; private readonly int num_classes = 10; // total classes public void Run() { var model = BuildModel(); model.summary(); model.load_weights(ModelFile); Console.WriteLine("press any key"); Console.ReadKey(); model.compile(optimizer: keras.optimizers.Adam(0.0001f), loss: keras.losses.SparseCategoricalCrossentropy(), metrics: new[] { "accuracy" }); (NDArray train_x, NDArray train_y) = LoadTrainingData(); model.fit(train_x, train_y, batch_size: 512, epochs: 1); model.save_weights(ModelFile); test(model); } /// <summary> /// 構建網絡模型 /// </summary> private Model BuildModel() { // 網絡參數 float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer> { keras.layers.Rescaling(scale, input_shape: (img_rows, img_cols, channel)), keras.layers.Conv2D(32, 5, padding: "same", activation: keras.activations.Relu), keras.layers.MaxPooling2D(), keras.layers.Conv2D(64, 3, padding: "same", activation: keras.activations.Relu), keras.layers.MaxPooling2D(), keras.layers.Flatten(), keras.layers.Dense(128, activation: keras.activations.Relu), keras.layers.Dense(num_classes,activation:keras.activations.Softmax) }); return model; } /// <summary> /// 加載訓練數據 /// </summary> /// <param name="total_size"></param> private (NDArray, NDArray) LoadTrainingData() { try { Console.WriteLine("Load data"); IFormatter serializer = new BinaryFormatter(); FileStream loadFile = new FileStream(train_date_path, FileMode.Open, FileAccess.Read); float[,,,] arrx = serializer.Deserialize(loadFile) as float[,,,]; loadFile = new FileStream(train_label_path, FileMode.Open, FileAccess.Read); int[] arry = serializer.Deserialize(loadFile) as int[]; Console.WriteLine("Load data success"); return (np.array(arrx), np.array(arry)); } catch (Exception ex) { Console.WriteLine($"Load data Exception:{ex.Message}"); return LoadRawData(); } } private (NDArray, NDArray) LoadRawData() { Console.WriteLine("LoadRawData"); int total_size = 60000; float[,,,] arrx = new float[total_size, img_rows, img_cols, channel]; int[] arry = new int[total_size]; int count = 0; DirectoryInfo RootDir = new DirectoryInfo(TrainImagePath); foreach (var Dir in RootDir.GetDirectories()) { foreach (var file in Dir.GetFiles("*.png")) { Bitmap bmp = (Bitmap)Image.FromFile(file.FullName); if (bmp.Width != img_cols || bmp.Height != img_rows) { continue; } for (int row = 0; row < img_rows; row++) for (int col = 0; col < img_cols; col++) { var pixel = bmp.GetPixel(col, row); int val = (pixel.R + pixel.G + pixel.B) / 3; arrx[count, row, col, 0] = val; arry[count] = int.Parse(Dir.Name); } count++; } Console.WriteLine($"Load image data count={count}"); } Console.WriteLine("LoadRawData finished"); //Save Data Console.WriteLine("Save data"); IFormatter serializer = new BinaryFormatter(); //開始序列化 FileStream saveFile = new FileStream(train_date_path, FileMode.Create, FileAccess.Write); serializer.Serialize(saveFile, arrx); saveFile.Close(); saveFile = new FileStream(train_label_path, FileMode.Create, FileAccess.Write); serializer.Serialize(saveFile, arry); saveFile.Close(); Console.WriteLine("Save data finished"); return (np.array(arrx), np.array(arry)); } /// <summary> /// 消費模型 /// </summary> private void test(Model model) { Random rand = new Random(1); DirectoryInfo TestDir = new DirectoryInfo(TestImagePath); foreach (var ChildDir in TestDir.GetDirectories()) { Console.WriteLine($"Folder:【{ChildDir.Name}】"); var Files = ChildDir.GetFiles("*.png"); for (int i = 0; i < 10; i++) { int index = rand.Next(1000); var image = Files[index]; var x = LoadImage(image.FullName); var pred_y = model.Apply(x); var result = argmax(pred_y[0].numpy()); Console.WriteLine($"FileName:{image.Name}\tPred:{result}"); } } } private NDArray LoadImage(string filename) { float[,,,] arrx = new float[1, img_rows, img_cols, channel]; Bitmap bmp = (Bitmap)Image.FromFile(filename); for (int row = 0; row < img_rows; row++) for (int col = 0; col < img_cols; col++) { var pixel = bmp.GetPixel(col, row); int val = (pixel.R + pixel.G + pixel.B) / 3; arrx[0, row, col, 0] = val; } return np.array(arrx); } private int argmax(NDArray array) { var arr = array.reshape(-1); float max = 0; for (int i = 0; i < 10; i++) { if (arr[i] > max) { max = arr[i]; } } for (int i = 0; i < 10; i++) { if (arr[i] == max) { return i; } } return 0; } }
通過采用CNN的方法,我們可以把Fashion-MNIST識別率提高到大約94%左右,而且還有提高的空間。但是網絡的優化是一件非常困難的事情,特別是識別率已經很高的時候,想提高1個百分點都是很不容易的。
以下是一個優化過的網絡,我查閱了不少資料,也參考了很多代碼,才構建了這個網絡,它的識別率約為96%,再怎么調整也提高不上去了。
/// <summary> /// 構建網絡模型 /// </summary> private Model BuildModel() { // 網絡參數 float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer> { keras.layers.Rescaling(scale, input_shape: (img_rows, img_cols, channel)), keras.layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu), keras.layers.MaxPooling2D(), keras.layers.Conv2D(64, 3, padding: "same", activation: keras.activations.Relu), keras.layers.MaxPooling2D(), keras.layers.Dropout(0.3f), keras.layers.BatchNormalization(), keras.layers.Conv2D(128, 3, padding: "same", activation: keras.activations.Relu), keras.layers.Conv2D(128, 3, padding: "same", activation: keras.activations.Relu), keras.layers.MaxPooling2D(), keras.layers.Dropout(0.4f), keras.layers.Flatten(), keras.layers.Dense(512, activation: keras.activations.Relu), keras.layers.Dropout(0.25f), keras.layers.Dense(num_classes,activation:keras.activations.Softmax) }); return model; }
【參考資料】
卷積神經網絡CNN總結 - Madcola - 博客園 (cnblogs.com)
卷積神經網絡(CNN)模型結構 - 劉建平Pinard - 博客園 (cnblogs.com)
【相關資源】
源碼:Git: https://gitee.com/seabluescn/tf_not.git
項目名稱:CNN_Fashion_MNIST,CNN_Fashion_MNIST_Plus