import torch
from torch import nn
import torch.nn.functional as F
#首先定义一个包含conv与ReLu的基础卷积类
class BasicConv2d(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,padding=0):
super(BasicConv2d,self).__init__()
self.conv=nn.Conv2d(in_channels,out_channels,kernel_size,padding=padding)
def forward(self,x):
x=self.conv(x)
return F.relu(x,inplace=True)
#Inceptionv1的类,初始化的时候需要提供各个子模块的通道数大小
class Inceptionv1(nn.Module):
def __init__(self,in_dim,hid_1_1,hid_2_1,hid_2_3,hid_3_1,out_3_5,out_4_1):
super(Inceptionv1,self).__init__()
#下面是4个子模块各自的网络定义
self.branch1x1=BasicConv2d(in_dim,hid_1_1,1)
self.branch3x3=nn.Sequential(
BasicConv2d(in_dim,hid_2_1,1),
BasicConv2d(hid_2_1,hid_2_3,3,padding=1)
)
self.brach5x5=nn.Sequential(
BasicConv2d(in_dim,hid_3_1,1),
BasicConv2d(hid_3_1,out_3_5,5,padding=2)
)
self.brach_pool=nn.Sequential(
nn.MaxPool2d(3,stride=1,padding=1),
BasicConv2d(in_dim,out_4_1,1)
)
def forward(self,x):
b1=self.branch1x1(x)
b2=self.branch3x3(x)
b3=self.branch3x3(x)
b4=self.brach_pool(x)
#将这四个模块沿着通道方向进行拼接
output=torch.cat((b1,b2,b3,b4),dim=1)
return output
查看网络结构:
module=Inceptionv1(192,64,96,128,16,32,32)
module
Inceptionv1(
(branch1x1): BasicConv2d(
(conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
)
(branch3x3): Sequential(
(0): BasicConv2d(
(conv): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicConv2d(
(conv): Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(brach5x5): Sequential(
(0): BasicConv2d(
(conv): Conv2d(192, 16, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicConv2d(
(conv): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
)
)
(brach_pool): Sequential(
(0): MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(1): BasicConv2d(
(conv): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
)
)
)
import torch
from torch import nn
import torch.nn.functional as F
#构建基础的卷积模块,与Inception V2的基础模块比,增加了BN层
class BasicConv2d(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size,padding=0):
super(BasicConv2d,self).__init__()
self.conv=nn.Conv2d(in_channels,out_channels,kernel_size,padding=padding)
self.bn=nn.BatchNorm2d(out_channels,eps=0.001)
def forward(self,x):
x=self.conv(x)
x=self.bn(x)
return F.relu(x,inplace=True)
class Inception2(nn.Module):
def __init__(self):
super(Inception2,self).__init__()
#具体对应如Inception v2网络结构图(上图)
#对应1x1卷积分支
self.branch1=BasicConv2d(192,96,1,0)
#对应1x1卷积与3x3卷积分支
self.branch2=nn.Sequential(
BasicConv2d(192,48,1,0),
BasicConv2d(48,64,3,1)
)
#对应1x1卷积,3x3卷积,3x3卷积
self.branch3=nn.Sequential(
BasicConv2d(192,64,1,0),
BasicConv2d(64,96,3,1),
BasicConv2d(96,96,3,1)
)
#对应3x3平均池化和1x1卷积
self.branch4=nn.Sequential(
nn.AvgPool2d(3,stride=1,padding=1,count_include_pad=False),
BasicConv2d(192,64,1,0)
)
#前向过程
def forward(self,x):
x0=self.branch1(x)
x1=self.branch2(x)
x2=self.branch3(x)
x3=self.branch4(x)
out=torch.cat((x0,x1,x2,x3),1)
return out
module=Inception2()
module
网络结构:
Inception2(
(branch1): BasicConv2d(
(conv): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
(branch2): Sequential(
(0): BasicConv2d(
(conv): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicConv2d(
(conv): Conv2d(48, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
)
(branch3): Sequential(
(0): BasicConv2d(
(conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicConv2d(
(conv): Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
(2): BasicConv2d(
(conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn): BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
)
(branch4): Sequential(
(0): AvgPool2d(kernel_size=3, stride=1, padding=1)
(1): BasicConv2d(
(conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
)
)
)