決策樹與隨機森林分類算法(Python實現)


一、原理:

決策樹:能夠利用一些決策結點,使數據根據決策屬性進行路徑選擇,達到分類的目的。

一般決策樹常用於DFS配合剪枝,被用於處理一些單一算法問題,但也能進行分類 。

也就是通過每一個結點的決策進行分類,那么關於如何設置這些結點的決策方式:

熵:描述一個集合內元素混亂程度的因素。

熵的衡量公式

公式中的熵值 Entropy 會隨着集合中類別數量增加而快速增加,也就是說一個集合中類別越少,那么它的熵就小,整體就越穩定。

對於一個標記數據集,要合理的建立一棵決策樹,就需要合理的決定決策結點來使決策樹盡快的降低熵值。

如何選擇合適的決策:

(1)信息增溢 

對於當前的集合,對每一個決策屬性都嘗試設置為決策結點的目標,計算決策分類前的熵值 與 分類后的所有子集的熵值 的差。選擇最大的,作為當前的決策目標。

此方式有一些確定,就是當面對一些決策變量的分類子集很多,而子集卻很小的情況。這次辦法雖然會很快的降低熵,但這並不是我們想要的。

(2)信息增溢率

這是對熵增溢的一種改進,把原本的前后熵值的差,增加: 

決策分類前屬性的熵和 與 決策分類后的的熵 的比值,如果比值很小,說明分類分很多,損失值就會很大。

(3)gini系數: Gini = 1-\sum_{i=1}^{n}p^{2} (i) 

gini系數和信息增溢率比較像

決策樹的剪枝 :

預剪枝:設置max_depth來達到建樹過程中的剪枝,表示樹的最大深度

后剪枝:通過min_sample_split與min_sample_leaf來對已經建成的決策樹進行剪枝,分別是結點的元素個數與子樹的葉子結點個數

隨機森林 :

構建多個決策樹,從而得到更加符合期望的一些決策結果。以森林的結果眾數來表示結果。

往往采用生成子數據集,取60%隨機生成數據集

交叉驗證: 

幾折交叉驗證方式為,將訓練數據進行幾次對折,取一部分作為測試集,其他作為訓練集。並將每個部分輪流作為測試集,最后得到一個平均評分。 

網格超參數調優:

對分類器的參數進行調優評價,最后得到一個最優的參數組,並作為最終的分類器的參數。

二、實現 :

數據集:威斯康辛州乳腺癌數據集

import pandas as pd
df = pd.read_csv('文件所在路徑\\breast_cancer.csv',encoding='gbk')
df.head()
df.res.value_counts()
y=df.res
y.head()
df=df.drop(index=0)#修正數據集
x=df.drop('res',axis=1)#去掉標簽

數據標簽分布較為均衡

#導入決策樹
from sklearn.tree import DecisionTreeClassifier
#導入隨機森林
from sklearn.ensemble import RandomForestClassifier
#導入集合分割,交叉驗證,網格搜索
from sklearn.model_selection import train_test_split,cross_val_score,GridSearchCV
seed=5#隨機種子
#分割訓練集與測試集
xtrain,xtest,ytrain,ytest=train_test_split(x,y,test_size=0.3,random_state=seed)
#實例化隨機森林
rfc=RandomForestClassifier()
#訓練
rfc=rfc.fit(xtrain,ytrain)
測試評估
result=rfc.score(xtest,ytest)

print('所有樹:%s'%rfc.estimators_)
print(rfc.classes_)
print(rfc.n_classes)
print('判定結果:%s '%rfc.predict(xtest))
print('判定結果:%s'%rfc.predict_proba(xtest)[:,:])
print('判定結果:%s '%rfc.predict_proba(xtest)[:,1])
#d1與d2結果相同
d1=np.array(pd.Series(rfc.predict_proba(xtest)[:,1]>0.5).map({False:0,True:1}))
d2=rfc.predict(xtest)
np.array_equal(d1,d2)
#導入評價模塊
from sklearn.metrics import roc_auc_score,roc_curve,auc
#准確率
roc_auc_score(ytest,rfc.predict_proba(xtest)[:,1])
#結果:0.9935171385991058
print('各個feature的重要性:%s '%rfc.feature_importances_)
std=np.std([tree.feature_importances_ for tree in rfc.estimators_],axis=0)
從大到小排序
indices = np.argsort(importances)[::-1]
print('Feature Ranking:')
for f in range(min(20,xtrain.shape[1])):
    print("%2d)%-*s %f"%(f+1, 30, xtrain.columns[indices[f]],importances[indices[f]]))

 

繪圖
#黑線是標准差
plt.figure()
plt.title("Feature importances")
plt.bar(range(xtrain.shap[1]), importances[indices], color='r', yerr=std[indices], align="center")
plt.xticks(range(xtrain.shap[1]), indices)
plt.xlim([-1, xtrain.shap[1]])
plt.show()

predictions_validation = rfc.predict_proba(xtest)[:,1]
fpr, tqr, _=roc_curve(ytest, predictions_validation)
roc_auc = auc(fpr, tqr)
plt.title('ROC Validation')
plt.plot(fpr, tqr, 'b', label='AUC = %0.2f'%roc_auc)
plt.legend(loc='lower right')
plt.plot([0, 1], [0, 1], 'r--')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.ylabel('True Position Rate')
plt.xlabel('False Postion Rate')
plt.show()

 

'''交叉驗證'''
'''
sklearn.model_selection.cross_val_score(estimator, X,yscoring=None, cv=None,\
                                        n_jobs=1,verbose=0,fit_params=None,pre_dispatch='2*n_jobs')
estimator:估計方法對象(分類器)
X:數據特征(Featrues)
y:數據標簽(Labels)
soring:調用方法(包括accuracy和mean_squared_error等等)
cv:幾折交叉驗證(樣本等分成幾個部分,輪流作為驗證集來驗證模型)
n_jobs:同時工作的cpu個數(-1 代表全部)
'''
#兩種分類器的比較
#決策樹
clf = DecisionTreeClassifier(max_depth=None,min_samples_split=2,random_state=0)
scores = cross_val_score(clf, xtrain, ytrain)
print(scores.mean())
#0.932157394843962
#隨機森林
clf = RandomForestClassifier()
scores = cross_val_score(clf, xtrain, ytrain)
print(scores.mean())
#0.9471958389868838

 

參數調優過程:

#參數調優
param_test1 = {'n_estimators':range(25,500,25)}
gsearch1 = GridSearchCV(estimator = RandomForestClassifier(min_samples_split=100,
                                                          min_samples_leaf=20,
                                                          max_depth=8,random_state=10),
                       param_grid = param_test1,
                       scoring='roc_auc',
                       cv = 5)
gsearch1.fit(xtrain, ytrain)
'''調優結果'''
print(gsearch1.best_params_,gsearch1.best_score_)

param_test2 = {'min_samples_split':range(60,200,20), 'min_samples_leaf':range(10,110,10)}
gsearch2 = GridSearchCV(estimator = RandomForestClassifier(n_estimators=50,
                                                          max_depth=8,random_state=10),
                       param_grid = param_test2,
                       scoring='roc_auc',
                       cv = 5)
gsearch2.fit(xtrain, ytrain)
'''調優結果'''
print(gsearch2.best_params_,gsearch2.best_score_)

param_test3 = {'max_depth':range(3,30,2)}
gsearch1 = GridSearchCV(estimator = RandomForestClassifier(min_samples_split=60,
                                                          min_samples_leaf=10,
                                                           n_estimators=50,
                                                          random_state=10),
                       param_grid = param_test3,
                       scoring='roc_auc',
                       cv = 5)
gsearch3.fit(xtrain, ytrain)
'''調優結果'''
print(gsearch3.best_params_,gsearch3.best_score_)

param_test4 = {'criterion':['gini','entropy'], 'class_weight':[None, 'balanced']}
gsearch4 = GridSearchCV(estimator = RandomForestClassifier(n_estimators=50,
                                                           min_samples_split=60,
                                                           min_samples_leaf=10,
                                                           max_depth=3,
                                                           random_state=10),
                       param_grid = param_test4,
                       scoring='roc_auc',
                       cv = 5)
gsearch4.fit(xtrain, ytrain)
'''調優結果'''
print(gsearch4.best_params_,gsearch4.best_score_)
#gini,None

#整合所有最優參數值,得到最優評分
best_score = roc_auc_score(ytest, gsearch4.best_estimator_.predict_proba(xtest)[:,1])
print(best_score)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM