手动搭建ResNet模型(pytorch)


一、ResNet模型结构?

  • ResNet18、34、50、101、152

在这里插入图片描述

  • ResNet18、34对应左下的残差块;ResNet50、101、152对应右下的残差块;

在这里插入图片描述

二、代码示例

import torchvision
import torch
import torch.nn as nn

__all__ = ['ResNet50','ResNet101','ResNet152']

def Conv1(in_planes,out_planes,stride=2):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_planes,out_channels=out_planes,kernel_size=7,stride=stride,padding=3,bias=False),
        nn.BatchNorm2d(out_planes),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
    )

class Bottleneck(nn.Module):
    def __init__(self,in_places,out_places,stride=1,downsampling=False,expansion=4):
        super(Bottleneck,self).__init__()
        self.expansion = expansion
        self.downsampling = downsampling

        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels=in_places,out_channels=out_places,kernel_size=1,stride=1,bias=False),
            nn.BatchNorm2d(out_places),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_places,out_channels=out_places,kernel_size=3,stride=stride,padding=1,bias=False),
            nn.BatchNorm2d(out_places),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_places,out_channels=out_places*self.expansion,kernel_size=1,stride=1,bias=False),
            nn.BatchNorm2d(out_places*self.expansion)
        )

        if self.downsampling :
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels=in_places,out_channels=out_places*self.expansion,kernel_size=1,stride=stride,bias=False),
                nn.BatchNorm2d(out_places*self.expansion)
            )
        self.relu = nn.ReLU(inplace=True)
    def forward(self,x):
        residual = x
        out = self.bottleneck(x)

        if self.downsampling:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self,blocks,num_classes=1000,expansion=4):
        super(ResNet,self).__init__()
        self.expansion = expansion

        self.conv1 = Conv1(in_planes=3,out_planes=64)

        self.layer1 = self.make_layer(in_places=64,out_places=64,block=blocks[0],stride=1)
        self.layer2 = self.make_layer(in_places=256,out_places=128,block=blocks[1],stride=2)
        self.layer3 = self.make_layer(in_places=512,out_places=256,block=blocks[2],stride=2)
        self.layer4 = self.make_layer(in_places=1024,out_places=512,block=blocks[3],stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(2048,num_classes)

        # 定义初始化方式
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
            elif isinstance(m,nn.BatchNorm2d):
                nn.init.constant_(m.weight,1)
                nn.init.constant_(m.bias,0)



    def make_layer(self,in_places,out_places,block,stride):
        layers = []
        layers.append(Bottleneck(in_places,out_places,stride,downsampling=True))
        for i in range(1,block):
            layers.append(Bottleneck(out_places*self.expansion,out_places))

        return nn.Sequential(*layers)

    def forward(self,x):
        x = self.conv1(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        return x


def ResNet50():
    return ResNet([3,4,6,3])

def ResNet101():
    return ResNet([3,4,23,3])

def ResNet152():
    return ResNet([3,8,36,3])

if __name__== '__main__':
    from torchsummary import summary
    model = ResNet50()
    # print (model)

    # input = torch.randn(1,3,244,244)
    # out = model(input)
    # print (out.shape)

    summary(model,(3,244,244))
  • 输出
D:\Anaconda3\python.exe C:/Users/夏戈/Desktop/DeepNet/classification/resnet.py
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 122, 122]           9,408
       BatchNorm2d-2         [-1, 64, 122, 122]             128
              ReLU-3         [-1, 64, 122, 122]               0
         MaxPool2d-4           [-1, 64, 61, 61]               0
            Conv2d-5           [-1, 64, 61, 61]           4,096
       BatchNorm2d-6           [-1, 64, 61, 61]             128
              ReLU-7           [-1, 64, 61, 61]               0
            Conv2d-8           [-1, 64, 61, 61]          36,864
       BatchNorm2d-9           [-1, 64, 61, 61]             128
             ReLU-10           [-1, 64, 61, 61]               0
           Conv2d-11          [-1, 256, 61, 61]          16,384
      BatchNorm2d-12          [-1, 256, 61, 61]             512
           Conv2d-13          [-1, 256, 61, 61]          16,384
      BatchNorm2d-14          [-1, 256, 61, 61]             512
             ReLU-15          [-1, 256, 61, 61]               0
       Bottleneck-16          [-1, 256, 61, 61]               0
           Conv2d-17           [-1, 64, 61, 61]          16,384
      BatchNorm2d-18           [-1, 64, 61, 61]             128
             ReLU-19           [-1, 64, 61, 61]               0
           Conv2d-20           [-1, 64, 61, 61]          36,864
      BatchNorm2d-21           [-1, 64, 61, 61]             128
             ReLU-22           [-1, 64, 61, 61]               0
           Conv2d-23          [-1, 256, 61, 61]          16,384
      BatchNorm2d-24          [-1, 256, 61, 61]             512
             ReLU-25          [-1, 256, 61, 61]               0
       Bottleneck-26          [-1, 256, 61, 61]               0
           Conv2d-27           [-1, 64, 61, 61]          16,384
      BatchNorm2d-28           [-1, 64, 61, 61]             128
             ReLU-29           [-1, 64, 61, 61]               0
           Conv2d-30           [-1, 64, 61, 61]          36,864
      BatchNorm2d-31           [-1, 64, 61, 61]             128
             ReLU-32           [-1, 64, 61, 61]               0
           Conv2d-33          [-1, 256, 61, 61]          16,384
      BatchNorm2d-34          [-1, 256, 61, 61]             512
             ReLU-35          [-1, 256, 61, 61]               0
       Bottleneck-36          [-1, 256, 61, 61]               0
           Conv2d-37          [-1, 128, 61, 61]          32,768
      BatchNorm2d-38          [-1, 128, 61, 61]             256
             ReLU-39          [-1, 128, 61, 61]               0
           Conv2d-40          [-1, 128, 31, 31]         147,456
      BatchNorm2d-41          [-1, 128, 31, 31]             256
             ReLU-42          [-1, 128, 31, 31]               0
           Conv2d-43          [-1, 512, 31, 31]          65,536
      BatchNorm2d-44          [-1, 512, 31, 31]           1,024
           Conv2d-45          [-1, 512, 31, 31]         131,072
      BatchNorm2d-46          [-1, 512, 31, 31]           1,024
             ReLU-47          [-1, 512, 31, 31]               0
       Bottleneck-48          [-1, 512, 31, 31]               0
           Conv2d-49          [-1, 128, 31, 31]          65,536
      BatchNorm2d-50          [-1, 128, 31, 31]             256
             ReLU-51          [-1, 128, 31, 31]               0
           Conv2d-52          [-1, 128, 31, 31]         147,456
      BatchNorm2d-53          [-1, 128, 31, 31]             256
             ReLU-54          [-1, 128, 31, 31]               0
           Conv2d-55          [-1, 512, 31, 31]          65,536
      BatchNorm2d-56          [-1, 512, 31, 31]           1,024
             ReLU-57          [-1, 512, 31, 31]               0
       Bottleneck-58          [-1, 512, 31, 31]               0
           Conv2d-59          [-1, 128, 31, 31]          65,536
      BatchNorm2d-60          [-1, 128, 31, 31]             256
             ReLU-61          [-1, 128, 31, 31]               0
           Conv2d-62          [-1, 128, 31, 31]         147,456
      BatchNorm2d-63          [-1, 128, 31, 31]             256
             ReLU-64          [-1, 128, 31, 31]               0
           Conv2d-65          [-1, 512, 31, 31]          65,536
      BatchNorm2d-66          [-1, 512, 31, 31]           1,024
             ReLU-67          [-1, 512, 31, 31]               0
       Bottleneck-68          [-1, 512, 31, 31]               0
           Conv2d-69          [-1, 128, 31, 31]          65,536
      BatchNorm2d-70          [-1, 128, 31, 31]             256
             ReLU-71          [-1, 128, 31, 31]               0
           Conv2d-72          [-1, 128, 31, 31]         147,456
      BatchNorm2d-73          [-1, 128, 31, 31]             256
             ReLU-74          [-1, 128, 31, 31]               0
           Conv2d-75          [-1, 512, 31, 31]          65,536
      BatchNorm2d-76          [-1, 512, 31, 31]           1,024
             ReLU-77          [-1, 512, 31, 31]               0
       Bottleneck-78          [-1, 512, 31, 31]               0
           Conv2d-79          [-1, 256, 31, 31]         131,072
      BatchNorm2d-80          [-1, 256, 31, 31]             512
             ReLU-81          [-1, 256, 31, 31]               0
           Conv2d-82          [-1, 256, 16, 16]         589,824
      BatchNorm2d-83          [-1, 256, 16, 16]             512
             ReLU-84          [-1, 256, 16, 16]               0
           Conv2d-85         [-1, 1024, 16, 16]         262,144
      BatchNorm2d-86         [-1, 1024, 16, 16]           2,048
           Conv2d-87         [-1, 1024, 16, 16]         524,288
      BatchNorm2d-88         [-1, 1024, 16, 16]           2,048
             ReLU-89         [-1, 1024, 16, 16]               0
       Bottleneck-90         [-1, 1024, 16, 16]               0
           Conv2d-91          [-1, 256, 16, 16]         262,144
      BatchNorm2d-92          [-1, 256, 16, 16]             512
             ReLU-93          [-1, 256, 16, 16]               0
           Conv2d-94          [-1, 256, 16, 16]         589,824
      BatchNorm2d-95          [-1, 256, 16, 16]             512
             ReLU-96          [-1, 256, 16, 16]               0
           Conv2d-97         [-1, 1024, 16, 16]         262,144
      BatchNorm2d-98         [-1, 1024, 16, 16]           2,048
             ReLU-99         [-1, 1024, 16, 16]               0
      Bottleneck-100         [-1, 1024, 16, 16]               0
          Conv2d-101          [-1, 256, 16, 16]         262,144
     BatchNorm2d-102          [-1, 256, 16, 16]             512
            ReLU-103          [-1, 256, 16, 16]               0
          Conv2d-104          [-1, 256, 16, 16]         589,824
     BatchNorm2d-105          [-1, 256, 16, 16]             512
            ReLU-106          [-1, 256, 16, 16]               0
          Conv2d-107         [-1, 1024, 16, 16]         262,144
     BatchNorm2d-108         [-1, 1024, 16, 16]           2,048
            ReLU-109         [-1, 1024, 16, 16]               0
      Bottleneck-110         [-1, 1024, 16, 16]               0
          Conv2d-111          [-1, 256, 16, 16]         262,144
     BatchNorm2d-112          [-1, 256, 16, 16]             512
            ReLU-113          [-1, 256, 16, 16]               0
          Conv2d-114          [-1, 256, 16, 16]         589,824
     BatchNorm2d-115          [-1, 256, 16, 16]             512
            ReLU-116          [-1, 256, 16, 16]               0
          Conv2d-117         [-1, 1024, 16, 16]         262,144
     BatchNorm2d-118         [-1, 1024, 16, 16]           2,048
            ReLU-119         [-1, 1024, 16, 16]               0
      Bottleneck-120         [-1, 1024, 16, 16]               0
          Conv2d-121          [-1, 256, 16, 16]         262,144
     BatchNorm2d-122          [-1, 256, 16, 16]             512
            ReLU-123          [-1, 256, 16, 16]               0
          Conv2d-124          [-1, 256, 16, 16]         589,824
     BatchNorm2d-125          [-1, 256, 16, 16]             512
            ReLU-126          [-1, 256, 16, 16]               0
          Conv2d-127         [-1, 1024, 16, 16]         262,144
     BatchNorm2d-128         [-1, 1024, 16, 16]           2,048
            ReLU-129         [-1, 1024, 16, 16]               0
      Bottleneck-130         [-1, 1024, 16, 16]               0
          Conv2d-131          [-1, 256, 16, 16]         262,144
     BatchNorm2d-132          [-1, 256, 16, 16]             512
            ReLU-133          [-1, 256, 16, 16]               0
          Conv2d-134          [-1, 256, 16, 16]         589,824
     BatchNorm2d-135          [-1, 256, 16, 16]             512
            ReLU-136          [-1, 256, 16, 16]               0
          Conv2d-137         [-1, 1024, 16, 16]         262,144
     BatchNorm2d-138         [-1, 1024, 16, 16]           2,048
            ReLU-139         [-1, 1024, 16, 16]               0
      Bottleneck-140         [-1, 1024, 16, 16]               0
          Conv2d-141          [-1, 512, 16, 16]         524,288
     BatchNorm2d-142          [-1, 512, 16, 16]           1,024
            ReLU-143          [-1, 512, 16, 16]               0
          Conv2d-144            [-1, 512, 8, 8]       2,359,296
     BatchNorm2d-145            [-1, 512, 8, 8]           1,024
            ReLU-146            [-1, 512, 8, 8]               0
          Conv2d-147           [-1, 2048, 8, 8]       1,048,576
     BatchNorm2d-148           [-1, 2048, 8, 8]           4,096
          Conv2d-149           [-1, 2048, 8, 8]       2,097,152
     BatchNorm2d-150           [-1, 2048, 8, 8]           4,096
            ReLU-151           [-1, 2048, 8, 8]               0
      Bottleneck-152           [-1, 2048, 8, 8]               0
          Conv2d-153            [-1, 512, 8, 8]       1,048,576
     BatchNorm2d-154            [-1, 512, 8, 8]           1,024
            ReLU-155            [-1, 512, 8, 8]               0
          Conv2d-156            [-1, 512, 8, 8]       2,359,296
     BatchNorm2d-157            [-1, 512, 8, 8]           1,024
            ReLU-158            [-1, 512, 8, 8]               0
          Conv2d-159           [-1, 2048, 8, 8]       1,048,576
     BatchNorm2d-160           [-1, 2048, 8, 8]           4,096
            ReLU-161           [-1, 2048, 8, 8]               0
      Bottleneck-162           [-1, 2048, 8, 8]               0
          Conv2d-163            [-1, 512, 8, 8]       1,048,576
     BatchNorm2d-164            [-1, 512, 8, 8]           1,024
            ReLU-165            [-1, 512, 8, 8]               0
          Conv2d-166            [-1, 512, 8, 8]       2,359,296
     BatchNorm2d-167            [-1, 512, 8, 8]           1,024
            ReLU-168            [-1, 512, 8, 8]               0
          Conv2d-169           [-1, 2048, 8, 8]       1,048,576
     BatchNorm2d-170           [-1, 2048, 8, 8]           4,096
            ReLU-171           [-1, 2048, 8, 8]               0
      Bottleneck-172           [-1, 2048, 8, 8]               0
AdaptiveAvgPool2d-173           [-1, 2048, 1, 1]               0
          Linear-174                 [-1, 1000]       2,049,000
================================================================
Total params: 25,557,032
Trainable params: 25,557,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.68
Forward/backward pass size (MB): 351.15
Params size (MB): 97.49
Estimated Total Size (MB): 449.33
----------------------------------------------------------------

Process finished with exit code 0

三、参考链接

  • 5
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
好的,您可以参考下面的代码: ```python import torch.nn as nn class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion*planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion*planes) ) def forward(self, x): out = nn.ReLU()(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = nn.ReLU()(out) return out class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(ResNet, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512*block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = nn.ReLU()(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = nn.AvgPool2d(4)(out) out = out.view(out.size(0), -1) out = self.linear(out) return out def ResNet18(): return ResNet(BasicBlock, [2,2,2,2]) ``` 这是一个基于PyTorch搭建ResNet18模型的例子。其中,BasicBlock是ResNet18的基本模块,ResNet是整个网络的结构体,包含多个BasicBlock,_make_layer函数用于构建多层BasicBlock,forward函数定义前向传播过程,ResNet18函数返回一个ResNet18模型的实例。您可以根据自己的需求修改或扩展该模型
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值