ML.NET 示例:二元分類之垃圾短信檢測


寫在前面

准備近期將微軟的machinelearning-samples翻譯成中文,水平有限,如有錯漏,請大家多多指正。
如果有朋友對此感興趣,可以加入我:https://github.com/feiyun0112/machinelearning-samples.zh-cn

垃圾短信檢測

ML.NET 版本 API 類型 狀態 應用程序類型 數據類型 場景 機器學習任務 算法
v0.7 動態API 可能需要更新項目結構以匹配模板 控制台應用程序 .tsv 文件 垃圾信息檢測 二元分類 SDCA(線性學習器),還展示了CustomMapping評估器,它可以將自定義代碼添加到ML.NET管道

在這個示例中,您將看到如何使用ML.NET來預測短信是否是垃圾信息。在機器學習領域中,這種類型的預測被稱為二元分類

問題

我們的目標是預測一個短信是否是垃圾信息(一個不相關的/不想要的消息)。我們將使用UCI的SMS Spam Collection Data Set,其中包含近6000條被分類為“垃圾信息”或“ham”(不是垃圾信息)的消息。我們將使用這個數據集來訓練一個模型,該模型可以接收新消息並預測它們是否是垃圾信息。

這是一個二元分類的示例,因為我們將短信分類為兩個類別。

解決方案

要解決這個問題,首先我們將建立一個評估器來定義我們想要使用的機器學習管道。 然后,我們將在現有數據上訓練這個評估器,評估其有多好,最后我們將使用該模型來預測一些示例消息是否是垃圾信息。

建立 -> 訓練 -> 評估 -> 使用

1. 建立評估器

為了建立評估器,我們將:

  • 定義如何讀取從 https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection 下載的垃圾信息數據集。

  • 應用多個數據轉換:

    • 將標簽(“spam”或“ham”)轉換為布爾值(“true”表示垃圾信息),這樣我們就可以在二元分類器中使用它。
    • 將短信轉換為數字向量,以便機器學習訓練器可以使用它
  • 添加一個訓練器(如StochasticDualCoordinateAscent)。

初始代碼類似以下內容:

// Set up the MLContext, which is a catalog of components in ML.NET.
var mlContext = new MLContext();

// Create the reader and define which columns from the file should be read.
var reader = new TextLoader(mlContext, new TextLoader.Arguments()
{
    Separator = "tab",
    HasHeader = true,
    Column = new[]
        {
            new TextLoader.Column("Label", DataKind.Text, 0),
            new TextLoader.Column("Message", DataKind.Text, 1)
        }
});

var data = reader.Read(new MultiFileSource(TrainDataPath));

// Create the estimator which converts the text label to boolean, featurizes the text, and adds a linear trainer.
var estimator = mlContext.Transforms.CustomMapping<MyInput, MyOutput>(MyLambda.MyAction, "MyLambda")
    .Append(mlContext.Transforms.Text.FeaturizeText("Message", "Features"))
    .Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent());

2. 評估模型

對於這個數據集,我們將使用交叉驗證來評估我們的模型。將數據集划分成5個不相交的子集,訓練5個模型(每個模型使用其中4個子集),並在訓練中沒有使用的數據子集上測試模型。

var cvResults = mlContext.BinaryClassification.CrossValidate(data, estimator, numFolds: 5);
var aucs = cvResults.Select(r => r.metrics.Auc);
Console.WriteLine("The AUC is {0}", aucs.Average());

請注意,通常我們在訓練后評估模型。 但是,交叉驗證包括模型訓練部分,因此我們不需要先執行Fit()。 但是,我們稍后將在完整數據集上訓練模型以利用其他數據。

3. 訓練模型

為了訓練模型,我們將調用評估器的Fit()方法,同時提供完整的訓練數據。

var model = estimator.Fit(data);

4. 使用模型

訓練完模型后,您可以使用Predict()API來預測新文本是否垃圾信息。 在這種情況下,我們更改模型的閾值以獲得更好的預測。 我們這樣做是因為我們的數據有偏差,大多數消息都不是垃圾信息。

// The dataset we have is skewed, as there are many more non-spam messages than spam messages.
// While our model is relatively good at detecting the difference, this skewness leads it to always
// say the message is not spam. We deal with this by lowering the threshold of the predictor. In reality,
// it is useful to look at the precision-recall curve to identify the best possible threshold.
var inPipe = new TransformerChain<ITransformer>(model.Take(model.Count() - 1).ToArray());
var lastTransformer = new BinaryPredictionTransformer<IPredictorProducing<float>>(mlContext, model.LastTransformer.Model, inPipe.GetOutputSchema(data.Schema), model.LastTransformer.FeatureColumn, threshold: 0.15f, thresholdColumn: DefaultColumnNames.Probability);

ITransformer[] parts = model.ToArray();
parts[parts.Length - 1] = lastTransformer;
var newModel = new TransformerChain<ITransformer>(parts);

// Create a PredictionFunction from our model 
var predictor = newModel.MakePredictionFunction<SpamInput, SpamPrediction>(mlContext);

var input = new SpamInput { Message = "free medicine winner! congratulations" };
Console.WriteLine("The message '{0}' is {1}", input.Message, predictor.Predict(input).isSpam ? "spam" : "not spam");


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM