Dite-HRNet采用了动态卷积核操作

该博客介绍了两个关键网络模块:KernelAttention和KernelAggregation。KernelAttention模块通过自适应平均池化、卷积和sigmoid激活函数来生成注意力权重。KernelAggregation模块利用这些注意力权重动态地聚合卷积核,进行特征提取。DynamicKernelAggregation模块结合了这两者,形成一个能动态调整卷积核的网络结构,适用于图像识别等任务。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

class KernelAttention(nn.Module):
    def __init__(self, channels, reduction=4, num_kernels=4, init_weight=True):
        super().__init__()

        if channels != 3:
            mid_channels = channels // reduction
        else:
            mid_channels = num_kernels

        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(mid_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(mid_channels, num_kernels, kernel_size=1, bias=True)
        self.sigmoid = nn.Sigmoid()

        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.avg_pool(x)

        x = self.conv1(x)
        x = self.bn(x)
        x = self.relu(x)

        x = self.conv2(x).view(x.shape[0], -1)
        x = self.sigmoid(x)

        return x

class KernelAggregation(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, num_kernels,
                 init_weight=True):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        self.num_kernels = num_kernels

        self.weight = nn.Parameter(
            torch.randn(num_kernels, out_channels, in_channels // groups, kernel_size, kernel_size),
            requires_grad=True)
        if bias:
            self.bias = nn.Parameter(
                torch.zeros(num_kernels, out_channels))
        else:
            self.bias = None

        if init_weight:
            self._initialize_weights()

    def _initialize_weights(self):
        for i in range(self.num_kernels):
            nn.init.kaiming_uniform_(self.weight[i])

    def forward(self, x, attention):
        batch_size, in_channels, height, width = x.size()

        x = x.contiguous().view(1, batch_size * self.in_channels, height, width)

        weight = self.weight.contiguous().view(self.num_kernels, -1)
        weight = torch.mm(attention, weight).contiguous().view(
            batch_size * self.out_channels,
            self.in_channels // self.groups,
            self.kernel_size,
            self.kernel_size)

        if self.bias is not None:
            bias = torch.mm(attention, self.bias).contiguous().view(-1)
            x = F.conv2d(
                x,
                weight=weight,
                bias=bias,
                stride=self.stride,
                padding=self.padding,
                dilation=self.dilation,
                groups=self.groups * batch_size)
        else:
            x = F.conv2d(
                x,
                weight=weight,
                bias=None,
                stride=self.stride,
                padding=self.padding,
                dilation=self.dilation,
                groups=self.groups * batch_size)

        x = x.contiguous().view(batch_size, self.out_channels, x.shape[-2], x.shape[-1])

        return x



class DynamicKernelAggregation(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
                 num_kernels=4):
        super().__init__()
        assert in_channels % groups == 0

        self.attention = KernelAttention(
            in_channels,
            num_kernels=num_kernels)
        self.aggregation = KernelAggregation(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
            num_kernels=num_kernels)

    def forward(self, x):
        attention = x

        attention = self.attention(attention)
        x = self.aggregation(x, attention)

        return x

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值