Image Super-Resolution with Non-Local Sparse Attention

具有非局部稀疏注意的图像超分辨率

参考:Image Super-Resolution with Non-Local Sparse Attention_非局部稀疏注意力_C站某不知名用户的博客-CSDN博客

非局部(NL)操作和稀疏表示对于单图像超分辨率(SISR)都至关重要。在本文中,我们研究了它们的组合,并提出了一种新的具有动态稀疏注意模式的非局部稀疏注意(NLSA)。NLSA旨在保留NL操作的长距离建模能力,同时享受稀疏表示的鲁棒性和高效性。具体而言,NLSA通过将输入空间划分为相关特征的哈希桶的球形位置敏感哈希(LSH)来纠正非局部注意力。对于每个查询信号,NLSA都会为其分配一个桶,并且只计算桶内的注意力。由此产生的稀疏注意力防止了模型关注有噪声和信息量较少的位置,同时将计算成本从空间大小的二次方降低到渐近线性。大量实验验证了NLSA的有效性和效率。通过几个非局部稀疏注意模块,我们的架构(称为非局部稀疏网络(NLSN))在数量和质量上达到了SISR的最先进性能。

单图像超分辨率(SISR)近年来引起了极大的关注。一般来说,SISR的目标是在给定低分辨率图像的情况下重建高分辨率图像。由于SISR任务的不适定性,提出了各种图像先验[12,14,24,36,40,46]作为正则化器,包括最具代表性的稀疏和非局部先验,这是本文的重点。几十年来,稀疏性约束已被充分探索为许多图像重建问题的强大驱动力[4,7,19],特别是SISR[46]。使用稀疏编码,图像可以很好地表示为预定义的超完整字典(如小波[11]和曲线[9]函数)中原子的稀疏线性组合。结合基于示例的方法,稀疏表示使用原始图像块[46]或从退化的图像块中学习的语义特征块来开发字典图像本身[17,19]或外部数据集[47]。随着SISR的深度卷积神经网络(CNN)的出现,层之间的非线性激活包含了稀疏先验的优点。Dong等人提出SRCNN[16]首先成功地将卷积桥接到经典稀疏编码,其中ReLU激活通过将所有负条目归零来大致强制50%的稀疏性。最近,Fan等人[21]通过明确地对隐藏神经元施加稀疏性约束来超越这一点,并得出结论,特征表示中的稀疏性确实是有益和有利的。广泛证明,稀疏性约束通过大大减少表示图像的元素的数量而导致高效率。在理论上[10,20]和实践中,它还产生了处理逆问题的更强大和鲁棒的表达式。

另一个被广泛探索的图像先验是非局部(NL)先验。对于SISR,采用非局部注意成为一种更普遍的方式[37,51],在小图案倾向于在同一图像内重复出现之前利用图像自相似性[5]。NL操作在全局范围内搜索那些相似的模式,并选择性地对这些相关特征求和以增强表示。尽管非局部注意是直观的,有希望融合特征,但将其直接应用于SISR任务将遇到一些不容忽视的问题。首先,深层特征的感受场往往是全局的,因此深层特征之间的互相关计算并不准确[33]。第二,全局NL关注需要计算所有像素位置之间的特征互相似性。它导致与图像大小相关的二次计算成本。为了缓解上述问题,一种策略是将NL搜索范围限制在本地邻域内。但它以丢失大量全局信息为代价,降低了计算成本。

在本文中,对于特定的SISR任务,我们的目标是在非局部注意模块中加强稀疏性,并大大降低其计算成本。具体而言,我们提出了一种新的非本地稀疏注意(NLSA),并将其嵌入到像EDSR[32]这样的剩余网络基线中,以形成非本地稀疏网络(NLSN)。为了增强NLSA的稀疏性,我们在空间上将深度特征像素划分为不同的组(称为关注桶在本文中)。同一桶内的特征像素被认为是内容密切相关的。然后,我们在查询像素所属的桶内或在排序后跨相邻桶应用非本地(NL)操作。我们通过在局部敏感散列(LSH)研究[23]上构建分区方法来实现这一点,该方法搜索产生最大内积的相似元素。所提出的NLSA将有可能将NL的计算复杂性从二次降低到空间维度的渐近线性。在较小的内容相关桶中搜索类似线索也将使模块关注信息量更大且相关的位置。结果,NLSA保留了标准NL操作的全局建模能力,同时从其稀疏表示中获得了鲁棒性和效率。总之,本文的主要贡献如下:我们提出通过一种新的非局部稀疏注意(NLSA)模块来增强SISR任务的非局部操作中的稀疏性。稀疏性约束迫使模块专注于相关和信息区域,而忽略不相关和有噪声的内容,我们通过首先对特征像素进行分组,并且仅在名为注意力桶的组内执行非局部操作来实现特征稀疏性。我们采用局部敏感哈希(LSH)进行分组,并为每个组分配一个哈希码。所提出的方法将计算复杂度从二次显著降低到渐近线性。在没有任何提示的情况下,几个NLSA模块可以将一个相当简单的ResNet主干驱动到最先进的水平。大量实验证明了NLSA优于标准的非局部注意(NLA)。

2. Related Work
2.1. Sparse representation.

在本节中,我们简要回顾了稀疏表示的关键概念。形式上,假设x1,x2。。。,xn∈Rd是超完备字典Dd×n(d<n)中的n个已知例子。对于查询信号y∈Rd,基于示例的方法[12,46]将其表示为D中元素的加权和:

方程2给出了一个未确定的线性系统。求解α成为一个不适定问题。为了缓解这种情况,稀疏表示假设y应该稀疏表示,即α应该稀疏:

其中,k.k0和k分别计数和限制α中非零元素的数量。给定稀疏性约束,OMP[38]等优化方法可以有效地近似方程3的解。所得到的稀疏表示已被证明在图像重建领域非常强大[19,47,53]。由于他们的成功,我们受到启发,将稀疏代表性纳入非本地关注

2.2. Non-Local Attention (NLA) for image SR

非局部操作假设小斑块倾向于在同一图像内再次出现,这已被充分证明是自然图像的强先验[5]。非局部方法旨在利用这些自回归来恢复基础信号。非局部操作已广泛应用于许多图像恢复问题,如超分辨率[22]、去噪[1、2、8、13]和修复[18]。Wang等人[45]首先将经典的非局部过滤与机器翻译的自我注意方法[43]联系起来,并进一步将非局部注意(NLA)引入深度神经网络,以捕获高级任务的全局语义关系。对于图像超分辨率,最近的方法,如NLRN[33]、SAN[15]、RNAN[51]和CSNLN[37],证明了通过采用NL关注来探索长距离特征相关性的显著好处。然而,为SISR任务设计的现有NLA要么局限于局部邻域,要么主要消耗计算资源。受语言建模中自我注意方法的最新进展[29,39,44]的启发,我们提出了非局部稀疏注意(NLSA)来接受长距离信息并降低复杂性。

3. Non-Local Sparse Attention (NLSA)
3.1. General Form of Sparse Attention

如上所述,图像SR的非局部关注的优点通常以限制其搜索范围为代价。为了缓解这个问题,我们建议将标准NLA连接到基于示例的方法,然后通过施加稀疏性约束来打破这种联系。

Non-Local Attention.通常,非局部注意通过汇总来自所有位置的信息来增强输入特征图X∈Rh×w×c。为了说明,我们将X重塑为一维特征X∈Rn×c,其中n=hw。给定查询位置i,相应的输出响应yi∈Rc可以表示为:

Sparsity Constraints on Non-Local Attention.给定等式4,通过将α的非零项的数量限制为常数k,可以对非局部注意施加稀疏性约束。因此,具有稀疏约束的非局部注意的一般形式可以推导为:

 Attention Bucket.

注意,索引集δi指示给定查询应该关注的像素位置组。换句话说,δi限制了可以从中计算非本地关注的识别位置。在本文中,我们将这组位置定义为注意力桶。图1显示了不同δi下注意力桶的一些示例。例如,标准的非局部注意力跨越了所有可能的位置,这使得聚集的特征变得嘈杂且信息量较少。如果注意力跨越长度为L的局部邻域,这指定了一个窗口δi={j||j−i|<L}。在这种情况下,某些远程上下文无法有效聚合。

直观地说,一个更强大的稀疏注意力有望覆盖信息最丰富、最接近的位置在全球范围内相关,因此忽略其他元素不会对性能造成损害。一种简单的方法是对所有的相似性进行排序,然后使用前k个条目。然而,这需要首先形成充分的关注,这不会带来效率的提高。在下面的部分中,我们将展示如何通过高效地对注意力进行全局建模来形成每个查询的注意力桶。

3.2. Attention Bucket from Locality Sensitive Hashing (LSH)

如上所述,期望的注意力不仅应保持稀疏,还应包含最相关的元素。在本节中,我们建议采用球形位置敏感散列(LSH)[3,42]来形成所需的关注桶,该关注桶包含与查询元素相关的全局元素。具体而言,我们建议根据角度距离将嵌入空间空间划分为具有相似特征的桶。因此,即使注意力只跨越一个桶而保持稀疏,它仍然可以捕获大部分相关元素。

回想一下,如果附近的元素很可能落入同一个哈希桶(哈希码),而远处的元素则不是,则哈希方案对位置敏感。球形LSH是为角距离设计的LSH的一个实例。人们可以直观地将其视为随机旋转一个刻在超球体中的十字多面体,如图2的顶部分支所示。哈希函数将张量投影到超球面上,并选择最近的多面体顶点作为其哈希代码。因此,如果两个向量具有较小的角距离,它们很可能落在同一个哈希桶中,这也是定义的注意力桶。形式上,假设我们想要获得m个哈希桶,我们必须首先将目标张量投影到超球面上,并用矩阵a∈Rc×m(一个带有i.i.d.高斯项的采样随机旋转矩阵)随机旋转它,即

在实践中,通过批量矩阵乘法对所有元素同时执行球形LSH,这只增加了可忽略的计算成本。预先知道要参与哪个存储桶,该模型可以通过忽略其他有噪声或相关性较小的分区来实现高效性和鲁棒性。

 3.3. Non-Local Sparse Attention

一旦确定了查询位置i的关注桶索引集δi,建议的非本地稀疏关注,(NLSA)可以很容易地从等式6导出。具体来说,如图2所示,NLSA根据内容相关性将X中的每个像素特征分配给共享相同哈希代码的桶,并且只有相应的桶元素对输出有贡献。在下文中,我们描述了在实际实现中使用的一些技术。

Dealing with Unbalanced Bucketing. 理想情况下,给定总共m个桶,每个哈希桶将平均包含n个元素。然而,这在实践中可能不成立,因为桶往往不平衡。这也使得并行计算非常困难。为了克服这一困难,我们首先根据特征的桶值(哈希码)对特征进行排序,然后将排列定义为π:i→ π(i)。在知道它们的新位置(由上标表示)后,我们将它们分成大小为k的块:

 述策略用于并行执行计算更加友好。尽管有其优点,但将原始桶拆分为固定大小的块作为更新的关注桶也会带来一个微妙的问题:一些新的块可能会跨越原始桶的边界,如图2所示。幸运的是,这一问题可以通过允许注意力也跨越相邻块而得到有效缓解。

Multi-round NLSA. 球形LSH的性质表明,某些相关元素被错误地散列到不同的散列桶中的可能性很小。幸运的是,这种机会可以通过独立地散列多个回合并合并所有结果来减少。基于这一观察结果,我们提出了多循环NLSA,以使哈希过程更加稳健。设δr,i表示第r次散列的xi的注意力桶,Att(xi,δr,i)是等式6中定义的相关稀疏注意力,即:

直观地说,多轮NLSA是每一轮关注结果的加权和,权重系数表示每一轮的查询与其分配的哈希桶中的元素之间的归一化相似度。作为一个副作用,这种增加线性地增加了相对于总哈希循环的计算成本。但我们仍然可以在评估期间动态调整该参数,以研究权衡。

Computational complexity. 我们分析了所提出的NLSA的时间复杂性。给定输入特征X∈Rn×c,具有m个桶的球面LSH的成本是矩阵乘法,即O(ncm)。具有稀疏性约束(注意力桶的大小)k的注意力操作(等式6)的成本为O(nck)。长度为n和m个不同数字(桶数)的序列的排序操作增加了快速排序的额外O(nm)(并且可以使用高级排序算法进一步优化)。因此,我们的非局部稀疏注意力的总体计算成本为O(nck+ncm+nm)。对r轮进行哈希运算会增加计算成本,增加系数r,从而导致O(rnck+rncm+rnm)。NLSA只考虑输入空间大小的线性计算复杂性

Instantiations.

为了实例化等式6中定义的非局部注意,我们为

3.4. Non-Local Sparse Network (NLSN)

为了证明非局部稀疏注意的有效性,我们在相当简单的EDSR[32]主干上构建了非局部稀疏网络(NLSN),该主干由32个残差块组成。如图3所示,网络总共使用5个关注块,每8个剩余块后插入一个关注块。网络仅使用L1重建损失进行训练

4. Experiments
4.1. Datasets and Metrics

......

class NonLocalSparseAttention(nn.Module):
    def __init__( self, n_hashes=4, channels=64, k_size=3, reduction=4, chunk_size=144, conv=common.default_conv, res_scale=1):
        super(NonLocalSparseAttention,self).__init__()
        self.chunk_size = chunk_size#每个chunk有144个元素的hash值
        self.n_hashes = n_hashes#hash值为4维
        self.reduction = reduction
        self.res_scale = res_scale
        self.conv_match = common.BasicBlock(conv, channels, channels//reduction, k_size, bn=False, act=None)
        self.conv_assembly = common.BasicBlock(conv, channels, channels, 1, bn=False, act=None)
        # 作为对比 标准Non-local如下
        # self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
        # self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act = nn.PReLU())
        # self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU())
 
    def LSH(self, hash_buckets, x):
        #x: [N,H*W,C]
        N = x.shape[0]#batch size
        device = x.device
        
        #generate random rotation matrix
        rotations_shape = (1, x.shape[-1], self.n_hashes, hash_buckets//2) #[1,C,n_hashes,hash_buckets//2]
        random_rotations = torch.randn(rotations_shape, dtype=x.dtype, device=device).expand(N, -1, -1, -1) #[N, C, n_hashes, hash_buckets//2]
        
        #locality sensitive hashing [n hw c]*[N, C, n_hashes, hash_buckets//2]
        rotated_vecs = torch.einsum('btf,bfhi->bhti', x, random_rotations) #[N, n_hashes, H*W, hash_buckets//2],把channel维度融掉了(hw乘以其对应的数进行旋转),对应于论文流程图中的求和步骤
        rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) #[N, n_hashes, H*W, hash_buckets]
        #为什么要这样做呢(又有正,又有负)?可以参考
        # [42] Kengo Terasawa and Yuzuru Tanaka. Spherical lsh for approximate nearest neighbor search on unit hypersphere. In Workshop on Algorithms and Data Structures, pages 27–38. Springer, 2007. 3
        #的附录,对应于orthoplex情景
 
        #get hash codes
        hash_codes = torch.argmax(rotated_vecs, dim=-1) #[N, n_hashes, H*W, hash_buckets]->[N,n_hashes,H*W]求得每个hash bucket中最大的值的位置 作为该feature map像素点的hash值
        
        #add offsets to avoid hash codes overlapping between hash rounds 加了一点偏移量,防止hash code重叠
        offsets = torch.arange(self.n_hashes, device=device) #生成【0,1,2,3】数组
        offsets = torch.reshape(offsets * hash_buckets, (1, -1, 1)) #【0,1*hb,3*hb,3*hb】  形状是(1,4,1)
        hash_codes = torch.reshape(hash_codes + offsets, (N, -1,)) #[N,n_hashes(这个维度和offsets一样),H*W]->[N,n_hashes*H*W]
    
        return hash_codes 
    
    def add_adjacent_buckets(self, x):
        #这个函数用于把相邻的bucket相连
        x_extra_back = torch.cat([x[:,:,-1:, ...], x[:,:,:-1, ...]], dim=2)#把倒数第一行移到了第一行的位置 相当于向下移动一行
        x_extra_forward = torch.cat([x[:,:,1:, ...], x[:,:,:1,...]], dim=2)#把第一行移到了倒数第一行的位置 相当于向上移动一行
        return torch.cat([x, x_extra_back,x_extra_forward], dim=3)#将这三个东西沿着行的方向进行拼接
        #这个操作十分巧妙地将第i组 第i-1和i+1组放在了一行里面 拼接了这三个组
 
    def forward(self, input):
        
        N,_,H,W = input.shape
        x_embed = self.conv_match(input).view(N,-1,H*W).contiguous().permute(0,2,1)#channel数 ➗4了 [N,h*w,c/4]
        y_embed = self.conv_assembly(input).view(N,-1,H*W).contiguous().permute(0,2,1)#channel数没有变 [N,h*w,c]
        #contiguous:view只能作用在contiguous的variable上,如果在view之前调用了transpose、permute等,就需要调用contiguous()来返回一个contiguous copy;
        #这儿为什么不是在permute之后采用contigious呢?不是很懂
        L,C = x_embed.shape[-2:] #L是H*W,且C是channel/4
 
        #number of hash buckets/hash bits 计算有多少个桶呢 最多128个
        hash_buckets = min(L//self.chunk_size + (L//self.chunk_size)%2, 128)#保障hash_buckets(bucket的数量)是偶数
        
        #get assigned hash codes/bucket number         
        hash_codes = self.LSH(hash_buckets, x_embed) #[N,n_hashes*H*W]
        hash_codes = hash_codes.detach()#计算过程不需要反向传播
 
        #group elements with same hash code by sorting
        _, indices = hash_codes.sort(dim=-1) #[N,n_hashes*H*W] sort以升序排列,返回值为value-tensor和indice-tensor
        _, undo_sort = indices.sort(dim=-1) #undo_sort to recover original order 
        #这里返回的是 【N,n_hashes*H*W】这一次返回值是原来的hash_codes中每一个值它的大小在整个序列里面的排名,如果给了这个序列按顺序排列的结果,那可以根据这个undo-sort列表,还原出原始的序列来。
        mod_indices = (indices % L) #now range from (0->H*W)
        x_embed_sorted = common.batched_index_select(x_embed, mod_indices) #[N,n_hashes*H*W,C]
        y_embed_sorted = common.batched_index_select(y_embed, mod_indices) #[N,n_hashes*H*W,C*4]
        # def batched_index_select(values, indices):
        #     last_dim = values.shape[-1]
        #     return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))
        #None的作用是在最后增加一维,类似于np.newaxis
 
        #pad the embedding if it cannot be divided by chunk_size
        padding = self.chunk_size - L%self.chunk_size if L%self.chunk_size!=0 else 0
        x_att_buckets = torch.reshape(x_embed_sorted, (N, self.n_hashes,-1, C)) #[N, n_hashes, H*W,C]
        y_att_buckets = torch.reshape(y_embed_sorted, (N, self.n_hashes,-1, C*self.reduction)) 
        if padding:
            pad_x = x_att_buckets[:,:,-padding:,:].clone()
            pad_y = y_att_buckets[:,:,-padding:,:].clone()
            x_att_buckets = torch.cat([x_att_buckets,pad_x],dim=2)
            y_att_buckets = torch.cat([y_att_buckets,pad_y],dim=2)#把最后几个作为pad来补足
        
        x_att_buckets = torch.reshape(x_att_buckets,(N,self.n_hashes,-1,self.chunk_size,C)) #[N, n_hashes, num_chunks, chunk_size, C]
        y_att_buckets = torch.reshape(y_att_buckets,(N,self.n_hashes,-1,self.chunk_size, C*self.reduction))
        
        x_match = F.normalize(x_att_buckets, p=2, dim=-1,eps=5e-5)#L2归一化
        #[N, n_hashes, num_chunks, chunk_size, C]
 
        #allow attend to adjacent buckets
        #论文中We then apply the Non-Local (NL) operation within the bucket that the query pixel belongs to, or across adjacent buckets after sorting.
        #为了可以搜索相邻的组
        x_match = self.add_adjacent_buckets(x_match)
        y_att_buckets = self.add_adjacent_buckets(y_att_buckets)
        
        #unormalized attention score
        raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) #[N, n_hashes, num_chunks, chunk_size, chunk_size*3]
        
        #softmax
        bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True)#logsumexp实际上是针对max函数的一种平滑操作
        score = torch.exp(raw_score - bucket_score) #(after softmax)
        bucket_score = torch.reshape(bucket_score,[N,self.n_hashes,-1])
        
        #attention
        ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets) #[N, n_hashes, num_chunks, chunk_size, C]
        ret = torch.reshape(ret,(N,self.n_hashes,-1,C*self.reduction))
        
        #if padded, then remove extra elements
        if padding:
            ret = ret[:,:,:-padding,:].clone()
            bucket_score = bucket_score[:,:,:-padding].clone()
         
        #recover the original order
        ret = torch.reshape(ret, (N, -1, C*self.reduction)) #[N, n_hashes*H*W,C]
        bucket_score = torch.reshape(bucket_score, (N, -1,)) #[N,n_hashes*H*W]
        ret = common.batched_index_select(ret, undo_sort)#[N, n_hashes*H*W,C]
        bucket_score = bucket_score.gather(1, undo_sort)#[N,n_hashes*H*W]
        
        #weighted sum multi-round attention
        ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction)) #[N, n_hashes*H*W,C]
        bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1))
        probs = nn.functional.softmax(bucket_score,dim=1)
        ret = torch.sum(ret * probs, dim=1)
        
        ret = ret.permute(0,2,1).view(N,-1,H,W).contiguous()*self.res_scale+input
        return ret

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值