(即插即用模块-Attention部分) 六十一、(2024 ACCV) LIA 基于局部重要性的注意力

在这里插入图片描述

paper:PlainUSR: Chasing Faster ConvNet for Efficient Super-Resolution

Code:https://github.com/icandle/PlainUSR


1、Local Importance-based Attention

现有空间注意力机制的缺陷:1-order 注意力(如 ESA): 性能较弱,无法充分利用图像信息。2-order 注意力(如 Self-Attention): 计算复杂度高,运行速度慢,不适合轻量级 SR 模型。而这篇论文提出一种 基于局部重要性的注意力(Local Importance-based Attention),旨在保证性能的前提下,降低计算复杂度,实现高效的 2-order 信息交互。LIA d的原理主要有两点:局部重要性: 通过计算每个像素周围区域的局部重要性,识别图像中关键信息的位置。注意力图: 利用局部重要性生成注意力图,对特征图进行加权,增强重要信息,抑制无关信息。

LIA 的实现过程:

  1. 局部重要性计算:使用 SoftPool 和 3x3 卷积对特征图进行下采样,扩大感受野,减少计算量。然后通过 Sigmoid 激活函数将下采样后的特征图转换为局部重要性图。
  2. 注意力图生成:使用第一个通道的特征图作为门控信号,对局部重要性图进行加权。使用 Bilinear 插值将注意力图缩放到原始特征图的尺寸。
  3. 特征图加权:将注意力图与原始特征图进行逐元素相乘,得到加权后的特征图。

优势:

  • 性能: LIA 能够有效地捕捉图像中的关键信息,提升 SR 模型的性能。
  • 效率: 相比于 2-order 注意力机制,LIA 计算复杂度更低,运行速度更快。
  • 可扩展性: LIA 可以灵活地与其他网络结构结合,适用于不同的 SR 任务。

Local Importance-based Attention 结构图:
在这里插入图片描述


2、代码实现

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SoftPooling2D(torch.nn.Module):
    def __init__(self,kernel_size,stride=None,padding=0):
        super(SoftPooling2D, self).__init__()

        self.avgpool = torch.nn.AvgPool2d(kernel_size,stride,padding, count_include_pad=False)
    def forward(self, x):
        # return self.avgpool(x)
        x_exp = torch.exp(x)
        x_exp_pool = self.avgpool(x_exp)
        x = self.avgpool(x_exp*x)
        return x/x_exp_pool


class LIA(nn.Module):
    ''' attention based on local importance'''
    def __init__(self, channels, f=16):
        super().__init__()
        f = f
        self.body = nn.Sequential(
            # sample importance
            nn.Conv2d(channels, f, 1),
            SoftPooling2D(7, stride=3),
            nn.Conv2d(f, f, kernel_size=3, stride=2, padding=1),
            nn.Conv2d(f, channels, 3, padding=1),
            # to heatmap
            nn.Sigmoid(),
        )
        self.gate = nn.Sequential(
            nn.Sigmoid(),
        )

    def forward(self, x):
        ''' forward '''
        # interpolate the heat map
        g = self.gate(x[:,:1])
        w = F.interpolate(self.body(x), (x.size(2), x.size(3)), mode='bilinear', align_corners=False)

        return x * w * g


if __name__ == '__main__':
    x = torch.randn(4, 64, 128, 128).cuda()
    model = LIA(64).cuda()
    out = model(x)
    print(out.shape)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值