原文鏈接:http://tecdat.cn/?p=9326
在這篇文章中,我將使用python中的決策樹(用於分類)。重點將放在基礎知識和對最終決策樹的理解上。
導入
因此,首先我們進行一些導入。
數據
接下來,我們需要考慮一些數據。我將使用著名的iris數據集,該數據集可對各種不同的iris類型進行各種測量。pandas和sckit-learn都可以輕松導入這些數據,我將使用pandas編寫一個從csv文件導入的函數。這樣做的目的是演示如何將scikit-learn與pandas一起使用。因此,我們定義了一個獲取iris數據的函數:
- 此函數首先嘗試在本地讀取數據。利用os.path.exists() 方法。如果在本地目錄中找到iris.csv文件,則使用pandas通過pd.read_csv()讀取文件。
- 如果本地iris.csv沒有發現,抓取URL數據來運行。
下一步是獲取數據,並使用head()和tail()方法查看數據的樣子。因此,首先獲取數據:
然后 :
從這些信息中,我們可以討論我們的目標:給定特征SepalLength, SepalWidth, PetalLength 和PetalWidth來預測iris類型。
預處理
為了將這些數據傳遞到scikit-learn,我們需要將Names編碼為整數。為此,我們將編寫另一個函數,並返回修改后的數據框以及目標(類)名稱的列表:
讓我們看看有什么:
接下來,我們獲得列的名稱:
用scikit-learn擬合決策樹
現在,我們可以使用 上面導入的DecisionTreeClassifier擬合決策樹,如下所示:
- 我們使用簡單的索引從數據框中提取X和y數據。
- 開始時導入的決策樹用兩個參數初始化:min_samples_split = 20需要一個節點中的20個樣本才能拆分,並且 random_state = 99進行種子隨機數生成器。
可視化樹
我們可以使用以下功能生成圖形:
- 從上面的scikit-learn導入的export_graphviz方法寫入一個點文件。此文件用於生成圖形。
- 生成圖形 dt.png。
運行函數:
結果
我們可以使用此圖來了解決策樹發現的模式:
- 所有數據(所有行)都從樹頂部開始。
- 考慮了所有功能,以了解如何以最有用的方式拆分數據-默認情況下使用基尼度量。
- 在頂部,我們看到最有用的條件是 PetalLength <= 2.4500。
- 這種分裂一直持續到
- 拆分后僅具有一個類別。
- 或者,結果中的樣本少於20個。
決策樹的偽代碼
最后,我們考慮生成代表學習的決策樹的偽代碼。
- 目標名稱可以傳遞給函數,並包含在輸出中。
- 使用spacer_base 參數,使輸出更容易閱讀。
應用於iris數據的結果輸出為:
將其與上面的圖形輸出進行比較-這只是決策樹的不同表示。
在python中進行決策樹交叉驗證
導入
首先,我們導入所有代碼:
主要添加的內容是sklearn.grid_search中的方法,它們可以:
- 時間搜索
- 使用itemgetter對結果進行排序
- 使用scipy.stats.randint生成隨機整數。
現在我們可以開始編寫函數了。
包括:
- get_code –為決策樹編寫偽代碼,
- visualize_tree –生成決策樹的圖形。
- encode_target –處理原始數據以與scikit-learn一起使用。
- get_iris_data –如果需要,從網絡上獲取 iris.csv,並將副本寫入本地目錄。
新功能
接下來,我們添加一些新功能來進行網格和隨機搜索,並報告找到的主要參數。首先是報告。此功能從網格或隨機搜索中獲取輸出,打印模型的報告並返回最佳參數設置。
網格搜索
接下來是run_gridsearch。該功能需要
- 特征X,
- 目標y,
- (決策樹)分類器clf,
- 嘗試參數字典的param_grid
- 交叉驗證cv的倍數,默認為5。
param_grid是一組參數,這將是作測試,要注意不要列表中有太多的選擇。
隨機搜尋
接下來是run_randomsearch函數,該函數從指定的列表或分布中采樣參數。與網格搜索類似,參數為:
- 功能X
- 目標y
- (決策樹)分類器clf
- 交叉驗證cv的倍數,默認為5
- n_iter_search的隨機參數設置數目,默認為20。
好的,我們已經定義了所有函數。
交叉驗證
獲取數據
接下來,讓我們使用上面設置的搜索方法來找到合適的參數設置。首先進行一些初步准備-獲取數據並構建目標數據:
第一次交叉驗證
在下面的所有示例中,我將使用10倍交叉驗證。
- 將數據分為10部分
- 擬合9個部分
- 其余部分的測試准確性
使用當前參數設置,在所有組合上重復此操作,以產生十個模型精度估計。通常會報告十個評分的平均值和標准偏差。
0.960還不錯。這意味着平均准確性(使用經過訓練的模型進行正確分類的百分比)為96%。該精度非常高,但是讓我們看看是否可以找到更好的參數。
網格搜索的應用
首先,我將嘗試網格搜索。字典para_grid提供了要測試的不同參數設置。
在大多數運行中,各種參數設置的平均值為0.967。這意味着從96%改善到96.7%!我們可以看到最佳的參數設置ts_gs,如下所示:
並復制交叉驗證結果:
接下來,讓我們使用獲取最佳樹的偽代碼:
我們還可以制作決策樹的圖形:
隨機搜索的應用
接下來,我們嘗試使用隨機搜索方法來查找參數。在此示例中,我使用288個樣本,以便測試的參數設置數量與上面的網格搜索相同:
與網格搜索一樣,這通常會找到平均精度為0.967或96.7%的多個參數設置。如上所述,最佳交叉驗證的參數為:
並且,我們可以再次測試最佳參數:
要查看決策樹是什么樣的,我們可以生成偽代碼以獲得最佳隨機搜索結果
並可視化樹
結論
因此,我們使用了帶有交叉驗證的網格和隨機搜索來調整決策樹的參數。在這兩種情況下,從96%到96.7%的改善都很小。當然,在更復雜的問題中,這種影響會更大。最后幾點注意事項:
- 通過交叉驗證搜索找到最佳參數設置后,通常使用找到的最佳參數對所有數據進行訓練。
- 傳統觀點認為,對於實際應用而言,隨機搜索比網格搜索更有效。網格搜索確實花費的時間太長,這當然是有意義的。
- 此處開發的基本交叉驗證想法可以應用於許多其他scikit學習模型-隨機森林,邏輯回歸,SVM等。
如果您有任何疑問,請在下面發表評論。