CrossEntropyLoss和手写交叉熵的区别
总结
torch.nn.CrossEntropyLoss(weight, reduction=“mean”)的计算方式和一般做法不一样,进行加权计算后,不是直接loss除以batch数量,而是batch中每个数据的标签对应的类别权重 把batch个权重加起来作为除数,被除数是loss
交叉熵损失函数原理
这里就不做介绍了,很多博客和视频已经讲的很清楚了。这里引用一下@b站up同济子豪兄之前的一次动态截图
不带权重的交叉熵计算
这里已经也有很多介绍了,不带权重的情况下,torch.nn.CrossEntropyLoss(reduction=“mean”)的计算和公式里一样。
另外,虽然pytorch里面CrossEntropyLoss的target输入要求是标签或者Probabilities,但是target从标签(batch_size, 1(num_classes))转换成独热编码(torch.nn.functional.one_hot),输出结果是一样的。我感觉是转换成独热编码就相当于target是那个Probabilities了,具体计算过程没细看,结果和手搓的代码一样。
import torch.nn as nn
import torch.nn.functional as F
weight=torch.tensor([1,1,1,1]).float() # 这里权重都是1,所以结果一样
loss_func_mean = nn.CrossEntropyLoss(weight, reduction="mean")
def ce_loss(y_pred, y_true, weight):
# 计算 log_softmax,这是更稳定的方式
log_probs = F.log_softmax(y_pred, dim=-1)
# 计算加权损失
loss = -(weight * y_true * log_probs).sum(dim=-1).mean()
return loss
pre_data = torch.tensor([[0.8, 0.5, 0.2, 0.5],
[0.2, 0.9, 0.3, 0.2],
[0.4, 0.3, 0.7, 0.1],
[0.1, 0.2, 0.4, 0.8]], dtype=torch.float)
tgt_index_data = torch.tensor([0,
1,
2,
3], dtype=torch.long)
tgt_onehot_data = torch.tensor([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]], dtype=torch.float)
print(loss_func_mean(pre_data, tgt_index_data))
print(loss_func_mean(pre_data, tgt_onehot_data))
print(ce_loss(pre_data, tgt_onehot_data, weight))
print(F.cross_entropy(pre_data, tgt_onehot_data, weight))
输出结果为
tensor(1.0315)
tensor(1.0315)
tensor(1.0315)
tensor(1.0315)
带权重的交叉熵
而如果将weight的权重改变,比如
weight=torch.tensor([1,2,3,4]).float()
# 那么结果就会变为
tensor(1.0226)
tensor(2.5566)
tensor(2.5566)
tensor(2.5566)
找了很久的知乎、博客等讲解,都没有对比torch.nn.CrossEntropyLoss()加上权重后和公式计算出来的结果是否一致。去看源代码,发现计算公式在pycharm中显示的比较抽象,后面又才看了官方文档介绍,才弄懂。现在放上官方计算公式介绍图(官方文档链接: CrossEntropyLoss.)
注意看,这里的reduction="mean"的计算公式和平常我们认为的不一样,分母
∑
n
=
1
N
w
y
n
×
1
{
y
n
≠
i
g
n
o
r
e
_
i
n
d
e
x
}
\sum_{n=1}^N w_{y_n} \times 1\{y_n \neq \mathrm{ignore\_index}\}
∑n=1Nwyn×1{yn=ignore_index}中计算的是n从1到N,本来还以为是写错了,算了之后发现确实是N(N=batch)。
因此将手打公式改正之后结果就和reduction="mean"一样了。当然用"sum"的时候就是一样的,懒得放结果。
import torch.nn as nn
import torch.nn.functional as F
weight=torch.tensor([1,2,3,4]).float()
loss_func_mean = nn.CrossEntropyLoss(weight, reduction="mean")
def ce_loss(y_pred, y_true, weight):
# 计算 log_softmax,这是更稳定的方式
log_probs = F.log_softmax(y_pred, dim=-1)
# 计算加权损失
loss = -(weight * y_true * log_probs).sum(dim=-1)
fenmu = [weight[torch.argmax(y_true, dim=-1)[i]] for i in range(y_true.shape[0])]
fenmu_sum = torch.tensor(fenmu).sum()
mean_loss = loss.sum() / fenmu_sum
return mean_loss
pre_data = torch.tensor([[0.8, 0.5, 0.2, 0.5],
[0.2, 0.9, 0.3, 0.2],
[0.4, 0.3, 0.7, 0.1],
[0.1, 0.2, 0.4, 0.8]], dtype=torch.float)
tgt_index_data = torch.tensor([0,
1,
2,
3], dtype=torch.long)
tgt_onehot_data = torch.tensor([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]], dtype=torch.float)
print(loss_func_mean(pre_data, tgt_index_data))
print(loss_func_mean(pre_data, tgt_onehot_data))
print(ce_loss(pre_data, tgt_onehot_data, weight))
print(F.cross_entropy(pre_data, tgt_onehot_data, weight))
# 结果为:
tensor(1.0226)
tensor(2.5566)
tensor(1.0226)
tensor(2.5566)
当然其实感觉用哪个都问题不大,只是说pytorch这种计算方式的话,如果四类数量是[1, 2, 3, 4],那么weight=torch.tensor([1., 1/2., 1/3., 1/4.])和weight=torch.tensor([4., 2., 4/3., 1.])结果都一样。
但是对网络训练来说,不知道哪种会好一点,直接都是用的手打的代码训练,自己都试试看吧