【扒模块】DySample

逐行注释

import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings

# 忽略警告信息,这通常用于开发过程中,避免警告干扰输出结果
warnings.filterwarnings('ignore')

# 定义一个函数,用于对神经网络模块的权重进行正态分布初始化
def normal_init(module, mean=0, std=1, bias=0):
    # 检查模块是否有权重属性,并且权重不为None
    if hasattr(module, 'weight') and module.weight is not None:
        # 使用正态分布初始化权重,均值为mean,标准差为std
        nn.init.normal_(module.weight, mean, std)
    # 检查模块是否有偏置属性,并且偏置不为None
    if hasattr(module, 'bias') and module.bias is not None:
        # 将偏置初始化为bias指定的值
        nn.init.constant_(module.bias, bias)

# 定义一个函数,用于将神经网络模块的权重初始化为一个常数值
def constant_init(module, val, bias=0):
    # 检查模块是否有权重属性,并且权重不为None
    if hasattr(module, 'weight') and module.weight is not None:
        # 将权重初始化为val指定的常数值
        nn.init.constant_(module.weight, val)
    # 检查模块是否有偏置属性,并且偏置不为None
    if hasattr(module, 'bias') and module.bias is not None:
        # 将偏置初始化为bias指定的值
        nn.init.constant_(module.bias, bias)

功能解释:

  • normal_init 函数用于对神经网络中的权重进行正态分布初始化。这通常用于初始化卷积层或线性层的权重,以引入小的随机性,帮助模型学习。函数接受三个参数:mean(均值,默认为0),std(标准差,默认为1),bias(偏置初始化值,默认为0)。
  • constant_init 函数用于将权重初始化为一个固定的常数值。这在某些特定情况下可能有用,例如,当需要将权重设置为特定值以实现某种特定的行为时。函数接受两个参数:val(权重的常数值),bias(偏置初始化值,默认为0)。
class DySample_UP(nn.Module):
    # 构造函数初始化DySample_UP模块
    def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False):
        super(DySample_UP, self).__init__()  # 调用基类的构造函数
        self.scale = scale  # 上采样的尺度因子,默认为2
        self.style = style  # 上采样的风格,可以是'lp'或'pl'
        self.groups = groups  # 组数,用于分组卷积

        # 确保上采样风格是有效的
        assert style in ['lp', 'pl']
        # 如果风格是'pl',则输入通道数必须是scale的平方,并且是scale的倍数
        if style == 'pl':
            assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
        # 输入通道数必须至少等于组数,并且是组数的倍数
        assert in_channels >= groups and in_channels % groups == 0

        # 根据风格设置输入和输出通道数
        if style == 'pl':
            in_channels = in_channels // scale ** 2  # 对于'pl'风格,调整输入通道数
            out_channels = 2 * groups  # 输出通道数为组数的两倍
        else:
            out_channels = 2 * groups * scale ** 2  # 对于'lp'风格,输出通道数为组数乘以scale的平方

        # 定义一个卷积层用于生成偏移量
        self.offset = nn.Conv2d(in_channels, out_channels, 1)
        normal_init(self.offset, std=0.001)  # 使用标准差为0.001的正态分布初始化偏移量卷积层

        # 如果启用了dyscope(动态作用域),则添加一个额外的卷积层
        if dyscope:
            self.scope = nn.Conv2d(in_channels, out_channels, 1)
            constant_init(self.scope, val=0.)  # 使用常数0初始化作用域卷积层

        # 注册一个缓冲区init_pos,用于存储初始化的偏移位置
        self.register_buffer('init_pos', self._init_pos())

    # 初始化偏移位置的方法
    def _init_pos(self):
        # 使用arange生成一个从-self.scale/2到self.scale/2的序列,然后除以scale进行归一化
        h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
        # 使用meshgrid生成网格,然后stack和transpose组合成一个2D偏移量矩阵
        return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)

功能解释:

  • DySample_UP 类是一个动态上采样模块,可以根据输入特征图动态地调整上采样的偏移量。
  • in_channels 参数指定了输入特征图的通道数。
  • scale 参数指定了上采样的尺度因子,默认为2,表示输出特征图的尺寸是输入的两倍。
  • style 参数定义了上采样的风格,可以是 'lp'(局部感知)或 'pl'(像素洗牌后局部感知)。
  • groups 参数用于分组卷积,可以增强特征图内的特征整合。
  • dyscope 参数是一个布尔值,用于确定是否使用动态作用域来调整偏移量。
  • self.offset 是一个卷积层,用于生成上采样的偏移量。
  • normal_init 函数用于初始化 self.offset 的权重。
  • self.scope 是一个可选的卷积层,仅在 dyscope 为 True 时使用,用于进一步调整偏移量。
  • _init_pos 方法生成了一个初始化的偏移位置矩阵,这个矩阵定义了上采样过程中每个像素点的参考位置。
class DySample_UP(nn.Module):
    # ...

    # sample 方法是上采样过程中对输入特征图 x 进行采样的核心函数
    def sample(self, x, offset):
        # 获取offset的尺寸,B是批次大小,H和W分别是特征图的高度和宽度
        B, _, H, W = offset.shape
        # 调整offset的视角,使其适用于后续的采样过程
        offset = offset.view(B, 2, -1, H, W)
        
        # 创建一个网格坐标,表示特征图中每个像素的位置
        coords_h = torch.arange(H) + 0.5
        coords_w = torch.arange(W) + 0.5
        coords = torch.stack(torch.meshgrid([coords_w, coords_h])).\
            transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
        # 归一化网格坐标,使其范围在[-1, 1],这是F.grid_sample所需的坐标范围
        normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
        coords = 2 * (coords + offset) / normalizer - 1
        
        # 使用pixel_shuffle调整coords的维度,以匹配后续的采样操作
        coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(
            B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
        
        # 使用grid_sample根据调整后的coords对x进行采样
        return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',
                             align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)

    # forward_lp是局部感知(Local Perception)风格的上采样方法
    def forward_lp(self, x):
        # 如果定义了scope,则使用scope调整offset
        if hasattr(self, 'scope'):
            offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
        else:
            # 否则,直接使用offset并加上初始化偏移
            offset = self.offset(x) * 0.25 + self.init_pos
        # 调用sample方法进行上采样
        return self.sample(x, offset)

    # forward_pl是像素洗牌后局部感知(Pixel Shuffle then Local Perception)风格的上采样方法
    def forward_pl(self, x):
        # 首先使用pixel_shuffle对x进行像素洗牌
        x_ = F.pixel_shuffle(x, self.scale)
        # 如果定义了scope,则使用scope调整offset
        if hasattr(self, 'scope'):
            offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos
        else:
            # 否则,直接使用offset并加上初始化偏移
            offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos
        # 调用sample方法进行上采样
        return self.sample(x, offset)

功能解释:

  • sample 方法是 DySample_UP 类的核心,它负责根据偏移量 offset 对输入特征图 x 进行采样。这个方法使用了 F.grid_sample 来实现上采样,通过调整采样坐标来实现动态上采样。
  • forward_lp 和 forward_pl 是两种不同的上采样风格。它们首先计算偏移量,然后调用 sample 方法来实现上采样。
  • 在 forward_lp 中,如果没有定义 scope,则偏移量是通过对 self.offset(x) 的输出进行缩放和加上初始化偏移量 self.init_pos 来得到的。
  • 在 forward_pl 中,首先对输入 x 使用 F.pixel_shuffle 进行像素洗牌,然后计算偏移量,再使用 F.pixel_unshuffle 对偏移量进行逆操作,以匹配像素洗牌后的维度。
  • 这两种方法都使用了 sample 方法来进行实际的上采样操作,其中 mode='bilinear' 指定了双线性插值作为采样方法,align_corners=False 和 padding_mode="border" 分别指定了坐标的对齐方式和填充模式。

通过这种方式,DySample_UP 类提供了一种灵活的动态上采样机制,可以根据不同的任务需求选择不同的上采样风格。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值