深度学习论文: LEDnet: A lightweight encoder-decoder network for real-time semantic segmentation及其PyTorch实现

class HalfSplit(nn.Module):
    def __init__(self, dim=1):
        super(HalfSplit, self).__init__()
        self.dim = dim

    def forward(self, input):
        splits = torch.chunk(input, 2, dim=self.dim)
        return splits[0], splits[1]

class ChannelShuffle(nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        self.groups = groups

    def forward(self, x):
        '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
        N, C, H, W = x.size()
        g = self.groups
        return x.view(N, g, int(C / g), H, W).permute(0, 2, 1, 3, 4).contiguous().view(N, C, H, W)


class SS_nbt(nn.Module):
    def __init__(self, channels, dilation=1, groups=4):
        super(SS_nbt, self).__init__()

        mid_channels = channels // 2
        self.half_split = HalfSplit(dim=1)

        self.first_bottleneck = nn.Sequential(
            ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, padding=[1, 0]),
            ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, padding=[0, 1]),
            ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, dilation=[dilation,1], padding=[dilation, 0]),
            ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, dilation=[1,dilation], padding=[0, dilation]),
        )

        self.second_bottleneck = nn.Sequential(
            ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, padding=[0, 1]),
            ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, padding=[1, 0]),
            ConvReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[1, 3], stride=1, dilation=[1,dilation], padding=[0, dilation]),
            ConvBNReLU(in_channels=mid_channels, out_channels=mid_channels, kernel_size=[3, 1], stride=1, dilation=[dilation,1], padding=[dilation, 0]),
        )

        self.channelShuffle = ChannelShuffle(groups)

    def forward(self, x):
        x1, x2 = self.half_split(x)
        x1 = self.first_bottleneck(x1)
        x2 = self.second_bottleneck(x2)
        out = torch.cat([x1, x2], dim=1)
        return self.channelShuffle(out+x)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值