pytorch: DiceLoss MulticlassDiceLoss

pytorch的自定义多类dice_loss 和单类dice_loss

import torch
import torch.nn as nn

class DiceLoss(nn.Module):
	def __init__(self):
		super(DiceLoss, self).__init__()

	def	forward(self, input, target):
		N = target.size(0)
		smooth = 1

		input_flat = input.view(N, -1)
		target_flat = target.view(N, -1)

		intersection = input_flat * target_flat

		loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
		loss = 1 - loss.sum() / N

		return loss

class MulticlassDiceLoss(nn.Module):
	"""
	requires one hot encoded target. Applies DiceLoss on each class iteratively.
	requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is
	  batch size and C is number of classes
	"""
	def __init__(self):
		super(MulticlassDiceLoss, self).__init__()

	def forward(self, input, target, weights=None):

		C = target.shape[1]

		# if weights is None:
		# 	weights = torch.ones(C) #uniform weights for all classes

		dice = DiceLoss()
		totalLoss = 0

		for i in range(C):
			diceLoss = dice(input[:,i], target[:,i])
			if weights is not None:
				diceLoss *= weights[i]
			totalLoss += diceLoss

		return totalLoss
class DiceCoeff(Function):
    """Dice coeff for individual examples"""

    def forward(self, input, target):
        self.save_for_backward(input, target)
        self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001
        self.union = torch.sum(input) + torch.sum(target) + 0.0001

        t = 2 * self.inter.float() / self.union.float()
        return t

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):

        input, target = self.saved_variables
        grad_input = grad_target = None

        if self.needs_input_grad[0]:
            grad_input = grad_output * 2 * (target * self.union + self.inter) \
                         / self.union * self.union
        if self.needs_input_grad[1]:
            grad_target = None

        return grad_input, grad_target


def dice_coeff(input, target):
    """Dice coeff for batches"""
    if input.is_cuda:
        s = torch.FloatTensor(1).cuda().zero_()
    else:
        s = torch.FloatTensor(1).zero_()

    for i, c in enumerate(zip(input, target)):
        s = s + DiceCoeff().forward(c[0], c[1])

    return s / (i + 1)
展开阅读全文

没有更多推荐了,返回首页