复现InceptionV4

论文链接: https://arxiv.org/abs/1602.07261

import torch
import torch.nn as nn


class BasicConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=(1, 1), padding=(0, 0)):
        super(BasicConv, self).__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels,
            kernel_size=kernel_size, stride=stride, padding=padding
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class StemModel(nn.Module):
    def __init__(self):
        super(StemModel, self).__init__()
        self.conv_1 = BasicConv(in_channels=3, out_channels=32, kernel_size=3, stride=2)
        self.conv_2 = BasicConv(in_channels=32, out_channels=32, kernel_size=3)
        self.conv_3 = BasicConv(in_channels=32, out_channels=64, kernel_size=3, padding=1)

        self.branch_1_1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.branch_1_2 = BasicConv(in_channels=64, out_channels=96, kernel_size=3, stride=2)

        self.branch_2_1 = nn.Sequential(
            BasicConv(in_channels=160, out_channels=64, kernel_size=1),
            BasicConv(in_channels=64, out_channels=96, kernel_size=3)
        )
        self.branch_2_2 = nn.Sequential(
            BasicConv(in_channels=160, out_channels=64, kernel_size=1),
            # 这里的padding是自己算出来的,因为2个branch的输出size不同无法拼接,故在这两层增加补齐操作
            BasicConv(in_channels=64, out_channels=64, kernel_size=(7, 1), padding=(3, 0)),  # 左右的size+3=+6
            BasicConv(in_channels=64, out_channels=64, kernel_size=(1, 7), padding=(0, 3)),
            # 根据figure3图中3x3层有size压缩标记,故不在这里进行padding来做补齐操作
            BasicConv(in_channels=64, out_channels=96, kernel_size=3)
        )

        self.branch_3_1 = BasicConv(in_channels=192, out_channels=192, kernel_size=3, stride=2)
        self.branch_3_2 = nn.MaxPool2d(kernel_size=3, stride=2)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)

        x_1 = self.branch_1_1(x)
        x_2 = self.branch_1_2(x)
        x = torch.cat([x_1, x_2], dim=1)

        x_1 = self.branch_2_1(x)
        x_2 = self.branch_2_2(x)
        x = torch.cat([x_1, x_2], dim=1)

        x_1 = self.branch_3_1(x)
        x_2 = self.branch_3_2(x)
        x = torch.cat([x_1, x_2], dim=1)

        return x


class InceptionA(nn.Module):
    def __init__(self):
        super(InceptionA, self).__init__()
        self.branch_1 = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=35),
            BasicConv(in_channels=384, out_channels=96, kernel_size=1)
        )
        self.branch_2 = BasicConv(in_channels=384, out_channels=96, kernel_size=1)
        self.branch_3 = nn.Sequential(
            BasicConv(in_channels=384, out_channels=64, kernel_size=1),
            BasicConv(in_channels=64, out_channels=96, kernel_size=3, padding=1)
        )
        self.branch_4 = nn.Sequential(
            BasicConv(in_channels=384, out_channels=64, kernel_size=1),
            BasicConv(in_channels=64, out_channels=96, kernel_size=3, padding=1),
            BasicConv(in_channels=96, out_channels=96, kernel_size=3, padding=1)
        )

    def forward(self, x):
        x_1 = self.branch_1(x)
        x_2 = self.branch_2(x)
        x_3 = self.branch_3(x)
        x_4 = self.branch_4(x)
        torch.cat([x_1, x_2, x_3, x_4], dim=1)
        return x


class InceptionB(nn.Module):
    def __init__(self):
        super(InceptionB, self).__init__()
        self.branch_1 = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=17),
            BasicConv(in_channels=1024, out_channels=128, kernel_size=1)
        )
        self.branch_2 = BasicConv(in_channels=1024, out_channels=384, kernel_size=1)
        self.branch_3 = nn.Sequential(
            BasicConv(in_channels=1024, out_channels=192, kernel_size=1),
            BasicConv(in_channels=192, out_channels=224, kernel_size=(1, 7), padding=(0, 3)),  # 根据V2经验进行padding猜测
            BasicConv(in_channels=224, out_channels=256, kernel_size=(7, 1), padding=(3, 0))
        )
        self.branch_4 = nn.Sequential(
            BasicConv(in_channels=1024, out_channels=192, kernel_size=1),
            BasicConv(in_channels=192, out_channels=192, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv(in_channels=192, out_channels=224, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv(in_channels=224, out_channels=224, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv(in_channels=224, out_channels=256, kernel_size=(7, 1), padding=(3, 0)),
        )

    def forward(self, x):
        x_1 = self.branch_1(x)
        x_2 = self.branch_2(x)
        x_3 = self.branch_3(x)
        x_4 = self.branch_4(x)
        torch.cat([x_1, x_2, x_3, x_4], dim=1)
        return x


class InceptionC(nn.Module):
    def __init__(self):
        super(InceptionC, self).__init__()
        self.branch_1 = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=8),
            BasicConv(in_channels=1536, out_channels=256, kernel_size=1)
        )
        self.branch_2 = BasicConv(in_channels=1536, out_channels=256, kernel_size=1)

        self.branch_3_1 = BasicConv(in_channels=1536, out_channels=384, kernel_size=1)
        self.branch_3_2_1 = BasicConv(in_channels=384, out_channels=256, kernel_size=(1, 3), padding=(0, 1))
        self.branch_3_2_2 = BasicConv(in_channels=384, out_channels=256, kernel_size=(3, 1), padding=(1, 0))

        self.branch_4_1 = nn.Sequential(
            BasicConv(in_channels=1536, out_channels=384, kernel_size=1),
            BasicConv(in_channels=384, out_channels=448, kernel_size=(1, 3), padding=(0, 1)),
            BasicConv(in_channels=448, out_channels=512, kernel_size=(3, 1), padding=(1, 0))
        )
        self.branch_4_2_1 = BasicConv(in_channels=512, out_channels=256, kernel_size=(3, 1), padding=(1, 0))
        self.branch_4_2_2 = BasicConv(in_channels=512, out_channels=256, kernel_size=(1, 3), padding=(0, 1))

    def forward(self, x):
        x_1 = self.branch_1(x)
        x_2 = self.branch_2(x)
        x_3 = self.branch_3_1(x)
        x_3_1 = self.branch_3_2_1(x_3)
        x_3_2 = self.branch_3_2_2(x_3)
        x_3 = torch.cat([x_3_1, x_3_2], dim=1)
        x_4 = self.branch_4_1(x)
        x_4_1 = self.branch_4_2_1(x_4)
        x_4_2 = self.branch_4_2_2(x_4)
        x_4 = torch.cat([x_4_1, x_4_2], dim=1)
        torch.cat([x_1, x_2, x_3, x_4], dim=1)
        return x


class Reduction_A(nn.Module):
    def __init__(self):
        super(Reduction_A, self).__init__()
        self.branch_1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.branch_2 = BasicConv(in_channels=384, out_channels=384, kernel_size=3, stride=2)
        self.branch_3 = nn.Sequential(
            BasicConv(in_channels=384, out_channels=192, kernel_size=1),
            BasicConv(in_channels=192, out_channels=224, kernel_size=3),
            BasicConv(in_channels=224, out_channels=256, kernel_size=3, stride=2, padding=1)  # padding是为了凑论文size
        )

    def forward(self, x):
        x_1 = self.branch_1(x)
        x_2 = self.branch_2(x)
        x_3 = self.branch_3(x)
        x = torch.cat([x_1, x_2, x_3], dim=1)
        return x


class Reduction_B(nn.Module):
    def __init__(self):
        super(Reduction_B, self).__init__()
        self.branch_1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.branch_2 = nn.Sequential(
            BasicConv(in_channels=1024, out_channels=192, kernel_size=1),
            BasicConv(in_channels=192, out_channels=192, kernel_size=3, stride=2)
        )
        self.branch_3 = nn.Sequential(
            BasicConv(in_channels=1024, out_channels=256, kernel_size=1),
            BasicConv(in_channels=256, out_channels=256, kernel_size=(1, 7), padding=(0, 3)),  # 注意这里stride=1
            BasicConv(in_channels=256, out_channels=320, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv(in_channels=320, out_channels=320, kernel_size=3, stride=2)
        )

    def forward(self, x):
        x_1 = self.branch_1(x)
        x_2 = self.branch_2(x)
        x_3 = self.branch_3(x)
        x = torch.cat([x_1, x_2, x_3], dim=1)
        return x


class InceptionV4(nn.Module):
    def __init__(self, num_classes):
        super(InceptionV4, self).__init__()
        self.stem = StemModel()
        self.ModelA_1 = InceptionA()
        self.ModelA_2 = InceptionA()
        self.ModelA_3 = InceptionA()
        self.ModelA_4 = InceptionA()
        self.reduction_a = Reduction_A()

        self.InceptionB_1 = InceptionB()
        self.InceptionB_2 = InceptionB()
        self.InceptionB_3 = InceptionB()
        self.InceptionB_4 = InceptionB()
        self.InceptionB_5 = InceptionB()
        self.InceptionB_6 = InceptionB()
        self.InceptionB_7 = InceptionB()

        self.Reduction_B = Reduction_B()

        self.InceptionC_1 = InceptionC()
        self.InceptionC_2 = InceptionC()
        self.InceptionC_3 = InceptionC()

        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.flatten = nn.Flatten()  # multi-dim -> one-dim
        self.fc = nn.Linear(in_features=1536, out_features=num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.ModelA_1(x)
        x = self.ModelA_2(x)
        x = self.ModelA_3(x)
        x = self.ModelA_4(x)
        x = self.reduction_a(x)
        x = self.InceptionB_1(x)
        x = self.Reduction_B(x)
        x = self.InceptionC_1(x)
        x = self.InceptionC_2(x)
        x = self.InceptionC_3(x)

        x = self.avg_pool(x)
        x = self.flatten(x)
        x = torch.dropout(x, 0.2, train=True)  # 论文中keep=0.8
        x = self.fc(x)
        x = torch.softmax(x, dim=1)
        return x


if __name__ == '__main__':
    # 根据第一层的输入要求来设定,第一个参数表示共20个branch
    input = torch.ones([20, 3, 299, 299])
    model = InceptionV4(num_classes=5)
    output = model(input)
    print(output.shape)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值