Deep Bilateral Learning for Real-Time Image Enhancement
ABSTRACT
对于移动端的图像处理,性能功耗是一个非常大的挑战,这篇文章提出了一种新的网络架构可以实现实时的图像处理,这种网络架构是基于 bilateral grid processin 和 local affine color transforms,利用成对的输入-输出图像,文章作者训练一个卷积神经网络去预测 bilateral space 中的局部仿射变换模型的系数,这个网络结构可以同时在局部以及全局层面,基于图像内容去估计一个想要的图像变换。在实际运行的时候,先将高分辨率的图像缩小送入网络,网络再产生一组在 bilateral space 的仿射变换系数,这些仿射变换通过一种保边的形式上采样回高分辨,这些高分辨率的仿射变换就可以直接作用于原始的高分辨率图像。文章作者指出,这种算法可以在移动端实现毫秒级的运算速度,同时可以灵活 cover 各种图像变换操作,甚至包括复杂的图像 PS 效果。
INTRODUCTION
现在消费级的终端产品拍摄的图像或者视频的分辨率越来越高,图像后处理的算法复杂度也随之升高,所以为了满足性能功耗要求,最终运行在终端产品上的代码都需要有经验的程序员仔细地去优化。图像增强是图像处理里面非常常用的一种方式,但是图像增强都有很强的主观性,很多时候,图像增强的效果都是基于人的美学感知去调试出来的,这种效果很难用传统的处理流程来复现,这篇文章提出用基于学习的方法来实现。
基于深度学习的图像处理已经有很多相关的工作了,之前的这些网络,在处理图像时候的复杂度基本和图像的分辨率成线性关系,因为之前的网络结构,其包含的算子都是要在全尺寸上运行。虽然这些网络的效果也不错,但是代价就是算法复杂度很高,无法满足很多应用场景下的实时性要求。这篇文章提出的网络结构可以运行地更快,与常见的网络相比,运行速度可以快上成百上千倍。
这个网络结构主要包括三个部分,1):网络绝大多数的预测都是在一个低维的 bilateral grid 里面进行,这个预测后面会详细介绍,为什么称为 bilateral grid,因为这有点类似 bilateral filter,同时考虑了像素的位置与像素值。2): 网络的输出是融合权重,而不是最终的图像,这个也是借鉴了之前的一些工作,发现预测融合权重的效果比直接预测输出的效果要好。3): 最终的 loss 是在全分辨率下进行计算的,这样可以保证网络去评估对全分辨率的图像的影响。
OUR ARCHITECTURE
这个模型包含两个 stream,一个是低分辨率的 stream,一个是全分辨率的 stream,低分辨率的 stream 负责预测局部仿射变换,文章中指出,图像增强任务需要同时考虑图像的局部信息和全局信息,所以低分辨率的 stream 又同时包含了一个 local path 和一个全局 path,这两个 path 在后端融合输出最终的仿射变换。全分辨率的 stream 利用低分辨率 stream 的仿射变换信息,对输入图像在全分辨率上进行仿射变换,文章利用一种 slicing 的方式,可以用很低的运算代价实现高分辨率的仿射变换。
Low-resolution prediction of bilateral coefficients
低分辨率的输入 I ~ \tilde{I} I~ 文章中是固定为 256 X 256,通过一系列的卷积提取 low-level 的特征同时压缩分辨率,然后在末端进行分流,一路用全卷积算子获取局部特征,另外一路用卷积和全连接去学习一个固定尺寸的全局特征,这两路的输出,最后再融合成一组特征,这组特征通过一些线性层最后成为包含仿射系数的 bilateral grid。
Low-level features
首先利用卷积算子获取 low level 的特征, S 0 : = I ~ S^{0}: = \tilde{I} S0:=I~
S c i [ x , y ] = σ ( b c i + ∑ x ′ , y ′ , c ′ w c c ′ i [ x ′ , y ′ ] S c ′ i − 1 [ s x + x ′ , s y + y ′ ] ) S_{c}^{i}[x, y] = \sigma \left( b_{c}^{i} + \sum_{x', y', c'} w_{cc'}^{i} [x', y']S_{c'}^{i-1}[sx+x', sy+y'] \right) Sci[x,y]=σ bci+x′,y′,c′∑wcc′i[x′,y′]Sc′i−1[sx+x′,sy+y′]
i = 1 , 2 , . . . , n S i=1, 2, ..., n_S i=1,2,...,nS 表示层数, c , c ′ c, c' c,c′ 表示每一层的通道数, w ′ w' w′ 表示卷积的权重, b i b^{i} bi 表示 biases, x ′ , y ′ x', y' x′,y′ 表示邻域的范围,激活函数用的是 RELU。文章中 n s n_s ns 是 4。
Local features path
提取的 low level 特征 S n s S^{n_s} Sns 在后端会分成两路处理,一路是利用 stride 为 1 的全卷积算子,进行卷积处理,在保持空间分辨率不变的同时,扩大了感受野,从而能够更好地提取语义特征。
Global features path
另外一路就是 global feature path,与 local feature path的输入一样,都是 low level feature S n s S^{n_s} Sns,这一路包含两个 stride 为 2 的卷积层,然后再加上三个全连接层,这一路最终输出一个 64 维的全局特征。
Fusion and linear prediction
接下来,就是将两路的输出进行融合,
F c [ x , y ] = σ ( b c + ∑ c ′ w c c ′ ′ G c ′ n G + ∑ c ′ w c c ′ L c ′ n L [ x , y ] ) F_c[x, y] = \sigma \left( b_c + \sum_{c'} w'_{cc'} G_{c'}^{n_G} + \sum_{c'}w_{cc'}L_{c'}^{n_L}[x, y] \right) Fc[x,y]=σ(bc+c′∑wcc′′Gc′nG+c′∑wcc′Lc′nL[x,y])
这个最终形成了一个 16 × 16 × 64 16 \times 16 \times 64 16×16×64 的 feature map,这个 feature map 再通过一些线性变换,得到 16 × 16 × 96 16 \times 16 \times 96 16×16×96 的 feature map:
A c [ x , y ] = b c + ∑ c ′ F c ′ [ x , y ] w c c ′ ′ A_{c}[x, y] = b_{c} + \sum_{c'} F_{c'}[x, y]w'_{cc'} Ac[x,y]=bc+c′∑Fc′[x,y]wcc′′
Image features as a bilateral grid
到目前为止,我们已经介绍了如何获取 feature map,接下来,我们介绍如何将 feature map 与 bilateral grid 联系起来,上面我们已经说过,我们获得了一个 16 × 16 × 96 16 \times 16 \times 96 16×16×96 的 feature map,我们可以把这个 feature map 看成是一个多通道的 bilateral grid,如下所示:
A d c + z [ x , y ] ⟷ A c [ x , y , z ] A_{dc+z}[x, y] \longleftrightarrow A_c[x, y, z] Adc+z[x,y]⟷Ac[x,y,z]
d = 8 d = 8 d=8 表示 grid 的深度,这样 A A A 可以看成是一个 16 × 16 × 8 16 \times 16 \times 8 16×16×8 的 bilateral grid,而每个 grid 包含 12 个系数,可以等价为一个 $ 3 \times 4$ 的仿射变换矩阵。
Upsampling with a trainable slicing layer
接下来,就是这篇文章最为创新的地方,通过一个 slicing 的操作,可以将低维的 feature map 和一个高维的 guidance map 结合,最后输出高维的 feature map,如下所示:
A ˉ c [ x , y ] = ∑ i , j , k τ ( s x x − i ) τ ( s y y − j ) τ ( d ⋅ g [ x , y ] − k ) A c [ i , j , k ] \bar{A}_c[x, y] = \sum_{i,j,k} \tau(s_x x - i) \tau(s_y y - j)\tau(d \cdot g[x, y] - k)A_c[i,j,k] Aˉc[x,y]=i,j,k∑τ(sxx−i)τ(syy−j)τ(d⋅g[x,y]−k)Ac[i,j,k]
其中, τ \tau τ 是一个线性插值函数, τ ( ⋅ ) = max ( 1 − ∣ ⋅ ∣ , 0 ) \tau(\cdot) = \max(1 - |\cdot|, 0) τ(⋅)=max(1−∣⋅∣,0), s x , s y s_x, s_y sx,sy 是 grid 的宽高相对于全尺寸图像的比例,文章中,grid 的分辨率固定为 16 × 16 16 \times 16 16×16 , d = 8 d = 8 d=8,如果全尺寸的分辨率为 H × W H \times W H×W, 那么, s x = 16 / W , s y = 16 / H s_x = 16/W, s_y = 16/H sx=16/W,sy=16/H, x ∈ [ 0 , W − 1 ] , y ∈ [ 0 , H − 1 ] x \in [0, W-1], y \in [0, H-1] x∈[0,W−1],y∈[0,H−1], s x x , s y y s_x x, s_y y sxx,syy 的范围都是在 0-16 左右, g [ x , y ] g[x,y] g[x,y] 的取值范围是 0-1, d ⋅ g [ x , y ] d \cdot g[x, y] d⋅g[x,y] 的范围大概是 0-8, 如果 i ∈ [ 0 , 15 ] , j ∈ [ 0 , 15 ] , k ∈ [ 0 , 7 ] i\in[0, 15],j \in[0, 15], k \in [0, 7] i∈[0,15],j∈[0,15],k∈[0,7],再结合线性插值函数的定义,里面的很多值都是 0, 这么算下来,这个 slicing 算子似乎很稀疏,可以理解成是将全分辨的大图分成了 16 × 16 16 \times 16 16×16 块,从 [0, 0] 到 [15,15],当处于 [0, 0] 块的时候, i , j i, j i,j 只有等于 0, 0 的时候,slicing 才有值,其它时候都是 0, 处于 [5, 5] 块的时候,那么 i , j i, j i,j 在 0-4 之间 slicing 的输出都是有值的,剩下的也都是 0,这个 slicing 操作看起来,似乎越往后面的块,算法的复杂度越高,最复杂的情况下, 一个点需要进行 16 × 16 × 8 16 \times 16 \times 8 16×16×8 这么多次的插值运算。最后的输出是一个 12 通道的全分辨率的 feature map A ˉ c , c ∈ [ 0 , 11 ] \bar A_c , c \in [0, 11] Aˉc,c∈[0,11]。
Assembling the full-resolution output
最后一部分,就是介绍如何获得最终的输出,首先是提取一个单通道的 guidance map g [ x , y ] g[x, y] g[x,y],这个 guidance map 文章中是利用一个线性变换,将原始全分辨率的 RGB 三通道的输入变换成一个单通道的 guidance map:
g [ x , y ] = b + ∑ c = 0 2 ρ c ( M c T ⋅ ϕ c [ x , y ] + b c ′ ) g[x, y] = b + \sum_{c=0}^{2} \rho_c(\mathbf{M}_{c}^{T} \cdot \phi_{c}[x, y] + b'_{c}) g[x,y]=b+c=0∑2ρc(McT⋅ϕc[x,y]+bc′)
其中, M c T \mathbf{M}_{c}^{T} McT 是一个 3 × 3 3 \times 3 3×3 的色彩变换矩阵, ϕ c \phi_c ϕc 表示输入的通道,
ρ c ( x ) = ∑ i = 0 15 a c , i max ( x − t c , i , 0 ) \rho_{c}(x) = \sum_{i=0}^{15} a_{c, i} \max(x - t_{c, i}, 0) ρc(x)=i=0∑15ac,imax(x−tc,i,0)
这些参数都是可以通过学习得到的,最后的输出就是将输入的三通道 ϕ c ′ \phi_{c'} ϕc′ 与 A ˉ \bar{A} Aˉ 进行融合。
O c [ x , y ] = A ˉ n ϕ + ( n ϕ + 1 ) c + ∑ c ′ = 0 n ϕ − 1 A ˉ c ′ + ( n ϕ + 1 ) c [ x , y ] ϕ c ′ [ x , y ] O_c[x, y] = \bar{A}_{n_{\phi} + (n_{\phi}+1)c} + \sum_{c'=0}^{n_{\phi} -1} \bar{A}_{c' + (n_{\phi}+1)c}[x, y]\phi_{c'}[x, y] Oc[x,y]=Aˉnϕ+(nϕ+1)c+c′=0∑nϕ−1Aˉc′+(nϕ+1)c[x,y]ϕc′[x,y]
文章中的输入是一个三通道的 RGB 图像,所以 n ϕ = 3 , c = 0 , 1 , 2 , c ′ = 0 , 1 , 2 n_{\phi} = 3, c=0,1,2, c'=0,1,2 nϕ=3,c=0,1,2,c′=0,1,2, 这个融合就类似下面的一个变换:
R ′ = a 1 R + a 2 G + a 3 B + b 1 G ′ = a 4 R + a 5 G + a 6 B + b 2 B ′ = a 7 R + a 8 G + a 9 B + b 3 R' = a_1 R + a_2 G + a_3 B + b1 \\ G' = a_4 R + a_5 G + a_6 B + b2 \\ B' = a_7 R + a_8 G + a_9 B + b3 R′=a1R+a2G+a3B+b1G′=a4R+a5G+a6B+b2B′=a7R+a8G+a9B+b3
R , G , B R, G, B R,G,B 就对应 ϕ c ′ [ x , y ] \phi_{c'}[x, y] ϕc′[x,y], R ′ , G ′ , B ′ R', G', B' R′,G′,B′ 就对应 O c [ x , y ] O_c[x, y] Oc[x,y], 这些系数就对应 A ˉ \bar{A} Aˉ 的 12 个通道。
最后的 loss 比较简单直接:
L = 1 D ∑ i ∣ ∣ O i − I i ∣ ∣ 2 \mathcal{L} = \frac{1}{\mathcal{D}} \sum_{i} || O_i - I_i ||^2 L=D1i∑∣∣Oi−Ii∣∣2
I i I_i Ii 表示目标效果, O i O_i Oi 表示网络的预测效果。
总得来说,这篇文章的创新在于引入了一个 bilateral grid 以及 slicing 的一个操作。但是看这个 slicing 的操作,算法的复杂度也不低,不知道为什么会这么快。