PyTorch 实现shuffleNet_v1在CIFAR10上图像分类

目录

一、前言

二、网络结构及原理

   (一)Group Convolution

   (二)Channel Shuffle

   (三)block

   (四)网络结构

三、代码

四、参数量

五、训练结果

六、完整代码


一、前言

        shuffleNet_v1是轻量级的网络,通过引入逐点组卷积以及通道重排技术,有效减少了参数量以及计算量。

二、网络结构及原理

   (一)Group Convolution

        Group Convolution早在AlexNet利用多GPU训练时被使用,之后ResNeX将中间3*3卷积采用组卷积的形式(这也是作者提出block的灵感吧,既然3*3利用了组卷积,为什么1*1的卷积核不使用呢)。组卷积是将输入channel分为多个group,同时将卷积核channel也进行分组,原理如下图所示。

        计算量:

        正常卷积:W*H*N*k*k*m       

        组卷积:(W*H*\frac{m}{g}*k*k*\frac{N}{g})*g,变为正常卷积的1/g倍;

        

        传统网络搭建通常将相同结构的block反复堆叠,由上图(a)可知,当采用多个组卷积堆叠到一起,该组的输出仅仅取决于该组的输入,这就导致不同group(即不同channel)之间的信息无法进行流动,因此作者提出了通道重排技术(即Channel Shuffle)。

   (二)Channel Shuffle

        如上图(b)所示,将组卷积生成的结果继续在各个Group内划分小的子group,之后将每个Group内部对应位置处的子group拼接成新的Group,拼接后的结果如图(c)所示。

   (三)block

         shuffleNet_v1的block相当于在残差结构上做了修改,图(a)是经典残差结构(将3*3卷积换成了dw卷积),图(b)和图(c)是ShuffleNet网络采用的block,当stride=1(图b)时候只是将1*1的卷积换成了组卷积(组卷积能减少计算量)并添加了Channel Shuffle部分;当stride=2(图c)时候,首先在shortcut部分采用了3*3的平均池化下采样,此外将输出通道的Add特征融合变成了Channel拼接。

   (四)网络结构

        值得注意的是,当输入channel为24时候,由于输入channel太小,因此第一个1*1卷积并没有采用组卷积。

三、代码

import torch.nn as nn
import torch
import torch.nn.functional as F
from collections import OrderedDict
from torchsummary import summary


def _make_divisible(ch, divisor=8, min_ch=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_ch is None:
        min_ch = divisor
    new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_ch < 0.9 * ch:
        new_ch += divisor
    return new_ch

#基本的卷积BNrelu
class baseConv(nn.Module):
    def __init__(self,inchannels,outchannels,kernel_size,stride,groups,hasRelu=False):
        super(baseConv, self).__init__()
        if hasRelu:
            #判断是否有relu激活函数
            activate=nn.ReLU
        else:
            activate=nn.Identity
        pad=kernel_size//2
        self.baseconv=nn.Sequential(
            nn.Conv2d(in_channels=inchannels,out_channels=outchannels,kernel_size=kernel_size,stride=stride,padding=pad,groups=groups,bias=False),
            nn.BatchNorm2d(outchannels),
            activate()
        )

    def forward(self,x):
        out=self.baseconv(x)
        return out


#通道重排
def ChannelShuffle(x,groups):
    batch_size,channel,height,width=x.size()
    #获得每组的组内channel
    inner_channel=channel//groups
    #[batch,groups,inner_channel,height,width]
    x=x.view(batch_size,groups,inner_channel,height,width)
    x=torch.transpose(x,1,2).contiguous()
    x=x.view(batch_size,-1,height,width)
    return x


#stage结构
class Residual(nn.Module):
    def __init__(self,inchannels,outchannels,stride,groups):
        super(Residual, self).__init__()
        self.add_=True      #shortcut为相加操作
        self.groups=groups

        hidden_channel=inchannels//4
        #当输入channel不等于24时候才有第一个1*1conv
        self.has_conv1=True
        if inchannels!=24:
            self.channel1_first1=baseConv(inchannels=inchannels,outchannels=hidden_channel,kernel_size=1,stride=1,groups=groups,hasRelu=True)
        else:
            self.has_conv1=False
            self.channel1_first1=nn.Identity()
            hidden_channel=inchannels

        #channel1
        self.channel1=nn.Sequential(
            baseConv(inchannels=hidden_channel,outchannels=hidden_channel,kernel_size=3,stride=stride,groups=hidden_channel),
            baseConv(inchannels=hidden_channel,outchannels=outchannels,kernel_size=1,stride=1,groups=groups)
        )

        #channel2
        if stride==2:
            self.channel2=nn.AvgPool2d(kernel_size=3,stride=stride,padding=1)
            self.add_=False


    def forward(self,x):
        if self.has_conv1:
            x1=self.channel1_first1(x)
            x1=ChannelShuffle(x1,groups=self.groups)
            out=self.channel1(x1)
        else:
            out=self.channel1(x)
        if self.add_:
            out+=x
            return F.relu_(out)
        else:
            out2=self.channel2(x)
            out=torch.cat((out,out2),dim=1)
            return F.relu_(out)


#shuffleNet
class ShuffleNet(nn.Module):
    def __init__(self,groups,out_channel_list,num_classes,rate,init_weight=True):
        super(ShuffleNet, self).__init__()

        #定义有序字典存放网络结构
        self.Module_List=OrderedDict()

        self.Module_List.update({'Conv1':nn.Sequential(nn.Conv2d(3,_make_divisible(24*rate,divisor=4*groups),3,2,1,bias=False),nn.BatchNorm2d(_make_divisible(24*rate,4*groups)),nn.ReLU())})
        self.Module_List.update({'MaxPool1':nn.MaxPool2d(3,2,1)})

        #net_config [inchannels,outchannels,stride]
        net_config=[[out_channel_list[0],out_channel_list[0],1],
                    [out_channel_list[0],out_channel_list[1],2],
                    [out_channel_list[1],out_channel_list[2],1],
                    [out_channel_list[2],out_channel_list[3],2],
                    [out_channel_list[3],out_channel_list[4],1]]
        repeat_num=[3,1,7,1,3]

        #搭建stage部分
        self.Module_List.update({'stage0_0':Residual(_make_divisible(24*rate,4*groups),_make_divisible((out_channel_list[0]-_make_divisible(24*rate,4*groups))*rate,4*groups),stride=2,groups=groups)})
        for idx,item in enumerate(repeat_num):
            config_item=net_config[idx]
            for j in range(item):
                if j==0 and idx!=0 and config_item[-1]==2:
                    self.Module_List.update({'stage{}_{}'.format(idx,j+1):Residual(_make_divisible(config_item[0]*rate,4*groups),_make_divisible((config_item[1]-config_item[0])*rate,4*groups),config_item[2],groups)})
                else:
                    self.Module_List.update({'stage{}_{}'.format(idx,j+1):Residual(_make_divisible(config_item[0]*rate,4*groups),_make_divisible(config_item[1]*rate,4*groups),config_item[2],groups)})
                config_item[-1]=1       #重复stage的stride=1
                config_item[0]=config_item[1]

        self.Module_List.update({'GlobalPool':nn.AvgPool2d(kernel_size=7,stride=1)})

        self.Module_List=nn.Sequential(self.Module_List)

        self.linear=nn.Sequential(
            nn.Dropout(p=0.2),
            nn.Linear(_make_divisible(out_channel_list[-1]*rate,4*groups),num_classes)
        )

        if init_weight:
            self.init_weight()
    def forward(self,x):
        out=self.Module_List(x)
        out=out.view(out.size(0),-1)
        out=self.linear(out)
        return out

    def init_weight(self):
        for w in self.modules():
            if isinstance(w, nn.Conv2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_out')
                if w.bias is not None:
                    nn.init.zeros_(w.bias)
            elif isinstance(w, nn.BatchNorm2d):
                nn.init.ones_(w.weight)
                nn.init.zeros_(w.bias)
            elif isinstance(w, nn.Linear):
                nn.init.normal_(w.weight, 0, 0.01)
                nn.init.zeros_(w.bias)

#定义shufflenet_
def shuffleNet_g1_(num_classes,rate=1.0):
    config=[144,288,288,576,576]
    return ShuffleNet(groups=1,out_channel_list=config,num_classes=num_classes,rate=rate)

def shuffleNet_g2_(num_classes,rate=1.0):       #
    config=[200,400,400,800,800]
    return ShuffleNet(groups=2,out_channel_list=config,num_classes=num_classes,rate=rate)

def shuffleNet_g3_(num_classes,rate=1.0):
    config=[240,480,480,960,960]
    return ShuffleNet(groups=3,out_channel_list=config,num_classes=num_classes,rate=rate)

def shuffleNet_g4_(num_classes,rate=1.0):
    config=[272,544,544,1088,1088]
    return ShuffleNet(groups=4,out_channel_list=config,num_classes=num_classes,rate=rate)

def shuffleNet_g8_(num_classes,rate=1.0):
    config=[384,768,768,1536,1536]
    return ShuffleNet(groups=8,out_channel_list=config,num_classes=num_classes,rate=rate)

if __name__ == '__main__':
    net=shuffleNet_g3_(10,rate=1.0).to('cuda')
    print(net)
    summary(net,(3,224,224))

         训练部分以及测试部分代码与前面类似,代码中rate其实是不起作用的(由于论文中提到了缩放因子,我在搭建时候以为可以通过rate来控制模型大小,实际发现论文最后只给出了3个版本1x、0.5x、0.25x,而通过连续的rate控制channel是有一定难度的),后面懒得对代码作修改了,即默认rate=1.0。

四、参数量

        这里我列出了,g分别为1、2、3、4、8时候的参数量:

        

五、训练结果

         由于没找到预训练权重,因此我这里从头训练只跑了8个epoch(训练时间太长了),准确率达到了68%(group=3);

    

六、完整代码

        代码地址:链接:百度网盘 请输入提取码 提取码:3mh3

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值