前言
CrossEntropyLoss 是分类任务中经常使用的损失函数,但是在某些情况下,其优化效果并不是很好,本文介绍了最近出现的对CrossEntropyLoss进行改进的新损失函数
一、CrossEntropyLoss
公式:
上图是pytorch版实现的CrossEntropyLoss,可以看出其主要作用是优化了正例对应的logits(logits介绍见上一篇博文)并使其无限大与其他类别的logits,这种过强的要求可能使得模型难以训练至收敛,因而出现了LabelSM版本的CrossEntropy,以及Sparse Softmax
顺带提一句,pytorch版本的CrossEntropyLoss是对dim=1进行的计算,
因而我们需要把各个类别的logits放到dim=1上来
二、SmoothCrossEntropy
公式:
SmoothCrossEntropy对应的公式为:
优势:
当 label smoothing 的 loss 函数为 cross entropy 时,如果 loss 取得极值点,则正确类和错误类的 logit 会保持一个常数距离,且正确类和所有错误类的 logits 相差的常数是一样的,都是
log
(
K
−
(
K
−
1
)
α
α
)
\log(\frac{K-(K-1)\alpha}{\alpha})
log(αK−(K−1)α)
证明见:知乎
code:
class SmoothCrossEntropy(nn.Module):
"""
loss = SmoothCrossEntropy()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)
"""
def __init__(self, alpha=0.1):
super(SmoothCrossEntropy, self).__init__()
self.alpha = alpha
def forward(self, logits, labels):
num_classes = logits.shape[-1]
alpha_div_k = self.alpha / num_classes
target_probs = F.one_hot(labels, num_classes=num_classes).float() * \
(1. - self.alpha) + alpha_div_k
loss = -(target_probs * torch.log_softmax(logits, dim=-1)).sum(dim=-1)
return loss.mean()
代码如下(示例):
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
三、Sparse Softmax
公式:
优势:
这是苏神在CAIL2020中提出的一个类别数过多的预测问题损失函数,我们只需要优化前topK项,使得
s
t
s_t
st 大于topk即可,不必要大于最小的
log
(
n
−
1
)
\log(n-1)
log(n−1)
,只需大于topk中最小的
l
o
g
(
k
)
log(k)
log(k)即可,可以防止过度训练
证明
pytoch版本:
def Sparse_Softmax(predictions, token_type_id, input_ids, vocab_size):
predictions = predictions[:, :-1].contiguous()
target_mask = token_type_id[:, 1:].contiguous()
"""
target_mask : 句子a部分和pad部分全为0, 而句子b部分为1
"""
predictions = predictions.view(-1, vocab_size)
labels = input_ids[:, 1:].contiguous()
labels = labels.view(-1)
target_mask = target_mask.view(-1).float()
# 正loss
pos_loss = predictions[list(range(predictions.shape[0])), labels]
# 负loss
y_pred = torch.topk(predictions, k=args.k_sparse)[0]
neg_loss = torch.logsumexp(y_pred, dim=-1)
loss = neg_loss - pos_loss
return (loss * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响
L-Softmax、SM-Softmax、AM-Softmax待补充