pytorch【GoogLeNet v1复现】

这是我们架构图

在这里插入图片描述

导入相关的包

import torch
from torch import nn
from torchinfo import summary

定义我们基础的卷积层

我们每一个基础卷积层都是经过了一个卷积层,然后进过bn层,然后再用ReLu函数进行激活

class BassicConv2d(nn.Module):
    #**kwargs就是谁用全部的参数
    def __init__(self,in_channels,out_channels,**kwargs):
        super().__init__()
        self.conv=nn.Sequential(
            #除了bias也就是偏差我们不需要,其他的参数我们都传入默认值
            nn.Conv2d(in_channels,out_channels,bias=False,**kwargs)
            ,nn.BatchNorm2d(out_channels)
            ,nn.ReLU(inplace=True)
        )
    def forward(self,x):
        x=self.conv(x)
        return x
#测试BassicConv2d
BassicConv2d(2,10,kernel_size=3)
BassicConv2d(
  (conv): Sequential(
    (0): Conv2d(2, 10, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (1): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
)

定义inception层

在这里插入图片描述
这是我们的inception层中具体的定义,也就是说有四层并行的结构
在这里插入图片描述

class Inception(nn.Module):
    def __init__(self
                 #限制输入的类型是int
                 ,in_channels  :int
                 ,ch1x1   :int
                 ,ch3x3red :int
                 ,ch3x3 :int
                 ,ch5x5red :int
                 ,ch5x5 :int
                 ,pool_proj :int
                 ):
        super().__init__()
        #1x1
        self.branch1=BassicConv2d(in_channels,ch1x1,kernel_size=1)

        #1x1+3x3
        self.branch2=nn.Sequential(
            BassicConv2d(in_channels,ch3x3red,kernel_size=1)
            ,BassicConv2d(ch3x3red,ch3x3,kernel_size=3,padding=1)
                                   )
        #1x1+5x5
        self.branch3=nn.Sequential(
            BassicConv2d(in_channels,ch5x5red,kernel_size=1)
            ,BassicConv2d(ch5x5red,ch5x5,kernel_size=5,padding=2)
        )

        #pool+1x1
        self.branch4=nn.Sequential(
            nn.MaxPool2d(kernel_size=3,stride=1,padding=1,ceil_mode=True)
            ,BassicConv2d(in_channels,pool_proj,kernel_size=1)
        )

    def forward(self,x):
        #并行的四层结构
        branch1=self.branch1(x)
        branch2=self.branch2(x)
        branch3=self.branch3(x)
        branch4=self.branch4(x)
        outputs=[branch1,branch2,branch3,branch4]
        return torch.cat(outputs,1)#纵向进行合并
#测试
in3a=Inception(192,64,96,128,16,32,32)
data=torch.ones(10,192,28,28)
in3a(data).size()
torch.Size([10, 256, 28, 28])

定义辅助分类器

核尺寸/步长输入4a输出尺寸输入4d输出尺寸
平均池化层5x5/34x4x5124x4x528
卷积层+ReLU1x1/14x4x1284x4x128
全连接层+ReLU10241024
Dropout(70%)10241024
全连接层+Softmax10001000
#auxiliay classifier
#辅助分类器
class AuxClf(nn.Module):
    def __init__(self,in_channels :int ,num_classes:int ,**kwargs):
        super().__init__()
        self.feature_=nn.Sequential(
            nn.AvgPool2d(kernel_size=5,stride=3)
            ,BassicConv2d(in_channels,128,kernel_size=1)
        )

        self.clf_=nn.Sequential(
            nn.Linear(4*4*128,1024)
            ,nn.ReLU(inplace=True)
            ,nn.Dropout(0.7)
            ,nn.Linear(1024,num_classes)
        )
    def forward(self,x):
        x=self.feature_(x)
        x=x.view(-1,4*4*128)
        x=self.clf_(x)
        return x
#4a后的辅助分类器
AuxClf(512,1000)
AuxClf(
  (feature_): Sequential(
    (0): AvgPool2d(kernel_size=5, stride=3, padding=0)
    (1): BassicConv2d(
      (conv): Sequential(
        (0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
      )
    )
  )
  (clf_): Sequential(
    (0): Linear(in_features=2048, out_features=1024, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.7, inplace=False)
    (3): Linear(in_features=1024, out_features=1000, bias=True)
  )
)

定义我们的整一张网络

在这里插入图片描述

class GoogLeNet(nn.Module):
    def __init__(self,num_classes:int =1000,blocks=None):
        super().__init__()


        if blocks is None:
            blocks=[BassicConv2d,Inception,AuxClf]
        conv_block=blocks[0]
        inception_block=blocks[1]
        aux_clf_block=blocks[2]

        #block1
        self.conv1=conv_block(3,64,kernel_size=7,stride=2,padding=3)
        #需要向上取整得到56
        #(112-3)/2+1=55.5
        self.maxpool1=nn.MaxPool2d(kernel_size=3,stride=2,ceil_mode=True)

        #block2
        self.conv2=conv_block(64,64,kernel_size=1)
        self.conv3=conv_block(64,192,kernel_size=3,padding=1,)
        self.maxpool2=nn.MaxPool2d(kernel_size=3,stride=2,ceil_mode=True)

        #block3
        self.inception3a=inception_block(192,64,96,128,16,32,32)
        self.inception3b=inception_block(256,128,128,192,32,96,64)
        self.maxpool3=nn.MaxPool2d(kernel_size=3,stride=2,ceil_mode=True)

        #block4
        self.inception4a=inception_block(480,192,96,208,16,48,64)
        self.inception4b=inception_block(512,160,112,224,24,64,64)
        self.inception4c=inception_block(512,128,128,256,24,64,64)
        self.inception4d=inception_block(512,112,144,288,32,64,64)
        self.inception4e=inception_block(528,256,160,320,32,128,128)
        self.maxpool4=nn.MaxPool2d(kernel_size=3,stride=2,ceil_mode=True)

        #block5
        self.inception5a=inception_block(832,256,160,320,32,128,128)
        self.inception5b=inception_block(832,384,192,384,48,128,128)

        #clf
        #我们这里需要输出的特征图尺寸是1*1
        self.avgpool=nn.AdaptiveAvgPool2d((1,1))#我需要输出的特征图尺寸
        self.dropout=nn.Dropout(0.4)
        self.fc=nn.Linear(1024,num_classes)

        #auxclf
        #辅助分类器
        self.aux1=aux_clf_block(512,num_classes)#4a
        self.aux2=aux_clf_block(528,num_classes)#4d

    def forward(self,x):
        #block1
        x=self.maxpool1(self.conv1(x))

        #block2
        x=self.maxpool2(self.conv3(self.conv2(x)))

        #block3
        x=self.inception3a(x)
        x=self.inception3b(x)
        x=self.maxpool3(x)

        #block4
        x=self.inception4a(x)
        aux1=self.aux1(x)

        x=self.inception4b(x)
        x=self.inception4c(x)
        x=self.inception4d(x)
        aux2=self.aux2(x)

        x=self.inception4e(x)
        x=self.maxpool4(x)

        #block5
        x=self.inception5a(x)
        x=self.inception5b(x)

        #clf
        x=self.avgpool(x)#在这个全局平均池化之后特征图尺寸就变成了1x1
        x=torch.flatten(x,1)
        x=self.dropout(x)
        x=self.fc(x)

        return x,aux2,aux1
#测试
data=torch.ones(10,3,224,224)
net=GoogLeNet(num_classes=1000)
fc2,fc1,fc0=net(data)
for i in [fc2,fc1,fc0]:
    print(i.shape)
torch.Size([10, 1000])
torch.Size([10, 1000])
torch.Size([10, 1000])
summary(net,(10,3,224,224),device="cpu")
===============================================================================================
Layer (type:depth-idx)                        Output Shape              Param #
===============================================================================================
GoogLeNet                                     [10, 1000]                --
├─BassicConv2d: 1-1                           [10, 64, 112, 112]        --
│    └─Sequential: 2-1                        [10, 64, 112, 112]        --
│    │    └─Conv2d: 3-1                       [10, 64, 112, 112]        9,408
│    │    └─BatchNorm2d: 3-2                  [10, 64, 112, 112]        128
│    │    └─ReLU: 3-3                         [10, 64, 112, 112]        --
├─MaxPool2d: 1-2                              [10, 64, 56, 56]          --
├─BassicConv2d: 1-3                           [10, 64, 56, 56]          --
│    └─Sequential: 2-2                        [10, 64, 56, 56]          --
│    │    └─Conv2d: 3-4                       [10, 64, 56, 56]          4,096
│    │    └─BatchNorm2d: 3-5                  [10, 64, 56, 56]          128
│    │    └─ReLU: 3-6                         [10, 64, 56, 56]          --
├─BassicConv2d: 1-4                           [10, 192, 56, 56]         --
│    └─Sequential: 2-3                        [10, 192, 56, 56]         --
│    │    └─Conv2d: 3-7                       [10, 192, 56, 56]         110,592
│    │    └─BatchNorm2d: 3-8                  [10, 192, 56, 56]         384
│    │    └─ReLU: 3-9                         [10, 192, 56, 56]         --
├─MaxPool2d: 1-5                              [10, 192, 28, 28]         --
├─Inception: 1-6                              [10, 256, 28, 28]         --
│    └─BassicConv2d: 2-4                      [10, 64, 28, 28]          --
│    │    └─Sequential: 3-10                  [10, 64, 28, 28]          12,416
│    └─Sequential: 2-5                        [10, 128, 28, 28]         --
│    │    └─BassicConv2d: 3-11                [10, 96, 28, 28]          18,624
│    │    └─BassicConv2d: 3-12                [10, 128, 28, 28]         110,848
│    └─Sequential: 2-6                        [10, 32, 28, 28]          --
│    │    └─BassicConv2d: 3-13                [10, 16, 28, 28]          3,104
│    │    └─BassicConv2d: 3-14                [10, 32, 28, 28]          12,864
│    └─Sequential: 2-7                        [10, 32, 28, 28]          --
│    │    └─MaxPool2d: 3-15                   [10, 192, 28, 28]         --
│    │    └─BassicConv2d: 3-16                [10, 32, 28, 28]          6,208
├─Inception: 1-7                              [10, 480, 28, 28]         --
│    └─BassicConv2d: 2-8                      [10, 128, 28, 28]         --
│    │    └─Sequential: 3-17                  [10, 128, 28, 28]         33,024
│    └─Sequential: 2-9                        [10, 192, 28, 28]         --
│    │    └─BassicConv2d: 3-18                [10, 128, 28, 28]         33,024
│    │    └─BassicConv2d: 3-19                [10, 192, 28, 28]         221,568
│    └─Sequential: 2-10                       [10, 96, 28, 28]          --
│    │    └─BassicConv2d: 3-20                [10, 32, 28, 28]          8,256
│    │    └─BassicConv2d: 3-21                [10, 96, 28, 28]          76,992
│    └─Sequential: 2-11                       [10, 64, 28, 28]          --
│    │    └─MaxPool2d: 3-22                   [10, 256, 28, 28]         --
│    │    └─BassicConv2d: 3-23                [10, 64, 28, 28]          16,512
├─MaxPool2d: 1-8                              [10, 480, 14, 14]         --
├─Inception: 1-9                              [10, 512, 14, 14]         --
│    └─BassicConv2d: 2-12                     [10, 192, 14, 14]         --
│    │    └─Sequential: 3-24                  [10, 192, 14, 14]         92,544
│    └─Sequential: 2-13                       [10, 208, 14, 14]         --
│    │    └─BassicConv2d: 3-25                [10, 96, 14, 14]          46,272
│    │    └─BassicConv2d: 3-26                [10, 208, 14, 14]         180,128
│    └─Sequential: 2-14                       [10, 48, 14, 14]          --
│    │    └─BassicConv2d: 3-27                [10, 16, 14, 14]          7,712
│    │    └─BassicConv2d: 3-28                [10, 48, 14, 14]          19,296
│    └─Sequential: 2-15                       [10, 64, 14, 14]          --
│    │    └─MaxPool2d: 3-29                   [10, 480, 14, 14]         --
│    │    └─BassicConv2d: 3-30                [10, 64, 14, 14]          30,848
├─AuxClf: 1-10                                [10, 1000]                --
│    └─Sequential: 2-16                       [10, 128, 4, 4]           --
│    │    └─AvgPool2d: 3-31                   [10, 512, 4, 4]           --
│    │    └─BassicConv2d: 3-32                [10, 128, 4, 4]           65,792
│    └─Sequential: 2-17                       [10, 1000]                --
│    │    └─Linear: 3-33                      [10, 1024]                2,098,176
│    │    └─ReLU: 3-34                        [10, 1024]                --
│    │    └─Dropout: 3-35                     [10, 1024]                --
│    │    └─Linear: 3-36                      [10, 1000]                1,025,000
├─Inception: 1-11                             [10, 512, 14, 14]         --
│    └─BassicConv2d: 2-18                     [10, 160, 14, 14]         --
│    │    └─Sequential: 3-37                  [10, 160, 14, 14]         82,240
│    └─Sequential: 2-19                       [10, 224, 14, 14]         --
│    │    └─BassicConv2d: 3-38                [10, 112, 14, 14]         57,568
│    │    └─BassicConv2d: 3-39                [10, 224, 14, 14]         226,240
│    └─Sequential: 2-20                       [10, 64, 14, 14]          --
│    │    └─BassicConv2d: 3-40                [10, 24, 14, 14]          12,336
│    │    └─BassicConv2d: 3-41                [10, 64, 14, 14]          38,528
│    └─Sequential: 2-21                       [10, 64, 14, 14]          --
│    │    └─MaxPool2d: 3-42                   [10, 512, 14, 14]         --
│    │    └─BassicConv2d: 3-43                [10, 64, 14, 14]          32,896
├─Inception: 1-12                             [10, 512, 14, 14]         --
│    └─BassicConv2d: 2-22                     [10, 128, 14, 14]         --
│    │    └─Sequential: 3-44                  [10, 128, 14, 14]         65,792
│    └─Sequential: 2-23                       [10, 256, 14, 14]         --
│    │    └─BassicConv2d: 3-45                [10, 128, 14, 14]         65,792
│    │    └─BassicConv2d: 3-46                [10, 256, 14, 14]         295,424
│    └─Sequential: 2-24                       [10, 64, 14, 14]          --
│    │    └─BassicConv2d: 3-47                [10, 24, 14, 14]          12,336
│    │    └─BassicConv2d: 3-48                [10, 64, 14, 14]          38,528
│    └─Sequential: 2-25                       [10, 64, 14, 14]          --
│    │    └─MaxPool2d: 3-49                   [10, 512, 14, 14]         --
│    │    └─BassicConv2d: 3-50                [10, 64, 14, 14]          32,896
├─Inception: 1-13                             [10, 528, 14, 14]         --
│    └─BassicConv2d: 2-26                     [10, 112, 14, 14]         --
│    │    └─Sequential: 3-51                  [10, 112, 14, 14]         57,568
│    └─Sequential: 2-27                       [10, 288, 14, 14]         --
│    │    └─BassicConv2d: 3-52                [10, 144, 14, 14]         74,016
│    │    └─BassicConv2d: 3-53                [10, 288, 14, 14]         373,824
│    └─Sequential: 2-28                       [10, 64, 14, 14]          --
│    │    └─BassicConv2d: 3-54                [10, 32, 14, 14]          16,448
│    │    └─BassicConv2d: 3-55                [10, 64, 14, 14]          51,328
│    └─Sequential: 2-29                       [10, 64, 14, 14]          --
│    │    └─MaxPool2d: 3-56                   [10, 512, 14, 14]         --
│    │    └─BassicConv2d: 3-57                [10, 64, 14, 14]          32,896
├─AuxClf: 1-14                                [10, 1000]                --
│    └─Sequential: 2-30                       [10, 128, 4, 4]           --
│    │    └─AvgPool2d: 3-58                   [10, 528, 4, 4]           --
│    │    └─BassicConv2d: 3-59                [10, 128, 4, 4]           67,840
│    └─Sequential: 2-31                       [10, 1000]                --
│    │    └─Linear: 3-60                      [10, 1024]                2,098,176
│    │    └─ReLU: 3-61                        [10, 1024]                --
│    │    └─Dropout: 3-62                     [10, 1024]                --
│    │    └─Linear: 3-63                      [10, 1000]                1,025,000
├─Inception: 1-15                             [10, 832, 14, 14]         --
│    └─BassicConv2d: 2-32                     [10, 256, 14, 14]         --
│    │    └─Sequential: 3-64                  [10, 256, 14, 14]         135,680
│    └─Sequential: 2-33                       [10, 320, 14, 14]         --
│    │    └─BassicConv2d: 3-65                [10, 160, 14, 14]         84,800
│    │    └─BassicConv2d: 3-66                [10, 320, 14, 14]         461,440
│    └─Sequential: 2-34                       [10, 128, 14, 14]         --
│    │    └─BassicConv2d: 3-67                [10, 32, 14, 14]          16,960
│    │    └─BassicConv2d: 3-68                [10, 128, 14, 14]         102,656
│    └─Sequential: 2-35                       [10, 128, 14, 14]         --
│    │    └─MaxPool2d: 3-69                   [10, 528, 14, 14]         --
│    │    └─BassicConv2d: 3-70                [10, 128, 14, 14]         67,840
├─MaxPool2d: 1-16                             [10, 832, 7, 7]           --
├─Inception: 1-17                             [10, 832, 7, 7]           --
│    └─BassicConv2d: 2-36                     [10, 256, 7, 7]           --
│    │    └─Sequential: 3-71                  [10, 256, 7, 7]           213,504
│    └─Sequential: 2-37                       [10, 320, 7, 7]           --
│    │    └─BassicConv2d: 3-72                [10, 160, 7, 7]           133,440
│    │    └─BassicConv2d: 3-73                [10, 320, 7, 7]           461,440
│    └─Sequential: 2-38                       [10, 128, 7, 7]           --
│    │    └─BassicConv2d: 3-74                [10, 32, 7, 7]            26,688
│    │    └─BassicConv2d: 3-75                [10, 128, 7, 7]           102,656
│    └─Sequential: 2-39                       [10, 128, 7, 7]           --
│    │    └─MaxPool2d: 3-76                   [10, 832, 7, 7]           --
│    │    └─BassicConv2d: 3-77                [10, 128, 7, 7]           106,752
├─Inception: 1-18                             [10, 1024, 7, 7]          --
│    └─BassicConv2d: 2-40                     [10, 384, 7, 7]           --
│    │    └─Sequential: 3-78                  [10, 384, 7, 7]           320,256
│    └─Sequential: 2-41                       [10, 384, 7, 7]           --
│    │    └─BassicConv2d: 3-79                [10, 192, 7, 7]           160,128
│    │    └─BassicConv2d: 3-80                [10, 384, 7, 7]           664,320
│    └─Sequential: 2-42                       [10, 128, 7, 7]           --
│    │    └─BassicConv2d: 3-81                [10, 48, 7, 7]            40,032
│    │    └─BassicConv2d: 3-82                [10, 128, 7, 7]           153,856
│    └─Sequential: 2-43                       [10, 128, 7, 7]           --
│    │    └─MaxPool2d: 3-83                   [10, 832, 7, 7]           --
│    │    └─BassicConv2d: 3-84                [10, 128, 7, 7]           106,752
├─AdaptiveAvgPool2d: 1-19                     [10, 1024, 1, 1]          --
├─Dropout: 1-20                               [10, 1024]                --
├─Linear: 1-21                                [10, 1000]                1,025,000
===============================================================================================
Total params: 13,385,816
Trainable params: 13,385,816
Non-trainable params: 0
Total mult-adds (G): 15.91
===============================================================================================
Input size (MB): 6.02
Forward/backward pass size (MB): 517.24
Params size (MB): 53.54
Estimated Total Size (MB): 576.81
===============================================================================================

当然我们也可以仅仅是查看最外层的基础的GoogLeNet结构,不看内部的。

summary(net,(10,3,224,224),device="mps",depth=1)
===============================================================================================
Layer (type:depth-idx)                        Output Shape              Param #
===============================================================================================
GoogLeNet                                     [10, 1000]                --
├─BassicConv2d: 1-1                           [10, 64, 112, 112]        9,536
├─MaxPool2d: 1-2                              [10, 64, 56, 56]          --
├─BassicConv2d: 1-3                           [10, 64, 56, 56]          4,224
├─BassicConv2d: 1-4                           [10, 192, 56, 56]         110,976
├─MaxPool2d: 1-5                              [10, 192, 28, 28]         --
├─Inception: 1-6                              [10, 256, 28, 28]         164,064
├─Inception: 1-7                              [10, 480, 28, 28]         389,376
├─MaxPool2d: 1-8                              [10, 480, 14, 14]         --
├─Inception: 1-9                              [10, 512, 14, 14]         376,800
├─AuxClf: 1-10                                [10, 1000]                3,188,968
├─Inception: 1-11                             [10, 512, 14, 14]         449,808
├─Inception: 1-12                             [10, 512, 14, 14]         510,768
├─Inception: 1-13                             [10, 528, 14, 14]         606,080
├─AuxClf: 1-14                                [10, 1000]                3,191,016
├─Inception: 1-15                             [10, 832, 14, 14]         869,376
├─MaxPool2d: 1-16                             [10, 832, 7, 7]           --
├─Inception: 1-17                             [10, 832, 7, 7]           1,044,480
├─Inception: 1-18                             [10, 1024, 7, 7]          1,445,344
├─AdaptiveAvgPool2d: 1-19                     [10, 1024, 1, 1]          --
├─Dropout: 1-20                               [10, 1024]                --
├─Linear: 1-21                                [10, 1000]                1,025,000
===============================================================================================
Total params: 13,385,816
Trainable params: 13,385,816
Non-trainable params: 0
Total mult-adds (G): 15.91
===============================================================================================
Input size (MB): 6.02
Forward/backward pass size (MB): 517.24
Params size (MB): 53.54
Estimated Total Size (MB): 576.81
===============================================================================================
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

桜キャンドル淵

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值