SR-LUT是2021年的一篇SISR领域的CVPR论文。SR-LUT以较快的执行速度,可脱离CNN在移动端也能快速实现超分的特点,此外其在重建表现力上也具有一定的能力。因此这篇文章很是有必要阅读一番关于SR-LUT的理论解析部分见我的另一篇超分之SR-LUT(建议先看论文解读部分,再来看源码解析)。
Practical Single-Image Super-Resolution Using Look-Up Table
写在前面:
下面我们分别对三部分做个简要解析。
1 训练部分
1.1 CNN模型结构
如上图所示是SR-LUT的CNN结构:
- 第0层是输入层:SR-LUT训练部分从DIV2K数据集中抠出 48 × 48 48\times 48 48×48的patch作为网络训练的输入size,输入通道数为3。
- 第1~6层是卷积层:常规的特征提取,其中除了第一层是用 2 × 2 2\times 2 2×2的卷积核提取的,其余几层都是使用了大小为 1 × 1 1\times 1 1×1的滤波器。需要注意的是最后一层卷积层是要输出通道数为 r 2 r^2 r2的feature map。
- 第7层是亚像素卷积层:由于上一层输出了
r
2
r^2
r2个feature map,根据ESPCN论文提出的亚像素卷积层来作为SR中的上采样部分可以减少模型训练复杂度的同时提高了效率,并且PyTorch中有关于亚像素卷积曾的实现——
torch.nn.PixelShuffle(r)
,具体参考我的另一篇PyTorch之PixelShuffle,其作用就是将输入feature map扩展成高和宽 r r r倍的输出feature map: ( r 2 , H , W ) → ( 1 , r H , r W ) (r^2, H,W)\to (1, rH, rW) (r2,H,W)→(1,rH,rW) - 最后一层是输出层。输出图像的格式为 ( b a t c h , 3 , 48 r , 48 r ) (batch, 3, 48r, 48r) (batch,3,48r,48r),其中 r r r为SR缩放倍率。
Note:
- 论文中写的 2 × 2 2\times 2 2×2小块感受野是在第二部分存表部分展现的,而训练部分还是常规的图像输入,旨在训练一个常规的CNN超分模型,这和大部分SR网络是类似的。但由于这个网络的参数要在存表部分用于较小的感知野,所以相较以往的SR结构,SR-LUT的网络结构较简单,即深度较浅,宽度较短。
- 卷积层参数例如(3,64,2,2,1,0)表示输入通道数3,输出通道数64,卷积核 2 × 2 2\times 2 2×2,stride=1,padding=0。
- 输入层
(
48
+
1
)
×
(
48
+
1
)
(48+1)\times (48+1)
(48+1)×(48+1)的输入,是为了保持输出为
48
×
48
48\times 48
48×48的大小,因为有一个
2
×
2
2\times 2
2×2的卷积核存在,作者的处理方法就是在输入前对图像进行pad填充成
49
×
49
49\times 49
49×49,使用的是Pytorch的
torch.pad
函数(填充模式是镜像模式)。 - 网络的输入图像是做了归一化的。
1.2 训练过程解析
为了扩大感受野而不增加LUT存储量,作者采用了自集成(self-ensemble),不同于EDSR中将自集成应用于测试中来提高重建表现力,SR-LUT作者将此技巧用于训练中,实验证明该方式确实有助于提升图像整体表现力。
文中采用了4种方式,分别是原图、旋转90°、旋转180°、旋转270°来增强图像,对每一种增强都将输入图像
x
i
x_i
xi按照先变换
R
R
R再输入网络再逆变换
R
−
1
R^{-1}
R−1成放大后的
H
R
HR
HR图像的顺序去训练。
用公式来表达:
y
i
^
=
1
4
∑
j
=
0
3
R
j
−
1
(
f
(
R
j
(
x
i
)
)
)
\hat{y_i} = \frac{1}{4}\sum^3_{j=0}R_j^{-1}(f(R_j(x_i)))
yi^=41j=0∑3Rj−1(f(Rj(xi)))
然后将自集成的结果与Ground做MSE-Loss,然后梯度下降更新模型参数:
L
o
s
s
=
∑
i
l
(
y
i
^
,
y
i
)
Loss = \sum_il(\hat{y_i}, y_i)
Loss=i∑l(yi^,yi)
Note:
- 我们再来总结一下这部分的训练:从DIV2K数据集中取出batch张图片,每一张图片都要进行4次的增强操作,并进行pad填充,之后再输入SR-LUT网络,将输出的结果进行之前增强的逆操作输出 H R HR HR图片,将这些图片取平均得到 y ^ \hat{y} y^,然后和Ground Truth(标签)做loss,从而可以更新模型参数,让模型学会如何重建 L R LR LR图片。
- 关于填充,PyTorch采用torch.pad函数来处理,关于这个函数的解析,可参考PyTorch碎片:F.pad的图文透彻理解。
- 关于旋转,PyTorch采用
torch.rot90
函数来处理,关于这个函数的解析,可借鉴我的另一篇PyTorch之rot函数。
总结一下:
- 这一部分只是和之前的SR算法一样,去训练一个可以将 L R LR LR重建成 H R HR HR的超分网络,其输入是一张图片或者图片的patch,训练的目的就是找到一个函数 f f f,它可以实现 L R → H R LR\to HR LR→HR。
- 训练的结果就是得到一个模型,我们将他保存下来,供下一环节的存表部分使用。
2 存表部分
这部分为了方便讨论,令SR缩放系数 r = 4 , W = 2 4 r=4,W=2^4 r=4,W=24。
2.1 表格构建
在SR-LUT论文中,作者设置了3个SR-LUT变体:Ours-V、Ours-F、Ours-S,分别代表感受野为2D、3D以及4D时候的SR-LUT。为了方便讨论,接下来我们只讨论4D感受野(即 2 × 2 2\times 2 2×2)下的SR-LUT结构。
SR-LUT表格的构建是以输入像素的像素值为索引,以SR-LUT网络输出结果(不做自集成)为内容构建的。
理论情况下,对于
2
×
2
2\times 2
2×2感受野,表格的索引一共有
(
2
8
)
(
2
×
2
)
(2^8)^{(2\times 2)}
(28)(2×2)种可能,且每一种可能都需要存
r
2
=
16
r^2=16
r2=16个8bits的值,因此存储量为(
1
G
B
=
2
30
B
1GB=2^{30}B
1GB=230B):
(
2
8
)
(
2
×
2
)
×
r
2
=
64
G
B
(2^8)^{(2\times 2)} \times r^2 = 64GB
(28)(2×2)×r2=64GB这样的存储量是非常大的,当我们在手机端执行的时候,要从这么大的表里去查找
L
R
LR
LR对应的
H
R
HR
HR像素值,显然是不可能的,因此为了减小Full-LUT的存储量,作者引入Sampled-LUT,即我们引入采样间隔
W
W
W,将
[
0
−
255
]
[0-255]
[0−255]这个区间按采样间隔分开来,这样就只剩下了
2
4
+
1
2^4+1
24+1种像素值,分别是
{
0
,
16
,
32
,
64
,
80
,
96
,
112
,
128
,
144
,
160
,
176
,
192
,
208
,
224
,
240
,
255
}
\{0, 16, 32,64,80,96,112,128,144,160,176,192,208,224,240,255\}
{0,16,32,64,80,96,112,128,144,160,176,192,208,224,240,255}。这种思想就类似于直方图统计,对于RGB一共
2
24
2^{24}
224位的图像,如果统计每个颜色的像素个数,那就完蛋了,因此常用的做法是,规定一个区间,将颜色相近的放到一个区间之内,凡是在同一个区间的都当成是一个颜色来看待,然后进行像素值统计。那么这里也是一样,比如我们将像素值在
[
0
,
15
]
[0,15]
[0,15]之内的像素都当成是一个像素值。因此Sampled-LUT的内存消耗为(
1
M
B
=
2
20
B
1MB=2^{20}B
1MB=220B):
(
2
4
+
1
)
(
2
×
2
)
×
r
2
=
1.274
M
B
(2^4+1)^{(2\times 2)} \times r^2 = 1.274MB
(24+1)(2×2)×r2=1.274MB
那么接下来我们看看源码是如何创建这个表的,接下来的2.2节介绍创建之后,如何存表。
base = torch.arange(0, 257, 2**SAMPLING_INTERVAL) # [0, 16, 32, ..., 255] 1D感受野
base[-1] -= 1
L = base.size(0) # 17
first = base.cuda().unsqueeze(1).repeat(1, L).reshape(-1) # 17*17 0 0 0... |16 16 16... |...|255 255 255...
second = base.cuda().repeat(L) # 17*17 0 16 32 .. 255|0 16 32 ... 255|...|0 16 32 ... 255
onebytwo = torch.stack([first, second], 1) # [17*17, 2] 2D感受野
third = base.cuda().unsqueeze(1).repeat(1, L*L).reshape(-1) # 17*17*17
onebytwo = onebytwo.repeat(L, 1)
onebythree = torch.cat([third.unsqueeze(1), onebytwo], 1) # [17*17*17, 3] 2D感受野
fourth = base.cuda().unsqueeze(1).repeat(1, L*L*L).reshape(-1) # 17*17*17*17
onebythree = onebythree.repeat(L, 1)
onebyfourth = torch.cat([fourth.unsqueeze(1), onebythree], 1) # [17*17*17*17, 4] 4D感受野
其实也很简单,就是先调用torch.arange(),然后创建相同值组成的数组和以
W
W
W为间隔的像素值数组,通过不断堆叠产生1D、2D、3D、4D感受野下的LUT,为了更直观体现这一过程,我画了一张图来可视化,如下图所示:
从上图可知,我们将
(
2
4
+
1
)
2
×
2
=
83521
(2^4+1)^{2\times 2}=83521
(24+1)2×2=83521种像素组合可能按顺序罗列了出来,每一种可能都可以在上图中找到。
2.2 存表过程解析
存表过程分三部分:
- 我们需要将上图reshape成 ( 1 7 4 , 1 , 2 , 2 ) (17^4, 1, 2, 2) (174,1,2,2)的图片格式,来表示将83521张size为 2 × 2 2\times 2 2×2的像素输入,并进行归一化处理。
- 我们将所有的83521张小图片进行batch=835的划分,然后输入到我们上一节花了一段时间训练好的SR-LUT网络模型中去,所得结果就是每个 2 × 2 2\times 2 2×2小图片对应的 H R HR HR图片。
- 最后将所有batch的输出结果合并并保存起来,这个结果就是我们最终所需要LUT,也就是供给测试用的LUT,它的大小为 ( 1 7 4 , 1 , r , r ) = ( 83521 , 1 , 4 , 4 ) (17^4, 1, r, r)=(83521, 1, 4, 4) (174,1,r,r)=(83521,1,4,4),从这个结果可以看出,将每一个 2 × 2 2\times 2 2×2感受野看成一个整体,其对应了一个 r × r r\times r r×r的 H R HR HR像素块。
对于每一个
2
×
2
2\times 2
2×2的小图片经过网络之后输出
r
×
r
r\times r
r×r的
H
R
HR
HR块,其前向如下所示:
故网络的输入输出格式为:
b
a
t
c
h
×
1
×
2
×
2
→
b
a
t
c
h
×
1
×
r
×
r
batch\times 1\times 2\times 2 \to batch \times 1 \times r\times r
batch×1×2×2→batch×1×r×r
我们将所有batch个
r
×
r
r\times r
r×r堆叠起来存到np.array中,这个数组就是我们的LUT,我们reshape一下格式,故最终的LUT的shape为:
(
1
7
4
,
r
2
)
(17^4, r^2)
(174,r2)。
3 测试(读表)部分
这部分为了方便讨论,我们只讨论4D感受野(即
2
×
2
2\times 2
2×2)下的SR-LUT插值部分,且
r
=
4
r=4
r=4。
由于为了节约LUT的存储消耗,作者采用了Sampled-LUT,即缩小图像采样的范围:从
[
0
,
255
]
[0,255]
[0,255]一共256个值的全LUT降到以
W
=
2
4
W=2^4
W=24为采样间隔形成的共17个采样值。这样虽然带来了资源优化,但也使得在测试的时候, 对于
0
255
0~255
0 255任意一个像素值,可能无法找到相对应的
H
R
HR
HR像素块。因此我们就需要插值来解决这个问题,对于非采样点(即不在17个采样点之中),我们会利用采样点的像素,通过一些插值算法来求得非采样点对应的高分辨率
r
×
r
r\times r
r×r块。
3.1 插值过程解析
在正式分析之前,我们先理一下几个点”
- 查表部分的核心思想:对于测试集(比如Set5)每一张 H × W H\times W H×W图片,当我遇到了坐标为 ( x , y ) (x,y) (x,y)处的像素值,我先遍历一遍以 ( x , y ) (x,y) (x,y)为左上像素点,其右边1格、下面1格、右下1个格共4个像素点的值,然后利用某种插值办法来找到sampled-LUT中的对应的某个 r × r r\times r r×r块来作为当前像素点 ( x , y ) (x,y) (x,y)重建之后 H R HR HR图像,这样就完成 ( 1 × 1 ) → ( r × r ) (1\times 1)\to (r\times r) (1×1)→(r×r)的放大;同理遍历每一个 ( x , y ) (x,y) (x,y)之后,就获得了 r H × r W rH\times rW rH×rW图像。
- 查表过程的输入是从存表步骤里我们最终获得的sampled-LUT,这个LUT有 ( 2 4 + 1 ) 2 × 2 (2^4+1)^{2\times 2} (24+1)2×2行,有 r 2 r^2 r2列,即里面每一行存的都是16个 H R HR HR像素。
- 关于插值的选用问题,根据论文中的Table 2所示:
作者在源码中对于2D、3D、4D分别采用了三角插值(三个顶点)、四面体插值(四个顶点)、4-单形插值(需要5个顶点),由于本文只讨论4D的感受野,所以接下来介绍4-simplex插值。 - 关于数据流问题,输入是一整张测试图片,比如Set5的第一张就是个小孩子:,这是一张
3
×
128
×
128
3\times 128\times 128
3×128×128的图片。①我们需要先利用自集成技巧,这里和ESDN做法一样,将self-ensemble用在测试环节提高表现力,这里和训练部分一样,对每一张图片使用4种自增强旋转操作;
②然后对它进行pad填充成 ( 128 + 1 ) × ( 128 + 1 ) (128+1)\times (128+1) (128+1)×(128+1),这和SR-LUT模型训练环节是一样的,为的是配合模型在亚像素采样之前保持在 128 × 128 128\times 128 128×128的大小(其他图片也是一样);
③不同于网络前向传输,之后是进行单形插值操作(实现 L R → H R LR\to HR LR→HR),并做逆增强操作;
④之后将4种增强的结果进行平均之后就是我们输出的高分辨率图像了。
接下来我们重点分析下插值部分,其涉及到三个难点:
- 如何找到一个 2 × 2 2\times 2 2×2的感受野。
- 如何查找到LUT中对应的 r × r r\times r r×r块问题(即索引是如何设置的)。
- 单形插值是如何进行的。
首先是Q1:如何找到一个
2
×
2
2\times 2
2×2感受野
我们看看源码:
img_a1 = img_in[:, 0:0+h, 0:0+w] // q
img_b1 = img_in[:, 0:0+h, 1:1+w] // q
img_c1 = img_in[:, 1:1+h, 0:0+w] // q
img_d1 = img_in[:, 1:1+h, 1:1+w] // q
img_in是输入的图像;img_b1、img_c1、img_d1分别是原图img_a1进行平移得到:
之所以这样做的原因是:
- 我们的LUT是根据输入的一个 2 × 2 2\times 2 2×2感受野,根据4个像素值情况来决定LUT的索引,然后找到对应的 r × r r\times r r×r像素块,所以每一次查表都必须知道 2 × 2 2\times 2 2×2块的像素值。
- 插值算法中需要知道每一个 2 × 2 2\times 2 2×2块中的4个像素值。
鉴于上述原因,我们利用pad的优势得到平移后的3张图像,那么每次只要取4张图像中相同坐标的值,就相当于对原图中一个 2 × 2 2\times 2 2×2感受野进行操作了,比如上图中红色虚线框中,我们要对这个框进行查表,只要取出四张表在 ( 0 , 0 ) (0,0) (0,0)处的值(如右边3张图的第0个像素值所示)就行了,因为这就相当于得到了原图中相邻的 2 × 2 2\times 2 2×2像素块。
其次是Q2:如何索引到LUT中的
H
R
HR
HR像素块
在测试环节我们拿到的LUT有
1
7
4
17^4
174行,每一行的内容都是经过网络处理过的
H
R
HR
HR像素,那么如何根据输入的
2
×
2
2\times 2
2×2来找到索引,从而
1
7
4
17^4
174个索引值中取出相应的
r
×
r
r\times r
r×r块呢?
参照本文2.1节表格构建部分,我们可以通过如下方式取得索引:(设采样间隔
W
=
2
4
W=2^4
W=24且像素都是采样点像素)
对于一张图像中的某一个坐标对
(
x
,
y
)
(x, y)
(x,y),可以通过其余三个平移图像获取
2
×
2
2\times 2
2×2感受野的信息,其对应的LUT索引由输入感受野像素组成:
img_a1[x,y].astype(np.int_)*L*L*L
+ img_b1[x,y].astype(np.int_)*L*L
+ img_c1[x,y].astype(np.int_)*L
+ img_d1[x,y].astype(np.int_) # L=2**(8-4) + 1
简单说明一下,这里每一个像素乘以不同个数的 L L L是由像素在 2 × 2 2\times 2 2×2中的位置决定的,因为LUT表中,从左到右的四列经过源码中
input_tensor = onebyfourth.unsqueeze(1).unsqueeze(1).reshape(-1, 1, 2, 2).float() / 255.0
这样的变型之后,会产生一种对应关系: 2 × 2 2\times 2 2×2的感受野,左上像素对应于四列中的第0列;右上像素对应于四列中的第1列;左下像素对应于四列中的第2列;右下像素对应于四列中的第3列(这种关系是PyTorch中数组低层存储关系决定的)。因此决定左上像素索引的就是 i m g _ a 1 [ x , y ] . a s t y p e ( n p . i n t _ ) ∗ L ∗ L ∗ L img\_a1[x,y].astype(np.int\_)*L*L*L img_a1[x,y].astype(np.int_)∗L∗L∗L;决定右上像素像素索引的就是 i m g _ b 1 [ x , y ] . a s t y p e ( n p . i n t _ ) ∗ L ∗ L img\_b1[x,y].astype(np.int\_)*L*L img_b1[x,y].astype(np.int_)∗L∗L;决定左下像素索引的就是 i m g _ c 1 [ x , y ] . a s t y p e ( n p . i n t _ ) ∗ L img\_c1[x,y].astype(np.int\_)*L img_c1[x,y].astype(np.int_)∗L;决定右下像素索引的是 i m g _ d 1 [ x , y ] . a s t y p e ( n p . i n t _ ) img\_d1[x,y].astype(np.int\_) img_d1[x,y].astype(np.int_)。
最后是Q3:单形插值
我们遍历输入图像中每一个像素点 ( x , y ) (x,y) (x,y),其中 x , y x,y x,y是坐标。对于每一个像素,我们通过插值算法获取其对应的一个 r × r r\times r r×r的像素块。遍历完 H H H行 W W W列之后,就获取了 H × W H\times W H×W个 r × r r\times r r×r块,经过reshape之后,就得到了 r H × r W rH\times rW rH×rW的高分辨率图像。
单行插值的步骤:
- 提取每个8bits像素的MSB(高4位)和LSB(低4位)。
- 获取顶点(读LUT),对于4D单形插值需要获取5个顶点。
- 获取权值以及找到5个最佳顶点。
- 根据权值和顶点获得最终的插值结果。
如下图所示就是三角插值的顶点和权值:
紫色点就是非采样点,灰色点就是采样格点,
w
0
,
w
1
,
w
2
w_0,w_1,w_2
w0,w1,w2就是权值。
3.1.1 提取MSBs和LSBs
# 获取MSBs
img_a1 = img_in[:, 0:0+h, 0:0+w] // q=W
img_b1 = img_in[:, 0:0+h, 1:1+w] // q
img_c1 = img_in[:, 1:1+h, 0:0+w] // q
img_d1 = img_in[:, 1:1+h, 1:1+w] // q
img_a2 = img_a1 + 1
img_b2 = img_b1 + 1
img_c2 = img_c1 + 1
img_d2 = img_d1 + 1
# 获取LSBs
fa_ = img_in[:, 0:0+h, 0:0+w] % q
fb_ = img_in[:, 0:0+h, 1:1+w] % q
fc_ = img_in[:, 1:1+h, 0:0+w] % q
fd_ = img_in[:, 1:1+h, 1:1+w] % q
Note:
- 对于一个8bits输入数据,前四位取商,后四位取余,相信学过大学C语言课的很熟悉。
img_a2、img_b2、img_c2、img_d2
用于后续求顶点。
3.1.2 查表求顶点
对于一个4D输入,一共由
2
4
2^4
24个顶点,但在单形插值中实际只需要5个顶点,至于要哪几个顶点要取决于
2
×
2
2\times 2
2×2中四个像素值LSBs(即步骤一求得的fa_、fb_、fc_、fd_)之间的大小关系。
Note:
- 上图所示就是求16个顶点的源码,规则是如果顶点的下标是二进制0,就采用
img_a1、img_b1、img_c1、img_d1
;反之为二进制1,就采用img_a1、img_b1、img_c1、img_d1
。此外顶点的四位中,从左到右分别代表 2 × 2 2\times 2 2×2感受野的左上、右上、左下、右下像素。 - weight就是LUT(np数组),求顶点的过程就是查表的过程,两者是等价的,查表需要 2 × 2 2\times 2 2×2输入块的所有像素值信息。
- 顶点就是输入图像中每个点 ( x , y ) (x,y) (x,y)对应LUT中的一个 r × r r\times r r×r高分辨率像素块。
- 顶点reshape之后的格式为: C × H × W × r × r C\times H\times W\times r\times r C×H×W×r×r。
3.1.3 求权值
第三步就是求出每个顶点对应的权值,至于权值的大小以及选用哪5个顶点取决于
2
×
2
2\times 2
2×2感受野的4个像素值的LSBs,具体实现见上述代码。
Note:
- 最后对于每一个像素
(
x
,
y
)
(x,y)
(x,y),求得一个插值结果
out[c,y,x]
,这是一个 r × r r\times r r×r的像素块,其中 ( x , y ) (x,y) (x,y)是这个 2 × 2 2\times 2 2×2感受野的左上像素。
3.1.4 输出HR图像
根据公式:
o
u
t
[
c
,
y
,
x
]
=
1
W
∑
i
=
0
4
w
i
O
i
.
out[c,y,x] = \frac{1}{W}\sum^4_{i=0}w_i O_i.
out[c,y,x]=W1i=0∑4wiOi.求得非采样点对应的高分辨率
r
×
r
r\times r
r×r块,那么对于图像的所有像素点都采用上述公式就可以获取整幅
L
R
LR
LR图像对应的
H
R
HR
HR图像:
out = np.transpose(out, (0, 1, 3, 2, 4)).reshape((C, H*upscale, img_a1.shape[2]*upscale))
然后进行逆增强就可以输出出去了,一共4次的逆增强求平均之后就是我们最后可以保存下来的缩放倍率 r = 4 r=4 r=4的 H R HR HR图像了!!!