# pytorch: DiceLoss MulticlassDiceLoss

pytorch的自定义多类dice_loss 和单类dice_loss

import torch
import torch.nn as nn

class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()

def	forward(self, input, target):
N = target.size(0)
smooth = 1

input_flat = input.view(N, -1)
target_flat = target.view(N, -1)

intersection = input_flat * target_flat

loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
loss = 1 - loss.sum() / N

return loss

class MulticlassDiceLoss(nn.Module):
"""
requires one hot encoded target. Applies DiceLoss on each class iteratively.
requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is
batch size and C is number of classes
"""
def __init__(self):
super(MulticlassDiceLoss, self).__init__()

def forward(self, input, target, weights=None):

C = target.shape[1]

# if weights is None:
# 	weights = torch.ones(C) #uniform weights for all classes

dice = DiceLoss()
totalLoss = 0

for i in range(C):
diceLoss = dice(input[:,i], target[:,i])
if weights is not None:
diceLoss *= weights[i]
totalLoss += diceLoss

return totalLoss
class DiceCoeff(Function):
"""Dice coeff for individual examples"""

def forward(self, input, target):
self.save_for_backward(input, target)
self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001
self.union = torch.sum(input) + torch.sum(target) + 0.0001

t = 2 * self.inter.float() / self.union.float()
return t

# This function has only a single output, so it gets only one gradient
def backward(self, grad_output):

input, target = self.saved_variables
grad_input = grad_target = None

if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * self.union + self.inter) \
/ self.union * self.union
if self.needs_input_grad[1]:
grad_target = None

return grad_input, grad_target

def dice_coeff(input, target):
"""Dice coeff for batches"""
if input.is_cuda:
s = torch.FloatTensor(1).cuda().zero_()
else:
s = torch.FloatTensor(1).zero_()

for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])

return s / (i + 1)