import torch.nn as nn
class SP_conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel=3, stride=1, dilation=1, bias=False):
super(SP_conv, self).__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel, stride, 0,
dilation, groups=in_channels, bias=bias
)
self.pixelwise = nn.Conv2d(
in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias
)
def forward(self, x):
x = self.conv(x)
x = self.pixelwise(x)
return x
sepconv(Separable Convolution)代码复现
最新推荐文章于 2021-11-01 13:16:43 发布