import torch
import torch.nn as nn
class BasicConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=(1, 1), padding=(0, 0)):
super(BasicConv, self).__init__()
self.conv = nn.Conv2d(
in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class StemModel(nn.Module):
def __init__(self):
super(StemModel, self).__init__()
self.conv_1 = BasicConv(in_channels=3, out_channels=32, kernel_size=3, stride=2)
self.conv_2 = BasicConv(in_channels=32, out_channels=32, kernel_size=3)
self.conv_3 = BasicConv(in_channels=32, out_channels=64, kernel_size=3, padding=1)
self.branch_1_1 = nn.MaxPool2d(kernel_size=3, stride=2)
self.branch_1_2 = BasicConv(in_channels=64, out_channels=96, kernel_size=3, stride=2)
self.branch_2_1 = nn.Sequential(
BasicConv(in_channels=160, out_channels=64, kernel_size=1),
BasicConv(in_channels=64, out_channels=96, kernel_size=3)
)
self.branch_2_2 = nn.Sequential(
BasicConv(in_channels=160, out_channels=64, kernel_size=1),
# 这里的padding是自己算出来的,因为2个branch的输出size不同无法拼接,故在这两层增加补齐操作
BasicConv(in_channels=64, out_channels=64, kernel_size=(7, 1), padding=(3, 0)), # 左右的size+3=+6
BasicConv(in_channels=64, out_channels=64, kernel_size=(1, 7), padding=(0, 3)),
# 根据figure3图中3x3层有size压缩标记,故不在这里进行padding来做补齐操作
BasicConv(in_channels=64, out_channels=96, kernel_size=3)
)
self.branch_3_1 = BasicConv(in_channels=192, out_channels=192, kernel_size=3, stride=2)
self.branch_3_2 = nn.MaxPool2d(kernel_size=3, stride=2)
def forward(self, x):
x = self.conv_1(x)
x = self.conv_2(x)
x = self.conv_3(x)
x_1 = self.branch_1_1(x)
x_2 = self.branch_1_2(x)
x = torch.cat([x_1, x_2], dim=1)
x_1 = self.branch_2_1(x)
x_2 = self.branch_2_2(x)
x = torch.cat([x_1, x_2], dim=1)
x_1 = self.branch_3_1(x)
x_2 = self.branch_3_2(x)
x = torch.cat([x_1, x_2], dim=1)
return x
class InceptionA(nn.Module):
def __init__(self):
super(InceptionA, self).__init__()
self.branch_1 = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=35),
BasicConv(in_channels=384, out_channels=96, kernel_size=1)
)
self.branch_2 = BasicConv(in_channels=384, out_channels=96, kernel_size=1)
self.branch_3 = nn.Sequential(
BasicConv(in_channels=384, out_channels=64, kernel_size=1),
BasicConv(in_channels=64, out_channels=96, kernel_size=3, padding=1)
)
self.branch_4 = nn.Sequential(
BasicConv(in_channels=384, out_channels=64, kernel_size=1),
BasicConv(in_channels=64, out_channels=96, kernel_size=3, padding=1),
BasicConv(in_channels=96, out_channels=96, kernel_size=3, padding=1)
)
def forward(self, x):
x_1 = self.branch_1(x)
x_2 = self.branch_2(x)
x_3 = self.branch_3(x)
x_4 = self.branch_4(x)
torch.cat([x_1, x_2, x_3, x_4], dim=1)
return x
class InceptionB(nn.Module):
def __init__(self):
super(InceptionB, self).__init__()
self.branch_1 = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=17),
BasicConv(in_channels=1024, out_channels=128, kernel_size=1)
)
self.branch_2 = BasicConv(in_channels=1024, out_channels=384, kernel_size=1)
self.branch_3 = nn.Sequential(
BasicConv(in_channels=1024, out_channels=192, kernel_size=1),
BasicConv(in_channels=192, out_channels=224, kernel_size=(1, 7), padding=(0, 3)), # 根据V2经验进行padding猜测
BasicConv(in_channels=224, out_channels=256, kernel_size=(7, 1), padding=(3, 0))
)
self.branch_4 = nn.Sequential(
BasicConv(in_channels=1024, out_channels=192, kernel_size=1),
BasicConv(in_channels=192, out_channels=192, kernel_size=(1, 7), padding=(0, 3)),
BasicConv(in_channels=192, out_channels=224, kernel_size=(7, 1), padding=(3, 0)),
BasicConv(in_channels=224, out_channels=224, kernel_size=(1, 7), padding=(0, 3)),
BasicConv(in_channels=224, out_channels=256, kernel_size=(7, 1), padding=(3, 0)),
)
def forward(self, x):
x_1 = self.branch_1(x)
x_2 = self.branch_2(x)
x_3 = self.branch_3(x)
x_4 = self.branch_4(x)
torch.cat([x_1, x_2, x_3, x_4], dim=1)
return x
class InceptionC(nn.Module):
def __init__(self):
super(InceptionC, self).__init__()
self.branch_1 = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=8),
BasicConv(in_channels=1536, out_channels=256, kernel_size=1)
)
self.branch_2 = BasicConv(in_channels=1536, out_channels=256, kernel_size=1)
self.branch_3_1 = BasicConv(in_channels=1536, out_channels=384, kernel_size=1)
self.branch_3_2_1 = BasicConv(in_channels=384, out_channels=256, kernel_size=(1, 3), padding=(0, 1))
self.branch_3_2_2 = BasicConv(in_channels=384, out_channels=256, kernel_size=(3, 1), padding=(1, 0))
self.branch_4_1 = nn.Sequential(
BasicConv(in_channels=1536, out_channels=384, kernel_size=1),
BasicConv(in_channels=384, out_channels=448, kernel_size=(1, 3), padding=(0, 1)),
BasicConv(in_channels=448, out_channels=512, kernel_size=(3, 1), padding=(1, 0))
)
self.branch_4_2_1 = BasicConv(in_channels=512, out_channels=256, kernel_size=(3, 1), padding=(1, 0))
self.branch_4_2_2 = BasicConv(in_channels=512, out_channels=256, kernel_size=(1, 3), padding=(0, 1))
def forward(self, x):
x_1 = self.branch_1(x)
x_2 = self.branch_2(x)
x_3 = self.branch_3_1(x)
x_3_1 = self.branch_3_2_1(x_3)
x_3_2 = self.branch_3_2_2(x_3)
x_3 = torch.cat([x_3_1, x_3_2], dim=1)
x_4 = self.branch_4_1(x)
x_4_1 = self.branch_4_2_1(x_4)
x_4_2 = self.branch_4_2_2(x_4)
x_4 = torch.cat([x_4_1, x_4_2], dim=1)
torch.cat([x_1, x_2, x_3, x_4], dim=1)
return x
class Reduction_A(nn.Module):
def __init__(self):
super(Reduction_A, self).__init__()
self.branch_1 = nn.MaxPool2d(kernel_size=3, stride=2)
self.branch_2 = BasicConv(in_channels=384, out_channels=384, kernel_size=3, stride=2)
self.branch_3 = nn.Sequential(
BasicConv(in_channels=384, out_channels=192, kernel_size=1),
BasicConv(in_channels=192, out_channels=224, kernel_size=3),
BasicConv(in_channels=224, out_channels=256, kernel_size=3, stride=2, padding=1) # padding是为了凑论文size
)
def forward(self, x):
x_1 = self.branch_1(x)
x_2 = self.branch_2(x)
x_3 = self.branch_3(x)
x = torch.cat([x_1, x_2, x_3], dim=1)
return x
class Reduction_B(nn.Module):
def __init__(self):
super(Reduction_B, self).__init__()
self.branch_1 = nn.MaxPool2d(kernel_size=3, stride=2)
self.branch_2 = nn.Sequential(
BasicConv(in_channels=1024, out_channels=192, kernel_size=1),
BasicConv(in_channels=192, out_channels=192, kernel_size=3, stride=2)
)
self.branch_3 = nn.Sequential(
BasicConv(in_channels=1024, out_channels=256, kernel_size=1),
BasicConv(in_channels=256, out_channels=256, kernel_size=(1, 7), padding=(0, 3)), # 注意这里stride=1
BasicConv(in_channels=256, out_channels=320, kernel_size=(7, 1), padding=(3, 0)),
BasicConv(in_channels=320, out_channels=320, kernel_size=3, stride=2)
)
def forward(self, x):
x_1 = self.branch_1(x)
x_2 = self.branch_2(x)
x_3 = self.branch_3(x)
x = torch.cat([x_1, x_2, x_3], dim=1)
return x
class InceptionV4(nn.Module):
def __init__(self, num_classes):
super(InceptionV4, self).__init__()
self.stem = StemModel()
self.ModelA_1 = InceptionA()
self.ModelA_2 = InceptionA()
self.ModelA_3 = InceptionA()
self.ModelA_4 = InceptionA()
self.reduction_a = Reduction_A()
self.InceptionB_1 = InceptionB()
self.InceptionB_2 = InceptionB()
self.InceptionB_3 = InceptionB()
self.InceptionB_4 = InceptionB()
self.InceptionB_5 = InceptionB()
self.InceptionB_6 = InceptionB()
self.InceptionB_7 = InceptionB()
self.Reduction_B = Reduction_B()
self.InceptionC_1 = InceptionC()
self.InceptionC_2 = InceptionC()
self.InceptionC_3 = InceptionC()
self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
self.flatten = nn.Flatten() # multi-dim -> one-dim
self.fc = nn.Linear(in_features=1536, out_features=num_classes)
def forward(self, x):
x = self.stem(x)
x = self.ModelA_1(x)
x = self.ModelA_2(x)
x = self.ModelA_3(x)
x = self.ModelA_4(x)
x = self.reduction_a(x)
x = self.InceptionB_1(x)
x = self.Reduction_B(x)
x = self.InceptionC_1(x)
x = self.InceptionC_2(x)
x = self.InceptionC_3(x)
x = self.avg_pool(x)
x = self.flatten(x)
x = torch.dropout(x, 0.2, train=True) # 论文中keep=0.8
x = self.fc(x)
x = torch.softmax(x, dim=1)
return x
if __name__ == '__main__':
# 根据第一层的输入要求来设定,第一个参数表示共20个branch
input = torch.ones([20, 3, 299, 299])
model = InceptionV4(num_classes=5)
output = model(input)
print(output.shape)
复现InceptionV4
最新推荐文章于 2024-08-29 10:04:09 发布