TensorFlow.NET機器學習入門【5】采用神經網絡實現手寫數字識別(MNIST)


 從這篇文章開始,終於要干點正兒八經的工作了,前面都是准備工作。這次我們要解決機器學習的經典問題,MNIST手寫數字識別。

首先介紹一下數據集。請首先解壓:TF_Net\Asset\mnist_png.tar.gz文件

 文件夾內包括兩個文件夾:training和validation,其中training文件夾下包括60000個訓練圖片validation下包括10000個評估圖片,圖片為28*28像素,分別放在0~9十個文件夾中。

程序總體流程和上一篇文章介紹的BMI分析程序基本一致,畢竟都是多元分類,有幾點不一樣。

1、BMI程序的特征數據(輸入)為一維數組,包含兩個數字,MNIST的特征數據為28*28的二位數組;

2、BMI程序的輸出為3個,MNIST的輸出為10個;

 

網絡模型構建如下:

        private readonly int img_rows = 28;
        private readonly int img_cols = 28;
        private readonly int num_classes = 10;  // total classes
        /// <summary>
        /// 構建網絡模型
        /// </summary>     
        private Model BuildModel()
        {
            // 網絡參數          
            int n_hidden_1 = 128;    // 1st layer number of neurons.     
            int n_hidden_2 = 128;    // 2nd layer number of neurons.                                
            float scale = 1.0f / 255;

            var model = keras.Sequential(new List<ILayer>
            {
                keras.layers.InputLayer((img_rows,img_cols)),
                keras.layers.Flatten(),
                keras.layers.Rescaling(scale),
                keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
                keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
                keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
            });

            return model;
        }

這個網絡里用到了兩個新方法,需要解釋一下:

1、Flatten方法:這里表示拉平,把28*28的二維數組拉平為含784個數據的一維數組,因為二維數組無法進行運算;

2、Rescaling 方法:就是對每個數據乘以一個系數,因為我們從圖片獲取的數據為每一個位點的灰度值,其取值范圍為0~255,所以乘以一個系數將數據縮小到1以內,以免后面運算時溢出。

 

其它基本和上一篇文章介紹的差不多,全部代碼如下:

 /// <summary>
    /// 神經網絡實現手寫數字識別
    /// </summary>
    public class NN_MultipleClassification_MNIST
    {
        private readonly string train_date_path = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\train_data.bin";
        private readonly string train_label_path = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\train_label.bin";

        private readonly int img_rows = 28;
        private readonly int img_cols = 28;
        private readonly int num_classes = 10;  // total classes

        public void Run()
        {
            var model = BuildModel();
            model.summary();

            model.compile(optimizer: keras.optimizers.Adam(0.001f),
                loss: keras.losses.SparseCategoricalCrossentropy(),
                metrics: new[] { "accuracy" });

            (NDArray train_x, NDArray train_y) = LoadTrainingData();
            model.fit(train_x, train_y, batch_size: 1024, epochs: 10);

            test(model);
        }

        /// <summary>
        /// 構建網絡模型
        /// </summary>     
        private Model BuildModel()
        {
            // 網絡參數          
            int n_hidden_1 = 128;    // 1st layer number of neurons.     
            int n_hidden_2 = 128;    // 2nd layer number of neurons.                                
            float scale = 1.0f / 255;

            var model = keras.Sequential(new List<ILayer>
            {
                keras.layers.InputLayer((img_rows,img_cols)),
                keras.layers.Flatten(),
                keras.layers.Rescaling(scale),
                keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
                keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
                keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
            });

            return model;
        }

        /// <summary>
        /// 加載訓練數據
        /// </summary>
        /// <param name="total_size"></param>    
        private (NDArray, NDArray) LoadTrainingData()
        {
            try
            {
                Console.WriteLine("Load data");
                IFormatter serializer = new BinaryFormatter();
                FileStream loadFile = new FileStream(train_date_path, FileMode.Open, FileAccess.Read);
                float[,,] arrx = serializer.Deserialize(loadFile) as float[,,];

                loadFile = new FileStream(train_label_path, FileMode.Open, FileAccess.Read);
                int[] arry = serializer.Deserialize(loadFile) as int[];
                Console.WriteLine("Load data success");
                return (np.array(arrx), np.array(arry));
            }
            catch (Exception ex)
            {
                Console.WriteLine($"Load data Exception:{ex.Message}");
                return LoadRawData();
            }
        }

        private (NDArray, NDArray) LoadRawData()
        {
            Console.WriteLine("LoadRawData");

            int total_size = 60000;
            float[,,] arrx = new float[total_size, img_rows, img_cols];
            int[] arry = new int[total_size];

            int count = 0;
            var TrainingImagePath = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\training";
            DirectoryInfo RootDir = new DirectoryInfo(TrainingImagePath);
            foreach (var Dir in RootDir.GetDirectories())
            {
                foreach (var file in Dir.GetFiles("*.png"))
                {
                    Bitmap bmp = (Bitmap)Image.FromFile(file.FullName);
                    if (bmp.Width != img_cols || bmp.Height != img_rows)
                    {
                        continue;
                    }

                    for (int row = 0; row < img_rows; row++)
                        for (int col = 0; col < img_cols; col++)
                        {
                            var pixel = bmp.GetPixel(col, row);
                            int val = (pixel.R + pixel.G + pixel.B) / 3;

                            arrx[count, row, col] = val;
                            arry[count] = int.Parse(Dir.Name);
                        }

                    count++;
                }

                Console.WriteLine($"Load image data count={count}");
            }

            Console.WriteLine("LoadRawData finished");
            //Save Data
            Console.WriteLine("Save data");
            IFormatter serializer = new BinaryFormatter();

            //開始序列化
            FileStream saveFile = new FileStream(train_date_path, FileMode.Create, FileAccess.Write);
            serializer.Serialize(saveFile, arrx);
            saveFile.Close();

            saveFile = new FileStream(train_label_path, FileMode.Create, FileAccess.Write);
            serializer.Serialize(saveFile, arry);
            saveFile.Close();
            Console.WriteLine("Save data finished");

            return (np.array(arrx), np.array(arry));
        }

        /// <summary>
        /// 消費模型
        /// </summary>      
        private void test(Model model)
        {
            var TestImagePath = @"D:\Study\Blogs\TF_Net\Asset\mnist_png\test";
            DirectoryInfo TestDir = new DirectoryInfo(TestImagePath);
            foreach (var image in TestDir.GetFiles("*.png"))
            {
                var x = LoadImage(image.FullName);
                var pred_y = model.Apply(x);
                var result = argmax(pred_y[0].numpy());

                Console.WriteLine($"FileName:{image.Name}\tPred:{result}");
            }
        }

        private NDArray LoadImage(string filename)
        {
            float[,,] arrx = new float[1, img_rows, img_cols];
            Bitmap bmp = (Bitmap)Image.FromFile(filename);

            for (int row = 0; row < img_rows; row++)
                for (int col = 0; col < img_cols; col++)
                {
                    var pixel = bmp.GetPixel(col, row);
                    int val = (pixel.R + pixel.G + pixel.B) / 3;
                    arrx[0, row, col] = val;
                }

            return np.array(arrx);
        }

        private int argmax(NDArray array)
        {
            var arr = array.reshape(-1);

            float max = 0;
            for (int i = 0; i < 10; i++)
            {
                if (arr[i] > max)
                {
                    max = arr[i];
                }
            }

            for (int i = 0; i < 10; i++)
            {
                if (arr[i] == max)
                {
                    return i;
                }
            }

            return 0;
        }
    }
View Code

 另有兩點說明:

1、由於對圖片的讀取比較耗時,所以我采用了一個方法,就是把讀取到的數據序列化到一個二進制文件中,下次直接從二進制文件反序列化即可,大大加快處理速度。如果找不到bin文件就從圖片讀取,bin文件我沒有上傳到git庫里,所以下載項目后第一次運行需要一點時間。

2、我沒有采用validation圖片進行評估,只是簡單選了20個樣本測試了一下。

 

【相關資源】

 源碼:Git: https://gitee.com/seabluescn/tf_not.git

項目名稱:NN_MultipleClassification_MNIST

目錄:查看TensorFlow.NET機器學習入門系列目錄


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM