本人觉得很多博客都没有解释清楚,或者存在或多或少的错误。
我个人觉得,以下这篇博文讲得很清晰易懂:
https://yinguobing.com/separable-convolution/#fn2。
以下是pytorch代码:
class depthwise_conv2d(nn.Module):
def __init__(self, n_in, n_out):
super(depthwise_separable_conv, self).__init__()
self.depth_wise = nn.Conv2d(n_in, n_in, kernel_size=3, padding=1, groups=n_in)
self.point_wise = nn.Conv2d(n_in, n_out, kernel_size=1)
def forward(self, x):
out = self.depth_wise(x)
out = self.point_wise(out)
return out