对PyTorch中F.cross_entropy()函数的理解

对PyTorch中F.cross_entropy()的理解

PyTorch提供了求交叉熵的两个常用函数,一个是F.cross_entropy(),另一个是F.nll_entropy(),在学这两个函数的使用的时候有一些问题,尤其是对F.cross_entropy(input, target)中参数target的理解很困难,现在好像弄懂了一些,故写一篇Blog进行记录,方便日后查阅。

一、交叉熵的公式及计算步骤

1、交叉熵的公式:

H ( p , q ) = − ∑ i P ( i ) log ⁡ Q ( i ) H\left( {p,q} \right) = - \sum\limits_i {P\left( i \right)\log Q\left( i \right)} H(p,q)=iP(i)logQ(i)

其中 P P P为真实值, Q Q Q为预测值

2、计算交叉熵的步骤:

1)步骤说明:

①将predict_scores进行softmax运算,将运算结果记为pred_scores_soft;
②将pred_scores_soft进行log运算,将运算结果记为pred_scores_soft_log;
③将pred_scores_soft_log与真实值进行计算处理。
思路即:
s c o r e s → s o f t m a x → l o g → c o m p u t e scores\to softmax\to log\to compute scoressoftmaxlogcompute

2)举一个例子对计算进行说明:

P 1 = [ 1 0 0 0 0 ] P_1=\begin{bmatrix} 1 & 0 & 0 & 0 & 0 \\ \end{bmatrix} P1=[10000]

Q 1 = [ 0.4 0.3 0.05 0.05 0.2 ] Q_1=\begin{bmatrix} 0.4 & 0.3 & 0.05 & 0.05 & 0.2 \\ \end{bmatrix} Q1=[0.40.30.050.050.2]

H ( p , q ) = − ∑ i P ( i ) log ⁡ Q ( i ) = − ( 1 ∗ l o g 0.4 + 0 ∗ l o g 0.3 + 0 ∗ l o g 0.05 + 0 ∗ l o g 0.05 + 0 ∗ l o g 0.2 ) = − l o g 0.4 ≈ 0.916 H\left( {p,q} \right) = - \sum\limits_i {P\left( i \right)\log Q\left( i \right)}=-(1*log0.4+0*log0.3+0*log0.05+0*log0.05+0*log0.2) \\=-log0.4 \approx 0.916 H(p,q)=iP(i)logQ(i)=(1log0.4+0log0.3+0log0.05+0log0.05+0log0.2)=log0.40.916
如果
Q 2 = [ 0.98 0.01 0 0 0.01 ] Q_2=\begin{bmatrix} 0.98 & 0.01 & 0 & 0 & 0.01 \\ \end{bmatrix} Q2=[0.980.01000.01]

H ( p , q ) = − ∑ i P ( i ) log ⁡ Q ( i ) = − ( 1 ∗ l o g 0.98 + 0 ∗ l o g 0.01 + 0 ∗ l o g 0.05 + 0 ∗ l o g 0 + 0 ∗ l o g 0.01 ) = − l o g 0.98 ≈ 0.02 H\left( {p,q} \right) = - \sum\limits_i {P\left( i \right)\log Q\left( i \right)}=-(1*log0.98+0*log0.01+0*log0.05+0*log0+0*log0.01) \\=-log0.98 \approx 0.02 H(p,q)=iP(i)logQ(i)=(1log0.98+0log0.01+0log0.05+0log0+0log0.01)=log0.980.02

H ( p , q ) H(p,q) H(p,q)的计算结果和直观地观察 Q 1 Q_1 Q1 Q 2 Q_2 Q2 P 1 P_1 P1的相似度,均可看出 Q 2 Q_2 Q2 Q 1 Q_1 Q1更近似于 P 1 P_1 P1

二、官方文档的说明

在PyTorch的官方中文文档中F.cross_entropy()的记录如下:

torch.nn.functional.cross_entropy(input, target, weight=None, size_average=True)

该函数使用了 log_softmaxnll_loss,详细请看CrossEntropyLoss

常用参数:

参数名shape
input(N,C)C是类别的个数
targetN0 <= targets[i] <= C-1

三、自己的困惑

在官方文档说明中,对于target参数的说明为,torch.shapetorch.Size([N]),0 <= targets[i] <= C-1。
我的困惑是:网络计算输出并送入函数中的inputtorch.shapetorch.Size([N,C]),它的torch.shape并不会因为softmaxlog的操作而改变,但是targettorch.shapetorch.Size([N]),是一个标量而不是一个矩阵,那么如何按照上面的例子中的运算方法进行交叉熵的计算?

例如:

import torch
import torch.nn.functional as F

pred_score = torch.tensor([[13., 3., 2., 5., 1.],
                           [1., 8., 20., 2., 3.],
                           [1., 14., 3., 5., 3.]])
print(pred_score)
pred_score_soft = F.softmax(pred_score, dim=1)
print(pred_score_soft)
pred_score_soft_log = pred_score_soft.log()
print(pred_score_soft_log)

它的结果为:

tensor([[13.,  3.,  2.,  5.,  1.],
        [ 1.,  8., 20.,  2.,  3.],
        [ 1., 14.,  3.,  5.,  3.]])
tensor([[9.9960e-01, 4.5382e-05, 1.6695e-05, 3.3533e-04, 6.1417e-06],
        [5.6028e-09, 6.1442e-06, 9.9999e-01, 1.5230e-08, 4.1399e-08],
        [2.2600e-06, 9.9984e-01, 1.6699e-05, 1.2339e-04, 1.6699e-05]])
tensor([[-4.0366e-04, -1.0000e+01, -1.1000e+01, -8.0004e+00, -1.2000e+01],
        [-1.9000e+01, -1.2000e+01, -6.1989e-06, -1.8000e+01, -1.7000e+01],
        [-1.3000e+01, -1.5904e-04, -1.1000e+01, -9.0002e+00, -1.1000e+01]])

如何与一个标量target进行计算?

四、分析

F.Cross_entropy(input, target)函数中包含了 s o f t m a x softmax softmax l o g log log的操作,即网络计算送入的input参数不需要进行这两个操作。

例如在分类问题中,input表示为一个torch.Size([N, C])的矩阵,其中, N N N为样本的个数, C C C是类别的个数,input[i][j]可以理解为第 i i i个样本的类别为 j j j的Scores,Scores值越大,类别为 j j j的可能性越高,就像在代码块中所体现的那样。

同时,一般我们将分类问题的结果作为lable表示时使用one-hot embedding,例如在手写数字识别的分类问题中,数字0的表示为 [ 1 0 0 0 0 0 0 0 0 0 ] \begin{bmatrix} 1 & 0 & 0 & 0 & 0& 0& 0& 0& 0& 0 \end{bmatrix} [1000000000]数字3的表示为 [ 0 0 0 1 0 0 0 0 0 0 ] \begin{bmatrix} 0 & 0 & 0 & 1 & 0& 0& 0& 0& 0& 0 \end{bmatrix} [0001000000]在手写数字识别的问题中,我们计算 l o s s loss loss的方法为 l o s s = ( y − y ^ ) 2 loss=(y-\hat y)^2 loss=(yy^)2,即求 y y y的embedding的矩阵减去pred_probability矩阵的结果矩阵的范数。

但是在这里,交叉熵的计算公式为

H ( p , q ) = − ∑ i P ( i ) log ⁡ Q ( i ) H\left( {p,q} \right) = - \sum\limits_i {P\left( i \right)\log Q\left( i \right)} H(p,q)=iP(i)logQ(i)

其中 P P P为真实值概率矩阵, Q Q Q为预测值概率矩阵

那么如果 P P P使用one-hot embedding的话,只有在 i i i为正确分类时 P ( i ) P(i) P(i)才等于 1 1 1,否则, P ( i ) P(i) P(i)等于0
例如在手写数字识别中,数字3的one-hot表示为 [ 0 0 0 1 0 0 0 0 0 0 ] \begin{bmatrix} 0 & 0 & 0 & 1 & 0& 0& 0& 0& 0& 0 \end{bmatrix} [0001000000]
对于交叉熵来说, H ( p , q ) = − ∑ i P ( i ) l o g Q ( i ) = − P ( 3 ) l o g Q ( 3 ) = − l o g Q ( 3 ) H(p,q)=- \sum\limits_iP(i)logQ(i)=-P(3)logQ(3)=-logQ(3) H(p,q)=iP(i)logQ(i)=P(3)logQ(3)=logQ(3)发现 H ( p , q ) H(p,q) H(p,q)的计算不依赖于 P P P矩阵,而仅仅与 P P P的真实类别的 i n d e x index index有关

五、总结

所以,我的理解是,在one-hot编码的前提下,在pytorch代码中target不需要以one-hot形式表示,而是直接用scalar,scalar的值则是真实类别的index。所以交叉熵的公式可表示为:
H ( p , q ) = − ∑ i P ( i ) l o g Q ( i ) = − P ( m ) l o g Q ( m ) = − l o g Q ( m ) H(p,q)=- \sum\limits_iP(i)logQ(i)=-P(m)logQ(m)=-logQ(m) H(p,q)=iP(i)logQ(i)=P(m)logQ(m)=logQ(m)
其中, m m m表示真实类别。

F.cross_entropy和F.binary_cross_entropyPyTorch的两个常用的损失函数,用于分类和二分类任务。 F.cross_entropy的输入包括两个参数:input和target。其,input是模型的输出,target是真实标签。input的形状为(N, C),N表示样本数量,C表示类别数量。target的形状为(N,),每个元素表示对应样本的真实类别索引。 F.binary_cross_entropy的输入也包括两个参数:input和target。其,input是模型的输出,target是真实标签。input的形状为(N, ),N表示样本数量,每个元素表示对应样本的预测概率或得分。target的形状为(N, ),每个元素表示对应样本的真实标签(0或1)。 要将F.cross_entropy的输入转化为F.binary_cross_entropy的输入,可以按照以下步骤进行: 1. 对于input,使用softmax函数将其转化为概率分布。可以使用torch.softmax(input, dim=1)。 2. 对于target,如果原来的target是类别索引,则需要将其转化为二分类标签。可以使用torch.eye(C)[target],其C表示类别数量。 具体代码如下: ``` import torch import torch.nn.functional as F # 假设input和target分别为F.cross_entropy的输入 input = torch.randn(10, 5) target = torch.tensor([2, 0, 1, 4, 3, 1, 2, 0, 3, 4]) # 将input转化为概率分布 input_prob = F.softmax(input, dim=1) # 将target转化为二分类标签 target_binary = torch.eye(5)[target] # 使用F.binary_cross_entropy计算损失 loss = F.binary_cross_entropy(input_prob, target_binary) ```
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值