基于深度学习使用Attention机制的图像超分辨率模型

使用Attention机制的图像超分辨率模型

本篇博客主要记录下使用基本的Attention机制在经典的深度学习超分辨率模型中的应用
 
模型选择;SRCNNEDSR
Attention机制选择:SENetSAM
 

SRCNN论文:https://arxiv.org/abs/1501.00092
EDSR论文:https://arxiv.org/abs/1707.02921
SAM论文:https://arxiv.org/abs/1807.06521
SENet论文:https://arxiv.org/abs/1709.01507

 

1. 什么是Attention机制?

深度学习与视觉注意力机制结合的研究工作,大多数是集中于使用掩码( mask )来形成注意力机制。掩码的原理在于通过另一层新的权重,将图片数据中关键的特征标识出来,通过学习训练,让深度神经网络学到每一张新图片中需要关注的区域,也就形成了注意力。(可以注意到,本质是希望通过学习得到一组可以作用在原图上的权重分布)。
根据这种思想,注意力有两个大的分类:分别是软注意力( soft attention )和强注意力( hard attention )。

所谓的权重分布,就是要得到注意力权重矩阵,然后用这个注意力权重矩阵去和相应的通道域或者空间域特征相乘,达到对通道域或空间域加权的目的

  1. 强注意力是一个随机的预测过程,更强调动态变化,同时其不可微,比如图像裁剪,直接“关注”某区域。训练往往需要通过增强学习来完成。

  2. 软注意力的关键在于其是可微的,也就意味着可以计算梯度,利用神经网络的训练方法获得。

目前Attention机制一般分为下面几种类型:

  • 空间域
  • 通道域
  • 层域
  • 混合域
  • 时间域

本篇博客目前只用到了空间域通道域
 

2. 通道域Attention:SENet

通道域的Attention机制,是指对不同维度的feature maps,让网络给予不同的权重,使得有效的feature mapa的权重大,无效或者效果小的feature maps小。要达到这样的目的,就要通过学习的方式自动获取每个feature maps通道的权重。SENet是一种子结构,可以获取feature maps中不用通道的权重大小。
 
最关键的步骤就是如何得到权重特征分布矩阵
图片来自https://arxiv.org/pdf/1709.01507.pdf

(图片摘自原论文:https://arxiv.org/pdf/1709.01507.pdf)

 
SENet的核心思想就是获得 1 × 1 × C 1 \times1\times C 1×1×C 的feature maps权重分布,相当于为feature maps的每个通道分配了一个标量的权重数字,最终和这个 C × H × W C \times H\times W C×H×W的做矩阵的乘积即可。
 
具体的操作如下:
 

  1. Ftr()标准卷积操作
    这一步就是标准卷积操作,没什么好说的。输入: C ′ × H ′ × W ′ C' \times H' \times W' C×H×W,输入: C × H × W C \times H \times W C×H×W

     

  2. Fsq()压缩操作

    压缩操作,用到全局池化操作(Global average pooling),从结果上来说,输入是 C C C维的feature maps每一维都得到一个标量数字权重。 输入 C × H × W C \times H \times W C×H×W,输入 1 × 1 × C 1 \times1\times C 1×1×C

    GAP的意义是对整个网络从结构上做正则化防止过拟合。既要参数少避免全连接带来的过拟合风险,又要能达到全连接一样的转换功能,怎么做呢?直接从feature map的通道上下手,如果我们最终有1000 channels,那么最后一层卷积输出的feature map就只有1000个channels,然后对这个feature map应用全局池化,输出长度为1000的向量,这就相当于剔除了全连接层黑箱子操作的特征,直接赋予了每个channel实际的类别意义。

#全局平均池化代码pytorch实现
import torch
a = torch.randn(3,4,5,6)
GAP = torch.nn.AdaptiveAvgPool2d(1)	#自适应池化,指定输出尺寸为1*1
b = GAP(a)	
print(b.shape)	#[3,4,1,1]

下图红色部分就是通过全局池化进行的压缩操作。

图 2

上图参考于图中链接,侵删。

 
3. Fex()扩展操作

输入: 1 × 1 × C 1 \times 1 \times C 1×1×C ,输出: 1 × 1 × C 1 \times 1\times C 1×1×C
论文通过Excitation操作(紫色框标注)来全面捕获通道依赖性(相互之间的重要性),论文提出需要满足两个标准:

  • 它必须是灵活的(特别是它必须能够学习通道之间的非线性交互);

  • 它必须学习一个非互斥的关系,因为独热激活相反,这里允许强调多个通道。
    为了满足这些要求,论文选择采用一个简单的gating mechanism,使用了sigmoid激活函数。

在这里插入图片描述

上图公式来源于个人 MathType 工具

  1. pytorch 完整代码
class SELayer(nn.Module):
  def __init__(self, channel, reduction=4):
      super(SELayer, self).__init__()
      self.avg_pool = nn.AdaptiveAvgPool2d(1)
      self.fc = nn.Sequential(
          nn.Linear(channel, channel // reduction, bias=False),
          nn.ReLU(inplace=True),
          nn.Linear(channel // reduction, channel, bias=False),
          nn.Sigmoid()
      )

  def forward(self, x):
      b, c, _, _ = x.size() #[-1,64,H,W]
      y = self.avg_pool(x).view(b, c)#[-1,64]
      y = self.fc(y).view(b, c, 1, 1)#[-1,64,1,1]

      return x * y.expand_as(x) #[-1,64,H,W]

 

3. 在SRCNN模型中添加SENet

 
SENet作为一种子结构,只需要修改相应的模型文件就可以,比如

#  SRCNN model.py
class SRCNN(nn.Module):
   def __init__(self,use_attention=None):
       super(SRCNN,self).__init__()
       self.conv1 = nn.Conv2d(1,64,kernel_size=9,padding=4)
       nn.init.kaiming_uniform(self.conv1.weight,mode='fan_in')
       self.relu1 = nn.ReLU()

       if use_attention:
           self.Senet = SELayer(64,reduction=4)

       self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0)
       nn.init.kaiming_uniform(self.conv2.weight,mode='fan_in')
       self.relu2 = nn.ReLU()
       self.conv3 = nn.Conv2d(32,1,kernel_size=5,padding=2)

   def forward(self,x):
       out = self.conv1(x)
       out = self.Senet(out)
       out = self.relu1(out)
       out = self.conv2(out)
       out = self.relu2(out)
       out = self.conv3(out)
       return out


class SELayer(nn.Module):
   def __init__(self, channel, reduction=4):
       super(SELayer, self).__init__()
       self.avg_pool = nn.AdaptiveAvgPool2d(1)
       self.fc = nn.Sequential(
           nn.Linear(channel, channel // reduction, bias=False),
           nn.ReLU(inplace=True),
           nn.Linear(channel // reduction, channel, bias=False),
           nn.Sigmoid()
       )
       
   def forward(self, x):
       b, c, _, _ = x.size() #[-1,64,H,W]
       y = self.avg_pool(x).view(b, c)
       y = self.fc(y).view(b, c, 1, 1)
       return x * y.expand_as(x) 

 
那输出的结果是怎么样的呢?我们可以写一段测试Demo查看输入和输出的张量

if __name__ == '__main__':
    import torch, torchvision
    from torchsummary import summary
    model = SRCNN( use_attention=True)
    summary(model, (1, 224, 224))

 
结果如下所示:可以看到从Conv2d-1作为SELayer子结构的输入,输出是SELayer-7的尺寸大小。两者尺寸大小是相同的,也反应了SELayer只是给通道加了不同的权重后输出。

在这里插入图片描述

 

4. 空间域Attention:SAM

 
采用 CBAM中的 SAM模块

在spatial层面上也需要网络能明白feature map中哪些部分应该有更高的响应。首先,还是使用average pooling和max pooling对输入feature map进行压缩操作,只不过这里的压缩变成了通道层面上的压缩,对输入特征分别在通道维度上做了mean和max操作。最后得到了两个二维的feature,将其按通道维度拼接在一起得到一个通道数为2的feature map,之后使用一个包含单个卷积核的隐藏层对其进行卷积操作,要保证最后得到的feature在spatial维度上与输入的feature map一致,如下图中紫色方框所示。

核心也就是如何得到注意力权重分布矩阵,再与自身的某个于中的特征相乘

在这里插入图片描述

上图摘自原论文:https://arxiv.org/abs/1807.06521

 
好了,基本的介绍不多说了,大家可以参考参考其他的博客,直接放pytorch代码


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):

        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
    return self.sigmoid(x)
    

 

5. 在EDSR中添加SAM机制

# Up_sacle_factor = 4,两次 nn.pixelshuffle(2)

class _Residual_Block(nn.Module):
    def __init__(self,channels = 64):
        super(_Residual_Block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1,bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1, bias=False)
    def forward(self, x):
        identity_data = x
        output = self.relu(self.conv1(x))
        output = self.conv2(output)
        output = torch.add(output, identity_data)
        return output


class EDSR(nn.Module):
    def __init__(self,channels = 64):
        super(EDSR, self).__init__()
        self.conv_input = nn.Conv2d(in_channels=1, out_channels=channels, kernel_size=3, stride=1, padding=1,ias=False)

        self.residual = self.make_layer(_Residual_Block, 4)
        
        self.conv_mid = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.upscale4x = nn.Sequential(
            nn.Conv2d(in_channels=channels, out_channels=channels * 4, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2),
            nn.Conv2d(in_channels=channels, out_channels=channels * 4, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2),
        )
        
        self.spatial_attention = SpatialAttention()#添加attention模块
        
        self.conv_output = nn.Conv2d(in_channels=channels, out_channels=1, kernel_size=3, stride=1, padding=1,bias=False)
                  
    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append( block() )
        return nn.Sequential(*layers)

    def forward(self, x):
        #out = self.sub_mean(x)
        out = self.conv_input(x)
        residual = out
        out = self.conv_mid(self.residual(out))
        
        out= self.spatial_attention(out)*out#使用attention模块
        
        out = torch.add(out, residual)
        out = self.upscale4x(out)
        out = self.conv_output(out)
        return out

 

6. 添加了Attention后的模型训练

 

6.1 数据集的选择和处理

数据集的生成可以用两种方法:

方法1: 在DataSets类中对HR图像进行退化生成LR图像
 

HR数据集

我使用的是公开数据集 DIV2K,DIV2K:数据集有1000张高清图(2K分辨率),其中800张作为训练,100张作为验证,100张作为测试。
 

LR数据集

将输入的HR图像使用Bicubic +高斯随机noise退化为LR数据,和HR构成一个pairs对。

方法2: 预先使用数据集处理函数,生成LR图像,保存至本地文件夹 或者 保存至HDF5(.h5)格式,
 

一些处理细节:

1 ) 由于不同的模型采用放大倍数的逻辑是不一样的,可以主要分为下面2种:

1.1 Pre-Upsamoling : 网络的输入尺寸就是高分辨率尺寸,代表模型:SRCNN
在这里插入图片描述

上图摘自原论文

1.2 Post-Upsampling: 网络输入就是退化的LR,在网络后面部分进行放大尺寸,代表模型:ESPCN

在这里插入图片描述

上图摘自原论文

2)模型的输入并不是整幅2K分辨率的图像,这样的话,数据读取会很缓慢,模型训练和测试都会很慢,因为输入的数据维度太大了。解决方法之一:中心裁剪。

即在改写Dataset类获取HR,LR图像后,使用DataLoaderj进行读取的时候,使用pytorch里面的transform.CenterCrop()对图像进行中心裁剪,构成一个pair对。

具体代码:

def input_transform(crop_size,upscale_factor):
    return Compose([
        CenterCrop(crop_size),
        Resize(crop_size // upscale_factor),
        ToTensor(), ])
        
def target_transform(crop_size):
    return Compose([
        CenterCrop(crop_size),
        ToTensor(),])

 

6.2 用测试集可视化特征图

 
将基于Attention训练好的模型用测试集进行测试,并查看卷积层的特征图,看看加了注意力机制的网络“注意了”哪些特征?

  1. 获取模型中的特征子模块

首先得查看模型的有哪些特征子模块(除去池化层和激活函数层)

   model_layer= list(model.children())
  1. 如何获取卷积层中的feature maps?
#2. 获取第k层的特征图
'''
k:定义提取第几层的feature map
x:图片的tensor
model_layer:特征层
'''
def get_k_layer_feature_map(model_layer, k, x):
    with torch.no_grad():
        for index, layer in enumerate(model_layer):#model的第一个Sequential()是有多层,所以遍历
            x = layer(x)#torch.Size([1, 64, 55, 55])生成了64个通道
            if k == index:
                return x

  1. 如何可视化卷积层中的特征图?
#  可视化特征图
def show_feature_map(feature_map):#feature_map=torch.Size([1, 64, 55, 55]),feature_map[0].shape=torch.Size([64, 55, 55])
                                                                        # feature_map[2].shape     out of bounds
   feature_map = feature_map.squeeze(0)#压缩成torch.Size([64, 55, 55])
   feature_map_num = feature_map.shape[0]#返回通道数
   print("the num of all features:",feature_map_num)
   row_num = np.ceil(np.sqrt(feature_map_num))#8

   if(os.path.exists('feature_map_save') == False):
       os.mkdir('feature_map_save')

   plt.title("feature maps visualize")
   for index in range(1, feature_map_num + 1):#通过遍历的方式,将64个通道的tensor拿出

       plt.subplot(row_num, row_num, index)
       plt.imshow(feature_map[index - 1])#feature_map[0].shape=torch.Size([55, 55])
       plt.axis('off')
       plt.imsave( 'feature_map_save//'+str(index) + ".png", feature_map[index - 1])

   plt.show()

  1. 完整代码和特征图可视化

def get_image_info(image_dir):
   image_info = Image.open(image_dir).convert('RGB')
   y,cb,cr = image_info.split()
   # 数据预处理方法
   image_transform = transforms.Compose([
       #transforms.Resize(256),
       transforms.CenterCrop(224),
       transforms.ToTensor(),
       #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
   ])
   image_info = image_transform(y)#torch.Size([3, 224, 224])
   image_info = image_info.unsqueeze(0)#torch.Size([1, 3, 224, 224])因为model的输入要求是4维,所以变成4维
   return image_info#变成tensor数据

#2. 获取第k层的特征图
'''
k:定义提取第几层的feature map
x:图片的tensor
model_layer:特征层
'''
def get_k_layer_feature_map(model_layer, k, x):
   with torch.no_grad():
       for index, layer in enumerate(model_layer):#model的第一个Sequential()是有多层,所以遍历
           x = layer(x)#torch.Size([1, 64, 55, 55])生成了64个通道
           if k == index:
               return x

#  可视化特征图
def show_feature_map(feature_map):#feature_map=torch.Size([1, 64, 55, 55]),feature_map[0].shape=torch.Size([64, 55, 55])
                                                                        # feature_map[2].shape     out of bounds
   feature_map = feature_map.squeeze(0)#压缩成torch.Size([64, 55, 55])
   feature_map_num = feature_map.shape[0]#返回通道数
   print("the num of all features:",feature_map_num)
   row_num = np.ceil(np.sqrt(feature_map_num))#8

   if(os.path.exists('feature_map_save') == False):
       os.mkdir('feature_map_save')

   plt.title("feature maps visualize")
   for index in range(1, feature_map_num + 1):#通过遍历的方式,将64个通道的tensor拿出

       plt.subplot(row_num, row_num, index)
       plt.imshow(feature_map[index - 1])#feature_map[0].shape=torch.Size([55, 55])
       plt.axis('off')
       plt.imsave( 'feature_map_save//'+str(index) + ".png", feature_map[index - 1])
   plt.show()

if __name__ ==  '__main__':
   image_dir = "./bird.bmp"
   # 定义提取第几层的feature map
   k = 0
   image_info = get_image_info(image_dir)
   model = ESPCN(upscale_factor=4)
   model.load_state_dict('...your own model state_dict path...',map_location='cpu'))
   model_layer= list(model.children() )
   print("model childen part :",model_layer)
   feature_map = get_k_layer_feature_map(model_layer, 1, image_info)
   show_feature_map(feature_map)

网络中64个特征图的可视化:
在这里插入图片描述

图片来源自作者实验结果

 
网络中32个特征图的可视化
在这里插入图片描述

图片来源自作者实验结果

  • 8
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值