1 #利用鳶尾花數據集繪制P-R曲線 2 print(__doc__) #打印注釋 3 4 import matplotlib.pyplot as plt 5 import numpy as np 6 from sklearn import svm, datasets 7 from sklearn.metrics import precision_recall_curve 8 from sklearn.metrics import average_precision_score 9 from sklearn.preprocessing import label_binarize 10 from sklearn.multiclass import OneVsRestClassifier #一對其余(每次將一個類作為正類,剩下的類作為負類) 11 12 from sklearn.cross_validation import train_test_split #適用於anaconda 3.6及以前版本 13 #from sklearn.model_selection import train_test_split#適用於anaconda 3.7 14 15 #以iris數據為例,畫出P-R曲線 16 iris = datasets.load_iris() 17 X = iris.data #150*4 18 y = iris.target #150*1 19 20 # 標簽二值化,將三個類轉為001, 010, 100的格式.因為這是個多類分類問題,后面將要采用 21 #OneVsRestClassifier策略轉為二類分類問題 22 y = label_binarize(y, classes=[0, 1, 2]) #將150*1轉化成150*3 23 n_classes = y.shape[1] #列的個數,等於3 24 print (y) 25 26 # 增加了800維的噪聲特征 27 random_state = np.random.RandomState(0) 28 n_samples, n_features = X.shape 29 30 X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] #行不變,只增加了列,150*804 31 32 # 訓練集和測試集拆分,比例為0.5 33 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=random_state) #隨機數,填0或不填,每次都會不一樣 34 35 # 一對其余,轉換成兩類,構建新的分類器 36 classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True, random_state=random_state)) 37 #訓練集送給fit函數進行擬合訓練,訓練完后將測試集的樣本特征注入,得到測試集中每個樣本預測的分數 38 y_score = classifier.fit(X_train, y_train).decision_function(X_test) 39 40 # Compute Precision-Recall and plot curve 41 #下面的下划線是返回的閾值。作為一個名稱:此時“_”作為臨時性的名稱使用。 42 #表示分配了一個特定的名稱,但是並不會在后面再次用到該名稱。 43 precision = dict() 44 recall = dict() 45 average_precision = dict() 46 for i in range(n_classes): 47 #對於每一類,計算精確率和召回率的序列(:表示所有行,i表示第i列) 48 precision[i], recall[i], _ = precision_recall_curve(y_test[:, i], y_score[:, i]) 49 average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])#切片,第i個類的分類結果性能 50 51 # Compute micro-average curve and area. ravel()將多維數組降為一維 52 precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(), y_score.ravel()) 53 average_precision["micro"] = average_precision_score(y_test, y_score, average="micro") #This score corresponds to the area under the precision-recall curve. 54 55 # Plot Precision-Recall curve for each class 56 plt.clf()#clf 函數用於清除當前圖像窗口 57 plt.plot(recall["micro"], precision["micro"], 58 label='micro-average Precision-recall curve (area = {0:0.2f})'.format(average_precision["micro"])) 59 for i in range(n_classes): 60 plt.plot(recall[i], precision[i], 61 label='Precision-recall curve of class {0} (area = {1:0.2f})'.format(i, average_precision[i])) 62 63 plt.xlim([0.0, 1.0]) 64 plt.ylim([0.0, 1.05]) #xlim、ylim:分別設置X、Y軸的顯示范圍。 65 plt.xlabel('Recall', fontsize=16) 66 plt.ylabel('Precision',fontsize=16) 67 plt.title('Extension of Precision-Recall curve to multi-class',fontsize=16) 68 plt.legend(loc="lower right")#legend 是用於設置圖例的函數 69 plt.show()
運行結果如下: