实现Resnet50

模型结构参考https://zhuanlan.zhihu.com/p/353235794


import torch
import torch.nn as nn
import torchvision
from torchsummary import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class input_layer(nn.Module):   #
    def __init__(self):
        super(input_layer, self).__init__()
        # self.model = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=4),
        #                       nn.BatchNorm2d(64),
        #                       nn.ReLU(),
        #                       nn.MaxPool2d(kernel_size=3, stride=2))
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=4)
        self.BN1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.BN1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        # x = self.model(x)
        return x


class Resent(nn.Module):
    def __init__(self, input_inchannels, num_channels, conv_block = False, strides = 1):  # num_channels = input_channels*4
        super(Resent, self).__init__()
        self.input_channels = input_inchannels

        if conv_block == True:
            self.conv1 = nn.Conv2d(self.input_channels, input_inchannels, kernel_size=1, stride=strides)
            self.BN1 = nn.BatchNorm2d(input_inchannels)
            self.conv2 = nn.Conv2d(input_inchannels, input_inchannels, kernel_size=3, stride=1, padding=1)
            self.BN2 = nn.BatchNorm2d(input_inchannels)
            self.conv3 = nn.Conv2d(input_inchannels, num_channels, kernel_size=1, stride=1)
            self.BN3 = nn.BatchNorm2d(num_channels)
            self.relu1 = nn.ReLU()
            self.relu2 = nn.ReLU()
            self.relu3 = nn.ReLU()
            self.conv4 = nn.Conv2d(input_inchannels, num_channels,kernel_size=1 ,stride=strides)
            self.BN4 = nn.BatchNorm2d(num_channels)
        else:
            self.conv4 = None   # 传入的两个数据为输出维度

            # self.input_channels = num_channels
            output_center_channel =[self.input_channels // 4, self.input_channels // 16]
            # print(self.input_channels, output_center_channel[0])
            self.conv1 = nn.Conv2d(self.input_channels, output_center_channel[0], kernel_size=1, stride=strides)
            self.BN1 = nn.BatchNorm2d(output_center_channel[0])
            self.conv2 = nn.Conv2d(output_center_channel[0], output_center_channel[1], kernel_size=3, stride=1, padding=1)
            self.BN2 = nn.BatchNorm2d(output_center_channel[1])
            self.conv3 = nn.Conv2d(output_center_channel[1], num_channels, kernel_size=1, stride=1)
            self.BN3 = nn.BatchNorm2d(num_channels)
            self.relu1 = nn.ReLU()
            self.relu2 = nn.ReLU()
            self.relu3 = nn.ReLU()
    def forward(self, x):
        Y = self.relu1(self.BN1(self.conv1(x)))
        Y = self.relu2(self.BN2(self.conv2(Y)))
        Y = self.BN3(self.conv3(Y))
        if self.conv4 == None:
            x = x
        else:
            x = self.conv4(x)
            x = self.BN4(x)

        return self.relu3(x + Y)


def Resent_num(input_channels, num_channels, block_nums, h_w_half = False, first_blocks = False):  # first_blocks 可以控制是否要有conv_block
    block = []
    strides = 1
    if h_w_half == True:
        strides = 2
    for i in range(block_nums):
        if i == 0 and first_blocks == True:
            block.append(Resent(input_channels, num_channels, conv_block=True, strides=strides))
        else:
            block.append(Resent(num_channels, num_channels, conv_block=False))

    return block

class output_layer(nn.Module):
    def __init__(self, num_class):
        super(output_layer, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(1*2048, num_class)  # 输出10类
    def forward(self, x):
        x = self.fc(self.flatten(self.avgpool(x)))
        return x
class Resnet50(nn.Module):
    def __init__(self, num_class=10):
        super(Resnet50, self).__init__()

        b1 = input_layer()
        # res = Resent(64, 64, conv_block=True, strides=1).to(device)
        # print(b1(input).shape)
        b2 = Resent_num(64, 256, 3, h_w_half=False, first_blocks=True)
        b3 = Resent_num(256, 512, 4, h_w_half=True, first_blocks=True)
        b4 = Resent_num(512, 1024, 6, h_w_half=True, first_blocks=True)
        b5 = Resent_num(1024, 2048, 3, h_w_half=True, first_blocks=True)
        b6 = output_layer(num_class=num_class)
        self.model = nn.Sequential(b1, *b2, *b3, *b4, *b5, b6)

    def forward(self, x):
        # 测试每层维度
        # test = x
        # for layer in self.model:
        #     output = layer(test)
        #     test = output
        #     print(layer.__class__.__name__, output.shape)
        x = self.model(x)
        return x
if __name__ == '__main__':
    input = torch.randn(size=(1, 3, 224, 224), device=device)

    wuze = Resnet50(num_class=2)
    wuze.to(device)
    # output = wuze(input)
    summary(wuze, (3, 224, 224))
    # print(wuze(input).shape)

使用torchsummary查看模型

        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 113, 113]           9,472
       BatchNorm2d-2         [-1, 64, 113, 113]             128
              ReLU-3         [-1, 64, 113, 113]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
       input_layer-5           [-1, 64, 56, 56]               0
            Conv2d-6           [-1, 64, 56, 56]           4,160
       BatchNorm2d-7           [-1, 64, 56, 56]             128
              ReLU-8           [-1, 64, 56, 56]               0
            Conv2d-9           [-1, 64, 56, 56]          36,928
      BatchNorm2d-10           [-1, 64, 56, 56]             128
             ReLU-11           [-1, 64, 56, 56]               0
           Conv2d-12          [-1, 256, 56, 56]          16,640
      BatchNorm2d-13          [-1, 256, 56, 56]             512
           Conv2d-14          [-1, 256, 56, 56]          16,640
      BatchNorm2d-15          [-1, 256, 56, 56]             512
             ReLU-16          [-1, 256, 56, 56]               0
           Resent-17          [-1, 256, 56, 56]               0
           Conv2d-18           [-1, 64, 56, 56]          16,448
      BatchNorm2d-19           [-1, 64, 56, 56]             128
             ReLU-20           [-1, 64, 56, 56]               0
           Conv2d-21           [-1, 16, 56, 56]           9,232
      BatchNorm2d-22           [-1, 16, 56, 56]              32
             ReLU-23           [-1, 16, 56, 56]               0
           Conv2d-24          [-1, 256, 56, 56]           4,352
      BatchNorm2d-25          [-1, 256, 56, 56]             512
             ReLU-26          [-1, 256, 56, 56]               0
           Resent-27          [-1, 256, 56, 56]               0
           Conv2d-28           [-1, 64, 56, 56]          16,448
      BatchNorm2d-29           [-1, 64, 56, 56]             128
             ReLU-30           [-1, 64, 56, 56]               0
           Conv2d-31           [-1, 16, 56, 56]           9,232
      BatchNorm2d-32           [-1, 16, 56, 56]              32
             ReLU-33           [-1, 16, 56, 56]               0
           Conv2d-34          [-1, 256, 56, 56]           4,352
      BatchNorm2d-35          [-1, 256, 56, 56]             512
             ReLU-36          [-1, 256, 56, 56]               0
           Resent-37          [-1, 256, 56, 56]               0
           Conv2d-38          [-1, 256, 28, 28]          65,792
      BatchNorm2d-39          [-1, 256, 28, 28]             512
             ReLU-40          [-1, 256, 28, 28]               0
           Conv2d-41          [-1, 256, 28, 28]         590,080
      BatchNorm2d-42          [-1, 256, 28, 28]             512
             ReLU-43          [-1, 256, 28, 28]               0
           Conv2d-44          [-1, 512, 28, 28]         131,584
      BatchNorm2d-45          [-1, 512, 28, 28]           1,024
           Conv2d-46          [-1, 512, 28, 28]         131,584
      BatchNorm2d-47          [-1, 512, 28, 28]           1,024
             ReLU-48          [-1, 512, 28, 28]               0
           Resent-49          [-1, 512, 28, 28]               0
           Conv2d-50          [-1, 128, 28, 28]          65,664
      BatchNorm2d-51          [-1, 128, 28, 28]             256
             ReLU-52          [-1, 128, 28, 28]               0
           Conv2d-53           [-1, 32, 28, 28]          36,896
      BatchNorm2d-54           [-1, 32, 28, 28]              64
             ReLU-55           [-1, 32, 28, 28]               0
           Conv2d-56          [-1, 512, 28, 28]          16,896
      BatchNorm2d-57          [-1, 512, 28, 28]           1,024
             ReLU-58          [-1, 512, 28, 28]               0
           Resent-59          [-1, 512, 28, 28]               0
           Conv2d-60          [-1, 128, 28, 28]          65,664
      BatchNorm2d-61          [-1, 128, 28, 28]             256
             ReLU-62          [-1, 128, 28, 28]               0
           Conv2d-63           [-1, 32, 28, 28]          36,896
      BatchNorm2d-64           [-1, 32, 28, 28]              64
             ReLU-65           [-1, 32, 28, 28]               0
           Conv2d-66          [-1, 512, 28, 28]          16,896
      BatchNorm2d-67          [-1, 512, 28, 28]           1,024
             ReLU-68          [-1, 512, 28, 28]               0
           Resent-69          [-1, 512, 28, 28]               0
           Conv2d-70          [-1, 128, 28, 28]          65,664
      BatchNorm2d-71          [-1, 128, 28, 28]             256
             ReLU-72          [-1, 128, 28, 28]               0
           Conv2d-73           [-1, 32, 28, 28]          36,896
      BatchNorm2d-74           [-1, 32, 28, 28]              64
             ReLU-75           [-1, 32, 28, 28]               0
           Conv2d-76          [-1, 512, 28, 28]          16,896
      BatchNorm2d-77          [-1, 512, 28, 28]           1,024
             ReLU-78          [-1, 512, 28, 28]               0
           Resent-79          [-1, 512, 28, 28]               0
           Conv2d-80          [-1, 512, 14, 14]         262,656
      BatchNorm2d-81          [-1, 512, 14, 14]           1,024
             ReLU-82          [-1, 512, 14, 14]               0
           Conv2d-83          [-1, 512, 14, 14]       2,359,808
      BatchNorm2d-84          [-1, 512, 14, 14]           1,024
             ReLU-85          [-1, 512, 14, 14]               0
           Conv2d-86         [-1, 1024, 14, 14]         525,312
      BatchNorm2d-87         [-1, 1024, 14, 14]           2,048
           Conv2d-88         [-1, 1024, 14, 14]         525,312
      BatchNorm2d-89         [-1, 1024, 14, 14]           2,048
             ReLU-90         [-1, 1024, 14, 14]               0
           Resent-91         [-1, 1024, 14, 14]               0
           Conv2d-92          [-1, 256, 14, 14]         262,400
      BatchNorm2d-93          [-1, 256, 14, 14]             512
             ReLU-94          [-1, 256, 14, 14]               0
           Conv2d-95           [-1, 64, 14, 14]         147,520
      BatchNorm2d-96           [-1, 64, 14, 14]             128
             ReLU-97           [-1, 64, 14, 14]               0
           Conv2d-98         [-1, 1024, 14, 14]          66,560
      BatchNorm2d-99         [-1, 1024, 14, 14]           2,048
            ReLU-100         [-1, 1024, 14, 14]               0
          Resent-101         [-1, 1024, 14, 14]               0
          Conv2d-102          [-1, 256, 14, 14]         262,400
     BatchNorm2d-103          [-1, 256, 14, 14]             512
            ReLU-104          [-1, 256, 14, 14]               0
          Conv2d-105           [-1, 64, 14, 14]         147,520
     BatchNorm2d-106           [-1, 64, 14, 14]             128
            ReLU-107           [-1, 64, 14, 14]               0
          Conv2d-108         [-1, 1024, 14, 14]          66,560
     BatchNorm2d-109         [-1, 1024, 14, 14]           2,048
            ReLU-110         [-1, 1024, 14, 14]               0
          Resent-111         [-1, 1024, 14, 14]               0
          Conv2d-112          [-1, 256, 14, 14]         262,400
     BatchNorm2d-113          [-1, 256, 14, 14]             512
            ReLU-114          [-1, 256, 14, 14]               0
          Conv2d-115           [-1, 64, 14, 14]         147,520
     BatchNorm2d-116           [-1, 64, 14, 14]             128
            ReLU-117           [-1, 64, 14, 14]               0
          Conv2d-118         [-1, 1024, 14, 14]          66,560
     BatchNorm2d-119         [-1, 1024, 14, 14]           2,048
            ReLU-120         [-1, 1024, 14, 14]               0
          Resent-121         [-1, 1024, 14, 14]               0
          Conv2d-122          [-1, 256, 14, 14]         262,400
     BatchNorm2d-123          [-1, 256, 14, 14]             512
            ReLU-124          [-1, 256, 14, 14]               0
          Conv2d-125           [-1, 64, 14, 14]         147,520
     BatchNorm2d-126           [-1, 64, 14, 14]             128
            ReLU-127           [-1, 64, 14, 14]               0
          Conv2d-128         [-1, 1024, 14, 14]          66,560
     BatchNorm2d-129         [-1, 1024, 14, 14]           2,048
            ReLU-130         [-1, 1024, 14, 14]               0
          Resent-131         [-1, 1024, 14, 14]               0
          Conv2d-132          [-1, 256, 14, 14]         262,400
     BatchNorm2d-133          [-1, 256, 14, 14]             512
            ReLU-134          [-1, 256, 14, 14]               0
          Conv2d-135           [-1, 64, 14, 14]         147,520
     BatchNorm2d-136           [-1, 64, 14, 14]             128
            ReLU-137           [-1, 64, 14, 14]               0
          Conv2d-138         [-1, 1024, 14, 14]          66,560
     BatchNorm2d-139         [-1, 1024, 14, 14]           2,048
            ReLU-140         [-1, 1024, 14, 14]               0
          Resent-141         [-1, 1024, 14, 14]               0
          Conv2d-142           [-1, 1024, 7, 7]       1,049,600
     BatchNorm2d-143           [-1, 1024, 7, 7]           2,048
            ReLU-144           [-1, 1024, 7, 7]               0
          Conv2d-145           [-1, 1024, 7, 7]       9,438,208
     BatchNorm2d-146           [-1, 1024, 7, 7]           2,048
            ReLU-147           [-1, 1024, 7, 7]               0
          Conv2d-148           [-1, 2048, 7, 7]       2,099,200
     BatchNorm2d-149           [-1, 2048, 7, 7]           4,096
          Conv2d-150           [-1, 2048, 7, 7]       2,099,200
     BatchNorm2d-151           [-1, 2048, 7, 7]           4,096
            ReLU-152           [-1, 2048, 7, 7]               0
          Resent-153           [-1, 2048, 7, 7]               0
          Conv2d-154            [-1, 512, 7, 7]       1,049,088
     BatchNorm2d-155            [-1, 512, 7, 7]           1,024
            ReLU-156            [-1, 512, 7, 7]               0
          Conv2d-157            [-1, 128, 7, 7]         589,952
     BatchNorm2d-158            [-1, 128, 7, 7]             256
            ReLU-159            [-1, 128, 7, 7]               0
          Conv2d-160           [-1, 2048, 7, 7]         264,192
     BatchNorm2d-161           [-1, 2048, 7, 7]           4,096
            ReLU-162           [-1, 2048, 7, 7]               0
          Resent-163           [-1, 2048, 7, 7]               0
          Conv2d-164            [-1, 512, 7, 7]       1,049,088
     BatchNorm2d-165            [-1, 512, 7, 7]           1,024
            ReLU-166            [-1, 512, 7, 7]               0
          Conv2d-167            [-1, 128, 7, 7]         589,952
     BatchNorm2d-168            [-1, 128, 7, 7]             256
            ReLU-169            [-1, 128, 7, 7]               0
          Conv2d-170           [-1, 2048, 7, 7]         264,192
     BatchNorm2d-171           [-1, 2048, 7, 7]           4,096
            ReLU-172           [-1, 2048, 7, 7]               0
          Resent-173           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-174           [-1, 2048, 1, 1]               0
         Flatten-175                 [-1, 2048]               0
          Linear-176                    [-1, 2]           4,098
    output_layer-177                    [-1, 2]               0
================================================================
Total params: 26,026,050
Trainable params: 26,026,050
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 267.18
Params size (MB): 99.28
Estimated Total Size (MB): 367.04
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值