C#中的深度學習(二):預處理識別硬幣的數據集


在文章中,我們將對輸入到機器學習模型中的數據集進行預處理。

這里我們將對一個硬幣數據集進行預處理,以便以后在監督學習模型中進行訓練。在機器學習中預處理數據集通常涉及以下任務:

  1. 清理數據——通過對周圍數據的平均值或使用其他策略來填補數據缺失或損壞造成的漏洞。
  2. 規范數據——將數據縮放值標准化到一個標准范圍,通常是0到1。具有廣泛值范圍的數據可能會導致不規范,因此我們將所有數據都放在一個公共范圍內。
  3. 一種熱編碼標簽——將數據集中對象的標簽或類編碼為N維二進制向量,其中N是類的總數。數組元素都被設置為0,除了與對象的類相對應的元素,它被設置為1。這意味着在每個數組中都有一個值為1的元素。
  4. 將輸入數據集分為訓練集和驗證集——訓練集被用於訓練模型,驗證集是用於檢查我們的訓練結果。

這個例子我們將使用Numpy.NET,它基本上是Python中流行的Numpy庫的.NET版本。

Numpy是一個專注於處理矩陣的庫。

為了實現我們的數據集處理器,我們在PreProcessing文件夾中創建Utils類和DataSet類。Utils類合並了一個靜態Normalize 方法,如下所示:

public class Utils
   {
       public static NDarray Normalize(string path)
       {
           var colorMode = Settings.Channels == 3 ? "rgb" : "grayscale";
           var img = ImageUtil.LoadImg(path, color_mode: colorMode, target_size: (Settings.ImgWidth, Settings.ImgHeight));
           return ImageUtil.ImageToArray(img) / 255;
       }

   }

在這種方法中,我們用給定的顏色模式(RGB或灰度)加載圖像,並將其調整為給定的寬度和高度。然后我們返回包含圖像的矩陣,每個元素除以255。每個元素除以255是使它們標准化,因為圖像中任何像素的值都在0到255之間,所以通過將它們除以255,我們確保了新的范圍是0到1,包括255。

我們還在代碼中使用了一個Settings類。該類包含用於跨應用程序使用的許多常量。另一個類DataSet,表示我們將要用來訓練機器學習模型的數據集。這里我們有以下字段:

  1. _pathToFolder—包含圖像的文件夾的路徑。
  2. _extList—要考慮的文件擴展名列表。
  3. _labels—_pathToFolder中圖像的標簽或類。
  4. _objs -圖像本身,表示為Numpy.NDarray。
  5. _validationSplit—用於將總圖像數划分為驗證集和訓練集的百分比,在本例中,百分比將定義驗證集與總圖像數之間的大小。
  6. NumberClasses-數據集中唯一類的總數。
  7. TrainX -訓練數據,表示為Numpy.NDarray。
  8. TrainY -訓練標簽,表示為Numpy.NDarray。
  9. ValidationX—驗證數據,表示為Numpy.NDarray。
  10. ValidationY-驗證標簽,表示為Numpy.NDarray。

這是DataSet類:

public class DataSet
    {
        private string _pathToFolder;
        private string[] _extList;
        private List<int> _labels;
        private List<NDarray> _objs;
        private double _validationSplit;
        public int NumberClasses { get; set; }
        public NDarray TrainX { get; set; }
        public NDarray ValidationX { get; set; }
        public NDarray TrainY { get; set; }
        public NDarray ValidationY { get; set; }

        public DataSet(string pathToFolder, string[] extList, int numberClasses, double validationSplit)
        {
            _pathToFolder = pathToFolder;
            _extList = extList;
            NumberClasses = numberClasses;
            _labels = new List<int>();
            _objs = new List<NDarray>();
            _validationSplit = validationSplit;
        }

        public void LoadDataSet()
        {
            // Process the list of files found in the directory.
            string[] fileEntries = Directory.GetFiles(_pathToFolder);
            foreach (string fileName in fileEntries)
                if (IsRequiredExtFile(fileName))
                    ProcessFile(fileName);

            MapToClassRange();
            GetTrainValidationData();
        }

        private bool IsRequiredExtFile(string fileName)
        {
            foreach (var ext in _extList)
            {
                if (fileName.Contains("." + ext))
                {
                    return true;
                }
            }

            return false;
        }

        private void MapToClassRange()
        {
            HashSet<int> uniqueLabels = _labels.ToHashSet();
            var uniqueLabelList = uniqueLabels.ToList();
            uniqueLabelList.Sort();

            _labels = _labels.Select(x => uniqueLabelList.IndexOf(x)).ToList();
        }

        private NDarray OneHotEncoding(List<int> labels)
        {
            var npLabels = np.array(labels.ToArray()).reshape(-1);
            return Util.ToCategorical(npLabels, num_classes: NumberClasses);
        }

        private void ProcessFile(string path)
        {
            _objs.Add(Utils.Normalize(path));
            ProcessLabel(Path.GetFileName(path));
        }

        private void ProcessLabel(string filename)
        {
            _labels.Add(int.Parse(ExtractClassFromFileName(filename)));
        }

        private string ExtractClassFromFileName(string filename)
        {
            return filename.Split('_')[0].Replace("class", "");
        }

        private void GetTrainValidationData()
        {
            var listIndices = Enumerable.Range(0, _labels.Count).ToList();
            var toValidate = _objs.Count * _validationSplit;
            var random = new Random();
            var xValResult = new List<NDarray>();
            var yValResult = new List<int>();
            var xTrainResult = new List<NDarray>();
            var yTrainResult = new List<int>();

            // Split validation data
            for (var i = 0; i < toValidate; i++)
            {
                var randomIndex = random.Next(0, listIndices.Count);
                var indexVal = listIndices[randomIndex];
                xValResult.Add(_objs[indexVal]);
                yValResult.Add(_labels[indexVal]);
                listIndices.RemoveAt(randomIndex);
            }

            // Split rest (training data)
            listIndices.ForEach(indexVal => 
            { 
                xTrainResult.Add(_objs[indexVal]);
                yTrainResult.Add(_labels[indexVal]);
            });

            TrainY = OneHotEncoding(yTrainResult);
            ValidationY = OneHotEncoding(yValResult);
            TrainX = np.array(xTrainResult);
            ValidationX = np.array(xValResult);
        }
}

下面是每個方法的說明:

  1. LoadDataSet()——類的主方法,我們調用它來加載_pathToFolder中的數據集。它調用下面列出的其他方法來完成此操作。
  2. IsRequiredExtFile(filename) - 檢查給定文件是否包含至少一個應該為該數據集處理的擴展名(在_extList中列出)。
  3. MapToClassRange() -獲取數據集中唯一標簽的列表。
  4. ProcessFile(path) -使用Utils.Normalize方法對圖像進行規格化,並調用ProcessLabel方法。
  5. ProcessLabel(filename)——將ExtractClassFromFileName方法的結果添加為標簽。
  6. ExtractClassFromFileName(filename) -從圖像的文件名中提取類。
  7. GetTrainValidationData()——將數據集划分為訓練子數據集和驗證子數據集。

在本系列中,我們將使用https://cvl.tuwien.ac.at/research/cvl-databases/coin-image-dataset/上的硬幣圖像數據集。

要加載數據集,我們可以在控制台應用程序的主類中包含以下內容:

var numberClasses = 60;
var fileExt = new string[] { ".png" };
var dataSetFilePath = @"C:/Users/arnal/Downloads/coin_dataset";
var dataSet = new PreProcessing.DataSet(dataSetFilePath, fileExt, numberClasses, 0.2);
dataSet.LoadDataSet();

我們的數據現在可以輸入到機器學習模型中。下一篇文章將介紹監督機器學習的基礎知識,以及訓練和驗證階段包括哪些內容。它是為沒有AI經驗的讀者准備的。

歡迎關注我的公眾號,如果你有喜歡的外文技術文章,可以通過公眾號留言推薦給我。

原文鏈接:https://www.codeproject.com/Articles/5284219/Deep-Learning-in-Csharp-Coin-Detection-Using-OpenC


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM