【LUT技术专题】DnLUT代码解读

目录

原文概要

1. 训练


本文是对DnLUT技术的代码解读,原文解读请看DnLUT。 

原文概要

DnLUT通过PCM模块和L型卷积,有效提升以往基于LUT方法降低色噪声的能力,用最小的存储量得到了同样的感受野范围,主要是2个创新点:

  • 提出了一个PCM模块(Pairwise Channel Mixer),使得其转换后LUT既不会维度太大,又可以满足通道和空间的交互性。
  • 提出了L型卷积(Rotation Non-overlapping Kernel),减小了旋转中卷积核的overlapping,从而在相同的感受野条件下,降低了LUT尺寸。

其网络结构图如下:

Pairwise Channel Mixer,如下图(c)所示,是空间和通道的交互,有3个4DLUT

L型卷积(Rotation Non-overlapping Kernel),如下图(b)所示,L型卷积在旋转后达到跟S型卷积一样的感受野,但减小了除中心位置的overlap,减小了LUT的存储压力:

首先我们通过专栏前面文章的讲解,可以预判流程实现需要分为:训练、转表、微调以及推理,这里DnLUT作者只开源了训练的部分,但其他的部分其实跟MuLUT等论文大差不差。代码整体结构如下:

跟模型实现重点的部分在dn文件夹中,关于某个层的实现在common文件夹的network.py中。


1. 训练

首先我们观察common/option_dnlut_sidd.py,可以看到DnLUT调用的模型是SRNets。

class BaseOptions():
    def __init__(self, debug=False):
        self.initialized = False
        self.debug = debug

    def initialize(self, parser):
        # experiment specifics
        parser.add_argument('--model', type=str, default='SRNets')
        parser.add_argument('--task', '-t', type=str, default='sr')
        parser.add_argument('--scale', '-r', type=int, default=1, help="up scale factor")
        parser.add_argument('--sigma', '-s', type=int, default=50, help="noise level")
        parser.add_argument('--qf', '-q', type=int, default=20, help="deblocking quality factor")
        parser.add_argument('--nf', type=int, default=64, help="number of filters of convolutional layers")
        parser.add_argument('--stages', type=int, default=1, help="stages of MuLUT")
        parser.add_argument('--modes', type=str, default='sdy', help="sampling modes to use in every stage")
        parser.add_argument('--interval', type=int, default=4, help='N bit uniform sampling')
        parser.add_argument('--modelRoot', type=str, default='../models')

        parser.add_argument('--expDir', '-e', type=str, default='/home/styan/DNLUT/exp/mulut_gaussian_50', help="experiment folder")
        parser.add_argument('--load_from_opt_file', action='store_true', default=False)

        parser.add_argument('--debug', default=False, action='store_true')

        self.initialized = True
        return parser

SRNets模型代码实现位于dn/model_dnlut.py中:

class SRNets(nn.Module):
    """ A LUT-convertable SR network with configurable stages and patterns. """

    def __init__(self, nf=64, scale=4, modes=['s', 'd', 'y'], stages=4, channel_mix=True):
        super(SRNets, self).__init__()
        if channel_mix:
            self.add_module("s1_m", SRNet("Mx1", nf=nf, upscale=None))
            self.add_module("s2_m", SRNet("Mx1", nf=nf, upscale=None))
            self.add_module("s1_q", SRNet("Qx1", nf=nf, upscale=None)) 
            self.add_module("s2_q", SRNet("Qx1", nf=nf, upscale=None)) 
            self.add_module("s3_q", SRNet("Qx1", nf=nf, upscale=None))
            
            

        for s in range(stages):  # 2-stage
            for mode in modes:
                self.add_module("s{}_{}".format(str(s + 1), mode),
                                SRNet("{}x1".format(mode.upper()), nf=nf))
        print_network(self)

    def forward(self, x, stage, mode):
        key = "s{}_{}".format(str(stage), mode)
        # print(key)
        module = getattr(self, key)
        return module(x)

这样可以看到,每个层使用的是SRNet结构,其实现在common/network.py中。

############### Image Super-Resolution ###############
class SRNet(nn.Module):
    """ Wrapper of a generalized (spatial-wise) MuLUT block. 
        By specifying the unfolding patch size and pixel indices,
        arbitrary sampling pattern can be implemented.
    """

    def __init__(self, mode, nf=64, upscale=None, dense=True):
        super(SRNet, self).__init__()
        self.mode = mode
        print(mode)
        if 'x1' in mode:
            assert upscale is None
        if mode == 'Sx1':
            self.model = MuLUTUnit('2x2', nf, upscale=1, dense=dense)
            self.K = 2
            self.S = 1
        elif mode == 'SxN':
            self.model = MuLUTUnit('2x2', nf, upscale=1, dense=dense)
            self.K = 2
            self.S = upscale
        elif mode == 'Cx1':
            self.model = MuLUTcUnit('1x1', nf)
            self.K = 2
            self.S = 1
        elif mode == 'Px1':
            self.model = PoolNear()
            self.K = 2
            self.S = 1
        elif mode == 'Qx1':
            self.model = MuLUTcUnit('1x1q', nf)
            self.K = 2
            self.S = 1
        elif mode == 'Mx1':
            self.model_rg = MuLUTmixUnit('2x2', nf)
            self.model_gb = MuLUTmixUnit('2x2', nf)
            self.model_rb = MuLUTmixUnit('2x2', nf)
            self.K = 2
            self.S = 1
        elif mode == 'MxN':
            self.model_rg = MuLUTmixUnit('2x2', nf)
            self.model_gb = MuLUTmixUnit('2x2', nf)
            self.model_rb = MuLUTmixUnit('2x2', nf)
            self.K = 2
            self.S = 1
        elif mode == 'Vx1':
            self.model = MuLUTUnit('1x3', nf, upscale=1, dense=dense)
            self.K = 2
            self.S = 1
        elif mode == 'VxN':
            self.model = MuLUTUnit('1x3', nf, upscale=1, dense=dense)
            self.K = 2
            self.S = 1
        elif mode == 'TMx1':
            self.model = MuLUTcUnit('1x1', nf)
            self.K = 1
            self.S = 1
        elif mode == 'Dx1':
            self.model = MuLUTUnit('2x2d', nf, upscale=1, dense=dense)
            self.K = 3
            self.S = 1
        elif mode == 'DxN':
            self.model = MuLUTUnit('2x2d', nf, upscale=upscale, dense=dense)
            self.K = 3
            self.S = upscale
        elif mode == 'Yx1':
            self.model = MuLUTUnit('1x4', nf, upscale=1, dense=dense)
            self.K = 3
            self.S = 1
        elif mode == 'YxN':
            self.model = MuLUTUnit('1x4', nf, upscale=upscale, dense=dense)
            self.K = 3
            self.S = upscale
        elif mode == 'Ex1':
            self.model = MuLUTUnit('2x2d3', nf, upscale=1, dense=dense)
            self.K = 4
            self.S = 1
        elif mode == 'ExN':
            self.model = MuLUTUnit('2x2d3', nf, upscale=upscale, dense=dense)
            self.K = 4
            self.S = upscale
        elif mode in ['Ox1', 'Hx1']:
            self.model = MuLUTUnit('1x4', nf, upscale=1, dense=dense)
            self.K = 4
            self.S = 1
        elif mode == ['OxN', 'HxN']:
            self.model = MuLUTUnit('1x4', nf, upscale=upscale, dense=dense)
            self.K = 4
            self.S = upscale
        else:
            raise AttributeError
        self.P = self.K - 1

    def forward(self, x):
        if 'TM' in self.mode:
            B, C, H, W = x.shape
            x = self.model(x)
            # print('down')
            return x
        elif 'C' in self.mode:
            B, C, H, W = x.shape
            
            x = self.model(x)
            # print('down')
            return x
        elif 'Q' in self.mode:
            B, C, H, W = x.shape
            
            x = self.model(x)
            # print('down')
            return x
        elif 'P' in self.mode:
            B, C, H, W = x.shape
            
            x = self.model(x)
            # print('down')
            return x
        elif 'M' in self.mode:
            B, C, H, W = x.shape
            x_rg = x[:, :2, :, :]
            x_gb = x[:, 1:, :, :]
            # x_rb = torch.stack((x[:, 0:1, :, :], x[:, 2:, :, :]),dim=1).squeeze(2)
            x_rb = torch.stack((x[:, 2:, :, :], x[:, 0:1, :, :]),dim=1).squeeze(2)
            processed_tensors = []

            for x, im in zip([x_rg, x_gb, x_rb], ['rg', 'gb', 'rb']):

                if 'rg' in im:
                    x = self.model_rg(x)   # B*C*L,K,K
                    x_rg_ = x
                elif 'gb' in im:
                    x = self.model_gb(x)   # B*C*L,K,K
                    x_gb_ = x
                else:
                    x = self.model_rb(x)   # B*C*L,K,K
                    x_rb_ = x
                processed_tensors.append(x)
            if x.is_cuda:
                device = x.device
            else:
                device = torch.device('cpu')
            combined_x = torch.cat(processed_tensors, dim=1).to(device)
            # print('down')
            return combined_x#, x_rg_, x_gb_, x_rb_
        else:
            B, C, H, W = x.shape
            x = F.unfold(x, self.K)  # B,C*K*K,L
            x = x.view(B, C, self.K * self.K, (H - self.P) * (W - self.P))  # B,C,K*K,L
            x = x.permute((0, 1, 3, 2))  # B,C,L,K*K
            x = x.reshape(B * C * (H - self.P) * (W - self.P),
                        self.K, self.K)  # B*C*L,K,K
            x = x.unsqueeze(1)  # B*C*L,l,K,K

            if 'Y' in self.mode:
                x = torch.cat([x[:, :, 0, 0], x[:, :, 1, 1],
                            x[:, :, 1, 2], x[:, :, 2, 1]], dim=1)

                x = x.unsqueeze(1).unsqueeze(1)
            elif 'V' in self.mode:
                # print(x.shape)
                x = torch.cat([x[:, :, 0, 0], x[:, :, 0, 1],
                            x[:, :, 1, 1]], dim=1)
                # print(x.shape)
                x = x.unsqueeze(1).unsqueeze(1)
            elif 'H' in self.mode:
                x = torch.cat([x[:, :, 0, 0], x[:, :, 2, 2],
                            x[:, :, 2, 3], x[:, :, 3, 2]], dim=1)

                x = x.unsqueeze(1).unsqueeze(1)
            elif 'O' in self.mode:
                x = torch.cat([x[:, :, 0, 0], x[:, :, 2, 2],
                            x[:, :, 1, 3], x[:, :, 3, 1]], dim=1)

                x = x.unsqueeze(1).unsqueeze(1)

            x = self.model(x)   # B*C*L,K,K
            x = x.squeeze(1)
            x = x.reshape(B, C, (H - self.P) * (W - self.P), -1)  # B,C,K*K,L
            x = x.permute((0, 1, 3, 2))  # B,C,K*K,L
            x = x.reshape(B, -1, (H - self.P) * (W - self.P))  # B,C*K*K,L
            x = F.fold(x, ((H - self.P) * self.S, (W - self.P) * self.S),
                    self.S, stride=self.S)
            return x

需要注意的有2个实现,一个是跟PCM相关的,一个是跟L型卷积相关的,分别对应的mode是Mx1以及Vx1,我们在上面关于SRNet的各种类型中定位到这具体的实现就是如下的:

        elif mode == 'Mx1':
            self.model_rg = MuLUTmixUnit('2x2', nf)
            self.model_gb = MuLUTmixUnit('2x2', nf)
            self.model_rb = MuLUTmixUnit('2x2', nf)
            self.K = 2
            self.S = 1
        elif mode == 'MxN':
            self.model_rg = MuLUTmixUnit('2x2', nf)
            self.model_gb = MuLUTmixUnit('2x2', nf)
            self.model_rb = MuLUTmixUnit('2x2', nf)
            self.K = 2
            self.S = 1
        elif mode == 'Vx1':
            self.model = MuLUTUnit('1x3', nf, upscale=1, dense=dense)
            self.K = 2
            self.S = 1

然后我们再观察MuLUTmixUnit和MuLUTUnit的实现,就可以基本清楚DnLUT的各个组件了,两者实现均在common/network.py中。

class MuLUTmixUnit(nn.Module):
    """ Channel-wise MuLUT block [RGB(3D) to RGB(3D)]. """

    def __init__(self, mode, nf):
        super(MuLUTmixUnit, self).__init__()
        self.act = nn.ReLU()

        if mode == '2x2':
            self.conv1 = Conv(2, nf, [1,2])
        else:
            raise AttributeError

        self.conv2 = DenseConv(nf, nf)
        self.conv3 = DenseConv(nf + nf * 1, nf)
        self.conv4 = DenseConv(nf + nf * 2, nf)
        self.conv5 = DenseConv(nf + nf * 3, nf)
        self.conv6 = Conv(nf * 5, 1, 1)

    def forward(self, x):
        
        x = self.act(self.conv1(x))
        
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = torch.tanh(self.conv6(x))
        return x



############### MuLUT Blocks ###############
class MuLUTUnit(nn.Module):
    """ Generalized (spatial-wise)  MuLUT block. """

    def __init__(self, mode, nf, upscale=1, out_c=1, dense=True):
        super(MuLUTUnit, self).__init__()
        self.act = nn.ReLU()
        self.upscale = upscale

        if mode == '2x2':
            self.conv1 = Conv(1, nf, 2)
        elif mode == '2x2d':
            self.conv1 = Conv(1, nf, 2, dilation=2)
        elif mode == '2x2d3':
            self.conv1 = Conv(1, nf, 2, dilation=3)
        elif mode == '1x4':
            self.conv1 = Conv(1, nf, (1, 4))
        elif mode == '1x3':
            self.conv1 = Conv(1, nf, (1, 3))
        elif mode == '1x1':
            self.conv1 = Conv(3, nf, (1, 1))
        else:
            raise AttributeError

        if dense:
            self.conv2 = DenseConv(nf, nf)
            self.conv3 = DenseConv(nf + nf * 1, nf)
            self.conv4 = DenseConv(nf + nf * 2, nf)
            self.conv5 = DenseConv(nf + nf * 3, nf)
            if mode == '1x1':
                self.conv6 = Conv(nf * 5, 3, 1)
            else:
                self.conv6 = Conv(nf * 5, 1 * upscale * upscale, 1)
        else:
            self.conv2 = ActConv(nf, nf, 1)
            self.conv3 = ActConv(nf, nf, 1)
            self.conv4 = ActConv(nf, nf, 1)
            self.conv5 = ActConv(nf, nf, 1)
            if mode == '1x1':
                self.conv6 = Conv(nf, 3 * upscale * upscale, 3)
            else:
                self.conv6 = Conv(nf, upscale * upscale, 1)
        if self.upscale > 1:
            self.pixel_shuffle = nn.PixelShuffle(upscale)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = torch.tanh(self.conv6(x))
        if self.upscale > 1:
            x = self.pixel_shuffle(x)
        return x

可以看到,其实PCM就是3个1x2,输入通道有2个的卷积,而L型卷积是kernel_size变成了一个1x3,以前是一个2x2,当然需要配合合适的input形状。

最后就是一个整体的forward过程,实现在dn/1_train_model_dnlut.py中。

mode_pad_dict = {"s": 1, "d": 2, "y": 2, "e": 3, "h": 3, "o": 3, 'm': 1, 'v': 1}


def mulut_predict(model_G, x, phase="train", opt=None):
        modes, stages = opt.modes, opt.stages
        pred = 0

        ## Pair-wise Mixer stage 1
        for r in [0, 1, 2, 3]:
            tmp = round_func(torch.rot90(model_G(F.pad(torch.rot90(x, r, [
        2, 3]), (0, 1, 0, 0), mode='replicate'), stage=1, mode='m'), (4 - r) % 4, [2, 3]) * 127)
            pred += tmp
        avg_factor, bias, norm = 4, 127, 255.0
        x = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm

        # L-conv Stage 1
        pred = 0
        for mode in modes:
            pad = mode_pad_dict[mode]
            for r in [0, 1, 2, 3]:
                pred += round_func(torch.rot90(model_G(F.pad(torch.rot90(x, r, [
                    2, 3]), (0, pad, 0, pad), mode='replicate'), stage=1, mode=mode), (4 - r) % 4, [2, 3]) * 127)
        avg_factor, bias, norm = len(modes) * 4, 127, 255.0
        x = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm
        x1 = x

        # L-conv Stage 2
        pred = 0
        for mode in modes:
            pad = mode_pad_dict[mode]
            for r in [0, 1, 2, 3]:
                pred += round_func(torch.rot90(model_G(F.pad(torch.rot90(x, r, [
                    2, 3]), (0, pad, 0, pad), mode='replicate'), stage=2, mode=mode), (4 - r) % 4, [2, 3]) * 127)
        avg_factor, bias, norm = len(modes) * 4, 127, 255.0
        x = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm
        x2 = x

        # L-conv Stage 3
        pred = 0
        for mode in modes:
            pad = mode_pad_dict[mode]
            for r in [0, 1, 2, 3]:
                pred += round_func(torch.rot90(model_G(F.pad(torch.rot90(x, r, [
                    2, 3]), (0, pad, 0, pad), mode='replicate'), stage=3, mode=mode), (4 - r) % 4, [2, 3]) * 127)

        avg_factor, bias, norm = len(modes) * 4, 127, 255.0
        x = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm
        x3 = x

        # concat
        x_r = torch.cat([x1[:, 0:1], x2[:, 0:1], x3[:, 0:1]], dim=1).to('cuda')
        x_g = torch.cat([x1[:, 1:2], x2[:, 1:2], x3[:, 1:2]], dim=1).to('cuda')
        x_b = torch.cat([x1[:, 2:], x2[:, 2:], x3[:, 2:]], dim=1).to('cuda')

        # R: 3 -> 1
        r = 0
        pred = round_func(torch.rot90(model_G(F.pad(torch.rot90(x_r, r, [
        2, 3]), (0, 0, 0, 0), mode='replicate'), stage=1, mode='q'), (4 - r) % 4, [2, 3]) * 127)
        avg_factor, bias, norm = 1, 127, 255.0
        x_r = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm

        # G: 3 -> 1
        r = 0
        pred = round_func(torch.rot90(model_G(F.pad(torch.rot90(x_g, r, [
        2, 3]), (0, 0, 0, 0), mode='replicate'), stage=1, mode='q'), (4 - r) % 4, [2, 3]) * 127)
        avg_factor, bias, norm = 1, 127, 255.0
        x_g = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm

        # B: 3 -> 1
        r = 0
        pred = round_func(torch.rot90(model_G(F.pad(torch.rot90(x_b, r, [
        2, 3]), (0, 0, 0, 0), mode='replicate'), stage=1, mode='q'), (4 - r) % 4, [2, 3]) * 127)
        avg_factor, bias, norm = 1, 127, 255.0
        x_b = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm

        x = torch.cat([x_r, x_g, x_b], dim=1).to('cuda')

        # L-conv Stage 4
        pred = 0
        for mode in modes:
            pad = mode_pad_dict[mode]
            for r in [0, 1, 2, 3]:
                pred += round_func(torch.rot90(model_G(F.pad(torch.rot90(x, r, [
                    2, 3]), (0, pad, 0, pad), mode='replicate'), stage=4, mode=mode), (4 - r) % 4, [2, 3]) * 127)

        avg_factor, bias, norm = len(modes) * 4, 127, 255.0
        x = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm

        # Pair-wise Mixer stage 2
        pred = 0
        for r in [0, 1, 2, 3]:
            tmp = round_func(torch.rot90(model_G(F.pad(torch.rot90(x, r, [
        2, 3]), (0, 1, 0, 0), mode='replicate'), stage=2, mode='m'), (4 - r) % 4, [2, 3]) * 127)
            # print(tmp.shape, r)
            pred += tmp
        avg_factor, bias, norm = 4, 127, 255.0
        x = round_func(torch.clamp((pred / avg_factor) + bias, 0, 255)) / norm
        
        # L-conv Stage 5
        pred = 0
        for mode in modes:
            pad = mode_pad_dict[mode]
            for r in [0, 1, 2, 3]:
                pred += round_func(torch.rot90(model_G(F.pad(torch.rot90(x, r, [
                    2, 3]), (0, pad, 0, pad), mode='replicate'), stage=5, mode=mode), (4 - r) % 4, [2, 3]) * 127)

        avg_factor, bias, norm = len(modes), 0, 1
        x = round_func((pred / avg_factor) + bias)
        if phase == "train":
            x = x / 255.0

        return x

这里的计算过程跟我们前面讲解的论文过程是一致的,每个计算的过程是通过调用我们前面讲到的SRNets的forward实现的,通过输入不同的stage和不同的mode(比如说mode='m'调用的是PCM即MuLUTmixUnit层,L型卷积通过查看options发现其调用就是V型卷积实现,即MuLUTUnit搭配mode=1x3)得到不同层的输出,中间当然需要搭配旋转和量化clip,最后完成输出。


以上针对于DnLUT的训练代码实现的部分讲解完毕,如果有不清楚的问题欢迎大家提出,关于转表和微调测试的部分大家可以自己尝试实现,原理逻辑是跟前面的文章一致的。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值