Python之隨機森林實戰


代碼實現:

 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Tue Sep  4 09:38:57 2018
 4 
 5 @author: zhen
 6 """
 7 
 8 from sklearn.ensemble import RandomForestClassifier
 9 from sklearn.model_selection import train_test_split
10 from sklearn.metrics import accuracy_score
11 from sklearn.datasets import load_iris
12 import matplotlib.pyplot as plt
13 
14 iris = load_iris()
15 x = iris.data[:, :2]
16 y = iris.target
17 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33, random_state=42)
18 
19 # n_estimators:森林中樹的個數(默認為10),建議為奇數
20 # n_jobs:並行執行任務的個數(包括模型訓練和預測),默認值為-1,表示根據核數
21 rnd_clf = RandomForestClassifier(n_estimators=15, max_leaf_nodes=16, n_jobs=1)
22 rnd_clf.fit(x_train, y_train)
23 
24 y_predict_rf = rnd_clf.predict(x_test)
25 
26 print(accuracy_score(y_test, y_predict_rf))
27 
28 for name, score in zip(iris['feature_names'], rnd_clf.feature_importances_):
29     print(name, score)
30     
31 # 可視化
32 plt.plot(x_test[:, 0], y_test, 'r.', label='real')
33 plt.plot(x_test[:, 0], y_predict_rf, 'b.', label='predict')
34 plt.xlabel('sepal-length', fontsize=15)
35 plt.ylabel('type', fontsize=15)
36 plt.legend(loc="upper left")
37 plt.show()
38 
39 plt.plot(x_test[:, 1], y_test, 'r.', label='real')
40 plt.plot(x_test[:, 1], y_predict_rf, 'b.', label='predict')
41 plt.xlabel('sepal-width', fontsize=15)
42 plt.ylabel('type', fontsize=15)
43 plt.legend(loc="upper right")
44 plt.show()

結果:

可視化(查看每個預測條件的影響):

   分析:鳶尾花的花萼長度在小於6時預測准確率很高,隨着長度的增加,在6~7這段中,預測出現較大錯誤率,當大於7時,預測會恢復到較好的情況。寬度也出現類似的情況,在3~3.5這個范圍出現較高錯誤,因此在訓練中建議在訓練數據中適量增加中間部分數據的訓練量(該部分不容易區分),以便得到較好的訓練模型!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM