pytorch實戰:詳解查准率(Precision)、查全率(Recall)與F1


pytorch實戰:詳解查准率(Precision)、查全率(Recall)與F1

1、概述

本文首先介紹了機器學習分類問題的性能指標查准率(Precision)、查全率(Recall)與F1度量,闡述了多分類問題中的混淆矩陣及各項性能指標的計算方法,然后介紹了PyTorch中scatter函數的使用方法,借助該函數實現了對Precision、Recall、F1及正確率的計算,並對實現過程進行了解釋。

觀前提示:閱讀本文需要你對機器學習與PyTorch框架具有一定的了解。

Tips:如果你只是想利用PyTorch計算查准率(Precision)、查全率(Recall)、F1這幾個指標,不想深入了解,請直接跳到第3部分copy代碼使用即可

2、查准率(Precision)、查全率(Recall)與F1

2.1、二分類問題

對於一個二分類(正例與反例)問題,其分類結果的混淆矩陣(Confusion Matrix)如下:

預測的正例 預測的反例
真實的正例 TP(真正例) FN(假反例)
真實的反例 FP(假正例) TN(真 反例)

則查准率P定義為:

\[P=\frac{TP}{TP+FP} \]

查全率R定義為:

\[R=\frac{TP}{TP+FN} \]

可見,查准率與查全率是一對相互矛盾的量,前者表示的是預測的正例真的是正例的概率,而后者表示的是將真正的正例預測為正例的概率。聽上去有點繞,通俗地講,假設這個分類問題是從一批西瓜中辨別哪些是好瓜哪些是不好的瓜,查准率高則意味着你挑選的好瓜大概率真的是好瓜,但是由於你選瓜的標准比較高,這意味着你也錯失了一些好瓜(寧缺毋濫);而查全率高則意味着你能選到大部分的好瓜,但是由於你為了挑選到盡可能多的好瓜,降低了選瓜的標准,這樣許多不太好的瓜也被當成了好瓜選進來。通過調整你選瓜的“門檻”,就可以調整查准率與查全率,即:“門檻”高,則查准率高而查全率低;“門檻”低,則查准率低而查全率高。通常,查准率與查全率不可兼得,在不同的任務中,對查准率與查全率的的重視程度也會有所不同。這時,我們就需要一個綜合考慮查准率與查全率的性能指標了,比如\(F_{\beta}\),該值定義為:

\[F_{\beta}=\frac{(1+{\beta}^2) \times P \times R}{({\beta}^2 \times P)+R} \]

其中\(\beta\)度量了查全率對查准率的相對重要性,\(\beta >1\)時,查全率影響更大,\(\beta <1\)時,查准率影響更大,當\(\beta =1\)時,即是標准的F1度量:

\[F1=\frac{2 \times P \times R}{P+R} \]

此時查准率與查全率影響相同。

2.2、多分類問題

對於多分類問題,在計算查准率與查全率的時候,可以將其當作二分類問題,即正確類和其他類。為了方便闡述,還是借助一個例子來說明吧,假設有一個五分類的問題,每類分別標記為0、1、2、3、4,其分類結果的混淆矩陣如下:

image

在這個問題中,對於類別0,可以看作區分類別{0}與類別{1,2,3,4}的二分類問題,則對應的TP=m00,FP=m10+m20+m30+m40,FN=m01+m02+m03+m04,同樣可以計算出相應的指標。因此,對於一個n分類問題,對應的查准率與查全率應該是一個n維向量,向量中的元素表示每類的查准率與查全率。

3、PyTorch中scatter()函數的使用

在講解第二節中各個指標的計算方法的時候,首先來學習一下PyTorch中scatter()函數的使用方法。查看Pytorch官方文檔中torch.Tensor.scatter()的說明,發現最終跳轉到了torch.Tensor.scatter_(),有的朋友可能會犯迷糊,這兩個函數有什么區別?其實區別很簡單,scatter()不改變原來的張量,而scatter_()是在原張量上進行改變。官網的介紹為:

Tensor.scatter_(dim, index, src, reduce=None) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

這里,self表示的是調用該函數的張量本身,其結果為:

self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

看上去有點繞,用大白話講就是將參數index中給定元素作為一位索引(具體是哪位由dim參數決定),將src中的值填入self張量。這里的src也可以是個標量,如果是標量就直接填充就行。之所以用“填充”這個詞,就是因為並不是self中所有的元素都會改變,只有少數張量會改變,具體改變的位置由index與dim決定。話不多說,直接看代碼

>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
        [0, 2, 0, 0, 0],
        [0, 0, 3, 0, 0]])

該代碼樣例中對一個維度為(3,5)的全0張量進行了計算,參數dim=0,結果張量中,值發生變化的位置分別是(0,0)、(1,1)、(2,2)、(0,3)。細心的朋友應該能發現,這四個位置的第一個維度(dim=0)剛好是張量index的各個元素,而第二個維度則是index中對應元素的另一個維度的值。再看另一個例子

>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 8],
        [0, 0, 0, 0, 0]])

該代碼中值發生變化的位置是(0,0)、(0,1)、(0,2)、(1,0)、(1,1)、(1,4),同樣的,這幾個位置的第二個維度(dim=1)是index中的各個元素,而第一個維度則是index中對應元素的另一個維度的值。在涉及到高維張量的時候,這個操作的空間意義我們可能難以想想,但是記住這點而不去追求理解其空間意義,再高的維度,也能很好地理解這個計算的具體操作。

4、PyTorch實戰與代碼解析

接下來就是在PyTorch實現對查准率、查全率與F1的計算了,在該實例中,我們用到了scatter_()函數,首先來看完整的數據集測試過程代碼,便於各位理解與取用,然后再對具體的計算代碼進行闡述。

def test(valid_queue, net, criterion):
    net.eval()
    test_loss = 0 
    target_num = torch.zeros((1, n_classes)) # n_classes為分類任務類別數量
    predict_num = torch.zeros((1, n_classes))
    acc_num = torch.zeros((1, n_classes))

    with torch.no_grad():
        for step, (inputs, targets) in enumerate(valid_queue):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, _ = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            
            pre_mask = torch.zeros(outputs.size()).scatter_(1, predicted.cpu().view(-1, 1), 1.)
            predict_num += pre_mask.sum(0)  # 得到數據中每類的預測量
            tar_mask = torch.zeros(outputs.size()).scatter_(1, targets.data.cpu().view(-1, 1), 1.)
            target_num += tar_mask.sum(0)  # 得到數據中每類的數量
            acc_mask = pre_mask * tar_mask 
            acc_num += acc_mask.sum(0) # 得到各類別分類正確的樣本數量

        recall = acc_num / target_num
        precision = acc_num / predict_num
        F1 = 2 * recall * precision / (recall + precision)
        accuracy = 100. * acc_num.sum(1) / target_num.sum(1)

        print('Test Acc {}, recal {}, precision {}, F1-score {}'.format(accuracy, recall, precision, F1))

    return accuracy

首先看下面這行代碼,產生一個大小為(batch_size , n_classes)的全0張量,然后將predicted的維度變成(batch_size,1),每個元素都代表的是其分類對應的編號,通過scatter_()函數,將1寫入了全0張量中的對應位置。得到的張量pre_mask就是每次預測結果的one-hot編碼。

pre_mask = torch.zeros(outputs.size()).scatter_(1, predicted.cpu().view(-1, 1), 1.)

然后將pre_mask的第一個維度上的所有值進行加和,就得到了一個維度為(1 ,n_classes)的張量,該張量中的每個元素都是預測結果中對應的類的數量,對其進行累加,便得到了整個測試數據集的預測結果predict_num。

predict_num += pre_mask.sum(0)

tar_mask與target_num同理,此處不再贅述。

acc_mask = pre_mask * tar_mask

然后通過pre_mask * tar_mask,得到了一個表示分類正確的樣本與對應類別的矩陣acc_mask,其維度為(batch_size , n_classes),對於其中一個元素acc_mask[i][j]=1,表示這個batch_size 中的第i個樣本分類正確,類別為j。

acc_num += acc_mask.sum(0)

將acc_mask的第一個維度上的所有值進行加和,便可得到該batch_size數據中每個類別預測正確的數量,累加即可得到整個驗證數據集中各類正確預測的樣本數。

recall = acc_num / target_num
precision = acc_num / predict_num
F1 = 2 * recall * precision / (recall + precision)
accuracy = 100. * acc_num.sum(1) / target_num.sum(1)

然后就可以計算各個指標了,很好理解,就不再解釋了。

希望對你有所幫助,也歡迎在評論區提出你的想法與意見。

5、參考

[1].《機器學習》,周志華

[2].https://blog.csdn.net/qq_16234613/article/details/80039080

[3].https://zhuanlan.zhihu.com/p/46204175

[4].https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html?highlight=scatter_


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM