pytorch=1.4标签平滑技术
标签平滑的作用是小幅度改变原有标签的值域,如[0,0,1]–>[0.1,0.1,0.8],它适用于人工的标注数据可能非完全正确的情况,可以使用标签平滑来弥补这种偏差,减小模型对某一条规律的绝对认知,防止过拟合。
在pytorch=1.4中,当使用原生的交叉熵损失(CrossEntropyLoss())时,要求标签值必须为整型。当使用标签平滑技术时候可以重新定义一个类似计算规则的标签平滑交叉熵损失函数。
## 自定义标签平滑后的交叉熵损失
class LabelSmoothingCELoss(nn.Module):
def __init__(self, classes, smoothing=0.0, dim=-1):
# smoothing 标签平滑的百分比
super(LabelSmoothingCELoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.cls = classes
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
with torch.no_grad():
true_dist = pred.data.clone()
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (self.cls - 1))
# .scatter_也是一种数据填充方法,目的仍然是将self.confidence填充到true_dist中
# 第一个参数0/1代表填充的轴,大多数情况下使用scatter_都使用纵轴(1)填充
# 第二个参数就是self.confidence的填充规则,即填充到第几列里面,如[[1], [2]]代表填充到第二列和第三列里面
# 第三个参数就是填充的数值,int/float
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
if __name__ == "__main__":
predict = torch.FloatTensor([[1, 1, 1, 1, 1]])
target = torch.LongTensor([2])
LSL = LabelSmoothingLoss(3, 0.03)
print(LSL(predict, target))
# tensor(1.6577)
#修改huggingface transformer中的源码:
#路径:/usr/local/lib/python3.7/site-packages/transformers/modeling_bert.py
1036 pooled_output = self.dropout(pooled_output)
1037 logits = self.classifier(pooled_output)
1038
1039 outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
1040
1041 if labels is not None:
1042 if self.num_labels == 1:
1043 # We are doing regression
1044 loss_fct = MSELoss()
1045 loss = loss_fct(logits.view(-1), labels.view(-1))
1046 else:
# 注释掉之前的交叉熵损失函数
1047 # loss_fct = CrossEntropyLoss()
# 导入之前的LabelSmoothingCELoss,填入类别参数和平滑系数
1048 from smoothing import LabelSmoothingCELoss
1049 loss_fct = LabelSmoothingCELoss(3, smoothing=0.1)
1050 loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1051 outputs = (loss,) + outputs
1052
1053 return outputs # (loss), logits, (hidden_states), (attentions)