少樣本學習(Few-shot Learning)


1

與傳統的監督學習不同,few-shot leaning的目標是讓機器學會學習;使用一個大型的數據集訓練模型,訓練完成后,給出兩張圖片,讓模型分辨這兩張圖片是否屬於同一種事物。比如訓練數據集中有老虎、大象、汽車、鸚鵡等圖片樣本,訓練完畢后給模型輸入兩張兔子的圖片讓模型判斷是否是同一種事物,或者給模型兔子和狗的圖片去判斷。

2

訓練的目的是靠着Support Set提供的一點信息,讓模型判斷出Query中的圖片是otter這個類別,盡管訓練數據集中沒有otter這個類別。
k-way n-shot Support Set

k-way: the support set has k classes;

n-shot: every class has n samples.
k way表示支撐集中的類別,n shot表示支撐集中每個類別包含的樣本數量

3

隨着Support Set中類別增加,分類准確率會降低

因為3選1比6選1更容易,准確率更高;

同樣地,Support Set中shot數量增加,分類准確率會提高

4

idea:學習一個相似度函數

sim函數來計算兩張圖片x和x'的相似度,

例如兩張狗的圖片x1和x2,一張貓的圖片x3,sim(x1,x2)=1, sim(x1,x3)=0,sim(x2,x3)=0

基本思想:

​ (1)首先,從一個大樣本數據集中學習一個相似度函數

​ (2)然后,用相似度函數來做預測

​ ①用query和support set的每一個樣本逐一作比較;

​ ②找出相似度得分最高的樣本

5 常用的數據集

(1)Omniglot

https://github.com/brendenlake/omniglot or https://www.tensorflow.org/datasets/catalog/omniglot

(2)Mini-ImageNet

二 連體網絡Siamese Network

兩種訓練Siamese Network的方法

1 每次取兩個樣本,比較他們的相似度

需要用到一個大的帶標簽的數據集來訓練神經網絡,利用訓練集來構造正樣本Positive Samples和負樣本Negative Samples

Positive Samples:每次從一個類別中隨機抽取兩張圖片,把標簽設置為1,即相似度滿分,用這樣的方法,也從其他類別中抽取圖片,標簽都設置為1

Negative Samples:隨機抽取一個類中的一張圖片,排除掉這個類,再從其他類中隨機抽取一張圖片,把標簽設置為0,即相似度為0,這樣構造負樣本。

搭建一個卷積神經網絡來提取特征,輸入圖片記為x,輸出特征向量記作f(x)

訓練神經網絡,將准備好的圖片輸入神經網絡f,提取的兩個特征向量記作h1,h2,z = |h1-h2|,再通過一個全連接層輸出一個標量,最后使用sigmoid函數得到一個0~1之間的輸出,這個輸出就可以衡量兩個圖片之間的相似度,sim(x1,x2)。兩張圖片屬於同一個類別,那么輸出應該接近1,如果兩張圖片屬於不同類別,那么輸出應該接近0。損失函數是標簽Target=1與sim(x1,x2)之間的差別,用來更新全連接層和神經網絡f的參數(注意這里的圖片輸入的是同一個神經網絡)之所以叫做連體網絡,是這個網絡的結構頭部連在一起,如下圖所示

這樣就完成了一輪訓練

負樣本訓練過程與之類似,只是輸入時兩張不同類別的照片,標簽Target=0

訓練完成后就可以做one-shot prediction,Support Set中的六個類別都不在訓練集里,將Query與Support Set逐一對比,相似度最高的就是預測結果

2 Triplet Loss

每次從訓練集中選出3張圖片,在這3個圖片中選擇一個記為xa, anchor(錨點),選出同類別的另一張圖片,記作正樣本x+,選出其他類別中的一張圖片,記作負樣本x-

把三張圖片輸入卷積神經網絡f提取特征向量f(xa), f(x+), f(x-),計算f(xa), f(x+)之間的二范數距離d+,和f(xa),f(x-)之間的二范數距離d-,d+應該很小,d-應該很大;

設置超參數α為margin,如果d-很大,d->d++α,那么損失函數為0,因為很好的區分開了兩類圖片,反之,損失函數為d++α-d-


在預測時,把圖片都變為特征向量,計算query與他們之間的距離,找出距離最小的 **總結**

三 Pretraining + Fine Tuning

大規模數據上做pretraining,小樣本上fine tuning

1 Pretraining

神經網絡的結構

用3-way 2-shot的SUpport Set做few-shot分類,用與訓練的神經網絡提取特征,將每個類別提取的兩個特征向量求平均,歸一化得到 μ1,μ2,μ3

提取query的特征向量,歸一化得到q,將 μ1,μ2,μ3堆疊起來,得到矩陣M,M與q相乘通過softmax函數得到輸出p,顯然μ1與q的內積是最大的,所以會將query識別為第一類。

2.Fine Tuning

上一過程中,我們假定的W=M,b=0,其實我們可以在Support Set上學習W和b,計算Support Set所有的pj和真實標簽yj之間的CrossEntropy,並使之最小,加上Regularization防止過擬合。

Trick1: 在初始化分類器時,可以把W初始化為M,b=0

Trick2: Regularization使用Entropy Regularization
因為輸出的q是類別的概率值,左邊這種情況說明分類器無法判別query屬於哪一類,這種情況的entropy很高;我們希望的情況是右邊的這種情況,分類器認為query屬於第二類

Trick3: 在通過softmax函數時,將WTq變為cos(W,q),能夠提升准確率

總結

也可以在插入中間步驟Fine Tuning


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM