论文地址:https://arxiv.org/pdf/1711.07971.pdf
内容简介
- 主要讨论下提出来的non-local block
non-local block
直译下来叫“非局部块”。长这样:
既然叫“非局部”那就看看什么叫做“局部”。所谓局部指的就是感受野,比如卷积核的大小。一般来说我们使用的都是如3×3的小尺寸卷积核,所以只能提取小范围的信息。
而为了克服这种局面,获取更大范围(比如两个较远像素)之间蕴含的信息,就可以用文章中提出的non local结构。
具体的流程是这样的:
- 对于输入的特征图(假设是1024通道),用1个1×1卷积将其降维至512通道(更通用的讲,降一半),以供备用,记为 g g g
- 同样,分别用2个1×1卷积处理输入特征图,记为 θ \theta θ和 ϕ \phi ϕ
- 将 θ \theta θ和 ϕ \phi ϕ做矩阵乘,得到的结果类似于协方差矩阵,包含自相关性信息
- 对该结果做softmax,获得的权重就相当于self attention系数
- 与g相乘,升维,再与原始输入x残差相加,得到最后结果
代码长这样:
class NonLocal(nn.Module):
def __init__(self, channel):
super(NonLocalBlock, self).__init__()
self.inter_channel = channel // 2
self.conv_phi = nn.Conv2d(channel, self.inter_channel, 1, 1, 0, False)
self.conv_theta = nn.Conv2d(channel, self.inter_channel, 1, 1, 0, False)
self.conv_g = nn.Conv2d(channel, self.inter_channel, 1, 1, 0, False)
self.softmax = nn.Softmax(dim=1)
self.conv_mask = nn.Conv2d(self.inter_channel, channel, 1, 1, 0, False)
def forward(self, x):
b, c, h, w = x.size()
x_phi = self.conv_phi(x).view(b, c, -1)
x_theta = self.conv_theta(x).view(b, c, -1).permute(0, 2, 1).contiguous()
x_g = self.conv_g(x).view(b, c, -1).permute(0, 2, 1).contiguous()
mul_theta_phi = torch.matmul(x_theta, x_phi)
mul_theta_phi = self.softmax(mul_theta_phi)
mul_theta_phi_g = torch.matmul(mul_theta_phi, x_g)
mul_theta_phi_g = mul_theta_phi_g.permute(0, 2, 1).contiguous().view(b, self.inter_channel, h, w)
mask = self.conv_mask(mul_theta_phi_g)
out = mask + x
return out
参考
https://mp.weixin.qq.com/s/FFEKWFgScdBZ8snZQLdfFg