XGBoost判斷蘑菇是否有毒示例


數據文件說明

本示例的數據集文件可以在https://github.com/dmlc/xgboost/tree/master/demo/data這里獲得。
該數據集描述的是不同蘑菇的相關特征,如大小、顏色等,並且每一種蘑菇都會被標記為可食用的(標記為0)或有毒的(標記為1)。

LibSVM 格式說明

這個數據是LibSVM格式的

LibSVM 使用的訓練數據和檢驗數據文件格式如下:

[label] [index1]:[value1] [index2]:[value2] …
[label] [index1]:[value1] [index2]:[value2] …

label 目標值,就是說class(屬於哪一類),就是你要分類的種類,通常是一些整數。

index 是有順序的索引,通常是連續的整數。就是指特征編號,必須按照升序排列

value 就是特征值,用來train的數據,通常是一堆實數組成。

格式特征:

  • 每行包含一個實例,並以“ \ n”字符結尾。
  • 對於分類,
  • 是一個從1開始的整數, 是一個實數。唯一的例外是預先計算的內核, 從0開始;

參考:libsvm的數據格式及制作

我們這個例子中的數據文件

1 3:1 10:1 11:1 21:1 30:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 124:1
0 3:1 10:1 20:1 21:1 23:1 34:1 36:1 39:1 41:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 120:1
0 1:1 10:1 19:1 21:1 24:1 34:1 36:1 39:1 42:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 122:1
1 3:1 9:1 19:1 21:1 30:1 34:1 36:1 40:1 42:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 124:1
0 3:1 10:1 14:1 22:1 29:1 34:1 37:1 39:1 41:1 54:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 98:1 106:1 114:1 120:1
0 3:1 9:1 20:1 21:1 23:1 34:1 36:1 39:1 42:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 116:1 120:1

每一行的label值,標記該蘑菇可食用的(標記為0)或有毒的(標記為1)。

數據源說明

這個例子的數據源來自:http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/
數據中包括蘑菇對形狀、顏色等特征,以及是否有毒的標簽。

agaricus-lepiota.data

原始數據存放在agaricus-lepiota.data里,內容如下所示。它有23列,其中第一列是標簽列,p表示有毒,e表示沒有毒。后面的22列是22個特征對應的特征值。

agaricus-lepiota.names

agaricus-lepiota.names 文件里存放特征映射關系,比如蘑菇頭形狀(cap-shap)為鍾型(bell)的用b表示,圓錐型(conical)的用c表示;蘑菇頭顏色(cap-color)為棕色(brown)的用n表示,淺黃色(buff)的用b表示,等等。總共22個特征映射,對應agaricus-lepiota.data里的第1~22列(第0列為標簽)。

數據准備

這里我們已經把這個數據變化成了LibSVM格式。

另外我們還把數據隨機分成訓練集(agaricus.txt.train)和測試集(agaricus.txt.test)兩部分,80%的數據分配給訓練集,20%分配給測試集。

參考:xgboost小試

訓練模型

我們的任務是對蘑菇特征數據進行學習,訓練相關模型,然后利用訓練好的模型預測未知的蘑菇樣本是否有毒。

import xgboost as xgb

# 數據讀取
xgb_train = xgb.DMatrix('./agaricus.txt.train')
xgb_test = xgb.DMatrix('./agaricus.txt.test')

# 定義模型訓練參數
params = {
    "objective":"binary:logistic",
    "booster":"gbtree",
    "max_depth":3
}

# 訓練輪數
num_round = 5

# 訓練過程中實時輸出評估結果
watchlist = [(xgb_train,'train'),(xgb_test,'test')]

# 模型訓練
model = xgb.train(params,xgb_train,num_round,watchlist)

輸出結果

% python ./xgb20.py 
[19:25:53] WARNING: /opt/concourse/worker/volumes/live/7a2b9f41-3287-451b-6691-43e9a6c0910f/volume/xgboost-split_1619728204606/work/src/learner.cc:1061: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.
[0]     train-logloss:0.45224   test-logloss:0.45317
[1]     train-logloss:0.32281   test-logloss:0.32412
[2]     train-logloss:0.23637   test-logloss:0.23739
[3]     train-logloss:0.16933   test-logloss:0.16935
[4]     train-logloss:0.12386   test-logloss:0.12352

XGBoost訓練過程中實時輸出了訓練集和測試集的錯誤率評估結果。隨着訓練的進行,訓練集和測試集的錯誤率均在不斷下降,說明模型對於特征數據的學習是十分有效的。

參數說明

  • "objective":"binary:logistic" objective 該參數用來指定目標函數,XGBoost可以根據該參數判斷進行何種學習任務,binary:logistic和binary:logitraw都表示學習任務類型為二分類。binary:logistic輸出為概率,binary:logitraw輸出為邏輯轉換前的輸出分數。
  • booster為gbtree表示采用XGBoost中的樹模型。
  • 參數max_depth表示決策樹分裂的最大深度。

預測


# 對測試集進行預測
preds = model.predict(xgb_test)

參考:


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM