F.cross_entropy是PyTorch中计算交叉熵损失的函数。来看一个简单的例子来说明它是如何计算的。
首先,了解F.cross_entropy的输入:
输入(input):这通常是模型的输出,shape为(N, C),其中N是batch size,C是类别的数量。
目标(target):这是每个样本的类别索引,shape为(N,),即一个长度为N的向量。
输出的shape是一个标量,代表了整个batch的平均损失。
举个例子,假设有一个3分类问题,batch size是2。
模型输出了每个类别的未归一化的分数(即logits),而不是概率。
输出是一个2x3的矩阵,因为我们有两个样本(N=2)和三个类别(C=3)。目标值(target)是每个样本的真实类别。
假设模型对于两个样本的输出(logits)是:
[[2.0, 1.0, 0.1], # 第一个样本的logits
[0.1, 3.0, 0.2]] # 第二个样本的logits
假设类别标签是:
[0, 2] # 第一个样本的真实类别是0,第二个样本的真实类别是2
那么,F.cross_entropy的计算步骤如下:
首先,对每个样本的logits应用softmax函数,将logits转化为概率。
对于第一个样本:
e^2.0 / (e^2.0 + e^1.0 + e^0.1) ≈ 0.659,
e^1.0 / (e^2.0 + e^1.0 + e^0.1) ≈ 0.242,
e^0.1 / (e^2.0 + e^1.0 + e^0.1) ≈ 0.099
对于第二个样本:
e^0.1 / (e^0.1 + e^3.0 + e^0.2) ≈ 0.018,
e^3.0 / (e^0.1 + e^3.0 + e^0.2) ≈ 0.864,
e^0.2 / (e^0.1 + e^3.0 + e^0.2) ≈ 0.118
然后,使用这个概率分布和真实的类别标签来计算交叉熵损失。
交叉熵损失的公式为-sum(y_i * log(p_i)),其中y_i是目标标签的one-hot编码(对应类别处为1,其余为0)。
对于第一个样本,我们只关心类别0的损失,因为真实类别是0:
-log(0.659) ≈ -(-0.417) ≈ 0.417
对于第二个样本,我们只关心类别2的损失,因为真实类别是2:
-log(0.118) ≈ -(-2.136) ≈ 2.136
最后,取所有样本的交叉熵损失的平均值。
(0.417 + 2.136) / 2 ≈ 1.277
使用F.cross_entropy计算上述例子的平均损失的PyTorch代码:
import torch
import torch.nn.functional as F
# 模型输出
logits = torch.tensor([[2.0, 1.0, 0.1],
[0.1, 3.0, 0.2]])
# 真实类别
targets = torch.tensor([0, 2])
# 计算交叉熵损失
loss = F.cross_entropy(logits, targets)
print(loss) # 输出将近似于1.277
最终的输出 loss 即为这个batch的平均交叉熵损失。在上面的例子中,它应该接近于计算得到的1.277。