dice loss

"""
GDiceLoss
GDiceLossV2
SSLoss
SoftDiceLoss
IouLoss
TverskyLoss
FocalTversky_loss
AsymLoss
Dice_and_CE_loss
PenaltyGDiceLoss
Dice_and_Topk_loss
ExpLog_Loss
"""
import numpy as np
import torch
from torch import nn
from torch import einsum
from torch.autograd import Variable
from losses_pytorch.ND_Crossentropy import CrossEntropy, TopkLoss, Weight_CrossEntropy_Loss


def soft_max(x):
    rpt = [1 for i in range(len(x.size()))]
    rpt[1] = x.size(1)
    x_max = x.max(dim=1, keepdim=True)[0].repeat(*rpt)
    e_x = torch.exp(x - x_max)
    return e_x / e_x.sum(dim=1, keepdim=True).repeat(*rpt)


def sum_tensor(inp, axes, keepdim=False):
    # axes = np.unique(axes).astype(int)
    if keepdim:
        for ax in axes:
            inp = inp.sum(dim=int(ax), keepdim=True)
    else:
        for ax in axes:
            inp = inp.sum(dim=int(ax), keepdim=False)
    return inp


def tp_tn_fp_fn(net_out, target, axes=None, mask=None, square=False):
    """
    net_out : (b, c, h, w)
    targrt : (b, 1, h, w) or (b, h, w) or one_hot encoding (b, c, h, w)
    """
    num_class = net_out.size()[1]
    if axes is None:
        axes = tuple(range(2, len(net_out.shape)))

    shp_x = net_out.shape
    shp_y = target.shape

    with torch.no_grad():
        if len(shp_x) != len(shp_y):
            target = target.view((shp_y[0], 1, *shp_y[1:]))
        if all([i == j for i, j in zip(net_out.shape, target.shape)]):
            one_hot = target
        else:
            idx = target.long()
            one_hot = torch.zeros(shp_x)
            one_hot.scatter_(1, idx, 1)

    tp = net_out * one_hot
    tn = (1 - net_out) * (1 - one_hot)
    fp = net_out * (1 - one_hot)
    fn = (1 - net_out) * one_hot

    if mask != None:
        tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
        tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1)
        fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
        fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)

    if square:
        tp = tp ** 2
        tn = tn ** 2
        fp = fp ** 2
        fn = fn ** 2

    tp = sum_tensor(tp, axes, keepdim=True).view(-1, num_class)
    tn = sum_tensor(tn, axes, keepdim=True).view(-1, num_class)
    fp = sum_tensor(fp, axes, keepdim=True).view(-1, num_class)
    fn = sum_tensor(fn, axes, keepdim=True).view(-1, num_class)

    return tp, tn, fp, fn


class GDiceLoss(nn.Module):
    def __init__(self, apply_nonlin=None, smooth=1e-5):
        super(GDiceLoss, self).__init__()

        self.apply_nonlin = apply_nonlin
        self.smooth = smooth

    def forward(self, net_out, target):

        if self.apply_nonlin != None:
            net_out = self.apply_nonlin(net_out)

        shp_x = net_out.shape
        shp_y = target.shape

        with torch.no_grad():
            if len(shp_x) != len(shp_y):
                target = target.view((shp_y[0], 1, *shp_y[1:]))
            if all([i == j for i, j in zip(net_out.shape, target.shape)]):
                one_hot = target
            else:
                idx = target.long()
                one_hot = torch.zeros(shp_x)
                one_hot.scatter_(1, idx, 1)

        w: torch.Tensor = 1 / (einsum('bcxy->bc', one_hot).type(torch.float32) + 1e-10) ** 2
        intersection: torch.Tensor = w * einsum('bcxy, bcxy->bc', net_out, one_hot)
        union: torch.Tensor = w * (einsum('bcxy->bc', net_out) + einsum('bcxy->bc', one_hot))
        divided: torch.Tensor = 2 * (einsum('bc->b', intersection) + self.smooth) / (einsum('bc->b', union) + self.smooth)
        GDLoss = divided.mean()

        return 1 - GDLoss


class GDiceLossV2(nn.Module):
    def __init__(self, apply_nonlin=None, smooth=1e-5):
        super(GDiceLossV2, self).__init__()

        self.apply_nonlin = apply_nonlin
        self.smooth = smooth

    def forward(self, net_out, target):
        if self.apply_nonlin != None:
            net_out = self.apply_nonlin(net_out)

        shp_x = net_out.shape
        shp_y = target.shape

        with torch.no_grad():
            if len(shp_x) != len(shp_y):
                target = target.view(shp_y[0], 1, *shp_y[1:])
            if all([i == j for i, j in zip(shp_x, shp_y)]):
                one_hot = target
            else:
                idx = target.long()
                one_hot = torch.zeros(shp_x)
                one_hot = one_hot.scatter_(1, idx, 1)

        input = torch.flatten(net_out)
        target = torch.flatten(one_hot).float()
        target_sum = target.sum(dim=-1)

        class_weight = Variable(1 / (target_sum * target_sum).clamp(min=self.smooth), requires_grad=False)
        intersection = (input * target).sum(dim=-1) * class_weight
        intersection = intersection.sum()
        denomimator = ((input + target).sum(dim=-1) * class_weight).sum()
        divided = -2 * intersection / denomimator.clamp(min=self.smooth)

        return divided


class SSLoss(nn.Module):
    def __init__(self, apply_nonlin=None, smooth=1e-5,
                 batch_dice=False, do_bg=True, square=False, weight=0.1, loss_mask=None):
        super(SSLoss, self).__init__()

        self.apply_nonlin = apply_nonlin
        self.smooth = smooth
        self.batch_dice = batch_dice
        self.do_bg = do_bg
        self.square = square
        self.weight = weight
        self.loss_mask = loss_mask

    def forward(self, net_out, target):
        shp_x = net_out.shape
        shp_y = target.shape

        if self.apply_nonlin != None:
            net_out = self.apply_nonlin(net_out)
        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        with torch.no_grad():
            if len(shp_x) != len(shp_y):
                target = target.view(shp_y[0], 1, *shp_y[1:])
            if all([i == j for i, j in zip(shp_x, shp_y)]):
                one_hot = target
            else:
                idx = target.long()
                one_hot = torch.zeros(shp_x)
                one_hot = one_hot.scatter_(1, idx, 1)

        squared_error = (one_hot - net_out) ** 2
        specificity = sum_tensor(squared_error * one_hot, axes) / (sum_tensor(one_hot, axes) + self.smooth)
        sensitivity = sum_tensor(squared_error * (1 - one_hot), axes) / (sum_tensor(1 - one_hot, axes) + self.smooth)
        ss = self.weight * specificity + (1 - self.weight) * sensitivity

        if not self.do_bg:
            if self.batch_dice:
                ss = ss[1:]
            else:
                ss = ss[:, 1:]

        ss = ss.mean()
        return ss


class SoftDiceLoss(nn.Module):
    def __init__(self, apply_nonlin=None, smooth=1e-5,
                 batch_dice=False, do_bg=True, square=False, loss_mask=None):
        super(SoftDiceLoss, self).__init__()

        self.apply_nonlin = apply_nonlin
        self.smooth = smooth
        self.batch_dice = batch_dice
        self.do_bg = do_bg
        self.square = square
        self.loss_mask = loss_mask

    def forward(self, net_out, target):
        shp_x = net_out.shape

        if self.apply_nonlin != None:
            net_out = self.apply_nonlin(net_out)
        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        tp, tn, fp, fn = tp_tn_fp_fn(net_out, target, axes, mask=self.loss_mask, square=self.square)

        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)

        if self.do_bg is not True:
            if self.batch_dice:
                dc = dc[1:]
            else:
                dc = dc[:, 1:]

        dc = dc.mean()
        return 1- dc


class IouLoss(nn.Module):
    def __init__(self, apply_nonlin=None, smooth=1e-5,
                 batch_dice=False, do_bg=True, square=False, loss_mask=None):
        super(IouLoss, self).__init__()

        self.apply_nonlin = apply_nonlin
        self.smooth = smooth
        self.batch_dice = batch_dice
        self.do_bg = do_bg
        self.square = square
        self.loss_mask = loss_mask

    def forward(self, net_out, target):
        shp_x = net_out.shape

        if self.apply_nonlin != None:
            net_out = self.apply_nonlin(net_out)
        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        tp, tn, fp, fn = tp_tn_fp_fn(net_out, target, axes, self.loss_mask, self.square)

        iou = (tp + self.smooth) / (tp + fp + fn + self.smooth)

        if self.do_bg is not True:
            if self.batch_dice:
                iou = iou[1:]
            else:
                iou = iou[:, 1:]
        iou = iou.mean()

        return -iou



class TverskyLoss(nn.Module):
    def __init__(self, apply_nonlin=None, smooth=1e-5,
                 batch_dice=False, do_bg=True, square=False, loss_mask=None, alpha=0.3, beta=0.7):
        super(TverskyLoss, self).__init__()

        self.apply_nonlin = apply_nonlin
        self.smooth = smooth
        self.batch_dice = batch_dice
        self.do_bg = do_bg
        self.square = square
        self.loss_mask = loss_mask
        self.alpha = alpha
        self.beta = beta

    def forward(self, net_out, target):
        shp_x = net_out.shape

        if self.apply_nonlin != None:
            net_out = self.apply_nonlin(net_out)
        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        tp, tn, fp, fn = tp_tn_fp_fn(net_out, target, axes, self.loss_mask, self.square)

        tversky = (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)

        if self.do_bg is not True:
            if self.batch_dice:
                tversky = tversky[1:]
            else:
                tversky = tversky[:, 1:]
        tversky = tversky.mean()

        return tversky


class FocalTversky_loss(nn.Module):
    def __init__(self, gamma=0.75):
        super(FocalTversky_loss, self).__init__()

        self.tversky_kwargs = TverskyLoss()
        self.gamma = gamma

    def forward(self, net_out, target):
        tversky_loss = 1 - self.tversky_kwargs(net_out, target)
        focaltversky_loss = torch.pow(tversky_loss, self.gamma)
        return focaltversky_loss


class AsymLoss(nn.Module):
    def __init__(self, apply_nonlin=None, smooth=1e-5,
                 batch_dice=False, do_bg=True, square=False, loss_mask=None, beta=1.5):
        super(AsymLoss, self).__init__()

        self.apply_nonlin = apply_nonlin
        self.smooth = smooth
        self.batch_dice = batch_dice
        self.do_bg = do_bg
        self.square = square
        self.loss_mask = loss_mask
        self.beta = beta

    def forward(self, net_out, target):
        shp_x = net_out.shape

        if self.apply_nonlin is not None:
            net_out = self.apply_nonlin(net_out)
        if self.batch_dice:
            axes = [0] + list(range(2, len(shp_x)))
        else:
            axes = list(range(2, len(shp_x)))

        tp , tn, fp, fn = tp_tn_fp_fn(net_out, target, axes, self.loss_mask, self.square)

        weight = (self.beta ** 2) / (1 + self.beta ** 2)
        asym_loss = (tp + self.smooth) / (tp + weight*fn + (1 - weight)*fp + self.smooth)

        if self.do_bg is not True:
            if self.batch_dice:
                asym_loss = asym_loss[1:]
            else:
                asym_loss = asym_loss[:, 1:]
        asym_loss = asym_loss.mean()

        return asym_loss


class Dice_and_CE_loss(nn.Module):
    def __init__(self, aggregate='sum'):
        super(Dice_and_CE_loss, self).__init__()
        self.aggregate = aggregate
        self.ce = CrossEntropy()
        self.dice = SoftDiceLoss()

    def forward(self, net_out, target):
        dice_loss = self.dice(net_out, target)
        ce_loss = self.ce(net_out, target)

        if self.aggregate == 'sum':
            loss = dice_loss + ce_loss
        else:
            raise NotImplementedError('nah son')

        return loss


class PenaltyGDiceLoss(nn.Module):
    def __init__(self, k=2.5):
        super(PenaltyGDiceLoss, self).__init__()

        self.k = k
        self.GDice = GDiceLoss()

    def forward(self, net_out, target):
        GDice_loss = self.GDice(net_out, target)
        panalty_GDice_loss = GDice_loss / (1 + self.k * (1-GDice_loss))

        return panalty_GDice_loss

class Dice_and_Topk_loss(nn.Module):
    def __init__(self, agregate='sum'):
        super(Dice_and_Topk_loss, self).__init__()

        self.agregate = agregate
        self.dice = SoftDiceLoss()
        self.topk = TopkLoss()

    def forward(self, net_out, target):
        dice_loss = self.dice(net_out, target)
        topk_loss = self.topk(net_out, target)
        if self.agregate == 'sum':
            dice_and_topk_loss = dice_loss + topk_loss
        else:
            raise NotImplementedError('nah son')
        return dice_and_topk_loss


class ExpLog_Loss(nn.Module):
    def __init__(self, gamma=0.3):
        super(ExpLog_Loss, self).__init__()
        self.wce_loss = Weight_CrossEntropy_Loss(weight=[0.9, 0.1], balance_idx=0)
        self.dice = SoftDiceLoss()
        self.gamma = gamma

    def forward(self, net_out, target):
        dice_loss = self.dice(net_out, target)      # weight=0.8
        wce_loss = self.wce_loss(net_out, target)   # weight=0.2
        explog_loss = 0.8 * torch.pow(dice_loss, self.gamma) + 0.2 * wce_loss

        return explog_loss


if __name__ == '__main__':
    img = torch.tensor(
        [[[[0.2, 0.2, 0.3, 0.3],
           [0.2, 0.2, 0.3, 0.3],
           [0.2, 0.2, 0.3, 0.3],
           [0.2, 0.2, 0.3, 0.3]],

          [[0.8, 0.8, 0.7, 0.7],
           [0.8, 0.8, 0.7, 0.7],
           [0.8, 0.8, 0.7, 0.7],
           [0.8, 0.8, 0.7, 0.7]]]]
    )
    target = torch.tensor([[[1, 1, 0, 0],
                            [1, 1, 0, 0],
                            [1, 1, 0, 0],
                            [1, 1, 0, 0]]])
    net = ExpLog_Loss()
    out = tp_tn_fp_fn(img, target)
    print(out)


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

后天...

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值