在网上找了几个帖子发现解释都比较冗余,于是自己总结了一下。
利用网络的output和真实标签target计算softmax+交叉熵损失的过程为:
(1)sm=softmax(output); 计算softmax
(2)logsm=log(sm); 取对数
(3)loss=-logsm.*target 取负后与onehot标签对应相乘完成交叉熵计算
NLLLoss相当于只执行(3),因此需要手动完成(1)和(2),其输入应为logsm;
CrossEntropyLoss相当于同时完成三步。