目录
一:Inception模块
1、需要解决的问题
在一个CNN网络中,到底设计多少卷积核,卷积路径设置多少为好,这个之前一直没有什么定论。所以inception模块的提出就是为了解决这类不确定参数的定量问题。
2、原理
其整体思路是,将图像按照n个维度(该例子中按照四维方式)设计n个并行卷积路线来提取特征值,每个路线的提取方式都尽量不同。如1*1、3*3、5*5、池化(注意:池化也是特征提取的一种)来把所有可能的路线都走一遍,最后在学习的过程中按照梯度的大小来分配各个路线的大小权重,进而解决上述问题。
需要注意的是在inception模块中我们不改变输入图像的长度和宽度,只对通道数做改变。其原因在于最后我们要按照固定尺寸重新拼接形成新的输出张量。所以在设计的时候应该做好padding、stride的设计,以防大小发生改变。
3、模块中1*1卷积核引入的特点
目的是简化计算量,提高计算效率。我们以192通道的28*28尺寸的输入经过5*5,padding=2,stride=1,in=192,out=32的卷积核为例子来计算一下电脑需要的计算量。
(5 ** 2) * (28**2 ) * (192) *(32)
而引入in = 192,out = 16,的1*1卷积核之后,计算量为
(1**2)* (28**2) *(192)* (16)+ (5**2)* (28**2) *16 *32
计算量减少了:107978752
二:整体设计思路
三:Inception模块设计思路
1、定义卷积核
卷积核设计时注意两点,第一点一定要考虑padding与stride,来保证输出的长宽不变。第二点,注意代码的简化行,不要重复设计,避免冗余。
2、拼接
拼接需要先将4个维度的输出放在一个表格中,再利用torch.cat(Tensor,dim)#dim=0,竖着拼接,dim=1横着拼接。
class InceptionMoudle(torch.nn.Module):
def __init__(self,in_channel):
super(InceptionMoudle, self).__init__()
self.cnn1_1 = torch.nn.Conv2d(in_channel, 16, 1)
self.cnn2_5 = torch.nn.Conv2d(16, 24, 5,padding=2)
self.cnn3_3_1 = torch.nn.Conv2d(16,24,3,padding=1)
self.cnn3_3_2 = torch.nn.Conv2d(24, 24, 3,padding=1)
self.pool1 = torch.nn.AvgPool2d(3,stride=1,padding=1)
self.pool2 = torch.nn.Conv2d(in_channel, 24, 1)
def forward(self,x):
y1 = self.cnn1_1(x)
y2 = self.cnn1_1(x)
y2 = self.cnn2_5(y2)
y3 = self.cnn1_1(x)
y3 = self.cnn3_3_1(y3)
y3 = self.cnn3_3_2(y3)
y4 = self.pool1(x)
y4 =self.pool2(y4)
input = [y1,y2,y3,y4]
# print(y1.shape,y2.shape,y3.shape,y4.shape)
output = torch.cat(input,1)
return output
四:模型设计
初始化过程中,将前面设计好的inception模块实例化,输入通道数。在前向函数中,在输入线性层之前注意利用x.view(-1,88*4*4)来吧张量拼接成一维。
五:整体代码
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
#导入数据
batchsize = 64
transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
train_data = datasets.MNIST("../datasets/",train=True,transform=transforms,download=True)
train_data = DataLoader(train_data,batch_size=batchsize,shuffle=True)
test_data = datasets.MNIST("../datasets/",train=False,transform=transforms,download=True)
test_data = DataLoader(test_data,batch_size=batchsize,shuffle=True)
#定义Inception模块
class InceptionMoudle(torch.nn.Module):
def __init__(self,in_channel):
super(InceptionMoudle, self).__init__()
self.cnn1_1 = torch.nn.Conv2d(in_channel, 16, 1)
self.cnn2_5 = torch.nn.Conv2d(16, 24, 5,padding=2)
self.cnn3_3_1 = torch.nn.Conv2d(16,24,3,padding=1)
self.cnn3_3_2 = torch.nn.Conv2d(24, 24, 3,padding=1)
self.pool1 = torch.nn.AvgPool2d(3,stride=1,padding=1)
self.pool2 = torch.nn.Conv2d(in_channel, 24, 1)
def forward(self,x):
y1 = self.cnn1_1(x)
y2 = self.cnn1_1(x)
y2 = self.cnn2_5(y2)
y3 = self.cnn1_1(x)
y3 = self.cnn3_3_1(y3)
y3 = self.cnn3_3_2(y3)
y4 = self.pool1(x)
y4 =self.pool2(y4)
input = [y1,y2,y3,y4]
# print(y1.shape,y2.shape,y3.shape,y4.shape)
output = torch.cat(input,1)
return output
#定义模型
class Modle(torch.nn.Module):
def __init__(self):
super(Modle, self).__init__()
self.cnn5_1 = torch.nn.Conv2d(1,10,5)
self.cnn5_2 = torch.nn.Conv2d(88,20,5)
self.mp = torch.nn.MaxPool2d(2)
self.line = torch.nn.Linear(1408,10)
self.inception1 = InceptionMoudle(10)
self.inception2 = InceptionMoudle(20)
def forward(self,x):
# in_size = x.size(0)
# print(in_size)
x = torch.nn.functional.relu(self.mp(self.cnn5_1(x)))
x = self.inception1(x)
x = torch.nn.functional.relu(self.mp(self.cnn5_2(x)))
x = self.inception2(x)
x = x.view(-1,1408)
x = self.line(x)
return x
model = Modle()
MSE = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(),lr=0.01)
def train(epoch):
total = 0
for i,data in enumerate(train_data,0):
image,index = data
output = model(image)
loss = MSE(output,index)
opt.zero_grad()
loss.backward()
opt.step()
total += loss.item()
if i % 300 == 299:
print("epoch = %d,i = %d" % (epoch,i))
print("loss = ", total / 300)
total = 0
def text():
a = 0
c = 0
with torch.no_grad():
for i, data in enumerate(test_data,0):
images,index = data
y =model(images)
_, b = torch.max(y.data,1)
a += (b == index).sum().item()
c += index.size(0)
correct = a / c
print("正确率=", correct)
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
text()