图像分割损失函数总结Dice loss 和 Focal losss

1.Dice Loss

Dice loss 来自文献[1],是从Dice系数推广得到的损失函数。

Dice 系数是一种集合相似度度量函数,是从区域角度衡量两个集合的相似度。(CE Loss是 从概率分布角度)

Dice 系数值域为 [0, 1] ,两个集合完全重叠时为1, 完全不重叠时为0,计算公式如下

Dice=\frac{2*gtmask\bigcap predmask}{gtmask+predmask}

Dice loss = 1-\frac{2\left | A\bigcap B \right |}{\left | A \right |+\left | B \right |},值域[0, 1],loss值越小,重合度越高

分母的计算:|A| 和 |B|分别表示A、B的元素个数

分子的计算:A和B的交集,用点乘。

# -*- coding: utf-8 -*-
"""
# @file name  : camvid_config.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2020-03-12
# @brief      : dice loss
"""
import torch
import torch.nn as nn


class DiceLoss(nn.Module):
    """
    soft dice loss, 直接使用预测概率而不是使用阈值或将它们转换为二进制mask
    """
    def __init__(self, epsilon=1e-5):
        super(DiceLoss, self).__init__()
        self.epsilon = epsilon

    def forward(self, predict, target):
        assert predict.size() == target.size(), "the size of predict and target must be equal."
        num = predict.size(0)

        # pred不需要转bool变量,如https://github.com/yassouali/pytorch-segmentation/blob/master/utils/losses.py#L44
        # soft dice loss, 直接使用预测概率而不是使用阈值或将它们转换为二进制mask
        pred = torch.sigmoid(predict).view(num, -1)
        targ = target.view(num, -1)

        intersection = (pred * targ).sum()  # 利用预测值与标签相乘当作交集
        union = (pred + targ).sum()

        score = 1 - 2 * (intersection + self.epsilon) / (union + self.epsilon)

        return score


if __name__ == "__main__":

    fake_out = torch.tensor([7, 7, -5, -5], dtype=torch.float32)
    fake_label = torch.tensor([1, 1, 0, 0], dtype=torch.float32)
    loss_f = DiceLoss()
    loss = loss_f(fake_out, fake_label)

    print(loss)




2.Focal Loss

Focal Loss:针对Two-Stage目标检测RPN网络中正负样本严重不均衡及困难样本提出的Loss,是在CE loss 基础上改进而来。CE LOSS如下:

解决不均衡:增加类别权重αt

解决困难样本:增加难度权重γ

最终Focal loss 公式如下:

# -*- coding: utf-8 -*-
"""
# @file name  : focal_loss.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2020-03-12
# @brief      : 标准的 focal loss
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt


class FocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=None, ignore_index=255, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.size_average = size_average
        self.CE_loss = nn.CrossEntropyLoss(ignore_index=ignore_index, weight=alpha)

    def forward(self, output, target):
        logpt = self.CE_loss(output, target)
        pt = torch.exp(-logpt)  # 因CE中取了log,所以要exp回来,就得到概率。因为输入并不是概率,CEloss中自带softmax转为概率形式
        loss = ((1-pt)**self.gamma) * logpt
        if self.size_average:
            return loss.mean()
        return loss.sum()


if __name__ == "__main__":

    target = torch.tensor([1], dtype=torch.long)
    gamma_lst = [0, 0.5, 1, 2, 5]
    loss_dict = {}
    for gamma in gamma_lst:
        focal_loss_func = FocalLoss(gamma=gamma)
        loss_dict.setdefault(gamma, [])

        for i in np.linspace(0.5, 10.0, num=30):
            outputs = torch.tensor([[5, i]], dtype=torch.float)  # 制造不同概率的输出
            prob = F.softmax(outputs, dim=1)  # 由于pytorch的CE自带softmax,因此想要知道具体预测概率,需要自己softmax
            loss = focal_loss_func(outputs, target)
            loss_dict[gamma].append((prob[0, 1].item(), loss.item()))

    for gamma, value in loss_dict.items():
        x_prob = [prob for prob, loss in value]
        y_loss = [loss for prob, loss in value]
        plt.plot(x_prob, y_loss, label="γ="+str(gamma))

    plt.title("Focal Loss")
    plt.xlabel("probability of ground truth class")
    plt.ylabel("loss")
    plt.legend()
    plt.show()


3.BCE Loss 

(Binary Cross-Entropy Loss) 是用于二分类问题的损失函数。它用于评估预测值和实际标签之间的差异。计算过程如下:

import torch
import torch.nn as nn

# 创建一个示例输入(预测值)和标签
predictions = torch.tensor([0.1, 0.9, 0.8, 0.3], dtype=torch.float32)
labels = torch.tensor([0, 1, 1, 0], dtype=torch.float32)

# 初始化 BCELoss
criterion = nn.BCELoss()

# 计算损失
loss = criterion(predictions, labels)

print(f"Binary Cross-Entropy Loss: {loss.item()}")

4.Bce+Dice (训练效果好)

采用BCE + Dice ,权重比例1:1   

import torch
import torch.nn as nn
from losses.dice_loss import DiceLoss


class BCEDiceLoss(nn.Module):
    def __init__(self, **kwargs):
        super(BCEDiceLoss, self).__init__()
        self.bce_func = nn.BCEWithLogitsLoss(**kwargs)  # *args和**kwargs,python可变长参数
        self.dice_func = DiceLoss()

    def forward(self, predict, target):
        loss_bce = self.bce_func(predict, target)
        loss_dice = self.dice_func(predict, target)
        return loss_dice + loss_bce


if __name__ == "__main__":

    fake_out = torch.tensor([1, 1, -1, -1], dtype=torch.float32)
    fake_label = torch.tensor([1, 1, 0, 0], dtype=torch.float32)
    loss_f = BCEDiceLoss()
    loss = loss_f(fake_out, fake_label)

    print(loss)

[1]《V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation》

  • 9
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
引用\[1\]中提到了将BCE LossDice Loss进行组合的方法,可以在数据较为均衡的情况下有所改善。然而,在数据极度不均衡的情况下,交叉熵会在迭代几个Epoch之后远远小于Dice Loss,这个组合Loss会退化为Dice Loss。所以,组合Dice Loss和Focal Loss可能会更好地解决前景背景不平衡的问题。引用\[2\]中提到,Dice Loss训练更关注对前景区域的挖掘,即保证有较低的FN,但会存在损失饱和问题,而CE Loss是平等地计算每个像素点的损失。因此,单独使用Dice Loss往往并不能取得较好的结果,需要进行组合使用,比如Dice Loss+CE Loss或者Dice Loss+Focal Loss等。所以,组合Dice Loss和Focal Loss可以综合考虑前景背景不平衡和损失饱和问题,从而取得更好的结果。 #### 引用[.reference_title] - *1* [分割常用损失函数](https://blog.csdn.net/m0_45447650/article/details/125620794)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [分割网络损失函数总结!交叉熵,Focal lossDice,iou,TverskyLoss!](https://blog.csdn.net/jijiarenxiaoyudi/article/details/128360405)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值