门控卷积
class Gated_Conv(nn.Module):
def __init__(self, in_ch, out_ch, ksize=3, stride=1, rate=1, activation=nn.ELU()):
super(Gated_Conv, self).__init__()
padding = int(rate * (ksize - 1) / 2)
self.conv = nn.Conv2d(in_ch, 2 * out_ch, kernel_size=ksize, stride=stride, padding=padding, dilation=rate)
self.activation = activation
def forward(self, x):
raw = self.conv(x)
x1 = raw.split(int(raw.shape[1] / 2), dim=1)
gate = F.sigmoid(x1[0])
out = self.activation(x1[1]) * gate
return out
门控反卷积
class Gated_Deconv(nn.Module):
def __init__(self, in_ch, out_ch, ksize=3, stride=1, rate=1, activation=nn.ELU()):
super(Gated_Deconv, self).__init__()
self.up_sample = nn.Upsample(scale_factor=2, mode='nearest')
self.conv = Gated_Conv(in_ch=in_ch, out_ch=out_ch, ksize=ksize, stride=stride, rate=rate, activation=activation)
def forward(self, x):
x = self.up_sample(x)
out = self.conv(x)
return out