代碼實現:
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Tue Sep 4 09:38:57 2018 4 5 @author: zhen 6 """ 7 8 from sklearn.ensemble import RandomForestClassifier 9 from sklearn.model_selection import train_test_split 10 from sklearn.metrics import accuracy_score 11 from sklearn.datasets import load_iris 12 import matplotlib.pyplot as plt 13 14 iris = load_iris() 15 x = iris.data[:, :2] 16 y = iris.target 17 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42) 18 19 # n_estimators:森林中樹的個數(默認為10),建議為奇數 20 # n_jobs:並行執行任務的個數(包括模型訓練和預測),默認值為-1,表示根據核數 21 rnd_clf = RandomForestClassifier(n_estimators=15, max_leaf_nodes=16, n_jobs=1) 22 rnd_clf.fit(x_train, y_train) 23 24 y_predict_rf = rnd_clf.predict(x_test) 25 26 print(accuracy_score(y_test, y_predict_rf)) 27 28 for name, score in zip(iris['feature_names'], rnd_clf.feature_importances_): 29 print(name, score) 30 31 # 可視化 32 plt.plot(x_test[:, 0], y_test, 'r.', label='real') 33 plt.plot(x_test[:, 0], y_predict_rf, 'b.', label='predict') 34 plt.xlabel('sepal-length', fontsize=15) 35 plt.ylabel('type', fontsize=15) 36 plt.legend(loc="upper left") 37 plt.show() 38 39 plt.plot(x_test[:, 1], y_test, 'r.', label='real') 40 plt.plot(x_test[:, 1], y_predict_rf, 'b.', label='predict') 41 plt.xlabel('sepal-width', fontsize=15) 42 plt.ylabel('type', fontsize=15) 43 plt.legend(loc="upper right") 44 plt.show()
結果:
可視化(查看每個預測條件的影響):
分析:鳶尾花的花萼長度在小於6時預測准確率很高,隨着長度的增加,在6~7這段中,預測出現較大錯誤率,當大於7時,預測會恢復到較好的情況。寬度也出現類似的情況,在3~3.5這個范圍出現較高錯誤,因此在訓練中建議在訓練數據中適量增加中間部分數據的訓練量(該部分不容易區分),以便得到較好的訓練模型!