自适应卷积和可变形卷积区别

自适应卷积和可变形卷积都是对传统卷积进行改进,以解决传统卷积在处理变形目标时的不足。它们的区别在于:

  • 自适应卷积是通过在卷积核中引入可学习的权重参数来动态调整卷积核的形状,从而适应不同的目标形状。具体而言,自适应卷积将卷积核的每个元素都视作可学习的参数,通过反向传播自适应地调整卷积核的形状和大小。自适应卷积可以看做是一种半参数化的卷积方法,因为卷积核中既包含了固定的元素,又包含了可学习的参数。
  • 可变形卷积则是通过在输入特征图上进行采样和偏移来调整卷积核中每个位置的采样点,从而适应不同的目标形变。具体而言,可变形卷积将输入特征图上的每个位置看做控制点,通过可学习的偏移量来调整卷积核中每个位置的采样点,从而形成一个可变形的卷积核。可变形卷积可以看做是一种全参数化的卷积方法,因为卷积核中的每个位置都是通过偏移量计算得到的。

两种方法的效果和速度并没有绝对的优劣之分,而是取决于具体的应用场景和模型架构。一般来说,自适应卷积比可变形卷积更容易优化,因为它只需要学习卷积核的形状和大小,而不需要学习采样和偏移。但是,在处理高度变形目标时,可变形卷积通常比自适应卷积效果更好。

下面是自适应卷积和可变形卷积的代码示例(基于PyTorch框架):

import torch.nn as nn
import torch

class AdaptiveConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True):
        super(AdaptiveConv2d, self).__init__()
        self.offset_conv = nn.Conv2d(in_channels, kernel_size*kernel_size*2, kernel_size=3, padding=1)
        self.mask_conv = nn.Conv2d(in_channels, kernel_size*kernel_size, kernel_size=3, padding=1)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation

    def forward(self, x):
        offset = self.offset_conv(x)
        mask = self.mask_conv(x)
        mask = torch.sigmoid(mask)
        N, C, H, W = offset.size()
        offset = offset.view(N, self.kernel_size*self.kernel_size, 2, H, W)
        mask = mask.view(N, self.kernel_size*self.kernel_size, 1, H, W)
        out = torch.zeros(N, self.conv.out_channels, (H+self.padding*2)//self.stride, (W+self.padding*2)//self.stride, dtype=x.dtype, device=x.device)
        for i in range(self.kernel_size):
            for j in range(self.kernel_size):
                offset_ij = offset[:, i*self.kernel_size+j, ...]
                mask_ij = mask[:, i*self.kernel_size+j, ...]
                kernel_ij = self.conv.weight[i*self.kernel_size+j, ...].unsqueeze(0)
                out += torch.nn.functional.conv2d(x, kernel_ij, stride=self.stride, padding=self.padding, dilation=self.dilation, bias=None, groups=1, padding_mode='zeros', offset=offset_ij) * mask_ij
        return out + (self.conv.bias.view(1, -1, 1, 1) if self.conv.bias is not None else 0)
# 自适应卷积示例代码
import torch.nn as nn

class AdaptiveConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=0):
        super(AdaptiveConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = padding
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.Tensor(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        weight_1x1 = self.weight.sum(dim=[2, 3]).unsqueeze(-1).unsqueeze(-1)
        x_1x1 = nn.functional.conv2d(x, weight_1x1, bias=None, stride=1, padding=0)
        weight_3x3 = self.weight
        x_3x3 = nn.functional.conv2d(x, weight_3x3, bias=None, stride=1, padding=self.padding)
        x_out = x_1x1 + x_3x3 + self.bias.view(1, -1, 1, 1).expand_as(x_1x1)
        return x_out
# 可变形卷积示例代码
from torch.nn.modules.utils import _pair
import torch
import torch.nn as nn

class DeformableConvolution(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, deformable_groups=1, bias=False):
        super(DeformableConvolution, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.deformable_groups = deformable_groups
        self.bias_flag = bias

        self.conv_offset = nn.Conv2d(
            self.in_channels, self.kernel_size[0] * self.kernel_size[1] * self.deformable_groups * 2,
            kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=True
        )

        self.conv_value = nn.Conv2d(
            self.in_channels, self.out_channels,
            kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias_flag
        )

    def forward(self, input):
        offset_mask = self.conv_offset(input)
        offset_mask = torch.sigmoid(offset_mask)

        offset_mask = offset_mask.split(
            self.kernel_size[0] * self.kernel_size[1], dim=1
        )

        input_split = input.split(1, dim=1)

        output_list = []

        for i in range(self.kernel_size[0] * self.kernel_size[1]):
            offset = offset_mask[i][:, :self.kernel_size[0] * self.kernel_size[1] * self.deformable_groups, ...]
            mask = offset_mask[i][:, self.kernel_size[0] * self.kernel_size[1] * self.deformable_groups:, ...]
            offset = offset.view(input.shape[0], self.deformable_groups, 2 * self.kernel_size[0] * self.kernel_size[1], input.shape[2], input.shape[3])
            mask = mask.view(input.shape[0], self.deformable_groups, self.kernel_size[0] * self.kernel_size[1], input.shape[2], input.shape[3])

            output = torch.nn.functional.conv2d(
                input_split[i], weight=self.conv_value.weight, stride=self.stride, padding=self.padding,
                dilation=1, groups=1
            )

            # im2col
            output = output.unfold(2, self.kernel_size[0], self.stride[0])
            output = output.unfold(3, self.kernel_size[1], self.stride[1])
            output = output.contiguous().view(
                output.size(0), output.size(1), output.size(2), output.size(3), -1
            )

            # set as (batch_size, out_channels_per_group, height, width, kernel_size**2)
            output = output.permute(0, 4, 1, 2, 3).contiguous()

            # set as (batch_size, out_channels_per_group, kernel_size**2, height, width)
            output = output.view(
                output.size(0),
                output.size(1),
                self.kernel_size[0] * self.kernel_size[1],
                output.size(2),
                output.size(3)
            )

            # shape: (batch_size, out_channels_per_group, kernel_size**2, height, width)
            offset = offset.permute(0, 2, 1, 3, 4)
            mask = mask.permute(0, 2, 1, 3, 4)

            # shape: (batch_size, out_channels_per_group, height*width, kernel_size**2)
            offset = offset.view(
                offset.size(0),
                offset.size(1),
                -1,
                self.kernel_size[0] * self.kernel_size[1]
            )

            mask = mask.view(
                mask.size(0),
                mask.size(1),
                -1,
                self.kernel_size[0] * self.kernel_size[1]
            )

            # calculate the final output using the offset and mask
            output = self.deform_conv_func(
                output, offset, mask, self.conv_value.weight.size(), self.stride, self.padding
            )

            # set as (batch_size, out_channels, height, width)
            output = output.view(
                output.size(0),
                self.out_channels,
                output.size(3),
                output.size(4)
            )

            output_list.append(output)

        # concatenate all outputs from different columns
        return torch.cat(output_list, dim=1)

    @staticmethod
    def deform_conv_func(input, offset, mask, weight_size, stride, padding):
        N, out_channels, height, width = input.shape
        kernel_h, kernel_w = weight_size[-2:]
        stride_h, stride_w = stride
        pad_h, pad_w = padding
        height_col = (height + 2 * pad_h - (kernel_h - 1) - 1) // stride_h + 1
        width_col = (width + 2 * pad_w - (kernel_w - 1) - 1) // stride_w + 1
        channels_col = out_channels // (kernel_h * kernel_w)

        input = input.data
        offset = offset.data
        mask = mask.data

        # reorganize input data with im2col
        input_col = input.unfold(2, kernel_h, stride_h).unfold(3, kernel_w, stride_w).reshape(N, -1, channels_col, height_col, width_col)

        # reorganize offset and mask data
        offset_col = offset.reshape(N, 2 * kernel_h * kernel_w * channels_col, height_col, width_col)
        mask_col = mask.reshape(N, kernel_h * kernel_w * channels_col, height_col, width_col)

        # get the final output using the formula
        output = torch.zeros(N, out_channels, height_col, width_col, dtype=input.dtype, device=input.device)

        for i in range(kernel_h):
            for j in range(kernel_w):
                index = i * kernel_w + j
                offset_index = index * 2 * channels_col
                mask_index = index * channels_col
                output[:, index * channels_col: (index + 1) * channels_col, :, :] += (
                        input_col[:, index, :, :, :]
                        * mask_col[:, mask_index: mask_index + channels_col, :, :]
                ).reshape(N, channels_col, -1, height_col, width_col).matmul(
                    self.get_conv_kernel(weight_size).reshape(channels_col, -1)
                ).reshape(N, channels_col, height_col, width_col)

                output_offset = input_col[:, index, :, :, :].reshape(
                    N, channels_col, -1, height_col, width_col
                ).matmul(
                    self.get_offset_kernel(kernel_h, kernel_w).reshape(channels_col, -1)
                ).reshape(N, channels_col, kernel_h, kernel_w, height_col, width_col)

                offset_h = offset_col[:, offset_index: offset_index + channels_col, :, :]
                offset_w = offset_col[:, offset_index + channels_col: offset_index + 2 * channels_col, :, :]

                h = torch.arange(kernel_h, dtype=input.dtype, device=input.device)
                w = torch.arange(kernel_w, dtype=input.dtype, device=input.device)
                W, H = torch.meshgrid(w, h)
                W = W.reshape(-1)
                H = H.reshape(-1)

                W = W.repeat(N, channels_col, height_col, width_col)
                H = H.repeat(N, channels_col, height_col, width_col)

                new_h = (H + offset_h).clamp(min=0, max=height_col - 1)
                new_w = (W + offset_w).clamp(min=0, max=width_col - 1)

                index = (
                        new_h.floor().long() * width_col + new_w.floor().long()
                ).reshape(N, channels_col, kernel_h, kernel_w, height_col, width_col)

                weight = self.get_weight(weight_size).reshape(out_channels, -1)
                weight = weight[
                    index[:, :, :, :, :, :].reshape(
                        N, channels_col * kernel_h * kernel_w, height_col, width_col
                    )
                ].reshape(N, channels_col, kernel_h, kernel_w, -1, height_col, width_col)

                output_offset = output_offset.reshape(
                    N, channels_col, kernel_h * kernel_w, height_col, width_col
                )

                weight = weight * mask_col[:, mask_index: mask_index + channels_col, :, :].reshape(
                    N, channels_col, kernel_h, kernel_w, 1, height_col, width_col
                )

                output_offset = output_offset * weight
                output_offset = output_offset.sum(dim=2)

                output_offset = output_offset.reshape(N, -1, height_col, width_col)

                output[:, index * channels_col: (index + 1) * channels_col, :, :] += output_offset

        return output

    def get_weight(self, weight_size):
        w = torch.zeros(weight_size, dtype=torch.float32, device=self.conv_value.weight.device)
        nn.init.kaiming_uniform_(w, a=1)
        return w

    def get_conv_kernel(self, weight_size):
        w = torch.zeros(weight_size, dtype=torch.float32, device=self.conv_value.weight.device)
        nn.init.kaiming_uniform_(w, a=0)
        return w

    def get_offset_kernel(self, kernel_h, kernel_w):
        w = torch.zeros((2 * kernel_h * kernel_w, kernel_h * kernel_w), dtype=torch.float32, device=self.conv_value.weight.device)
        nn.init.constant_(w, 0)
        return w

# 使用可变形卷积
class DeformNet(nn.Module):
    def __init__(self):
        super(DeformNet, self).__init__()

        self.features = nn.Sequential(
            DeformableConvolution(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            DeformableConvolution(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            DeformableConvolution(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            DeformableConvolution(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            DeformableConvolution(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            DeformableConvolution(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            DeformableConvolution(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            DeformableConvolution(512, 512, kernel_size=3, padding=1),
        )

        self.classifier = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Dropout(),

            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),

            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),

            nn.Linear(4096, 1000),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
  • 1
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ywfwyht

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值