class KernelAttention(nn.Module):
def __init__(self, channels, reduction=4, num_kernels=4, init_weight=True):
super().__init__()
if channels != 3:
mid_channels = channels // reduction
else:
mid_channels = num_kernels
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(mid_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(mid_channels, num_kernels, kernel_size=1, bias=True)
self.sigmoid = nn.Sigmoid()
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.avg_pool(x)
x = self.conv1(x)
x = self.bn(x)
x = self.relu(x)
x = self.conv2(x).view(x.shape[0], -1)
x = self.sigmoid(x)
return x
class KernelAggregation(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, num_kernels,
init_weight=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.bias = bias
self.num_kernels = num_kernels
self.weight = nn.Parameter(
torch.randn(num_kernels, out_channels, in_channels // groups, kernel_size, kernel_size),
requires_grad=True)
if bias:
self.bias = nn.Parameter(
torch.zeros(num_kernels, out_channels))
else:
self.bias = None
if init_weight:
self._initialize_weights()
def _initialize_weights(self):
for i in range(self.num_kernels):
nn.init.kaiming_uniform_(self.weight[i])
def forward(self, x, attention):
batch_size, in_channels, height, width = x.size()
x = x.contiguous().view(1, batch_size * self.in_channels, height, width)
weight = self.weight.contiguous().view(self.num_kernels, -1)
weight = torch.mm(attention, weight).contiguous().view(
batch_size * self.out_channels,
self.in_channels // self.groups,
self.kernel_size,
self.kernel_size)
if self.bias is not None:
bias = torch.mm(attention, self.bias).contiguous().view(-1)
x = F.conv2d(
x,
weight=weight,
bias=bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups * batch_size)
else:
x = F.conv2d(
x,
weight=weight,
bias=None,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups * batch_size)
x = x.contiguous().view(batch_size, self.out_channels, x.shape[-2], x.shape[-1])
return x
class DynamicKernelAggregation(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
num_kernels=4):
super().__init__()
assert in_channels % groups == 0
self.attention = KernelAttention(
in_channels,
num_kernels=num_kernels)
self.aggregation = KernelAggregation(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
num_kernels=num_kernels)
def forward(self, x):
attention = x
attention = self.attention(attention)
x = self.aggregation(x, attention)
return x