pytorch nn.CrossEntropyLoss 源代码复现

hi,我们先看一下nn.CrossEntropyLoss 应用例子:

import torch
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
# 方便理解,此处假设batch_size = 1
x_output = torch.randn(2, 3)   # 预测2个对象,每个对象分别属于三个类别分别的概率
# 需要的GT格式为(2)的tensor,其中的值范围必须在0-2(0<value<C-1)之间。
x_target = torch.tensor([0, 2])  # 这里给出两个对象所属的类别标签即可,此处的意思为第一个对象属于第0类,第二个我对象属于第2类
loss = loss_fn(x_output, x_target)
print('loss:\n', loss)

x_output代表模型的输出它的大小为[2,3],2代表批次,3代表维度(3分类任务)

x_target代表真值,大小为[2],代表每一个批次的真实输出

nn.CrossEntropyLoss 会对x_output在1维度上进行一个softmax,因此先写一个softmax函数

def softmax(output):
    #对所用的output取指数
    output=torch.exp(output)
    #在1维度上取sum
    sum=output.sum(dim=1,keepdim=True)
    output=output/sum
    return output

源码可能不对x_target 进行one_hot,但是我进行了,帮助诸位看官理解一下就好                  

def one_hot(label,numclass=3):
    shape=label.shape
    label_one_hot=torch.zeros((shape[0],numclass))
    index=torch.arange(shape[0])
    label_one_hot[index,label]=1
    return label_one_hot

 最后才是交叉熵损失函数

def CrossEntropyLoss(output,label):
    output= softmax(output)
    label = one_hot(label)
    result=torch.mul(torch.log(output),label)
    return -result.sum()/output.shape[0]

 总代码:

import torch
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
def softmax(output):
    #对所用的output取指数
    output=torch.exp(output)
    #在1维度上取sum
    sum=output.sum(dim=1,keepdim=True)
    output=output/sum
    return output
def one_hot(label,numclass=3):
    shape=label.shape
    label_one_hot=torch.zeros((shape[0],numclass))
    index=torch.arange(shape[0])
    label_one_hot[index,label]=1
    return label_one_hot
def CrossEntropyLoss(output,label):
    output= softmax(output)
    label = one_hot(label)
    result=torch.mul(torch.log(output),label)
    return -result.sum()/output.shape[0]
#构造数据
x_output = torch.tensor([[0,1.0,2.0],
                         [3,4.5,5]])
x_target = torch.tensor([0, 2])
loss = loss_fn(x_output, x_target)
print('pytorch loss:\n', loss)
#测试自己写的
loss=CrossEntropyLoss(x_output,x_target)
print('my loss:\n', loss)

结果:

希望可以帮助大家理解代码

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值