YOLO即插即用模块---MEGANet

MEGANet: Multi-Scale Edge-Guided Attention Network for Weak Boundary Polyp Segmentation

 论文地址:

解决问题:

解决方案细节:

解决方案用于目标检测:

即插即用代码:


 论文地址:

https://arxiv.org/pdf/2309.03329icon-default.png?t=O83Ahttps://arxiv.org/pdf/2309.03329

解决问题:

MEGANet 主要解决了弱边界息肉分割问题。息肉图像通常具有复杂的背景、多变的形状和模糊的边界,这给分割任务带来了挑战。

MEGANet 通过结合边缘信息和注意力机制,有效地保留了高频边缘信息,从而提高了分割精度。MEGANet 的解决方案主要包括三个模块:

  • 编码器: 从输入图像中提取特征。

  • 解码器: 利用编码器提取的特征生成分割结果。

  • 边缘引导注意力模块 (EGA): 利用拉普拉斯算子增强息肉边界信息,并引导模型关注边缘相关的特征。

 

解决方案细节:

  • EGA 模块:

    • 接收来自编码器的嵌入特征、来自拉普拉斯算子的高频特征以及来自解码器的预测特征。

    • 将高频特征与边界注意力图和反向注意力图进行元素级乘法,得到融合特征。

    • 使用注意力掩码引导模型关注重要区域,抑制背景噪声。

    • 通过 CBAM 模块进一步细化特征,捕捉边界与背景区域之间的特征相关性。

解决方案用于目标检测:

MEGANet 的 EGA 模块可以应用于目标检测任务,用于增强目标边界信息,提高检测精度。 具体应用位置可以参考以下几种方案:

  • 特征提取阶段: 将 EGA 模块添加到特征提取网络中,例如在 ResNet 或 EfficientNet 的某些层之间插入 EGA 模块,增强特征图中目标边界信息。

  • 目标框回归阶段: 将 EGA 模块添加到目标框回归网络中,例如在 RetinaNet 或 YOLO 的回归层之前添加 EGA 模块,引导模型更精确地回归目标边界。

  • 目标分类阶段: 将 EGA 模块添加到目标分类网络中,例如在 Faster R-CNN 的 RoI Pooling 层之后添加 EGA 模块,增强目标区域特征,提高分类准确率。

需要注意的是,将 EGA 模块应用于目标检测任务需要进行一些调整,例如

  • 选择合适的边缘检测方法: 拉普拉斯算子可能不适用于所有目标检测任务,需要根据任务特点选择合适的边缘检测方法。

  • 调整 EGA 模块结构: 根据目标检测网络的结构和任务需求,调整 EGA 模块的结构和参数。

  • 训练策略: 需要重新训练模型,并调整训练策略,例如学习率、优化器等。

总的来说,MEGANet 的 EGA 模块为解决弱边界目标分割问题提供了一种有效的方法,并且可以应用于目标检测任务,提高检测精度

即插即用代码:

import torch
import torch.nn.functional as F
import torch.nn as nn


def gauss_kernel(channels=3, cuda=True):
    kernel = torch.tensor([[1., 4., 6., 4., 1],
                           [4., 16., 24., 16., 4.],
                           [6., 24., 36., 24., 6.],
                           [4., 16., 24., 16., 4.],
                           [1., 4., 6., 4., 1.]])
    kernel /= 256.
    kernel = kernel.repeat(channels, 1, 1, 1)
    if cuda:
        kernel = kernel.cuda()
    return kernel


def downsample(x):
    return x[:, :, ::2, ::2]


def conv_gauss(img, kernel):
    img = F.pad(img, (2, 2, 2, 2), mode='reflect')
    out = F.conv2d(img, kernel, groups=img.shape[1])
    return out


def upsample(x, channels):
    cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])
    cc = cc.permute(0, 1, 3, 2)
    cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3)
    cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)
    x_up = cc.permute(0, 1, 3, 2)
    return conv_gauss(x_up, 4 * gauss_kernel(channels))


def make_laplace(img, channels):
    filtered = conv_gauss(img, gauss_kernel(channels))
    down = downsample(filtered)
    up = upsample(down, channels)
    if up.shape[2] != img.shape[2] or up.shape[3] != img.shape[3]:
        up = nn.functional.interpolate(up, size=(img.shape[2], img.shape[3]))
    diff = img - up
    return diff


def make_laplace_pyramid(img, level, channels):
    current = img
    pyr = []
    for _ in range(level):
        filtered = conv_gauss(current, gauss_kernel(channels))
        down = downsample(filtered)
        up = upsample(down, channels)
        if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:
            up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3]))
        diff = current - up
        pyr.append(diff)
        current = down
    pyr.append(current)
    return pyr


class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            nn.Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
        )

    def forward(self, x):
        avg_out = self.mlp(F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))))
        max_out = self.mlp(F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))))
        channel_att_sum = avg_out + max_out

        scale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale


class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2)

    def forward(self, x):
        x_compress = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out)  # broadcasting
        return x * scale


class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio)
        self.SpatialGate = SpatialGate()

    def forward(self, x):
        x_out = self.ChannelGate(x)
        x_out = self.SpatialGate(x_out)
        return x_out


# Edge-Guided Attention Module(EGA)
class EGA(nn.Module):
    def __init__(self, in_channels):
        super(EGA, self).__init__()

        self.fusion_conv = nn.Sequential(
            nn.Conv2d(in_channels * 3, in_channels, 3, 1, 1),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True))

        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, 1, 3, 1, 1),
            nn.BatchNorm2d(1),
            nn.Sigmoid())

        self.cbam = CBAM(in_channels)

    def forward(self, edge_feature, x, pred):
        residual = x
        xsize = x.size()[2:]

        pred = torch.sigmoid(pred)

        # reverse attention
        background_att = 1 - pred
        background_x = x * background_att

        # boudary attention
        edge_pred = make_laplace(pred, 1)
        pred_feature = x * edge_pred

        # high-frequency feature
        edge_input = F.interpolate(edge_feature, size=xsize, mode='bilinear', align_corners=True)
        input_feature = x * edge_input

        fusion_feature = torch.cat([background_x, pred_feature, input_feature], dim=1)
        fusion_feature = self.fusion_conv(fusion_feature)

        attention_map = self.attention(fusion_feature)
        fusion_feature = fusion_feature * attention_map

        out = fusion_feature + residual
        out = self.cbam(out)
        return out


if __name__ == '__main__':


    # 模拟输入张量
    edge_feature = torch.randn(1, 1, 128, 128).cuda()
    x = torch.randn(1, 64, 128, 128).cuda()
    pred = torch.randn(1, 1, 128, 128).cuda()  # pred 通常是1通道

    # 实例化 EGA 类
    block = EGA(64).cuda()

    # 传递输入张量通过 EGA 实例
    output = block(edge_feature, x, pred)

    # 打印输入和输出的形状
    print(edge_feature.size())
    print(x.size())
    print(pred.size())
    print(output.size())

大家对于YOLO改进感兴趣的可以进群了解,群中有答疑,(QQ群:828370883)

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值