多分类问题,一个样本的交叉熵损失函数为:
L C E = − ∑ c = 1 M y o , c l o g ( p o , c ) L_{CE}=-\sum^M_{c=1} y_{o,c} log(p_{o,c}) LCE=−∑c=1Myo,clog(po,c)
其中:
- M:类别数
- y o , c : y o 是 o n e − h o t 编码的向量,代表这个样本的真实标签。 c 为某位置上分量 y_{o,c}:y_o是one-hot编码的向量,代表这个样本的真实标签。c为某位置上分量 yo,c:yo是one−hot编码的向量,代表这个样本的真实标签。c为某位置上分量
- p o , c : 模型预测样本 o 属于类别 c 的概率。 p_{o,c}:模型预测样本 o 属于类别 c 的概率。 po,c:模型预测样本o属于类别c的概率。
- p o , c : 通常是 s o f t m a x 计算。再 t o r c h 中的 C E 方法自动对输入的 l o g i t s 先算 s o f t m a x 再算 C E p_{o,c}:通常是softmax计算。再torch中的CE方法自动对输入的logits先算softmax再算CE po,c:通常是softmax计算。再torch中的CE方法自动对输入的logits先算softmax再算CE
举例计算
假设的真实标签和模型预测:
- 样本1的真实标签是类别3(one-hot编码向量为[0, 0, 1, 0, 0])。
- 样本2的真实标签是类别1(one-hot编码向量为[1, 0, 0, 0, 0])。
- 样本3的真实标签是类别5(one-hot编码向量为[0, 0, 0, 0, 1])。
- 模型预测的概率分布为:
- 样本1的预测概率:[0.1, 0.2, 0.5, 0.1, 0.1]。
- 样本2的预测概率:[0.3, 0.3, 0.1, 0.2, 0.1]。
- 样本3的预测概率:[0.2, 0.1, 0.1, 0.1, 0.5]。
计算交叉熵损失的步骤:
计算每个样本的损失:
计算对数概率:
计算最终损失:
求和取平均
所以一个batch内的多分类问题中的CE损失公式可以表示为:
再来看看单个样本的InfoNCE:
出自MoCo:
可以理解为,对一个q,计算它和batch(K+1个)中每个k的相似度
q
∗
k
i
q*k_{i}
q∗ki得到一个相似度向量。我们最大化q和自己正样本k+的相似度。相当于把向量
q
∗
k
i
q*k_{i}
q∗ki看作CE的输入,one-hot向量为:正样本上为1,其他位置为0的。也就是对相似度做K+1的CE。使得和正样本的相似度最高。
batch大小为N的InfoNCE: