1. 對比損失函數(Contrastive Loss function)
孿生架構的目的不是對輸入圖像進行分類,而是區分它們。因此,分類損失函數(如交叉熵)不是最合適的選擇,這種架構更適合使用對比函數。對比損失函數如下:
(以判斷圖片相似度為例)其中Dw被定義為姐妹孿生網絡的輸出之間的歐氏距離。Y值為1或0。如果模型預測輸入是相似的,那么Y的值為0,否則Y為1。m是大於0的邊際價值(margin value)。有一個邊際價值表示超出該邊際價值的不同對不會造成損失。
Siamese網絡架構需要一個輸入對,以及標簽(類似/不相似)。
2. 孿生網絡的訓練過程
(1) 通過網絡傳遞圖像對的第一張圖像。
(2) 通過網絡傳遞圖像對的第二張圖像。
(3) 使用(1)和(2)中的輸出來計算損失。
其中,l12為標簽,用於表示x1的排名是否高於x2。
訓練過程中兩個分支網絡的輸出為高級特征,可以視為quality score。在訓練時,輸入是兩張圖像,分別得到對應的分數,將分數的差異嵌入loss層,再進行反向傳播。
(4) 返回傳播損失計算梯度。
(5) 使用優化器更新權重。
3. 基於Siamese網絡的無參考圖像質量評估:RankIQA
3.1 參考文獻
https://arxiv.org/abs/1707.08347
3.2 RankIQA的流程
(1) 合成失真圖像。
(2) 訓練Siamese網絡,使網絡輸出對圖像質量的排序。
(3) 提取Siamese網絡的一支,使用IQA數據集進行fine-tune。將網絡的輸出校正為IQA度量值。fine-tune階段的損失函數如下:
訓練階段使用Hinge loss,fine-tune階段使用MSE。
訓練時,每次從圖像中隨機提取224*224或者227*227大小的圖像塊。和AlexNet、VGG16有關。在訓練Siamese network時,初始的learning rate 是1e-4;fine-tuning是初始的learning rate是1e-6,每隔10k步,rate變成原來的0.1倍。訓練50k次。測試時,隨機在圖像中提取30個圖像塊,得到30個分數之后,進行均值操作。
本文如何提高Siamese網絡的效率:
假設有三張圖片,則每張圖片將被輸入網絡兩次,原因是含有某張圖片的排列數為2。為了減少計算量,每張圖片只輸入網絡一次,在網絡之后、損失函數之前新建一層,用於生成每個mini-batch中圖片的可能排列組合。
使用上述方法,每張圖片只前向傳播一次,只在loss計算時考慮所有的圖片組合方式。
本文使用的網絡架構:Shallow, AlexNet, and VGG-16。
4. Siamese網絡的開源實現
4.1 代碼地址
https://github.com/xialeiliu/RankIQA
4.2 RankIQA的運行過程
4.2.1 數據集
使用兩方面的數據集,一般性的非IQA數據集用於生成排序好的圖片,進而訓練Siamese網絡;IQA數據集用於微調和評估。
本文使用的IQA數據集:
(1) LIVE數據集:http://live.ece.utexas.edu/research/quality/ 對29張原始圖片進行五類失真處理,得到808張圖片。Ground Truth MOS在[0, 100]之間(人工評分)。
(2) TID2013:25張原始圖片,3000張失真圖片。MOS范圍是[0, 9]。
本文使用的用於生成ranked pairs的數據集:
(1) 為了測試LIVE數據集,人工生成了四類失真,GB(Gaussian Blur)、GN(Gaussian Noise)、JPEG、JPEG2K
(2) 為了在TID2013上測試,生成了17種失真(去掉了#3, #4,#12, #13, #20, #21, #24)
Waterloo數據集:
包含4744張高質量自然圖片。
Places2數據集:
作為驗證集(包含356種場景,http://places2.csail.mit.edu/ ),每類100張,共35600張。
兩種數據集的區別:
python generate_rank_txt_tid2013.py生成的是tid2013_train.txt,標簽只起到表示相對順序的作用,即,標簽為{1, 2, 3, 4, 5};python generate_ft_txt_tid2013.py生成的是ft_tid2013_test.txt,其中的標簽是浮點數,表示圖片的質量評分。
4.2.2 訓練和測試過程
從原始圖像中隨機采樣子圖(sub-images),避免因差值和過濾而產生的失真。輸入的子圖至少占原圖的1/3,以保留場景信息。本文采用227*227或者224*224的采樣圖像(根據使用的主干網絡而不同)。
訓練過程使用mini-batch SGD,初始學習率1e-4,fine-tune學習率1e-6。
共迭代50K次,每10K次減小學習率(乘以0.1),兩個訓練過程都是用l2權重衰減(正則化系數lambda=5e-4)。
實驗一:本文首先使用Places2數據集(使用五種失真進行處理)訓練網絡(不進行微調),然后在Waterloo數據及上進行預測IQA(使用同樣的五種失真進行處理)。實驗結果如圖2所示。
實驗二:hard negative mining
難分樣本挖掘,是當得到錯誤的檢測patch時,會明確的從這個patch中創建一個負樣本,並把這個負樣本添加到訓練集中去。重新訓練分類器后,分類器會表現的更好,並且不會像之前那樣產生多的錯誤的正樣本。
本實驗使用Alexnet進行。
實驗三:網絡性能分析
LIVE數據集,80%訓練集,評價指標LCC和SROCC。VGG-16的效果最好。
4.2.3 RankIQA對數據集的處理過程
將原始圖像文件放在data/rank_tid2013/pristine_images路徑下,然后運行data/rank_tid2013/路徑下的tid2013_main.m,進而生成排序數據集(17種失真形式)。
4.3 運行指令
4.3.1 Train RankIQA
To train the RankIQA models on tid2013 dataset:
./src/RankIQA/tid2013/train_vgg.sh
To train the RankIQA models on LIVE dataset:
./src/RankIQA/live/train_vgg.sh
FT
To train the RankIQA+FT models on tid2013 dataset:
./src/FT/tid2013/train_vgg.sh
To train the RankIQA+FT models on LIVE dataset:
./src/FT/live/train_live.sh
4.3.2 Evaluation for RankIQA
python src/eval/Rank_eval_each_tid2013.py # evaluation for each distortions in tid2013
python src/eval/Rank_eval_all_tid2013.py # evaluation for all distortions in tid2013
Evaluation for RankIQA+FT on tid2013:
python src/eval/FT_eval_each_tid2013.py # evaluation for each distortions in tid2013
python src/eval/FT_eval_all_tid2013.py # evaluation for all distortions in tid2013
Evaluation for RankIQA on LIVE:
python src/eval/Rank_eval_all_live.py # evaluation for all distortions in LIVE
Evaluation for RankIQA+FT on LIVE:
python src/eval/FT_eval_all_live.py # evaluation for all distortions in LIVE
5. 代碼調試過程
5.1 Python無法導入某個模塊ImportError:could not find module XXX
解決方案:
配置環境變量:export PYTHONPATH=path/to/modules