import torch
from torch import nn
from torch.nn import init
# 通道注意力+空间注意力的改进版
# 方法出处 2018 BMCV 《BAM: Bottleneck Attention Module》
# 展平层
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
# 将输入的x,假如它是[B,C,H,W]维度的特征图
# 其中B代表批大小
# C代表通道数
# H,W代表高和宽
# 展平层将特征图展平为[B,C*H*W]
# 其中每一个是一个行向量
# 方便输入到下一个全连接层中
def forward(self, x):
return x.view(x.shape[0], -1)
# 通道注意力
class ChannelAttention(nn.Module):
# 网络层的初始化
def __init__(self, channel, reduction=16, num_layers=3):
super(ChannelAttention, self).__init__()
# 自适应平均池化
# 将特征图的维度,假设是[B,C,H,W]
# 平均池化到[B,C,1,1]
# 相当于将切片矩阵H,W
# 先按行相加
# 在按列相加
self.avgpool = nn.AdaptiveAvgPool2d(1)
# 通道注意力中的多个全连接层的通道参数
# gate_channels是个列表
# 其中存放如下的数字
# [channel,channel//reduction,channel//reduction,...(这里由num_layers控制,就是你想有多少个中间层)
# channel]最后没有改变输入的通道数
# 因为最后要按照通道数乘以通道权重
gate_channels = [channel]
gate_channels += [channel // reduction] * num_layers
gate_channels += [channel]
# 搭建全连接层计算通道注意力
# Sequential以序列化的形式存储网络层
self.ca = nn.Sequential()
# 首先加入一个展平层,方便输入到后面的全连接层中
self.ca.add_module('flatten', Flatten())
# 循环,依次加入全连接层组合
# 这个全连接组合包括
# nn.Linear(channel,channel//reduction)或者
# nn.Linear(channel//reduction,channel//reduction)形式的隐藏层
# 紧接着全连接层之后的正则化层nn.BatchNorm1d
# 因为输出的是向量所以用1d的正则化层
# 然后是激活层
for i in range(len(gate_channels) - 2):
self.ca.add_module('fc%d' % i, nn.Linear(gate_channels[i], gate_channels[i + 1]))
self.ca.add_module('bn%d' % i, nn.BatchNorm1d(gate_channels[i + 1]))
self.ca.add_module('relu%d' % i, nn.ReLU())
# 最后将通道数还原的全连接层
self.ca.add_module('last_fc', nn.Linear(gate_channels[-2], gate_channels[-1]))
# 前向传递建立计算图
def forward(self, x):
# 首先进行池化
res = self.avgpool(x)
# 然后通过全连接层
# 计算出不同通道之间的相似性信息
res = self.ca(res)
# 改变通道注意力结果到与输入的特征图统一维度
# 方便后面的相乘运算
res = res.unsqueeze(-1).unsqueeze(-1).expand_as(x)
return res
# 空间注意力
class SpatialAttention(nn.Module):
def __init__(self, channel, reduction=16, num_layers=3, dia_val=2):
super(SpatialAttention, self).__init__()
# 空间注意力中中间的卷积层
self.sa = nn.Sequential()
# 首先是1*1的卷积层
# 1*1的卷积层不改变卷积层的输入的宽高
# 只是改变输入的通道数
self.sa.add_module('conv_reduce1',
nn.Conv2d(kernel_size=1, in_channels=channel, out_channels=channel // reduction))
self.sa.add_module('bn_reduce1', nn.BatchNorm2d(channel // reduction))
self.sa.add_module('relu_reduce1', nn.ReLU())
# 然后是3个3*3的卷积层
# 这里指定了使用空洞卷积
# 普通的卷积,卷积核之间的元素是相邻的
# 在空洞卷积中,卷积核之间的元素会间隔指定的距离
# 这个距离由我们自己指定
# 因为元素之间存在空隙
# 所以叫做空洞卷积
# 普通卷积输出宽高的计算公式为
# 输出的高=(输入的高+2*padding-卷积核大小)/卷积步幅+1
# 带入参数可知这些3*3普通的卷积核没有改变输入的宽高
# 但是这里的卷积层指定了空洞卷积
# 计算公式为
# 输出的高=(输入的高+2*padding-空洞距离(卷积核大小-1)-1)/卷积步幅+1
# 带入参数
# 输出的高=输入的高-2
# 3次之后宽,高就变成1*1
for i in range(num_layers):
self.sa.add_module('conv_%d' % i, nn.Conv2d(kernel_size=3, in_channels=channel // reduction,
out_channels=channel // reduction, padding=1, dilation=dia_val))
self.sa.add_module('bn_%d' % i, nn.BatchNorm2d(channel // reduction))
self.sa.add_module('relu_%d' % i, nn.ReLU())
# 最后是1*1的卷积层
# 输出通道是1
# 最后空间注意力维度是[B,1,1,1]
self.sa.add_module('last_conv', nn.Conv2d(channel // reduction, 1, kernel_size=1))‘’
#前向传递,建立计算图
def forward(self, x):
#计算空间注意力
res = self.sa(x)
#同理要转换成和输入相同的维度
#方便和通道注意力的结果相加
res = res.expand_as(x)
return res
# BAM整体模型
class BAMBlock(nn.Module):
#初始化层
def __init__(self, channel=512, reduction=16, dia_val=2):
super().__init__()
#计算通道注意力
self.ca = ChannelAttention(channel=channel, reduction=reduction)
#计算空间注意力
self.sa = SpatialAttention(channel=channel, reduction=reduction, dia_val=dia_val)
#激活层
self.sigmoid = nn.Sigmoid()
#初始化网络层权重
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
#前向传递
def forward(self, x):
b, c, _, _ = x.size()
#通道注意力结果
sa_out = self.sa(x)
#空间注意力结果
ca_out = self.ca(x)
#激活
weight = self.sigmoid(sa_out + ca_out)
#这里有个残差连接x+weight*x
out = (1 + weight) * x
return out
if __name__ == '__main__':
# 可以将input看作一个特征图
input = torch.randn(50, 512, 7, 7)
# 捕获不同特征图不同通道之间的关系
bam = BAMBlock(channel=512, reduction=16, dia_val=2)
output = bam(input)
print(output.shape)
2018 BMCV 《BAM: Bottleneck Attention Module》Pytorch实现
最新推荐文章于 2024-03-07 09:49:45 发布