本篇文章中我们将详细比较torch.nn中两个损失函数类NLLLoss与CrossEntropyLoss,首先我们将介绍负对数似然和交叉熵,其次我们再介绍在Pytorch中两个类具体的执行计算方式。
数学推导
我们来考虑一个
n
n
n分类问题, 为了使讨论更为简洁,我们这里只考虑一个样本(sample),输入为
x
\boldsymbol{x}
x ,经模型输出为
l
o
g
i
t
s
logits
logits,经过Softmax归一化后预测概率分布为
y
^
=
S
o
f
t
m
a
x
(
l
o
g
i
t
s
)
=
[
p
1
,
p
2
,
…
,
p
n
]
T
\hat{y}=Softmax(logits)=[p_1,p_2,\dots,p_n]^T
y^=Softmax(logits)=[p1,p2,…,pn]T,真实标签为
y
\boldsymbol{y}
y,假设该样本实际上属于第
c
c
c类,即
y
=
[
y
1
,
y
2
,
…
,
y
n
]
T
=
[
0
,
0
,
…
,
1
,
…
,
0
]
T
\boldsymbol{y}=[y_1,y_2,\dots,y_n]^T=[0,0,\dots,1,\dots,0]^T
y=[y1,y2,…,yn]T=[0,0,…,1,…,0]T为one-hot向量。
我们想要最大化样本属于真实类别
c
c
c的概率, 即最小化负对数似然(negetive log likelihood)
N
L
L
=
−
L
o
g
P
(
y
^
∣
x
)
=
−
l
o
g
p
c
(1)
\begin{aligned} NLL &= -LogP(\hat{y}|x)=-logp_c \tag{1} \end{aligned}
NLL=−LogP(y^∣x)=−logpc(1)
另外要注意深度学习中
l
o
g
log
log函数往往指的是
l
n
ln
ln函数,即自然对数。
而我们知道
y
\boldsymbol{y}
y为one-hot向量,只有第
c
c
c维位置为1,故
N
L
L
=
−
l
o
g
p
c
=
−
1
⋅
l
o
g
p
c
=
−
(
0
⋅
l
o
g
p
1
+
0
⋅
l
o
g
p
2
+
⋯
+
1
⋅
l
o
g
p
c
+
⋯
+
0
⋅
l
o
g
p
n
)
=
−
∑
i
=
1
n
y
i
l
o
g
p
i
=
−
y
⋅
l
o
g
y
^
(2)
\begin{aligned} NLL &= -logp_c \\ &= -1\cdot logp_c \\ &=-(0\cdot logp_1+0\cdot logp_2+\dots+1\cdot logp_c+\dots+0\cdot logp_n) \\ &= -\sum\limits_{i=1}^ny_ilogp_i \\ &= -\boldsymbol{y}\cdot log\hat{\textbf{y}} \tag{2} \end{aligned}
NLL=−logpc=−1⋅logpc=−(0⋅logp1+0⋅logp2+⋯+1⋅logpc+⋯+0⋅logpn)=−i=1∑nyilogpi=−y⋅logy^(2)
最后结果即为交叉熵(Cross Entropy)
C
E
=
−
y
⋅
l
o
g
y
^
(3)
CE = -\boldsymbol{y}\cdot log\hat{\textbf{y}} \tag{3}
CE=−y⋅logy^(3)
所以对于n分类问题,两者是等价的。
代码实践
但事实上在Pytorch中,具体的执行计算方式有所不同。
由公式(2)我们可得到
C
E
=
−
y
⋅
L
o
g
S
o
f
t
m
a
x
(
l
o
g
i
t
s
)
(4)
\begin{aligned} CE &= -\boldsymbol{y}\cdot LogSoftmax(logits) \tag{4} \end{aligned}
CE=−y⋅LogSoftmax(logits)(4)
而CrossEntropyLoss()事实上是对logits进行LogSoftmax计算交叉熵,但是NLLLoss()并没有这一步,需要对模型输出的logits外加LogSoftmax操作。
下面我们通过代码演示来展示在Pytorch框架中两种损失函数的实际应用区别。
import torch.nn as nn
import torch.nn.functional as F
nnl = nn.NLLLoss()
ce = nn.CrossEntropyLoss()
ls = nn.LogSoftmax(dim=-1)
logits = torch.rand(3)
target = torch.tensor(1)
print(logits)
loss1 = nnl(ls(logits), target)
loss2 = ce(logits, target)
print(loss1)
print(loss2)
# output
#tensor([0.0437, 0.1241, 0.2193])
#tensor(1.1061)
#tensor(1.1061)
所以我们最终可以总结为
n
n
.
L
o
g
S
o
f
t
m
a
x
(
)
&
n
n
.
N
L
L
L
o
s
s
(
)
⇔
n
n
.
C
r
o
s
s
E
n
t
r
o
p
y
L
o
s
s
(
)
\textcolor{red} {nn.LogSoftmax() \& nn.NLLLoss() \quad \Leftrightarrow \quad nn.CrossEntropyLoss()}
nn.LogSoftmax()&nn.NLLLoss()⇔nn.CrossEntropyLoss()