传送地址:https://arxiv.org/pdf/1904.11492.pdf
主要思想
如何获取全文信息(global context) 或者叫长距离依赖性。从视觉角度直观理解,比如我们要认出一个人的话可能需要看整个脸部才能认的出来,单给你脸部的一块皮肤,鼻子相对比较难判断。极端一点来说,例如你的朋友离你很远的情况下,需要看身高,衣服来获取整体的信息才能做出判断。
那么对于图像卷积来说,如何获取全文信息呢?这对大物体的识别和分类是非常有帮助的。卷积网络中典型的就是通过卷积层的堆积 (不太懂的可以看这个两个33的卷积为什么能替代55),这种解决方案存在的问题是计算量效率低,更加难优化模型。那么为了缓解这种问题提出了自注意机制(self-attention mechanism),了解NLP的会比较熟悉这个名字。
这里用生活的语言来解释的话可以怎么理解,我们在识别一种物体的话首先是宏观的来看,然后聚焦中某一个点提取信息来识别。比如我们拿女明星来举例子,猛地一看长的都一样,仔细看他的鼻子,眼睛等。就能够分辨。
例子仅供参考,,,,比如像我这种脸盲的
实验结果
- 自注意力机制
对应的block主要有三种: SE,NL和GC,NL不是特别熟悉,有大佬了解的话还望留言赐教! - 结构图
- block实验对比图
自注意力机制的block分为两个模块,一个是获取全文信息,一个是进行信息的转化(transform)。从结构图上来看NL和GC的区别主要在transform模块,看似差别不是很大,但是在实验结果上GC比NL模块高出1.4个点。并且论文中提出通过可视化特征图,NL模块在注意点在不同的位置特征图基本是一样的,可以理解为没有注意到关键点上面??
从实现角度来讲,这三个block不同在哪里呢?transform相对比较好理解,就是在全文信息的基础上转换为0-1之间的权重值。简单直接的方式就是就是在全局求个最大值,平均值就可以在一定程度上代表全文的信息啦,恭喜你,答对了SE block就是怎么做的。
如果全文的信息相互的之间有交互,是不是效果就更好呢?那么GC block就是在该基础上实现的
# 汇集全文的信息 对应的像素点进行匹配,整个图像的像素点全部相加
context = torch.matmul(input_x, context_mask)
大家调试下面的代码,发现context_mask是0-1之间的值,我们可以理解成它代表这图像中每个像素点的重要性
matmul是点积的操作,大家可以用二维的tensor实验。点积的相乘再相加完成了信息的交互
核心代码
# -*- coding: utf-8 -*-
# @Time : 2019/5/29 15:30
# @Author : ljf
from __future__ import absolute_import
import torch
from torch import nn
from mmcv.cnn import constant_init, kaiming_init
import math
def last_zero_init(m):
if isinstance(m, nn.Sequential):
# nn.init.constant(m[-1].weight,val=0)
constant_init(m[-1], val=0)
m[-1].inited = True
else:
constant_init(m, val=0)
m.inited = True
class ContextBlock2d(nn.Module):
def __init__(self, inplanes, planes, pool, fusions):
super(ContextBlock2d, self).__init__()
assert pool in ['avg', 'att']
assert all([f in ['channel_add', 'channel_mul'] for f in fusions])
assert len(fusions) > 0, 'at least one fusion should be used'
self.inplanes = inplanes
self.planes = planes
self.pool = pool
self.fusions = fusions
if 'att' in pool:
self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2d(1)
if 'channel_add' in fusions:
self.channel_add_conv = nn.Sequential(
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True),
nn.Conv2d(self.planes, self.inplanes, kernel_size=1)
)
else:
self.channel_add_conv = None
if 'channel_mul' in fusions:
self.channel_mul_conv = nn.Sequential(
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(inplace=True),
nn.Conv2d(self.planes, self.inplanes, kernel_size=1)
)
else:
self.channel_mul_conv = None
self.reset_parameters()
def reset_parameters(self):
if self.pool == 'att':
kaiming_init(self.conv_mask, mode='fan_in')
self.conv_mask.inited = True
if self.channel_add_conv is not None:
last_zero_init(self.channel_add_conv)
if self.channel_mul_conv is not None:
last_zero_init(self.channel_mul_conv)
def spatial_pool(self, x):
batch, channel, height, width = x.size()
if self.pool == 'att':
input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
context_mask = self.conv_mask(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = self.softmax(context_mask)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(3)
# [N, 1, C, 1]
# 汇集全文的信息 对应的像素点进行匹配,整个图像的像素点全部相加
context = torch.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x):
# [N, C, 1, 1]
context = self.spatial_pool(x)
if self.channel_mul_conv is not None:
# [N, C, 1, 1]
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
out = x * channel_mul_term
else:
out = x
if self.channel_add_conv is not None:
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
return out
if __name__ == "__main__":
inputs = torch.randn(1,16,300,300)
block = ContextBlock2d(16,4,"att",["channel_add"])
out = block(inputs)
print(out.size())