pytorch 实现 SE Block

pytorch 实现 SE Block

论文模块图

在这里插入图片描述

代码

import torch.nn as nn
class SE_Block(nn.Module):
    def __init__(self, ch_in, reduction=16):
        super(SE_Block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)				# 全局自适应池化
        self.fc = nn.Sequential(
            nn.Linear(ch_in, ch_in // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(ch_in // reduction, ch_in, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

现在还有许多关于SE的变形,但大都大同小异

  • 10
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
PyTorch框架中,要在Yolov4中添加SE模块,可以按照以下步骤进行操作: 1. 导入必要的库和模块 ```python import torch import torch.nn as nn ``` 2. 定义SE模块 ```python class SEModule(nn.Module): def __init__(self, in_channels, reduction=16): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(in_channels // reduction, in_channels, bias=False), nn.Sigmoid() ) def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x).view(b, c) y = self.fc(y).view(b, c, 1, 1) return x * y ``` 3. 在Yolov4中应用SE模块 在Yolov4网络的定义中,可以在每个卷积层之后添加SE模块。例如,在CSPDarknet53中,可以按照以下方式添加SE模块: ```python class CSPBlock(nn.Module): def __init__(self, in_channels, out_channels, num_blocks, use_se=True): super(CSPBlock, self).__init__() self.downsample_conv = ConvBlock(in_channels, out_channels, kernel_size=3, stride=2) self.split_conv = ConvBlock(out_channels, out_channels, kernel_size=1, stride=1) self.blocks_conv = nn.Sequential(*[ResidualBlock(out_channels, use_se=use_se) for _ in range(num_blocks)]) self.concat_conv = ConvBlock(out_channels * 2, out_channels, kernel_size=1, stride=1) def forward(self, x): x = self.downsample_conv(x) x = torch.split(x, x.shape[1] // 2, dim=1) x = self.split_conv(x[0]), self.blocks_conv(x[1]) x = torch.cat(x, dim=1) return self.concat_conv(x) ``` 在ResidualBlock中,也可以添加SE模块: ```python class ResidualBlock(nn.Module): def __init__(self, channels, use_se=True): super(ResidualBlock, self).__init__() self.conv1 = ConvBlock(channels, channels // 2, kernel_size=1, stride=1) self.conv2 = ConvBlock(channels // 2, channels, kernel_size=3, stride=1) if use_se: self.se = SEModule(channels) else: self.se = None def forward(self, x): residual = x x = self.conv1(x) x = self.conv2(x) if self.se is not None: x = self.se(x) x += residual return x ``` 通过以上操作,就可以在Yolov4中添加SE模块了。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值