Label Smoothing Regularization理论和代码分析

文章目录

理论

优化策略5 Label Smoothing Regularization_LSR原理分析

代码

我选择的transformer的代码

class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False)
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        assert x.size(1) == self.size
        
        true_dist = x.data.clone() # 复制x
        
        true_dist.fill_(self.smoothing / (self.size - 2)) # e*u(k)填充
        
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) # scatter_(input, dim, index, src) 按行dim=1赋值,index以target为准 赋的值(1-smooth)*q(y|x)
        
        true_dist[:, self.padding_idx] = 0
        
        mask = torch.nonzero(target.data == self.padding_idx) # [[2]]获取target数据等于padding_的索引
        
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0) # index_fill_(dim,index,val)在dim维度填充index为2值为0
            
        self.true_dist = true_dist
        return self.criterion(x, Variable(true_dist, requires_grad=False)) # KL 散度

查看变换后的真值分布

#Example of label smoothing. 可视化真值分布
crit = LabelSmoothing(5, 0, 0.4)
predict = torch.FloatTensor([[0, 0.2, 0.7, 0.1, 0],
                             [0, 0.2, 0.7, 0.1, 0], 
                             [0, 0.2, 0.7, 0.1, 0]])
v = crit(Variable(predict.log()), 
         Variable(torch.LongTensor([2, 1, 0])))

#Show the target distributions expected by the system. 真值分布
plt.imshow(crit.true_dist)

在这里插入图片描述
查看对loss的影响,可以看到,随着confidence增大,这里是有一个惩罚,即loss增大。

#x增大在一定程度上 loss增大
crit = LabelSmoothing(5, 0, 0.2)
def loss(x):
    d = x + 3 * 1
    predict = torch.FloatTensor([[0, x / d, 1 / d, 1 / d, 1 / d],
                                 ])
    #print(predict)
    return crit(Variable(predict.log()),
                 Variable(torch.LongTensor([1]))).item()
plt.plot(np.arange(1, 100), [loss(x) for x in range(1, 100)])

在这里插入图片描述

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值