數據挖掘作業,要實現決策樹,現記錄學習過程
win10系統,Python 3.7.0
構建一個決策樹,在鳶尾花數據集上訓練一個DecisionTreeClassifier:
from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier iris = load_iris() X = iris.data[:,2:] y = iris.target tree_clf = DecisionTreeClassifier(max_depth=2) tree_clf.fit(X,y)
要將決策樹可視化,首先,使用export_graphviz()方法輸出一個圖形定義文件,命名為iris_tree.dot
這里需要安裝graphviz
安裝方式:
① conda install python-graphviz
② pip install graphviz
在當前目錄下新建images/decision_trees目錄
不然會報錯
Traceback (most recent call last):
File "decisiontree.py", line 21, in <module>
filled=True)
File "E:\Anaconda\lib\site-packages\sklearn\tree\export.py", line 762, in export_graphviz
out_file = open(out_file, "w", encoding="utf-8")
FileNotFoundError: [Errno 2] No such file or directory: '.\\images\\decision_trees\\iris_tree.dot'
from sklearn.tree import export_graphviz import os PROJECT_ROOT_DIR = "." CHAPTER_ID = "decision_trees" def image_path(fig_id): return os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID, fig_id)
export_graphviz(tree_clf, out_file=image_path("iris_tree.dot"), feature_names=iris.feature_names[2:], class_names=iris.target_names, rounded=True, filled=True)
運行過后生成了一個dot文件
使用命令dot -Tpng iris_tree.dot -o iris_tree.png 將dot文件轉換為png文件方便顯示
決策樹如上圖所示
petal length:花瓣長度 petal width:花瓣寬度
samples:統計出它應用於多少個訓練樣本實例
value:這個節點對於每一個類別的樣例有多少個 這個葉結點顯示包含0 個 Iris-Setosa,1 個 Iris-Versicolor 和 45 個 Iris-Virginica
Gini:用於測量它的純度,如果一個節點包含的所有訓練樣例全都是同一類別的,我們就說這個節點是純的( Gini=0 )
Gini公式:
Pik是第i個節點上,類別為k的訓練實例占比
進行預測
當找到了一朵鳶尾花並且想對它進行分類時,從根節點開始,詢問花朵的花瓣長度是否小於2.45厘米。如果是,將向下移動到根的左側子節點,在這種情況下,它是一片葉子節點,它不會再繼續問任何問題,決策樹預測你的花是iris-setosa
假設你找到了另一朵花,但這次的花瓣長度是大於2.45厘米的。必須向下移動到根的右側子節點,而這個節點不是葉節點,它會問另一個問題,花瓣寬度是否小於1.75厘米?如果是,則將這朵花分類成iris-versicolor ,不是,則分類成iris-versicolor
注意:scikit-learn使用的是CART算法,該算法僅生成二叉樹;非葉節點永遠只有兩個子節點。
估計分類概率
新樣本:花瓣長5厘米,花瓣寬1.5厘米,預測具體的類
print(tree_clf.predict_proba([[5,1.5]])) print(tree_clf.predict([[5,1.5]]))
此處說明分類為iris-setosa的概率為0,分類為iris-versicolor的概率為0.90740741,分類為iris-virginica的概率為0.09259259
通過predict預測該花為iris-versicolor
完整代碼
#在鳶尾花數據集上進行一個決策樹分類器的訓練 from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier from sklearn.tree import export_graphviz import os PROJECT_ROOT_DIR = "." CHAPTER_ID = "decision_trees" def image_path(fig_id): return os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID, fig_id) iris = load_iris() X = iris.data[:,2:] y = iris.target tree_clf = DecisionTreeClassifier(max_depth=2) tree_clf.fit(X,y) export_graphviz(tree_clf, out_file=image_path("iris_tree.dot"), feature_names=iris.feature_names[2:], class_names=iris.target_names, rounded=True, filled=True) print(tree_clf.predict_proba([[5,1.5]])) #[0]:iris-setosa, [1]:iris-versicolor, [2]:iris-virginica" print(tree_clf.predict([[5,1.5]]))
CART訓練算法原理介紹: