参考
从自监督到全监督!Google 提出新损失函数SupCon,准确率提升2%!-CSDN博客
背景
Google NeurIPS 2020
原理
损失公式如下,同一类为p,不同类为α,这样可以最大化负样本,最小化正样本。
源码
项目地址:https://github.com/HobbitLong/SupContrast/tree/master
修改的易于理解的supconloss的实现脚本,可以运行一下,看着各个变量能加深理解。这里的相似度度量用的是特征的直接相乘,且锚样本的负样本用的是除自身外的负样本,做Re-ID任务注意甄别。
import torch
import torch.nn as nn
import numpy as np
class SupConLoss(nn.Module):
def __init__(self, temperature=0.07,
base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.base_temperature = base_temperature
def forward(self, features, labels):
# 0. 确定需要的参数,包括:batch, labels, mask, 要比对的张量
batch_size = features.shape[0]
labels = labels.contiguous().view(-1, 1)
mask = torch.eq(labels, labels.T).float()
print(f"mask:\n{mask}")
# 0.1 将张量由(B, C, H*W)变为(B*C, H*W), 因为比对维度是B*C的维度,而不是只以B为维度
anchor_count = features.shape[1]
anchor_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
print(f"anchor_feature:\n{anchor_feature.shape}")
# 1. 确定每个特征与其他特征的交互,并确定正负样本
# 1.1 得到每个特征与其他特征相乘后相加的数值
# 矩阵叉乘后除一个系数,得到的矩阵以第一行举例,
# 这一行第一个值为第一个特征张量与自己交互得到的值,第二个值依次类推,第二行依次类推
# 打印了两个变量可以看一下交互关系
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, anchor_feature.T),
self.temperature)
print(f"anchor_feature:\n{anchor_feature}")
print(f"anchor_feature.T:\n{anchor_feature.T}")
print(f"anchor_dot_contrast:\n{anchor_dot_contrast.shape}")
# 1.2 得到每个特征与其他特征交互的最大值,并将每个值与最大值相除,做归一化使数值稳定
# 打印了变量的值可以理解一下
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
print(f"logits_max:\n{logits_max}")
print(f"logits:\n{logits}")
# 1.3 正负样本判断容器制作,自己与自己的张量交互要去掉,logits_mask干这事
mask = mask.repeat(anchor_count, anchor_count)
print(f"mask:\n{mask}")
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1),
0
)
mask = mask * logits_mask
print(f"logits_mask:\n{logits_mask}")
print(f"mask:\n{mask}")
# 1.4 计算loss
exp_logits = torch.exp(logits) * logits_mask
# log_prob的结果相当于除负样本乘积之和,结合公式来看,
# 这里的负样本用的是除自己之外都是负样本,没有属于同一ID的概念,使用时注意甄别
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
print(f"log_prob:\n{log_prob}")
mask_pos_pairs = mask.sum(1)
print(f"mask_pos_pairs:\n{mask_pos_pairs}")
mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
print(f"mask_pos_pairs:\n{mask_pos_pairs}")
mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
print(f"mean_log_prob_pos:\n{mean_log_prob_pos}")
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
print(f"loss:\n{loss}")
loss = loss.view(anchor_count, batch_size).mean()
print(f"loss:\n{loss}")
return loss
if __name__ == "__main__":
sup_con = SupConLoss()
features = torch.randn((3, 2, 5))
labels = torch.tensor([7, 7, 6])
sup_con.forward(features, labels)
其他
(个人理解,不对的话,随时批评!!!)
参考
盘点检索任务中的损失函数_mask = torch.eye(scores.size(0)) > .5-CSDN博客
损失
NCE(2010):公式如下,将对比学习的多分类问题转换为二分类问题:
交叉熵(CE)公式如下,info NCE与CE是一样的,只不过内部K,info是同一batch内的其他类,而CE则是固定的从全连接层取出的类。
info NCE(2018,info Noise Contrastive Estimation loss): 正例之间相互吸引,负例之间相互排斥,温度超参越小,越能感知近似的负样本,将NCE的二分类问题再转换为多分类问题,对网络来说训练更加合理。
CLIP和MOCO用的损失是info NCE
info NCE与supcon_loss的区别在于下面 S的计算(info需要转换为概率,比如相乘后用一个softmax),supcon_loss除自己外都是负样本,而info是拥有了ID的概念,类属于同一ID的为正样本,类属于不同ID的为负样本。