CLIP的loss计算
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
ground_truth = torch.arange(len(logits_per_image)).long()
ground_truth = ground_truth.cuda(args.local_device_rank, non_blocking=True)
total_loss = (
loss_img(logits_per_image, ground_truth)
+ loss_txt(logits_per_text, ground_truth)
) / 2
logit_scale参数的作用是在计算图像和文本之间的相似度分数时进行缩放,它本质上是一个可学习的温度参数(temperature parameter)。 np.log(1 / 0.07) 约等于2.659为经验值。
图像特征和文本特征矩阵的转置相乘,得到M*M的矩阵,M为batch size,代表每个图像特征向量和所有文本向量的相似度。
文本特征矩阵乘以图像特征的转置矩阵,得到每个文本特征向量和所有图像特征向量的相似度。
ground_truth为正确类别的索引值,即第一行图像特征对应的正确类别为第一个文本特征,第二行图像特征对应第二个文本特征,正好是对角线的位置。
最终loss为这两个loss的均值。
nn.CrossEntropyLoss() 的目标标签有两种格式,类别的索引值或者类别的概率值。
import torch
import torch.nn as nn
logits = torch.rand(10, 10)
targets = torch.arange(10)
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, targets)
print("Cross Entropy Loss:", loss.item())
#等价于
m1 = nn.LogSoftmax(dim=-1)
m2 = nn.NLLLoss()
y1 = m1(logits)
y2 = m2(y1,targets)
print(y2.item())
参考:
深入理解交叉熵损失CrossEntropyLoss - nn.NLLLoss(Negative Log-Likelihood Loss)