Dice和IoU代码实现
class DiceLoss(nn.Module):#二分类
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, input, target):
N = target.size(0)
smooth = 0.001 #防止/0
input = nn.functional.sigmoid(input)
input = input[:, 1]
input_flat = input.view(N, -1)
target_flat = target.view(N, -1)
intersection = input_flat * target_flat #用相乘来求得交集
loss = 2.0 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.