ECCV2022 论文 Contrastive Deep Supervision

论文链接https://arxiv.org/pdf/2207.05306.pdf

代码链接GitHub - ArchipLab-LinfengZhang/contrastive-deep-supervision: Codes for ECCV2022 paper - contrastive deep supervision

动机

近年来,由于大量数据的出现以及计算机算力的提升,深度学习统治了计算机视觉领域。然而,随着神经网络深度增加的同时,也带来了一些挑战。传统的有监督方法仅对模型的最后一层进行监督,然后再将误差反向传播到中间层。由于反向传播过程中可能会出现梯度消失、爆炸及弥漫等问题,怎么优化好模型中间层的参数成为了一个难点。

近期,深度监督被用于解决上述问题,它的做法是在中间层中添加辅助的分类器。在训练期间,辅助分类器与最终的分类器一同优化。大量实验证明,深度监督加速了模型的收敛。然而,通常来说,不同深度的特征学到的信息不同,底层特征往往含有丰富的纹理及颜色等信息,而深层特征往往含有丰富的语义信息,简单地将辅助分类器应用到中间层特征显然存在问题,因为底层特征没有丰富的语义信息,不适合进行分类 (底层特征往往用于目标定位,因为它含有较多的空间位置信息)。基于这些理论,就有了这篇文章 《Contrastive Deep Supervision》,以下简称 CDS。

创新点

这篇文章的作者认为:相比于有监督的任务损失,对比学习能给中间层的特征提供更好的监督。对比学习通常在同一张图片中使用两种不同的数据增强 (增强方法可以相同,但其中的参数不同),随后将增强后的两张图片视为正样本对,与其余图片构成负样本对。作者提出的方法如下图中的 (d) 所示,几个投影头会附件在中间层的后面,用于将特征映射到嵌入空间,以便进行对比学习,这些投影头在推理期间会被 kill 掉,这样就避免了额外的计算及额外的存储空间。与训练中间层特征去学习特定任务知识的深度监督不同,CDS 学习的是图片中的本质信息,这些信息不受数据增强的影响,这也使神经网络能更好地泛化。此外,由于对比学习可以在未标记的数据上进行,CDS 也可应用到半监督任务中。这篇文章的主要创新点如下:

(1) 提出了 CDS,这是一种神经网络训练方法,其中中间层直接通过对比学习进行优化。它使神经网络能够学习更好的视觉表示,且无需在推理过程中增加额外的开销

(2) 从深度监督的角度来看,作者第一个表明除了有监督任务损失之外,中间层还可以通过其他方式进行训练

(3) 从表示学习的角度来看,作者首个表明对比学习和监督学习可以以一阶段的深度监督的方式联合训练模型,而不是两阶段的 “pretrain-finetune” 方案 (先预训练,后微调)

方法论

CDS

假定一个 minibatch 有 N 张图片,对每张图片都进行两次随机的数据增强,增强后就有 2N 张图片。为了方便,作者把 x_{i} 和 x_{N+i} 作为来自同一图像的两个增强表示,这两张图片也被视为一个正样本对。z=c(x) 为经过投影层并标准化后的输出,对比学习的公式如下:

L_{Contra} 鼓励编码器网络从同一图像中学习不同增强的相似表示,同时增加来自不同图像的增强表示之间的差异。

CDS 与深度监督之间的主要区别在于深度监督通过交叉熵损失来训练辅助分类器,而 CDS 则通过对比学习来训练。CDS 整体损失函数公式如下:

这个公式表示有 K-1 个中间层使用了对比学习来训练,最后一层使用交叉熵损失来训练。

CDS 还可以推广到半监督学习和知识蒸馏中:

在半监督学习中,作者假设有 X_{1} 个带标签的图片,对应的标签为 Y_{1},无标签的数据为 X_{2}。在有标签数据中,可以直接使用 CDS。在无标签数据中,只能进行对比学习。整体的损失公式如下:

在知识蒸馏中,作者进一步提出通过将教师模型学到的图像在数据增强中的不变性传递给学生模型,来改进具有 CDS 的知识蒸馏。f^{S} 和 f^{T} 分别表示知识蒸馏中的学生模型和教师模型,原始的知识蒸馏直接最小化了学生和教师模型的骨干特征之间的距离,可以表示为:

与原始知识蒸馏不同,带有 CDS 的知识蒸馏最小化的是两个模型的嵌入向量 (经投影层得到) 之间的距离,公式如下:

知识蒸馏中的整体损失函数公式如下:

一些细节和 tricks

投影层的设计

在 CDS 的训练期间,将几个投影头添加到神经网络的中间层。这些投影头将骨干特征映射到归一化的嵌入空间,其中应用了对比学习损失。通常,投影头是由两个全连接层和一个 ReLU 函数堆叠而成的非线性投影。然而,在 CDS 中,输入特征来自中间层而不是最终层,因此需要修改投影层的设计。作者通过在非线性投影之前添加卷积层来增加这些投影头的复杂性。

对比学习

CDS 是一个通用的训练框架,不依赖于特定的对比学习方法。在这篇文章中,作者在大多数实验中采用 SimCLR 和 SupCon 作为对比学习的方法。如果使用更好的对比学习算法,模型最终的性能也会进一步提升。

负样本

以前的研究表明,负样本的数量对对比学习的表现有着重要的影响,因此在对比学习中通常使用大的 batch size。但在 CDS 中,作者认为诸如交叉熵之类的损失已经足以防止对比学习收敛到崩溃的解决方案。

实验结果

在 CIFAR100 和 CIFAR10 上的分类结果如下:

ImageNet 上的分类结果如下:

在目标检测数据集 COCO2017 上的结果如下:

在细粒度数据集上的结果如下:

代码

代码也比较简单,拿 resnet18 来举例:

import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import torch

__all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]


# model_urls = {
#     "resnet18": "./pretrain/resnet18-5c106cde.pth",
#     "resnet34": "./pretrain/resnet34-333f7ec4.pth",
#     "resnet50": "./pretrain/resnet50-19c8e357.pth",
#     "resnet101": "./pretrain/resnet101-5d3b4d8f.pth",
#     "resnet152": "./pretrain/resnet152-b121ed2d.pth",
# }

model_urls = {
    "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
    "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
    "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
    "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
    "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
}


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
    )


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class SepConv(nn.Module):
    def __init__(
        self, channel_in, channel_out, kernel_size=3, stride=2, padding=1, affine=True
    ):
        #   depthwise and pointwise convolution, downsample by 2
        super(SepConv, self).__init__()
        self.op = nn.Sequential(
            nn.Conv2d(
                channel_in,
                channel_in,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=channel_in,
                bias=False,
            ),
            nn.Conv2d(channel_in, channel_in, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(channel_in, affine=affine),
            nn.ReLU(inplace=False),
            nn.Conv2d(
                channel_in,
                channel_in,
                kernel_size=kernel_size,
                stride=1,
                padding=padding,
                groups=channel_in,
                bias=False,
            ),
            nn.Conv2d(channel_in, channel_out, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(channel_out, affine=affine),
            nn.ReLU(inplace=False),
        )

    def forward(self, x):
        return self.op(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(
        self, block, layers, num_classes=100, zero_init_residual=False, align="CONV"
    ):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.align = align
        #   reduce the kernel-size and stride of ResNet on cifar datasets.
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        #   remove maxpooling layer for ResNet on cifar datasets.
        #   self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.auxiliary1 = nn.Sequential(
            SepConv(channel_in=64 * block.expansion, channel_out=128 * block.expansion),
            SepConv(
                channel_in=128 * block.expansion, channel_out=256 * block.expansion
            ),
            SepConv(
                channel_in=256 * block.expansion, channel_out=512 * block.expansion
            ),
            nn.AvgPool2d(4, 4),
        )

        self.auxiliary2 = nn.Sequential(
            SepConv(
                channel_in=128 * block.expansion,
                channel_out=256 * block.expansion,
            ),
            SepConv(
                channel_in=256 * block.expansion,
                channel_out=512 * block.expansion,
            ),
            nn.AvgPool2d(4, 4),
        )
        self.auxiliary3 = nn.Sequential(
            SepConv(
                channel_in=256 * block.expansion,
                channel_out=512 * block.expansion,
            ),
            nn.AvgPool2d(4, 4),
        )
        self.auxiliary4 = nn.AvgPool2d(4, 4)

        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        feature_list = []
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        feature_list.append(x)
        x = self.layer2(x)
        feature_list.append(x)
        x = self.layer3(x)
        feature_list.append(x)
        x = self.layer4(x)
        feature_list.append(x)

        out1_feature = self.auxiliary1(feature_list[0]).view(x.size(0), -1)
        out2_feature = self.auxiliary2(feature_list[1]).view(x.size(0), -1)
        out3_feature = self.auxiliary3(feature_list[2]).view(x.size(0), -1)
        out4_feature = self.auxiliary4(feature_list[3]).view(x.size(0), -1)
        out = self.fc(out4_feature)
        feat_list = [out4_feature, out3_feature, out2_feature, out1_feature]
        for index in range(len(feat_list)):
            feat_list[index] = F.normalize(feat_list[index], dim=1)
        if self.training:
            return out, feat_list
        else:
            return out


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(
            model_zoo.load_url(model_urls["resnet50"])
        )
    return model

就是在 resnet 的4个 layer 后添加了 auxiliary head,而 auxiliary head 又由深度可分离卷积与平均池化层构成,用于进一步提取特征 (因为作者认为 resnet 提取的特征的表达能力还不够强,需要进一步提取)

对比学习的损失函数代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf
        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]  #   2
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        #   256 x 512
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature   #   256 x   512
            anchor_count = contrast_count   #   2
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        #   print (anchor_dot_contrast.size())  256 x 256

        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)

        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()
        return loss

在 CIFAR100 上,我使用 resnet18 复现的结果为 80.54%,与论文中的 80.84% 差别不大

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

chen_znn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值