论文地址:Barlow Twins: Self-Supervised Learning via Redundancy Reduction (ICML 2021)
代码地址:Barlow Twins Code
摘要
自监督学习(SSL)正迅速在大型计算机视觉的基础上缩小与监督方法之间的性能差距。SSL方法成功的关键是学习对输入样本的扭曲不变的嵌入。然而,这种方法的一个经常出现的问题是存在平凡常数解。大多数当前的方法通过精心的细节实现(调参?)来避免这样的解决方案。我们提出了一种客观函数,通过测量两个相同网络的输出之间的交叉相关矩阵来自然地避免模型坍塌,并使其尽可能接近单位矩阵。这导致经过数据增强的样本的嵌入向量相似,同时最小化这些向量的组件之间的冗余。该方法被称为Barlow Twins,这是因为将神经科学家H. Barlow的冗余减少原则应用于一对相同的网络。Barlow Twins不需要大批量处理,也不需要网络孪生体之间的不对称性,比如预测网络、梯度停止或权重更新的移动平均。有趣的是,它受益于非常高维的输出向量。Barlow Twins在ImageNet上在低数据情况下的半监督分类方面优于先前的方法,并在使用线性分类器头的ImageNet分类以及分类和目标检测的转移任务中与当前的最新技术相媲美。
方法
简单来说就是,一张图片X通过某种方法生成一对正样本(图中YA和YB),经过同样的编码器和MLP等结构得到嵌入表达(图中ZA和ZB)。因为他们的源是一样的,所以希望他们输出的表达尽可能相似,但在不同的方面提取的特征是不同的。
Loss计算
先看公式(2):上图中的
Z
A
,
Z
B
Z^A,Z^B
ZA,ZB大小为 [B, D], B是batchsize,D是dim是特征维度,分子
z
b
,
i
A
z_{b,i} ^{A}
zb,iA表示的是在一个batch中的来自图片A的第b个样本的第i个元素,分母做了一个归一化。整个大C大小就是[D, D]
再看公式(1):Loss由两部分构成:Invariance term指的是对角线上的元素,意为不同的Y也可以得到相似的Representation;Redundancy reduction term指的是非对角线元素,意为Representation中的每个元素表达含义不同。因为这两项的个数不一样,所以加了一个超参数lambd。(跟个人认为这里肯定有更好的方法计算出一个lambd)
pytorch伪代码
核心代码
class BarlowTwins(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
# 定义backbone
self.backbone = torchvision.models.resnet50(zero_init_residual=True)
self.backbone.fc = nn.Identity()
# 定义projector
sizes = [2048] + list(map(int, args.projector.split('-')))
layers = []
for i in range(len(sizes) - 2):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
layers.append(nn.BatchNorm1d(sizes[i + 1]))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
self.projector = nn.Sequential(*layers)
# normalization layer for the representations z1 and z2
self.bn = nn.BatchNorm1d(sizes[-1], affine=False)
def forward(self, y1, y2):
z1 = self.projector(self.backbone(y1))
z2 = self.projector(self.backbone(y2))
# empirical cross-correlation matrix
c = self.bn(z1).T @ self.bn(z2) # [D, D]希望它更接近于单位对角矩阵
# sum the cross-correlation matrix between all gpus
c.div_(self.args.batch_size)
torch.distributed.all_reduce(c)
on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() # 对角线上的损失
off_diag = off_diagonal(c).pow_(2).sum() # 非对角线损失
loss = on_diag + self.args.lambd * off_diag
return loss