深度学习中的常用的损失函数(PyTorch)

定义

损失函数量化的预测值和真实值之间的差距,被用来观察所训练的模型是否朝着正确的方向在进行优化。也可以将多个损失函数结合起来计算模型的损失值。

分类问题(Classification)

Cross-Entropy Loss

定义: L C E = − 1 n ∑ i = 1 N g i l o g ( p i ) L_{CE}=-\frac{1}{n}\sum_{i=1}^{N}{g_ilog(p_i)} LCE=n1i=1Ngilog(pi)

其中 g i g_i gi是真实值, p i p_i pi是第 i i i个类别的概率。对于二分类问题来说, N = 2 N=2 N=2

代码:
torch.nn.CrossEntropyLoss(多分类)
torch.nn.BCELoss(二分类)

Dice Loss

定义: D S C = 1 − 2 ∑ i y i ^ y i + s ∑ i y i ^ ∑ i y i + s DSC=1-\frac{2\sum_{i}{\hat{y_i}y_i+s}}{\sum_{i}{\hat{y_i}}\sum_{i}y_i+s} DSC=1iyi^iyi+s2iyi^yi+s
优势:收敛快
代码:

#PyTorch
class DiceLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice

Focal Loss

定义: F L = y × l o g ( y ^ ) × ( 1 − y ^ ) Υ FL=y\times log(\hat{y})\times (1-\hat{y})^\Upsilon FL=y×log(y^)×(1y^)Υ

在医疗图像分割中,Focal loss 可以提高模型对模糊细胞的分割性能。也可以和Dice loss 结合使用。

代码:

#PyTorch
ALPHA = 0.8
GAMMA = 2

class FocalLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(FocalLoss, self).__init__()

    def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #first compute binary cross-entropy 
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        BCE_EXP = torch.exp(-BCE)
        focal_loss = alpha * (1-BCE_EXP)**gamma * BCE
                       
        return focal_loss

Jaccard/Intersection over Union (IoU) Loss

定义: L I o U = 1 − ∑ i y i ^ y i ∑ i y i ^ + ∑ i y i − ∑ i y i ^ y i L_{IoU}=1-\frac{\sum_{i}{\hat{y_i}y_i}}{\sum_{i}{\hat{y_i}}+\sum_{i}y_i-\sum_{i}{\hat{y_i}y_i}} LIoU=1iyi^+iyiiyi^yiiyi^yi

IoU loss 常在图像分割中使用,可以反应出预测掩膜图像与实际分割图像的相似性。

代码:

#PyTorch
class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #intersection is equivalent to True Positive count
        #union is the mutually inclusive area of all labels & predictions 
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return 1 - IoU

回归问题(Regression)

Mean Square Error

定义: M S E = ∑ i n ( y i − y i ^ ) 2 n MSE=\frac{\sum_{i}^{n}{(y_i-\hat{y_i})^2}}{n} MSE=nin(yiyi^)2

代码:torch.nn.MSELoss

Mean Absolute Error

定义: M A E = ∑ i n ∣ y i − y i ^ ∣ n MAE=\frac{\sum_{i}^{n}{|y_i-\hat{y_i}|}}{n} MAE=ninyiyi^

代码:torch.nn.L1Loss

References

  1. https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch#Jaccard/Intersection-over-Union-(IoU)-Loss
  2. https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
  3. https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html
  4. https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值