yolov10改进之轻量化的动态上采样算子DySample

一、本次改进是在yolov10的基础上添加一种轻量的动态上采样算子DySample

DySample(Dynamic Sampling for Efficient Upsampling)是一种轻量化的动态上采样算子,它用于图像处理、计算机视觉及深度学习领域,尤其是在图像生成、超分辨率重建、图像分割等任务中,旨在提高上采样操作的效率,同时减少计算量和内存占用。DySample的主要创新点在于其通过动态地选择和调整上采样的方式,避免了传统上采样方法中固定模式的局限性,从而实现了更灵活、更加高效的图像上采样操作。

论文地址:https://arxiv.org/pdf/2308.15085

DySample的主要特点:

1.动态选择采样位置

DySample通过动态地选择采样位置来进行上采样,而不是采用传统的固定采样模式(例如:最近邻插值、双线性插值等)。这种动态选择的方式允许算法根据输入图像的内容自适应地调整上采样的策略,增强了对不同图像特征的适应性。 

2.轻量化设计

DySample被设计为计算量较小、参数少的上采样算子,适用于资源受限的环境。它通过减少计算冗余和内存占用,实现了更高的效率,尤其适合移动端设备和边缘计算设备。

3.降低计算复杂度

传统的上采样方法(如反卷积、双线性插值)通常需要较大的计算量,尤其是在需要高分辨率输出的任务中。DySample通过动态调整采样位置和采用轻量化计算方式,减少了计算复杂度,使得上采样过程更加高效。

4.灵活的自适应性

DySample能够根据不同输入图像的特征进行自适应调整,这意味着它不仅能在不同场景中表现出色,还能根据任务需求调整上采样的精度和效率。通过这种方式,DySample在图像生成、分割、恢复等多个任务中表现出较强的适应能力。

5.适用于多种任务

DySample作为一种上采样算子,适用于多个深度学习任务,尤其是图像生成、超分辨率、语义分割等领域。它能够帮助模型在保持高效性的同时,生成高质量的输出。 

二、代码

import torch
import torch.nn as nn
import torch.nn.functional as F


def normal_init(module, mean=0, std=1, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.normal_(module.weight, mean, std)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


def constant_init(module, val, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


class DySample(nn.Module):
    def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False):
        super().__init__()
        self.scale = scale
        self.style = style
        self.groups = groups
        assert style in ['lp', 'pl']
        if style == 'pl':
            assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
        assert in_channels >= groups and in_channels % groups == 0

        if style == 'pl':
            in_channels = in_channels // scale ** 2
            out_channels = 2 * groups
        else:
            out_channels = 2 * groups * scale ** 2

        self.offset = nn.Conv2d(in_channels, out_channels, 1)
        normal_init(self.offset, std=0.001)
        if dyscope:
            self.scope = nn.Conv2d(in_channels, out_channels, 1)
            constant_init(self.scope, val=0.)

        self.register_buffer('init_pos', self._init_pos())

    def _init_pos(self):
        h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
        return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)

    def sample(self, x, offset):
        B, _, H, W = offset.shape
        offset = offset.view(B, 2, -1, H, W)
        coords_h = torch.arange(H) + 0.5
        coords_w = torch.arange(W) + 0.5
        coords = torch.stack(torch.meshgrid([coords_w, coords_h])
                             ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
        normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
        coords = 2 * (coords + offset) / normalizer - 1
        coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(
            B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
        return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',
                             align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)

    def forward_lp(self, x):
        if hasattr(self, 'scope'):
            offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
        else:
            offset = self.offset(x) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward_pl(self, x):
        x_ = F.pixel_shuffle(x, self.scale)
        if hasattr(self, 'scope'):
            offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos
        else:
            offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward(self, x):
        if self.style == 'pl':
            return self.forward_pl(x)
        return self.forward_lp(x)


if __name__ == '__main__':
    x = torch.rand(2, 64, 4, 7)
    dys = DySample(64)
    print(dys(x).shape)

三、大家改进完可以试一试,看看效果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值