決策樹的可視化(sklearn可視化案例)


可視化

數據集

Iris數據集。
導入python庫和實驗數據集

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn import tree
# 加載數據集
data = load_iris() 
# 轉換成.DataFrame形式
df = pd.DataFrame(data.data, columns = data.feature_names)
# 添加品種列
df['Species'] = data.target

# 用數值替代品種名作為標簽
target = np.unique(data.target)
target_names = np.unique(data.target_names)
targets = dict(zip(target, target_names))
df['Species'] = df['Species'].replace(targets)

# 提取數據和標簽
X = df.drop(columns="Species")
y = df["Species"]
feature_names = X.columns
labels = y.unique()
# 數據集切分成60%訓練集,40%測試
X_train, test_x, y_train, test_lab = train_test_split(X,y,
                                                 test_size = 0.4,
                                                 random_state = 42)
model = DecisionTreeClassifier(max_depth =3, random_state = 42)
model.fit(X_train, y_train) 

四種可視化決策樹的方式

1.文字表示

# 以文字形式輸出樹     
text_representation = tree.export_text(model)
print(text_representation)

|--- feature_2 <= 2.45
| |--- class: setosa
|--- feature_2 > 2.45
| |--- feature_3 <= 1.75
| | |--- feature_2 <= 5.35
| | | |--- class: versicolor
| | |--- feature_2 > 5.35
| | | |--- class: virginica
| |--- feature_3 > 1.75
| | |--- feature_2 <= 4.85
| | | |--- class: virginica
| | |--- feature_2 > 4.85
| | | |--- class: virginica

2. 使用plot_tree函數畫圖表示

# 用圖片畫出
plt.figure(figsize=(15,10)) #
a = tree.plot_tree(model,
                   feature_names = feature_names,
                   class_names = labels,
                   rounded = True,
                   filled = True,
                   fontsize=16)
plt.show()  


實際上,存儲在scikit-learn的tree模塊中的有GraphViz,所以直接調用plot_tree能夠輸出與使用GraphViz的方法相同的圖形。

3.graphviz畫圖

sklearn.tree.export_graphviz以DOT格式導出決策樹模型

# DOT data
dot_data = tree.export_graphviz(model, out_file=None, 
                                feature_names=data.feature_names,  
                                class_names=data.target_names,
                                filled=True,
                                rounded=True)

# Draw graph
import graphviz
graph = graphviz.Source(dot_data, format="png") 
graph

4. plot_decision_region函數可視化

遇到的問題及解決

Graphviz模塊按照、 環境變量的問題

安裝步驟:

  1. 去官網安裝graphviz

  2. 配置環境變量:安裝路徑\bin\dot.exe

  3. 安裝到python:pip install Graphviz

  4. 重啟(建議)

參考

【1】解決failed to execute [‘dot’, ‘-Tsvg’], make sure the Graphviz executables are on your systems的問題
【2】sklearn.tree.plot_tree官方文檔
【3】sklearn幾種分類算法可視化


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM