def __init__(self, in_planes, out_channels, stride=1, padding=0):
super(conv3x3, self).__init__()
self.conv3x3 = nn.Sequential(
nn.Conv2d(in_planes, out_channels, kernel_size=3, stride=stride, padding=padding),#卷积核为3x3
nn.BatchNorm2d(out_channels),#BN层,防止过拟合以及梯度爆炸
nn.ReLU()#激活函数
)
def forward(self, input):
return self.conv3x3(input)
class conv1x1(nn.Module):
def init(self, in_planes, out_channels, stride=1, padding=0):
super(conv1x1, self).init()
self.conv1x1 = nn.Sequential(
nn.Conv2d(in_planes, out_channels, kernel_size=1, stride=stride, padding=padding),#卷积核为1x1
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, input):
return self.conv1x1(input)
Stem模块:输入299\*299\*3,输出35\*35\*256.
class StemV1(nn.Module):
def init(self, in_planes):
super(StemV1, self).init()
self.conv1 = conv3x3(in_planes =in_planes,out_channels=32,stride=2, padding=0)
self.conv2 = conv3x3(in_planes=32, out_channels=32, stride=1, padding=0)
self.conv3 = conv3x3(in_planes=32, out_channels=64, stride=1, padding=1)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
self.conv4 = conv3x3(in_planes=64, out_channels=64, stride=1, padding=1)
self.conv5 = conv1x1(in_planes =64,out_channels=80, stride=1, padding=0)
self.conv6 = conv3x3(in_planes=80, out_channels=192, stride=1, padding=0)
self.conv7 = conv3x3(in_planes=192, out_channels=256, stride=2, padding=0)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.maxpool(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.conv6(x)
x = self.conv7(x)
return x
IR-A模块\*5:输入35\*35\*256,输出35\*35\*256.
class Inception_ResNet_A(nn.Module):
def init(self, input ):
super(Inception_ResNet_A, self).init()
self.conv1 = conv1x1(in_planes =input,out_channels=32,stride=1, padding=0)
self.conv2 = conv3x3(in_planes=32, out_channels=32, stride=1, padding=1)
self.line = nn.Conv2d(96, 256, 1, stride=1, padding=0, bias=True)
self.relu = nn.ReLU()
def forward(self, x):
c1 = self.conv1(x)
# print("c1",c1.shape)
c2 = self.conv1(x)
# print("c2", c2.shape)
c3 = self.conv1(x)
# print("c3", c3.shape)
c2_1 = self.conv2(c2)
# print("c2_1", c2_1.shape)
c3_1 = self.conv2(c3)
# print("c3_1", c3_1.shape)
c3_2 = self.conv2(c3_1)
# print("c3_2", c3_2.shape)
cat = torch.cat([c1, c2_1, c3_2],dim=1)#torch.Size([4, 96, 15, 15])
# print("x",x.shape)
line = self.line(cat)
# print("line",line.shape)
out =x+line
out = self.relu(out)
return out
Reduction-