scikit-learn機器學習(四)使用決策樹做分類,並畫出決策樹,隨機森林對比


數據來自 UCI 數據集 匹馬印第安人糖尿病數據集

 

載入數據

# -*- coding: utf-8 -*-
import pandas as pd
import matplotlib
matplotlib.rcParams['font.sans-serif']=[u'simHei']
matplotlib.rcParams['axes.unicode_minus']=False
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV

from sklearn.datasets import load_breast_cancer

data_set = pd.read_csv('pima-indians-diabetes.csv')
data = data_set.values[:,:]

y = data[:,8]
X = data[:,:8]
X_train,X_test,y_train,y_test = train_test_split(X,y)

 

建立決策樹,網格搜索微調模型

# In[1] 網格搜索微調模型
pipeline = Pipeline([
        ('clf',DecisionTreeClassifier(criterion='entropy'))
        ])
parameters={
        'clf__max_depth':(3,5,10,15,20,25,30,35,40),
        'clf__min_samples_split':(2,3),
        'clf__min_samples_leaf':(1,2,3)
        }
#GridSearchCV 用於系統地遍歷多種參數組合,通過交叉驗證確定最佳效果參數。
grid_search = GridSearchCV(pipeline,parameters,n_jobs=-1,verbose=-1,scoring='f1')
grid_search.fit(X_train,y_train)

# 獲取搜索到的最優參數
best_parameters = grid_search.best_estimator_.get_params()
print("最好的F1值為:",grid_search.best_score_)
print('最好的參數為:')
for param_name in sorted(parameters.keys()):
    print('t%s: %r' % (param_name,best_parameters[param_name]))
    
# In[2] 輸出預測結果並評價
predictions = grid_search.predict(X_test)
print(classification_report(y_test,predictions))

 

最好的F1值為: 0.5573515325670498
最好的參數為:
tclf__max_depth: 5
tclf__min_samples_leaf: 1
tclf__min_samples_split: 2

 

 

 

評價模型

# In[2] 輸出預測結果並評價
predictions = grid_search.predict(X_test)
print(classification_report(y_test,predictions))

 

              precision    recall  f1-score   support

         0.0       0.74      0.89      0.81       124
         1.0       0.67      0.43      0.52        68

 

 

畫出決策樹

# In[3]打印樹
from sklearn import tree  
feature_name=data_set.columns.values.tolist()[:-1]   # 列名稱
DT = tree.DecisionTreeClassifier(criterion='entropy',max_depth=5,min_samples_split=2,min_samples_leaf=5)
DT.fit(X_train,y_train)

'''
# 法一
import pydotplus
from sklearn.externals.six import StringIO
dot_data = StringIO()
tree.export_graphviz(DT,out_file = dot_data,feature_names=feature_name,
                     class_names=["有糖尿病","無病"],filled=True,rounded=True,
                     special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("Tree.pdf")
print('Visible tree plot saved as pdf.')
'''

# 法二
import graphviz
#ID3為決策樹分類器fit之后得到的模型,注意這里必須在fit后執行,在predict之后運行會報錯
dot_data = tree.export_graphviz(DT, out_file=None,feature_names=feature_name,class_names=["有糖尿病","無病"]) # doctest: +SKIP
graph = graphviz.Source(dot_data) # doctest: +SKIP
#在同級目錄下生成tree.pdf文件
graph.render("tree2") # doctest: +SKIP

 

 

隨機森林

 

# -*- coding: utf-8 -*-
import pandas as pd
import matplotlib
matplotlib.rcParams['font.sans-serif']=[u'simHei']
matplotlib.rcParams['axes.unicode_minus']=False
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier

from sklearn.datasets import load_breast_cancer

data_set = pd.read_csv('pima-indians-diabetes.csv')
data = data_set.values[:,:]

y = data[:,8]
X = data[:,:8]
X_train,X_test,y_train,y_test = train_test_split(X,y)

RF = RandomForestClassifier(n_estimators=10,random_state=11)
RF.fit(X_train,y_train)
predictions = RF.predict(X_test)
print(classification_report(y_test,predictions))

 

              precision    recall  f1-score   support

         0.0       0.82      0.91      0.86       126
         1.0       0.78      0.61      0.68        66

   micro avg       0.81      0.81      0.81       192
   macro avg       0.80      0.76      0.77       192
weighted avg       0.80      0.81      0.80       192

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM