各种loss实现

        bsz = pred.shape[0]
        if pred.dim() != target.dim():
            # one_hot_target, weight = _expand_onehot_labels(target, pred.size(-1))
            one_hot_target = F.one_hot(target).float()

        # pred_norm = pred.sigmoid() if self.require_sigmoid else pred
        # pred_norm = 1. / (torch.exp(2.*pred) + 1.0)
        # one_hot_target = one_hot_target.type_as(pred)
        pred_norm = torch.clamp_min(pred, 0.)

        if self.downweight_pos:
            pt = (1 - pred_norm) * one_hot_target + pred_norm * (1 - one_hot_target)
            focal_weight = (self.alpha * one_hot_target + (1 - self.alpha) * (1 - one_hot_target)) * pt.pow(self.gamma)
        else:
            pt = (1 / pred_norm) * one_hot_target + pred_norm * (1 - one_hot_target)
            focal_weight = pt.pow(self.gamma)

        pred_log_softmax = -F.log_softmax(pred, dim=1)
        loss = (one_hot_target*pred_log_softmax).sum() / bsz
        print('\n')
        print('nll_loss', loss)
        print('ce loss:', F.cross_entropy(pred, target))
        print('our binary loss:', -(pred.sigmoid().log()*one_hot_target+(1-one_hot_target)*(1-pred.sigmoid()).log()).mean())
        print('binary loss:', F.binary_cross_entropy_with_logits(pred, one_hot_target).mean())

        return loss
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值