【自监督通用方法】Barlow Twins | 通过减少冗余进行的自监督学习

论文地址:Barlow Twins: Self-Supervised Learning via Redundancy Reduction (ICML 2021)
代码地址:Barlow Twins Code

摘要

自监督学习(SSL)正迅速在大型计算机视觉的基础上缩小与监督方法之间的性能差距。SSL方法成功的关键是学习对输入样本的扭曲不变的嵌入。然而,这种方法的一个经常出现的问题是存在平凡常数解。大多数当前的方法通过精心的细节实现(调参?)来避免这样的解决方案。我们提出了一种客观函数,通过测量两个相同网络的输出之间的交叉相关矩阵来自然地避免模型坍塌,并使其尽可能接近单位矩阵。这导致经过数据增强的样本的嵌入向量相似,同时最小化这些向量的组件之间的冗余。该方法被称为Barlow Twins,这是因为将神经科学家H. Barlow的冗余减少原则应用于一对相同的网络。Barlow Twins不需要大批量处理,也不需要网络孪生体之间的不对称性,比如预测网络、梯度停止或权重更新的移动平均。有趣的是,它受益于非常高维的输出向量。Barlow Twins在ImageNet上在低数据情况下的半监督分类方面优于先前的方法,并在使用线性分类器头的ImageNet分类以及分类和目标检测的转移任务中与当前的最新技术相媲美。

方法

在这里插入图片描述
简单来说就是,一张图片X通过某种方法生成一对正样本(图中YA和YB),经过同样的编码器和MLP等结构得到嵌入表达(图中ZA和ZB)。因为他们的源是一样的,所以希望他们输出的表达尽可能相似,但在不同的方面提取的特征是不同的。

Loss计算

Barlow Twins Loss
先看公式(2):上图中的 Z A , Z B Z^A,Z^B ZAZB大小为 [B, D], B是batchsize,D是dim是特征维度,分子 z b , i A z_{b,i} ^{A} zb,iA表示的是在一个batch中的来自图片A的第b个样本的第i个元素,分母做了一个归一化。整个大C大小就是[D, D]
再看公式(1):Loss由两部分构成:Invariance term指的是对角线上的元素,意为不同的Y也可以得到相似的Representation;Redundancy reduction term指的是非对角线元素,意为Representation中的每个元素表达含义不同。因为这两项的个数不一样,所以加了一个超参数lambd。(跟个人认为这里肯定有更好的方法计算出一个lambd)

pytorch伪代码

在这里插入图片描述

核心代码

class BarlowTwins(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        # 定义backbone
        self.backbone = torchvision.models.resnet50(zero_init_residual=True)
        self.backbone.fc = nn.Identity()

        # 定义projector
        sizes = [2048] + list(map(int, args.projector.split('-')))
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)

    def forward(self, y1, y2):
        z1 = self.projector(self.backbone(y1))
        z2 = self.projector(self.backbone(y2))

        # empirical cross-correlation matrix
        c = self.bn(z1).T @ self.bn(z2)  # [D, D]希望它更接近于单位对角矩阵

        # sum the cross-correlation matrix between all gpus
        c.div_(self.args.batch_size)
        torch.distributed.all_reduce(c)

        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()  # 对角线上的损失
        off_diag = off_diagonal(c).pow_(2).sum()  # 非对角线损失
        loss = on_diag + self.args.lambd * off_diag
        return loss
  • 19
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Willow输入中

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值