pytorch多标签分类类别不平衡损失函数
focal loss 多标签分类版
def criterion(y_pred, y_true, weight=None, alpha=0.25, gamma=2):
sigmoid_p = nn.Sigmoid(y_pred)
zeros = torch.zeros_like(sigmoid_p)
pos_p_sub = torch.where(y_true > zeros,y_true - sigmoid_p,zeros)
neg_p_sub = torch.where(y_true > zeros,zeros,sigmoid_p)
per_entry_cross_ent = -alpha * (pos_p_sub ** gamma) * torch.log(torch.clamp(sigmoid_p,1e-8,1.0))-(1-alpha)*(neg_p_sub ** gamma)*torch.log(torch.clamp(1.0-sigmoid_p,1e-8,1.0))
return per_entry_cross_ent.sum()
softmax应用于多标签分类
https://mp.weixin.qq.com/s/Ii2sxJUGNvX4CnmtVmbFwA
def criterion2(y_pred,y_true):
y_pred = (1 - 2*y_true)*y_pred
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = y_pred - (1 - y_true) * 1e12
zeros = torch.zeros_like(y_pred[...,:1])
y_pred_neg = torch.cat((y_pred_neg,zeros),dim=-1)
y_pred_pos = torch.cat((y_pred_pos,zeros),dim=-1)
neg_loss = torch.logsumexp(y_pred_neg,dim=-1)
pos_loss = torch.logsumexp(y_pred_pos,dim=-1)
return torch.mean(neg_loss + pos_loss)