損失函數Center Loss 代碼解析


center loss來自ECCV2016的一篇論文:A Discriminative Feature Learning Approach for Deep Face Recognition。 
論文鏈接:http://ydwen.github.io/papers/WenECCV16.pdf 
代碼鏈接:https://github.com/davidsandberg/facenet

理論解析請參看 https://blog.csdn.net/u014380165/article/details/76946339

下面給出centerloss的計算公式以及更新公式

 

 

下面的代碼是facenet作者利用tensorflow實現的centerloss代碼

def center_loss(features, label, alfa, nrof_classes):
    """Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"
       (http://ydwen.github.io/papers/WenECCV16.pdf)
       https://blog.csdn.net/u014380165/article/details/76946339
    """
    nrof_features = features.get_shape()[1]
  #訓練過程中,需要保存當前所有類中心的全連接預測特征centers, 每個batch的計算都要先讀取已經保存的centers centers
= tf.get_variable('centers', [nrof_classes, nrof_features], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False) label = tf.reshape(label, [-1]) centers_batch = tf.gather(centers, label)#獲取當前batch對應的類中心特征 diff = (1 - alfa) * (centers_batch - features)#計算當前的類中心與特征的差異,用於Cj的的梯度更新,這里facenet的作者做了一個 1-alfa操作,比較奇怪,和原論文不同 centers = tf.scatter_sub(centers, label, diff)#更新梯度Cj,對於上圖中步驟6,tensorflow會將該變量centers保留下來,用於計算下一個batch的centerloss loss = tf.reduce_mean(tf.square(features - centers_batch))#計算當前的centerloss 對應於Lc return loss, centers

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM