Siamese網絡


1.       對比損失函數(Contrastive Loss function)

孿生架構的目的不是對輸入圖像進行分類,而是區分它們。因此,分類損失函數(如交叉熵)不是最合適的選擇,這種架構更適合使用對比函數。對比損失函數如下:

 

 

(以判斷圖片相似度為例)其中Dw被定義為姐妹孿生網絡的輸出之間的歐氏距離。Y值為1或0。如果模型預測輸入是相似的,那么Y的值為0,否則Y為1。m是大於0的邊際價值(margin value)。有一個邊際價值表示超出該邊際價值的不同對不會造成損失。

Siamese網絡架構需要一個輸入對,以及標簽(類似/不相似)。

2.       孿生網絡的訓練過程

(1)    通過網絡傳遞圖像對的第一張圖像。

(2)    通過網絡傳遞圖像對的第二張圖像。

(3)    使用(1)和(2)中的輸出來計算損失。

 

 

  其中,l12為標簽,用於表示x1的排名是否高於x2。

訓練過程中兩個分支網絡的輸出為高級特征,可以視為quality score。在訓練時,輸入是兩張圖像,分別得到對應的分數,將分數的差異嵌入loss層,再進行反向傳播。

(4)    返回傳播損失計算梯度。

(5)    使用優化器更新權重。

3.       基於Siamese網絡的無參考圖像質量評估:RankIQA

3.1          參考文獻

https://arxiv.org/abs/1707.08347

3.2          RankIQA的流程

(1)    合成失真圖像。

(2)    訓練Siamese網絡,使網絡輸出對圖像質量的排序。

(3)    提取Siamese網絡的一支,使用IQA數據集進行fine-tune。將網絡的輸出校正為IQA度量值。fine-tune階段的損失函數如下:

 

 

訓練階段使用Hinge loss,fine-tune階段使用MSE。

訓練時,每次從圖像中隨機提取224*224或者227*227大小的圖像塊。和AlexNet、VGG16有關。在訓練Siamese network時,初始的learning rate 是1e-4;fine-tuning是初始的learning rate是1e-6,每隔10k步,rate變成原來的0.1倍。訓練50k次。測試時,隨機在圖像中提取30個圖像塊,得到30個分數之后,進行均值操作。

本文如何提高Siamese網絡的效率:

假設有三張圖片,則每張圖片將被輸入網絡兩次,原因是含有某張圖片的排列數為2。為了減少計算量,每張圖片只輸入網絡一次,在網絡之后、損失函數之前新建一層,用於生成每個mini-batch中圖片的可能排列組合。

使用上述方法,每張圖片只前向傳播一次,只在loss計算時考慮所有的圖片組合方式。

本文使用的網絡架構:Shallow, AlexNet, and VGG-16。

4.       Siamese網絡的開源實現

4.1          代碼地址

https://github.com/xialeiliu/RankIQA

4.2          RankIQA的運行過程

4.2.1            數據集

使用兩方面的數據集,一般性的非IQA數據集用於生成排序好的圖片,進而訓練Siamese網絡;IQA數據集用於微調和評估。

本文使用的IQA數據集:

(1)    LIVE數據集:http://live.ece.utexas.edu/research/quality/ 對29張原始圖片進行五類失真處理,得到808張圖片。Ground Truth MOS在[0, 100]之間(人工評分)。

(2)    TID2013:25張原始圖片,3000張失真圖片。MOS范圍是[0, 9]。

本文使用的用於生成ranked pairs的數據集:

(1)    為了測試LIVE數據集,人工生成了四類失真,GB(Gaussian Blur)、GN(Gaussian Noise)、JPEG、JPEG2K

(2)    為了在TID2013上測試,生成了17種失真(去掉了#3, #4,#12, #13, #20, #21, #24)

Waterloo數據集:

包含4744張高質量自然圖片。

Places2數據集:

作為驗證集(包含356種場景,http://places2.csail.mit.edu/ ),每類100張,共35600張。

兩種數據集的區別:

python generate_rank_txt_tid2013.py生成的是tid2013_train.txt,標簽只起到表示相對順序的作用,即,標簽為{1, 2, 3, 4, 5};python generate_ft_txt_tid2013.py生成的是ft_tid2013_test.txt,其中的標簽是浮點數,表示圖片的質量評分。

 

4.2.2            訓練和測試過程

從原始圖像中隨機采樣子圖(sub-images),避免因差值和過濾而產生的失真。輸入的子圖至少占原圖的1/3,以保留場景信息。本文采用227*227或者224*224的采樣圖像(根據使用的主干網絡而不同)。

訓練過程使用mini-batch SGD,初始學習率1e-4,fine-tune學習率1e-6。

共迭代50K次,每10K次減小學習率(乘以0.1),兩個訓練過程都是用l2權重衰減(正則化系數lambda=5e-4)。

實驗一:本文首先使用Places2數據集(使用五種失真進行處理)訓練網絡(不進行微調),然后在Waterloo數據及上進行預測IQA(使用同樣的五種失真進行處理)。實驗結果如圖2所示。

 

 

實驗二:hard negative mining

難分樣本挖掘,是當得到錯誤的檢測patch時,會明確的從這個patch中創建一個負樣本,並把這個負樣本添加到訓練集中去。重新訓練分類器后,分類器會表現的更好,並且不會像之前那樣產生多的錯誤的正樣本。

本實驗使用Alexnet進行。

實驗三:網絡性能分析

LIVE數據集,80%訓練集,評價指標LCC和SROCC。VGG-16的效果最好。

4.2.3            RankIQA對數據集的處理過程

將原始圖像文件放在data/rank_tid2013/pristine_images路徑下,然后運行data/rank_tid2013/路徑下的tid2013_main.m,進而生成排序數據集(17種失真形式)。

4.3          運行指令

4.3.1            Train RankIQA

To train the RankIQA models on tid2013 dataset:

./src/RankIQA/tid2013/train_vgg.sh

 

To train the RankIQA models on LIVE dataset:

./src/RankIQA/live/train_vgg.sh

 

FT

To train the RankIQA+FT models on tid2013 dataset:

./src/FT/tid2013/train_vgg.sh

 

To train the RankIQA+FT models on LIVE dataset:

./src/FT/live/train_live.sh

 

4.3.2            Evaluation for RankIQA

python src/eval/Rank_eval_each_tid2013.py  # evaluation for each distortions in tid2013

python src/eval/Rank_eval_all_tid2013.py   # evaluation for all distortions in tid2013

Evaluation for RankIQA+FT on tid2013:

python src/eval/FT_eval_each_tid2013.py  # evaluation for each distortions in tid2013

python src/eval/FT_eval_all_tid2013.py   # evaluation for all distortions in tid2013

Evaluation for RankIQA on LIVE:

python src/eval/Rank_eval_all_live.py   # evaluation for all distortions in LIVE

Evaluation for RankIQA+FT on LIVE:

python src/eval/FT_eval_all_live.py   # evaluation for all distortions in LIVE

5.       代碼調試過程

5.1          Python無法導入某個模塊ImportError:could not find module XXX

解決方案:

配置環境變量:export PYTHONPATH=path/to/modules

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM