import torch
import torch.nn as nn
from torch.nn.functional import log_softmax
class ResBlock(nn.Module):
'''
在这里实现残差的单元模块
'''
def __init__(self,in_channel,out_channel,stride):
super(ResBlock, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(in_channel,out_channel,(3,3),stride,padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(),
nn.Conv2d(out_channel,out_channel,(3,3),(1,1),padding=1),
nn.BatchNorm2d(out_channel)
)
self.shutcut = nn.Sequential()
if stride!=1 or in_channel!=out_channel:
self.shutcut = nn.Sequential(
nn.Conv2d(in_channel, out_channel, (3, 3), stride,padding=1),
nn.BatchNorm2d(out_channel),
nn.ReLU()
)
def forward(self,x):
left = self.layer(x)
# print('left',left.size())
right = self.shutcut(x)
# print('right', left.size())
out = left + right
return out
class ResNet(nn.Module):
def make_layer(self,resblock,in_channel,out_channel,stride,num_block):
block = nn.Sequential()
for i in range(num_block):
# 仅有第一次的时候进行一次降采样
if stride != 1 and i==0:
in_stride = stride
else:
in_stride = 1
res_block = resblock(in_channel,out_channel,in_stride)
block.append(res_block)
in_channel = out_channel
return block
def __init__(self,resblock):
super(ResNet, self).__init__()
self.inchannel = 32
self.init_conv = nn.Sequential(
nn.Conv2d(3, self.inchannel, (3, 3), (1, 1), 1),
nn.BatchNorm2d(32),
nn.ReLU()
)
self.layer1 = self.make_layer(resblock, 32, 64, 2, 2) # 14
self.layer2 = self.make_layer(resblock, 64, 128, 2, 2) # 7
self.layer3 = self.make_layer(resblock, 128, 256, 2, 2) # 4
self.layer4 = self.make_layer(resblock, 256, 512, 2, 2) # 2
self.avg = nn.AvgPool2d(2,1)
self.linear = nn.Linear(512,10)
def forward(self,x):
out = self.init_conv(x)
# print('1',out.size())
out = self.layer1(out)
# print('2', out.size())
out = self.layer2(out)
# print('3', out.size())
out = self.layer3(out)
# print('4', out.size())
out = self.layer4(out)
# print('5', out.size())
avg = self.avg(out)
out = avg.view(avg.size()[0],-1)
out = self.linear(out)
return log_softmax(out,dim=1)
def resnet():
return ResNet(ResBlock)
resnet()