本文内容:简单回顾交叉熵损失函数,解释pytorch中该函数的参数(reduce, reduction, size_average)含义。
交叉熵函数定义
背景知识:
给定一个batch的预测分数(softmax归一化后)
p
p
r
e
d
∈
R
B
×
C
p_{pred}\in \mathbb{R}^{B\times C}
ppred∈RB×C,与其真值标签(one-hot)
p
g
t
∈
R
B
×
C
p_{gt}\in \mathbb{R}^{B\times C}
pgt∈RB×C. 其中,
B
B
B为batchsize,
C
C
C为类别总数。那么,其中样本
b
b
b的交叉熵损失可以计算为:
l o s s b = p g t ( b ) ∑ c i = 1 C l o g ( p g t ( b , c i ) p p r e d ( b , c i ) ) loss_b=p_{gt}(b)\sum_{ci=1}^{C} log (\frac{p_{gt}(b,ci)} {p_{pred}(b,ci)}) lossb=pgt(b)ci=1∑Clog(ppred(b,ci)pgt(b,ci))
pytorch
中的参数定义
pytorch对应函数:torch.nn.CrossEntropyLoss
, documents 链接: link
CLASS torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)
要解释的参数: reduce
,size_average
,reduction
- 在
torch.nn.CrossEntropyLoss
函数中,能够设置的只是不同样本的交叉熵返回形式,例如:沿batch求和 或 逐样本返回等。但每个样本的交叉熵Loss,都是沿着所有类别求和后的结果(即上述公式),这点无法自定义。(这点与KL散度不同,不过很符合loss本身的物理含义) reduce
参数为第一优先级的控制参数,控制返回的loss是 逐个样本的 (返回 B × 1 B\times 1 B×1),还是一整个batch的(返回 1 × 1 1\times 1 1×1)。
具体来说,如果 reduce
为False: 返回每个样本的loss:
{
l
o
s
s
b
}
b
=
1
,
⋯
,
B
B
\{loss_b\}^B_{b=1,\cdots,B}
{lossb}b=1,⋯,BB; 如果reduce
为True
: 将所有样本的loss融合,对一个batch返回一个标量(
1
×
1
1\times 1
1×1), 融合策略由参数size_average
控制。
size_average
为第二优先级的参数,在对一整个batch返回一个Loss的设置下,size_average
用来控制对各个样本Loss的融合方式。
具体来说,size_average
为bool
类型,设置为True
时,表示对各个样本的loss,求平均;设置为False
时,表示对各个样本的Loss,求和。此处,各个样本的loss,即为上式中的
l
o
s
s
b
loss_b
lossb.
注意:size_average
当且仅当在reduce
为True
时,才被考虑。在reduce
为False
时,size_average
被屏蔽。
reduction
为第三优先级的参数。它只有在size_average
和reduce
都没有额外指定时才生效。
reduction
是reduction
和size_average
的融合(功能上是等价的),可以单独用reduction
一个函数,起到这两个参数组合的效果。
reduction
为str
,可选的有mean, sum, none
。
none
: 返回每个样本的loss, 同reduce=False
mean
: 返回所有样本的loss均值,同reduce=True, size_average=True
mean
: 返回所有样本的loss之和,同reduce=True, size_average=False
函数使用注意:torch.nn.CrossEntropyLoss
对输入的预测,默认为未归一化前的logits,不要将softmax的输出送进去。
官方例子:
>>> # Example of target with class indices
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()