決策樹算法對鳶尾花數據集進行分類


①導入相關擴展包

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz 

②獲取數據集

iris = load_iris()

③划分數據集

x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=20)

④決策樹預估器(estimator)

estimator=DecisionTreeClassifier(criterion="entropy")   #criterion默認為'gini'系數,也可選擇信息增益熵'entropy'
estimator.fit(x_train,y_train)      #調用fit()方法進行訓練,()內為訓練集的特征值與目標值

⑤模型評估

方法一:直接對比測試集的真實值和預測值

y_predict=estimator.predict(x_test)     #傳入測試集特征值,預測所給測試集的目標值
print("y_predict:\n",y_predict)
print("直接對比真實值和預測值:\n",y_test==y_predict)

方法二:計算准確率

score=estimator.score(x_test,y_test)    #傳入測試集的特征值和目標值

⑥決策樹可視化(將結果寫入tree.dot文件中,然后將tree.dot文件中的內容粘貼在webgraphviz.com中進行可視化展示

 

export_graphviz(estimator, out_file="tree.dot", feature_names=iris.feature_names)

 

主要代碼:

 

def decision_demo():
#     1.獲取數據集
    iris = load_iris()
#     2.划分數據集
    x_train,x_test,y_train,y_test=train_test_split(iris.data,iris.target,random_state=20)
#     3.決策樹預估器(estimator)
    estimator=DecisionTreeClassifier(criterion="entropy")   #criterion默認為'gini'系數,也可選擇信息增益熵'entropy'
    estimator.fit(x_train,y_train)      #調用fit()方法進行訓練,()內為訓練集的特征值與目標值
#     4.模型評估
    #方法一:直接對比真實值和預測值
    y_predict=estimator.predict(x_test)     #傳入測試集特征值,預測所給測試集的目標值
    print("y_predict:\n",y_predict)
    print("直接對比真實值和預測值:\n",y_test==y_predict)

    #方法二:計算准確率
    score=estimator.score(x_test,y_test)    #傳入測試集的特征值和目標值
    print("准確率為:\n",score)

    #決策樹可視化
    export_graphviz(estimator,out_file="tree.dot",feature_names=iris.feature_names)

    return None

 

代碼運行結果:

 

 可視化展示結果:

 

 

 

 注:可視化展示中,feature_names=iris.feature_names缺省會出現特征值名稱缺失現象,如下圖所示:

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM