Inception-v4(GoogLeNet-v4)模型框架(PyTorch)

I. 前言

Inception-v4,又名GoogLeNet-v4,论文地址:Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning,该论文将Inception-v3与Inception-v4相比,此外还与ResNet结合,提出了Inception-ResNet-A及Inception-ResNet-B两个模型。本文只针对Inception-v4部分的模型框架进行复现。

II. 模型构架图

在这里插入图片描述

III. 各部分构架图

1. Stem


【注】标“V”即valid-padding,padding=0,否则为same-padding,需手动计算出每处所需的padding数.

2. Inception-A

在这里插入图片描述
【注】此处的Avg Pooling的kernel_size为3, padding为1, stride为1,下同.

3. Inception-B

在这里插入图片描述
【注】不对称卷积核的padding,如conv17的padding为(0,3),conv71的padding为(3,0),conv13及conv31的padding分别为(0,1)及(1,0).

4. Inception-C

在这里插入图片描述

5. Reduction-A

在这里插入图片描述
【注】此处根据论文中的表格,k=192、l=224、m=256、n=384.

6. Reduction-B

在这里插入图片描述

IV. 代码复现

import torch
import torch.nn as nn
import torch.nn.functional as F

定义一个卷积模块(带BatchNormalization及ReLU激活函数)

class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x)

InceptionA模块

class InceptionA(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(InceptionA, self).__init__()
        #branch1: avgpool --> conv1*1(96)
        self.b1_1 = nn.AvgPool2d(kernel_size=3, padding=1, stride=1)
        self.b1_2 = BasicConv2d(in_channels, 96, kernel_size=1)
        
        #branch2: conv1*1(96)
        self.b2 = BasicConv2d(in_channels, 96, kernel_size=1)
        
        #branch3: conv1*1(64) --> conv3*3(96)
        self.b3_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.b3_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        
        #branch4: conv1*1(64) --> conv3*3(96) --> conv3*3(96)
        self.b4_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.b4_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.b4_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)
        
    def forward(self, x):
        y1 = self.b1_2(self.b1_1(x))
        y2 = self.b2(x)
        y3 = self.b3_2(self.b3_1(x))
        y4 = self.b4_3(self.b4_2(self.b4_1(x)))
        
        outputsA = [y1, y2, y3, y4]
        return torch.cat(outputsA, 1)

InceptionB模块

class InceptionB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(InceptionB, self).__init__()
        #branch1: avgpool --> conv1*1(128)
        self.b1_1 = nn.AvgPool2d(kernel_size=3, padding=1, stride=1)
        self.b1_2 = BasicConv2d(in_channels, 128, kernel_size=1)
        
        #branch2: conv1*1(384)
        self.b2 = BasicConv2d(in_channels, 384, kernel_size=1)
        
        #branch3: conv1*1(192) --> conv1*7(224) --> conv1*7(256)
        self.b3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.b3_2 = BasicConv2d(192, 224, kernel_size=(1,7), padding=(0,3))
        self.b3_3 = BasicConv2d(224, 256, kernel_size=(1,7), padding=(0,3))
        
        #branch4: conv1*1(192) --> conv1*7(192) --> conv7*1(224) --> conv1*7(224) --> conv7*1(256)
        self.b4_1 = BasicConv2d(in_channels, 192, kernel_size=1, stride=1)
        self.b4_2 = BasicConv2d(192, 192, kernel_size=(1,7), padding=(0,3))
        self.b4_3 = BasicConv2d(192, 224, kernel_size=(7,1), padding=(3,0))
        self.b4_4 = BasicConv2d(224, 224, kernel_size=(1,7), padding=(0,3))
        self.b4_5 = BasicConv2d(224, 256, kernel_size=(7,1), padding=(3,0))
        
    def forward(self, x):
        y1 = self.b1_2(self.b1_1(x))
        y2 = self.b2(x)
        y3 = self.b3_3(self.b3_2(self.b3_1(x)))
        y4 = self.b4_5(self.b4_4(self.b4_3(self.b4_2(self.b4_1(x)))))
        
        outputsB = [y1, y2, y3, y4]
        return torch.cat(outputsB, 1)

InceptionC模块

class InceptionC(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(InceptionC, self).__init__()
        #branch1: avgpool --> conv1*1(256)
        self.b1_1 = nn.AvgPool2d(kernel_size=3, padding=1, stride=1)
        self.b1_2 = BasicConv2d(in_channels, 256, kernel_size=1)
        
        #branch2: conv1*1(256)
        self.b2 = BasicConv2d(in_channels, 256, kernel_size=1)
        
        #branch3: conv1*1(384) --> conv1*3(256) & conv3*1(256)
        self.b3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
        self.b3_2_1 = BasicConv2d(384, 256, kernel_size=(1,3), padding=(0,1))
        self.b3_2_2 = BasicConv2d(384, 256, kernel_size=(3,1), padding=(1,0))
        
        #branch4: conv1*1(384) --> conv1*3(448) --> conv3*1(512) --> conv3*1(256) & conv7*1(256)
        self.b4_1 = BasicConv2d(in_channels, 384, kernel_size=1, stride=1)
        self.b4_2 = BasicConv2d(384, 448, kernel_size=(1,3), padding=(0,1))
        self.b4_3 = BasicConv2d(448, 512, kernel_size=(3,1), padding=(1,0))
        self.b4_4_1 = BasicConv2d(512, 256, kernel_size=(3,1), padding=(1,0))
        self.b4_4_2 = BasicConv2d(512, 256, kernel_size=(1,3), padding=(0,1))
        
    def forward(self, x):
        y1 = self.b1_2(self.b1_1(x))
        y2 = self.b2(x)
        y3_1 = self.b3_2_1(self.b3_1(x))
        y3_2 = self.b3_2_2(self.b3_1(x))
        y4_1 = self.b4_4_1(self.b4_3(self.b4_2(self.b4_1(x))))
        y4_2 = self.b4_4_2(self.b4_3(self.b4_2(self.b4_1(x))))
        
        outputsC = [y1, y2, y3_1, y3_2, y4_1, y4_2]
        return torch.cat(outputsC, 1)

ReductionA模块

class ReductionA(nn.Module):
    def __init__(self, in_channels, out_channels, k, l, m, n):
        super(ReductionA, self).__init__()
        #branch1: maxpool3*3(stride2 valid)
        self.b1 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        #branch2: conv3*3(n stride2 valid)
        self.b2 = BasicConv2d(in_channels, n, kernel_size=3, stride=2)
        
        #branch3: conv1*1(k) --> conv3*3(l) --> conv3*3(m stride2 valid)
        self.b3_1 = BasicConv2d(in_channels, k, kernel_size=1)
        self.b3_2 = BasicConv2d(k, l, kernel_size=3, padding=1)
        self.b3_3 = BasicConv2d(l, m, kernel_size=3, stride=2)
        
    def forward(self, x):
        y1 = self.b1(x)
        y2 = self.b2(x)
        y3 = self.b3_3(self.b3_2(self.b3_1(x)))
        
        outputsRedA = [y1, y2, y3]
        return torch.cat(outputsRedA, 1)

ReductionB模块

class ReductionB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ReductionB, self).__init__()
        #branch1: maxpool3*3(stride2 valid)
        self.b1 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        #branch2: conv1*1(192) --> conv3*3(192 stride2 valid)
        self.b2_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.b2_2 = BasicConv2d(192, 192, kernel_size=3, stride=2)
        
        #branch3: conv1*1(256) --> conv1*7(256) --> conv7*1(320) --> conv3*3(320 stride2 valid)
        self.b3_1 = BasicConv2d(in_channels, 256, kernel_size=1)
        self.b3_2 = BasicConv2d(256, 256, kernel_size=(1,7), padding=(0,3))
        self.b3_3 = BasicConv2d(256, 320, kernel_size=(7,1), padding=(3,0))
        self.b3_4 = BasicConv2d(320, 320, kernel_size=3, stride=2)
        
    def forward(self, x):
        y1 = self.b1(x)
        y2 = self.b2_2(self.b2_1((x)))
        y3 = self.b3_4(self.b3_3(self.b3_2(self.b3_1(x))))
        
        outputsRedB = [y1, y2, y3]
        return torch.cat(outputsRedB, 1)

Stem模块

class Stem(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Stem, self).__init__()
        #conv3*3(32 stride2 valid)
        self.conv1 = BasicConv2d(in_channels, 32, kernel_size=3, stride=2)
        #conv3*3(32 valid)
        self.conv2 = BasicConv2d(32, 32, kernel_size=3)
        #conv3*3(64)
        self.conv3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        #maxpool3*3(stride2 valid) & conv3*3(96 stride2 valid)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv4 = BasicConv2d(64, 96, kernel_size=3, stride=2)
        
        #conv1*1(64) --> conv3*3(96 valid)
        self.conv5_1_1 = BasicConv2d(160, 64, kernel_size=1)
        self.conv5_1_2 = BasicConv2d(64, 96, kernel_size=3)
        #conv1*1(64) --> conv7*1(64) --> conv1*7(64) --> conv3*3(96 valid)
        self.conv5_2_1 = BasicConv2d(160, 64, kernel_size=1)
        self.conv5_2_2 = BasicConv2d(64, 64, kernel_size=(7,1), padding=(3,0))
        self.conv5_2_3 = BasicConv2d(64, 64, kernel_size=(1,7), padding=(0,3))
        self.conv5_2_4 = BasicConv2d(64, 96, kernel_size=3)
        
        #conv3*3(192 valid)
        self.conv6 = BasicConv2d(192, 192, kernel_size=3, stride=2)
        #maxpool3*3(stride2 valid)
        self.maxpool6 = nn.MaxPool2d(kernel_size=3, stride=2)
        
    def forward(self, x):
        y1_1 = self.maxpool4(self.conv3(self.conv2(self.conv1(x))))
        y1_2 = self.conv4(self.conv3(self.conv2(self.conv1(x))))
        y1 = torch.cat([y1_1, y1_2], 1)
        
        y2_1 = self.conv5_1_2(self.conv5_1_1(y1))
        y2_2 = self.conv5_2_4(self.conv5_2_3(self.conv5_2_2(self.conv5_2_1(y1))))
        y2 = torch.cat([y2_1, y2_2], 1)
        
        y3_1 = self.conv6(y2)
        y3_2 = self.maxpool6(y2)
        y3 = torch.cat([y3_1, y3_2], 1)
        
        return y3

定义网络模型,将上述模块按构架图组装一起

class Googlenetv4(nn.Module):
    def __init__(self):
        super(Googlenetv4, self).__init__()
        self.stem = Stem(3, 384)
        self.icpA = InceptionA(384, 384)
        self.redA = ReductionA(384, 1024, 192, 224, 256, 384)
        self.icpB = InceptionB(1024, 1024)
        self.redB = ReductionB(1024, 1536)
        self.icpC = InceptionC(1536, 1536)
        self.avgpool = nn.AvgPool2d(kernel_size=8)
        self.dropout = nn.Dropout(p=0.8)
        self.linear = nn.Linear(1536, 1000)
        
    def forward(self, x):
        #Stem Module
        out = self.stem(x)
        #InceptionA Module * 4
        out = self.icpA(self.icpA(self.icpA(self.icpA(out))))
        #ReductionA Module
        out = self.redA(out)
        #InceptionB Module * 7
        out = self.icpB(self.icpB(self.icpB(self.icpB(self.icpB(self.icpB(self.icpB(out)))))))
        #ReductionB Module
        out = self.redB(out)
        #InceptionC Module * 3
        out = self.icpC(self.icpC(self.icpC(out)))
        #Average Pooling
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        #Dropout
        out = self.dropout(out)
        #Linear(Softmax)
        out = self.linear(out)
        
        return out

V. 测试部分及结果

def test():
    x = torch.randn(1, 3, 299, 299)
    net = Googlenetv4()
    y = net(x)
    print(y.size())
test()
torch.Size([1, 1000])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值