- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
理论知识
InceptionV3版本是2015年发布的,InceptionV1的第三个版本。
V3版本的特点如下:
- 更深层的网络:具有48层卷积层,这使得网络可以提取更多层次的特征。
- Factorized Convolution(分解卷积):使用分解卷积,将大卷积核替换为多个小的卷积核,进一步降低了参数量,减少了计算的复杂度,保持良好的性能
- 使用了BatchNormalization:V3版本在每个卷积层后都加了BN层,使模型更容易收敛并提升泛化能力。BN层的使用可以减少Internal covariate shift (内部协变量偏移),提高模型的训练速度,提升模型的鲁棒性。
- 辅助分类器:V3版本引入了辅助分类器,在模型的中间引出一些特征构建辅助分类器,将辅助分类器的输出与主分类器的输出加权融合,得到最终的预测结果
- RMSProp优化器:V3版本的Inception使用了RMSProp优化器,可以自适应的调节学习率,使训练过程更加稳定,收敛更快。
模型结构
InceptionV1版本中,作者将卷积通过1x1卷积降维后进行大卷积核计算,降低了计算量。在InceptionV3版本中继续做了改进。
-
首先是将大核卷积改为多层小核卷积,比如把一个5x5的卷积变成两个3x3的卷积。虽然由一层变成了两层,但是一个5x5卷积的开销是1个3x3的2.78倍,所以这种改变仍有利于性能的提升。
-
其次是将NxN的卷积变为1xN + Nx1的两个卷积。3x3的卷积变成1x3 +3x1的卷积,可以节省33%的性能。
串行会让网络变的很深,可能会造成信息损失。因此InceptionV3采用了并行的拆解
最终模型的结构为:
模型实现
- 首先是InceptionA模块,和InceptionV1一样
class InceptionA(nn.Module):
def __init__(self, in_channels, pool_features):
super().__init__()
self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)
self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)
self.branch3x3_1 = BasicConv2d(in_channels, 64, kernel_size=1)
self.branch3x3_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
self.branch3x3_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)
self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)
branch3x3 = self.branch3x3_3(branch3x3)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3, branch_pool]
return torch.cat(outputs, 1)
- 然后是InceptionB模块,串行拆分大卷积核
class InceptionB(nn.Module):
def __init__(self, in_channels, channels_7x7):
super().__init__()
self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)
c7 = channels_7x7
self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)
self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)
self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)
branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return torch.cat(outputs, 1)
- 然后是InceptionC,并行拆分大卷积核
class InceptionC(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)
self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)
self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)
def forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
- 还有用于降维的Reduction模块
class ReductionA(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2)
self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2)
def forward(self, x):
branch3x3 = self.branch3x3(x)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
outputs = [branch3x3, branch3x3dbl, branch_pool]
return torch.cat(outputs, 1)
class ReductionB(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2)
self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2)
def forward(self, x):
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)
branch7x7x3 = self.branch7x7x3_1(x)
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
outputs = [branch3x3, branch7x7x3, branch_pool]
return torch.cat(outputs, 1)
- 最后是辅助分类器
class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
self.conv1 = BasicConv2d(128, 768, kernel_size=5)
self.conv1.stddev = 0.01
self.fc = nn.Linear(768, num_classes)
self.fc.stddev = 0.001
def forward(self, x):
x = F.avg_pool2d(x, kernel_size=5, stride=3)
x = self.conv0(x)
x = self.conv1(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
模型组装
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)
class InceptionV3(nn.Module):
def __init__(self, num_classes=1000, aux_logits=False, transform_input=False):
super().__init__()
self.aux_logits = aux_logits
self.transform_input = transform_input
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
self.Mixed_5b = InceptionA(192, pool_features=32)
self.Mixed_5c = InceptionA(256, pool_features=64)
self.Mixed_5d = InceptionA(288, pool_features=64)
self.Mixed_6a = ReductionA(288)
self.Mixed_6b = InceptionB(768, channels_7x7=128)
self.Mixed_6c = InceptionB(768, channels_7x7=160)
self.Mixed_6d = InceptionB(768, channels_7x7=160)
self.Mixed_6e = InceptionB(768, channels_7x7=192)
if aux_logits:
self.AuxLogits = InceptionAux(768, num_classes)
self.Mixed_7a = ReductionB(768)
self.Mixed_7b = InceptionC(1280)
self.Mixed_7c = InceptionC(2048)
self.fc = nn.Linear(2048, num_classes)
def forward(self, x):
if self.transform_input:
x = x.clone()
x[:, 0] = x[:, 0]*(0.229 / 0.5) + (0.485 - 0.5) / 0.5
x[:, 1] = x[:, 1]*(0.224 / 0.5) + (0.456 - 0.5) / 0.5
x[:, 2] = x[:, 2]*(0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = self.Conv2d_1a_3x3(x)
x = self.Conv2d_2a_3x3(x)
x = self.Conv2d_2b_3x3(x)
x = F.max_pool2d(x, kernel_size=3, stride=2)
x = self.Conv2d_3b_1x1(x)
x = self.Conv2d_4a_3x3(x)
x = F.max_pool2d(x, kernel_size=3, stride=2)
x = self.Mixed_5b(x)
x = self.Mixed_5c(x)
x = self.Mixed_5d(x)
x = self.Mixed_6a(x)
x = self.Mixed_6b(x)
x = self.Mixed_6c(x)
x = self.Mixed_6d(x)
x = self.Mixed_6e(x)
if self.training and self.aux_logits:
aux = self.AuxLogits(x)
x = self.Mixed_7a(x)
x = self.Mixed_7b(x)
x = self.Mixed_7c(x)
x = F.avg_pool2d(x, kernel_size=8)
x = F.dropout(x, training=self.training)
x = x.view(x.size(0), -1)
x = self.fc(x)
if self.training and self.aux_logits:
return x, aux
return x
模型结构
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
InceptionV3 [32, 3] --
├─BasicConv2d: 1-1 [32, 32, 149, 149] --
│ └─Conv2d: 2-1 [32, 32, 149, 149] 864
│ └─BatchNorm2d: 2-2 [32, 32, 149, 149] 64
├─BasicConv2d: 1-2 [32, 32, 147, 147] --
│ └─Conv2d: 2-3 [32, 32, 147, 147] 9,216
│ └─BatchNorm2d: 2-4 [32, 32, 147, 147] 64
├─BasicConv2d: 1-3 [32, 64, 147, 147] --
│ └─Conv2d: 2-5 [32, 64, 147, 147] 18,432
│ └─BatchNorm2d: 2-6 [32, 64, 147, 147] 128
├─BasicConv2d: 1-4 [32, 80, 73, 73] --
│ └─Conv2d: 2-7 [32, 80, 73, 73] 5,120
│ └─BatchNorm2d: 2-8 [32, 80, 73, 73] 160
├─BasicConv2d: 1-5 [32, 192, 71, 71] --
│ └─Conv2d: 2-9 [32, 192, 71, 71] 138,240
│ └─BatchNorm2d: 2-10 [32, 192, 71, 71] 384
├─InceptionA: 1-6 [32, 256, 35, 35] --
│ └─BasicConv2d: 2-11 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-1 [32, 64, 35, 35] 12,288
│ │ └─BatchNorm2d: 3-2 [32, 64, 35, 35] 128
│ └─BasicConv2d: 2-12 [32, 48, 35, 35] --
│ │ └─Conv2d: 3-3 [32, 48, 35, 35] 9,216
│ │ └─BatchNorm2d: 3-4 [32, 48, 35, 35] 96
│ └─BasicConv2d: 2-13 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-5 [32, 64, 35, 35] 76,800
│ │ └─BatchNorm2d: 3-6 [32, 64, 35, 35] 128
│ └─BasicConv2d: 2-14 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-7 [32, 64, 35, 35] 12,288
│ │ └─BatchNorm2d: 3-8 [32, 64, 35, 35] 128
│ └─BasicConv2d: 2-15 [32, 96, 35, 35] --
│ │ └─Conv2d: 3-9 [32, 96, 35, 35] 55,296
│ │ └─BatchNorm2d: 3-10 [32, 96, 35, 35] 192
│ └─BasicConv2d: 2-16 [32, 96, 35, 35] --
│ │ └─Conv2d: 3-11 [32, 96, 35, 35] 82,944
│ │ └─BatchNorm2d: 3-12 [32, 96, 35, 35] 192
│ └─BasicConv2d: 2-17 [32, 32, 35, 35] --
│ │ └─Conv2d: 3-13 [32, 32, 35, 35] 6,144
│ │ └─BatchNorm2d: 3-14 [32, 32, 35, 35] 64
├─InceptionA: 1-7 [32, 288, 35, 35] --
│ └─BasicConv2d: 2-18 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-15 [32, 64, 35, 35] 16,384
│ │ └─BatchNorm2d: 3-16 [32, 64, 35, 35] 128
│ └─BasicConv2d: 2-19 [32, 48, 35, 35] --
│ │ └─Conv2d: 3-17 [32, 48, 35, 35] 12,288
│ │ └─BatchNorm2d: 3-18 [32, 48, 35, 35] 96
│ └─BasicConv2d: 2-20 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-19 [32, 64, 35, 35] 76,800
│ │ └─BatchNorm2d: 3-20 [32, 64, 35, 35] 128
│ └─BasicConv2d: 2-21 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-21 [32, 64, 35, 35] 16,384
│ │ └─BatchNorm2d: 3-22 [32, 64, 35, 35] 128
│ └─BasicConv2d: 2-22 [32, 96, 35, 35] --
│ │ └─Conv2d: 3-23 [32, 96, 35, 35] 55,296
│ │ └─BatchNorm2d: 3-24 [32, 96, 35, 35] 192
│ └─BasicConv2d: 2-23 [32, 96, 35, 35] --
│ │ └─Conv2d: 3-25 [32, 96, 35, 35] 82,944
│ │ └─BatchNorm2d: 3-26 [32, 96, 35, 35] 192
│ └─BasicConv2d: 2-24 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-27 [32, 64, 35, 35] 16,384
│ │ └─BatchNorm2d: 3-28 [32, 64, 35, 35] 128
├─InceptionA: 1-8 [32, 288, 35, 35] --
│ └─BasicConv2d: 2-25 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-29 [32, 64, 35, 35] 18,432
│ │ └─BatchNorm2d: 3-30 [32, 64, 35, 35] 128
│ └─BasicConv2d: 2-26 [32, 48, 35, 35] --
│ │ └─Conv2d: 3-31 [32, 48, 35, 35] 13,824
│ │ └─BatchNorm2d: 3-32 [32, 48, 35, 35] 96
│ └─BasicConv2d: 2-27 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-33 [32, 64, 35, 35] 76,800
│ │ └─BatchNorm2d: 3-34 [32, 64, 35, 35] 128
│ └─BasicConv2d: 2-28 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-35 [32, 64, 35, 35] 18,432
│ │ └─BatchNorm2d: 3-36 [32, 64, 35, 35] 128
│ └─BasicConv2d: 2-29 [32, 96, 35, 35] --
│ │ └─Conv2d: 3-37 [32, 96, 35, 35] 55,296
│ │ └─BatchNorm2d: 3-38 [32, 96, 35, 35] 192
│ └─BasicConv2d: 2-30 [32, 96, 35, 35] --
│ │ └─Conv2d: 3-39 [32, 96, 35, 35] 82,944
│ │ └─BatchNorm2d: 3-40 [32, 96, 35, 35] 192
│ └─BasicConv2d: 2-31 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-41 [32, 64, 35, 35] 18,432
│ │ └─BatchNorm2d: 3-42 [32, 64, 35, 35] 128
├─ReductionA: 1-9 [32, 768, 17, 17] --
│ └─BasicConv2d: 2-32 [32, 384, 17, 17] --
│ │ └─Conv2d: 3-43 [32, 384, 17, 17] 995,328
│ │ └─BatchNorm2d: 3-44 [32, 384, 17, 17] 768
│ └─BasicConv2d: 2-33 [32, 64, 35, 35] --
│ │ └─Conv2d: 3-45 [32, 64, 35, 35] 18,432
│ │ └─BatchNorm2d: 3-46 [32, 64, 35, 35] 128
│ └─BasicConv2d: 2-34 [32, 96, 35, 35] --
│ │ └─Conv2d: 3-47 [32, 96, 35, 35] 55,296
│ │ └─BatchNorm2d: 3-48 [32, 96, 35, 35] 192
│ └─BasicConv2d: 2-35 [32, 96, 17, 17] --
│ │ └─Conv2d: 3-49 [32, 96, 17, 17] 82,944
│ │ └─BatchNorm2d: 3-50 [32, 96, 17, 17] 192
├─InceptionB: 1-10 [32, 768, 17, 17] --
│ └─BasicConv2d: 2-36 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-51 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-52 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-37 [32, 128, 17, 17] --
│ │ └─Conv2d: 3-53 [32, 128, 17, 17] 98,304
│ │ └─BatchNorm2d: 3-54 [32, 128, 17, 17] 256
│ └─BasicConv2d: 2-38 [32, 128, 17, 17] --
│ │ └─Conv2d: 3-55 [32, 128, 17, 17] 114,688
│ │ └─BatchNorm2d: 3-56 [32, 128, 17, 17] 256
│ └─BasicConv2d: 2-39 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-57 [32, 192, 17, 17] 172,032
│ │ └─BatchNorm2d: 3-58 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-40 [32, 128, 17, 17] --
│ │ └─Conv2d: 3-59 [32, 128, 17, 17] 98,304
│ │ └─BatchNorm2d: 3-60 [32, 128, 17, 17] 256
│ └─BasicConv2d: 2-41 [32, 128, 17, 17] --
│ │ └─Conv2d: 3-61 [32, 128, 17, 17] 114,688
│ │ └─BatchNorm2d: 3-62 [32, 128, 17, 17] 256
│ └─BasicConv2d: 2-42 [32, 128, 17, 17] --
│ │ └─Conv2d: 3-63 [32, 128, 17, 17] 114,688
│ │ └─BatchNorm2d: 3-64 [32, 128, 17, 17] 256
│ └─BasicConv2d: 2-43 [32, 128, 17, 17] --
│ │ └─Conv2d: 3-65 [32, 128, 17, 17] 114,688
│ │ └─BatchNorm2d: 3-66 [32, 128, 17, 17] 256
│ └─BasicConv2d: 2-44 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-67 [32, 192, 17, 17] 172,032
│ │ └─BatchNorm2d: 3-68 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-45 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-69 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-70 [32, 192, 17, 17] 384
├─InceptionB: 1-11 [32, 768, 17, 17] --
│ └─BasicConv2d: 2-46 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-71 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-72 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-47 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-73 [32, 160, 17, 17] 122,880
│ │ └─BatchNorm2d: 3-74 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-48 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-75 [32, 160, 17, 17] 179,200
│ │ └─BatchNorm2d: 3-76 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-49 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-77 [32, 192, 17, 17] 215,040
│ │ └─BatchNorm2d: 3-78 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-50 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-79 [32, 160, 17, 17] 122,880
│ │ └─BatchNorm2d: 3-80 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-51 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-81 [32, 160, 17, 17] 179,200
│ │ └─BatchNorm2d: 3-82 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-52 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-83 [32, 160, 17, 17] 179,200
│ │ └─BatchNorm2d: 3-84 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-53 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-85 [32, 160, 17, 17] 179,200
│ │ └─BatchNorm2d: 3-86 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-54 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-87 [32, 192, 17, 17] 215,040
│ │ └─BatchNorm2d: 3-88 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-55 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-89 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-90 [32, 192, 17, 17] 384
├─InceptionB: 1-12 [32, 768, 17, 17] --
│ └─BasicConv2d: 2-56 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-91 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-92 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-57 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-93 [32, 160, 17, 17] 122,880
│ │ └─BatchNorm2d: 3-94 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-58 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-95 [32, 160, 17, 17] 179,200
│ │ └─BatchNorm2d: 3-96 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-59 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-97 [32, 192, 17, 17] 215,040
│ │ └─BatchNorm2d: 3-98 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-60 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-99 [32, 160, 17, 17] 122,880
│ │ └─BatchNorm2d: 3-100 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-61 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-101 [32, 160, 17, 17] 179,200
│ │ └─BatchNorm2d: 3-102 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-62 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-103 [32, 160, 17, 17] 179,200
│ │ └─BatchNorm2d: 3-104 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-63 [32, 160, 17, 17] --
│ │ └─Conv2d: 3-105 [32, 160, 17, 17] 179,200
│ │ └─BatchNorm2d: 3-106 [32, 160, 17, 17] 320
│ └─BasicConv2d: 2-64 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-107 [32, 192, 17, 17] 215,040
│ │ └─BatchNorm2d: 3-108 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-65 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-109 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-110 [32, 192, 17, 17] 384
├─InceptionB: 1-13 [32, 768, 17, 17] --
│ └─BasicConv2d: 2-66 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-111 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-112 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-67 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-113 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-114 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-68 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-115 [32, 192, 17, 17] 258,048
│ │ └─BatchNorm2d: 3-116 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-69 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-117 [32, 192, 17, 17] 258,048
│ │ └─BatchNorm2d: 3-118 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-70 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-119 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-120 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-71 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-121 [32, 192, 17, 17] 258,048
│ │ └─BatchNorm2d: 3-122 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-72 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-123 [32, 192, 17, 17] 258,048
│ │ └─BatchNorm2d: 3-124 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-73 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-125 [32, 192, 17, 17] 258,048
│ │ └─BatchNorm2d: 3-126 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-74 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-127 [32, 192, 17, 17] 258,048
│ │ └─BatchNorm2d: 3-128 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-75 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-129 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-130 [32, 192, 17, 17] 384
├─ReductionB: 1-14 [32, 1280, 8, 8] --
│ └─BasicConv2d: 2-76 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-131 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-132 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-77 [32, 320, 8, 8] --
│ │ └─Conv2d: 3-133 [32, 320, 8, 8] 552,960
│ │ └─BatchNorm2d: 3-134 [32, 320, 8, 8] 640
│ └─BasicConv2d: 2-78 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-135 [32, 192, 17, 17] 147,456
│ │ └─BatchNorm2d: 3-136 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-79 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-137 [32, 192, 17, 17] 258,048
│ │ └─BatchNorm2d: 3-138 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-80 [32, 192, 17, 17] --
│ │ └─Conv2d: 3-139 [32, 192, 17, 17] 258,048
│ │ └─BatchNorm2d: 3-140 [32, 192, 17, 17] 384
│ └─BasicConv2d: 2-81 [32, 192, 8, 8] --
│ │ └─Conv2d: 3-141 [32, 192, 8, 8] 331,776
│ │ └─BatchNorm2d: 3-142 [32, 192, 8, 8] 384
├─InceptionC: 1-15 [32, 2048, 8, 8] --
│ └─BasicConv2d: 2-82 [32, 320, 8, 8] --
│ │ └─Conv2d: 3-143 [32, 320, 8, 8] 409,600
│ │ └─BatchNorm2d: 3-144 [32, 320, 8, 8] 640
│ └─BasicConv2d: 2-83 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-145 [32, 384, 8, 8] 491,520
│ │ └─BatchNorm2d: 3-146 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-84 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-147 [32, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-148 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-85 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-149 [32, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-150 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-86 [32, 448, 8, 8] --
│ │ └─Conv2d: 3-151 [32, 448, 8, 8] 573,440
│ │ └─BatchNorm2d: 3-152 [32, 448, 8, 8] 896
│ └─BasicConv2d: 2-87 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-153 [32, 384, 8, 8] 1,548,288
│ │ └─BatchNorm2d: 3-154 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-88 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-155 [32, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-156 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-89 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-157 [32, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-158 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-90 [32, 192, 8, 8] --
│ │ └─Conv2d: 3-159 [32, 192, 8, 8] 245,760
│ │ └─BatchNorm2d: 3-160 [32, 192, 8, 8] 384
├─InceptionC: 1-16 [32, 2048, 8, 8] --
│ └─BasicConv2d: 2-91 [32, 320, 8, 8] --
│ │ └─Conv2d: 3-161 [32, 320, 8, 8] 655,360
│ │ └─BatchNorm2d: 3-162 [32, 320, 8, 8] 640
│ └─BasicConv2d: 2-92 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-163 [32, 384, 8, 8] 786,432
│ │ └─BatchNorm2d: 3-164 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-93 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-165 [32, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-166 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-94 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-167 [32, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-168 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-95 [32, 448, 8, 8] --
│ │ └─Conv2d: 3-169 [32, 448, 8, 8] 917,504
│ │ └─BatchNorm2d: 3-170 [32, 448, 8, 8] 896
│ └─BasicConv2d: 2-96 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-171 [32, 384, 8, 8] 1,548,288
│ │ └─BatchNorm2d: 3-172 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-97 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-173 [32, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-174 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-98 [32, 384, 8, 8] --
│ │ └─Conv2d: 3-175 [32, 384, 8, 8] 442,368
│ │ └─BatchNorm2d: 3-176 [32, 384, 8, 8] 768
│ └─BasicConv2d: 2-99 [32, 192, 8, 8] --
│ │ └─Conv2d: 3-177 [32, 192, 8, 8] 393,216
│ │ └─BatchNorm2d: 3-178 [32, 192, 8, 8] 384
├─Linear: 1-17 [32, 3] 6,147
==========================================================================================
Total params: 21,791,715
Trainable params: 21,791,715
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 182.76
==========================================================================================
Input size (MB): 34.33
Forward/backward pass size (MB): 4591.35
Params size (MB): 87.17
Estimated Total Size (MB): 4712.85
==========================================================================================
模型效果
训练过程
Epoch: 1, Train_acc: 61.1%, Train_loss: 0.737, Test_acc: 68.5%, Test_loss:0.653
Epoch: 2, Train_acc: 61.2%, Train_loss: 0.681, Test_acc: 66.9%, Test_loss:0.724
Epoch: 3, Train_acc: 63.3%, Train_loss: 0.665, Test_acc: 70.4%, Test_loss:0.573
Epoch: 4, Train_acc: 66.4%, Train_loss: 0.626, Test_acc: 66.2%, Test_loss:0.942
Epoch: 5, Train_acc: 63.3%, Train_loss: 0.656, Test_acc: 73.2%, Test_loss:0.584
Epoch: 6, Train_acc: 66.1%, Train_loss: 0.630, Test_acc: 70.9%, Test_loss:0.595
Epoch: 7, Train_acc: 65.4%, Train_loss: 0.640, Test_acc: 73.2%, Test_loss:0.553
Epoch: 8, Train_acc: 66.0%, Train_loss: 0.616, Test_acc: 69.7%, Test_loss:0.562
Epoch: 9, Train_acc: 66.4%, Train_loss: 0.622, Test_acc: 71.8%, Test_loss:0.576
Epoch: 10, Train_acc: 67.8%, Train_loss: 0.599, Test_acc: 71.6%, Test_loss:0.595
Epoch: 11, Train_acc: 66.0%, Train_loss: 0.603, Test_acc: 66.2%, Test_loss:0.604
Epoch: 12, Train_acc: 67.4%, Train_loss: 0.594, Test_acc: 69.5%, Test_loss:0.589
Epoch: 13, Train_acc: 69.1%, Train_loss: 0.579, Test_acc: 74.6%, Test_loss:0.528
Epoch: 14, Train_acc: 71.1%, Train_loss: 0.567, Test_acc: 66.9%, Test_loss:0.603
Epoch: 15, Train_acc: 71.9%, Train_loss: 0.557, Test_acc: 71.8%, Test_loss:0.553
Epoch: 16, Train_acc: 71.5%, Train_loss: 0.555, Test_acc: 77.2%, Test_loss:0.505
Epoch: 17, Train_acc: 74.4%, Train_loss: 0.504, Test_acc: 76.7%, Test_loss:0.499
Epoch: 18, Train_acc: 76.9%, Train_loss: 0.504, Test_acc: 78.6%, Test_loss:0.494
Epoch: 19, Train_acc: 75.1%, Train_loss: 0.500, Test_acc: 73.0%, Test_loss:0.530
Epoch: 20, Train_acc: 77.2%, Train_loss: 0.488, Test_acc: 80.4%, Test_loss:0.443
Epoch: 21, Train_acc: 81.1%, Train_loss: 0.430, Test_acc: 76.9%, Test_loss:0.439
Epoch: 22, Train_acc: 82.3%, Train_loss: 0.410, Test_acc: 81.1%, Test_loss:0.464
Epoch: 23, Train_acc: 84.5%, Train_loss: 0.368, Test_acc: 85.3%, Test_loss:0.384
Epoch: 24, Train_acc: 84.6%, Train_loss: 0.369, Test_acc: 88.8%, Test_loss:0.331
Epoch: 25, Train_acc: 87.0%, Train_loss: 0.320, Test_acc: 90.0%, Test_loss:0.255
Epoch: 26, Train_acc: 86.6%, Train_loss: 0.327, Test_acc: 88.1%, Test_loss:0.272
Epoch: 27, Train_acc: 87.4%, Train_loss: 0.322, Test_acc: 87.2%, Test_loss:0.309
Epoch: 28, Train_acc: 89.0%, Train_loss: 0.284, Test_acc: 88.3%, Test_loss:0.334
Epoch: 29, Train_acc: 88.4%, Train_loss: 0.275, Test_acc: 89.7%, Test_loss:0.259
Epoch: 30, Train_acc: 88.3%, Train_loss: 0.285, Test_acc: 90.2%, Test_loss:0.260
Epoch: 31, Train_acc: 91.7%, Train_loss: 0.221, Test_acc: 88.6%, Test_loss:0.280
Epoch: 32, Train_acc: 91.2%, Train_loss: 0.226, Test_acc: 91.4%, Test_loss:0.252
Epoch: 33, Train_acc: 91.8%, Train_loss: 0.204, Test_acc: 90.9%, Test_loss:0.229
Epoch: 34, Train_acc: 89.2%, Train_loss: 0.262, Test_acc: 88.3%, Test_loss:0.366
Epoch: 35, Train_acc: 91.0%, Train_loss: 0.204, Test_acc: 86.5%, Test_loss:0.349
Epoch: 36, Train_acc: 92.0%, Train_loss: 0.212, Test_acc: 90.9%, Test_loss:0.275
Epoch: 37, Train_acc: 94.2%, Train_loss: 0.159, Test_acc: 92.3%, Test_loss:0.244
Epoch: 38, Train_acc: 94.7%, Train_loss: 0.141, Test_acc: 89.7%, Test_loss:0.273
Epoch: 39, Train_acc: 93.3%, Train_loss: 0.166, Test_acc: 90.7%, Test_loss:0.255
Epoch: 40, Train_acc: 94.2%, Train_loss: 0.157, Test_acc: 90.4%, Test_loss:0.233
Epoch: 41, Train_acc: 96.0%, Train_loss: 0.121, Test_acc: 92.8%, Test_loss:0.244
Epoch: 42, Train_acc: 94.5%, Train_loss: 0.143, Test_acc: 88.1%, Test_loss:0.389
Epoch: 43, Train_acc: 94.6%, Train_loss: 0.143, Test_acc: 92.1%, Test_loss:0.241
Epoch: 44, Train_acc: 94.2%, Train_loss: 0.146, Test_acc: 92.5%, Test_loss:0.226
Epoch: 45, Train_acc: 95.4%, Train_loss: 0.122, Test_acc: 86.0%, Test_loss:0.383
Epoch: 46, Train_acc: 94.0%, Train_loss: 0.158, Test_acc: 92.8%, Test_loss:0.239
Epoch: 47, Train_acc: 95.1%, Train_loss: 0.137, Test_acc: 89.5%, Test_loss:0.320
Epoch: 48, Train_acc: 97.0%, Train_loss: 0.086, Test_acc: 92.5%, Test_loss:0.274
Epoch: 49, Train_acc: 97.3%, Train_loss: 0.078, Test_acc: 87.9%, Test_loss:0.349
Epoch: 50, Train_acc: 95.0%, Train_loss: 0.137, Test_acc: 92.3%, Test_loss:0.239
Done
训练曲线
总结与心得体会
通过对InceptionV1和InceptionV3的对比可以发现,InceptionV1的模型结构十分简单,InceptionV3模型的结构就复杂了很多。然而InceptionV3的性能在InceptionV1的基础上略有提升,但不不是非常明显。
通过对InceptionV3模型的学习,发现两个在之前的学习中从来没有遇到的结构:
- 是并行1xn和nx1的卷积来替换nxn卷积,在精度不变的情况下可以降低很多参数量和计算量。
- 是使用辅助分类器,通过辅助分类器和主分类器加权融合形成最终的预测结果,可以使模型的浅层模块也能较好的得到训练,和Residual模块相比,也是解决深层网络模型训练问题的一个思路。