上一篇已经基于Pytorch实现了Vgg16的图片分类任务,这次写一下Inception10网络。目前很多开源的代码都是基于Tensorflow实现的,所以我用Pytorch实现了一下,基本思路差不多,只是函数的用法稍微有些不一样,如果大家发现有什么问题欢迎指正,一起交流学习。
首先引入需要的库,并创建一个conv_bn_activ类
import torch.nn as nn
import torch
class conv_bn_activ(nn.Module): # 将卷积、BN和激活函数封装在一起,在减少代码量的同时便于阅读和理解
def __init__(self, in_ch, out_ch, kenel_size=3, stride=1, padding=1):
super(conv_bn_activ, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kenel_size, stride, padding),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
)
def forward(self, x):
x = self.model(x)
return x
然后利用这个类创建如上图所示的InceptionNet的基本单元
class inception_block(nn.Module): # 创建InceptionNet的基本单元
def __init__(self, in_ch, out_ch, stride=1):
super(inception_block, self).__init__()
self.c1 = conv_bn_activ(in_ch, out_ch, 1, stride, padding=0)
self.c2_1 = conv_bn_activ(in_ch, out_ch, 1, stride, padding=0)
self.c2_2 = conv_bn_activ(out_ch, out_ch, 3, stride=1, padding=1)
self.c3_1 = conv_bn_activ(in_ch, out_ch, 1, stride, padding=0)
self.c3_2 = conv_bn_activ(out_ch, out_ch, 5, stride=1, padding=2)
self.c4_1 = nn.MaxPool2d(3, stride=1, padding=1)
self.c4_2 = conv_bn_activ(in_ch, out_ch, 1, stride, padding=0)
def forward(self, x):
x1 = self.c1(x)
x2_1 = self.c2_1(x)
x2_2 = self.c2_2(x2_1)
x3_1 = self.c3_1(x)
x3_2 = self.c3_2(x3_1)
x4_1 = self.c4_1(x)
x4_2 = self.c4_2(x4_1)
x = torch.cat([x1, x2_2, x3_2, x4_2], dim=1) # 在通道维度上进行拼接
return x
最后利用基本单元搭建Inception10网络
class Inception10(nn.Module): # 搭建Inception10网络
def __init__(self, num_block, num_class):
super(Inception10, self).__init__()
self.c1 = conv_bn_activ(3, 16)
self.in_ch = 16
self.out_ch = 16
self.num_block = num_block
self.num_class = num_class
self.blocks = nn.Sequential()
for block_i in range(num_block):
for layer_i in range(2):
if layer_i == 0:
block = inception_block(self.in_ch, self.out_ch, stride=2) # 通过步长为2的卷积进行下采样
else:
block = inception_block(self.out_ch*4, self.out_ch, stride=1)
self.blocks.add_module('{}_{}'.format(block_i, layer_i), block)
self.in_ch = self.out_ch*4
self.out_ch *= 2 # 增大输出维度
self.p1 = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Sequential(
nn.Linear(self.out_ch*2, self.num_class),
nn.Softmax(dim=1)
)
def forward(self, x):
x = self.c1(x)
x = self.blocks(x)
x = self.p1(x).view(64, self.out_ch*2)
x = self.fc1(x)
return x
网络搭建好后就可以开始对CIFAR10数据集进行训练和测试了(当然也可以使用其他的分类数据集),训练和测试的代码可以参考我的上一篇文章,只需要将引入的网络从Vgg16改成Inception10即可。