Focal Loss 的Pytorch 實現以及實驗


Focal Loss 的Pytorch 實現以及實驗

Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出對簡單樣本的進行decay的一種損失函數。是對標准的Cross Entropy Loss 的一種改進。 F L對於簡單樣本(p比較大)回應較小的loss。

如論文中的圖1, 在p=0.6時, 標准的CE然后又較大的loss, 但是對於FL就有相對較小的loss回應。這樣就是對簡單樣本的一種decay。其中alpha 是對每個類別在訓練數據中的頻率有關, 但是下面的實現我們是基於alpha=1進行實驗的。

標准的Cross Entropy 為:

[公式]

Focal Loss 為:

[公式]

[公式]

其中 [公式]

以上公式為下面實現代碼的基礎。

 

采用基於pytorch 的yolo2 在VOC的上的實驗結果如下:

 

在單純的替換了CrossEntropyLoss之后就有1個點左右的提升。效果還是比較顯著的。本實驗中采用的是darknet19, 要是采用更大的網絡就可能會有更好的性能提升。這個實驗結果已經能很好的說明的Focal Loss 的對於檢測的價值了。

 

一點沒做的但是可能會提升性能:

1. 采用soft - gamma: 在訓練的過程中階段性的增大gamma 可能會有更好的性能提升

 

 

本文實驗中采用的Focal Loss 代碼如下。

關於Focal Loss 的數學推倒在文章:Focal Loss 的前向與后向公式推導

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in 
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)
        #print(class_mask)


        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P*class_mask).sum(1).view(-1,1)

        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        #print(probs)

        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        #print('-----bacth_loss------')
        #print(batch_loss)


        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM