J9 - Inception v3算法



理论知识

InceptionV3版本是2015年发布的,InceptionV1的第三个版本。
V3版本的特点如下:

  • 更深层的网络:具有48层卷积层,这使得网络可以提取更多层次的特征。
  • Factorized Convolution(分解卷积):使用分解卷积,将大卷积核替换为多个小的卷积核,进一步降低了参数量,减少了计算的复杂度,保持良好的性能
  • 使用了BatchNormalization:V3版本在每个卷积层后都加了BN层,使模型更容易收敛并提升泛化能力。BN层的使用可以减少Internal covariate shift (内部协变量偏移),提高模型的训练速度,提升模型的鲁棒性。
  • 辅助分类器:V3版本引入了辅助分类器,在模型的中间引出一些特征构建辅助分类器,将辅助分类器的输出与主分类器的输出加权融合,得到最终的预测结果
  • RMSProp优化器:V3版本的Inception使用了RMSProp优化器,可以自适应的调节学习率,使训练过程更加稳定,收敛更快。

模型结构

InceptionV1版本中,作者将卷积通过1x1卷积降维后进行大卷积核计算,降低了计算量。在InceptionV3版本中继续做了改进。

  • 首先是将大核卷积改为多层小核卷积,比如把一个5x5的卷积变成两个3x3的卷积。虽然由一层变成了两层,但是一个5x5卷积的开销是1个3x3的2.78倍,所以这种改变仍有利于性能的提升。
    改小核

  • 其次是将NxN的卷积变为1xN + Nx1的两个卷积。3x3的卷积变成1x3 +3x1的卷积,可以节省33%的性能。
    1xn串行
    串行会让网络变的很深,可能会造成信息损失。因此InceptionV3采用了并行的拆解
    1xn并行
    最终模型的结构为:
    模型整体结构

模型实现

  • 首先是InceptionA模块,和InceptionV1一样
    InceptionA
class InceptionA(nn.Module):
    def __init__(self, in_channels, pool_features):
        super().__init__()
        
        self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)

        self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)

        self.branch3x3_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)

        self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)
        branch3x3 = self.branch3x3_3(branch3x3)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
        return torch.cat(outputs, 1)
  • 然后是InceptionB模块,串行拆分大卷积核
    InceptionB
class InceptionB(nn.Module):
    def __init__(self, in_channels, channels_7x7):
        super().__init__()

        self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)

        c7 = channels_7x7
        self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))

        self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))

        self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, 1)
  • 然后是InceptionC,并行拆分大卷积核
    InceptionC
class InceptionC(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)

        self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)
  • 还有用于降维的Reduction模块
    ReductionA
    ReductionB
class ReductionA(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)

    def forward(self, x):
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)

class ReductionB(nn.Module):
    def __init__(self, in_channels):
        super().__init__()

        self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2)

        self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2)

    def forward(self, x):
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
        outputs = [branch3x3, branch7x7x3, branch_pool]
        return torch.cat(outputs, 1)
  • 最后是辅助分类器
    辅助分类器
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        
        self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
        self.conv1 = BasicConv2d(128, 768, kernel_size=5)
        self.conv1.stddev = 0.01
        self.fc = nn.Linear(768, num_classes)
        self.fc.stddev = 0.001

    def forward(self, x):
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
        x = self.conv0(x)
        x = self.conv1(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

模型组装

class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

class InceptionV3(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=False, transform_input=False):
        super().__init__()

        self.aux_logits = aux_logits
        self.transform_input = transform_input
        self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
        self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
        self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
        self.Mixed_5b = InceptionA(192, pool_features=32)
        self.Mixed_5c = InceptionA(256, pool_features=64)
        self.Mixed_5d = InceptionA(288, pool_features=64)
        self.Mixed_6a = ReductionA(288)
        self.Mixed_6b = InceptionB(768, channels_7x7=128)
        self.Mixed_6c = InceptionB(768, channels_7x7=160)
        self.Mixed_6d = InceptionB(768, channels_7x7=160)
        self.Mixed_6e = InceptionB(768, channels_7x7=192)
        if aux_logits:
            self.AuxLogits = InceptionAux(768, num_classes)
        self.Mixed_7a = ReductionB(768)
        self.Mixed_7b = InceptionC(1280)
        self.Mixed_7c = InceptionC(2048)
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        if self.transform_input:
            x = x.clone()
            x[:, 0] = x[:, 0]*(0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x[:, 1] = x[:, 1]*(0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x[:, 2] = x[:, 2]*(0.225 / 0.5) + (0.406 - 0.5) / 0.5
        x = self.Conv2d_1a_3x3(x)
        x = self.Conv2d_2a_3x3(x)
        x = self.Conv2d_2b_3x3(x)
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        x = self.Conv2d_3b_1x1(x)
        x = self.Conv2d_4a_3x3(x)
        x = F.max_pool2d(x, kernel_size=3, stride=2)
        x = self.Mixed_5b(x)
        x = self.Mixed_5c(x)
        x = self.Mixed_5d(x)
        x = self.Mixed_6a(x)
        x = self.Mixed_6b(x)
        x = self.Mixed_6c(x)
        x = self.Mixed_6d(x)
        x = self.Mixed_6e(x)
        if self.training and self.aux_logits:
            aux = self.AuxLogits(x)
        x = self.Mixed_7a(x)
        x = self.Mixed_7b(x)
        x = self.Mixed_7c(x)
        x = F.avg_pool2d(x, kernel_size=8)
        x = F.dropout(x, training=self.training)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        if self.training and self.aux_logits:
            return x, aux
        return x

模型结构

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
InceptionV3                              [32, 3]                   --
├─BasicConv2d: 1-1                       [32, 32, 149, 149]        --
│    └─Conv2d: 2-1                       [32, 32, 149, 149]        864
│    └─BatchNorm2d: 2-2                  [32, 32, 149, 149]        64
├─BasicConv2d: 1-2                       [32, 32, 147, 147]        --
│    └─Conv2d: 2-3                       [32, 32, 147, 147]        9,216
│    └─BatchNorm2d: 2-4                  [32, 32, 147, 147]        64
├─BasicConv2d: 1-3                       [32, 64, 147, 147]        --
│    └─Conv2d: 2-5                       [32, 64, 147, 147]        18,432
│    └─BatchNorm2d: 2-6                  [32, 64, 147, 147]        128
├─BasicConv2d: 1-4                       [32, 80, 73, 73]          --
│    └─Conv2d: 2-7                       [32, 80, 73, 73]          5,120
│    └─BatchNorm2d: 2-8                  [32, 80, 73, 73]          160
├─BasicConv2d: 1-5                       [32, 192, 71, 71]         --
│    └─Conv2d: 2-9                       [32, 192, 71, 71]         138,240
│    └─BatchNorm2d: 2-10                 [32, 192, 71, 71]         384
├─InceptionA: 1-6                        [32, 256, 35, 35]         --
│    └─BasicConv2d: 2-11                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-1                  [32, 64, 35, 35]          12,288
│    │    └─BatchNorm2d: 3-2             [32, 64, 35, 35]          128
│    └─BasicConv2d: 2-12                 [32, 48, 35, 35]          --
│    │    └─Conv2d: 3-3                  [32, 48, 35, 35]          9,216
│    │    └─BatchNorm2d: 3-4             [32, 48, 35, 35]          96
│    └─BasicConv2d: 2-13                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-5                  [32, 64, 35, 35]          76,800
│    │    └─BatchNorm2d: 3-6             [32, 64, 35, 35]          128
│    └─BasicConv2d: 2-14                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-7                  [32, 64, 35, 35]          12,288
│    │    └─BatchNorm2d: 3-8             [32, 64, 35, 35]          128
│    └─BasicConv2d: 2-15                 [32, 96, 35, 35]          --
│    │    └─Conv2d: 3-9                  [32, 96, 35, 35]          55,296
│    │    └─BatchNorm2d: 3-10            [32, 96, 35, 35]          192
│    └─BasicConv2d: 2-16                 [32, 96, 35, 35]          --
│    │    └─Conv2d: 3-11                 [32, 96, 35, 35]          82,944
│    │    └─BatchNorm2d: 3-12            [32, 96, 35, 35]          192
│    └─BasicConv2d: 2-17                 [32, 32, 35, 35]          --
│    │    └─Conv2d: 3-13                 [32, 32, 35, 35]          6,144
│    │    └─BatchNorm2d: 3-14            [32, 32, 35, 35]          64
├─InceptionA: 1-7                        [32, 288, 35, 35]         --
│    └─BasicConv2d: 2-18                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-15                 [32, 64, 35, 35]          16,384
│    │    └─BatchNorm2d: 3-16            [32, 64, 35, 35]          128
│    └─BasicConv2d: 2-19                 [32, 48, 35, 35]          --
│    │    └─Conv2d: 3-17                 [32, 48, 35, 35]          12,288
│    │    └─BatchNorm2d: 3-18            [32, 48, 35, 35]          96
│    └─BasicConv2d: 2-20                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-19                 [32, 64, 35, 35]          76,800
│    │    └─BatchNorm2d: 3-20            [32, 64, 35, 35]          128
│    └─BasicConv2d: 2-21                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-21                 [32, 64, 35, 35]          16,384
│    │    └─BatchNorm2d: 3-22            [32, 64, 35, 35]          128
│    └─BasicConv2d: 2-22                 [32, 96, 35, 35]          --
│    │    └─Conv2d: 3-23                 [32, 96, 35, 35]          55,296
│    │    └─BatchNorm2d: 3-24            [32, 96, 35, 35]          192
│    └─BasicConv2d: 2-23                 [32, 96, 35, 35]          --
│    │    └─Conv2d: 3-25                 [32, 96, 35, 35]          82,944
│    │    └─BatchNorm2d: 3-26            [32, 96, 35, 35]          192
│    └─BasicConv2d: 2-24                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-27                 [32, 64, 35, 35]          16,384
│    │    └─BatchNorm2d: 3-28            [32, 64, 35, 35]          128
├─InceptionA: 1-8                        [32, 288, 35, 35]         --
│    └─BasicConv2d: 2-25                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-29                 [32, 64, 35, 35]          18,432
│    │    └─BatchNorm2d: 3-30            [32, 64, 35, 35]          128
│    └─BasicConv2d: 2-26                 [32, 48, 35, 35]          --
│    │    └─Conv2d: 3-31                 [32, 48, 35, 35]          13,824
│    │    └─BatchNorm2d: 3-32            [32, 48, 35, 35]          96
│    └─BasicConv2d: 2-27                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-33                 [32, 64, 35, 35]          76,800
│    │    └─BatchNorm2d: 3-34            [32, 64, 35, 35]          128
│    └─BasicConv2d: 2-28                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-35                 [32, 64, 35, 35]          18,432
│    │    └─BatchNorm2d: 3-36            [32, 64, 35, 35]          128
│    └─BasicConv2d: 2-29                 [32, 96, 35, 35]          --
│    │    └─Conv2d: 3-37                 [32, 96, 35, 35]          55,296
│    │    └─BatchNorm2d: 3-38            [32, 96, 35, 35]          192
│    └─BasicConv2d: 2-30                 [32, 96, 35, 35]          --
│    │    └─Conv2d: 3-39                 [32, 96, 35, 35]          82,944
│    │    └─BatchNorm2d: 3-40            [32, 96, 35, 35]          192
│    └─BasicConv2d: 2-31                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-41                 [32, 64, 35, 35]          18,432
│    │    └─BatchNorm2d: 3-42            [32, 64, 35, 35]          128
├─ReductionA: 1-9                        [32, 768, 17, 17]         --
│    └─BasicConv2d: 2-32                 [32, 384, 17, 17]         --
│    │    └─Conv2d: 3-43                 [32, 384, 17, 17]         995,328
│    │    └─BatchNorm2d: 3-44            [32, 384, 17, 17]         768
│    └─BasicConv2d: 2-33                 [32, 64, 35, 35]          --
│    │    └─Conv2d: 3-45                 [32, 64, 35, 35]          18,432
│    │    └─BatchNorm2d: 3-46            [32, 64, 35, 35]          128
│    └─BasicConv2d: 2-34                 [32, 96, 35, 35]          --
│    │    └─Conv2d: 3-47                 [32, 96, 35, 35]          55,296
│    │    └─BatchNorm2d: 3-48            [32, 96, 35, 35]          192
│    └─BasicConv2d: 2-35                 [32, 96, 17, 17]          --
│    │    └─Conv2d: 3-49                 [32, 96, 17, 17]          82,944
│    │    └─BatchNorm2d: 3-50            [32, 96, 17, 17]          192
├─InceptionB: 1-10                       [32, 768, 17, 17]         --
│    └─BasicConv2d: 2-36                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-51                 [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-52            [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-37                 [32, 128, 17, 17]         --
│    │    └─Conv2d: 3-53                 [32, 128, 17, 17]         98,304
│    │    └─BatchNorm2d: 3-54            [32, 128, 17, 17]         256
│    └─BasicConv2d: 2-38                 [32, 128, 17, 17]         --
│    │    └─Conv2d: 3-55                 [32, 128, 17, 17]         114,688
│    │    └─BatchNorm2d: 3-56            [32, 128, 17, 17]         256
│    └─BasicConv2d: 2-39                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-57                 [32, 192, 17, 17]         172,032
│    │    └─BatchNorm2d: 3-58            [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-40                 [32, 128, 17, 17]         --
│    │    └─Conv2d: 3-59                 [32, 128, 17, 17]         98,304
│    │    └─BatchNorm2d: 3-60            [32, 128, 17, 17]         256
│    └─BasicConv2d: 2-41                 [32, 128, 17, 17]         --
│    │    └─Conv2d: 3-61                 [32, 128, 17, 17]         114,688
│    │    └─BatchNorm2d: 3-62            [32, 128, 17, 17]         256
│    └─BasicConv2d: 2-42                 [32, 128, 17, 17]         --
│    │    └─Conv2d: 3-63                 [32, 128, 17, 17]         114,688
│    │    └─BatchNorm2d: 3-64            [32, 128, 17, 17]         256
│    └─BasicConv2d: 2-43                 [32, 128, 17, 17]         --
│    │    └─Conv2d: 3-65                 [32, 128, 17, 17]         114,688
│    │    └─BatchNorm2d: 3-66            [32, 128, 17, 17]         256
│    └─BasicConv2d: 2-44                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-67                 [32, 192, 17, 17]         172,032
│    │    └─BatchNorm2d: 3-68            [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-45                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-69                 [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-70            [32, 192, 17, 17]         384
├─InceptionB: 1-11                       [32, 768, 17, 17]         --
│    └─BasicConv2d: 2-46                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-71                 [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-72            [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-47                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-73                 [32, 160, 17, 17]         122,880
│    │    └─BatchNorm2d: 3-74            [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-48                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-75                 [32, 160, 17, 17]         179,200
│    │    └─BatchNorm2d: 3-76            [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-49                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-77                 [32, 192, 17, 17]         215,040
│    │    └─BatchNorm2d: 3-78            [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-50                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-79                 [32, 160, 17, 17]         122,880
│    │    └─BatchNorm2d: 3-80            [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-51                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-81                 [32, 160, 17, 17]         179,200
│    │    └─BatchNorm2d: 3-82            [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-52                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-83                 [32, 160, 17, 17]         179,200
│    │    └─BatchNorm2d: 3-84            [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-53                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-85                 [32, 160, 17, 17]         179,200
│    │    └─BatchNorm2d: 3-86            [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-54                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-87                 [32, 192, 17, 17]         215,040
│    │    └─BatchNorm2d: 3-88            [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-55                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-89                 [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-90            [32, 192, 17, 17]         384
├─InceptionB: 1-12                       [32, 768, 17, 17]         --
│    └─BasicConv2d: 2-56                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-91                 [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-92            [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-57                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-93                 [32, 160, 17, 17]         122,880
│    │    └─BatchNorm2d: 3-94            [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-58                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-95                 [32, 160, 17, 17]         179,200
│    │    └─BatchNorm2d: 3-96            [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-59                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-97                 [32, 192, 17, 17]         215,040
│    │    └─BatchNorm2d: 3-98            [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-60                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-99                 [32, 160, 17, 17]         122,880
│    │    └─BatchNorm2d: 3-100           [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-61                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-101                [32, 160, 17, 17]         179,200
│    │    └─BatchNorm2d: 3-102           [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-62                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-103                [32, 160, 17, 17]         179,200
│    │    └─BatchNorm2d: 3-104           [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-63                 [32, 160, 17, 17]         --
│    │    └─Conv2d: 3-105                [32, 160, 17, 17]         179,200
│    │    └─BatchNorm2d: 3-106           [32, 160, 17, 17]         320
│    └─BasicConv2d: 2-64                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-107                [32, 192, 17, 17]         215,040
│    │    └─BatchNorm2d: 3-108           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-65                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-109                [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-110           [32, 192, 17, 17]         384
├─InceptionB: 1-13                       [32, 768, 17, 17]         --
│    └─BasicConv2d: 2-66                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-111                [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-112           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-67                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-113                [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-114           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-68                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-115                [32, 192, 17, 17]         258,048
│    │    └─BatchNorm2d: 3-116           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-69                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-117                [32, 192, 17, 17]         258,048
│    │    └─BatchNorm2d: 3-118           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-70                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-119                [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-120           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-71                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-121                [32, 192, 17, 17]         258,048
│    │    └─BatchNorm2d: 3-122           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-72                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-123                [32, 192, 17, 17]         258,048
│    │    └─BatchNorm2d: 3-124           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-73                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-125                [32, 192, 17, 17]         258,048
│    │    └─BatchNorm2d: 3-126           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-74                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-127                [32, 192, 17, 17]         258,048
│    │    └─BatchNorm2d: 3-128           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-75                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-129                [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-130           [32, 192, 17, 17]         384
├─ReductionB: 1-14                       [32, 1280, 8, 8]          --
│    └─BasicConv2d: 2-76                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-131                [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-132           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-77                 [32, 320, 8, 8]           --
│    │    └─Conv2d: 3-133                [32, 320, 8, 8]           552,960
│    │    └─BatchNorm2d: 3-134           [32, 320, 8, 8]           640
│    └─BasicConv2d: 2-78                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-135                [32, 192, 17, 17]         147,456
│    │    └─BatchNorm2d: 3-136           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-79                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-137                [32, 192, 17, 17]         258,048
│    │    └─BatchNorm2d: 3-138           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-80                 [32, 192, 17, 17]         --
│    │    └─Conv2d: 3-139                [32, 192, 17, 17]         258,048
│    │    └─BatchNorm2d: 3-140           [32, 192, 17, 17]         384
│    └─BasicConv2d: 2-81                 [32, 192, 8, 8]           --
│    │    └─Conv2d: 3-141                [32, 192, 8, 8]           331,776
│    │    └─BatchNorm2d: 3-142           [32, 192, 8, 8]           384
├─InceptionC: 1-15                       [32, 2048, 8, 8]          --
│    └─BasicConv2d: 2-82                 [32, 320, 8, 8]           --
│    │    └─Conv2d: 3-143                [32, 320, 8, 8]           409,600
│    │    └─BatchNorm2d: 3-144           [32, 320, 8, 8]           640
│    └─BasicConv2d: 2-83                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-145                [32, 384, 8, 8]           491,520
│    │    └─BatchNorm2d: 3-146           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-84                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-147                [32, 384, 8, 8]           442,368
│    │    └─BatchNorm2d: 3-148           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-85                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-149                [32, 384, 8, 8]           442,368
│    │    └─BatchNorm2d: 3-150           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-86                 [32, 448, 8, 8]           --
│    │    └─Conv2d: 3-151                [32, 448, 8, 8]           573,440
│    │    └─BatchNorm2d: 3-152           [32, 448, 8, 8]           896
│    └─BasicConv2d: 2-87                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-153                [32, 384, 8, 8]           1,548,288
│    │    └─BatchNorm2d: 3-154           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-88                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-155                [32, 384, 8, 8]           442,368
│    │    └─BatchNorm2d: 3-156           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-89                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-157                [32, 384, 8, 8]           442,368
│    │    └─BatchNorm2d: 3-158           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-90                 [32, 192, 8, 8]           --
│    │    └─Conv2d: 3-159                [32, 192, 8, 8]           245,760
│    │    └─BatchNorm2d: 3-160           [32, 192, 8, 8]           384
├─InceptionC: 1-16                       [32, 2048, 8, 8]          --
│    └─BasicConv2d: 2-91                 [32, 320, 8, 8]           --
│    │    └─Conv2d: 3-161                [32, 320, 8, 8]           655,360
│    │    └─BatchNorm2d: 3-162           [32, 320, 8, 8]           640
│    └─BasicConv2d: 2-92                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-163                [32, 384, 8, 8]           786,432
│    │    └─BatchNorm2d: 3-164           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-93                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-165                [32, 384, 8, 8]           442,368
│    │    └─BatchNorm2d: 3-166           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-94                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-167                [32, 384, 8, 8]           442,368
│    │    └─BatchNorm2d: 3-168           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-95                 [32, 448, 8, 8]           --
│    │    └─Conv2d: 3-169                [32, 448, 8, 8]           917,504
│    │    └─BatchNorm2d: 3-170           [32, 448, 8, 8]           896
│    └─BasicConv2d: 2-96                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-171                [32, 384, 8, 8]           1,548,288
│    │    └─BatchNorm2d: 3-172           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-97                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-173                [32, 384, 8, 8]           442,368
│    │    └─BatchNorm2d: 3-174           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-98                 [32, 384, 8, 8]           --
│    │    └─Conv2d: 3-175                [32, 384, 8, 8]           442,368
│    │    └─BatchNorm2d: 3-176           [32, 384, 8, 8]           768
│    └─BasicConv2d: 2-99                 [32, 192, 8, 8]           --
│    │    └─Conv2d: 3-177                [32, 192, 8, 8]           393,216
│    │    └─BatchNorm2d: 3-178           [32, 192, 8, 8]           384
├─Linear: 1-17                           [32, 3]                   6,147
==========================================================================================
Total params: 21,791,715
Trainable params: 21,791,715
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 182.76
==========================================================================================
Input size (MB): 34.33
Forward/backward pass size (MB): 4591.35
Params size (MB): 87.17
Estimated Total Size (MB): 4712.85
==========================================================================================

模型效果

训练过程

Epoch:  1, Train_acc: 61.1%, Train_loss: 0.737, Test_acc: 68.5%, Test_loss:0.653
Epoch:  2, Train_acc: 61.2%, Train_loss: 0.681, Test_acc: 66.9%, Test_loss:0.724
Epoch:  3, Train_acc: 63.3%, Train_loss: 0.665, Test_acc: 70.4%, Test_loss:0.573
Epoch:  4, Train_acc: 66.4%, Train_loss: 0.626, Test_acc: 66.2%, Test_loss:0.942
Epoch:  5, Train_acc: 63.3%, Train_loss: 0.656, Test_acc: 73.2%, Test_loss:0.584
Epoch:  6, Train_acc: 66.1%, Train_loss: 0.630, Test_acc: 70.9%, Test_loss:0.595
Epoch:  7, Train_acc: 65.4%, Train_loss: 0.640, Test_acc: 73.2%, Test_loss:0.553
Epoch:  8, Train_acc: 66.0%, Train_loss: 0.616, Test_acc: 69.7%, Test_loss:0.562
Epoch:  9, Train_acc: 66.4%, Train_loss: 0.622, Test_acc: 71.8%, Test_loss:0.576
Epoch: 10, Train_acc: 67.8%, Train_loss: 0.599, Test_acc: 71.6%, Test_loss:0.595
Epoch: 11, Train_acc: 66.0%, Train_loss: 0.603, Test_acc: 66.2%, Test_loss:0.604
Epoch: 12, Train_acc: 67.4%, Train_loss: 0.594, Test_acc: 69.5%, Test_loss:0.589
Epoch: 13, Train_acc: 69.1%, Train_loss: 0.579, Test_acc: 74.6%, Test_loss:0.528
Epoch: 14, Train_acc: 71.1%, Train_loss: 0.567, Test_acc: 66.9%, Test_loss:0.603
Epoch: 15, Train_acc: 71.9%, Train_loss: 0.557, Test_acc: 71.8%, Test_loss:0.553
Epoch: 16, Train_acc: 71.5%, Train_loss: 0.555, Test_acc: 77.2%, Test_loss:0.505
Epoch: 17, Train_acc: 74.4%, Train_loss: 0.504, Test_acc: 76.7%, Test_loss:0.499
Epoch: 18, Train_acc: 76.9%, Train_loss: 0.504, Test_acc: 78.6%, Test_loss:0.494
Epoch: 19, Train_acc: 75.1%, Train_loss: 0.500, Test_acc: 73.0%, Test_loss:0.530
Epoch: 20, Train_acc: 77.2%, Train_loss: 0.488, Test_acc: 80.4%, Test_loss:0.443
Epoch: 21, Train_acc: 81.1%, Train_loss: 0.430, Test_acc: 76.9%, Test_loss:0.439
Epoch: 22, Train_acc: 82.3%, Train_loss: 0.410, Test_acc: 81.1%, Test_loss:0.464
Epoch: 23, Train_acc: 84.5%, Train_loss: 0.368, Test_acc: 85.3%, Test_loss:0.384
Epoch: 24, Train_acc: 84.6%, Train_loss: 0.369, Test_acc: 88.8%, Test_loss:0.331
Epoch: 25, Train_acc: 87.0%, Train_loss: 0.320, Test_acc: 90.0%, Test_loss:0.255
Epoch: 26, Train_acc: 86.6%, Train_loss: 0.327, Test_acc: 88.1%, Test_loss:0.272
Epoch: 27, Train_acc: 87.4%, Train_loss: 0.322, Test_acc: 87.2%, Test_loss:0.309
Epoch: 28, Train_acc: 89.0%, Train_loss: 0.284, Test_acc: 88.3%, Test_loss:0.334
Epoch: 29, Train_acc: 88.4%, Train_loss: 0.275, Test_acc: 89.7%, Test_loss:0.259
Epoch: 30, Train_acc: 88.3%, Train_loss: 0.285, Test_acc: 90.2%, Test_loss:0.260
Epoch: 31, Train_acc: 91.7%, Train_loss: 0.221, Test_acc: 88.6%, Test_loss:0.280
Epoch: 32, Train_acc: 91.2%, Train_loss: 0.226, Test_acc: 91.4%, Test_loss:0.252
Epoch: 33, Train_acc: 91.8%, Train_loss: 0.204, Test_acc: 90.9%, Test_loss:0.229
Epoch: 34, Train_acc: 89.2%, Train_loss: 0.262, Test_acc: 88.3%, Test_loss:0.366
Epoch: 35, Train_acc: 91.0%, Train_loss: 0.204, Test_acc: 86.5%, Test_loss:0.349
Epoch: 36, Train_acc: 92.0%, Train_loss: 0.212, Test_acc: 90.9%, Test_loss:0.275
Epoch: 37, Train_acc: 94.2%, Train_loss: 0.159, Test_acc: 92.3%, Test_loss:0.244
Epoch: 38, Train_acc: 94.7%, Train_loss: 0.141, Test_acc: 89.7%, Test_loss:0.273
Epoch: 39, Train_acc: 93.3%, Train_loss: 0.166, Test_acc: 90.7%, Test_loss:0.255
Epoch: 40, Train_acc: 94.2%, Train_loss: 0.157, Test_acc: 90.4%, Test_loss:0.233
Epoch: 41, Train_acc: 96.0%, Train_loss: 0.121, Test_acc: 92.8%, Test_loss:0.244
Epoch: 42, Train_acc: 94.5%, Train_loss: 0.143, Test_acc: 88.1%, Test_loss:0.389
Epoch: 43, Train_acc: 94.6%, Train_loss: 0.143, Test_acc: 92.1%, Test_loss:0.241
Epoch: 44, Train_acc: 94.2%, Train_loss: 0.146, Test_acc: 92.5%, Test_loss:0.226
Epoch: 45, Train_acc: 95.4%, Train_loss: 0.122, Test_acc: 86.0%, Test_loss:0.383
Epoch: 46, Train_acc: 94.0%, Train_loss: 0.158, Test_acc: 92.8%, Test_loss:0.239
Epoch: 47, Train_acc: 95.1%, Train_loss: 0.137, Test_acc: 89.5%, Test_loss:0.320
Epoch: 48, Train_acc: 97.0%, Train_loss: 0.086, Test_acc: 92.5%, Test_loss:0.274
Epoch: 49, Train_acc: 97.3%, Train_loss: 0.078, Test_acc: 87.9%, Test_loss:0.349
Epoch: 50, Train_acc: 95.0%, Train_loss: 0.137, Test_acc: 92.3%, Test_loss:0.239
Done

训练曲线

在这里插入图片描述

总结与心得体会

通过对InceptionV1和InceptionV3的对比可以发现,InceptionV1的模型结构十分简单,InceptionV3模型的结构就复杂了很多。然而InceptionV3的性能在InceptionV1的基础上略有提升,但不不是非常明显。

通过对InceptionV3模型的学习,发现两个在之前的学习中从来没有遇到的结构:

  1. 是并行1xn和nx1的卷积来替换nxn卷积,在精度不变的情况下可以降低很多参数量和计算量。
  2. 是使用辅助分类器,通过辅助分类器和主分类器加权融合形成最终的预测结果,可以使模型的浅层模块也能较好的得到训练,和Residual模块相比,也是解决深层网络模型训练问题的一个思路。
  • 16
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值