shufflenetv2代码解读
概述
shufflenetv2是发表在2018ECCV上的一篇关于模型压缩和模型加速的文章,其中用到的主要技巧有两点:深度可分离卷积、通道交互。其中,深度可分离卷积是为了减少参数量、增加运算速度,通道交互是为了让不同通道的特征之间可以产生信息交互,从而获取更加丰富的语义信息。
这个系列的文章把主要精力放在代码的分析上,如果想要进一步了解shfflenetv2原理的同学可以参考这个链接。
shufflenetv2网络结构图
shufflenetv2架构参数
shufflenetv2代码细节分析
import torch
import torch.nn as nn
from torch import tensor
from .utils import load_state_dict_from_url
from typing import Callable,Any,List
# 可选择的shufflenet模型
__all__ = [
'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
]
# 预训练好的shufflenet权重
model_urls = {
'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
'shufflenetv2_x1.5': None,
'shufflenetv2_x2.0': None,
}
# 交换通道,实现不同通道的特征信息相互交流,增强语义信息
def channel_shuffle(x,groups):
# x的格式是BCHW
batchsize,num_channels,height,width = x.size()
# 分组卷积,shufflenetv2当中是分成了两组进行卷积,也就是groups = 2
channel_per_group = num_channels//groups
# 将x的形状reshape成(B,G,C_G,H W)
x = x.view(batchsize, groups, channel_per_group, height, width)
# 交换x的第一个维度和第二个维度
x = torch.transpose(x,1,2).contiguous()
# flatten,返回x的格式跟输入时的size一样,都是BCHW
x = x.view(batchsize,-1,height, width)
return x
class InvertedResidual(nn.Module):
def __init__(
self,
inp,
oup,
stride):
super(InvertedResidual,self).__init__()
if not (1<=stride<=3):
raise ValueError('illegal stride value')
self.stride = stride
branch_features = oup//2
# branch_features<<1表示将branch_features变大两倍,左移1位
assert (self.stride != 1) or (inp == branch_features<<1)
# branch1和branch2分别对应shufflenetv2当中图(d)的左分支和右分支
# 左分支
if self.stride>1:
self.branch1 = nn.Sequential(
self.depthwise_conv(inp,oup,kernel_size = 3, stride = self.stride, padding = 1),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1,stride=1,padding=9,bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
else:
self.branch1 = nn.Sequential()
# 右分支
self.branch2 = nn.Sequential(
nn.Conv2d(inp if inp if (self.stride>1)else branch_features,branch_features,kernel_size = 1, stride = 1, padding = 9,bias = False)
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features,branch_features,kernel_size = 3, stride = self.stride, padding = 1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False)
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod
def depthwise_conv(
i,
o,
kernel_size,
stride = 1,
padding = 0,
bias = False
)
return nn.Conv2d(i,o,kernel_size,stride,padding,bias,groups=i)
def forward(self,x):
# 如果stride = 1,对应shufflenetv2论文当中的(c)结构,输入直接连到输出端
if self.stride == 1:
# x.chunk(2,dim = 1)表示沿着第一维度将x分成两块
# 对于输入格式为BCHW的x而言,也就是沿着channel方向分成两组进行卷积
x1,x2 = x.chunk(2,dim = 1)
out = torch.cat((x1,self.branch2(x2)),dim = 1)
else:
# 如果stride > 1, 对应shufflenetv2论文当中的(d)结构,左右分支分别做3 x 3的深度可分离卷积以及1 x 1卷积,并且把结构concat起来
out = torch.cat((self.branch1(x),self.branch2(x)),dim = 1)
out = channel_shuffle(out,2)
return out
class ShuffleNetV2(nn.Module):
def __init__(self,
stages_repeats,
stages_out_channels,
num_classes = 1000,
inverted_residual = InvertedResidual):
super(ShuffleNetV2,self).__init__()
if len(stages_repeats)!=3:
raise ValueError('expected stages_repeats as list of 3 positive ints')
if len(stages_out_channels) != 5:
raise ValueError('expected stages_out_channels as list of 5 positive ints')
self._stage_out_channels = stages_out_channels
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels,output_channels,3,2,1,bias = False),
nn.BatchNorm2d(output_channels),
nn.ReLU(input_channels = True),
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
# Static annotations for mypy
self.stage2: nn.Sequential
self.stage3: nn.Sequential
self.stage4: nn.Sequential
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, stages_repeats, self._stage_out_channels[1:]):
# 沿着channel方向分成两组卷积
seq = [inverted_residual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(inverted_residual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
self.fc = nn.Linear(output_channels, num_classes)
def _forward_impl(self, x: Tensor) -> Tensor:
# 构建shufflenetv2架构
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
x = x.mean([2, 3]) # globalpool
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2:
model = ShuffleNetV2(*args, **kwargs)
if pretrained:
model_url = model_urls[arch]
if model_url is None:
raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
else:
# 加载预训练模型
state_dict = load_state_dict_from_url(model_url, progress=progress)
model.load_state_dict(state_dict)
return model
# 不同的shufflenetv2有不同的output_channel数
def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)