【LUT技术专题】RCLUT代码解读

目录

原文概要

1. 训练

2. 转表

3. 微调

4. 测试

本文是对RCLUT技术的代码解读,原文解读请看RCLUT。 

原文概要

RCLUT通过增大网络感受野,来提升超分效果,实现SRLUT的改进,主要是1个创新点:

  • 提出了一个RC模块,中文名为重构卷积,重构卷积可以使用极小的代价来大幅提升RF,这个RC模块集成到前面讲到的SRLUTMuLUT上有效果的提升。

其网络结构图如下:

RC Module:

代码结构如下,作者是在MuLUT代码的基础上加入了RC模块,整体同样分成了4个步骤,分别是训练、转表、微调、测试,common中放着其他网络结构的实现:


1. 训练

代码位于1_train_model.py文件中,跟MuLUT结构一样,首先我们找到模型设置部分,在common/option.py中:

class BaseOptions():
    def __init__(self, debug=False):
        self.initialized = False
        self.debug = debug

    def initialize(self, parser):
        # experiment specifics
        parser.add_argument('--model', type=str, default='SRNets')
        parser.add_argument('--task', '-t', type=str, default='sr')
        parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
        parser.add_argument('--sigma', '-s', type=int, default=25, help="noise level")
        parser.add_argument('--qf', '-q', type=int, default=20, help="deblocking quality factor")
        parser.add_argument('--nf', type=int, default=64, help="number of filters of convolutional layers")
        parser.add_argument('--stages', type=int, default=2, help="stages of MuLUT")
        parser.add_argument('--modes', type=str, default='sdy', help="sampling modes to use in every stage")
        parser.add_argument('--interval', type=int, default=4, help='N bit uniform sampling')
        parser.add_argument('--modelRoot', type=str, default='../models')

        parser.add_argument('--expDir', '-e', type=str, default='', help="experiment folder")
        parser.add_argument('--load_from_opt_file', action='store_true', default=False)

        parser.add_argument('--debug', default=False, action='store_true')

        self.initialized = True
        return parser

模型的名称是SRNets,SRNets的定义在sr/model.py中,如下:

class SRNets(nn.Module):
    """ A LUT-convertable SR network with configurable stages and patterns. """

    def __init__(self, nf=64, scale=4, modes=['s', 'd', 'y'], stages=2):
        super(SRNets, self).__init__()

        
        for s in range(stages):  # 2-stage
            if (s + 1) == stages:
                upscale = scale
                flag = "N"
            else:
                upscale = None
                flag = "1"
            for mode in modes:
                # print('mode', mode, "{}x{}".format(mode.upper(), flag))
                self.add_module("s{}_{}".format(str(s + 1), mode),
                                SRNet("{}x{}".format(mode.upper(), flag), nf=nf, upscale=upscale))
        # print_network(self)
            # self.add_module("s{}_{}".format(str(s+1), "Connect"), SRNet("Connect", nf=nf, upscale=upscale))
    def forward(self, x, stage, mode):
        key = "s{}_{}".format(str(stage), mode)
        module = getattr(self, key)
        return module(x)

可以看到SRNets中有一个add_module部分,其实就是多个SRNet的叠加。SRNet的具体实现在common/network.py中。

class SRNet(nn.Module):
    """ Wrapper of a generalized (spatial-wise) MuLUT block. 
        By specifying the unfolding patch size and pixel indices,
        arbitrary sampling pattern can be implemented.
    """

    def __init__(self, mode, nf=64, upscale=None, dense=True):
        super(SRNet, self).__init__()
        self.mode = mode
        
        if 'x1' in mode:
            assert upscale is None
        if mode == 'Sx1':
            self.model = MuLUTUnit('2x2', nf, upscale=1, dense=dense, stage=1)
            # self.model = MuLUTUnit('4x1', nf, upscale=1, dense=dense, deform_mode='s')
            self.K = 5
            self.S = 1
        elif mode == 'SxN':
            upscale = upscale
            self.model = MuLUTUnit('2x2', nf, upscale=upscale, dense=dense, stage=2)
            # self.model = MuLUTUnit('4x1', nf, upscale=upscale, dense=dense, deform_mode='s')
            self.K = 5
            self.S = upscale
        elif mode == 'Hx1':
            self.model = MuLUTUnit('4x1', nf, upscale=1, dense=dense, deform_mode='h')
            self.K = 2
            self.S = 1
        elif mode == 'HxN':
            self.model = MuLUTUnit('4x1', nf, upscale=upscale, dense=dense, deform_mode='h')
            self.K = 2
            self.S = upscale
        elif mode == 'Jx1':
            self.model = MuLUTUnit('4x1', nf, upscale=1, dense=dense, deform_mode='h')
            self.K = 2
            self.S = 1
        elif mode == 'JxN':
            self.model = MuLUTUnit('4x1', nf, upscale=upscale, dense=dense, deform_mode='h')
            self.K = 2
            self.S = upscale
        elif mode == 'Fx1':
            self.model = MuLUTUnit('2x2d4', nf, upscale=1, dense=dense)
            self.K = 5
            self.S = 1
        elif mode == 'FxN':
            self.model = MuLUTUnit('2x2d4', nf, upscale=upscale, dense=dense)
            self.K = 5
            self.S = upscale
        elif mode == 'Dx1':
            self.model = MuLUTUnit('2x2d', nf, upscale=1, dense=dense, stage=1)
            # self.model = MuLUTUnit('4x1', nf, upscale=1, dense=dense, deform_mode='d')
            self.K = 5
            self.S = 1
        elif mode == 'DxN':
            self.model = MuLUTUnit('2x2d', nf, upscale=upscale, dense=dense, stage=2)
            # self.model = MuLUTUnit('4x1', nf, upscale=upscale, dense=dense, deform_mode='d')
            self.K = 5
            self.S = upscale
        elif mode == 'Yx1':
            self.model = MuLUTUnit('1x4', nf, upscale=1, dense=dense, stage=1)
            # self.model = MuLUTUnit('4x1', nf, upscale=1, dense=dense, deform_mode='y')
            self.K = 5
            self.S = 1
        elif mode == 'YxN':
            self.model = MuLUTUnit('1x4', nf, upscale=upscale, dense=dense, stage=2)
            # self.model = MuLUTUnit('4x1', nf, upscale=upscale, dense=dense, deform_mode='y')
            self.K = 5
            self.S = upscale
        elif mode == 'Ex1':
            self.model = MuLUTUnit('2x2d3', nf, upscale=1, dense=dense)
            self.K = 4
            self.S = 1
        elif mode == 'ExN':
            self.model = MuLUTUnit('2x2d3', nf, upscale=upscale, dense=dense)
            self.K = 4
            self.S = upscale
        # elif mode in ['Ox1', 'Hx1']:
        #     self.model = MuLUTUnit('1x4', nf, upscale=1, dense=dense)
        #     self.K = 4
        #     self.S = 1
        # elif mode == ['OxN', 'HxN']:
        #     self.model = MuLUTUnit('1x4', nf, upscale=upscale, dense=dense)
        #     self.K = 4
        #     self.S = upscale
        elif mode == 'Connect':
            self.model = MuLUTcUnit('1x1', nf)
            self.K = 3
        else:
            raise AttributeError
        self.P = self.K - 1

    def forward(self, x):
        if 'H' in self.mode:
            channel = x.size(1)
            x = x.reshape(-1, 1, x.size(2), x.size(3))
            x = self.model(x)
            x = x.reshape(-1, channel, x.size(2), x.size(3))
        elif self.mode == 'Connect':
            x = self.model(x)
        else:
            B, C, H, W = x.shape
            x_dense = x[:, :, :-4, :-4]
            x_7x7 = F.pad(x, [2, 0, 2, 0], mode='replicate')
            B7, C7, H7, W7 = x_7x7.shape
            x_7x7 = F.unfold(x_7x7, 7) 
            x_3x3 = x[:, :, :-2, :-2]
            B3, C3, H3, W3 = x_3x3.shape
            x_3x3 = F.unfold(x_3x3, 3)
            
            x_3x3 = x_3x3.view(B3, C3, 9, (H3-2)*(W3-2))
            x_3x3 = x_3x3.permute((0, 1, 3, 2))
            x_3x3 = x_3x3.reshape(B3 * C3 * (H3-2)*(W3-2), 3, 3)
            x_3x3 = x_3x3.unsqueeze(-1)

            x_7x7 = x_7x7.view(B7, C7, 49, (H7-6)*(W7-6))
            x_7x7 = x_7x7.permute((0, 1, 3, 2))
            x_7x7 = x_7x7.reshape(B7 * C7 * (H7-6)*(W7-6), 7, 7)
            x_7x7 = x_7x7.unsqueeze(-1)

            x = F.unfold(x, self.K)  # B,C*K*K,L
            x = x.view(B, C, self.K * self.K, (H - self.P) * (W - self.P))  # B,C,K*K,L
            r_H = H - self.P
            r_W = W - self.P
            x = x.permute((0, 1, 3, 2))  # B,C,L,K*K
            x = x.reshape(B * C * (H - self.P) * (W - self.P),
                        self.K, self.K)  # B*C*L,K,K
            # x = x.unsqueeze(1)  # B*C*L,l,K,K
            x = x.unsqueeze(-1)

            # if 'Y' in self.mode:
            #     x = torch.cat([x[:, :, 0, 0], x[:, :, 1, 1],
            #                 x[:, :, 1, 2], x[:, :, 2, 1]], dim=1)

            #     x = x.unsqueeze(1).unsqueeze(1)
            if 'H' in self.mode:
                x = torch.cat([x[:, :, 0, 0], x[:, :, 2, 2],
                            x[:, :, 2, 3], x[:, :, 3, 2]], dim=1)

                x = x.unsqueeze(1).unsqueeze(1)
            elif 'O' in self.mode:
                x = torch.cat([x[:, :, 0, 0], x[:, :, 2, 2],
                            x[:, :, 1, 3], x[:, :, 3, 1]], dim=1)

                x = x.unsqueeze(1).unsqueeze(1)

            x = self.model(x, r_H, r_W, x_dense, x_3x3, x_7x7)   # B*C*L,K,K
            x = x.squeeze(1)
            x = x.reshape(B, C, (H - self.P) * (W - self.P), -1)  # B,C,K*K,L
            x = x.permute((0, 1, 3, 2))  # B,C,K*K,L
            x = x.reshape(B, -1, (H - self.P) * (W - self.P))  # B,C*K*K,L
            x = F.fold(x, ((H - self.P) * self.S, (W - self.P) * self.S),
                    self.S, stride=self.S)
            # print('ll', x.size())
        return x

可以看到作者这里的SRNet,主要引用的是MuLUTUnit模块。在common/network.py中,这个模块的实现如下:

############### MuLUT Blocks ###############
class MuLUTUnit(nn.Module):
    """ Generalized (spatial-wise)  MuLUT block. """

    def __init__(self, mode, nf, upscale=1, out_c=1, dense=True, deform_mode='s', patch_size=48, stage=1):
        super(MuLUTUnit, self).__init__()
        self.act = nn.ReLU()
        self.upscale = upscale
        self.conv_naive = Conv(1, nf, 2)
        self.mode = mode
        self.stage = stage

        if mode == '2x2':
            
            if self.stage == 1:
                self.conv1 = RC_Module(1, nf, 4, mlp_field=5)
            else:
                self.conv1 = RC_Module(1, nf, 1, mlp_field=5)
            self.s_conv = Conv(4, nf, 1)
            # self.conv1 = Conv_test(1, nf)
        elif mode == '2x2d':
            # self.conv1 = Conv(1, nf, 2, dilation=2)
            if self.stage == 1:
                self.conv1 = RC_Module(1, nf, 4, mlp_field=7)
            else:
                self.conv1 = RC_Module(1, nf, 1, mlp_field=7)
            self.d_conv = Conv(4, nf, 1)
        elif mode == '2x2d3':
            self.conv1 = Conv(1, nf, 2, dilation=3)
        elif mode == '2x2d4':
            self.conv1 = Conv(1, nf, 2, dilation=4)
        elif mode == '1x4':
            # self.conv1 = Conv(1, nf, (1, 4))
            if self.stage == 1:
                self.conv1 = RC_Module(1, nf, 4, mlp_field=3)
            else:
                self.conv1 = RC_Module(1, nf, 1, mlp_field=3)
            self.y_conv = Conv(4, nf, 1)
        elif mode == '4x1':
            self.conv1 = DeformConv2d(1, nf, mode=deform_mode)
        else:
            raise AttributeError

        if dense:
            self.conv2 = DenseConv(nf, nf)
            self.conv3 = DenseConv(nf + nf * 1, nf)
            self.conv4 = DenseConv(nf + nf * 2, nf)
            self.conv5 = DenseConv(nf + nf * 3, nf)
            self.conv6 = Conv(nf * 5, 1 * upscale * upscale, 1)
            # self.conv6 = Conv(nf * 5, nf, 1)
        else:
            self.conv2 = ActConv(nf, nf, 1)
            self.conv3 = ActConv(nf, nf, 1)
            self.conv4 = ActConv(nf, nf, 1)
            self.conv5 = ActConv(nf, nf, 1)
            self.conv6 = Conv(nf, upscale * upscale, 1)
        if self.upscale > 1:
            self.pixel_shuffle = nn.PixelShuffle(upscale)

    def forward(self, x, r_H, r_W, x_dense, x_3x3, x_7x7):
        B, C, H, W = x_dense.shape
        x_dense = x_dense.reshape(-1, 1, H, W)
        if self.mode == '2x2':
            x = x
            x = torch.tanh(self.conv1(x))
            if self.stage == -1:
                x = self.s_conv(x)
            else:
                x = x.reshape(-1, 1, H, W)
                # x += x_dense
                x = F.pad(x, [0, 1, 0, 1], mode='replicate')
                x = F.unfold(x, 2)
                x = x.view(B, C, 2*2, H*W)
                x = x.permute((0,1,3,2))
                x = x.reshape(B * C * H * W, 2, 2)
                x = x.unsqueeze(1)
                x = self.act(self.conv_naive(x))
            
        elif self.mode == '2x2d':
            x = x_7x7
            x = torch.tanh(self.conv1(x))
            if self.stage == -1:
                x = self.d_conv(x)
            else:
                x = x.reshape(-1, 1, H, W)
                # x += x_dense
                x = F.pad(x, [0, 1, 0, 1], mode='replicate')
                x = F.unfold(x, 2)
                x = x.view(B, C, 2*2, H*W)
                x = x.permute((0,1,3,2))
                x = x.reshape(B * C * H * W, 2, 2)
                x = x.unsqueeze(1)
                x = self.act(self.conv_naive(x))
        elif self.mode == '1x4':
            x = x_3x3
            x = torch.tanh(self.conv1(x))
            if self.stage == -1:
                x = self.y_conv(x)
            else:
                x = x.reshape(-1, 1, H, W)
                # x += x_dense
                x = F.pad(x, [0, 1, 0, 1], mode='replicate')
                x = F.unfold(x, 2)
                x = x.view(B, C, 2*2, H*W)
                x = x.permute((0,1,3,2))
                x = x.reshape(B * C * H * W, 2, 2)
                x = x.unsqueeze(1)
                x = self.act(self.conv_naive(x))

        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        # x += x_3x3
        x = torch.tanh(x)
        if self.upscale > 1:
            x = self.pixel_shuffle(x)
        # print(x.size())
        return x

终于在MuLUTUnit中,我们看到了与本文相关的RC模块。在common/network.py中,部分MuLUTUnit会用到这个RC模块,RC模块的实现如下:

class RC_Module(nn.Module):
    def __init__(self, in_channels, out_channels, out_dim, kernel_size=4, stride=1, padding=0, dilation=1, bias=True, mlp_field=7):
        super(RC_Module, self).__init__()
        self.mlp_field = mlp_field
        self.conv = nn.Conv2d(in_channels, out_channels, 1,
                              stride=stride, padding=padding, dilation=dilation, bias=bias)
        self.out_layer = nn.Linear(out_channels, in_channels)
        nn.init.kaiming_normal_(self.conv.weight)
        nn.init.constant_(self.conv.bias, 0)
        for i in range(mlp_field * mlp_field):
            setattr(self, 'linear{}'.format(i+1), nn.Linear(in_channels, out_channels))
            setattr(self, 'out{}'.format(i+1), nn.Linear(out_channels, out_dim))
        
    def forward(self, x):
        x_kv = {}
        # print('x', x.size())
        for i in range(self.mlp_field):
            for j in range(self.mlp_field):
                num = i * self.mlp_field + j + 1
                module1 = getattr(self, 'linear{}'.format(num))
                x_kv[str(num)] = module1(x[:, i, j, :]).unsqueeze(1)
        x_list = []
        temp = []
        for i in range(self.mlp_field * self.mlp_field):
            module = getattr(self, 'out{}'.format(i+1))
            x_list.append(module(x_kv[str(i+1)]))
        # for i in range(self.mlp_field * self.mlp_field):
        #     temp.append(x_kv[str(i+1)])
        out = torch.cat(x_list, dim=1)
        # out = torch.cat(temp, dim=1)
        out = out.mean(1)
        # out = self.out_layer(out)
        out = out.unsqueeze(-1).unsqueeze(-1)
        
        out = torch.tanh(out)
        out = round_func(out * 127)
        bias, norm = 127, 255.0
        out = round_func(torch.clamp(out + bias, 0, 255)) / norm
        return out

这里的mlp_field就是我们讲解流程中的窗口大小“N”,self.conv和out_conv在下面的前向过程中没有使用到(博主认为这里可能是作者的一个失误,把无意义的代码留在这里)。

我们主要关注的是作者初始化了N*N个用于升维的self.linear[i+1]以及用于降维的self.out[i+1]

在前向的过程中,首先作者依次访问了N*N个升维度的MLP,将x进行处理,保存至x_kv中,这里的x的输入维度博主认为是B,N,N,C,由于他是需要进行卷积的模块,B应当是与H、W以及stride相关的结果,只不过每一个位置用了2个FC,且只作用于当前访问的点,读者可以初始化模型来观察tensor的shape进一步熟悉。

然后再对这N*N个结果进行一个降维,并且将他们cat在一起,求mean,相当于我们前面讲到的avg_pool操作。

最后就是一个简单的约束范围到[-1,1]并且做量化,量化使用的round_func在专栏MuLUT的代码讲解的Re-index中有讲到,如下所示。

def round_func(input):
    # Backward Pass Differentiable Approximation (BPDA)
    # This is equivalent to replacing round function (non-differentiable)
    # with an identity function (differentiable) only when backward,
    forward_value = torch.round(input)
    out = input.clone()
    out.data = forward_value.data
    return out

前向做一个四舍五入,反向的过程跟原来一致,最后重新规整到0-1,输出结果。

这里想要完全明白整个流程的详情,建议读者自己实践一次,包括实际模型的初始化,运行代码,以及每一步去看模型前向中tensor的shape,对照着讲解可以更熟悉这整个过程。


2. 转表

这部分代码跟MuLUT代码一致,作者没有给出与RC模块相关的转换代码,不过我们已经清楚RC模块其实就是N*N个1D LUT,在初始化的部分跟其他的LUT转换不一样的点在于kernel_size的大小以及最终得到的LUT个数是N*N个。博主给出自己的RCLUT转换实现。

首先是初始化输入的tensor,因为RC模块窗口中每一个只是一个1D LUT,所以我们不需要做多维的输入,输入可以用下列代码初始化。

def get_rc_lut(interval):
    base = torch.arange(0, 257, interval)  # 0-256
    base[-1] -= 1
    return base.unsqueeze(1)

这样以interval=1为例,则输出是[256, 1],再经过下列代码进行推理,就可以得出RC模块中NxN个lut的一个对应关系,将他们保存即可。

def inference(model, input_tensor, mlp_field=3):
    luts = {}
    for i in range(mlp_field):
        for j in range(mlp_field):
            num = i * mlp_field + j + 1
            fc1 = getattr(model, 'linear{}'.format(num))
            fc2 = getattr(model, 'out{}'.format(num))
            # same as torch model inference
            batch_output = torch.tanh(fc2(fc1(input_tensor)))
            
            luts[str(num)] = torch.round(torch.clamp(batch_output, -1, 1)
                                          * 127).cpu().data.astype(np.int8)


    return luts  

3. 微调

微调的部分跟MuLUT是一样的,这个也可以查看MuLUT一栏对于微调部分的讲解,即将LUT作为一个可训练的参数去进行损失的优化。


4. 测试

测试的部分与MuLUT也是一致的,只不过针对于RC模块的测试代码需要修改为numpy实现的版本,感兴趣的读者可以自己尝试实现,博主认为逻辑是一样的。

以上针对于RCLUT代码实现的部分讲解完毕,如果有不清楚的问题欢迎大家提出。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值