機器學習框架ML.NET學習筆記【4】多元分類之手寫數字識別


一、問題與解決方案

通過多元分類算法進行手寫數字識別,手寫數字的圖片分辨率為8*8的灰度圖片、已經預先進行過處理,讀取了各像素點的灰度值,並進行了標記。

其中第0列是序號(不參與運算)、1-64列是像素值、65列是結果。

我們以64位像素值為特征進行多元分類,算法采用SDCA最大熵分類算法。

 

二、源碼

 先貼出全部代碼:

namespace MulticlassClassification_Mnist
{
    class Program
    {
        static readonly string TrainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "optdigits-full.csv");
        static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "SDCA-Model.zip");

        static void Main(string[] args)
        {
            MLContext mlContext = new MLContext(seed: 1);
          
            TrainAndSaveModel(mlContext);
            TestSomePredictions(mlContext);

            Console.WriteLine("Hit any key to finish the app");
            Console.ReadKey();
        }
              

        public static void TrainAndSaveModel(MLContext mlContext)
        {
            // STEP 1: 准備數據
            var fulldata = mlContext.Data.LoadFromTextFile(path: TrainDataPath,
                    columns: new[]
                    {
                        new TextLoader.Column("Serial", DataKind.Single, 0),
                        new TextLoader.Column("PixelValues", DataKind.Single, 1, 64),
                        new TextLoader.Column("Number", DataKind.Single, 65)
                    },
                    hasHeader: true,
                    separatorChar: ','
                    );

            var trainTestData = mlContext.Data.TrainTestSplit(fulldata, testFraction: 0.2);
            var trainData = trainTestData.TrainSet;
            var testData = trainTestData.TestSet;

            // STEP 2: 配置數據處理管道        
            var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue);

            // STEP 3: 配置訓練算法
            var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(labelColumnName: "Label", featureColumnName: "PixelValues");
            var trainingPipeline = dataProcessPipeline.Append(trainer)
              .Append(mlContext.Transforms.Conversion.MapKeyToValue("Number", "Label"));
            
            // STEP 4: 訓練模型使其與數據集擬合
            Console.WriteLine("=============== Train the model fitting to the DataSet ===============");           

            ITransformer trainedModel = trainingPipeline.Fit(trainData);         


            // STEP 5:評估模型的准確性
            Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
            var predictions = trainedModel.Transform(testData);
            var metrics = mlContext.MulticlassClassification.Evaluate(data: predictions, labelColumnName: "Number", scoreColumnName: "Score");
            PrintMultiClassClassificationMetrics(trainer.ToString(), metrics);
         
            // STEP 6:保存模型              
            mlContext.ComponentCatalog.RegisterAssembly(typeof(DebugConversion).Assembly);
            mlContext.Model.Save(trainedModel, trainData.Schema, ModelPath);
            Console.WriteLine("The model is saved to {0}", ModelPath);
        }

        private static void TestSomePredictions(MLContext mlContext)
        {
            // Load Model           
            ITransformer trainedModel = mlContext.Model.Load(ModelPath, out var modelInputSchema);

            // Create prediction engine 
            var predEngine = mlContext.Model.CreatePredictionEngine<InputData, OutPutData>(trainedModel);

            //num 1
            InputData MNIST1 = new InputData()
            {               
                PixelValues = new float[] { 0, 0, 0, 0, 14, 13, 1, 0, 0, 0, 0, 5, 16, 16, 2, 0, 0, 0, 0, 14, 16, 12, 0, 0, 0, 1, 10, 16, 16, 12, 0, 0, 0, 3, 12, 14, 16, 9, 0, 0, 0, 0, 0, 5, 16, 15, 0, 0, 0, 0, 0, 4, 16, 14, 0, 0, 0, 0, 0, 1, 13, 16, 1, 0 }
            }; 
            var resultprediction1 = predEngine.Predict(MNIST1);
            resultprediction1.PrintToConsole();           
        }      
    }

    class InputData
    {
        public float Serial;
        [VectorType(64)]
        public float[] PixelValues;               
        public float Number;       
    }

    class OutPutData : InputData
    {  
        public float[] Score;  
    }   
}
View Code

  

三、分析

 整體流程和二元分類沒有什么區別,下面解釋一下有差異的兩個地方。

 1、加載數據

      // STEP 1: 准備數據
            var fulldata = mlContext.Data.LoadFromTextFile(path: TrainDataPath,
                    columns: new[]
                    {
                        new TextLoader.Column("Serial", DataKind.Single, 0),
                        new TextLoader.Column("PixelValues", DataKind.Single, 1, 64),
                        new TextLoader.Column("Number", DataKind.Single, 65)
                    },
                    hasHeader: true,
                    separatorChar: ','
                    );

  這次我們不是通過實體對象來加載數據,而是通過列信息來進行加載,其中PixelValues是特征值,Number是標簽值。

 

2、訓練通道

            // STEP 2: 配置數據處理管道        
            var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Number", keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue)

// STEP 3: 配置訓練算法 var trainer = mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(labelColumnName: "Label", featureColumnName: "PixelValues");
var trainingPipeline = dataProcessPipeline.Append(trainer)
.Append(mlContext.Transforms.Conversion.MapKeyToValue(
"Number", "Label"));

// STEP 4: 訓練模型使其與數據集擬合
ITransformer trainedModel
= trainingPipeline.Fit(trainData);

 首先通過MapValueToKey方法將Number值轉換為Key類型,多元分類算法要求標簽值必須是這種類型(類似枚舉類型,二元分類要求標簽為BOOL類型)。關於這個轉換的原因及編碼方式,下面詳細介紹。

 

四、鍵值類型編碼與獨熱編碼

 MapValueToKey功能是將(字符串)值類型轉換為KeyTpye類型。

有時候某些輸入字段用來表示類型(類別特征),但本身並沒有特別的含義,比如編號、電話號碼、行政區域名稱或編碼等,這里需要把這些類型轉換為1到一個整數如1-300來進行重新編號。

舉個簡單的例子,我們進行圖片識別的時候,目標結果可能是“貓咪”、“小狗”、“人物”這些分類,需要把這些分類轉換為1、2、3這樣的整數。但本文的標簽值本身就是1、2、3,為什么還要轉換呢?因為我們這里的一二三其實不是數學意義上的數字,而是一種標志,可以理解為壹、貳、叄,所以要進行編碼。

 MapKeyToValue和MapValueToKey相反,它把將鍵類型轉換回其原始值(字符串)。就是說標簽是文本格式,在運算前已經被轉換為數字枚舉類型了,此時預測結果為數字,通過MapKeyToValue將其結果轉換為對應文本。

MapValueToKey一般是對標簽值進行編碼,一般不用於特征值,如果是特征值為字符串類型的,建議采用獨熱編碼。獨熱編碼即 One-Hot 編碼,又稱一位有效編碼,其方法是使用N位狀態寄存器來對N個狀態進行編碼,每個狀態都由他獨立的寄存器位,並且在任意時候,其中只有一位有效。例如:

自然狀態碼為:0,1,2,3,4,5
獨熱編碼為:000001,000010,000100,001000,010000,100000

怎么理解這個事情呢?舉個例子,假如我們要進行人的身材的分析,但我們希望加入地域特征,比如:“黑龍江”、“山東”、“湖南”、“廣東”這種特征,但這種字符串機器學習是不認識的,必須轉換為浮點數,剛才提到MapKeyToValue可以把字符串轉換為數字,為什么這里要采用獨熱編碼呢?簡單來說,假設把地域名稱轉換為1到10幾個數字,在歐氏幾何中1到3的歐拉距離和1到9的歐拉距離是不等的,但經過獨熱編碼后,任意兩點間的歐拉距離都是相等的,而我們這里的地域特征僅僅是想表達分類關系,彼此之間沒有其他邏輯關系,所以應該采用獨熱編碼。

 

五、進度調試

一般機器算法的數據擬合過程時間都比較長,有時程序跑了兩個小時還沒結束,也不知道還需要多長時間,着實讓人着急,所以及時了解學習進度,是很有必要的。

由於機器學習算法一般都有“遞歸直到收斂”這種操作,所以我們是沒有辦法預先知道最終運算次數的,能做到的只能打印一些過程信息,看到程序在動,心里也有點底,當系統跑過一次之后,基本就大致知道需要多少次擬合了,后面再調試就可以大致了解進度了。補充一句,可不可以在測試階段先減少樣本數據進行快速調試,調試通過后再切換到全樣本進行訓練?其實不行,有時候樣本數量小,可能會引起指標震盪,時間反而長了。

之前在Githube上看到有人通過MLContext.LOG事件來打印調試信息,我試了一下,發現沒法控制篩選內容,不太方便,后來想到一個方法,就是新增一個自定義數據處理通道,這個通道不做具體事情,就打印調試信息。

類定義:

namespace MulticlassClassification_Mnist
{
    public class DebugConversionInput
    {
        public float Serial { get; set; }
    }
 
    public class DebugConversionOutput
    {
        public float DebugFeature { get; set; }
    }

    [CustomMappingFactoryAttribute("DebugConversionAction")]
    public class DebugConversion : CustomMappingFactory<DebugConversionInput, DebugConversionOutput>
    {       static long TotalCount = 0;

        public void CustomAction(DebugConversionInput input, DebugConversionOutput output)
        {
            output.DebugFeature = 1.0f;  
TotalCount++;
Console.WriteLine($"DebugConversion.CustomAction's debug info.TotalCount={TotalCount} "); } public override Action<DebugConversionInput, DebugConversionOutput> GetMapping() => CustomAction; } }

 使用方法:

 var dataProcessPipeline = mlContext.Transforms.CustomMapping(new DebugConversion().GetMapping(), contractName: "DebugConversionAction")
       .Append(...)
       .Append(mlContext.Transforms.Concatenate("Features", new string[] { "RealFeatures", "DebugFeature" }));

 通過CustomMapping加載我們自定義的數據處理通道,由於數據集是懶加載(Lazy)的,所以必須把我們自定義數據處理通道的輸出加入為特征值,才能參與運算,然后算法在操作每一條數據時都會調用到CustomAction方法,這樣就可以打印進度信息了。為了不影響運算結果,我們把這個數據處理通道的輸出值固定為1.0f 。

 

六、資源獲取

源碼下載地址:https://github.com/seabluescn/Study_ML.NET

工程名稱:MulticlassClassification_Mnist

點擊查看機器學習框架ML.NET學習筆記系列文章目錄


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM