多分類下的ROC曲線和AUC


 本文主要介紹一下多分類下的ROC曲線繪制和AUC計算,並以鳶尾花數據為例,簡單用python進行一下說明。如果對ROC和AUC二分類下的概念不是很了解,可以先參考下這篇文章:http://blog.csdn.net/ye1215172385/article/details/79448575

        由於ROC曲線是針對二分類的情況,對於多分類問題,ROC曲線的獲取主要有兩種方法:

        假設測試樣本個數為m,類別個數為n(假設類別標簽分別為:0,2,...,n-1)。在訓練完成后,計算出每個測試樣本的在各類別下的概率或置信度,得到一個[m, n]形狀的矩陣P,每一行表示一個測試樣本在各類別下概率值(按類別標簽排序)。相應地,將每個測試樣本的標簽轉換為類似二進制的形式,每個位置用來標記是否屬於對應的類別(也按標簽排序,這樣才和前面對應),由此也可以獲得一個[m, n]的標簽矩陣L。

         比如n等於3,標簽應轉換為:

        方法1:每種類別下,都可以得到m個測試樣本為該類別的概率(矩陣P中的列)。所以,根據概率矩陣P和標簽矩陣L中對應的每一列,可以計算出各個閾值下的假正例率(FPR)和真正例率(TPR),從而繪制出一條ROC曲線。這樣總共可以繪制出n條ROC曲線。最后對n條ROC曲線取平均,即可得到最終的ROC曲線。

        方法2:首先,對於一個測試樣本:1)標簽只由0和1組成,1的位置表明了它的類別(可對應二分類問題中的‘’正’’),0就表示其他類別(‘’負‘’);2)要是分類器對該測試樣本分類正確,則該樣本標簽中1對應的位置在概率矩陣P中的值是大於0對應的位置的概率值的。基於這兩點,將標簽矩陣L和概率矩陣P分別按行展開,轉置后形成兩列,這就得到了一個二分類的結果。所以,此方法經過計算后可以直接得到最終的ROC曲線。

       上面的兩個方法得到的ROC曲線是不同的,當然曲線下的面積AUC也是不一樣的。 在python中,方法1和方法2分別對應sklearn.metrics.roc_auc_score函數中參數average值為'macro'和'micro'的情況。

      下面以方法1為例,直接上代碼,概率矩陣P和標簽矩陣L分別對應代碼中的y_score和y_one_hot:

#!/usr/bin/python
# -*- coding:utf-8 -*-
 
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegressionCV
from sklearn import metrics
from sklearn.preprocessing import label_binarize
 
if __name__ == '__main__':
    np.random.seed(0)
    data = pd.read_csv('iris.data', header = None)  #讀取數據
    iris_types = data[4].unique()
    n_class = iris_types.size
    x = data.iloc[:, :2]  #只取前面兩個特征
    y = pd.Categorical(data[4]).codes  #將標簽轉換0,1,...
    x_train, x_test, y_train, y_test = train_test_split(x, y, train_size = 0.6, random_state = 0)
    y_one_hot = label_binarize(y_test, np.arange(n_class))  #裝換成類似二進制的編碼
    alpha = np.logspace(-2, 2, 20)  #設置超參數范圍
    model = LogisticRegressionCV(Cs = alpha, cv = 3, penalty = 'l2')  #使用L2正則化
    model.fit(x_train, y_train)
    print '超參數:', model.C_
    # 計算屬於各個類別的概率,返回值的shape = [n_samples, n_classes]
    y_score = model.predict_proba(x_test)
    # 1、調用函數計算micro類型的AUC
    print '調用函數auc:', metrics.roc_auc_score(y_one_hot, y_score, average='micro')
    # 2、手動計算micro類型的AUC
    #首先將矩陣y_one_hot和y_score展開,然后計算假正例率FPR和真正例率TPR
    fpr, tpr, thresholds = metrics.roc_curve(y_one_hot.ravel(),y_score.ravel())
    auc = metrics.auc(fpr, tpr)
    print '手動計算auc:', auc
    #繪圖
    mpl.rcParams['font.sans-serif'] = u'SimHei'
    mpl.rcParams['axes.unicode_minus'] = False
    #FPR就是橫坐標,TPR就是縱坐標
    plt.plot(fpr, tpr, c = 'r', lw = 2, alpha = 0.7, label = u'AUC=%.3f' % auc)
    plt.plot((0, 1), (0, 1), c = '#808080', lw = 1, ls = '--', alpha = 0.7)
    plt.xlim((-0.01, 1.02))
    plt.ylim((-0.01, 1.02))
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.yticks(np.arange(0, 1.1, 0.1))
    plt.xlabel('False Positive Rate', fontsize=13)
    plt.ylabel('True Positive Rate', fontsize=13)
    plt.grid(b=True, ls=':')
    plt.legend(loc='lower right', fancybox=True, framealpha=0.8, fontsize=12)
    plt.title(u'鳶尾花數據Logistic分類后的ROC和AUC', fontsize=17)
    plt.show()

我的實戰

Bnew_one1=[]
    for lis in Bnew4:
        bol=np.zeros(51)
        bol=bol.tolist()
        bol[lis[0]]=1
        Bnew_one1.append(bol)
    
    Blast_one=[]
    for lis in Blast:
        bol=np.zeros(51)
        bol=bol.tolist()
        bol[lis[0]]=1
        Blast_one.append(bol)
    
    Bnew_one1=np.array(Bnew_one1)
    Blast_one=np.array(Blast_one)
    Bnew_one=np.array(Bnew_one)
    
    print('調用函數auc:', metrics.roc_auc_score(Blast_one, Bnew_one1, average='micro'))
    
    fpr, tpr, thresholds = metrics.roc_curve(Blast_one.ravel(),Bnew_one1.ravel())
    auc = metrics.auc(fpr, tpr)
    print('手動計算auc:', auc)
    #繪圖
    mpl.rcParams['font.sans-serif'] = u'SimHei'
    mpl.rcParams['axes.unicode_minus'] = False
    #FPR就是橫坐標,TPR就是縱坐標
    plt.plot(fpr, tpr, c = 'r', lw = 2, alpha = 0.7, label = u'AUC=%.3f' % auc)
    plt.plot((0, 1), (0, 1), c = '#808080', lw = 1, ls = '--', alpha = 0.7)
    plt.xlim((-0.01, 1.02))
    plt.ylim((-0.01, 1.02))
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.yticks(np.arange(0, 1.1, 0.1))
    plt.xlabel('False Positive Rate', fontsize=13)
    plt.ylabel('True Positive Rate', fontsize=13)
    plt.grid(b=True, ls=':')
    plt.legend(loc='lower right', fancybox=True, framealpha=0.8, fontsize=12)
    plt.title(u'大類問題一分類后的ROC和AUC', fontsize=17)
    plt.show()
    


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM