Pytorch中构建InceptionNet网络结构

直接看代码: 

import torch
import torch.nn as nn

'''
input: A

resnetV2: B = g(A) + f(A)

Inception:
B1 = f1(A)
B2 = f2(A)
B3 = f3(A)
B3 = f4(A)
concat([B1, B2, B3, B4])
'''


def ConvBNRelu(in_channel, out_channel, kernel_size):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel,
                  kernel_size=kernel_size,
                  stride=1,
                  padding=kernel_size // 2),
        nn.BatchNorm2d(out_channel),
        nn.ReLU()
    )


class BaseInception(nn.Module):
    def __init__(self,
                 in_channel,
                 out_channel_list,
                 reduce_channel_list):
        super(BaseInception, self).__init__()

        self.branch1_conv = ConvBNRelu(in_channel,
                                       out_channel_list[0],
                                       1)

        self.branch2_conv1 = ConvBNRelu(in_channel,
                                        reduce_channel_list[0],
                                        1)
        self.branch2_conv2 = ConvBNRelu(reduce_channel_list[0],
                                        out_channel_list[1],
                                        3)

        self.branch3_conv1 = ConvBNRelu(in_channel,
                                        reduce_channel_list[1],
                                        1)
        self.branch3_conv2 = ConvBNRelu(reduce_channel_list[1],
                                        out_channel_list[2],
                                        5)

        self.branch4_pool = nn.MaxPool2d(kernel_size=3,
                                         stride=1,
                                         padding=1)
        self.branch4_conv = ConvBNRelu(in_channel,
                                       out_channel_list[3],
                                       3)

    def forward(self, x):
        out1 = self.branch1_conv(x)

        out2 = self.branch2_conv1(x)
        out2 = self.branch2_conv2(out2)

        out3 = self.branch3_conv1(x)
        out3 = self.branch3_conv2(out3)

        out4 = self.branch4_pool(x)
        out4 = self.branch4_conv(out4)
        out = torch.cat([out1, out2, out3, out4], dim=1)
        return out


class InceptionNet(nn.Module):
    def __init__(self):
        super(InceptionNet, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64,
                      kernel_size=7,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128,
                      kernel_size=3,
                      stride=2,
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

        self.block3 = nn.Sequential(
            BaseInception(in_channel=128,
                          out_channel_list=[64, 64,
                                            64, 64],
                          reduce_channel_list=[16, 16]),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.block4 = nn.Sequential(
            BaseInception(in_channel=256,
                          out_channel_list=[96, 96,
                                            96, 96],
                          reduce_channel_list=[32, 32]),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.fc = nn.Linear(384, 10)

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = torch.nn.functional.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def InceptionNetSmall():
    return InceptionNet()

     

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浅蓝的风

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值