一、ResNet模型结构?
- ResNet18、34对应左下的残差块;ResNet50、101、152对应右下的残差块;
二、代码示例
import torchvision
import torch
import torch.nn as nn
__all__ = ['ResNet50','ResNet101','ResNet152']
def Conv1(in_planes,out_planes,stride=2):
return nn.Sequential(
nn.Conv2d(in_channels=in_planes,out_channels=out_planes,kernel_size=7,stride=stride,padding=3,bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
)
class Bottleneck(nn.Module):
def __init__(self,in_places,out_places,stride=1,downsampling=False,expansion=4):
super(Bottleneck,self).__init__()
self.expansion = expansion
self.downsampling = downsampling
self.bottleneck = nn.Sequential(
nn.Conv2d(in_channels=in_places,out_channels=out_places,kernel_size=1,stride=1,bias=False),
nn.BatchNorm2d(out_places),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=out_places,out_channels=out_places,kernel_size=3,stride=stride,padding=1,bias=False),
nn.BatchNorm2d(out_places),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=out_places,out_channels=out_places*self.expansion,kernel_size=1,stride=1,bias=False),
nn.BatchNorm2d(out_places*self.expansion)
)
if self.downsampling :
self.downsample = nn.Sequential(
nn.Conv2d(in_channels=in_places,out_channels=out_places*self.expansion,kernel_size=1,stride=stride,bias=False),
nn.BatchNorm2d(out_places*self.expansion)
)
self.relu = nn.ReLU(inplace=True)
def forward(self,x):
residual = x
out = self.bottleneck(x)
if self.downsampling:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,blocks,num_classes=1000,expansion=4):
super(ResNet,self).__init__()
self.expansion = expansion
self.conv1 = Conv1(in_planes=3,out_planes=64)
self.layer1 = self.make_layer(in_places=64,out_places=64,block=blocks[0],stride=1)
self.layer2 = self.make_layer(in_places=256,out_places=128,block=blocks[1],stride=2)
self.layer3 = self.make_layer(in_places=512,out_places=256,block=blocks[2],stride=2)
self.layer4 = self.make_layer(in_places=1024,out_places=512,block=blocks[3],stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(2048,num_classes)
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
elif isinstance(m,nn.BatchNorm2d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias,0)
def make_layer(self,in_places,out_places,block,stride):
layers = []
layers.append(Bottleneck(in_places,out_places,stride,downsampling=True))
for i in range(1,block):
layers.append(Bottleneck(out_places*self.expansion,out_places))
return nn.Sequential(*layers)
def forward(self,x):
x = self.conv1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0),-1)
x = self.fc(x)
return x
def ResNet50():
return ResNet([3,4,6,3])
def ResNet101():
return ResNet([3,4,23,3])
def ResNet152():
return ResNet([3,8,36,3])
if __name__== '__main__':
from torchsummary import summary
model = ResNet50()
summary(model,(3,244,244))
D:\Anaconda3\python.exe C:/Users/夏戈/Desktop/DeepNet/classification/resnet.py
----------------------------------------------------------------
Layer (type) Output Shape Param
================================================================
Conv2d-1 [-1, 64, 122, 122] 9,408
BatchNorm2d-2 [-1, 64, 122, 122] 128
ReLU-3 [-1, 64, 122, 122] 0
MaxPool2d-4 [-1, 64, 61, 61] 0
Conv2d-5 [-1, 64, 61, 61] 4,096
BatchNorm2d-6 [-1, 64, 61, 61] 128
ReLU-7 [-1, 64, 61, 61] 0
Conv2d-8 [-1, 64, 61, 61] 36,864
BatchNorm2d-9 [-1, 64, 61, 61] 128
ReLU-10 [-1, 64, 61, 61] 0
Conv2d-11 [-1, 256, 61, 61] 16,384
BatchNorm2d-12 [-1, 256, 61, 61] 512
Conv2d-13 [-1, 256, 61, 61] 16,384
BatchNorm2d-14 [-1, 256, 61, 61] 512
ReLU-15 [-1, 256, 61, 61] 0
Bottleneck-16 [-1, 256, 61, 61] 0
Conv2d-17 [-1, 64, 61, 61] 16,384
BatchNorm2d-18 [-1, 64, 61, 61] 128
ReLU-19 [-1, 64, 61, 61] 0
Conv2d-20 [-1, 64, 61, 61] 36,864
BatchNorm2d-21 [-1, 64, 61, 61] 128
ReLU-22 [-1, 64, 61, 61] 0
Conv2d-23 [-1, 256, 61, 61] 16,384
BatchNorm2d-24 [-1, 256, 61, 61] 512
ReLU-25 [-1, 256, 61, 61] 0
Bottleneck-26 [-1, 256, 61, 61] 0
Conv2d-27 [-1, 64, 61, 61] 16,384
BatchNorm2d-28 [-1, 64, 61, 61] 128
ReLU-29 [-1, 64, 61, 61] 0
Conv2d-30 [-1, 64, 61, 61] 36,864
BatchNorm2d-31 [-1, 64, 61, 61] 128
ReLU-32 [-1, 64, 61, 61] 0
Conv2d-33 [-1, 256, 61, 61] 16,384
BatchNorm2d-34 [-1, 256, 61, 61] 512
ReLU-35 [-1, 256, 61, 61] 0
Bottleneck-36 [-1, 256, 61, 61] 0
Conv2d-37 [-1, 128, 61, 61] 32,768
BatchNorm2d-38 [-1, 128, 61, 61] 256
ReLU-39 [-1, 128, 61, 61] 0
Conv2d-40 [-1, 128, 31, 31] 147,456
BatchNorm2d-41 [-1, 128, 31, 31] 256
ReLU-42 [-1, 128, 31, 31] 0
Conv2d-43 [-1, 512, 31, 31] 65,536
BatchNorm2d-44 [-1, 512, 31, 31] 1,024
Conv2d-45 [-1, 512, 31, 31] 131,072
BatchNorm2d-46 [-1, 512, 31, 31] 1,024
ReLU-47 [-1, 512, 31, 31] 0
Bottleneck-48 [-1, 512, 31, 31] 0
Conv2d-49 [-1, 128, 31, 31] 65,536
BatchNorm2d-50 [-1, 128, 31, 31] 256
ReLU-51 [-1, 128, 31, 31] 0
Conv2d-52 [-1, 128, 31, 31] 147,456
BatchNorm2d-53 [-1, 128, 31, 31] 256
ReLU-54 [-1, 128, 31, 31] 0
Conv2d-55 [-1, 512, 31, 31] 65,536
BatchNorm2d-56 [-1, 512, 31, 31] 1,024
ReLU-57 [-1, 512, 31, 31] 0
Bottleneck-58 [-1, 512, 31, 31] 0
Conv2d-59 [-1, 128, 31, 31] 65,536
BatchNorm2d-60 [-1, 128, 31, 31] 256
ReLU-61 [-1, 128, 31, 31] 0
Conv2d-62 [-1, 128, 31, 31] 147,456
BatchNorm2d-63 [-1, 128, 31, 31] 256
ReLU-64 [-1, 128, 31, 31] 0
Conv2d-65 [-1, 512, 31, 31] 65,536
BatchNorm2d-66 [-1, 512, 31, 31] 1,024
ReLU-67 [-1, 512, 31, 31] 0
Bottleneck-68 [-1, 512, 31, 31] 0
Conv2d-69 [-1, 128, 31, 31] 65,536
BatchNorm2d-70 [-1, 128, 31, 31] 256
ReLU-71 [-1, 128, 31, 31] 0
Conv2d-72 [-1, 128, 31, 31] 147,456
BatchNorm2d-73 [-1, 128, 31, 31] 256
ReLU-74 [-1, 128, 31, 31] 0
Conv2d-75 [-1, 512, 31, 31] 65,536
BatchNorm2d-76 [-1, 512, 31, 31] 1,024
ReLU-77 [-1, 512, 31, 31] 0
Bottleneck-78 [-1, 512, 31, 31] 0
Conv2d-79 [-1, 256, 31, 31] 131,072
BatchNorm2d-80 [-1, 256, 31, 31] 512
ReLU-81 [-1, 256, 31, 31] 0
Conv2d-82 [-1, 256, 16, 16] 589,824
BatchNorm2d-83 [-1, 256, 16, 16] 512
ReLU-84 [-1, 256, 16, 16] 0
Conv2d-85 [-1, 1024, 16, 16] 262,144
BatchNorm2d-86 [-1, 1024, 16, 16] 2,048
Conv2d-87 [-1, 1024, 16, 16] 524,288
BatchNorm2d-88 [-1, 1024, 16, 16] 2,048
ReLU-89 [-1, 1024, 16, 16] 0
Bottleneck-90 [-1, 1024, 16, 16] 0
Conv2d-91 [-1, 256, 16, 16] 262,144
BatchNorm2d-92 [-1, 256, 16, 16] 512
ReLU-93 [-1, 256, 16, 16] 0
Conv2d-94 [-1, 256, 16, 16] 589,824
BatchNorm2d-95 [-1, 256, 16, 16] 512
ReLU-96 [-1, 256, 16, 16] 0
Conv2d-97 [-1, 1024, 16, 16] 262,144
BatchNorm2d-98 [-1, 1024, 16, 16] 2,048
ReLU-99 [-1, 1024, 16, 16] 0
Bottleneck-100 [-1, 1024, 16, 16] 0
Conv2d-101 [-1, 256, 16, 16] 262,144
BatchNorm2d-102 [-1, 256, 16, 16] 512
ReLU-103 [-1, 256, 16, 16] 0
Conv2d-104 [-1, 256, 16, 16] 589,824
BatchNorm2d-105 [-1, 256, 16, 16] 512
ReLU-106 [-1, 256, 16, 16] 0
Conv2d-107 [-1, 1024, 16, 16] 262,144
BatchNorm2d-108 [-1, 1024, 16, 16] 2,048
ReLU-109 [-1, 1024, 16, 16] 0
Bottleneck-110 [-1, 1024, 16, 16] 0
Conv2d-111 [-1, 256, 16, 16] 262,144
BatchNorm2d-112 [-1, 256, 16, 16] 512
ReLU-113 [-1, 256, 16, 16] 0
Conv2d-114 [-1, 256, 16, 16] 589,824
BatchNorm2d-115 [-1, 256, 16, 16] 512
ReLU-116 [-1, 256, 16, 16] 0
Conv2d-117 [-1, 1024, 16, 16] 262,144
BatchNorm2d-118 [-1, 1024, 16, 16] 2,048
ReLU-119 [-1, 1024, 16, 16] 0
Bottleneck-120 [-1, 1024, 16, 16] 0
Conv2d-121 [-1, 256, 16, 16] 262,144
BatchNorm2d-122 [-1, 256, 16, 16] 512
ReLU-123 [-1, 256, 16, 16] 0
Conv2d-124 [-1, 256, 16, 16] 589,824
BatchNorm2d-125 [-1, 256, 16, 16] 512
ReLU-126 [-1, 256, 16, 16] 0
Conv2d-127 [-1, 1024, 16, 16] 262,144
BatchNorm2d-128 [-1, 1024, 16, 16] 2,048
ReLU-129 [-1, 1024, 16, 16] 0
Bottleneck-130 [-1, 1024, 16, 16] 0
Conv2d-131 [-1, 256, 16, 16] 262,144
BatchNorm2d-132 [-1, 256, 16, 16] 512
ReLU-133 [-1, 256, 16, 16] 0
Conv2d-134 [-1, 256, 16, 16] 589,824
BatchNorm2d-135 [-1, 256, 16, 16] 512
ReLU-136 [-1, 256, 16, 16] 0
Conv2d-137 [-1, 1024, 16, 16] 262,144
BatchNorm2d-138 [-1, 1024, 16, 16] 2,048
ReLU-139 [-1, 1024, 16, 16] 0
Bottleneck-140 [-1, 1024, 16, 16] 0
Conv2d-141 [-1, 512, 16, 16] 524,288
BatchNorm2d-142 [-1, 512, 16, 16] 1,024
ReLU-143 [-1, 512, 16, 16] 0
Conv2d-144 [-1, 512, 8, 8] 2,359,296
BatchNorm2d-145 [-1, 512, 8, 8] 1,024
ReLU-146 [-1, 512, 8, 8] 0
Conv2d-147 [-1, 2048, 8, 8] 1,048,576
BatchNorm2d-148 [-1, 2048, 8, 8] 4,096
Conv2d-149 [-1, 2048, 8, 8] 2,097,152
BatchNorm2d-150 [-1, 2048, 8, 8] 4,096
ReLU-151 [-1, 2048, 8, 8] 0
Bottleneck-152 [-1, 2048, 8, 8] 0
Conv2d-153 [-1, 512, 8, 8] 1,048,576
BatchNorm2d-154 [-1, 512, 8, 8] 1,024
ReLU-155 [-1, 512, 8, 8] 0
Conv2d-156 [-1, 512, 8, 8] 2,359,296
BatchNorm2d-157 [-1, 512, 8, 8] 1,024
ReLU-158 [-1, 512, 8, 8] 0
Conv2d-159 [-1, 2048, 8, 8] 1,048,576
BatchNorm2d-160 [-1, 2048, 8, 8] 4,096
ReLU-161 [-1, 2048, 8, 8] 0
Bottleneck-162 [-1, 2048, 8, 8] 0
Conv2d-163 [-1, 512, 8, 8] 1,048,576
BatchNorm2d-164 [-1, 512, 8, 8] 1,024
ReLU-165 [-1, 512, 8, 8] 0
Conv2d-166 [-1, 512, 8, 8] 2,359,296
BatchNorm2d-167 [-1, 512, 8, 8] 1,024
ReLU-168 [-1, 512, 8, 8] 0
Conv2d-169 [-1, 2048, 8, 8] 1,048,576
BatchNorm2d-170 [-1, 2048, 8, 8] 4,096
ReLU-171 [-1, 2048, 8, 8] 0
Bottleneck-172 [-1, 2048, 8, 8] 0
AdaptiveAvgPool2d-173 [-1, 2048, 1, 1] 0
Linear-174 [-1, 1000] 2,049,000
================================================================
Total params: 25,557,032
Trainable params: 25,557,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.68
Forward/backward pass size (MB): 351.15
Params size (MB): 97.49
Estimated Total Size (MB): 449.33
----------------------------------------------------------------
Process finished with exit code 0
三、参考链接