ML.NET教程之出租車車費預測(回歸問題)


理解問題

出租車的車費不僅與距離有關,還涉及乘客數量,是否使用信用卡等因素(這是的出租車是指紐約市的)。所以並不是一個簡單的一元方程問題。

准備數據

建立一控制台應用程序工程,新建Data文件夾,在其目錄下添加taxi-fare-train.csvtaxi-fare-test.csv文件,不要忘了把它們的Copy to Output Directory屬性改為Copy if newer。之后,添加Microsoft.ML類庫包。

加載數據

新建MLContext對象,及創建TextLoader對象。TextLoader對象可用於從文件中讀取數據。

MLContext mlContext = new MLContext(seed: 0);

_textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
{
    Separator = ",",
    HasHeader = true,
    Column = new[]
    {
        new TextLoader.Column("VendorId", DataKind.Text, 0),
        new TextLoader.Column("RateCode", DataKind.Text, 1),
        new TextLoader.Column("PassengerCount", DataKind.R4, 2),
        new TextLoader.Column("TripTime", DataKind.R4, 3),
        new TextLoader.Column("TripDistance", DataKind.R4, 4),
        new TextLoader.Column("PaymentType", DataKind.Text, 5),
        new TextLoader.Column("FareAmount", DataKind.R4, 6)
    }
});

提取特征

數據集文件里共有七列,前六列做為特征數據,最后一列是標記數據。

public class TaxiTrip
{
    [Column("0")]
    public string VendorId;

    [Column("1")]
    public string RateCode;

    [Column("2")]
    public float PassengerCount;

    [Column("3")]
    public float TripTime;

    [Column("4")]
    public float TripDistance;

    [Column("5")]
    public string PaymentType;

    [Column("6")]
    public float FareAmount;
}

public class TaxiTripFarePrediction
{
    [ColumnName("Score")]
    public float FareAmount;
}

訓練模型

首先讀取訓練數據集,其次建立管道。管道中第一步是把FareAmount列復制到Label列,做為標記數據。第二步,通過OneHotEncoding方式將VendorIdRateCodePaymentType三個字符串類型列轉換成數值類型列。第三步,合並六個數據列為一個特征數據列。最后一步,選擇FastTreeRegressionTrainer算法做為訓練方法。
完成管道后,開始訓練模型。

IDataView dataView = _textLoader.Read(dataPath);
var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
    .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
    .Append(mlContext.Regression.Trainers.FastTree());
var model = pipeline.Fit(dataView);

評估模型

這里要使用測試數據集,並用回歸問題的Evaluate方法進行評估。

IDataView dataView = _textLoader.Read(_testDataPath);
var predictions = model.Transform(dataView);
var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
Console.WriteLine();
Console.WriteLine($"*************************************************");
Console.WriteLine($"*       Model quality metrics evaluation         ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
Console.WriteLine($"*       RMS loss:      {metrics.Rms:#.##}");

保存模型

完成訓練的模型可以被保存為zip文件以備之后使用。

using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
    mlContext.Model.Save(model, fileStream);

使用模型

首先加載已經保存的模型。接着建立預測函數對象,TaxiTrip為函數的輸入類型,TaxiTripFarePrediction為輸出類型。之后執行預測方法,傳入待測數據。

ITransformer loadedModel;
using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
{
    loadedModel = mlContext.Model.Load(stream);
}

var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext);

var taxiTripSample = new TaxiTrip()
{
    VendorId = "VTS",
    RateCode = "1",
    PassengerCount = 1,
    TripTime = 1140,
    TripDistance = 3.75f,
    PaymentType = "CRD",
    FareAmount = 0 // To predict. Actual/Observed = 15.5
};

var prediction = predictionFunction.Predict(taxiTripSample);

Console.WriteLine($"**********************************************************************");
Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
Console.WriteLine($"**********************************************************************");

完整示例代碼

using Microsoft.ML;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using System;
using System.IO;

namespace TexiFarePredictor
{
    class Program
    {
        static readonly string _trainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-train.csv");
        static readonly string _testDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-test.csv");
        static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Data", "Model.zip");
        static TextLoader _textLoader;

        static void Main(string[] args)
        {
            MLContext mlContext = new MLContext(seed: 0);

            _textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
            {
                Separator = ",",
                HasHeader = true,
                Column = new[]
                {
                    new TextLoader.Column("VendorId", DataKind.Text, 0),
                    new TextLoader.Column("RateCode", DataKind.Text, 1),
                    new TextLoader.Column("PassengerCount", DataKind.R4, 2),
                    new TextLoader.Column("TripTime", DataKind.R4, 3),
                    new TextLoader.Column("TripDistance", DataKind.R4, 4),
                    new TextLoader.Column("PaymentType", DataKind.Text, 5),
                    new TextLoader.Column("FareAmount", DataKind.R4, 6)
                }
            });

            var model = Train(mlContext, _trainDataPath);

            Evaluate(mlContext, model);

            TestSinglePrediction(mlContext);

            Console.Read();
        }

        public static ITransformer Train(MLContext mlContext, string dataPath)
        {
            IDataView dataView = _textLoader.Read(dataPath);
            var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
                .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
                .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
                .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
                .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
                .Append(mlContext.Regression.Trainers.FastTree());
            var model = pipeline.Fit(dataView);
            SaveModelAsFile(mlContext, model);
            return model;
        }

        private static void SaveModelAsFile(MLContext mlContext, ITransformer model)
        {
            using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
                mlContext.Model.Save(model, fileStream);
        }

        private static void Evaluate(MLContext mlContext, ITransformer model)
        {
            IDataView dataView = _textLoader.Read(_testDataPath);
            var predictions = model.Transform(dataView);
            var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
            Console.WriteLine();
            Console.WriteLine($"*************************************************");
            Console.WriteLine($"*       Model quality metrics evaluation         ");
            Console.WriteLine($"*------------------------------------------------");
            Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
            Console.WriteLine($"*       RMS loss:      {metrics.Rms:#.##}");
        }

        private static void TestSinglePrediction(MLContext mlContext)
        {
            ITransformer loadedModel;
            using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
            {
                loadedModel = mlContext.Model.Load(stream);
            }

            var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext);

            var taxiTripSample = new TaxiTrip()
            {
                VendorId = "VTS",
                RateCode = "1",
                PassengerCount = 1,
                TripTime = 1140,
                TripDistance = 3.75f,
                PaymentType = "CRD",
                FareAmount = 0 // To predict. Actual/Observed = 15.5
            };

            var prediction = predictionFunction.Predict(taxiTripSample);

            Console.WriteLine($"**********************************************************************");
            Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
            Console.WriteLine($"**********************************************************************");
        }
    }
}

程序運行后顯示的結果:

*************************************************
*       Model quality metrics evaluation
*------------------------------------------------
*       R2 Score:      0.92
*       RMS loss:      2.81
**********************************************************************
Predicted fare: 15.7855, actual fare: 15.5
**********************************************************************

最后的預測結果還是比較符合實際數值的。


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM