pytorch | labelSmooth
labelSmooth
labelSmooth也称为标签平滑: 提高模型的泛化能力,对于未知域任务,分类任务,可以提高精度。主要解决噪音问题:例如在如果具有噪音的标签参与训练,就会造成过拟合的情况。
Loss
=
−
∑
i
=
1
K
p
i
log
q
i
p
i
=
{
1
,
if
(
i
=
y
)
0
,
if
(
i
≠
y
)
\begin{array}{l} \text {Loss}=-\sum_{i=1}^{K} p_{i} \log q_{i} \\ p_{i}=\left\{\begin{array}{l} 1, \text {if }(i=y) \\ 0, \text {if }(i \neq y) \end{array}\right. \end{array}
Loss=−∑i=1Kpilogqipi={1,if (i=y)0,if (i=y)
上式是交叉熵函数损失,为了达到最好的拟合效果,最优的预测概率分布为:
Z
i
=
{
+
∞
,
if
(
i
=
y
)
0
,
if
(
i
≠
y
)
Z_{i}=\left\{\begin{array}{l} +\infty, \text {if }(i=y) \\ 0, \text {if }(i \neq y) \end{array}\right.
Zi={+∞,if (i=y)0,if (i=y)
改进之后,更新后的分布就相当于往真实分布中加入了噪声,为了便于计算,该噪声服从简单的均匀分布:
Loss
=
−
∑
i
=
1
K
p
i
log
q
i
⟼
Loss
i
=
{
(
1
−
ε
)
∗
Loss
,
i
f
(
i
=
y
)
ε
∗
Loss
,
if
(
i
≠
y
)
\text { Loss }=-\sum_{i=1}^{K} p_{i} \log q_{i} \longmapsto \operatorname{Loss}_{i}=\left\{\begin{array}{l} (1-\varepsilon)^{*} \operatorname{Loss}, i f(i=y) \\ \varepsilon^{*} \operatorname{Loss}, \text {if }(i \neq y) \end{array}\right.
Loss =−i=1∑Kpilogqi⟼Lossi={(1−ε)∗Loss,if(i=y)ε∗Loss,if (i=y)
Z
i
=
{
+
∞
,
i
f
(
i
=
y
)
0
,
i
f
(
i
≠
y
)
⟶
Z
i
=
{
log
(
k
−
1
)
(
1
−
ε
)
ε
+
α
,
i
f
(
i
=
y
)
α
,
i
f
(
i
≠
y
)
Z_{i}=\left\{\begin{array}{l} +\infty, i f(i=y) \\ 0, i f(i \neq y) \end{array} \quad \longrightarrow \quad Z_{i}=\left\{\begin{array}{l} \log \frac{(k-1)(1-\varepsilon)}{\varepsilon+\alpha}, i f(i=y) \\ \alpha, i f(i \neq y) \end{array}\right.\right.
Zi={+∞,if(i=y)0,if(i=y)⟶Zi={logε+α(k−1)(1−ε),if(i=y)α,if(i=y)
阿尔法可以是任意实数,最终通过抑制正负样本输出差值,使得网络能有更好的泛化能力。
pytorch实现
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps=0.1, reduction='mean'):
super(LabelSmoothingCrossEntropy, self).__init__()
self.eps = eps
self.reduction = reduction
def forward(self, output, target):
c = output.size()[-1]
log_preds = F.log_softmax(output, dim=-1)
if self.reduction=='sum':
loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=-1)
if self.reduction=='mean':
loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)