★★★ 本文源自AI Studio社区精品项目,【点击此处】查看更多精品内容 >>>
引入
-
论文名称:A Close Look at Spatial Modeling: From Attention to Convolution
-
模型名称:FCVIT
-
论文时间:23 Dec 2022
-
简介:继承了Transformer和卷积的优点,超过了convnext等先进网络
-
官方代码:https://github.com/ma-xu/FCViT.git
-
Paper链接:https://arxiv.org/abs/2212.12552
一、FCVIT简介
1. 背景:
作者在对一些Transformer Base模型的注意力图进行可视化分析之后发现两点:
- Transformer Base的模型在深层是呈现查询不相关行为
- 注意力图是稀疏的
如上图可见,无论查询点在狗狗的右眼上方还是左眼下方,对应的注意力图都一个样。同时,如果大多数Patch都有助于引起注意,那么注意力权重集中在一个非常小的值(∼0.005)和略大的值,
ViT表现出稀疏的注意力,而来自卷积的知识(例如,DeiT-B-Distill)可以在很大程度上平滑注意力权重
2. 解决方法:
从上诉的过程中我们可以知道要处理的问题有两个,一个是qurey的不相关行为,另外一个是注意力权重稀疏问题。
我们回顾注意力的定义,可以写为公式1:
作者在这里做了一个很大胆的操作,直接将q给删了,得到新的注意力定义如下:
同时作者提到卷积是可以平滑注意力权重的,并且在deit-b上面是有较好的提升,所以作者将新的卷积引入最终得到如下的公式:
3. 模型结构:
那么作者基于上诉就提出了一种新的模型FCVIT,该模型的主要模块由Token-mixer和Channle-mixer组成,模型结构大体如下图:
其中Token-mixer就由一个Global Context模型和两个卷积层组成,Channle-mixer就由几个卷积层和GELU组成。
二、模型实现
2.0 导包与实现小工具
import paddle
import paddle.nn as nn
from fcvit import *
# to_2tuple: 参照timm.models.layers.helpers.to_2tuple
def to_2tuple(x):
if isinstance(x, int):
return (x, x)
else:
return x
2.1 OverlapPatchEmbed
与VIT一样,FCVIT在输入到主干网络之前需要对图像打层patch,在这里都是使用一个卷积层进行实现
参数介绍
- patch_size: 要打的patch的size,本质上是这个卷积层的卷积核大小
- stride: 步长大小
class OverlapPatchEmbed(nn.Layer):
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
patch_size = to_2tuple(patch_size)
self.proj = nn.Conv2D(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.BatchNorm2D(embed_dim)
# TODO init function and apply
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
2.2 LN
在输入之Token-mixer之前需要对数据进行一次Norm
参数介绍
- num_channels:输入维度
代码解释
- 实际上官方代码使用的是GroupNorm重写,当GroupNorm的参数group为1的时候是和Layer Norm等价的,为了尽量保证对齐,在这里同样使用GroupNorm实现
class GroupNorm(nn.GroupNorm):
def __init__(self, num_channels, **kwargs):
super(GroupNorm, self).__init__(1, num_channels, **kwargs)
2.3 GlobalContext
Token-mixer的子模块
参数介绍
- dim:输入维度
- act_layer: 激活函数类型
小坑:
- torch的Tensor.max()和paddle的Tensor.max()不太一样,具体可以看两者的文档
- rearrange需要使用reshape进行
模型简介:
- GlobalContext的右边先对输入进行一个Pooling操作(使用的是MaxPooling),然后是用两个全连接层实现info boottleneck。
- GlobalContext的左边计算了输入的相似度矩阵,然后经过Normlize之后和右边输出进行乘法运算
- 最后与原输入进行残差链接
class GlobalContext(nn.Layer):
def __init__(self, dim, act_layer=nn.GELU, params=params):
super().__init__()
self.compete = params["global_context"]["compete"]
if self.compete:
self.fc1 = nn.Linear(dim, 2*dim//params["global_context"]["gc_reduction"])
self.fc2 = nn.Linear(dim//params["global_context"]["gc_reduction"], dim)
else:
self.fc = nn.Sequential(
nn.Linear(dim, dim//params["global_context"]["gc_reduction"]),
act_layer(),
nn.Linear(dim//params["global_context"]["gc_reduction"], dim)
)
self.weight_gc = params["global_context"]["weighted_gc"]
if self.weight_gc:
self.head = params["global_context"]["head"]
self.scale = (dim//self.head) ** -0.5
rescale_weight = self.create_parameter(shape=[self.head], default_initializer=nn.initializer.Constant(1.0))
self.rescale_weight = self.add_parameter("rescale_weight", rescale_weight)
rescale_bias = self.create_parameter(shape=[self.head], default_initializer=nn.initializer.Constant(0.0))
self.rescale_bias = self.add_parameter("rescale_bias", rescale_bias)
self.epsilon = 1e-5
def _get_gc(self, gap): # gap [b,c]
if self.compete:
b,c = gap.shape
gc = self.fc1(gap).reshape([b,2,-1])
gc = gc.max(axis=1)
gc = self.fc2(gc)
return gc
else:
return self.fc(gap)
def forward(self, x):
if self.weight_gc:
b,c,w,h = x.shape
x = paddle.reshape(x, [b, c, -1])
# MAX POOLING
gap = x.mean(-1, keepdim=True) # [b,c,1]
# 相似度矩阵
q, g = map(lambda t: paddle.reshape(t, [b, self.head, -1, t.shape[2]]), [x, gap])
sim = paddle.einsum("bhdi,bhjd->bhij", q, g.transpose([0, 1, 3, 2])).squeeze(-1) * self.scale
# 对sim矩阵进行norm
std, mean = paddle.std(sim, axis=[1,2], keepdim=True), paddle.mean(sim, axis=[1,2], keepdim=True)
sim = (sim-mean)/(std+self.epsilon)
sim = sim * self.rescale_weight.unsqueeze(axis=0).unsqueeze(axis=-1) + self.rescale_bias.unsqueeze(axis=0).unsqueeze(axis=-1)
sim = sim.reshape([b,self.head, 1, w, h])
# 进行GC模块的右边运算
gc = self._get_gc(gap.squeeze(axis=-1)).reshape([b,self.head,-1]).unsqueeze(axis=-1).unsqueeze(axis=-1) # [b, head, hdim, 1, 1]
gc = paddle.reshape(sim*gc, [b, -1, w, h])
else:
gc = self._get_gc(x.mean(axis=-1).mean(axis=-1)).unsqueeze(axis=-1).unsqueeze(axis=-1)
return gc
2.4 Token-mixer
Token-mixer的子模块
参数介绍
- dim:输入维度
- act_layer: 激活函数类型
模型简介:
- 主要操作都在GC模块里面了,其余的是一个DWConv和一个点卷积
class TokenMixer(nn.Layer):
def __init__(self, dim, act_layer=nn.GELU, params=params):
super().__init__()
self.act = act_layer()
self.useSpatialAtt = params["spatial_mixer"]["useSpatialAtt"]
if params["spatial_mixer"]["use_globalcontext"]:
self.gc1 = GlobalContext(dim, act_layer=act_layer, params=params)
self.dw1 = DWConv2D(dim, params["spatial_mixer"]["mix_size_1"])
if params["spatial_mixer"]["fc_factor"]>1:
self.fc1 = nn.Sequential(
nn.Conv2D(dim, max(dim//params["spatial_mixer"]["fc_factor"], params["spatial_mixer"]["fc_min_value"]), 1),
self.act,
nn.Conv2D(max(dim//params["spatial_mixer"]["fc_factor"], params["spatial_mixer"]["fc_min_value"]), dim, 1)
)
else:
self.fc1 = nn.Conv2D(dim,dim,1)
if params["spatial_mixer"]["useSecondTokenMix"]:
if params["spatial_mixer"]["use_globalcontext"]:
self.gc2 = GlobalContext(dim, act_layer=act_layer, params=params)
self.dw2 = DWConv2D(dim, params["spatial_mixer"]["mix_size_2"])
if params["spatial_mixer"]["fc_factor"]>1:
self.fc2 = nn.Sequential(
nn.Conv2D(dim, max(dim//params["spatial_mixer"]["fc_factor"], params["spatial_mixer"]["fc_min_value"]), 1),
self.act,
nn.Conv2D(max(dim//params["spatial_mixer"]["fc_factor"], params["spatial_mixer"]["fc_min_value"]), dim, 1)
)
else:
self.fc2 = nn.Conv2D(dim,dim,1)
if params["spatial_mixer"]["useSpatialAtt"]:
self.spatial_att = SpatialAtt(dim=dim, act_layer=act_layer, params=params)
def forward(self, x):
if hasattr(self,"gc1"):
gc1 = self.gc1(x)
x = x + gc1
x = self.act(self.fc1(self.dw1(x)))
if hasattr(self, "fc2"):
if hasattr(self, "gc2"):
gc2 = self.gc2(x)
x = x + gc2
x = self.act(self.fc2(self.dw2(x)))
if self.useSpatialAtt:
x = self.spatial_att(x)
return x
2.5 Channle-mixer
Channle-mixer模块
参数介绍
- dim:输入维度
- hidden_dim: 中间维度大学
- act_layer:激活函数类型
- drop:发生drop的概率
模型简介:
Channle-mixer比较简单,不多说了
class ChannelMixer(nn.Layer):
def __init__(self, dim, hidden_dim=None, act_layer=nn.GELU, drop=0., params=params):
super().__init__()
hidden_dim = hidden_dim or dim
self.useChannelAtt = params["channel_mixer"]["useChannelAtt"]
self.act = act_layer()
self.fc1 = nn.Conv2D(dim, hidden_dim, 1)
if params["channel_mixer"]["useDWconv"]:
ks=params["channel_mixer"]["DWconv_size"]
self.dwconv = nn.Conv2D(hidden_dim, hidden_dim, ks, padding=ks//2, groups=hidden_dim)
self.fc2 =nn.Conv2D(hidden_dim, dim, 1)
self.drop = nn.Dropout(drop)
if self.useChannelAtt:
self.channel_att = ChannelAtt(act_layer=act_layer, params=params)
# self.apply(self._init_weights)
# TODO init function and apply
def forward(self, x):
x = self.fc1(x)
if hasattr(self, "dwconv"):
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
if self.useChannelAtt:
x = self.channel_att(x)
return x
至此FCVIT的大体已经差不多了,细节上可以去fcvit.py查看
三、预训练参数转换
为什么要说这个。。。在最开始的时候是想使用x2paddle工具进行转换的,但是遇到了许多版本差异的问题,最后是参考了 PaddlePaddle与Pytoch模型参数之间相互赋值的实现方法重新写了参数转换
大体的代码如下:
# paddle的Tensor.shape和torch的Tensor.shape有点差异,需要转化为list
# norm层里面的参数torch是mean而paddle是running_mean
# 某些层的参数需要转置一下
for k, p in torch_cheakpoint.items():
if k in paddle_cheakpoint:
torch_param = torch_cheakpoint[k].detach().cpu().numpy()
if list(torch_param.shape) == list(paddle_cheakpoint[k].shape):
paddle_cheakpoint[k] = torch_param
elif list(torch_param.shape) == list(paddle_cheakpoint[k].transpose([1, 0]).shape):
paddle_cheakpoint[k] = torch_param.transpose(1, 0)
else:
print('torch param {} dose not match paddle param {}'.format(k, k))
elif 'running_mean' in k:
torch_param = torch_cheakpoint[k].detach().cpu().numpy()
if list(torch_param.shape) == list(paddle_cheakpoint[k[:-12]+'_mean'].shape):
paddle_cheakpoint[k[:-12]+'_mean'] = torch_param
else:
print('torch param {} dose not match paddle param {}'.format(k, k[:-12]+'_mean'))
elif 'running_var' in k:
torch_param = torch_cheakpoint[k].detach().cpu().numpy()
if list(torch_param.shape) == list(paddle_cheakpoint[k[:-11] + '_variance'].shape):
paddle_cheakpoint[k[:-11] + '_variance'] = torch_param
else:
print('torch param {} dose not match paddle param {}'.format(k, k[:-11] + '_variance'))
else:
print('torch param {} not exist in paddle modle'.format(k))
误差验证:
max error: 2.2411346435546875e-05
mean error: 3.471657691989094e-06
四、精度验证
参考Swin Transformer:层次化视觉 Transformer
4.1 解压数据集
!mkdir ~/data/ILSVRC2012
!tar -xf ~/data/data105740/ILSVRC2012_val.tar -C ~/data/ILSVRC2012
^C
!mkdir ~/models
!unzip data/data185449/FCVIT.zip -d ~/models
Archive: data/data185449/FCVIT.zip
inflating: /home/aistudio/models/fcvit_b12.pdparams
inflating: /home/aistudio/models/fcvit_b24.pdparams ^C
4.2 模型验证
import os
import cv2
import numpy as np
import paddle
from PIL import Image
# 构建数据集
class ILSVRC2012(paddle.io.Dataset):
def __init__(self, root, label_list, transform, backend='pil'):
self.transform = transform
self.root = root
self.label_list = label_list
self.backend = backend
self.load_datas()
def load_datas(self):
self.imgs = []
self.labels = []
with open(self.label_list, 'r') as f:
for line in f:
img, label = line[:-1].split(' ')
self.imgs.append(os.path.join(self.root, img))
self.labels.append(int(label))
def __getitem__(self, idx):
label = self.labels[idx]
image = self.imgs[idx]
if self.backend=='cv2':
image = cv2.imread(image)
else:
image = Image.open(image).convert('RGB')
image = self.transform(image)
return image.astype('float32'), np.array(label).astype('int64')
def __len__(self):
return len(self.imgs)
from paddle.vision.transforms import Compose, Resize, Normalize, ToTensor, CenterCrop
val_transforms = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 配置数据集
val_dataset = ILSVRC2012('data/ILSVRC2012/ILSVRC2012_val/', transform=val_transforms, label_list='data/ILSVRC2012/ILSVRC2012_val/val_list.txt')
# 配置模型
from fcvit import fcvit_b48, fcvit_b24, fcvit_b12, fcvit_tiny
models = [fcvit_b48, fcvit_b24, fcvit_b12, fcvit_tiny]
for model in models:
print('********************Validating model: {}***********************'.format(model.__name__))
cheakpoint = paddle.load('models/{}.pdparams'.format(model.__name__))
model = model()
model.set_state_dict(cheakpoint)
model = paddle.Model(model)
model.prepare(metrics=paddle.metric.Accuracy(topk=(1, 5)))
dicts = model.evaluate(val_dataset, batch_size=256)
print(f'******************{dicts}**********************')
********************Validating model: fcvit_b12***********************
Eval begin...
step 10/98 - acc_top1: 0.8066 - acc_top5: 0.9533 - 4s/step
step 20/98 - acc_top1: 0.8083 - acc_top5: 0.9528 - 4s/step
step 30/98 - acc_top1: 0.8059 - acc_top5: 0.9526 - 4s/step
step 40/98 - acc_top1: 0.8042 - acc_top5: 0.9529 - 4s/step
step 50/98 - acc_top1: 0.8050 - acc_top5: 0.9525 - 4s/step
step 60/98 - acc_top1: 0.8061 - acc_top5: 0.9529 - 4s/step
step 70/98 - acc_top1: 0.8046 - acc_top5: 0.9524 - 4s/step
step 80/98 - acc_top1: 0.8051 - acc_top5: 0.9525 - 4s/step
step 90/98 - acc_top1: 0.8048 - acc_top5: 0.9525 - 4s/step
step 98/98 - acc_top1: 0.8049 - acc_top5: 0.9526 - 4s/step
Eval samples: 50000
******************{'acc_top1': 0.80492, 'acc_top5': 0.95262}**********************
4.3 精度比较
模型 | 官方精度 | 转换精度 |
---|---|---|
fcvit_b48 | 83.6 | 82.60 |
fcvit_b24 | 82.5 | 81.5 |
fcvit_b12 | 80.9 | 80.5 |
fcvit_tiny | 74.9 | 74.5 |
这个精度是batchsize为512的时候测试出来的,官方验证的时候使用的batchsize好像是256,待我在测试一下
五、总结
- 总体上是一个很有意思的论文,模型结构也比较简单,按照论文给出的结论的话,性能还算不错。
- 作者的操作非常大胆,删除query是出奇制胜了属于是。
- 未来可以测测这个模型在视觉的下游任务上性能如何~~~
- 从验证结果来看,模型基本精度稍逊一筹