本文主要讲述CrossEntropyLoss函数内部的计算逻辑,我觉得摆一大堆公式都不如直接上代码来的实在。
首先,我们先设置好两个变量,input与y_target。
importtorchimporttorch.nnasnnimportnumpyasnpa=np.arange(1,13).reshape(3,4)b=torch.from_numpy(a)input=b.float()print('input:\n',input)
可以看到input矩阵如下:
tensor([[1.,2.,3.,4.],[5.,6.,7.,8.],[9.,10.,11.,12.]])
然后设置y_target:
y_target=torch.tensor([1,2,3])print('y_target:\n',y_target)
这个不用打印大家也应该知道是什么样了。
input是一个【3 4】的矩阵,y-target是一个【1 3】的矩阵。input是预测值,代表有三个样本,四个类别。y-target代表三个样本的真实标签。
crossentropyloss=nn.CrossEntropyLoss(reduction='none')crossentropyloss_output=crossentropyloss(x_input,y_target)print('crossentropyloss_output:\n',crossentropyloss_output)
经过CrossEntropyLoss后,最终结果为:
crossentropyloss_output:tensor([2.4402,1.4402,0.4402])
下面我们来剖析它的计算过程。其实CrossEntropyLoss相当于softmax + log + nllloss。
不信可以计算一遍:
softmax_func=nn.Softmax(dim=1)soft_output=softmax_func(input)print('soft_output:\n',soft_output)log_output=torch.log(soft_output)print('log_output:\n',log_output)nllloss_func=nn.NLLLoss(reduction='none')nllloss_output=nllloss_func(log_output,y_target)print('nllloss_output:\n',nlloss_output)
最终结果是一样的:
nllloss_output:tensor([2.4402,1.4402,0.4402])
softmax、log这两个函数应该都比较好理解。
下面主要讲解一下nllloss这个损失函数。
有了经过softmax与log的矩阵,我们叫它矩阵A。
tensor([[-3.4402,-2.4402,-1.4402,-0.4402],[-3.4402,-2.4402,-1.4402,-0.4402],[-3.4402,-2.4402,-1.4402,-0.4402]])
还有真实标签:
tensor([1, 2, 3])
y-target中第一个数字,代表第一个样本的真实标签,第一个数字是1,代表第一个样本的真实标签为1。所以取出矩阵A的第一行第二个数字-2.4402,并加上一个负号。之后也是这样,依次取出-1.4402、-0.4402。
所以有了最终结果:
crossentropyloss_output:
tensor([2.4402, 1.4402, 0.4402])
reduction这个参数着重提一下,它一般有none、sum、mean等几个选项,none就是没有别的附加操作,sum就是把这个几个损失加和,mean就是把这几个损失求平均。