于写毕设时碰到,因pytorch不提供Dice损失函数,所以需要自写。。。于是照着公式写,写出如下代码
class DiceLoss(nn.Module):
def __init__(self, smooth=1.):
super(DiceLoss, self).__init__()
self.smooth = smooth
def forward(self, pred, target):
pred = torch.sigmoid(pred)
intersection = torch.sum(torch.logical_and(pred, target))
union = torch.sum(pred) + torch.sum(target)
dice_score = (2. * intersection + self.smooth) / (union + self.smooth)
return 1 - dice_score
我在Debug时,又报了一错,猛然发现intersection变成了整型,这样就导致计算图在这里断开了,因此不能进行逻辑与操作,那就使用乘法来代替,不过这样会导致实际loss和计算公式的loss不同。
intersection = (pred*target).sum()
希望对同样情况的人有用。在计算loss时不能改变模型的输出值,这样会导致输出值的计算图断裂。