基於 ONNX 在 ML.NET 中使用 Pytorch 訓練的垃圾分類模型


ML.NET 在經典機器學習范疇內,對分類、回歸、異常檢測等問題開發模型已經有非常棒的表現了,我之前的文章都有過介紹。當然我們希望在更高層次的領域加以使用,例如計算機視覺、自然語言處理和信號處理等等領域。

圖像識別是計算機視覺的一類分支,AI研發者們較為熟悉的是使用TensorFlow、Pytorch、Keras、MXNET等框架來訓練深度神經網絡模型,其中會涉及到CNN(卷積神經網絡)、DNN(深度神經網絡)的相關算法。

ML.NET 在較早期的版本是無法支持這類研究的,可喜的是最新的版本不但能很好地集成 TensorFlow 的模型做遷移學習,還可以直接導入 DNN 常見預編譯模型:AlexNet、ResNet18、ResNet50、ResNet101 實現對圖像的分類、識別等。

我特別想推薦的是,ML.NET 最新版本對 ONNX 的支持也是非常強勁,通過 ONNX 可以把眾多其他優秀深度學習框架的模型引入到 .NET Core 運行時中,極大地擴充了 .NET 應用在智能認知服務的豐富程度。在 Microsoft Docs 中已經提供了一個基於 ONNX 使用 Tiny YOLOv2 做對象檢測的例子。為了展現 ML.NET 在其他框架上的通用性,本文將介紹使用 Pytorch 訓練的垃圾分類的模型,基於 ONNX 導入到 ML.NET 中完成預測。

在2019年9月華為雲舉辦了一次人工智能大賽·垃圾分類挑戰杯,首次將AI與環保主題結合,展現人工智能技術在生活中的運用。有幸我看到了本次大賽亞軍方案的分享,並且在 github 上找到了開源代碼,按照 README 說明,我用 Pytorch 訓練出了一個模型,並保存為garbage.pt 文件。

生成 ONNX 模型

首先,我使用以下 Pytorch 代碼來生成一個garbage.pt 對應的文件,命名為 garbage.onnx

torch_model = torch.load("garbage.pt") # pytorch模型加載
    batch_size = 1  #批處理大小
    input_shape = (3,224,224)   #輸入數據

    # # set the model to inference mode
    torch_model.eval()

    x = torch.randn(batch_size, *input_shape, device='cuda')        # 生成張量
    export_onnx_file = "garbage.onnx"                    # 目的ONNX文件名
 
    
    torch.onnx.export(torch_model.module,
                        x,
                        export_onnx_file,
                        input_names=["input"],        # 輸入名
                        output_names=["output"]    # 輸出名

准備 ML.NET 項目

創建一個 .NET Core 控制台應用,按如下結構創建好合適的目錄。assets 目錄下的 images 子目錄將放置待預測的圖片,而 Model 子目錄放入前一個步驟生成的 garbage.onnx 模型文件。

ImageNetData 和 ImageNetPrediction 類定義了輸入和輸出的數據結構。

using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML.Data;

namespace GarbageDemo.DataStructures
{
    public class ImageNetData
    {
        [LoadColumn(0)]
        public string ImagePath;

        [LoadColumn(1)]
        public string Label;

        public static IEnumerable<ImageNetData> ReadFromFile(string imageFolder)
        {
            return Directory
               .GetFiles(imageFolder)
               .Where(filePath => Path.GetExtension(filePath) == ".jpg")
               .Select(filePath => new ImageNetData { ImagePath = filePath, Label = Path.GetFileName(filePath) });

        }
    }

    public class ImageNetPrediction : ImageNetData
    {
        public float[] Score;

        public string PredictedLabelValue;
    }
}

 OnnxModelScorer 類定義了 ONNX 模型加載、打分預測的過程。注意 ImageNetModelSettings 的屬性和第一步中指定的輸入輸出字段名要一致。

using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms.Onnx;
using Microsoft.ML.Transforms.Image;
using GarbageDemo.DataStructures;

namespace GarbageDemo
{
    class OnnxModelScorer
    {
        private readonly string imagesFolder;
        private readonly string modelLocation;
        private readonly MLContext mlContext;


        public OnnxModelScorer(string imagesFolder, string modelLocation, MLContext mlContext)
        {
            this.imagesFolder = imagesFolder;
            this.modelLocation = modelLocation;
            this.mlContext = mlContext;
        }

        public struct ImageNetSettings
        {
            public const int imageHeight = 224;
            public const int imageWidth = 224;    
            public const float Mean = 127;
            public const float Scale = 1;
            public const bool ChannelsLast = false;
        } 
        
        public struct ImageNetModelSettings
        {
            // input tensor name
            public const string ModelInput = "input";

            // output tensor name
            public const string ModelOutput = "output";
        }

        private ITransformer LoadModel(string modelLocation)
        {
            Console.WriteLine("Read model");
            Console.WriteLine($"Model location: {modelLocation}");
            Console.WriteLine($"Default parameters: image size=({ImageNetSettings.imageWidth},{ImageNetSettings.imageHeight})");

            // Create IDataView from empty list to obtain input data schema
            var data = mlContext.Data.LoadFromEnumerable(new List<ImageNetData>());

            // Define scoring pipeline
            var pipeline = mlContext.Transforms.LoadImages(outputColumnName: ImageNetModelSettings.ModelInput, imageFolder: "", inputColumnName: nameof(ImageNetData.ImagePath))                           
                            .Append(mlContext.Transforms.ResizeImages(outputColumnName: ImageNetModelSettings.ModelInput, 
                                                                        imageWidth: ImageNetSettings.imageWidth, 
                                                                        imageHeight: ImageNetSettings.imageHeight, 
                                                                        inputColumnName: ImageNetModelSettings.ModelInput,
                                                                        resizing: ImageResizingEstimator.ResizingKind.IsoCrop,
                                                                        cropAnchor: ImageResizingEstimator.Anchor.Center
                                                                        ))
                            .Append(mlContext.Transforms.ExtractPixels(outputColumnName: ImageNetModelSettings.ModelInput, interleavePixelColors: ImageNetSettings.ChannelsLast))
                            .Append(mlContext.Transforms.NormalizeGlobalContrast(outputColumnName: ImageNetModelSettings.ModelInput, 
                                                                                 inputColumnName: ImageNetModelSettings.ModelInput, 
                                                                                 ensureZeroMean : true, 
                                                                                 ensureUnitStandardDeviation: true, 
                                                                                 scale: ImageNetSettings.Scale))
                            .Append(mlContext.Transforms.ApplyOnnxModel(modelFile: modelLocation, outputColumnNames: new[] { ImageNetModelSettings.ModelOutput }, inputColumnNames: new[] { ImageNetModelSettings.ModelInput }));

            // Fit scoring pipeline
            var model = pipeline.Fit(data);

            return model;
        }

        private IEnumerable<float[]> PredictDataUsingModel(IDataView testData, ITransformer model)
        {
            Console.WriteLine($"Images location: {imagesFolder}");
            Console.WriteLine("");
            Console.WriteLine("=====Identify the objects in the images=====");
            Console.WriteLine("");

            IDataView scoredData = model.Transform(testData);

            IEnumerable<float[]> probabilities = scoredData.GetColumn<float[]>(ImageNetModelSettings.ModelOutput);

            return probabilities;
        }

        public IEnumerable<float[]> Score(IDataView data)
        {
            var model = LoadModel(modelLocation);

            return PredictDataUsingModel(data, model);
        }
    }
}

Program 類中定義了調用過程,完成預測結果呈現。

using GarbageDemo.DataStructures;
using Microsoft.ML;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;

namespace GarbageDemo
{
    class Program
    {
        static void Main(string[] args)
        {
            var assetsRelativePath = @"../../../assets";
            string assetsPath = GetAbsolutePath(assetsRelativePath);
            var modelFilePath = Path.Combine(assetsPath, "Model", "garbage.onnx");
            var imagesFolder = Path.Combine(assetsPath, "images");// Initialize MLContext
            MLContext mlContext = new MLContext();

            try
            {
                // Load Data
                IEnumerable<ImageNetData> images = ImageNetData.ReadFromFile(imagesFolder);
                IDataView imageDataView = mlContext.Data.LoadFromEnumerable(images);

                // Create instance of model scorer
                var modelScorer = new OnnxModelScorer(imagesFolder, modelFilePath, mlContext);

                // Use model to score data
                IEnumerable<float[]> probabilities = modelScorer.Score(imageDataView);

                int index = 0;
                foreach (var probable in probabilities)
                {
                    var scores = Softmax(probable);

                    var (topResultIndex, topResultScore) = scores.Select((predictedClass, index) => (Index: index, Value: predictedClass))
                        .OrderByDescending(result => result.Value)
                        .First();
                    Console.WriteLine("圖片:{3} \r\n 分類{2} {0}:{1}", labels[topResultIndex], topResultScore, topResultIndex, images.ElementAt(index).ImagePath);
                    Console.WriteLine("=============================");
                    index++;
                }

            }
            catch (Exception ex)
            {
                Console.WriteLine(ex.ToString());
            }

            Console.WriteLine("========= End of Process..Hit any Key ========");
            Console.ReadLine();
        }

        public static string GetAbsolutePath(string relativePath)
        {
            FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
            string assemblyFolderPath = _dataRoot.Directory.FullName;

            string fullPath = Path.Combine(assemblyFolderPath, relativePath);

            return fullPath;
        }

        private static float[] Softmax(float[] values)
        {
            var maxVal = values.Max();
            var exp = values.Select(v => Math.Exp(v - maxVal));
            var sumExp = exp.Sum();

            return exp.Select(v => (float)(v / sumExp)).ToArray();
        }

        private static string[] labels = new string[]
        {
            "其他垃圾/一次性快餐盒",
            "其他垃圾/污損塑料",
            "其他垃圾/煙蒂",
            "其他垃圾/牙簽",
            "其他垃圾/破碎花盆及碟碗",
            "其他垃圾/竹筷",
            "廚余垃圾/剩飯剩菜",
            "廚余垃圾/大骨頭",
            "廚余垃圾/水果果皮",
            "廚余垃圾/水果果肉",
            "廚余垃圾/茶葉渣",
            "廚余垃圾/菜葉菜根",
            "廚余垃圾/蛋殼",
            "廚余垃圾/魚骨",
            "可回收物/充電寶",
            "可回收物/包",
            "可回收物/化妝品瓶",
            "可回收物/塑料玩具",
            "可回收物/塑料碗盆",
            "可回收物/塑料衣架",
            "可回收物/快遞紙袋",
            "可回收物/插頭電線",
            "可回收物/舊衣服",
            "可回收物/易拉罐",
            "可回收物/枕頭",
            "可回收物/毛絨玩具",
            "可回收物/洗發水瓶",
            "可回收物/玻璃杯",
            "可回收物/皮鞋",
            "可回收物/砧板",
            "可回收物/紙板箱",
            "可回收物/調料瓶",
            "可回收物/酒瓶",
            "可回收物/金屬食品罐",
            "可回收物/鍋",
            "可回收物/食用油桶",
            "可回收物/飲料瓶",
            "有害垃圾/干電池",
            "有害垃圾/軟膏",
            "有害垃圾/過期葯物",
            "可回收物/毛巾",
            "可回收物/飲料盒",
            "可回收物/紙袋"
        };

選擇一張圖片放到 images 目錄中,運行結果如下:

有 0.88 的得分說明照片中的物品屬於污損塑料,讓我們看一下圖片真相。

果然是相當准確 ,並且把周邊的附屬物都過濾掉了。

對於 ML.NET 訓練深度神經網絡模型支持更復雜的場景是不是更有信心了!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM