bsz = pred.shape[0]
if pred.dim() != target.dim():
# one_hot_target, weight = _expand_onehot_labels(target, pred.size(-1))
one_hot_target = F.one_hot(target).float()
# pred_norm = pred.sigmoid() if self.require_sigmoid else pred
# pred_norm = 1. / (torch.exp(2.*pred) + 1.0)
# one_hot_target = one_hot_target.type_as(pred)
pred_norm = torch.clamp_min(pred, 0.)
if self.downweight_pos:
pt = (1 - pred_norm) * one_hot_target + pred_norm * (1 - one_hot_target)
focal_weight = (self.alpha * one_hot_target + (1 - self.alpha) * (1 - one_hot_target)) * pt.pow(self.gamma)
else:
pt = (1 / pred_norm) * one_hot_target + pred_norm * (1 - one_hot_target)
focal_weight = pt.pow(self.gamma)
pred_log_softmax = -F.log_softmax(pred, dim=1)
loss = (one_hot_target*pred_log_softmax).sum() / bsz
print('\n')
print('nll_loss', loss)
print('ce loss:', F.cross_entropy(pred, target))
print('our binary loss:', -(pred.sigmoid().log()*one_hot_target+(1-one_hot_target)*(1-pred.sigmoid()).log()).mean())
print('binary loss:', F.binary_cross_entropy_with_logits(pred, one_hot_target).mean())
return loss
各种loss实现
最新推荐文章于 2023-03-05 14:58:12 发布