倒残差与线性瓶颈浅析 - MobileNetV2_晓野豬的博客-CSDN博客_倒残差结构
首先理解倒置残差要先了解残差形式
残差结构:
1.采用1*1卷积降维,比如输入是256,降维到64
2.采用卷积核为3*3形式
3.采用1*1卷积升维。,比如64变成256
了解完残差结构后,现在开始学习倒置残差结构
1.采用1*1卷积升维,比如输入是64,降维到256
2.采用卷积核为深度可分离的3*3形式
3.采用1*1卷积 降维。比如256,降维到64
这里的激活函数采用的是relu6
# _*_coding:utf-8_*_
import torch
import torch.nn as nn
class InvertedResidualsBlock(nn.Module):
def __init__(self, in_channels, out_channels, expansion, stride):
super(InvertedResidualsBlock, self).__init__()
channels = expansion * in_channels
self.stride = stride
self.basic_block = nn.Sequential(
nn.Conv2d(in_channels, channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU6(inplace=True),
nn.Conv2d(channels, channels, kernel_size=3, stride=stride, padding=1, groups=channels, bias=False),
nn.BatchNorm2d(channels),
nn.ReLU6(inplace=True),
nn.Conv2d(channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels)
)
# The shortcut operation does not affect the number of channels
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.basic_block(x)
if self.stride == 1:
print("With shortcut!")
out = out + self.shortcut(x)
else:
print("No shortcut!")
print(out.size())
return out
if __name__ == "__main__":
x = torch.randn(16, 3, 32, 32)
# no shortcut
net1 = InvertedResidualsBlock(3, 6, 6, 2)
# with shortcut
net2 = InvertedResidualsBlock(3, 6, 6, 1)
y1, y2 = net1(x), net2(x)