Island loss損失函數的理解與實現


#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2020/02/04 20:08
# @Author  : dangxusheng
# @Email   : dangxusheng163@163.com
# @File    : isLand_loss.py
'''
島嶼損失旨在減少類內變化,同時擴大類間差異
目的是在center loss的基礎上, 進一步優化類間距離
https://blog.csdn.net/heruili/article/details/88912074
L_island = L_center + lamda1 * penalty
Loss = L_softmax + lamda * L_island
'''

from myToolsPkgs.pytorch_helper import *
from torch.autograd import Function


class IslandLoss(nn.Module):
    """
    paper: https://arxiv.org/pdf/1710.03144.pdf
    url: https://blog.csdn.net/u013841196/article/details/89920441
    """

    def __init__(self, features_dim, num_class=10, lamda=1., lamda1=10., scale=1.0, batch_size=64):
        """
        初始化
        :param features_dim: 特征維度 = c*h*w
        :param num_class: 類別數量
        :param lamda:   island loss的權重系數
        :param lamda1:  island loss內部 特征中心距離懲罰項的權重系數
        :param scale:   特征中心梯度的縮放因子
        :param batch_size:   批次大小
        """
        super(IslandLoss, self).__init__()
        self.lamda = lamda
        self.lamda1 = lamda1
        self.num_class = num_class
        self.scale = scale
        self.batch_size = batch_size
        self.feat_dim = features_dim
        # store the center of each class , should be ( num_class, features_dim)
        self.feature_centers = nn.Parameter(torch.randn([num_class, features_dim]))

        # self.lossfunc = IslandLossFunc.apply

    def forward(self, output_features, y_truth):
        """
        損失計算
        :param output_features: conv層輸出的特征,  [b,c,h,w]
        :param y_truth:  標簽值  [b,]
        :return:
        """
        batch_size = y_truth.size(0)
        num_class = self.num_class
        output_features = output_features.view(batch_size, -1)
        assert output_features.size(-1) == self.feat_dim

        factor = self.scale / batch_size
        # # # 第一種: 使用自己重寫的backward
        # return self.lossfunc(output_features, y_truth, self.feature_centers,
        #                      torch.Tensor([self.alpha, self.lamda, self.lamda1, self.scale]))

        # 第二種: 使用pytorch默認的
        centers_batch = self.feature_centers.index_select(0, y_truth.long())  # [b,features_dim]
        diff = output_features - centers_batch
        # 1 先求 center loss
        loss_center = 1 / 2.0 * (diff.pow(2).sum()) * factor
        # 2 再求 類心余弦距離
        # 每個類心求余弦距離,+1 使得范圍為0-2,越接近0表示類別差異越大,從而優化Loss即使得類間距離變大。
        centers = self.feature_centers
        # 求出向量模長矩陣 ||Ci||
        centers_mod = torch.sum(centers * centers, dim=1, keepdim=True).sqrt()  # [num_class, 1]
        #  ====================== method 1 =======================
        item1_sum = 0
        for j in range(num_class):
            dis_sum_j_others = 0
            for k in range(j + 1, num_class):
                dot_kj = torch.sum(centers[j] * centers[k])
                fenmu = centers_mod[j] * centers_mod[k] + 1e-9
                cos_dis = dot_kj / fenmu
                dis_sum_j_others += cos_dis + 1.
                # print(dis_sum_j_others)
            item1_sum += dis_sum_j_others
        loss_island = self.lamda * (loss_center + self.lamda1 * item1_sum)
        # ====================== method 2 =======================
        # # Ci X Ci.T
        # centers_mm = torch.matmul(centers,centers.t())  # [num_class, num_class]
        # centers_mod_mm = centers_mod.mm(centers_mod.t())  # [num_class,num_class]
        # # 求出 cos距離 矩陣, 這是一個對稱矩陣
        # centers_cos_dis = centers_mm / centers_mod_mm
        # centers_cos_dis += 1.
        # # 只獲取上三角, 代表同一個類別的距離不考慮
        # centers_cos_dis_1 = torch.triu(centers_cos_dis,diagonal=1)
        # print(centers_cos_dis_1)
        # sum_centers_cos_dis = torch.sum(centers_cos_dis_1)
        # loss_island = self.lamda * (loss_center + self.lamda1 * sum_centers_cos_dis)
        return loss_island


torch.manual_seed(1000)
if __name__ == '__main__':
    import random

    # test 1
    num_class = 10
    batch_size = 10
    feat_dim = 2
    ct = IslandLoss(feat_dim, num_class, 0.1, 1., 1., batch_size)
    y = torch.Tensor([random.choice(range(num_class)) for i in range(batch_size)])
    feat = torch.randn(num_class, feat_dim).requires_grad_()
    print(feat)
    out = ct(feat, y)
    out.backward()
    print(ct.feature_centers.grad)
    print(feat.grad)

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM