使用ML.NET實現NBA得分預測
導讀:ML.NET系列文章
ML.NET已經發布了v0.2版本,新增了聚類訓練器,執行性能進一步增強。本文將介紹一種特殊的回歸——泊松回歸,並以NBA比賽得分預測的案例來演練。
泊松回歸 Poisson regression
前面的文章已提過,回歸是用來預測連續值的,泊松回歸是其中一種,其特殊在僅用於預測正整數,通常為計數類的數值。泊松分布是離散分布,所以特征值和標簽值應為相同(或接近相同)時間間隔下的獨立隨機事件。
那么什么場景是符合計數,可以適用泊松回歸呢?舉幾個例子,比如共享單車的調度,每一處地域中心,每隔1小時都要統計借車和還車數,根據這個統計我們就可以預測下一個小時此處地域需要調配多少車輛才能滿足需要。再比如,公司每個月都有離職員工,那么人力資源部門就可以對月人員流失數進行計數,然后通過泊松回歸來預測下個月的流失情況,以便提早采取措施做好招聘計划。
是不是有一點感覺了,本次我們用大家喜歡的NBA比賽得分來進行演練,因為比賽得分正好也是一種計數,也符合連續相同時間間隔(比賽時長的大體相近),比賽結果具有不確定性,所以也是泊松回歸大顯身手的地方,為了易於理解,我將示范預測的是主場球隊的得分。
NBA比賽數據
本案例數據來源Kaggle.com,內容是NBA Team Game Stats from 2014 to 2018,這份數據集收集了最近4年的NBA比賽,格式類似如下:
"","Team","Game","Date","Home","Opponent","WINorLOSS","TeamPoints","OpponentPoints","FieldGoals","FieldGoalsAttempted","FieldGoals.","X3PointShots","X3PointShotsAttempted","X3PointShots.","FreeThrows","FreeThrowsAttempted","FreeThrows.","OffRebounds","TotalRebounds","Assists","Steals","Blocks","Turnovers","TotalFouls","Opp.FieldGoals","Opp.FieldGoalsAttempted","Opp.FieldGoals.","Opp.3PointShots","Opp.3PointShotsAttempted","Opp.3PointShots.","Opp.FreeThrows","Opp.FreeThrowsAttempted","Opp.FreeThrows.","Opp.OffRebounds","Opp.TotalRebounds","Opp.Assists","Opp.Steals","Opp.Blocks","Opp.Turnovers","Opp.TotalFouls"
"1","ATL","1",2014-10-29,"Away","TOR","L","102","109","40","80",".500","13","22",".591","9","17",".529","10","42","26","6","8","17","24","37","90",".411","8","26",".308","27","33",".818","16","48","26","13","9","9","22"
"2","ATL","2",2014-11-01,"Home","IND","W","102","92","35","69",".507","7","20",".350","25","33",".758","3","37","26","10","6","12","20","31","81",".383","12","32",".375","18","21",".857","11","44","25","5","5","18","26"
"3","ATL","3",2014-11-05,"Away","SAS","L","92","94","38","92",".413","8","25",".320","8","11",".727","10","37","26","14","5","13","25","31","69",".449","5","17",".294","27","38",".711","11","50","25","7","9","19","15"
"4","ATL","4",2014-11-07,"Away","CHO","L","119","122","43","93",".462","13","33",".394","20","26",".769","7","38","28","8","3","19","33","48","97",".495","6","21",".286","20","27",".741","11","51","31","6","7","19","30"
"5","ATL","5",2014-11-08,"Home","NYK","W","103","96","33","81",".407","9","22",".409","28","36",".778","12","41","18","10","5","8","17","40","84",".476","8","21",".381","8","11",".727","13","44","26","2","6","15","29"
"6","ATL","6",2014-11-10,"Away","NYK","W","91","85","27","71",".380","10","27",".370","27","28",".964","9","38","20","7","3","15","16","36","83",".434","6","26",".231","7","12",".583","11","40","23","4","2","15","26"
"7","ATL","7",2014-11-12,"Home","UTA","W","100","97","39","76",".513","9","20",".450","13","18",".722","13","46","23","8","4","18","12","43","86",".500","5","23",".217","6","12",".500","8","30","28","12","8","11","17"
"8","ATL","8",2014-11-14,"Home","MIA","W","114","103","42","75",".560","11","28",".393","19","23",".826","3","36","33","10","5","13","20","35","74",".473","10","21",".476","23","25",".920","5","32","27","10","3","14","20"
各字段如下:
比賽基本信息:主場Team,比賽場次序號Game,比賽日期Date,主隊Home,客隊Opponent,主隊勝負Win or Loss。
比賽主客隊技術數據:Team Points,Field Goals,Field Goals Attempted,Field Goals Percentage,3 Point Shots,3 Point Shots Attempted,3 Point Shots Percentage,Free Throws,Free Throws Attempted,Free Throws Percentage,Offensive Rebounds,Total Rebounds,Assists,Steals,Blocks,Turnovers,Total Fouls。
這些指標反映了主客隊投籃出手次數、命中數、命中率,三分球的出手次數、命中數、命中率,罰球的出手次數、命中數、命中率,助攻,搶斷,犯規等,這些都是我們在看NBA時常見的統計。
由於只有這一份數據,為了分別用於訓練、評估和預測,我將數據集按7:2:1的比例進行分割。
代碼片段分解
定義原始數據結構、預測數據結構,TeamPoints是主隊得分,是本次示例要預測的目標,因此定義為標簽字段。
public class Match
{
[Column(ordinal: "0")]
public string Id;
[Column(ordinal: "1")]
public string Team;
[Column(ordinal: "2")]
public string Game;
[Column(ordinal: "3")]
public string Date;
[Column(ordinal: "4")]
public string Home;
[Column(ordinal: "5")]
public string Opponent;
[Column(ordinal: "6")]
public string WINorLOSS;
[Column(ordinal: "7", name: "Label")]
public float TeamPoints;
[Column(ordinal: "8")]
public float OpponentPoints;
[Column(ordinal: "9")]
public float FieldGoals;
[Column(ordinal: "10")]
public float FieldGoalsAttempted;
[Column(ordinal: "11")]
public float FieldGoals_;
[Column(ordinal: "12")]
public float X3PointShots;
[Column(ordinal: "13")]
public float X3PointShotsAttempted;
[Column(ordinal: "14")]
public float X3PointShots_;
[Column(ordinal: "15")]
public float FreeThrows;
[Column(ordinal: "16")]
public float FreeThrowsAttempted;
[Column(ordinal: "17")]
public float FreeThrows_;
[Column(ordinal: "18")]
public float OffRebounds;
[Column(ordinal: "19")]
public float TotalRebounds;
[Column(ordinal: "20")]
public float Assists;
[Column(ordinal: "21")]
public float Steals;
[Column(ordinal: "22")]
public float Blocks;
[Column(ordinal: "23")]
public float Turnovers;
[Column(ordinal: "24")]
public float TotalFouls;
[Column(ordinal: "25")]
public float Opp_FieldGoals;
[Column(ordinal: "26")]
public float Opp_FieldGoalsAttempted;
[Column(ordinal: "27")]
public float Opp_FieldGoals_;
[Column(ordinal: "28")]
public float Opp_3PointShots;
[Column(ordinal: "29")]
public float Opp_3PointShotsAttempted;
[Column(ordinal: "30")]
public float Opp_3PointShots_;
[Column(ordinal: "31")]
public float Opp_FreeThrows;
[Column(ordinal: "32")]
public float Opp_FreeThrowsAttempted;
[Column(ordinal: "33")]
public float Opp_FreeThrows_;
[Column(ordinal: "34")]
public float Opp_OffRebounds;
[Column(ordinal: "35")]
public float Opp_TotalRebounds;
[Column(ordinal: "36")]
public float Opp_Assists;
[Column(ordinal: "37")]
public float Opp_Steals;
[Column(ordinal: "38")]
public float Opp_Blocks;
[Column(ordinal: "39")]
public float Opp_Turnovers;
[Column(ordinal: "40")]
public float Opp_TotalFouls;
}
public class MatchPrediction
{
[ColumnName("Score")]
public float TeamPoints;
}
加載數據部分
const string DATA_PATH = "data/nba.games.stats.csv";
static ICollection<Match> LoadData()
{
var matches = new List<Match>();
using (var sr = new StreamReader(File.OpenRead(DATA_PATH)))
{
sr.ReadLine();
while (!sr.EndOfStream)
{
var line = sr.ReadLine();
var values = line.Split(",");
var match = new Match
{
Id = values[0].Trim('"'),
Team = values[1].Trim('"'),
Game = values[2].Trim('"'),
Date = values[3].Trim('"'),
Home = values[4].Trim('"'),
Opponent = values[5].Trim('"'),
WINorLOSS = values[6].Trim('"'),
TeamPoints = Convert.ToSingle(values[7].Trim('"')),
OpponentPoints = Convert.ToSingle(values[8].Trim('"')),
FieldGoals = Convert.ToSingle(values[9].Trim('"')),
FieldGoalsAttempted = Convert.ToSingle(values[10].Trim('"')),
FieldGoals_ = Convert.ToSingle(values[11].Trim('"')),
X3PointShots = Convert.ToSingle(values[12].Trim('"')),
X3PointShotsAttempted = Convert.ToSingle(values[13].Trim('"')),
X3PointShots_ = Convert.ToSingle(values[14].Trim('"')),
FreeThrows = Convert.ToSingle(values[15].Trim('"')),
FreeThrowsAttempted = Convert.ToSingle(values[16].Trim('"')),
FreeThrows_ = Convert.ToSingle(values[17].Trim('"')),
OffRebounds = Convert.ToSingle(values[18].Trim('"')),
TotalRebounds = Convert.ToSingle(values[19].Trim('"')),
Assists = Convert.ToSingle(values[20].Trim('"')),
Steals = Convert.ToSingle(values[21].Trim('"')),
Blocks = Convert.ToSingle(values[22].Trim('"')),
Turnovers = Convert.ToSingle(values[23].Trim('"')),
TotalFouls = Convert.ToSingle(values[24].Trim('"')),
Opp_FieldGoals = Convert.ToSingle(values[25].Trim('"')),
Opp_FieldGoalsAttempted = Convert.ToSingle(values[26].Trim('"')),
Opp_FieldGoals_ = Convert.ToSingle(values[27].Trim('"')),
Opp_3PointShots = Convert.ToSingle(values[28].Trim('"')),
Opp_3PointShotsAttempted = Convert.ToSingle(values[29].Trim('"')),
Opp_3PointShots_ = Convert.ToSingle(values[30].Trim('"')),
Opp_FreeThrows = Convert.ToSingle(values[31].Trim('"')),
Opp_FreeThrowsAttempted = Convert.ToSingle(values[32].Trim('"')),
Opp_FreeThrows_ = Convert.ToSingle(values[33].Trim('"')),
Opp_OffRebounds = Convert.ToSingle(values[34].Trim('"')),
Opp_TotalRebounds = Convert.ToSingle(values[35].Trim('"')),
Opp_Assists = Convert.ToSingle(values[36].Trim('"')),
Opp_Steals = Convert.ToSingle(values[37].Trim('"')),
Opp_Blocks = Convert.ToSingle(values[38].Trim('"')),
Opp_Turnovers = Convert.ToSingle(values[39].Trim('"')),
Opp_TotalFouls = Convert.ToSingle(values[40].Trim('"'))
};
matches.Add(match);
}
}
return matches;
}
訓練、評估、預測部分
static PredictionModel<Match, MatchPrediction> Train(IEnumerable<Match> trainData)
{
var pipeline = new LearningPipeline();
pipeline.Add(CollectionDataSource.Create(trainData));
pipeline.Add(new ColumnDropper() { Column = new[] { "Id" } });
pipeline.Add(new CategoricalOneHotVectorizer("Team", "Game", "Date", "Home", "Opponent", "WINorLOSS"));
pipeline.Add(new ColumnConcatenator("Features", "Team", "Game", "Date", "Home", "Opponent", "WINorLOSS", "OpponentPoints", "FieldGoals", "FieldGoalsAttempted", "FieldGoals_", "X3PointShots", "X3PointShotsAttempted", "X3PointShots_", "FreeThrows", "FreeThrowsAttempted", "FreeThrows_", "OffRebounds", "TotalRebounds", "Assists", "Steals", "Blocks", "Turnovers", "TotalFouls", "Opp_FieldGoals", "Opp_FieldGoalsAttempted", "Opp_FieldGoals_", "Opp_3PointShots", "Opp_3PointShotsAttempted", "Opp_3PointShots_", "Opp_FreeThrows", "Opp_FreeThrowsAttempted", "Opp_FreeThrows_", "Opp_OffRebounds", "Opp_TotalRebounds", "Opp_Assists", "Opp_Steals", "Opp_Blocks", "Opp_Turnovers", "Opp_TotalFouls"));
pipeline.Add(new PoissonRegressor());
var model = pipeline.Train<Match, MatchPrediction>();
return model;
}
static void Evaluate(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> evaluateData)
{
var evaluator = new RegressionEvaluator();
var metric = evaluator.Evaluate(model, CollectionDataSource.Create(evaluateData));
Console.WriteLine("LossFn: {0}", metric.LossFn);
Console.WriteLine("RSquared: {0}", metric.RSquared);
Console.WriteLine("Rms: {0}", metric.Rms);
}
static void Predict(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> predictData)
{
var predicts = model.Predict(predictData);
var results = predictData.Zip(predicts, (d, p) => (d, p));
foreach (var result in results)
{
Console.WriteLine("Date: {0}, Team: {1} Opponent: {2}, Score: {3}-{4}, Predict Home Score: {5}",
result.d.Date, result.d.Team, result.d.Opponent, result.d.TeamPoints, result.d.OpponentPoints, result.p.TeamPoints);
}
}
最后是Main調用部分
static void Main(string[] args)
{
var data = LoadData();
var trainCount = Convert.ToInt32(data.Count * 0.7);
var evaluateCount = Convert.ToInt32(data.Count * 0.2);
var trainData = data.Take(trainCount);
var evaluateData = data.Skip(trainCount).Take(evaluateCount);
var predictData = data.Skip(trainCount + evaluateCount);
var model = Train(trainData);
Evaluate(model, evaluateData);
Predict(model, predictData);
}
執行結果
結尾
可以看到,最近的NBA比賽主隊預測得分與真實結果對比,正確率已相當可觀了,由於特征值都是比賽技術數據,用在以后的比賽時,可根據比賽進行的實時情況不斷更新,便可越來越接近結果。
對球迷來說這可是一件神器呀。想想2018世界杯也馬上要開始了,保羅、阿喀琉斯什么的都弱爆了,相信小伙伴們也要嘗試一下ML.NET的套路了吧,記得拿到歷年完整的數據喲!
完整代碼如下:
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Models;
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Learners;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
namespace NBAPrediction
{
class Program
{
const string DATA_PATH = "data/nba.games.stats.csv";
static ICollection<Match> LoadData()
{
var matches = new List<Match>();
using (var sr = new StreamReader(File.OpenRead(DATA_PATH)))
{
sr.ReadLine();
while (!sr.EndOfStream)
{
var line = sr.ReadLine();
var values = line.Split(",");
var match = new Match
{
Id = values[0].Trim('"'),
Team = values[1].Trim('"'),
Game = values[2].Trim('"'),
Date = values[3].Trim('"'),
Home = values[4].Trim('"'),
Opponent = values[5].Trim('"'),
WINorLOSS = values[6].Trim('"'),
TeamPoints = Convert.ToSingle(values[7].Trim('"')),
OpponentPoints = Convert.ToSingle(values[8].Trim('"')),
FieldGoals = Convert.ToSingle(values[9].Trim('"')),
FieldGoalsAttempted = Convert.ToSingle(values[10].Trim('"')),
FieldGoals_ = Convert.ToSingle(values[11].Trim('"')),
X3PointShots = Convert.ToSingle(values[12].Trim('"')),
X3PointShotsAttempted = Convert.ToSingle(values[13].Trim('"')),
X3PointShots_ = Convert.ToSingle(values[14].Trim('"')),
FreeThrows = Convert.ToSingle(values[15].Trim('"')),
FreeThrowsAttempted = Convert.ToSingle(values[16].Trim('"')),
FreeThrows_ = Convert.ToSingle(values[17].Trim('"')),
OffRebounds = Convert.ToSingle(values[18].Trim('"')),
TotalRebounds = Convert.ToSingle(values[19].Trim('"')),
Assists = Convert.ToSingle(values[20].Trim('"')),
Steals = Convert.ToSingle(values[21].Trim('"')),
Blocks = Convert.ToSingle(values[22].Trim('"')),
Turnovers = Convert.ToSingle(values[23].Trim('"')),
TotalFouls = Convert.ToSingle(values[24].Trim('"')),
Opp_FieldGoals = Convert.ToSingle(values[25].Trim('"')),
Opp_FieldGoalsAttempted = Convert.ToSingle(values[26].Trim('"')),
Opp_FieldGoals_ = Convert.ToSingle(values[27].Trim('"')),
Opp_3PointShots = Convert.ToSingle(values[28].Trim('"')),
Opp_3PointShotsAttempted = Convert.ToSingle(values[29].Trim('"')),
Opp_3PointShots_ = Convert.ToSingle(values[30].Trim('"')),
Opp_FreeThrows = Convert.ToSingle(values[31].Trim('"')),
Opp_FreeThrowsAttempted = Convert.ToSingle(values[32].Trim('"')),
Opp_FreeThrows_ = Convert.ToSingle(values[33].Trim('"')),
Opp_OffRebounds = Convert.ToSingle(values[34].Trim('"')),
Opp_TotalRebounds = Convert.ToSingle(values[35].Trim('"')),
Opp_Assists = Convert.ToSingle(values[36].Trim('"')),
Opp_Steals = Convert.ToSingle(values[37].Trim('"')),
Opp_Blocks = Convert.ToSingle(values[38].Trim('"')),
Opp_Turnovers = Convert.ToSingle(values[39].Trim('"')),
Opp_TotalFouls = Convert.ToSingle(values[40].Trim('"'))
};
matches.Add(match);
}
}
return matches;
}
static PredictionModel<Match, MatchPrediction> Train(IEnumerable<Match> trainData)
{
var pipeline = new LearningPipeline();
pipeline.Add(CollectionDataSource.Create(trainData));
pipeline.Add(new ColumnDropper() { Column = new[] { "Id" } });
pipeline.Add(new CategoricalOneHotVectorizer("Team", "Game", "Date", "Home", "Opponent", "WINorLOSS"));
pipeline.Add(new ColumnConcatenator("Features", "Team", "Game", "Date", "Home", "Opponent", "WINorLOSS", "OpponentPoints", "FieldGoals", "FieldGoalsAttempted", "FieldGoals_", "X3PointShots", "X3PointShotsAttempted", "X3PointShots_", "FreeThrows", "FreeThrowsAttempted", "FreeThrows_", "OffRebounds", "TotalRebounds", "Assists", "Steals", "Blocks", "Turnovers", "TotalFouls", "Opp_FieldGoals", "Opp_FieldGoalsAttempted", "Opp_FieldGoals_", "Opp_3PointShots", "Opp_3PointShotsAttempted", "Opp_3PointShots_", "Opp_FreeThrows", "Opp_FreeThrowsAttempted", "Opp_FreeThrows_", "Opp_OffRebounds", "Opp_TotalRebounds", "Opp_Assists", "Opp_Steals", "Opp_Blocks", "Opp_Turnovers", "Opp_TotalFouls"));
pipeline.Add(new PoissonRegressor());
var model = pipeline.Train<Match, MatchPrediction>();
return model;
}
static void Evaluate(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> evaluateData)
{
var evaluator = new RegressionEvaluator();
var metric = evaluator.Evaluate(model, CollectionDataSource.Create(evaluateData));
Console.WriteLine("LossFn: {0}", metric.LossFn);
Console.WriteLine("RSquared: {0}", metric.RSquared);
Console.WriteLine("Rms: {0}", metric.Rms);
}
static void Predict(PredictionModel<Match, MatchPrediction> model, IEnumerable<Match> predictData)
{
var predicts = model.Predict(predictData);
var results = predictData.Zip(predicts, (d, p) => (d, p));
foreach (var result in results)
{
Console.WriteLine("Date: {0}, Team: {1} Opponent: {2}, Score: {3}-{4}, Predict Home Score: {5}",
result.d.Date, result.d.Team, result.d.Opponent, result.d.TeamPoints, result.d.OpponentPoints, result.p.TeamPoints);
}
}
static void Main(string[] args)
{
var data = LoadData();
var trainCount = Convert.ToInt32(data.Count * 0.7);
var evaluateCount = Convert.ToInt32(data.Count * 0.2);
var trainData = data.Take(trainCount);
var evaluateData = data.Skip(trainCount).Take(evaluateCount);
var predictData = data.Skip(trainCount + evaluateCount);
var model = Train(trainData);
Evaluate(model, evaluateData);
Predict(model, predictData);
}
}
public class Match
{
[Column(ordinal: "0")]
public string Id;
[Column(ordinal: "1")]
public string Team;
[Column(ordinal: "2")]
public string Game;
[Column(ordinal: "3")]
public string Date;
[Column(ordinal: "4")]
public string Home;
[Column(ordinal: "5")]
public string Opponent;
[Column(ordinal: "6")]
public string WINorLOSS;
[Column(ordinal: "7", name: "Label")]
public float TeamPoints;
[Column(ordinal: "8")]
public float OpponentPoints;
[Column(ordinal: "9")]
public float FieldGoals;
[Column(ordinal: "10")]
public float FieldGoalsAttempted;
[Column(ordinal: "11")]
public float FieldGoals_;
[Column(ordinal: "12")]
public float X3PointShots;
[Column(ordinal: "13")]
public float X3PointShotsAttempted;
[Column(ordinal: "14")]
public float X3PointShots_;
[Column(ordinal: "15")]
public float FreeThrows;
[Column(ordinal: "16")]
public float FreeThrowsAttempted;
[Column(ordinal: "17")]
public float FreeThrows_;
[Column(ordinal: "18")]
public float OffRebounds;
[Column(ordinal: "19")]
public float TotalRebounds;
[Column(ordinal: "20")]
public float Assists;
[Column(ordinal: "21")]
public float Steals;
[Column(ordinal: "22")]
public float Blocks;
[Column(ordinal: "23")]
public float Turnovers;
[Column(ordinal: "24")]
public float TotalFouls;
[Column(ordinal: "25")]
public float Opp_FieldGoals;
[Column(ordinal: "26")]
public float Opp_FieldGoalsAttempted;
[Column(ordinal: "27")]
public float Opp_FieldGoals_;
[Column(ordinal: "28")]
public float Opp_3PointShots;
[Column(ordinal: "29")]
public float Opp_3PointShotsAttempted;
[Column(ordinal: "30")]
public float Opp_3PointShots_;
[Column(ordinal: "31")]
public float Opp_FreeThrows;
[Column(ordinal: "32")]
public float Opp_FreeThrowsAttempted;
[Column(ordinal: "33")]
public float Opp_FreeThrows_;
[Column(ordinal: "34")]
public float Opp_OffRebounds;
[Column(ordinal: "35")]
public float Opp_TotalRebounds;
[Column(ordinal: "36")]
public float Opp_Assists;
[Column(ordinal: "37")]
public float Opp_Steals;
[Column(ordinal: "38")]
public float Opp_Blocks;
[Column(ordinal: "39")]
public float Opp_Turnovers;
[Column(ordinal: "40")]
public float Opp_TotalFouls;
}
public class MatchPrediction
{
[ColumnName("Score")]
public float TeamPoints;
}
}