torch中实现depthwise convolution
depthwise-convolution( 逐通道卷积)是一个能大幅减少参数量&运算量的操作,它的原理很简单,但是在使用Pytorch框架进行实现时,有一个小细节需要注意,特此记录。
depthwise-convolution
首先还是简单介绍一下什么是depthwise-convolution。
对于大小为 ( c i n , h , w ) (c_{in},h, w) (cin,h,w)的input features,做卷积操作后,获得大小为 ( c o u t , h , w ) (c_{out},h,w) (cout,h,w)的output features。
对于一般的卷积操作来说,使用大小为 ( c o u t , c i n , k , k ) (c_{out}, c_{in}, k, k) (cout,cin,k,k)的卷积核即可。
对于depthwise-convolution,实际上是对 c i n c_{in} cin个大小为 ( 1 , h , w ) (1, h, w) (1,h,w)的输入sub-features单独处理。对于每个sub-features,使用大小为 ( c o u t c i n , 1 , k , k ) (\frac{c_{out}}{c_{in}}, 1, k, k) (cincout,1,k,k)的sub-kernel来做卷积操作,得到大小为 ( c o u t c i n , h , w ) (\frac{c_{out}}{c_{in}}, h, w) (cincout,h,w)的输出sub-features。然后再将获得的 c i n c_in cin个输出sub-features沿着channel-axis拼接起来,获得了最终的output features。
需要满足 c o u t c_{out} cout为 c i n c_{in} cin的整数倍这一条件才能用depthwise convolution
参数量比较:
普通卷积:
c
o
u
t
×
c
i
n
×
k
×
k
c_{out}\times c_{in}\times k\times k
cout×cin×k×k
逐通道卷积:
c
i
n
×
c
o
u
t
c
i
n
×
k
×
k
=
c
o
u
t
×
k
×
k
c_{in}\times \frac{c_{out}}{c_{in}}\times k\times k = c_{out}\times k \times k
cin×cincout×k×k=cout×k×k
参数量减少了
c
i
n
c_{in}
cin倍,还是十分可观的。
torch代码实现
使用torch.nn.functional.conv2d
先看一下官方文档的解释:
torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor
input – input tensor of shape ( minibatch , in_channels , i H , i W ) (\text{minibatch} , \textbf{in\_channels} , iH , iW) (minibatch,in_channels,iH,iW)
weight – filters of shape ( out_channels , in_channels groups , k H , k W ) (\textbf{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW) (out_channels,groupsin_channels,kH,kW)
groups – split input into groups, both in_channels \textbf{in\_channels} in_channels and out_channels \textbf{out\_channels} out_channels should be divisible by the number of groups. Default: 1
特别要注意的是,我加粗了的部分才是groups这个参数会做split
操作的地方。
也意味着,除了input
之外,weight
也会被应用到split
操作!
class DepthwiseConv(nn.Module):
def __init__(self, channel_in, channel_out, kernel_size):
"""
:param channel_in: int, 原始input features的输入通道数
:param channel_out: int, 期望的output features的输出通道数
:param kernel_size: int (use the square receptive field)
"""
super().__init__()
k = kernel_size
weight = torch.randn(channel_out, 1, k, k)
self.weight = nn.Parameter(weight)
self.kernel_size = kernel_size
self.channel_in = channel_in
def forward(self, x):
x = F.conv2d(x, self.weight, padding=self.kernel_size//2, groups=self.channel_in)
# after split
# x : (b, inc, h, w) -> (b, inc / inc = 1, h, w)
# self.weight : (cout, 1, k, k) -> (cout / cin, 1, k, k)
return x
使用nn.Conv2d
如果是用nn.Conv2d
,则不需要自己手算kernel的大小^_^
class DepthwiseConv(nn.Module):
def __init__(self, channel_in, channel_out, kernel_size):
"""
:param channel_in: int, 原始input features的输入通道数
:param channel_out: int, 期望的output features的输出通道数
:param kernel_size: int (use the square receptive field)
"""
super().__init__()
self.conv = nn.Conv2d(channel_in, channel_out, kernel_size, padding=kernel_size // 2, groups=channel_in)
def forward(self, x):
# self.conv.weight.shape = (3,1,3,3)
x = self.conv(x)
return x
比depthwise convolution更节省参数量的方式
还有一种与depthwise convolution相比,运算量相同,但参数量再减少 c i n c_{in} cin倍,就是逐通道卷积运用的都是同一个卷积核。
这种方式使用nn.Conv2d
暂时没研究出来该怎么写,仅提供一个使用torch.nn.functional.conv2d
实现的版本供大家参考。
class DepthwiseLessParamsConv(nn.Module):
def __init__(self, channel_in, channel_out, kernel_size):
"""
:param channel_in: int, 原始input features的输入通道数
:param channel_out: int, 期望的output features的输出通道数
:param kernel_size: int (use the square receptive field)
"""
super().__init__()
k = kernel_size
# 先申请一片内存
self.weight = nn.Parameter(torch.empty(channel_out // channel_in, 1, k, k))
# 初始化self.weight
nn.init.normal_(self.weight, mean=0.0, std=0.01)
self.bias = nn.Parameter(torch.empty(channel_out // channel_in))
nn.init.constant_(self.bias, 0.0)
self.kernel_size = kernel_size
self.channel_in = channel_in
def forward(self, x):
x = F.conv2d(x, weight=torch.cat([self.weight] * self.channel_in, 0), bias=self.bias,
padding=self.kernel_size//2, groups=self.channel_in)
# after split
# x : (b, inc, h, w) -> (b, inc / inc = 1, h, w)
# conv_weight : (cout, 1, k, k) -> (cout / cin, 1, k, k)
return x