pytorch技巧 二: 深度可分离卷积
1. 分组卷积
要想弄懂深度可分离卷积(depth-wise Separable convolution
),先要明白什么是分组卷积。用一个简单例子来说明:
import torch
from torchsummary import summary
class mymodel(torch.nn.Module):
def __init__(self):
super(mymodel, self).__init__()
self.conv2d = torch.nn.Conv2d(in_channels=4,
out_channels=8,
kernel_size=3,
stride=1,
padding=1,
groups=1)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv2d(x)
x = self.relu(x)
return x
device = torch.device("cuda" )
model = mymodel().to(device)
summary(model, (4, 3, 3))
上面代码中是一个一层卷积网络,torch.nn.Conv2d
有个参数groups
, 这个参数的含义是对输入和输出通道数分组,这也就要求输入特征图和输出特征图都是groups
的倍数。不太明白的不要紧,我们看代码,代码中groups
为1
,把输入特征图和输出特征图分为一组,这就是普通的卷积。如下图:
从图中看出,我们输入为4×3×3(channels,width,height),输出为8×3×3, 且kernel_size=3
,可以得到参数(忽略偏置)个数为3×3×4×8=388 (kernel_size
×kernel_size
×输入的通道数×输出通道数)。
当groups
为2
时,如下图:
参数为3×3×(4/2)×(8/2)×2 = 194(kernel_size
×kernel_size
×(输入通道数/groups
)×(输出通道数/groups
)×groups
) ,即分组卷积参数量=普通卷积参数量 /groups
.这样就达到减少参数的目的。
1. 深度可分离卷积
深度可分离卷积分为两部分:
第一部分:分组卷积,且groups
和输出通道数皆为输入通道数。
第二部分:利用1×1的卷积更改输出通道数。
import torch.nn as nn
import torch
from torchsummary import summary
class depthwise_separable_conv(nn.Module):
def __init__(self, ch_in, ch_out):
super(depthwise_separable_conv, self).__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.depth_conv = nn.Conv2d(ch_in, ch_in, kernel_size=3, padding=1, groups=ch_in)
self.point_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1)
def forward(self, x):
x = self.depth_conv(x)
x = self.point_conv(x)
return x
class mymodel(nn.Module):
def __init__(self):
super(mymodel, self).__init__()
self.conv2d = depthwise_separable_conv(4, 8)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv2d(x)
x = self.relu(x)
return x
device = torch.device("cuda" )
model = mymodel().to(device)
summary(model, (4, 3, 3))
上面代码中就是一层深度可分离卷积网络,是由两个卷积网络组合而成,第一个卷积网络为groups=输入通道数
,输出通道数=输入通道数
,的分组卷积。即每一个卷积核只在一个通道上进行卷积,其参数量=3×3×4=36(kernel_size1×kernel_size1×输入通道数
). 第二个卷积网络的kernel_size=1
, 其参数量=1×1×4×8=32(kernel_size2×kernel_size2×输入通道数×输出通道数
)。总参数量=36 + 32 = 68
可以得出深度可分离卷积网络极大的减少了参数量。像MobileNet的基本单元就是深度可分离卷积!