目录
本文是对RCLUT技术的代码解读,原文解读请看RCLUT。
原文概要
RCLUT通过增大网络感受野,来提升超分效果,实现SRLUT的改进,主要是1个创新点:
其网络结构图如下:
RC Module:
代码结构如下,作者是在MuLUT代码的基础上加入了RC模块,整体同样分成了4个步骤,分别是训练、转表、微调、测试,common中放着其他网络结构的实现:
1. 训练
代码位于1_train_model.py文件中,跟MuLUT结构一样,首先我们找到模型设置部分,在common/option.py中:
class BaseOptions():
def __init__(self, debug=False):
self.initialized = False
self.debug = debug
def initialize(self, parser):
# experiment specifics
parser.add_argument('--model', type=str, default='SRNets')
parser.add_argument('--task', '-t', type=str, default='sr')
parser.add_argument('--scale', '-r', type=int, default=4, help="up scale factor")
parser.add_argument('--sigma', '-s', type=int, default=25, help="noise level")
parser.add_argument('--qf', '-q', type=int, default=20, help="deblocking quality factor")
parser.add_argument('--nf', type=int, default=64, help="number of filters of convolutional layers")
parser.add_argument('--stages', type=int, default=2, help="stages of MuLUT")
parser.add_argument('--modes', type=str, default='sdy', help="sampling modes to use in every stage")
parser.add_argument('--interval', type=int, default=4, help='N bit uniform sampling')
parser.add_argument('--modelRoot', type=str, default='../models')
parser.add_argument('--expDir', '-e', type=str, default='', help="experiment folder")
parser.add_argument('--load_from_opt_file', action='store_true', default=False)
parser.add_argument('--debug', default=False, action='store_true')
self.initialized = True
return parser
模型的名称是SRNets,SRNets的定义在sr/model.py中,如下:
class SRNets(nn.Module):
""" A LUT-convertable SR network with configurable stages and patterns. """
def __init__(self, nf=64, scale=4, modes=['s', 'd', 'y'], stages=2):
super(SRNets, self).__init__()
for s in range(stages): # 2-stage
if (s + 1) == stages:
upscale = scale
flag = "N"
else:
upscale = None
flag = "1"
for mode in modes:
# print('mode', mode, "{}x{}".format(mode.upper(), flag))
self.add_module("s{}_{}".format(str(s + 1), mode),
SRNet("{}x{}".format(mode.upper(), flag), nf=nf, upscale=upscale))
# print_network(self)
# self.add_module("s{}_{}".format(str(s+1), "Connect"), SRNet("Connect", nf=nf, upscale=upscale))
def forward(self, x, stage, mode):
key = "s{}_{}".format(str(stage), mode)
module = getattr(self, key)
return module(x)
可以看到SRNets中有一个add_module部分,其实就是多个SRNet的叠加。SRNet的具体实现在common/network.py中。
class SRNet(nn.Module):
""" Wrapper of a generalized (spatial-wise) MuLUT block.
By specifying the unfolding patch size and pixel indices,
arbitrary sampling pattern can be implemented.
"""
def __init__(self, mode, nf=64, upscale=None, dense=True):
super(SRNet, self).__init__()
self.mode = mode
if 'x1' in mode:
assert upscale is None
if mode == 'Sx1':
self.model = MuLUTUnit('2x2', nf, upscale=1, dense=dense, stage=1)
# self.model = MuLUTUnit('4x1', nf, upscale=1, dense=dense, deform_mode='s')
self.K = 5
self.S = 1
elif mode == 'SxN':
upscale = upscale
self.model = MuLUTUnit('2x2', nf, upscale=upscale, dense=dense, stage=2)
# self.model = MuLUTUnit('4x1', nf, upscale=upscale, dense=dense, deform_mode='s')
self.K = 5
self.S = upscale
elif mode == 'Hx1':
self.model = MuLUTUnit('4x1', nf, upscale=1, dense=dense, deform_mode='h')
self.K = 2
self.S = 1
elif mode == 'HxN':
self.model = MuLUTUnit('4x1', nf, upscale=upscale, dense=dense, deform_mode='h')
self.K = 2
self.S = upscale
elif mode == 'Jx1':
self.model = MuLUTUnit('4x1', nf, upscale=1, dense=dense, deform_mode='h')
self.K = 2
self.S = 1
elif mode == 'JxN':
self.model = MuLUTUnit('4x1', nf, upscale=upscale, dense=dense, deform_mode='h')
self.K = 2
self.S = upscale
elif mode == 'Fx1':
self.model = MuLUTUnit('2x2d4', nf, upscale=1, dense=dense)
self.K = 5
self.S = 1
elif mode == 'FxN':
self.model = MuLUTUnit('2x2d4', nf, upscale=upscale, dense=dense)
self.K = 5
self.S = upscale
elif mode == 'Dx1':
self.model = MuLUTUnit('2x2d', nf, upscale=1, dense=dense, stage=1)
# self.model = MuLUTUnit('4x1', nf, upscale=1, dense=dense, deform_mode='d')
self.K = 5
self.S = 1
elif mode == 'DxN':
self.model = MuLUTUnit('2x2d', nf, upscale=upscale, dense=dense, stage=2)
# self.model = MuLUTUnit('4x1', nf, upscale=upscale, dense=dense, deform_mode='d')
self.K = 5
self.S = upscale
elif mode == 'Yx1':
self.model = MuLUTUnit('1x4', nf, upscale=1, dense=dense, stage=1)
# self.model = MuLUTUnit('4x1', nf, upscale=1, dense=dense, deform_mode='y')
self.K = 5
self.S = 1
elif mode == 'YxN':
self.model = MuLUTUnit('1x4', nf, upscale=upscale, dense=dense, stage=2)
# self.model = MuLUTUnit('4x1', nf, upscale=upscale, dense=dense, deform_mode='y')
self.K = 5
self.S = upscale
elif mode == 'Ex1':
self.model = MuLUTUnit('2x2d3', nf, upscale=1, dense=dense)
self.K = 4
self.S = 1
elif mode == 'ExN':
self.model = MuLUTUnit('2x2d3', nf, upscale=upscale, dense=dense)
self.K = 4
self.S = upscale
# elif mode in ['Ox1', 'Hx1']:
# self.model = MuLUTUnit('1x4', nf, upscale=1, dense=dense)
# self.K = 4
# self.S = 1
# elif mode == ['OxN', 'HxN']:
# self.model = MuLUTUnit('1x4', nf, upscale=upscale, dense=dense)
# self.K = 4
# self.S = upscale
elif mode == 'Connect':
self.model = MuLUTcUnit('1x1', nf)
self.K = 3
else:
raise AttributeError
self.P = self.K - 1
def forward(self, x):
if 'H' in self.mode:
channel = x.size(1)
x = x.reshape(-1, 1, x.size(2), x.size(3))
x = self.model(x)
x = x.reshape(-1, channel, x.size(2), x.size(3))
elif self.mode == 'Connect':
x = self.model(x)
else:
B, C, H, W = x.shape
x_dense = x[:, :, :-4, :-4]
x_7x7 = F.pad(x, [2, 0, 2, 0], mode='replicate')
B7, C7, H7, W7 = x_7x7.shape
x_7x7 = F.unfold(x_7x7, 7)
x_3x3 = x[:, :, :-2, :-2]
B3, C3, H3, W3 = x_3x3.shape
x_3x3 = F.unfold(x_3x3, 3)
x_3x3 = x_3x3.view(B3, C3, 9, (H3-2)*(W3-2))
x_3x3 = x_3x3.permute((0, 1, 3, 2))
x_3x3 = x_3x3.reshape(B3 * C3 * (H3-2)*(W3-2), 3, 3)
x_3x3 = x_3x3.unsqueeze(-1)
x_7x7 = x_7x7.view(B7, C7, 49, (H7-6)*(W7-6))
x_7x7 = x_7x7.permute((0, 1, 3, 2))
x_7x7 = x_7x7.reshape(B7 * C7 * (H7-6)*(W7-6), 7, 7)
x_7x7 = x_7x7.unsqueeze(-1)
x = F.unfold(x, self.K) # B,C*K*K,L
x = x.view(B, C, self.K * self.K, (H - self.P) * (W - self.P)) # B,C,K*K,L
r_H = H - self.P
r_W = W - self.P
x = x.permute((0, 1, 3, 2)) # B,C,L,K*K
x = x.reshape(B * C * (H - self.P) * (W - self.P),
self.K, self.K) # B*C*L,K,K
# x = x.unsqueeze(1) # B*C*L,l,K,K
x = x.unsqueeze(-1)
# if 'Y' in self.mode:
# x = torch.cat([x[:, :, 0, 0], x[:, :, 1, 1],
# x[:, :, 1, 2], x[:, :, 2, 1]], dim=1)
# x = x.unsqueeze(1).unsqueeze(1)
if 'H' in self.mode:
x = torch.cat([x[:, :, 0, 0], x[:, :, 2, 2],
x[:, :, 2, 3], x[:, :, 3, 2]], dim=1)
x = x.unsqueeze(1).unsqueeze(1)
elif 'O' in self.mode:
x = torch.cat([x[:, :, 0, 0], x[:, :, 2, 2],
x[:, :, 1, 3], x[:, :, 3, 1]], dim=1)
x = x.unsqueeze(1).unsqueeze(1)
x = self.model(x, r_H, r_W, x_dense, x_3x3, x_7x7) # B*C*L,K,K
x = x.squeeze(1)
x = x.reshape(B, C, (H - self.P) * (W - self.P), -1) # B,C,K*K,L
x = x.permute((0, 1, 3, 2)) # B,C,K*K,L
x = x.reshape(B, -1, (H - self.P) * (W - self.P)) # B,C*K*K,L
x = F.fold(x, ((H - self.P) * self.S, (W - self.P) * self.S),
self.S, stride=self.S)
# print('ll', x.size())
return x
可以看到作者这里的SRNet,主要引用的是MuLUTUnit模块。在common/network.py中,这个模块的实现如下:
############### MuLUT Blocks ###############
class MuLUTUnit(nn.Module):
""" Generalized (spatial-wise) MuLUT block. """
def __init__(self, mode, nf, upscale=1, out_c=1, dense=True, deform_mode='s', patch_size=48, stage=1):
super(MuLUTUnit, self).__init__()
self.act = nn.ReLU()
self.upscale = upscale
self.conv_naive = Conv(1, nf, 2)
self.mode = mode
self.stage = stage
if mode == '2x2':
if self.stage == 1:
self.conv1 = RC_Module(1, nf, 4, mlp_field=5)
else:
self.conv1 = RC_Module(1, nf, 1, mlp_field=5)
self.s_conv = Conv(4, nf, 1)
# self.conv1 = Conv_test(1, nf)
elif mode == '2x2d':
# self.conv1 = Conv(1, nf, 2, dilation=2)
if self.stage == 1:
self.conv1 = RC_Module(1, nf, 4, mlp_field=7)
else:
self.conv1 = RC_Module(1, nf, 1, mlp_field=7)
self.d_conv = Conv(4, nf, 1)
elif mode == '2x2d3':
self.conv1 = Conv(1, nf, 2, dilation=3)
elif mode == '2x2d4':
self.conv1 = Conv(1, nf, 2, dilation=4)
elif mode == '1x4':
# self.conv1 = Conv(1, nf, (1, 4))
if self.stage == 1:
self.conv1 = RC_Module(1, nf, 4, mlp_field=3)
else:
self.conv1 = RC_Module(1, nf, 1, mlp_field=3)
self.y_conv = Conv(4, nf, 1)
elif mode == '4x1':
self.conv1 = DeformConv2d(1, nf, mode=deform_mode)
else:
raise AttributeError
if dense:
self.conv2 = DenseConv(nf, nf)
self.conv3 = DenseConv(nf + nf * 1, nf)
self.conv4 = DenseConv(nf + nf * 2, nf)
self.conv5 = DenseConv(nf + nf * 3, nf)
self.conv6 = Conv(nf * 5, 1 * upscale * upscale, 1)
# self.conv6 = Conv(nf * 5, nf, 1)
else:
self.conv2 = ActConv(nf, nf, 1)
self.conv3 = ActConv(nf, nf, 1)
self.conv4 = ActConv(nf, nf, 1)
self.conv5 = ActConv(nf, nf, 1)
self.conv6 = Conv(nf, upscale * upscale, 1)
if self.upscale > 1:
self.pixel_shuffle = nn.PixelShuffle(upscale)
def forward(self, x, r_H, r_W, x_dense, x_3x3, x_7x7):
B, C, H, W = x_dense.shape
x_dense = x_dense.reshape(-1, 1, H, W)
if self.mode == '2x2':
x = x
x = torch.tanh(self.conv1(x))
if self.stage == -1:
x = self.s_conv(x)
else:
x = x.reshape(-1, 1, H, W)
# x += x_dense
x = F.pad(x, [0, 1, 0, 1], mode='replicate')
x = F.unfold(x, 2)
x = x.view(B, C, 2*2, H*W)
x = x.permute((0,1,3,2))
x = x.reshape(B * C * H * W, 2, 2)
x = x.unsqueeze(1)
x = self.act(self.conv_naive(x))
elif self.mode == '2x2d':
x = x_7x7
x = torch.tanh(self.conv1(x))
if self.stage == -1:
x = self.d_conv(x)
else:
x = x.reshape(-1, 1, H, W)
# x += x_dense
x = F.pad(x, [0, 1, 0, 1], mode='replicate')
x = F.unfold(x, 2)
x = x.view(B, C, 2*2, H*W)
x = x.permute((0,1,3,2))
x = x.reshape(B * C * H * W, 2, 2)
x = x.unsqueeze(1)
x = self.act(self.conv_naive(x))
elif self.mode == '1x4':
x = x_3x3
x = torch.tanh(self.conv1(x))
if self.stage == -1:
x = self.y_conv(x)
else:
x = x.reshape(-1, 1, H, W)
# x += x_dense
x = F.pad(x, [0, 1, 0, 1], mode='replicate')
x = F.unfold(x, 2)
x = x.view(B, C, 2*2, H*W)
x = x.permute((0,1,3,2))
x = x.reshape(B * C * H * W, 2, 2)
x = x.unsqueeze(1)
x = self.act(self.conv_naive(x))
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.conv6(x)
# x += x_3x3
x = torch.tanh(x)
if self.upscale > 1:
x = self.pixel_shuffle(x)
# print(x.size())
return x
终于在MuLUTUnit中,我们看到了与本文相关的RC模块。在common/network.py中,部分MuLUTUnit会用到这个RC模块,RC模块的实现如下:
class RC_Module(nn.Module):
def __init__(self, in_channels, out_channels, out_dim, kernel_size=4, stride=1, padding=0, dilation=1, bias=True, mlp_field=7):
super(RC_Module, self).__init__()
self.mlp_field = mlp_field
self.conv = nn.Conv2d(in_channels, out_channels, 1,
stride=stride, padding=padding, dilation=dilation, bias=bias)
self.out_layer = nn.Linear(out_channels, in_channels)
nn.init.kaiming_normal_(self.conv.weight)
nn.init.constant_(self.conv.bias, 0)
for i in range(mlp_field * mlp_field):
setattr(self, 'linear{}'.format(i+1), nn.Linear(in_channels, out_channels))
setattr(self, 'out{}'.format(i+1), nn.Linear(out_channels, out_dim))
def forward(self, x):
x_kv = {}
# print('x', x.size())
for i in range(self.mlp_field):
for j in range(self.mlp_field):
num = i * self.mlp_field + j + 1
module1 = getattr(self, 'linear{}'.format(num))
x_kv[str(num)] = module1(x[:, i, j, :]).unsqueeze(1)
x_list = []
temp = []
for i in range(self.mlp_field * self.mlp_field):
module = getattr(self, 'out{}'.format(i+1))
x_list.append(module(x_kv[str(i+1)]))
# for i in range(self.mlp_field * self.mlp_field):
# temp.append(x_kv[str(i+1)])
out = torch.cat(x_list, dim=1)
# out = torch.cat(temp, dim=1)
out = out.mean(1)
# out = self.out_layer(out)
out = out.unsqueeze(-1).unsqueeze(-1)
out = torch.tanh(out)
out = round_func(out * 127)
bias, norm = 127, 255.0
out = round_func(torch.clamp(out + bias, 0, 255)) / norm
return out
这里的mlp_field就是我们讲解流程中的窗口大小“N”,self.conv和out_conv在下面的前向过程中没有使用到(博主认为这里可能是作者的一个失误,把无意义的代码留在这里)。
我们主要关注的是作者初始化了N*N个用于升维的self.linear[i+1]以及用于降维的self.out[i+1],
在前向的过程中,首先作者依次访问了N*N个升维度的MLP,将x进行处理,保存至x_kv中,这里的x的输入维度博主认为是B,N,N,C,由于他是需要进行卷积的模块,B应当是与H、W以及stride相关的结果,只不过每一个位置用了2个FC,且只作用于当前访问的点,读者可以初始化模型来观察tensor的shape进一步熟悉。
然后再对这N*N个结果进行一个降维,并且将他们cat在一起,求mean,相当于我们前面讲到的avg_pool操作。
最后就是一个简单的约束范围到[-1,1]并且做量化,量化使用的round_func在专栏MuLUT的代码讲解的Re-index中有讲到,如下所示。
def round_func(input):
# Backward Pass Differentiable Approximation (BPDA)
# This is equivalent to replacing round function (non-differentiable)
# with an identity function (differentiable) only when backward,
forward_value = torch.round(input)
out = input.clone()
out.data = forward_value.data
return out
前向做一个四舍五入,反向的过程跟原来一致,最后重新规整到0-1,输出结果。
这里想要完全明白整个流程的详情,建议读者自己实践一次,包括实际模型的初始化,运行代码,以及每一步去看模型前向中tensor的shape,对照着讲解可以更熟悉这整个过程。
2. 转表
这部分代码跟MuLUT代码一致,作者没有给出与RC模块相关的转换代码,不过我们已经清楚RC模块其实就是N*N个1D LUT,在初始化的部分跟其他的LUT转换不一样的点在于kernel_size的大小以及最终得到的LUT个数是N*N个。博主给出自己的RCLUT转换实现。
首先是初始化输入的tensor,因为RC模块窗口中每一个只是一个1D LUT,所以我们不需要做多维的输入,输入可以用下列代码初始化。
def get_rc_lut(interval):
base = torch.arange(0, 257, interval) # 0-256
base[-1] -= 1
return base.unsqueeze(1)
这样以interval=1为例,则输出是[256, 1],再经过下列代码进行推理,就可以得出RC模块中NxN个lut的一个对应关系,将他们保存即可。
def inference(model, input_tensor, mlp_field=3):
luts = {}
for i in range(mlp_field):
for j in range(mlp_field):
num = i * mlp_field + j + 1
fc1 = getattr(model, 'linear{}'.format(num))
fc2 = getattr(model, 'out{}'.format(num))
# same as torch model inference
batch_output = torch.tanh(fc2(fc1(input_tensor)))
luts[str(num)] = torch.round(torch.clamp(batch_output, -1, 1)
* 127).cpu().data.astype(np.int8)
return luts
3. 微调
微调的部分跟MuLUT是一样的,这个也可以查看MuLUT一栏对于微调部分的讲解,即将LUT作为一个可训练的参数去进行损失的优化。
4. 测试
测试的部分与MuLUT也是一致的,只不过针对于RC模块的测试代码需要修改为numpy实现的版本,感兴趣的读者可以自己尝试实现,博主认为逻辑是一样的。
以上针对于RCLUT代码实现的部分讲解完毕,如果有不清楚的问题欢迎大家提出。