pytorch04——自定義損失函數


pytroch在torch.nn模塊中本來就為我們提供了許多常用的損失函數,比如MSELoss,L1Loss,BCELoss.........但是在科研中還有實際一些運用場景中,我們需要通過自定義損失函數的方式來實現一些損失函數。

1.以函數的方式自定義損失函數

def my_loss(output,input):
    loss = torch.mean((output - target) ** 2)
return loss

2.以類的方式進行定義

雖然以函數定義的方式很簡單,但是以類方式定義更加常用,在以類的方式定義損失函數時,我們如果看每一個損失函數的繼承關系,我們就可以發現Loss函數部分繼承自_loss,部分繼承自_weightedLoss,而_WeightedLoss繼承自_loss._loss繼承自nn.Module。因此我們可以將 以類定義的損失函數當作神經網絡中的一層來對待,因此我們自定義的損失函數類就需要繼承自nn.Module類
1.例如在分割領域常見的損失函數,DiceLoss

class DiceLoss(nn.Module):
    def __init__(self,weight=None,size_average=True):
        super(DiceLoss,self).__init__()
        
	def forward(self,inputs,targets,smooth=1)
        inputs = F.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()                   
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        return 1 - dice

# 使用方法    
criterion = DiceLoss()
loss = criterion(input,targets)

2.DiceBCELoss

class DiceBCELoss(nn.Module):
    def __init__(self,weight=None,size_average = True):
        super(DiceBCELoss, self).__init__()
    def forward(self,inputs,targets,smooth=1):
        inputs = F.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs*targets).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        return Dice_BCE

3.IouLoss

class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()
    def forward(self, inputs, targets, smooth=1):
        inputs = F.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection
        IoU = (intersection + smooth) / (union + smooth)
        return 1 - IoU

4.FocalLoss

class FocalLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(FocalLoss, self).__init__()
    def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
        inputs = F.sigmoid(inputs)
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = alpha * (1 - BCE_EXP) ** gamma * BCE
        return focal_loss

總結:

自定函數可以通過函數和類兩種方式進行實現,不過在實際運用中用類更多,我們全程使用PyTorch提供的張量計算接口,這樣集不需要我們去實現自動求導功能,並且可以直接進行調用cuda


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM