【pytorch】——crossentropy如何求出每个样本的loss

pytorch, crossentropy

思路:

  • 自己写crossentropy的计算过程
  • 采用pytorch自身的crossentropy,需要将input变为:bc-1, target:b*-1

1. 自己实现

def Self_cross_entropy(pred, target):
    probability=F.softmax(pred, dim=1)#shape [num_samples,num_classes]
    log_P=torch.log(probability)
    '''对输入的target标签进行 one-hot编码,使用_scatter方法'''
    a=torch.unsqueeze(target,dim=0)
    # print(a.shape,probability.shape)
    one_hot = torch.zeros(probability.shape, device=target.device).scatter_(1, torch.unsqueeze(target,dim=1), 1)
    loss3 = - one_hot * log_P
    loss3 = loss3.sum(dim=1)
    loss3 = loss3.mean(1).mean(1)
    return loss3

2. 采用pytorch的crossentropy

  • 将crossentropy的reduction='none'
  • 改变input,target的shape:需要将input变为:bc-1, target:b*-1

example具体代码:

import torch
import torch.nn as nn
import torch.nn.functional as F

def Self_cross_entropy(pred, target):
    probability=F.softmax(pred, dim=1)#shape [num_samples,num_classes]
    log_P=torch.log(probability)
    '''对输入的target标签进行 one-hot编码,使用_scatter方法'''
    a=torch.unsqueeze(target,dim=0)
    # print(a.shape,probability.shape)
    one_hot = torch.zeros(probability.shape, device=target.device).scatter_(1, torch.unsqueeze(target,dim=1), 1)
    loss3 = - one_hot * log_P
    loss3 = loss3.sum(dim=1)
    loss3 = loss3.mean(1).mean(1)
    return loss3

loss = nn.CrossEntropyLoss()
input = torch.randn(4, 5, 10, 10, requires_grad=True)
target = torch.empty((4, 10, 10), dtype=torch.long).random_(5)
output = loss(input, target)
print("pytorch CEloss:", output)

selfout = Self_cross_entropy(input, target)
print("self CE", selfout.mean())
print("every CE", selfout)

loss2= nn.CrossEntropyLoss(reduction='none')
b, c = input.shape[:2]
output3 = loss2(input.view(b, c, -1), target.view(b, -1))
print("pytorch CE", output3.mean())
print("pytorch every CE", output3.mean(-1))

输出

pytorch CEloss: tensor(1.9572, grad_fn=)
self CE tensor(1.9572, grad_fn=)
every CE tensor([1.8493, 2.0463, 1.9736, 1.9595], grad_fn=)
pytorch CE tensor(1.9572, grad_fn=)
pytorch every CE tensor([1.8493, 2.0463, 1.9736, 1.9595], grad_fn=)

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值