使用Attention机制的图像超分辨率模型
本篇博客主要记录下使用基本的Attention机制在经典的深度学习超分辨率模型中的应用
模型选择;SRCNN,EDSR
Attention机制选择:SENet,SAM
SRCNN论文:https://arxiv.org/abs/1501.00092
EDSR论文:https://arxiv.org/abs/1707.02921
SAM论文:https://arxiv.org/abs/1807.06521
SENet论文:https://arxiv.org/abs/1709.01507
1. 什么是Attention机制?
深度学习与视觉注意力机制结合的研究工作,大多数是集中于使用掩码( mask )来形成注意力机制。掩码的原理在于通过另一层新的权重,将图片数据中关键的特征标识出来,通过学习训练,让深度神经网络学到每一张新图片中需要关注的区域,也就形成了注意力。(可以注意到,本质是希望通过学习得到一组可以作用在原图上的权重分布)。
根据这种思想,注意力有两个大的分类:分别是软注意力( soft attention )和强注意力( hard attention )。
所谓的权重分布,就是要得到注意力权重矩阵,然后用这个注意力权重矩阵去和相应的通道域或者空间域特征相乘,达到对通道域或空间域加权的目的
-
强注意力是一个随机的预测过程,更强调动态变化,同时其不可微,比如图像裁剪,直接“关注”某区域。训练往往需要通过增强学习来完成。
-
软注意力的关键在于其是可微的,也就意味着可以计算梯度,利用神经网络的训练方法获得。
目前Attention机制一般分为下面几种类型:
- 空间域
- 通道域
- 层域
- 混合域
- 时间域
本篇博客目前只用到了空间域和通道域
2. 通道域Attention:SENet
通道域的Attention机制,是指对不同维度的feature maps,让网络给予不同的权重,使得有效的feature mapa的权重大,无效或者效果小的feature maps小。要达到这样的目的,就要通过学习的方式自动获取每个feature maps通道的权重。SENet是一种子结构,可以获取feature maps中不用通道的权重大小。
最关键的步骤就是如何得到权重特征分布矩阵
(图片摘自原论文:https://arxiv.org/pdf/1709.01507.pdf)
SENet的核心思想就是获得
1
×
1
×
C
1 \times1\times C
1×1×C 的feature maps权重分布,相当于为feature maps的每个通道分配了一个标量的权重数字,最终和这个
C
×
H
×
W
C \times H\times W
C×H×W的做矩阵的乘积即可。
具体的操作如下:
-
Ftr():标准卷积操作
这一步就是标准卷积操作,没什么好说的。输入: C ′ × H ′ × W ′ C' \times H' \times W' C′×H′×W′,输入: C × H × W C \times H \times W C×H×W -
Fsq():压缩操作
压缩操作,用到全局池化操作(Global average pooling),从结果上来说,输入是 C C C维的feature maps每一维都得到一个标量数字权重。 输入 C × H × W C \times H \times W C×H×W,输入 1 × 1 × C 1 \times1\times C 1×1×C
GAP的意义是对整个网络从结构上做正则化防止过拟合。既要参数少避免全连接带来的过拟合风险,又要能达到全连接一样的转换功能,怎么做呢?直接从feature map的通道上下手,如果我们最终有1000 channels,那么最后一层卷积输出的feature map就只有1000个channels,然后对这个feature map应用全局池化,输出长度为1000的向量,这就相当于剔除了全连接层黑箱子操作的特征,直接赋予了每个channel实际的类别意义。
#全局平均池化代码pytorch实现
import torch
a = torch.randn(3,4,5,6)
GAP = torch.nn.AdaptiveAvgPool2d(1) #自适应池化,指定输出尺寸为1*1
b = GAP(a)
print(b.shape) #[3,4,1,1]
下图红色部分就是通过全局池化进行的压缩操作。
上图参考于图中链接,侵删。
3. Fex():扩展操作
输入:
1
×
1
×
C
1 \times 1 \times C
1×1×C ,输出:
1
×
1
×
C
1 \times 1\times C
1×1×C
论文通过Excitation操作(紫色框标注)来全面捕获通道依赖性(相互之间的重要性),论文提出需要满足两个标准:
-
它必须是灵活的(特别是它必须能够学习通道之间的非线性交互);
-
它必须学习一个非互斥的关系,因为独热激活相反,这里允许强调多个通道。
为了满足这些要求,论文选择采用一个简单的gating mechanism,使用了sigmoid激活函数。
上图公式来源于个人 MathType 工具
- pytorch 完整代码
class SELayer(nn.Module):
def __init__(self, channel, reduction=4):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size() #[-1,64,H,W]
y = self.avg_pool(x).view(b, c)#[-1,64]
y = self.fc(y).view(b, c, 1, 1)#[-1,64,1,1]
return x * y.expand_as(x) #[-1,64,H,W]
3. 在SRCNN模型中添加SENet
SENet作为一种子结构,只需要修改相应的模型文件就可以,比如
# SRCNN model.py
class SRCNN(nn.Module):
def __init__(self,use_attention=None):
super(SRCNN,self).__init__()
self.conv1 = nn.Conv2d(1,64,kernel_size=9,padding=4)
nn.init.kaiming_uniform(self.conv1.weight,mode='fan_in')
self.relu1 = nn.ReLU()
if use_attention:
self.Senet = SELayer(64,reduction=4)
self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0)
nn.init.kaiming_uniform(self.conv2.weight,mode='fan_in')
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(32,1,kernel_size=5,padding=2)
def forward(self,x):
out = self.conv1(x)
out = self.Senet(out)
out = self.relu1(out)
out = self.conv2(out)
out = self.relu2(out)
out = self.conv3(out)
return out
class SELayer(nn.Module):
def __init__(self, channel, reduction=4):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size() #[-1,64,H,W]
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
那输出的结果是怎么样的呢?我们可以写一段测试Demo查看输入和输出的张量
if __name__ == '__main__':
import torch, torchvision
from torchsummary import summary
model = SRCNN( use_attention=True)
summary(model, (1, 224, 224))
结果如下所示:可以看到从Conv2d-1作为SELayer子结构的输入,输出是SELayer-7的尺寸大小。两者尺寸大小是相同的,也反应了SELayer只是给通道加了不同的权重后输出。
4. 空间域Attention:SAM
采用 CBAM中的 SAM模块
在spatial层面上也需要网络能明白feature map中哪些部分应该有更高的响应。首先,还是使用average pooling和max pooling对输入feature map进行压缩操作,只不过这里的压缩变成了通道层面上的压缩,对输入特征分别在通道维度上做了mean和max操作。最后得到了两个二维的feature,将其按通道维度拼接在一起得到一个通道数为2的feature map,之后使用一个包含单个卷积核的隐藏层对其进行卷积操作,要保证最后得到的feature在spatial维度上与输入的feature map一致,如下图中紫色方框所示。
核心也就是如何得到注意力权重分布矩阵,再与自身的某个于中的特征相乘
上图摘自原论文:https://arxiv.org/abs/1807.06521
好了,基本的介绍不多说了,大家可以参考参考其他的博客,直接放pytorch代码
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
5. 在EDSR中添加SAM机制
# Up_sacle_factor = 4,两次 nn.pixelshuffle(2)
class _Residual_Block(nn.Module):
def __init__(self,channels = 64):
super(_Residual_Block, self).__init__()
self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1,bias=False)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1, bias=False)
def forward(self, x):
identity_data = x
output = self.relu(self.conv1(x))
output = self.conv2(output)
output = torch.add(output, identity_data)
return output
class EDSR(nn.Module):
def __init__(self,channels = 64):
super(EDSR, self).__init__()
self.conv_input = nn.Conv2d(in_channels=1, out_channels=channels, kernel_size=3, stride=1, padding=1,ias=False)
self.residual = self.make_layer(_Residual_Block, 4)
self.conv_mid = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1, bias=False)
self.upscale4x = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels * 4, kernel_size=3, stride=1, padding=1, bias=False),
nn.PixelShuffle(2),
nn.Conv2d(in_channels=channels, out_channels=channels * 4, kernel_size=3, stride=1, padding=1, bias=False),
nn.PixelShuffle(2),
)
self.spatial_attention = SpatialAttention()#添加attention模块
self.conv_output = nn.Conv2d(in_channels=channels, out_channels=1, kernel_size=3, stride=1, padding=1,bias=False)
def make_layer(self, block, num_of_layer):
layers = []
for _ in range(num_of_layer):
layers.append( block() )
return nn.Sequential(*layers)
def forward(self, x):
#out = self.sub_mean(x)
out = self.conv_input(x)
residual = out
out = self.conv_mid(self.residual(out))
out= self.spatial_attention(out)*out#使用attention模块
out = torch.add(out, residual)
out = self.upscale4x(out)
out = self.conv_output(out)
return out
6. 添加了Attention后的模型训练
6.1 数据集的选择和处理
数据集的生成可以用两种方法:
方法1: 在DataSets类中对HR图像进行退化生成LR图像
HR数据集:
我使用的是公开数据集 DIV2K,DIV2K:数据集有1000张高清图(2K分辨率),其中800张作为训练,100张作为验证,100张作为测试。
LR数据集:
将输入的HR图像使用Bicubic +高斯随机noise退化为LR数据,和HR构成一个pairs对。
方法2: 预先使用数据集处理函数,生成LR图像,保存至本地文件夹 或者 保存至HDF5(.h5)格式,
一些处理细节:
1 ) 由于不同的模型采用放大倍数的逻辑是不一样的,可以主要分为下面2种:
1.1
Pre-Upsamoling : 网络的输入尺寸就是高分辨率尺寸,代表模型:SRCNN
上图摘自原论文
1.2
Post-Upsampling: 网络输入就是退化的LR,在网络后面部分进行放大尺寸,代表模型:ESPCN
上图摘自原论文
2)模型的输入并不是整幅2K分辨率的图像,这样的话,数据读取会很缓慢,模型训练和测试都会很慢,因为输入的数据维度太大了。解决方法之一:中心裁剪。
即在改写Dataset类获取HR,LR图像后,使用DataLoaderj进行读取的时候,使用pytorch里面的transform.CenterCrop()对图像进行中心裁剪,构成一个pair对。
具体代码:
def input_transform(crop_size,upscale_factor):
return Compose([
CenterCrop(crop_size),
Resize(crop_size // upscale_factor),
ToTensor(), ])
def target_transform(crop_size):
return Compose([
CenterCrop(crop_size),
ToTensor(),])
6.2 用测试集可视化特征图
将基于Attention训练好的模型用测试集进行测试,并查看卷积层的特征图,看看加了注意力机制的网络“注意了”哪些特征?
- 获取模型中的特征子模块
首先得查看模型的有哪些特征子模块(除去池化层和激活函数层)
model_layer= list(model.children())
- 如何获取卷积层中的feature maps?
#2. 获取第k层的特征图
'''
k:定义提取第几层的feature map
x:图片的tensor
model_layer:特征层
'''
def get_k_layer_feature_map(model_layer, k, x):
with torch.no_grad():
for index, layer in enumerate(model_layer):#model的第一个Sequential()是有多层,所以遍历
x = layer(x)#torch.Size([1, 64, 55, 55])生成了64个通道
if k == index:
return x
- 如何可视化卷积层中的特征图?
# 可视化特征图
def show_feature_map(feature_map):#feature_map=torch.Size([1, 64, 55, 55]),feature_map[0].shape=torch.Size([64, 55, 55])
# feature_map[2].shape out of bounds
feature_map = feature_map.squeeze(0)#压缩成torch.Size([64, 55, 55])
feature_map_num = feature_map.shape[0]#返回通道数
print("the num of all features:",feature_map_num)
row_num = np.ceil(np.sqrt(feature_map_num))#8
if(os.path.exists('feature_map_save') == False):
os.mkdir('feature_map_save')
plt.title("feature maps visualize")
for index in range(1, feature_map_num + 1):#通过遍历的方式,将64个通道的tensor拿出
plt.subplot(row_num, row_num, index)
plt.imshow(feature_map[index - 1])#feature_map[0].shape=torch.Size([55, 55])
plt.axis('off')
plt.imsave( 'feature_map_save//'+str(index) + ".png", feature_map[index - 1])
plt.show()
- 完整代码和特征图可视化
def get_image_info(image_dir):
image_info = Image.open(image_dir).convert('RGB')
y,cb,cr = image_info.split()
# 数据预处理方法
image_transform = transforms.Compose([
#transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
#transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_info = image_transform(y)#torch.Size([3, 224, 224])
image_info = image_info.unsqueeze(0)#torch.Size([1, 3, 224, 224])因为model的输入要求是4维,所以变成4维
return image_info#变成tensor数据
#2. 获取第k层的特征图
'''
k:定义提取第几层的feature map
x:图片的tensor
model_layer:特征层
'''
def get_k_layer_feature_map(model_layer, k, x):
with torch.no_grad():
for index, layer in enumerate(model_layer):#model的第一个Sequential()是有多层,所以遍历
x = layer(x)#torch.Size([1, 64, 55, 55])生成了64个通道
if k == index:
return x
# 可视化特征图
def show_feature_map(feature_map):#feature_map=torch.Size([1, 64, 55, 55]),feature_map[0].shape=torch.Size([64, 55, 55])
# feature_map[2].shape out of bounds
feature_map = feature_map.squeeze(0)#压缩成torch.Size([64, 55, 55])
feature_map_num = feature_map.shape[0]#返回通道数
print("the num of all features:",feature_map_num)
row_num = np.ceil(np.sqrt(feature_map_num))#8
if(os.path.exists('feature_map_save') == False):
os.mkdir('feature_map_save')
plt.title("feature maps visualize")
for index in range(1, feature_map_num + 1):#通过遍历的方式,将64个通道的tensor拿出
plt.subplot(row_num, row_num, index)
plt.imshow(feature_map[index - 1])#feature_map[0].shape=torch.Size([55, 55])
plt.axis('off')
plt.imsave( 'feature_map_save//'+str(index) + ".png", feature_map[index - 1])
plt.show()
if __name__ == '__main__':
image_dir = "./bird.bmp"
# 定义提取第几层的feature map
k = 0
image_info = get_image_info(image_dir)
model = ESPCN(upscale_factor=4)
model.load_state_dict('...your own model state_dict path...',map_location='cpu'))
model_layer= list(model.children() )
print("model childen part :",model_layer)
feature_map = get_k_layer_feature_map(model_layer, 1, image_info)
show_feature_map(feature_map)
网络中64个特征图的可视化:
图片来源自作者实验结果
网络中32个特征图的可视化
图片来源自作者实验结果