CLIP 中不一定会被注意的细节(ResNet 网络的改进)
在一开始使用CLIP的时候,CLIP的 ResNet50 网络并不是直接从 torchvision.models 直接导入进来的,这一点对于CLIP的模型设计非常重要。
-
更改原因
-
1. 首先想要进行 CLIP 这样的对比学习,进行特征比较的过程需要的向量,仅仅是特征向量长度而不是序列,所以没有序列维度,而ResNet这样的网络去掉池化和全连接出来的特征如果输入的是标准大小的图片的话是一个(7*7)的特征向量。所以这个向量含有位置信息,而不能对齐局部信息
- There are now 3 “stem” convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
-
中文意思就是:
- 与之前仅有1个的不同,现在有 3个“茎”卷积层 ,并且采用 平均池化 而非最大池化。
- 执行抗锯齿的步幅卷积,当步幅大于1时,在卷积之前添加平均池化。
- 最后的 池化层 采用了QKV 注意力机制,而不是平均池化。
对比之下我们就可以看出来不同
ModifiedResNet(
(conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu3): ReLU(inplace=True)
(avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
(layer1): Sequential(
...
)
(layer2): Sequential(
...
)
(layer3): Sequential(
...
)
(layer4): Sequential(
...
)
(attnpool): AttentionPool2d(
(k_proj): Linear(in_features=2048, out_features=2048, bias=True)
(q_proj): Linear(in_features=2048, out_features=2048, bias=True)
(v_proj): Linear(in_features=2048, out_features=2048, bias=True)
(c_proj): Linear(in_features=2048, out_features=512, bias=True)
)
)
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
...
)
(layer2): Sequential(
...
)
(layer3): Sequential(
...
)
(layer4): Sequential(
...
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=2048, out_features=1000, bias=True)
)
作者改进最吸引人的点是使用了 AttentionPooling 这样特殊的池化方式
这样使用注意力进行池化,很有意思,一开始甚至很难发现它和池化有什么关系
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, attn_weights = F.multi_head_attention_forward(
query=x[:1], key=x, value=x, # 这里是池化的原因,只对全局进行了注意力的查询
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=True
)
self.attn_weights = attn_weights
return x.squeeze(0)
x, attn_weights = F.multi_head_attention_forward( query=x[:1], key=x, value=x,)
这里是池化的原因,只对全局进行了注意力的查询,只对 第一个CLS标记,全局特征 进行查询注意力