pytorch中的cross_entropy函数

本文详细解析了PyTorch中的交叉熵函数cross_entropy的计算过程,包括输入参数input(模型预测概率)和target(真实类别标签)的处理。通过softmax将模型输出转换为概率分布,并对目标标签进行one-hot编码,简化计算公式。最终通过Python代码展示了从原始输出到交叉熵损失的计算步骤,并比较了函数和类的计算结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

        cross_entropy函数是pytorch中计算交叉熵的函数。根据源码分析,输入主要包括两部分,一个是input,是维度为(batch_size,class)的矩阵,class表示分类的数量,这个就表示模型输出的预测结果;另一个是target,是维度为(batch_size)的一维向量,表示每个样本的真实值。输出是交叉熵的值。nn中的CrossEntropyLoss类与此函数的作用相同。


计算过程

        交叉熵是常见的损失函数,之前的文章中已经详细介绍了交叉熵的公式由来(交叉熵详解),公式如下:

L=-[y*log\hat{y}+(1-y)*log(1-\hat{y})]

        如果用在多分类问题中当做损失函数的话,一般会这样写:

L=-\sum_{i=1}^{n}y\cdot log_{2}\hat{y}

        其中y是真实分类,是一个标签值;\hat{y}是模型预测结果,包含了属于每种标签的概率(此时这几个概率相加还不等于1)。在上面说了函数的输入分别是input和target,那么y就对应target这个向量,\hat{y}就对应input这个矩阵。但是现在就出现了两个问题:

         1、target就是一个标签值,无法与input直接进行运算,那么我们就对这个target先进行one-hot编码,使二者的维度相同。 比如target的值是[3],共有5个类,那么转换为one-hot编码之后就是:[0,0,0,1,0]。

        2、input中的预测值不能直接代表概率,而且这几个值相加不为1,这时就进行一个softmax操作,让模型的输出值满足上面两个条件,对输出结果的归一化公式如下:

\hat{y}=P(\hat{y}=i|x)=\frac{e^{input_{[i]}}}{\sum_{j=1}^{n}e^{input_{[j]}}}

        将\hat{y}带入到上面的损失函数公式中进行推导:

L=-\sum_{i=1}^{n}y\cdot log_{2}\hat{y}=-\sum_{i=1}^{n}y\cdot \frac{e^{input_{[i]}}}{\sum_{j=1}^{n}e^{input_{[j]}}}=-\sum_{i=1}^{n}y(input[i]-log_{2}\sum_{j=1}^{n}e^{input_{[j]}})

        在第一点中我们已经将y转换为了[0,0,0,1,0]这样的编码,可见式子中只有target那一项的损失值需要计算,其他与0相乘就都消掉了,所以式子中最外层那个连加号就可以去掉了。最后式子就简化为:

L=-input[target]+log_{2}\sum_{j=1}^{n}e^{input_{[j]}}


Python实现

        现在已经清楚了cross_entropy这个函数的运算过程,现在用Python来模拟一下,首先构造一个随机的input和target:

output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
print("output:")
print(output)
print("label:")
print(label)

        输出结果如下:

        然后模拟函数的运算过程:

first = -input[0][target[0]]
second = 0
res=0
for j in range(5):
    second += math.exp(input[0][j])
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

         输出结果为:

         然后分别调用cross_entropy函数和CrossEntropyLoss类计算loss的结果:

criterion = nn.CrossEntropyLoss()
loss = criterion(input, target)
loss2 = torch.nn.functional.cross_entropy(input=input,target=target)
print("cross_entropy函数计算loss的结果:")
print(loss)
print("CrossEntropyLoss类计算loss的结果:")
print(loss2)

        输出结果为:

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值