交叉熵损失在做一件什么事?
看公式:
x是预测(不需要softmax归一化),y是label, N是batch维度的数量,交叉熵损失,干了三件事.
1. 对输入在类别维度求softmax
2. 多softmax后的数,求log
3. 对(样本数, 类别数)为shape的tensor计算NLLLoss.
其中,NLLloss做的就是log取负, 和one-hot编码点乘.相加得到最终的总损失,因为reduction默认为mean,所以除以样本数.看以下代码.
代码实现
#
import torch
import torch.nn as nn
#
# cross entropy loss = softmax + log + nllloss
# 先softmax, 再 log,
# 初始化 input_ 和 target
input_ = torch.randn(3,3)
target = torch.tensor([0,2,1])
mask = torch.zeros(3,3)
mask[0,0] = 1
mask[1,2] = 1
mask[2,1] = 1
# 1.0 输入softmax
sft_ = nn.Softmax(dim = -1)(input_)
# 2.0 log
log_ = torch.log(sft_)
# 3.0 nllloss
loss = nn.NLLLoss()
print("split loss")
print(loss(log_, target))
# 4.0 crossentropy
print("ce loss")
loss = nn.CrossEntropyLoss()
print(loss(input_, target))
print("manual loss")
neg_log = 0 - log_
print(torch.sum(mask *neg_log ) / 3)
# ----------输出--------------
>> loss_function python crossEntropyLoss.py
>> split loss
>> tensor(1.2294)
>> ce loss
>> tensor(1.2294)
>> manual loss
>> tensor(1.2294)