CLIP 中不一定会被注意的细节(ResNet 网络的改进)

CLIP 中不一定会被注意的细节(ResNet 网络的改进)

在这里插入图片描述

在一开始使用CLIP的时候,CLIP的 ResNet50 网络并不是直接从 torchvision.models 直接导入进来的,这一点对于CLIP的模型设计非常重要。

更改原因
1. 首先想要进行 CLIP 这样的对比学习,进行特征比较的过程需要的向量,仅仅是特征向量长度而不是序列,所以没有序列维度,而ResNet这样的网络去掉池化和全连接出来的特征如果输入的是标准大小的图片的话是一个(7*7)的特征向量。所以这个向量含有位置信息,而不能对齐局部信息
  • There are now 3 “stem” convolutions as opposed to 1, with an average pool instead of a max pool.
  • Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
  • The final pooling layer is a QKV attention instead of an average pool
中文意思就是:
  • 与之前仅有1个的不同,现在有 3个“茎”卷积层 ,并且采用 平均池化 而非最大池化。
  • 执行抗锯齿的步幅卷积,当步幅大于1时,在卷积之前添加平均池化。
  • 最后的 池化层 采用了QKV 注意力机制,而不是平均池化。

对比之下我们就可以看出来不同

ModifiedResNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU(inplace=True)
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU(inplace=True)
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU(inplace=True)
  (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (layer1): Sequential(
    ...
  )
  (layer2): Sequential(
    ...
  )
  (layer3): Sequential(
   ...
  )
  (layer4): Sequential(
    ...
  )
  (attnpool): AttentionPool2d(
    (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
    (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
    (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
    (c_proj): Linear(in_features=2048, out_features=512, bias=True)
  )
)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    ...
  )
  (layer2): Sequential(
    ...
  )
  (layer3): Sequential(
    ...
  )
  (layer4): Sequential(
   ...
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

作者改进最吸引人的点是使用了 AttentionPooling 这样特殊的池化方式

这样使用注意力进行池化,很有意思,一开始甚至很难发现它和池化有什么关系

class AttentionPool2d(nn.Module):
    def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
        super().__init__()
        self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
        self.num_heads = num_heads

    def forward(self, x):
        x = x.flatten(start_dim=2).permute(2, 0, 1)  # NCHW -> (HW)NC
        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
        x = x + self.positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
        x, attn_weights = F.multi_head_attention_forward(
            query=x[:1], key=x, value=x,  # 这里是池化的原因,只对全局进行了注意力的查询
            embed_dim_to_check=x.shape[-1],
            num_heads=self.num_heads,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            in_proj_weight=None,
            in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=0,
            out_proj_weight=self.c_proj.weight,
            out_proj_bias=self.c_proj.bias,
            use_separate_proj_weight=True,
            training=self.training,
            need_weights=True
        )
        self.attn_weights = attn_weights
        return x.squeeze(0)

x, attn_weights = F.multi_head_attention_forward( query=x[:1], key=x, value=x,)
这里是池化的原因,只对全局进行了注意力的查询,只对 第一个CLS标记,全局特征 进行查询注意力

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值