HAT论文详解:Activating More Pixels in Image Super-Resolution Transformer

code:https://github.com/XPixelGroup/HAT
paper: https://arxiv.org/abs/2309.05239

1. 概述

本文是对Swinir的改进,目前很多图像超分Benchmark的SOTA。相对于SwinIR的改进主要有三个地方:1. 引入Channel Attention,以获得更好的全局能力;2. 提出了overlapping cross-attention模块,来进行跨window的信息交互;3. 提出一个预训练策略。

2. 引言

2.1 阐明swinir存在的问题

  • SwinIR在SR任务上取得了突破,然而为什么Transformer-based方法要比CNN-based方法好,却很难说清楚。一个直观的解释是Transformer方法可以受益于self-attention机制,并能够利用远距离信息。作者通过LAM分析发现,与RCAN相比,SwinIR并没有利用更大range的信息,这是反直觉的。同时可以说明SwinIR具备比CNN强的映射能力,可以利用更少的信息取得更好的效果。但是由于利用的pixel的范围有限,SwinIR可能会restore出错误的纹理。如下图所示。
    在这里插入图片描述

  • 尽管平均性能优于RCAN,但是有一些结果也比RCAN差

  • 这说明Swin transformer建模局部信息的能力很强,但是探索的信息范围需要扩大

  • 在SwinIR的特征图上发现了block artifacts,这是由于窗口划分造成的,这说明移动窗口机制并不能有效的建立跨窗口的交互。。
    在这里插入图片描述

2.2 本文的贡献:

  • 设计了一个Hybrid Attention,结合了channel attention, self-attention和overlapping cross-attention;
    channel attention:具备很好地获取全局信息的能力
    self-attention: 强大的表达能力(representative ability)

  • 提出一个预训练策略
    因为transformer不具备cnn的归纳偏置,所以需要大规模数据进行预训练,才能解锁潜力。

3. 方法介绍

HAT结构图
在这里插入图片描述
上面两张图分别是HAT和SwinIR的整体结构图,可以看出HAT延续了SwinIR的基本结构,将RSTB升级成RHAG,内部的STL也对应升级成HAB,并且在每个Block中加入了一个OCAB。下面具体来看这两处改动。

  • 向(STL)Swin Transformer Layer中加入了Channel Attention,也就是将(S)W-MSA与CAB的结果叠加起来。

    CAB的代码实现:

    class ChannelAttention(nn.Module):
        """Channel attention used in RCAN.
        Args:
            num_feat (int): Channel number of intermediate features.
            squeeze_factor (int): Channel squeeze factor. Default: 16.
        """
    
        def __init__(self, num_feat, squeeze_factor=16):
            super(ChannelAttention, self).__init__()
            self.attention = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
                nn.ReLU(inplace=True),
                nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
                nn.Sigmoid())
    
        def forward(self, x):
            y = self.attention(x)
            return x * y
    
    class CAB(nn.Module):
    
        def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
            super(CAB, self).__init__()
    
            self.cab = nn.Sequential(
                nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
                nn.GELU(),
                nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
                ChannelAttention(num_feat, squeeze_factor)
                )
    
        def forward(self, x):
            return self.cab(x)
    
  • 在每一个RHAG的最后引入一个Overlapping Cross-Attention Block (OCAB),直接建立跨窗口的连接,同时增强窗口自注意力的表达能力。实现方式仍是基于W-MSA,只是在窗口划分时,Q的窗口是正常的无overlap的窗口,窗口大小为M * M,而K和V的窗口大小是M0 * M0, M0 =(1+gamma) * M, gamma是用于控制重叠大小的参数。虽然窗口的大小不一样,但是窗口的数量是相同的,一一对应的。

      Q shape: (nums_of_windows, M*M, emb_dims)
      
      K shape: (nums_of_windows, M0*M0, emb_dims)
      
      V shape: (nums_of_windows, M0*M0, emb_dims)
      
      QK.T shape: (nums_of_windows, M*M, M0*M0)
      
      因此得到的结果仍是 (nums_of_windows, M*M, emb_dims),但其过程中获取了跨窗口的信息,因为OCA的key和value是从更大的区域中计算得到的,因此更多有用的信息将被query查询到;
    
    • 预训练

      使用Imagenet进行X4预训练,再在DF2K上进行finetune, 发现很有效,预训练的效果取决于数据的量级和多样性;同时,作者指出充分的iteration和合适的小学习率对于预训练来说非常重要;

  • 33
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值