pytorch=1.4标签平滑技术

pytorch=1.4标签平滑技术

标签平滑的作用是小幅度改变原有标签的值域,如[0,0,1]–>[0.1,0.1,0.8],它适用于人工的标注数据可能非完全正确的情况,可以使用标签平滑来弥补这种偏差,减小模型对某一条规律的绝对认知,防止过拟合。

在pytorch=1.4中,当使用原生的交叉熵损失(CrossEntropyLoss())时,要求标签值必须为整型。当使用标签平滑技术时候可以重新定义一个类似计算规则的标签平滑交叉熵损失函数。

## 自定义标签平滑后的交叉熵损失
class LabelSmoothingCELoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
    # smoothing  标签平滑的百分比
        super(LabelSmoothingCELoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            # .scatter_也是一种数据填充方法,目的仍然是将self.confidence填充到true_dist中
            # 第一个参数0/1代表填充的轴,大多数情况下使用scatter_都使用纵轴(1)填充
            # 第二个参数就是self.confidence的填充规则,即填充到第几列里面,如[[1], [2]]代表填充到第二列和第三列里面
            # 第三个参数就是填充的数值,int/float
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
if __name__ == "__main__":
    predict = torch.FloatTensor([[1, 1, 1, 1, 1]])
    target = torch.LongTensor([2])
    LSL = LabelSmoothingLoss(3, 0.03)
    print(LSL(predict, target))
# tensor(1.6577)
#修改huggingface transformer中的源码:
#路径:/usr/local/lib/python3.7/site-packages/transformers/modeling_bert.py
1036         pooled_output = self.dropout(pooled_output)
1037         logits = self.classifier(pooled_output)
1038
1039         outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
1040
1041         if labels is not None:
1042             if self.num_labels == 1:
1043                 #  We are doing regression
1044                 loss_fct = MSELoss()
1045                 loss = loss_fct(logits.view(-1), labels.view(-1))
1046             else:
                     # 注释掉之前的交叉熵损失函数
1047                 # loss_fct = CrossEntropyLoss()
                     # 导入之前的LabelSmoothingCELoss,填入类别参数和平滑系数
1048                 from smoothing import LabelSmoothingCELoss
1049                 loss_fct = LabelSmoothingCELoss(3, smoothing=0.1)
1050                 loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1051             outputs = (loss,) + outputs
1052
1053         return outputs  # (loss), logits, (hidden_states), (attentions)




  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值