本文的目的主要在于改进自注意力计算的高昂计算成本,基于局部自注意力的形式进行了扩展,实现了一种更加高效的全局注意力形式,而免去了Swin那样的划窗操作(划窗操作需要进行padding和mask,以及划窗仅仅会覆盖不同局部区域的部分内容)或者其他更为复杂的例如token unfolding和rolling操作,甚至是对于key和value的额外计算。GCViT模型的整体结构如下图所示。
本文仍然基于windows-based attention的形式,从而保证了相对于图像大小线性的放缩关系;在类似于Swin中的local attention的基础上,作者们构建了一种新的global attention形式,来实现横跨不同local window之间的图像patch上的信息交流。
global attention的核心是对原本local attention的query的改进。其直接使用从原始图像特征上利用CNN结构(即文中的Global Query Generator)提取缩小到窗口区域对应尺寸嵌入,使其与image token窗口中的local key和local value进行计算,从而允许捕获跨区域交互的长距离信息。因此local attention与global attention的唯一差异在于query的来源,前者来自模块输入,而后者则来自于stage初生成,内部各个module共享的global token。
整个stage的流程如下:
特征送如一个stage中,即GCViTLayer
首先根据送入的特征计算global query。这里的计算中,会将输入特征下采样到和后续计算的window-based attention的window有着相同的尺寸
特征和global query送入多个交替安置的local attention(global query不起作用)和global attention模块中
在attention模块中,使用pre-norm形式,attention的计算需要先对输入特征划分窗口,对各个窗口执行local attention。计算输出结果后再将形状恢复,送入后续的基于FC的MLP中。其中local attention的q来自于输入特征,计算过程是正常的;global attention的q来自global query,在执行qk计算之前,会先对q执行repeat操作,使其形状与k和v保持一致,之后执行local attention
stage结尾后按需求进行下采样,下采样是借助于跨步卷积实现
GCViT变体参数配置如下:
在YOLOv5项目中添加模型作为Backbone使用的教程:
(1)将YOLOv5项目的models/yolo.py修改parse_model函数以及BaseModel的_forward_once函数
(2)在models/backbone(新建)文件下新建GCVIT.py,添加如下的代码:
(3)在models/yolo.py导入模型并在parse_model函数中修改如下(先导入文件):
(4)在model下面新建配置文件:yolov5_gcvit.yaml
(5)运行验证:在models/yolo.py文件指定–cfg参数为新建的yolov5_gcvit.yaml