from sklearn import datasets from sklearn.tree import DecisionTreeClassifier
1.載入iris數據集(from sklearn import datasets)
x = iris.data[:,[0,2]] # x = iris.data[:, 0:2] y = iris.target
2.設置訓練集中的數據和標簽(x是特征集合,二維數組,y是標簽值集合,一維數組)
clf = DecisionTreeClassifier(max_depth = 3)
clf.fit(x,y)
3.訓練模型(DecisionTreeClassifier涉及到參數max_depth及其他,參考sklearn)
最后,是決策樹的可視化,預備工作為:
scikit-learn中決策樹的可視化一般需要安裝graphviz。主要包括graphviz的安裝和python的graphviz插件的安裝。
第一步是安裝graphviz。下載地址在:http://www.graphviz.org/。如果你是linux,可以用apt-get或者yum的方法安裝。如果是windows,就在官網下載msi文件安裝。無論是linux還是windows,裝完后都要設置環境變量,將graphviz的bin目錄加到PATH,比如我是windows,將C:/Program Files (x86)/Graphviz2.38/bin/加入了PATH
第二步是安裝python插件graphviz: pip install graphviz
第三步是安裝python插件pydotplus。這個沒有什么好說的: pip install pydotplus
這樣環境就搭好了,若仍然找不到graphviz,可以在代碼里面加入這一行:
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
兩種方法:
(1)生成pdf
import pydotplus dot_data = tree.export_graphviz(clf, out_file=None) graph = pydotplus.graph_from_dot_data(dot_data) graph.write_pdf("iris.pdf")
(2)直接在jupyter中顯示
from IPython.display import Image from sklearn import tree import pydotplus import os os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/' dot_data = tree.export_graphviz(clf, out_file=None, feature_names=["sepal length","sepal width"], class_names=iris.target_names, filled=True, rounded=True, special_characters=True) graph = pydotplus.graph_from_dot_data(dot_data) Image(graph.create_png())