transformer输出的logits与target如何衡量误差?
一、交叉熵公式及代码
1. 交叉熵公式
H ( p , q ) = − ∑ i P ( i ) l o g Q ( i ) H(p, q) = -\sum_i P(i)logQ(i) H(p,q)=−∑iP(i)logQ(i)
P: target, Q: prediction
2. 代码实现
import torch.nn.functional as F
# logits.size=[B,L,C] C is dimention
# target.size=[B,L]
# after reshape, logits.size=[B*L, C]
# after reshape, target.size=[B*L]
loss = F.cross_entropy(logits.reshape(-1,logits.size(-1)), target.rashape(-1))
函数的常用参数:
参数名 | shape |
---|---|
input | [N,C], C 是类别 |
target | [N], 0<=target[i] <=C-1,target[i]是真实类别的index |
二、KL散度及代码
1. KL散度公式
D K L ( P ∣ ∣ Q ) = ∑ x ∈ χ P ( x ) l o g ( P ( x ) Q ( x ) ) D_{KL}(P||Q)=\sum_{x\in\chi}P(x)log(\frac{P(x)}{Q(x)}) DKL(P∣∣Q)=∑x∈χP(x)log(Q(x)P(x))
P&Q均为概率分布,P代表真实概率,Q代表预测概率,sum=1
2. 代码实现
import torch.nn.functional as F
x = torch.randn((4, 5))
y = torch.randn((4, 5))
# 负对数概率
# 如果不经过log_softmax会出现计算结果为负数的情况
logp_x = F.log_softmax(x, dim=-1)
p_y = F.softmax(y,dim=-1)
loss = F.kl_div(logp_x, p_y, reduction='batchmean')
# notice: reduction有很多种,具体在实现时查阅
- 第一个传入的参数需要是一个对数概率矩阵
- 第二个传入的参数是概率矩阵
- 设两个概率分布X,Y,若想用Y指导X,则第一个参数X,第二个参数Y
函数的常用参数:
reduction | 含义 |
---|---|
‘none’ | 输出损失与输入(x)形状相同,各点的损失单独计算,不会对结果做 reduction |
‘batchmean’ | 输出损失的形状为[],输出为所有损失的总和除以批量大小,与kl散度的数学定义一致 |
‘sum’ | 输出损失的形状为[],输出为所有损失的总和 |
‘mean’ | 输出损失的形状为[],输出为所有损失的平均值 |
‘default’ | ‘mean’ |
3. 异常情况分析
- kl散度为负数:检查X是否经过log_softmax, Y是否经过softmax
- kl散度结果为一个很大的值,不像loss:检查softmax是否指定了特定的维度。
如图所示,如果不显式指定dim,那么logits将会默认以dim=0的方式softmax,其实是有问题的。下图中只有dim=2时的kl_div正确。