深度學習中噪聲標簽的影響和識別


問題導入

在機器學習領域中,常見的一類工作是使用帶標簽數據訓練神經網絡實現分類、回歸或其他目的,這種訓練模型學習規律的方法一般稱之為監督學習。在監督學習中,訓練數據所對應的標簽質量對於學習效果至關重要。如果學習時使用的標簽數據都是錯誤的,那么不可能訓練出有效的預測模型。同時,深度學習使用的神經網絡往往結構復雜,為了得到良好的學習效果,對於帶標簽的訓練數據的數量也有較高要求,即常被提到的大數據或海量數據。

矛盾在於:給數據打標簽這個工作在很多場景下需要人工實現,海量、高質量標簽本身費時費力,在經濟上相對昂貴。因此,實際應用中的機器學習問題必須面對噪音標簽的影響,即我們拿到的每一個帶標簽數據集都要假定其中是包含噪聲的。進一步,由於樣本量很大,對於每一個帶標簽數據集,我們不可能人工逐個檢查並校正標簽。

基於上述矛盾現狀,在實際工作中必須面對以下兩點問題

1. 訓練集帶標簽樣本中噪音達到什么水平對於模型預測結果會有致命影響

2. 對於任意給定帶標簽訓練集,如何快速找出可能是噪音的樣本

本文接下來將圍繞這兩點通過實驗給出分析

數據、神經網絡設計和代碼

本文以Tensorflow教程中提及的MNIST問題[1]為數據來源和問題定義。此問題為圖像識別問題,圖片為手寫的0-9字符,每個圖片格式為28*28灰度圖。訓練集數據包括55000張手寫數字和標簽,驗證集包括約10000張圖片和標簽。通過訓練神經網絡從而實現當輸入一張驗證集中的圖片后,神經網絡能夠正確預測這張圖片的標簽。

對於MNIST問題本身,Tensorflow教程[2]描述的包含2個卷積池化層的CNN網絡已經足以實現99%左右的預測精度,因此在本實驗中,筆者直接引用Tensorflow官方樣例中的CNN網絡[3]作為預測模型的神經網絡。

本文所有代碼可以在筆者的Github項目中獲得:wangyaobupt/NoisyLabels

噪聲標簽對於分類器性能的影響

考慮到MNIST是機器學習領域使用多年的數據庫,且在其數據上訓練的模型已經得到了較好的結果,由此可以合理推斷其標簽本身的噪聲含量較低(這個推理將在下一個章節通過實驗證實)。因此,在這一節的實驗中,我們假定原始的MNIST的訓練集和驗證集標簽都是無噪聲的。

使用如下步驟給標簽添加噪聲

1. 根據給定的噪聲比例noiseLevel,從N個總樣本中選擇出K個樣本,K = N*noiseLevel

2. 對於選出的K個樣本中的每一個樣本,將其原始標簽替換為0-9之間扣除原始標簽之外的隨機數

上述算法的代碼實現如下,testcase2.py提供了完整的可執行程序

# Add random noise to MNIST training set
# input:
# mnist_data: data structure that follow tensorflow MNIST demo
# noise_level: a percentage from 0 to 1, indicate how many percentage of labels are wrong
def addRandomNoiseToTrainingSet(mnist_data, noise_level):
    # the data structure of labels refer to DataSet in tensorflow/tensorflow/contrib/learn/python/learn/datasets/mnist.py 
    label_data_set = mnist_data.train.labels
    #print label_data_set.shape

    totalNum = label_data_set.shape[0]
    corruptedIdxList = randomSelectKFromN(int(noise_level*totalNum),totalNum)
    #print 'DEBUG: 1st elements in corruptedIdxList is: ', corruptedIdxList[0], ' length = ', len(corruptedIdxList)

    for cIdx in corruptedIdxList:
        #print "DEBUG: convert index = ", cIdx
        correctLabel = label_data_set[cIdx]
        #print 'DEBUG: Correct label = ', correctLabel
        wrongLabel = convertCorrectLabelToCorruptedLabel(correctLabel)
        #print 'DEBUG: Wrong label = ', wrongLabel
        label_data_set[cIdx] = wrongLabel


# uniform randomly select K integers from range [0,N-1]
def randomSelectKFromN(K, N):
    #print 'DEBUG: K = ',K, ' N = ', N
    resultList =[]
    seqList = range(N)
    while (len(resultList) < K):
        index = (int)(np.random.rand(1)[0] * len(seqList))
        #index = 0 # for DEBUG ONLY
        resultList.append(seqList[index])
        seqList.remove(seqList[index])
    #print resultList
    return resultList

# Convert correct ont-hot vector label to a wrong label, the error pattern is randomly selected, i.e. not considering the content of image
def convertCorrectLabelToCorruptedLabel(correctLabel):
    correct_value = np.argmax(correctLabel, 0)
    target_value = int(np.random.rand(1)[0]*10)%10
    if target_value == correct_value:
        target_value = ((target_value+1) % 10)
    result = np.zeros(correctLabel.shape)
    result[target_value] = 1.0 
    return result


這樣,當給定噪聲水平之后,上述算法完成添加噪聲,進一步用帶噪聲的訓練集訓練出模型,最終在驗證集上對模型評價精度。下圖是噪聲標簽比例在0-100%范圍內變化時,模型精度的變化。



從上圖可以看出,在噪聲標簽占比不超過60%的情況下,驗證集精度保持在96%以上,即便噪聲標簽占比達到70%,驗證集精度仍然能達到93%。在噪聲標簽占比超過70%之后,精度結果快速下降,當噪聲占比達到88%時,預測精度已經下降到7%。這個水平已經低於純隨機預測,考慮到此問題為10分類問題,在完全隨機的情況下,預期精度的數學期望也在10%左右。

這里就引出了兩個問題:

1. 為什么在噪音標簽占比70%的情況下,模型抗噪聲性能這么好?

2. 70%之后的快速下降又是由什么導致的

為了回答上述問題,要重新審視此前加噪聲標簽的方法。在加噪聲的第一步,我們均勻的隨機抽取出一定比例的標簽,考慮到原始數據10類標簽的分布是基本均勻的,那么抽出來的K個樣本中10類標簽的數量基本一致。在第二步,對於每個標簽,我們將正確標簽抹去,從正確標簽之外的9個字符中選擇一個作為標簽,由於選擇算法本身也是隨機的,那么,錯誤標簽是均勻分布在其他9類的。

上述解釋如果還不夠直觀,那么可以看下圖。假設有1000條正確標簽為2的數據,在70%的噪聲條件下,只有300條數據標簽為2,其余700條數據的標簽均勻分布在其他9類。這樣,正確標簽(300條‘2’標簽)相比其他任何一個類別,仍然占有明顯數量優勢,所以CNN才可以根據這個數量優勢學習到正確標簽2.

而當噪聲比例進一步增加后,數量對比優勢會逐漸弱化,例如下圖。這種情況下正確標簽雖然占比仍然多於其他分類,但是數量上已經只有2倍的差異。在模型訓練中,正確標簽帶來的梯度下降增益不足以對抗錯誤標簽的影響,神經網絡傾向於學習到隨機標簽。

由上述兩張圖可以看出,如果在多分類問題中噪音標簽是均勻分布的,同時正確標簽相對於每個類別的錯誤標簽有數倍的數量優勢,那么訓練過程有可能承受較高的噪聲標簽水平得到相對精確的模型。但如果噪音標簽已經與正確標簽數量接近,那么很難訓練出有意義的模型。


如何快速識別出疑似噪聲的標簽

在真實應用中,我們顯然不會人工在訓練數據集上添加噪聲。但如前文所述,訓練數據集本身是含有噪聲的,除了人工逐個審查,有沒有辦法快速找出疑似是噪聲的標簽呢?

為了解決這個問題,我們回到基於CNN網絡的MNIST分類器最后一層來看。在分類器的最后一層,全連接網絡包含10個神經元,輸出10個運算結果,可以看作一個10維向量。這個10維向量經過softmax運算可以轉為離散概率分布,其和為1,每個維度代表分類器預測當前圖片屬於某一類的概率。最終的預測結果就是取離散概率分布中概率值最高的一類作為預測結果。

在實驗中觀察不同樣本的概率分布,可以看到有以下兩種情況

  • 當一張圖片清晰且無歧義時,神經網絡輸出的離散概率分布是集中在一個標簽的,例如正確標簽概率為0.999,其余9種類別的概率接近於0.
  • 當一張圖片存在歧義時,神經網絡輸出的離散概率分布就不會只集中在一個標簽,有可能最強的標簽概率只有0.6,第二強的標簽概率0.39,其余8個類別概率為0 這樣的結果意味着神經網絡認為這張標簽有二義性。

基於這個認識,就可以設計出一種方法,讓神經網絡把自己認為存在二義性的樣本和標簽篩選出來,即實現了非人工快速找出疑似噪音標簽。

下面是二義性判斷的代碼實現,二義性在這里定量的定義為:分類器認為最有可能類別的概率低於70%,同時第二可能類別概率高於15%。下列代碼是挑選二義性概率分布的實現,是simpleCNN.py的一部分,testcase3.py提供了篩選二義性樣本的可執行程序

   # Filter out images with low SNR.
    # The term 'low SNR' is defined as: in the probability distribution of this sample, the largest value is <= 0.7, while the 2nd largest value >= 0.15
    # the raw images data (in shape of 1*784 vector), labels, and top 2 possibilities by CNN will be returned
    # Parameter:
    # train_or_test, 0 means train data, 1 means test data
    def filterLowSNRSamples(self, mnist, train_or_test=0):
        if train_or_test == 1:
            data = mnist.test
        else:
            data = mnist.train

        resultList = []

        for sample_idx in range(data.images.shape[0]):
            prob_dist, label=self.sess.run([self.output_prob_distribution, self.label], feed_dict={
                    self.x: np.reshape(data.images[sample_idx], (1, 784)), self.y_: np.reshape(data.labels[sample_idx], (1,10)), self.keep_prob:1.0})

            raw_prob_array = prob_dist[0]

            #search for position of the largest value and the 2nd largest value
            top_1_pos , top_2_pos = findPosOfLargestTwoElement(raw_prob_array, 10)

            #Low SNR criteria
            if raw_prob_array[top_1_pos] <= 0.7 and raw_prob_array[top_2_pos] >= 0.15:
                resultList.append((sample_idx, data.images[sample_idx], label, top_1_pos, top_2_pos))

            if (sample_idx % 1000 == 0):
                print "DEBUG, current idx = %d, num_of_low_SNR = %d" % (sample_idx, len(resultList))
        return resultList


使用這套方法,在MNIST的55000個訓練數據和標簽中篩選出408個疑似有二義性的圖片,下圖是部分典型圖片。由此來看,MNIST本身的標簽質量是較高的。下圖中不少標簽人工識別也存在困難,這恰恰說明了找出的標簽很大程度上就是“疑似噪聲標簽”



小結

本文對於MNIST數據集,使用CNN分類器,考察了噪聲對模型預測精度的影響,實驗結果表明,在均勻分布的隨機噪聲條件下,CNN模型可以在噪聲標簽占比70%的情況下預測精度無明顯下降。進一步,為了識別原始訓練集中的疑似噪聲樣本,文中使用訓練好的CNN模型通過預測向量的概率分布,識別存在二義性的標簽,實現了低代價找出訓練集噪聲標簽的目的。


參考文獻

[1] tensorflow.org/get_star

[2] tensorflow.org/get_star

[3] TF MNIST code







免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM