(原)InsightFace及其mxnet代碼


轉載請注明出處:

http://www.cnblogs.com/darkknightzh/p/8525287.html

論文

InsightFace : Additive Angular Margin Loss for Deep Face Recognition

https://arxiv.org/abs/1801.07698

官方mxnet代碼:

https://github.com/deepinsight/insightface

 

說明:沒用過mxnet,下面的代碼注釋只是純粹從代碼的角度來分析並進行注釋,如有錯誤之處,敬請諒解,並歡迎指出。

 

先查看sphereface,查看$\psi (\theta )$的介紹:http://www.cnblogs.com/darkknightzh/p/8524937.html

論文arcface中,定義$\psi (\theta )$為:

$\psi (\theta )=\cos ({{\theta }_{yi}}+m)$

同時對w及x均進行了歸一化,為了使得訓練能收斂,增加了一個參數s=64,最終loss如下:

$L=-\frac{1}{m}\sum\limits_{i=1}^{m}{\log \frac{{{e}^{s(\cos ({{\theta }_{yi}}+m))}}}{{{e}^{s(\cos ({{\theta }_{yi}}+m))}}+\sum\nolimits_{j=n,j\ne yi}^{n}{{{e}^{s\cos {{\theta }_{j}}}}}}}$

其中,

${{W}_{j}}=\frac{{{W}_{j}}}{\left\| {{W}_{j}} \right\|}$,${{x}_{i}}=\frac{{{x}_{i}}}{\left\| {{x}_{i}} \right\|}$,$\cos {{\theta }_{j}}=W_{j}^{T}{{x}_{i}}$

程序中先對w及x歸一化,然后通過全連接層得到cosθ,再擴大s倍,得到scosθ。

對於yi處,由於

$\cos (\theta +m)=\cos \theta \cos m-\sin \theta \sin m$

以及

$\sin \theta =\sqrt{1-{{\cos }^{2}}\theta }$

得到sinθ。

由於$\cos (\theta +m)$非單調,設置了easy_margin標志,當其為真時,使用0作為閾值,當特征和權重的cos值小於0,直接截斷;當其為假時,使用cos(pi-m)=-cos(m)作為閾值。該閾值小於0。

之后判斷時,當easy_margin為真時,若s*cos(θ+m)小於0,直接使用s*cos(θ);當easy_margin為假時,若s*cos(θ+m)小於0,使用s*cos(θ)-s*m*sin(m)。

具體的代碼如下(完整代碼見參考網址):

 1     s = args.margin_s  # 參數s
 2     m = args.margin_m  # 參數m
 3 
 4     _weight = mx.symbol.Variable("fc7_weight", shape=(args.num_classes, args.emb_size), lr_mult=1.0) # (C,F)
 5     _weight = mx.symbol.L2Normalization(_weight, mode='instance')   # 對w進行歸一化
 6     nembedding = mx.symbol.L2Normalization(embedding, mode='instance', name='fc1n')*s # 對x進行歸一化,並得到s*x,(B,F)
 7     fc7 = mx.sym.FullyConnected(data=nembedding, weight = _weight, no_bias = True, num_hidden=args.num_classes, name='fc7') # Y=XW'+b,(B,F)*(C,F)'=(B,C),'為轉置,此處得到scos(theta)
 8     
 9     zy = mx.sym.pick(fc7, gt_label, axis=1)  # 得到fc7中gt_label位置的值。(B,1)或者(B),即當前batch中yi處的scos(theta)
10     cos_t = zy/s  # 由於fc7及zy均為cos的s倍,此處除以s,得到實際的cos值。(B,1)或者(B)
11     
12     cos_m = math.cos(m)
13     sin_m = math.sin(m)
14     mm = math.sin(math.pi-m)*m # sin(pi-m)*m = sin(m)*m
15     threshold = math.cos(math.pi-m)  # 閾值,避免theta + m >= pi,實際上threshold < 0
16     if args.easy_margin:
17       cond = mx.symbol.Activation(data=cos_t, act_type='relu') #easy_margin=True,直接使用0作為閾值,得到超過閾值的索引
18     else:
19       cond_v = cos_t - threshold #easy_margin=False,使用threshold(負數)作為閾值。
20       cond = mx.symbol.Activation(data=cond_v, act_type='relu') # 得到超過閾值的索引
21     body = cos_t*cos_t  # 通過cos*cos + sin * sin = 1, 來得到sin_theta
22     body = 1.0-body
23     sin_t = mx.sym.sqrt(body)  # sin_theta
24     new_zy = cos_t*cos_m # cos(theta+m)=cos(theta)*cos(m)-sin(theta)*sin(m),此處為cos(theta)*cos(m)
25     b = sin_t*sin_m # 此處為sin(theta)*sin(m)
26     new_zy = new_zy - b # 此處為cos(theta)*cos(m)-sin(theta)*sin(m)=cos(theta+m)
27     new_zy = new_zy*s # 此處為s*cos(theta+m),擴充了s倍
28     if args.easy_margin:
29       zy_keep = zy   # zy_keep為zy,即s*cos(theta)
30     else:
31       zy_keep = zy - s*mm  # zy_keep為zy-s*sin(m)*m=s*cos(theta)-s*m*sin(m)
32     new_zy = mx.sym.where(cond, new_zy, zy_keep) # cond中>0的保持new_zy=s*cos(theta+m)不變,<0的裁剪為zy_keep= s*cos(theta) or s*cos(theta)-s*m*sin(m)
33 
34     diff = new_zy - zy # 
35     diff = mx.sym.expand_dims(diff, 1)
36     gt_one_hot = mx.sym.one_hot(gt_label, depth = args.num_classes, on_value = 1.0, off_value = 0.0)
37     body = mx.sym.broadcast_mul(gt_one_hot, diff) # 對應yi處為new_zy - zy
38     fc7 = fc7+body # 對應yi處,fc7=zy + (new_zy - zy) = new_zy,即cond中>0的為s*cos(theta+m),<0的裁剪為s*cos(theta) or s*cos(theta)-s*m*sin(m)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM