对PyTorch中F.cross_entropy()的理解
PyTorch提供了求交叉熵的两个常用函数,一个是F.cross_entropy()
,另一个是F.nll_entropy()
,在学这两个函数的使用的时候有一些问题,尤其是对F.cross_entropy(input, target)
中参数target
的理解很困难,现在好像弄懂了一些,故写一篇Blog进行记录,方便日后查阅。
一、交叉熵的公式及计算步骤
1、交叉熵的公式:
H ( p , q ) = − ∑ i P ( i ) log Q ( i ) H\left( {p,q} \right) = - \sum\limits_i {P\left( i \right)\log Q\left( i \right)} H(p,q)=−i∑P(i)logQ(i)
其中 P P P为真实值, Q Q Q为预测值。
2、计算交叉熵的步骤:
1)步骤说明:
①将predict_scores进行softmax运算,将运算结果记为pred_scores_soft;
②将pred_scores_soft进行log运算,将运算结果记为pred_scores_soft_log;
③将pred_scores_soft_log与真实值进行计算处理。
思路即:
s
c
o
r
e
s
→
s
o
f
t
m
a
x
→
l
o
g
→
c
o
m
p
u
t
e
scores\to softmax\to log\to compute
scores→softmax→log→compute
2)举一个例子对计算进行说明:
P 1 = [ 1 0 0 0 0 ] P_1=\begin{bmatrix} 1 & 0 & 0 & 0 & 0 \\ \end{bmatrix} P1=[10000]
Q 1 = [ 0.4 0.3 0.05 0.05 0.2 ] Q_1=\begin{bmatrix} 0.4 & 0.3 & 0.05 & 0.05 & 0.2 \\ \end{bmatrix} Q1=[0.40.30.050.050.2]
H
(
p
,
q
)
=
−
∑
i
P
(
i
)
log
Q
(
i
)
=
−
(
1
∗
l
o
g
0.4
+
0
∗
l
o
g
0.3
+
0
∗
l
o
g
0.05
+
0
∗
l
o
g
0.05
+
0
∗
l
o
g
0.2
)
=
−
l
o
g
0.4
≈
0.916
H\left( {p,q} \right) = - \sum\limits_i {P\left( i \right)\log Q\left( i \right)}=-(1*log0.4+0*log0.3+0*log0.05+0*log0.05+0*log0.2) \\=-log0.4 \approx 0.916
H(p,q)=−i∑P(i)logQ(i)=−(1∗log0.4+0∗log0.3+0∗log0.05+0∗log0.05+0∗log0.2)=−log0.4≈0.916
如果
Q
2
=
[
0.98
0.01
0
0
0.01
]
Q_2=\begin{bmatrix} 0.98 & 0.01 & 0 & 0 & 0.01 \\ \end{bmatrix}
Q2=[0.980.01000.01]
则
H
(
p
,
q
)
=
−
∑
i
P
(
i
)
log
Q
(
i
)
=
−
(
1
∗
l
o
g
0.98
+
0
∗
l
o
g
0.01
+
0
∗
l
o
g
0.05
+
0
∗
l
o
g
0
+
0
∗
l
o
g
0.01
)
=
−
l
o
g
0.98
≈
0.02
H\left( {p,q} \right) = - \sum\limits_i {P\left( i \right)\log Q\left( i \right)}=-(1*log0.98+0*log0.01+0*log0.05+0*log0+0*log0.01) \\=-log0.98 \approx 0.02
H(p,q)=−i∑P(i)logQ(i)=−(1∗log0.98+0∗log0.01+0∗log0.05+0∗log0+0∗log0.01)=−log0.98≈0.02
由 H ( p , q ) H(p,q) H(p,q)的计算结果和直观地观察 Q 1 Q_1 Q1和 Q 2 Q_2 Q2与 P 1 P_1 P1的相似度,均可看出 Q 2 Q_2 Q2比 Q 1 Q_1 Q1更近似于 P 1 P_1 P1。
二、官方文档的说明
在PyTorch的官方中文文档中F.cross_entropy()的记录如下:
torch.nn.functional.cross_entropy(input, target, weight=None, size_average=True)
该函数使用了 log_softmax
和 nll_loss
,详细请看CrossEntropyLoss
常用参数:
参数名 | shape | 注 |
---|---|---|
input | (N,C) | C是类别的个数 |
target | N | 0 <= targets[i] <= C-1 |
三、自己的困惑
在官方文档说明中,对于target
参数的说明为,torch.shape
为torch.Size([N])
,0 <= targets[i] <= C-1。
我的困惑是:网络计算输出并送入函数中的input
的torch.shape
为torch.Size([N,C])
,它的torch.shape
并不会因为softmax
和log
的操作而改变,但是target
的torch.shape
为torch.Size([N])
,是一个标量而不是一个矩阵,那么如何按照上面的例子中的运算方法进行交叉熵的计算?
例如:
import torch
import torch.nn.functional as F
pred_score = torch.tensor([[13., 3., 2., 5., 1.],
[1., 8., 20., 2., 3.],
[1., 14., 3., 5., 3.]])
print(pred_score)
pred_score_soft = F.softmax(pred_score, dim=1)
print(pred_score_soft)
pred_score_soft_log = pred_score_soft.log()
print(pred_score_soft_log)
它的结果为:
tensor([[13., 3., 2., 5., 1.],
[ 1., 8., 20., 2., 3.],
[ 1., 14., 3., 5., 3.]])
tensor([[9.9960e-01, 4.5382e-05, 1.6695e-05, 3.3533e-04, 6.1417e-06],
[5.6028e-09, 6.1442e-06, 9.9999e-01, 1.5230e-08, 4.1399e-08],
[2.2600e-06, 9.9984e-01, 1.6699e-05, 1.2339e-04, 1.6699e-05]])
tensor([[-4.0366e-04, -1.0000e+01, -1.1000e+01, -8.0004e+00, -1.2000e+01],
[-1.9000e+01, -1.2000e+01, -6.1989e-06, -1.8000e+01, -1.7000e+01],
[-1.3000e+01, -1.5904e-04, -1.1000e+01, -9.0002e+00, -1.1000e+01]])
如何与一个标量target进行计算?
四、分析
F.Cross_entropy(input, target)
函数中包含了
s
o
f
t
m
a
x
softmax
softmax和
l
o
g
log
log的操作,即网络计算送入的input
参数不需要进行这两个操作。
例如在分类问题中,input
表示为一个torch.Size([N, C])
的矩阵,其中,
N
N
N为样本的个数,
C
C
C是类别的个数,input[i][j]
可以理解为第
i
i
i个样本的类别为
j
j
j的Scores,Scores值越大,类别为
j
j
j的可能性越高,就像在代码块中所体现的那样。
同时,一般我们将分类问题的结果作为lable表示时使用one-hot embedding,例如在手写数字识别的分类问题中,数字0的表示为 [ 1 0 0 0 0 0 0 0 0 0 ] \begin{bmatrix} 1 & 0 & 0 & 0 & 0& 0& 0& 0& 0& 0 \end{bmatrix} [1000000000]数字3的表示为 [ 0 0 0 1 0 0 0 0 0 0 ] \begin{bmatrix} 0 & 0 & 0 & 1 & 0& 0& 0& 0& 0& 0 \end{bmatrix} [0001000000]在手写数字识别的问题中,我们计算 l o s s loss loss的方法为 l o s s = ( y − y ^ ) 2 loss=(y-\hat y)^2 loss=(y−y^)2,即求 y y y的embedding的矩阵减去pred_probability矩阵的结果矩阵的范数。
但是在这里,交叉熵的计算公式为
H ( p , q ) = − ∑ i P ( i ) log Q ( i ) H\left( {p,q} \right) = - \sum\limits_i {P\left( i \right)\log Q\left( i \right)} H(p,q)=−i∑P(i)logQ(i)
其中 P P P为真实值概率矩阵, Q Q Q为预测值概率矩阵。
那么如果
P
P
P使用one-hot embedding的话,只有在
i
i
i为正确分类时
P
(
i
)
P(i)
P(i)才等于
1
1
1,否则,
P
(
i
)
P(i)
P(i)等于0。
例如在手写数字识别中,数字3的one-hot表示为
[
0
0
0
1
0
0
0
0
0
0
]
\begin{bmatrix} 0 & 0 & 0 & 1 & 0& 0& 0& 0& 0& 0 \end{bmatrix}
[0001000000]
对于交叉熵来说,
H
(
p
,
q
)
=
−
∑
i
P
(
i
)
l
o
g
Q
(
i
)
=
−
P
(
3
)
l
o
g
Q
(
3
)
=
−
l
o
g
Q
(
3
)
H(p,q)=- \sum\limits_iP(i)logQ(i)=-P(3)logQ(3)=-logQ(3)
H(p,q)=−i∑P(i)logQ(i)=−P(3)logQ(3)=−logQ(3)发现
H
(
p
,
q
)
H(p,q)
H(p,q)的计算不依赖于
P
P
P矩阵,而仅仅与
P
P
P的真实类别的
i
n
d
e
x
index
index有关
五、总结
所以,我的理解是,在one-hot编码的前提下,在pytorch代码中target不需要以one-hot形式表示,而是直接用scalar,scalar的值则是真实类别的index。所以交叉熵的公式可表示为:
H
(
p
,
q
)
=
−
∑
i
P
(
i
)
l
o
g
Q
(
i
)
=
−
P
(
m
)
l
o
g
Q
(
m
)
=
−
l
o
g
Q
(
m
)
H(p,q)=- \sum\limits_iP(i)logQ(i)=-P(m)logQ(m)=-logQ(m)
H(p,q)=−i∑P(i)logQ(i)=−P(m)logQ(m)=−logQ(m)
其中,
m
m
m表示真实类别。