conv transpose2d
code from:mmflow
双线性插值转成反卷积操作
class Upsample(nn.Module):
"""Upsampling module.
Args:
scale_factor (int): Scale factor of upsampling.
channels (int): Number of channels of conv_transpose2d.
"""
def __init__(self, scale_factor: int, channels: int) -> None:
super().__init__()
self.kernel_size = 2 * scale_factor - scale_factor % 2
self.stride = scale_factor
self.pad = math.ceil((scale_factor - 1) / 2.)
self.channels = channels
self.register_buffer('weight', self.bilinear_upsampling_filter())
# caffe::BilinearFilter
def bilinear_upsampling_filter(self) -> torch.Tensor:
"""Generate the weights for caffe::BilinearFilter.
Returns:
Tensor: The weights for caffe::BilinearFilter
"""
f = math.ceil(self.kernel_size / 2.)
c = (2 * f - 1 - f % 2) / 2. / f
weight = torch.zeros(self.kernel_size**2)
for i in range(self.kernel_size**2):
x = i % self.kernel_size
y = (i / self.kernel_size) % self.kernel_size
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
return weight.view(1, 1, self.kernel_size,
self.kernel_size).repeat(self.channels, 1, 1, 1)
def forward(self, data: torch.Tensor) -> torch.Tensor:
"""Forward function for upsample.
Args:
data (Tensor): The input data.
Returns:
Tensor: The upsampled data.
"""
return F.conv_transpose2d(
data,
self.weight,
stride=self.stride,
padding=self.pad,
groups=self.channels)