卷积模块重参数化

卷积模块的重参数化实现。
class SeqConv3x3(nn.Module):
def init(self, seq_type, inp_planes, out_planes, depth_multiplier):
super(SeqConv3x3, self).init()

    self.type = seq_type
    self.inp_planes = inp_planes
    self.out_planes = out_planes

    if self.type == 'conv1x1-conv3x3':
        self.mid_planes = int(out_planes * depth_multiplier)
        conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0)
        self.k0 = conv0.weight
        self.b0 = conv0.bias

        conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3)
        self.k1 = conv1.weight
        self.b1 = conv1.bias

    elif self.type == 'conv1x1-sobelx':
        conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
        self.k0 = conv0.weight
        self.b0 = conv0.bias

        # init scale & bias
        scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
        self.scale = nn.Parameter(scale)
        # bias = 0.0
        # bias = [bias for c in range(self.out_planes)]
        # bias = torch.FloatTensor(bias)
        bias = torch.randn(self.out_planes) * 1e-3
        bias = torch.reshape(bias, (self.out_planes,))
        self.bias = nn.Parameter(bias)
        # init mask
        self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
        for i in range(self.out_planes):
            self.mask[i, 0, 0, 0] = 1.0
            self.mask[i, 0, 1, 0] = 2.0
            self.mask[i, 0, 2, 0] = 1.0
            self.mask[i, 0, 0, 2] = -1.0
            self.mask[i, 0, 1, 2] = -2.0
            self.mask[i, 0, 2, 2] = -1.0
        self.mask = nn.Parameter(data=self.mask, requires_grad=False)

    elif self.type == 'conv1x1-sobely':
        conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
        self.k0 = conv0.weight
        self.b0 = conv0.bias

        # init scale & bias
        scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
        self.scale = nn.Parameter(torch.FloatTensor(scale))
        # bias = 0.0
        # bias = [bias for c in range(self.out_planes)]
        # bias = torch.FloatTensor(bias)
        bias = torch.randn(self.out_planes) * 1e-3
        bias = torch.reshape(bias, (self.out_planes,))
        self.bias = nn.Parameter(torch.FloatTensor(bias))
        # init mask
        self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
        for i in range(self.out_planes):
            self.mask[i, 0, 0, 0] = 1.0
            self.mask[i, 0, 0, 1] = 2.0
            self.mask[i, 0, 0, 2] = 1.0
            self.mask[i, 0, 2, 0] = -1.0
            self.mask[i, 0, 2, 1] = -2.0
            self.mask[i, 0, 2, 2] = -1.0
        self.mask = nn.Parameter(data=self.mask, requires_grad=False)

    elif self.type == 'conv1x1-laplacian':
        conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
        self.k0 = conv0.weight
        self.b0 = conv0.bias

        # init scale & bias
        scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
        self.scale = nn.Parameter(torch.FloatTensor(scale))
        # bias = 0.0
        # bias = [bias for c in range(self.out_planes)]
        # bias = torch.FloatTensor(bias)
        bias = torch.randn(self.out_planes) * 1e-3
        bias = torch.reshape(bias, (self.out_planes,))
        self.bias = nn.Parameter(torch.FloatTensor(bias))
        # init mask
        self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
        for i in range(self.out_planes):
            self.mask[i, 0, 0, 1] = 1.0
            self.mask[i, 0, 1, 0] = 1.0
            self.mask[i, 0, 1, 2] = 1.0
            self.mask[i, 0, 2, 1] = 1.0
            self.mask[i, 0, 1, 1] = -4.0
        self.mask = nn.Parameter(data=self.mask, requires_grad=False)

    elif self.type == 'conv1x1-laplacian8':
        conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
        self.k0 = conv0.weight
        self.b0 = conv0.bias

        # init scale & bias
        scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
        self.scale = nn.Parameter(torch.FloatTensor(scale))
        # bias = 0.0
        # bias = [bias for c in range(self.out_planes)]
        # bias = torch.FloatTensor(bias)
        bias = torch.randn(self.out_planes) * 1e-3
        bias = torch.reshape(bias, (self.out_planes,))
        self.bias = nn.Parameter(torch.FloatTensor(bias))
        # init mask
        self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
        for i in range(self.out_planes):
            self.mask[i, 0, 0, 0] = 1.0
            self.mask[i, 0, 0, 1] = 1.0
            self.mask[i, 0, 0, 2] = 1.0
            self.mask[i, 0, 1, 0] = 1.0
            self.mask[i, 0, 1, 2] = 1.0
            self.mask[i, 0, 2, 0] = 1.0
            self.mask[i, 0, 2, 1] = 1.0
            self.mask[i, 0, 2, 2] = 1.0
            self.mask[i, 0, 1, 1] = -8.0
        self.mask = nn.Parameter(data=self.mask, requires_grad=False)

    elif self.type == 'conv1x1-prewittx':
        conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
        self.k0 = conv0.weight
        self.b0 = conv0.bias

        # init scale & bias
        scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
        self.scale = nn.Parameter(torch.FloatTensor(scale))
        # bias = 0.0
        # bias = [bias for c in range(self.out_planes)]
        # bias = torch.FloatTensor(bias)
        bias = torch.randn(self.out_planes) * 1e-3
        bias = torch.reshape(bias, (self.out_planes,))
        self.bias = nn.Parameter(torch.FloatTensor(bias))
        # init mask
        self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
        for i in range(self.out_planes):
            self.mask[i, 0, 0, 0] = 1.0
            self.mask[i, 0, 1, 0] = 1.0
            self.mask[i, 0, 2, 0] = 1.0
            self.mask[i, 0, 0, 2] = -1.0
            self.mask[i, 0, 1, 2] = -1.0
            self.mask[i, 0, 2, 2] = -1.0
        self.mask = nn.Parameter(data=self.mask, requires_grad=False)

    elif self.type == 'conv1x1-prewitty':
        conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
        self.k0 = conv0.weight
        self.b0 = conv0.bias

        # init scale & bias
        scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
        self.scale = nn.Parameter(torch.FloatTensor(scale))
        # bias = 0.0
        # bias = [bias for c in range(self.out_planes)]
        # bias = torch.FloatTensor(bias)
        bias = torch.randn(self.out_planes) * 1e-3
        bias = torch.reshape(bias, (self.out_planes,))
        self.bias = nn.Parameter(torch.FloatTensor(bias))
        # init mask
        self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
        for i in range(self.out_planes):
            self.mask[i, 0, 0, 0] = 1.0
            self.mask[i, 0, 0, 1] = 1.0
            self.mask[i, 0, 0, 2] = 1.0
            self.mask[i, 0, 2, 0] = -1.0
            self.mask[i, 0, 2, 1] = -1.0
            self.mask[i, 0, 2, 2] = -1.0
        self.mask = nn.Parameter(data=self.mask, requires_grad=False)

    else:
        raise ValueError('the type of seqconv is not supported!')

def forward(self, x):
    if self.type == 'conv1x1-conv3x3':
        # conv-1x1
        y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
        # explicitly padding with bias
        y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
        b0_pad = self.b0.view(1, -1, 1, 1)
        y0[:, :, 0:1, :] = b0_pad
        y0[:, :, -1:, :] = b0_pad
        y0[:, :, :, 0:1] = b0_pad
        y0[:, :, :, -1:] = b0_pad
        # conv-3x3
        y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
    else:
        y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
        # explicitly padding with bias
        y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
        b0_pad = self.b0.view(1, -1, 1, 1)
        y0[:, :, 0:1, :] = b0_pad
        y0[:, :, -1:, :] = b0_pad
        y0[:, :, :, 0:1] = b0_pad
        y0[:, :, :, -1:] = b0_pad
        # conv-3x3
        y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes)
    return y1

def rep_params(self):
    device = self.k0.get_device()
    if device < 0:
        device = None

    if self.type == 'conv1x1-conv3x3':
        # re-param conv kernel
        # k0: [4, 2, 1, 1] k0.permute(1, 0, 2, 3) [2, 4, 1, 1]
        # k1: [2, 4, 3, 3]
        RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
        # re-param conv bias
        RB = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
        RB = F.conv2d(input=RB, weight=self.k1).view(-1, ) + self.b1
    else:
        tmp = self.scale * self.mask
        k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), device=device)
        for i in range(self.out_planes):
            k1[i, i, :, :] = tmp[i, 0, :, :]
        b1 = self.bias
        # re-param conv kernel
        RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
        # re-param conv bias
        RB = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
        RB = F.conv2d(input=RB, weight=k1).view(-1, ) + b1
    return RK, RB
  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值