多分类 focal loss 以及 dice loss 的pytorch以及keras实现
pytorch 下的多分类 focal loss 以及 dice loss实现
dice loss
class DiceLoss(nn.Module):
def __init__(self):
super(DiceLoss, self).__init__()
def forward(self, input, target):
N = target.size(0)
smooth = 1
input_flat = input.view(N, -1)
target_flat = target.view(N, -1)
intersection = input_flat * target_flat
loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
# loss = 1 - loss.sum() / N
return 1 - loss
focal loss
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, logits=False, sampling='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.sampling = sampling
def forward(self, y_pred, y_true):
alpha = self.alpha
alpha_ = (1 - self.alpha)
if self.logits:
y_pred = torch.sigmoid(y_pred)
pt_positive = torch.where(y_true == 1, y_pred, torch.ones_like(y_pred))
pt_negative = torch.where(y_true == 0, y_pred, torch.zeros_like(y_pred))
pt_positive = torch.clamp(pt_positive, 1e-3, .999)
pt_negative = torch.clamp(pt_negative, 1e-3, .999)
pos_ = (1 - pt_positive) ** self.gamma
neg_ = pt_negative ** self.gamma
pos_loss = -alpha * pos_ * torch.log(pt_positive)
neg_loss = -alpha_ * neg_ * torch.log(1 - pt_negative)
loss = pos_loss + neg_loss
if self.sampling == "mean":
return loss.mean()
elif self.sampling == "sum":
return loss.sum()
elif self.sampling == None:
return loss
keras/tf 下的多分类 focal loss 以及 dice loss实现
dice loss
def dice(y_true, y_pred, smooth=1.):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_loss(y_true, y_pred):
return 1-dice(y_true, y_pred)
focal loss
def focal_loss(y_true, y_pred):
gamma = 2
alpha = 0.25
'''tf.where(tensor,a,b):将tensor中true位置元素替换为a中对应位置元素,false的替换为b中对应位置元素'''
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
pt_1 = K.clip(pt_1, 1e-3, .999)
pt_0 = K.clip(pt_0, 1e-3, .999)
return K.mean(-alpha*K.pow(1.-pt_1, gamma)*K.log(pt_1)-(1-alpha)*K.pow(pt_0, gamma)*K.log(1.-pt_0))