多分类用的交叉熵损失函数,用这个 loss 前面不需要加 Softmax 层。
这里损害函数的计算,按理说应该也是原始交叉熵公式的形式,但是这里限制了 target 类型为 torch.LongTensr,而且不是多标签意味着标签是 one-hot 编码的形式,即只有一个位置是 1,其他位置都是 0,那么带入交叉熵公式中化简后就成了下面的简化形式。
loss(x,label)=−wlabellogexlabel∑Nj=1exj=wlabel[−xlabel+log∑j=1Nexj]
这里的 x∈ℝN,是没有经过 Softmax 的激活值,N 是 x 的维度大小(或者叫特征维度);label∈[0,C−1]是标量,是对应的标签,可以看到两者维度是不一样的。C 是要分类的个数。w∈ℝC是维度为 C的向量,表示标签的权重,样本少的类别,可以考虑把权重设置大一点。