評估指標【交叉驗證&ROC曲線】


 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Mon Sep 10 11:21:27 2018
 4 
 5 @author: zhen
 6 """
 7 from sklearn.datasets import fetch_mldata
 8 import numpy as np
 9 from sklearn.linear_model import SGDClassifier
10 from sklearn.model_selection import cross_val_score
11 from sklearn.model_selection import cross_val_predict
12 from sklearn.metrics import precision_recall_curve
13 import matplotlib
14 import matplotlib.pyplot as plt
15 from sklearn.metrics import roc_curve
16 from sklearn.metrics import roc_auc_score
17 from sklearn.ensemble import RandomForestClassifier
18 
19 mnist = fetch_mldata('MNIST original', data_home='D:/AnalyseData學習資源庫/人工智能開發/分類評估/資料/test_data_home')
20 
21 x, y = mnist['data'], mnist['target']
22 some_digit = x[36000]  #獲取第36000行數據
23 
24 some_digit_image = some_digit.reshape(28, 28)
25 
26 plt.imshow(some_digit_image, cmap=matplotlib.cm.binary,
27            interpolation='nearest', vmin=0, vmax=1)
28 plt.axis('off')
29 plt.show()
30 
31 x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
32 shuffle_index = np.random.permutation(60000)
33 
34 x_train, y_train = x_train[shuffle_index], y_train[shuffle_index]
35 
36 y_train_5 = (y_train == 5)
37 y_test_5 = (y_test == 5)
38 
39 sgd_clf = SGDClassifier(loss='log', random_state=42, max_iter=1000, tol=1e-4)
40 sgd_clf.fit(x_train, y_train_5)  
41 
42 result = sgd_clf.predict([some_digit])
43 
44 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='accuracy'))
45 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='precision'))
46 print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='recall'))
47 
48 sgd_clf.fit(x_train, y_train_5)
49 
50 y_scores = sgd_clf.decision_function([some_digit])
51 
52 threshold = 0
53 y_some_digit_pred = (y_scores > threshold)
54 
55 threshold = 200000
56 y_some_digit_pred = (y_scores > threshold)
57 
58 # cv 數據集划分的個數
59 y_scores = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3, method='decision_function')
60 
61 precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)
62 
63 
64 def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
65     plt.plot(thresholds, precisions[:-1], 'b--',label='Precision')
66     plt.plot(thresholds, recalls[:-1], 'r--', label='Recall')
67     plt.xlabel("Threshold")
68     plt.legend(loc='upper left')
69     plt.ylim([0, 1])
70     plt.show()  
71     
72     
73 def plot_roc_curve(fpr, tpr, label=None):
74     plt.plot(fpr, tpr, linewidth=2, label='roc')
75     plt.plot([0, 1], [0, 1], 'k--', label='mid')
76     plt.legend(loc='lower right')
77     # plt.axes([0, 1, 0, 1]) : 前兩個參數表示坐標原點的位置,后兩個表示x,y軸的長度
78     plt.xlabel('fpr')
79     plt.ylabel('tpr')
80     plt.show()  
81 
82 
83 plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
84 
85 fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
86 plot_roc_curve(fpr, tpr)
87    
88 print(roc_auc_score(y_train_5, y_scores))
89 
90 forest_clf = RandomForestClassifier(random_state=42)
91 y_probas_forest = cross_val_predict(forest_clf, x_train, y_train_5, cv=3, method='predict_proba')
92 y_scores_forest = y_probas_forest[:, 1]
93 fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
94 plt.plot(fpr, tpr, 'b:', label='SGD')
95 plt.plot(fpr_forest, tpr_forest, label='Random Forest')
96 plt.legend(loc='lower right')
97 plt.show()
98 
99 print(roc_auc_score(y_train_5, y_scores_forest))

          

總結:正向准確率和召回率在整體上成反比,可知在使用相同數據集,相同驗證方式的情況下,隨機森林要優於隨機梯度下降!


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM