深度學習(PYTORCH)-3.sphereface-pytorch.lfw_eval.py詳解


pytorch版本sphereface的原作者地址:https://github.com/clcarwin/sphereface_pytorch

由於接觸深度學習不久,所以花了較長時間來閱讀源碼,以下對項目中的lfw_eval.py文件做了詳細解釋

(不知是版本問題還是作者code有誤,原代碼存在很多的bug,需要自行一一糾正,另:由於在windows下運行,故而去掉了gpu加速以及多線程)

  1 #-*- coding:utf-8 -*-
  2 from __future__ import print_function
  3 
  4 import torch
  5 import torch.nn as nn
  6 import torch.optim as optim
  7 import torch.nn.functional as F
  8 from torch.autograd import Variable
  9 torch.backends.cudnn.bencmark = True
 10 
 11 import os,sys,cv2,random,datetime
 12 import argparse
 13 import numpy as np
 14 import zipfile
 15 
 16 from dataset import ImageDataset
 17 from matlab_cp2tform import get_similarity_transform_for_cv2
 18 import net_sphere
 19 from matplotlib import pyplot as plt
 20 
 21 #圖像對齊和裁剪
 22 def alignment(src_img,src_pts):
 23     #使用標准人臉坐標對圖像進行仿射
 24     ref_pts = [ [30.2946, 51.6963],[65.5318, 51.5014],
 25         [48.0252, 71.7366],[33.5493, 92.3655],[62.7299, 92.2041] ]
 26     crop_size = (96, 112)
 27     src_pts = np.array(src_pts).reshape(5,2)
 28 
 29     s = np.array(src_pts).astype(np.float32)
 30     r = np.array(ref_pts).astype(np.float32)
 31 
 32     tfm = get_similarity_transform_for_cv2(s, r)
 33     face_img = cv2.warpAffine(src_img, tfm, crop_size)
 34     return face_img
 35 
 36 #k-fold cross validation(k-折疊交叉驗證)
 37 #將n份數據分為n_folds份,以次將第i份作為測試集,其余部分作為訓練集
 38 def KFold(n=200, n_folds=10, shuffle=False):
 39     folds = []
 40     base = list(range(n))
 41     for i in range(n_folds):
 42         test = base[(i*n//n_folds):((i+1)*n//n_folds)]
 43         train = list(set(base)-set(test))
 44         folds.append([train,test])
 45     return folds
 46 
 47 #求解當前閾值時的准確率
 48 def eval_acc(threshold, diff):
 49     y_true = []
 50     y_predict = []
 51     for d in diff:
 52         same = 1 if float(d[2]) > threshold else 0
 53         y_predict.append(same)
 54         y_true.append(int(d[3]))
 55     y_true = np.array(y_true)
 56     y_predict = np.array(y_predict)
 57     accuracy = 1.0*np.count_nonzero(y_true==y_predict)/len(y_true)
 58     return accuracy
 59 
 60 #eval_acc和find_best_threshold共同工作,來求試圖找到最佳閾值,
 61 #
 62 def find_best_threshold(thresholds, predicts):
 63     #threshould 閾值
 64     best_threshold = best_acc = 0
 65     for threshold in thresholds:
 66         accuracy = eval_acc(threshold, predicts)
 67         if accuracy >= best_acc:
 68             best_acc = accuracy
 69             best_threshold = threshold
 70     return best_threshold
 71 
 72 
 73 #命令行參數
 74 parser = argparse.ArgumentParser(description='PyTorch sphereface lfw')
 75 parser.add_argument('--net','-n', default='sphere20a', type=str)
 76 parser.add_argument('--lfw', default='../DataSet/lfw.zip', type=str)
 77 parser.add_argument('--model','-m', default='./sphere20a_20171020.pth', type=str)
 78 args = parser.parse_args()
 79 
 80 predicts=[]
 81 
 82 #加載網絡
 83 net = getattr(net_sphere,args.net)()
 84 #加載模型
 85 net.load_state_dict(torch.load(args.model))
 86 #
 87 net.eval()
 88 #
 89 net.feature = True
 90 
 91 #加載圖片數據
 92 zfile = zipfile.ZipFile(args.lfw)
 93 
 94 #加載landmark,每張照片包括五個特征點,共五組坐標
 95 landmark = {}
 96 with open('data/lfw_landmark.txt') as f:
 97     landmark_lines = f.readlines()
 98 #對每一行進行處理
 99 for line in landmark_lines:
100     l = line.replace('\n','').split('\t')
101     #將每一組數據轉化為字典形式
102     landmark[l[0]] = [int(k) for k in l[1:]]
103 
104 #加載pairs
105 with open('data/pairs.txt') as f:
106     pairs_lines = f.readlines()[1:]
107 
108 #range表示測試的圖片對數
109 for i in range(600):
110     print(str(i)+" start")
111     p = pairs_lines[i].replace('\n','').split('\t')
112     # pairs.txt一共有6000行,存在兩種形式,
113     # 分別表示進行對比的兩張照片,形式1是同一個人,形式2是不同人:
114     # name 數字1 數字2
115     # name 數字1 name數字2
116     if 3==len(p):
117         sameflag = 1
118         #形式例如:Woody_Allen/Woody_Allen_0002.jpg
119         name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1]))
120         name2 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[2]))
121     if 4==len(p):
122         sameflag = 0
123         name1 = p[0]+'/'+p[0]+'_'+'{:04}.jpg'.format(int(p[1]))
124         name2 = p[2]+'/'+p[2]+'_'+'{:04}.jpg'.format(int(p[3]))
125 
126     #分別加載兩張照片,並對其進行圖像對齊
127     org_img1=cv2.imdecode(np.frombuffer(zfile.read("lfw/lfw/"+name1),np.uint8),1)
128     org_img2=cv2.imdecode(np.frombuffer(zfile.read("lfw/lfw/"+name2),np.uint8),1)
129     img1 = alignment(org_img1,landmark[name1])
130     img2 = alignment(org_img2,landmark[name2])
131     #1.對輸出圖像使用cv2進行展示
132     # cv2.imshow("org_img1", org_img1)
133     # cv2.imshow("org_img2", org_img2)
134     # cv2.imshow("img1",img1)
135     # cv2.imshow("img2", img2)
136     # cv2.waitKey(0)
137     # cv2.destroyAllWindows()
138     #2.對輸出圖像使用matplotlib進行展示
139     fig_new=plt.figure()
140     img_list=[[org_img1,221],[org_img2,222],[img1,223],[img2,224]]
141     for p,q in img_list:
142         ax=fig_new.add_subplot(q)
143         p = p[:, :, (2, 1, 0)]
144         ax.imshow(p)
145     plt.show()
146 
147     #cv.flip圖像翻轉,第二個參數:1:水平翻轉,0:垂直翻轉,-1:水平垂直翻轉
148     imglist = [img1,cv2.flip(img1,1),img2,cv2.flip(img2,1)]
149     #分別對圖片進行
150     for m in range(len(imglist)):
151         imglist[m] = imglist[m].transpose(2, 0, 1).reshape((1,3,112,96))
152         imglist[m] = (imglist[m]-127.5)/128.0
153 
154     # p.vstack: 垂直(按照行順序)的把數組給堆疊起來
155     #******舉例******
156     # import numpy as np
157     # a = [1, 2, 3]
158     # b = [4, 5, 6]
159     # print(np.vstack((a, b)))
160     #
161     # 輸出:
162     # [[1 2 3]
163     #  [4 5 6]]
164     img = np.vstack(imglist)
165     #將numpy形式轉化為variable形式
166     img = Variable(torch.from_numpy(img).float(),volatile=True)
167     output = net(img)
168     #得到計算結果,f1和f2均為512維向量形式
169     f = output.data
170     f1,f2 = f[0],f[2]
171     #計算二者的余弦相似度,后面加上常量是為了防止分母為0
172     #關於余弦相似度請自行百度或google
173     #這里給出一個簡單說明的鏈接:http://blog.csdn.net/huangfei711/article/details/78469614
174     #a*b/|a||b|
175     cosdistance = f1.dot(f2)/(f1.norm()*f2.norm()+1e-5)
176     predicts.append('{}\t{}\t{}\t{}\n'.format(name1,name2,cosdistance,sameflag))
177     print(str(i) + " end")
178 
179 
180 #准確率
181 accuracy = []
182 #(最佳)閾值
183 thd = []
184 #k-fold cross validation(k-折疊交叉驗證)
185 #folds的形式為[[train,test],[train,test].....]
186 folds = KFold(n=600, n_folds=10, shuffle=False)
187 #取數組為-1到1,步長為0.005
188 thresholds = np.arange(-1.0, 1.0, 0.005)
189 # 此處為原作者code,疑似有誤,已做修改
190 # predicts = np.array(map(lambda line:frd.append(line.strip('\n').split()), predicts))
191 predicts = np.array([k.strip('\n').split() for k in predicts])
192 for idx, (train, test) in enumerate(folds):
193     # predicts[train/test]形式為:
194     # [['Doris_Roberts/Doris_Roberts_0001.jpg'
195     # 'Doris_Roberts/Doris_Roberts_0003.jpg' '0.6532696413605743' '1'],.....]
196     #尋找最佳閾值
197     best_thresh = find_best_threshold(thresholds, predicts[train])
198     #通過上面的得到的最佳閾值來對test數據集進行測試得到准確率
199     accuracy.append(eval_acc(best_thresh, predicts[test]))
200     #thd閾值
201     thd.append(best_thresh)
202 #np.mean:計算均值,np.std:計算標准差
203 #輸出結果分別為:准確率均值,准確率標准差,閾值均值
204 print('LFWACC={:.4f} std={:.4f} thd={:.4f}'.format(np.mean(accuracy), np.std(accuracy), np.mean(thd)))
205 #例如結果為 LFWACC=0.9800 std=0.0600 thd=0.3490
206 #則說明准確率為98%,准確率標准差為0.06,閾值的均值為0.3490
207 #因此我們可以認為余弦相似度大於0.3490的兩張圖片里是同一個人

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM