使用场景:解决样本不均衡的问题
优点:
缺点:
不稳定,容易出现梯度爆炸
实例化:dice_loss = loss_fns.SoftDiceLoss()
使用:loss1 = dice_loss(out, y)
class SoftDiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True, sigmoid=False):
super(SoftDiceLoss, self).__init__()
self.sigmoid = sigmoid
def forward(self, logits, targets):
num = targets.size(0)
smooth = 1
if self.sigmoid:
logits = torch.sigmoid(logits)
m1 = logits.view(num, -1)
m2 = targets.view(num, -1)
intersection = (m1 * m2)
dice = 2. * (intersection.sum() + smooth) / ((m1 ** 2).sum() + (m1 ** 2).sum() + smooth)
loss = 1. - dice.mean()
return loss