一、導入必要的工具包
# 導入必要的工具包
import xgboost as xgb
# 計算分類正確率
from sklearn.metrics import accuracy_score
二、數據讀取
XGBoost可以加載libsvm格式的文本數據,libsvm的文件格式(稀疏特征)如下:
1 101:1.2 102:0.03
0 1:2.1 10001:300 10002:400
...
每一行表示一個樣本,第一行的開頭的“1”是樣本的標簽。“101”和“102”為特征索引,'1.2'和'0.03' 為特征的值。
在兩類分類中,用“1”表示正樣本,用“0” 表示負樣本。也支持[0,1]表示概率用來做標簽,表示為正樣本的概率。
下面的示例數據需要我們通過一些蘑菇的若干屬性判斷這個品種是否有毒。
UCI數據描述:http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/ ,
每個樣本描述了蘑菇的22個屬性,比如形狀、氣味等等(將22維原始特征用加工后變成了126維特征,
並存為libsvm格式),然后給出了這個蘑菇是否可食用。其中6513個樣本做訓練,1611個樣本做測試。
注:libsvm格式文件說明如下 https://www.cnblogs.com/codingmengmeng/p/6254325.html
XGBoost加載的數據存儲在對象DMatrix中
XGBoost自定義了一個數據矩陣類DMatrix,優化了存儲和運算速度
DMatrix文檔:http://xgboost.readthedocs.io/en/latest/python/python_api.html
數據下載地址:http://download.csdn.net/download/u011630575/10266113
# read in data,數據在xgboost安裝的路徑下的demo目錄,現在我們將其copy到當前代碼下的data目錄
my_workpath = './data/'
dtrain = xgb.DMatrix(my_workpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(my_workpath + 'agaricus.txt.test')
查看數據情況
dtrain.num_col()
dtrain.num_row()
dtest.num_row()
三、訓練參數設置
max_depth: 樹的最大深度。缺省值為6,取值范圍為:[1,∞]
eta:為了防止過擬合,更新過程中用到的收縮步長。在每次提升計算之后,算法會直接獲得新特征的權重。
eta通過縮減特征的權重使提升計算過程更加保守。缺省值為0.3,取值范圍為:[0,1]
silent:取0時表示打印出運行時信息,取1時表示以緘默方式運行,不打印運行時信息。缺省值為0
objective: 定義學習任務及相應的學習目標,“binary:logistic” 表示二分類的邏輯回歸問題,輸出為概率。
其他參數取默認值。
# specify parameters via map
param = {'max_depth':2, 'eta':1, 'silent':0, 'objective':'binary:logistic' }
print(param)
四、訓練模型
# 設置boosting迭代計算次數
num_round = 2
import time
starttime = time.clock()
bst = xgb.train(param, dtrain, num_round) # dtrain是訓練數據集
endtime = time.clock()
print (endtime - starttime)
XGBoost預測的輸出是概率。這里蘑菇分類是一個二類分類問題,輸出值是樣本為第一類的概率。
我們需要將概率值轉換為0或1。
train_preds = bst.predict(dtrain)
train_predictions = [round(value) for value in train_preds]
y_train = dtrain.get_label() #值為輸入數據的第一行
train_accuracy = accuracy_score(y_train, train_predictions)
print ("Train Accuary: %.2f%%" % (train_accuracy * 100.0))
五、測試
模型訓練好后,可以用訓練好的模型對測試數據進行預測
# make prediction
preds = bst.predict(dtest)
檢查模型在測試集上的正確率
XGBoost預測的輸出是概率,輸出值是樣本為第一類的概率。我們需要將概率值轉換為0或1。
predictions = [round(value) for value in preds]
y_test = dtest.get_label()
test_accuracy = accuracy_score(y_test, predictions)
print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))
六、模型可視化
調用XGBoost工具包中的plot_tree,在顯示
要可視化模型需要安裝graphviz軟件包
plot_tree()的三個參數:
1. 模型
2. 樹的索引,從0開始
3. 顯示方向,缺省為豎直,‘LR'是水平方向
from matplotlib import pyplot
import graphviz
xgb.plot_tree(bst, num_trees=0, rankdir= 'LR' )
pyplot.show()
#xgb.plot_tree(bst,num_trees=1, rankdir= 'LR' )
#pyplot.show()
#xgb.to_graphviz(bst,num_trees=0)
#xgb.to_graphviz(bst,num_trees=1)
七、代碼整理
# coding:utf-8
import xgboost as xgb
# 計算分類正確率
from sklearn.metrics import accuracy_score
# read in data,數據在xgboost安裝的路徑下的demo目錄,現在我們將其copy到當前代碼下的data目錄
my_workpath = './data/'
dtrain = xgb.DMatrix(my_workpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(my_workpath + 'agaricus.txt.test')
dtrain.num_col()
dtrain.num_row()
dtest.num_row()
# specify parameters via map
param = {'max_depth':2, 'eta':1, 'silent':0, 'objective':'binary:logistic' }
print(param)
# 設置boosting迭代計算次數
num_round = 2
import time
starttime = time.clock()
bst = xgb.train(param, dtrain, num_round) # dtrain是訓練數據集
endtime = time.clock()
print (endtime - starttime)
train_preds = bst.predict(dtrain) #
print ("train_preds",train_preds)
train_predictions = [round(value) for value in train_preds]
print ("train_predictions",train_predictions)
y_train = dtrain.get_label()
print ("y_train",y_train)
train_accuracy = accuracy_score(y_train, train_predictions)
print ("Train Accuary: %.2f%%" % (train_accuracy * 100.0))
# make prediction
preds = bst.predict(dtest)
predictions = [round(value) for value in preds]
y_test = dtest.get_label()
test_accuracy = accuracy_score(y_test, predictions)
print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))
# from matplotlib import pyplot
# import graphviz
import graphviz
# xgb.plot_tree(bst, num_trees=0, rankdir='LR')
# pyplot.show()
# xgb.plot_tree(bst,num_trees=1, rankdir= 'LR' )
# pyplot.show()
# xgb.to_graphviz(bst,num_trees=0)
# xgb.to_graphviz(bst,num_trees=1)