1. 先说结论
nn.CrossEntropyLoss
(简称CEL)等于nn.NLLLoss
(简称NLL)+torch.log
+F.softmax
主要用于多分类问题的损失计算。
图为Pytorch对于CEL的描述,这个criterion(指标)结合了nn.LogSoftmax()和nn.NLLLoss()
nn.LogSoftmax()
其实也就是torch.log
+F.softmax
2.CrossEntropyLoss(CEL)公式分解
C E L ( x , c l a s s ) = − l o g ( exp ( x [ c l a s s ] ) ∑ j exp ( x [ j ] ) ) CEL(x, class) = -log(\frac{\exp(x[class])} {\sum_{j}\exp(x[j])}) CEL(x,class)=−log(∑jexp(x[j])exp(x[class]))
-
F.Softmax
对应:
S o f t m a x ( x i ) = exp ( x [ i ] ) ∑ j exp ( x [ j ] ) Softmax(x_i) = \frac{\exp(x[i])}{\sum_j\exp(x[j])} Softmax(xi)=∑jexp(x[j])exp(x[i])
求得每个标签的概率,且每个标签概率的和为1(因为分子求和后,和分母约分为1) -
torch.log
对应:
L o g ( x i ) = l o g ( x i ) Log(x_i) = log(x_i) Log(xi)=log(xi) -
nn.NLLLoss
对应:
N L L L o s s ( x , c l a s s ) = − F ( x [ c l a s s ] ) NLLLoss(x, class) = -F(x[class]) NLLLoss(x,class)=−F(x[class])
这里的F
指代上面的log
+softmax
具体是根据label中class的索引,去计算完log
和softmax
的序列中取值,并取反。
综上,CEL
嵌套如下:
C E L ( x , c l a s s ) = N L L L o s s ( L o g ( S o f t m a x ( x [ c l a s s ] ) ) ) CEL(x, class) = NLLLoss(Log(Softmax(x[class]))) CEL(x,class)=NLLLoss(Log(Softmax(x[class])))
看代码会有更直观的感受~
3.实验代码
F.softmax
+torch.log
+nn.NLLLoss
:
# label就是标签,序号为2
label = torch.tensor([2])
# x是模型预测的值,共10个
x = torch.randn((1, 10))
# tensor([[-0.9783, 1.5044, 1.5414, -0.6410, -0.5324,
# 1.0770, 0.8484, 0.9239, -0.0294, -1.6969]])
# 第一步:对x的10个数进行softmax
x = F.softmax(x, dim=1)
# tensor([[0.0192, 0.2296, 0.2382, 0.0269, 0.0299,
# 0.1497,0.1191, 0.1285, 0.0495, 0.0093]])
# 可以看到x已经归一化到 0~1
# 求和也为1,和公式相对应
print(x.sum())
# tensor(1.0000)
# 第二步:对每个数求log
x = torch.log(x)
# tensor([[-3.9542, -1.4715, -1.4345, -3.6169, -3.5084,
# -1.8989, -2.1275, -2.0520, -3.0053, -4.6728]])
# 因为x都是 0~1 之间的数,因此log后都为负数
# 第三步:NLLLoss求loss
nllloss = nn.NLLLoss()
nllloss(x, label)
# tensor(1.4345)
# 根据label=2,取得x[2]的值-1.4345,再取反
对比nn.CrossEntropyLoss
:
celoss = nn.CrossEntropyLoss()
# x2和x的初始值一致
x2 = torch.tensor([[-0.9783, 1.5044, 1.5414, -0.6410, -0.5324,
1.0770, 0.8484, 0.9239, -0.0294, -1.6969]])
celoss(x2, label)
# tensor(1.4345)
证明两者的计算流一致,CrossEntropyLoss
对三步进行了封装~