简单的例子说明 F.cross_entropy用法

F.cross_entropy是PyTorch中计算交叉熵损失的函数。来看一个简单的例子来说明它是如何计算的。

首先,了解F.cross_entropy的输入:

输入(input):这通常是模型的输出,shape为(N, C),其中N是batch size,C是类别的数量。
目标(target):这是每个样本的类别索引,shape为(N,),即一个长度为N的向量。
输出的shape是一个标量,代表了整个batch的平均损失。

举个例子,假设有一个3分类问题,batch size是2。
模型输出了每个类别的未归一化的分数(即logits),而不是概率。
输出是一个2x3的矩阵,因为我们有两个样本(N=2)和三个类别(C=3)。目标值(target)是每个样本的真实类别。

假设模型对于两个样本的输出(logits)是:

[[2.0, 1.0, 0.1],   # 第一个样本的logits
 [0.1, 3.0, 0.2]]   # 第二个样本的logits

假设类别标签是:

[0, 2]  # 第一个样本的真实类别是0,第二个样本的真实类别是2

那么,F.cross_entropy的计算步骤如下:

首先,对每个样本的logits应用softmax函数,将logits转化为概率。

对于第一个样本:

e^2.0 / (e^2.0 + e^1.0 + e^0.1)0.659,
e^1.0 / (e^2.0 + e^1.0 + e^0.1)0.242,
e^0.1 / (e^2.0 + e^1.0 + e^0.1)0.099

对于第二个样本:

e^0.1 / (e^0.1 + e^3.0 + e^0.2)0.018,
e^3.0 / (e^0.1 + e^3.0 + e^0.2)0.864,
e^0.2 / (e^0.1 + e^3.0 + e^0.2)0.118

然后,使用这个概率分布和真实的类别标签来计算交叉熵损失。

交叉熵损失的公式为-sum(y_i * log(p_i)),其中y_i是目标标签的one-hot编码(对应类别处为1,其余为0)。

对于第一个样本,我们只关心类别0的损失,因为真实类别是0:

-log(0.659)-(-0.417)0.417

对于第二个样本,我们只关心类别2的损失,因为真实类别是2:

-log(0.118)-(-2.136)2.136

最后,取所有样本的交叉熵损失的平均值。

(0.417 + 2.136) / 21.277

使用F.cross_entropy计算上述例子的平均损失的PyTorch代码:

import torch
import torch.nn.functional as F
# 模型输出
logits = torch.tensor([[2.0, 1.0, 0.1],
                       [0.1, 3.0, 0.2]])
# 真实类别
targets = torch.tensor([0, 2])
# 计算交叉熵损失
loss = F.cross_entropy(logits, targets)
print(loss)  # 输出将近似于1.277

最终的输出 loss 即为这个batch的平均交叉熵损失。在上面的例子中,它应该接近于计算得到的1.277。

  • 8
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

蓝羽飞鸟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值