CSP-ResNeXt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary


class BN_Conv2d_Leaky(nn.Module):

    def __init__(self, in_channels: object, out_channels: object, kernel_size: object, stride: object, padding: object,
                 dilation=1, groups=1, bias=False) -> object:
        super(BN_Conv2d_Leaky, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups, bias=bias),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return F.leaky_relu(self.seq(x))

class ResidualBlock(nn.Module):
    """
    Residual block for CSP-ResNeXt
    """
    def __init__(self, in_channels, cardinality, group_width, stride=1):
        super(ResidualBlock, self).__init__()
        self.out_channels = cardinality * group_width
        self.conv1 = BN_Conv2d_Leaky(in_channels, self.out_channels, 1, 1, 0)
        self.conv2 = BN_Conv2d_Leaky(self.out_channels, self.out_channels, 3, stride, 1, groups=cardinality)
        self.conv3 = nn.Conv2d(self.out_channels, self.out_channels, 1, 1, 0)
        self.bn = nn.BatchNorm2d(self.out_channels)

        # make shortcut
        layers = []
        if in_channels != self.out_channels:
            layers.append(nn.Conv2d(in_channels, self.out_channels, 1, 1, 0))
            layers.append(nn.BatchNorm2d(self.out_channels))
        if stride != 1:
            layers.append(nn.AvgPool2d(stride))
        self.shortcut = nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv3(self.conv2(self.conv1(x)))
        out = self.bn(out)
        out += self.shortcut(x)
        return F.leaky_relu(out)


class Stem(nn.Module):
    def __init__(self, in_channels, num_blocks, cardinality, group_with, stride=2): #256  3  32  4
        super(Stem, self).__init__()
        self.c0 = in_channels // 2  #128
        self.c1 = in_channels - in_channels // 2  #128
        self.hidden_channels = cardinality * group_with  # 32*4 = 128
        self.out_channels = self.hidden_channels * 2 # 256
        self.trans_part0 = nn.Sequential(BN_Conv2d_Leaky(self.c0, self.hidden_channels, 1, 1, 0), nn.AvgPool2d(stride)) # o=256
        self.block = self.__make_block(num_blocks, self.c1, cardinality, group_with, stride)
        self.trans_part1 = BN_Conv2d_Leaky(self.hidden_channels, self.hidden_channels, 1, 1, 0)
        self.trans = BN_Conv2d_Leaky(self.out_channels, self.out_channels, 1, 1, 0)

    def __make_block(self, num_blocks, in_channels, cardinality, group_with, stride):
        strides = [stride] + [1] * (num_blocks-1)
        channels = [in_channels] + [self.hidden_channels] * (num_blocks-1)
        return nn.Sequential(*[ResidualBlock(c, cardinality, group_with, s)
                               for c, s in zip(channels, strides)])

    def forward(self, x):
        x0 = x[:, :self.c0, :, :]
        x1 = x[:, self.c0:, :, :]
        out0 = self.trans_part0(x0)
        out1 = self.trans_part1(self.block(x1))
        out = torch.cat((out0, out1), 1)
        return self.trans(out)


class CSP_ResNeXt(nn.Module):
    def __init__(self, num_blocks, cadinality, group_width, num_classes):
        super(CSP_ResNeXt, self).__init__()
        self.conv0 = BN_Conv2d_Leaky(3, 64, 7, 2, 3)
        self.pool1 = nn.MaxPool2d(3, 2, 1)
        self.conv1 = BN_Conv2d_Leaky(64, 128, 1, 1, 0)
        self.stem0 = Stem(cadinality*group_width, num_blocks[0], cadinality, group_width, stride=1)
        self.stem1 = Stem(cadinality*group_width*2, num_blocks[1], cadinality, group_width*2)
        self.stem2 = Stem(cadinality*group_width*4, num_blocks[2], cadinality, group_width*4)
        self.stem3 = Stem(cadinality*group_width*8, num_blocks[3], cadinality, group_width*8)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(cadinality*group_width*16, num_classes)

    def forward(self, x):
        out = self.conv0(x)
        out = self.pool1(out)
        out = self.conv1(out)
        out = self.stem0(out)
        out = self.stem1(out)
        out = self.stem2(out)
        out = self.stem3(out)
        out = self.global_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        # return F.softmax(out)
        return out


def csp_resnext_50_32x4d(num_classes=1000):
    return CSP_ResNeXt([3, 4, 6, 3], 32, 4, num_classes)

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = csp_resnext_50_32x4d().to(device)
    input = (3, 512,512)
    summary(net, input)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值