ISTA-Net: Interpretable Optimization-Inspired Deep Network for Image Compressive Sensing
关于论文《ISTA-Net》的研究
本篇内容为关于2018年CVPR论文《ISTA-Net: Interpretable Optimization-Inspired Deep Network for Image Compressive Sensing》的个人理解。该篇论文第一作者为北京大学副教授张建。该论文为deep unfolding方法在图像压缩感知方面的成功应用,搭建出了在“层”一级可解释的神经网络ISTA-Net。该网络的每一层对应于迭代软阈值算法(ISTA-Net)的一次迭代运算。该方法很好的结合了传统迭代算法和深度网络算法的优势,既大大提高了计算效率,又赋予了网络明确的可解读性。该文行文流畅、用词精准而不晦涩,实为一篇佳作。
要想很好地理解这篇论文,首先要从“迭代软阈值算法”的概念入手。
1. 迭代软阈值算法(ISTA)
我们常见的优化问题的目标函数可以表示为:
X
^
=
a
r
g
m
i
n
∣
∣
X
−
B
∣
∣
2
2
+
λ
∣
∣
X
∣
∣
1
(1)
\hat{X}=arg~ min||X-B||^{2}_{2}+\lambda||X||_{1}\tag{1}
X^=arg min∣∣X−B∣∣22+λ∣∣X∣∣1(1)
这里
X
^
\hat{X}
X^只最优解,
∣
∣
⋅
∣
∣
2
||\cdot||_{2}
∣∣⋅∣∣2代表二范数,
∣
∣
⋅
∣
∣
1
||\cdot||_{1}
∣∣⋅∣∣1代表一范数,
λ
\lambda
λ为一个常数。(1)式中等号的右边两项,前者
a
r
g
m
i
n
∣
∣
X
−
B
∣
∣
2
2
arg~ min||X-B||^{2}_{2}
arg min∣∣X−B∣∣22为保真项, 用来衡量
X
X
X和
A
A
A的相似程度,后者
λ
∣
∣
X
∣
∣
1
\lambda||X||_{1}
λ∣∣X∣∣1为正则项。保真项的意义很好理解,因为我们的目的就要找到和
B
B
B最接近的
X
X
X,所以要是这一项尽可能得小。而正则项的意义对于初学者往往不太容易理解,在这个问题里可以这样思考:
正则项相当于是给最小化保真项这一过程中添加了额外一个约束,即仅仅“满足保真项足够小”还不够,还要在"正则项不能太大"的约束下完成,即如果找到了一组解
X
^
\hat{X}
X^, 使得
∣
∣
X
^
−
B
∣
∣
2
2
||\hat{X}-B||^{2}_{2}
∣∣X^−B∣∣22很小,但
λ
∣
∣
X
^
∣
∣
1
\lambda||\hat{X}||_{1}
λ∣∣X^∣∣1很大,这样的
X
^
\hat{X}
X^是不可取的,必须二者的和一起约束到最小。
那为什么要这么干呢?答:有好几个作用。 主要有(1)缓解过拟合; (2) 满足
X
X
X的稀疏性 (这一点在压缩感知里尤为重要)
再回到优化问题本身,究竟如何求解满足(1)式的 X ^ \hat{X} X^呢? 迭代软阈值算法(ISTA)就是一个迭代求解 X ^ \hat{X} X^的算法。迭代很好理解,那么“软阈值”是什么意思呢?下面首先解释软阈值函数的概念
- 软阈值函数(soft-threshold)
软阈值函数的表达式为
s o f t ( x , T ) = { x + T , x ≤ − T 0 , ∣ x ∣ < T x − T , x ≥ T (2) soft(x,T)=\left\{ \begin{array}{c} x+T, x\leq-T \\ 0, |x|<T \\ x-T, x\geq T\end{array}\right. \tag{2} soft(x,T)=⎩⎨⎧x+T,x≤−T0,∣x∣<Tx−T,x≥T(2)
这个长相略微奇怪的函数跟求解(1)式有什么关系呢?我们再回到(1)式,把这些向量都写开, 设 X = [ x 1 , x 2 , . . . , x N ] T X= [x_1,x_2,...,x_N]^{T} X=[x1,x2,...,xN]T, B = [ b 1 , b 2 , . . . , b N ] T B= [b_1,b_2,...,b_N]^{T} B=[b1,b2,...,bN]T,优化函数 F ( X ) F(X) F(X), 则
F ( X ) = ∣ ∣ X − B ∣ ∣ 2 2 + λ ∣ ∣ X ∣ ∣ 1 = ∑ n = 1 N ( x n − b n ) 2 + λ ∣ x n ∣ (3) F(X) = ||X-B||^{2}_{2}+\lambda||X||_{1} \\ =\sum_{n=1}^{N}(x_n-b_n)^{2}+\lambda|x_n| \tag{3} F(X)=∣∣X−B∣∣22+λ∣∣X∣∣1=n=1∑N(xn−bn)2+λ∣xn∣(3)
即求 N N N个形如 f ( x ) = ( x − b ) 2 + λ ∣ x ∣ f(x)=(x-b)^{2}+\lambda|x| f(x)=(x−b)2+λ∣x∣的函数的极小值。求这个函数的极值是不难的,直接对 x x x求导,可得
∂ f ( x ) ∂ x = 2 ( x − b ) + λ s g n ( x ) (4) \frac{\partial f(x)}{\partial x}=2(x-b)+\lambda sgn(x) \tag{4} ∂x∂f(x)=2(x−b)+λsgn(x)(4)
令其为0,即
2 ( x − b ) + λ s g n ( x ) = 0 (5) 2(x-b)+\lambda sgn(x)=0 \tag{5} 2(x−b)+λsgn(x)=0(5)
解得, x = b − λ 2 s g n ( x ) (6) x=b-\frac{\lambda}{2}sgn(x) \tag{6} x=b−2λsgn(x)(6)
s g n ( x ) sgn(x) sgn(x)的值取决于 x x x和0的大小关系,因此要分情况讨论求极值点。
a). 当 x > 0 时 x>0时 x>0时, s g n ( x ) = 1 , → sgn(x)=1,\rightarrow sgn(x)=1,→ x = b − λ 2 > 0 , → b > λ 2 x = b-\frac{\lambda}{2}>0,\rightarrow b>\frac{\lambda}{2} x=b−2λ>0,→b>2λ
b). 当 x < 0 时 x<0时 x<0时, s g n ( x ) = − 1 , → sgn(x)=-1,\rightarrow sgn(x)=−1,→ x = b + λ 2 < 0 , → b < − λ 2 x = b+\frac{\lambda}{2}<0,\rightarrow b<-\frac{\lambda}{2} x=b+2λ<0,→b<−2λ
因此, f ( x ) f(x) f(x)的极值点为
x ^ = { b + λ 2 , b ≤ − λ 2 0 , ∣ x ∣ < λ 2 b + λ 2 , b ≥ λ 2 (7) \hat{x}=\left\{ \begin{array}{c} b+\frac{\lambda}{2}, b\leq-\frac{\lambda}{2} \\ 0, |x|<\frac{\lambda}{2} \\ b+\frac{\lambda}{2}, b\geq \frac{\lambda}{2}\end{array}\right. \tag{7} x^=⎩⎨⎧b+2λ,b≤−2λ0,∣x∣<2λb+2λ,b≥2λ(7)
现在再来仔细看一眼式(7), 是不是觉得这个形式有点眼熟? 对啦! 式(7)和软阈值函数的表达式的形式(式2)是相同的。
所以,(1)式的解
X
^
\hat{X}
X^可以写成:
X
^
=
s
o
f
t
(
B
,
λ
2
)
.
(8)
\hat{X}= soft(B,\frac{\lambda}{2}). \tag{8}
X^=soft(B,2λ).(8)
到这里,迭代软阈值算法就介绍完毕了。不过这里面貌似没有迭代的操作,因此还需要结合压缩感知信号恢复的具体应用背景来研究。下面将介绍压缩感知恢复的概念,并将ISTA贯穿其中。
2. 压缩感知回复(Compressed Sensing Reconstruction)
压缩感知理论的基础知识这里不再赘述,仅从恢复压缩后的信号开始介绍。设一个线性量测
y
\mathbf{y}
y, 传统CS算法通过以下优化问题来恢复原始信号
x
\mathbf{x}
x:
a
r
g
m
i
n
1
2
∣
∣
Φ
x
−
y
∣
∣
2
2
+
λ
∣
∣
Ψ
x
∣
∣
1
(9)
arg~min \frac{1}{2}||\mathbf{\Phi x-y}||^{2}_{2}+\lambda||\mathbf{\Psi} x||_{1} \tag{9}
arg min21∣∣Φx−y∣∣22+λ∣∣Ψx∣∣1(9)
Ψ
\mathbf{\Psi}
Ψ为稀疏基矩阵,
Φ
\mathbf{\Phi}
Φ为传感矩阵, 即
Ψ
\mathbf{\Psi}
Ψ和观测矩阵(文中未显式表示)的矩阵乘积。注意,这里的符号表示和一般的压缩感知文献有区别, 一般的压缩感知文献
Φ
\mathbf{\Phi}
Φ往往单纯指观测矩阵, 而论文中明确提到
x
\mathbf{x}
x为original image, 而观测矩阵并不直接作用于原数据,而是原数据在稀疏域的映射,因此这里
Φ
\mathbf{\Phi}
Φ为传感矩阵。同样,(9)式的右边第一项为保真项,第二项为正则项。
下面的论文直接给出了两个迭代公式:
r
(
k
)
=
x
(
k
−
1
)
−
ρ
Φ
T
(
Φ
x
(
k
−
1
)
−
y
)
(10)
\mathbf{r}^{(k)}=\mathbf{x}^{(k-1)}-\rho\mathbf{\Phi}^{T}(\mathbf{\Phi} \mathbf{x}^{(k-1)}-\mathbf{y})\tag{10}
r(k)=x(k−1)−ρΦT(Φx(k−1)−y)(10)
x
(
k
)
=
a
r
g
x
m
i
n
1
2
∣
∣
x
−
r
(
k
)
∣
∣
2
2
+
λ
∣
∣
Ψ
x
∣
∣
1
(11)
\mathbf{x}^{(k)}=arg_{\mathbf{x}}~min\frac{1}{2}||\mathbf{x}-\mathbf{r}^{(k)}||^{2}_{2}+\lambda||\mathbf{\Psi x}||_{1}\tag{11}
x(k)=argx min21∣∣x−r(k)∣∣22+λ∣∣Ψx∣∣1(11)
这下坏了,对于优化问题不是非常了解的读者读到这里开始二和尚摸不着头了。这都什么鬼,怎么这么突兀?式(11)的形式跟式(9)很像, 但是式(10)中的
ρ
Φ
T
(
Φ
x
(
k
−
1
)
−
y
)
\rho\mathbf{\Phi}^{T}(\mathbf{\Phi} \mathbf{x}^{(k-1)}-\mathbf{y})
ρΦT(Φx(k−1)−y)这一项直接冒出来。论文由于篇幅所限没有详细解释这两个最重要的式子是怎么来的(默认读者是懂的),下面我来解释一下:
这里实际上用到了梯度下降的思想。式(10)中的 Φ T ( Φ x ( k − 1 ) − y ) \mathbf{\Phi}^{T}(\mathbf{\Phi} \mathbf{x}^{(k-1)}-\mathbf{y}) ΦT(Φx(k−1)−y)这个令人感到最突兀的一项其实就是保真项 1 2 ∣ ∣ x − r ( k ) ∣ ∣ 2 2 \frac{1}{2}||\mathbf{x}-\mathbf{r}^{(k)}||^{2}_{2} 21∣∣x−r(k)∣∣22在 x ( k − 1 ) \mathbf{x}^{(k-1)} x(k−1)处对 x \mathbf{x} x的梯度(后有证明)。式(10)的含义是:在第 k − 1 k-1 k−1次迭代中, x ( k − 1 ) \mathbf{x}^{(k-1)} x(k−1)向保真项的负梯度方向移动,步长为 ρ \rho ρ,得到的结果命名为 r ( k ) \mathbf{r}^{(k)} r(k)。这样操作后, r ( k ) \mathbf{r}^{(k)} r(k)比 x ( k − 1 ) \mathbf{x}^{(k-1)} x(k−1)向着保真项极小值点的方向更近了一步。式(11)的含义是: 寻找一个新的 x ( k ) \mathbf{x}^{(k)} x(k),使其逼近于上一次用式(9)算出来的 r ( k ) \mathbf{r}^{(k)} r(k)。算法的具体操作步骤为:首先初始化 x ( 0 ) \mathbf{x}^{(0)} x(0), 带入式(10), 计算处 r ( 1 ) \mathbf{r}^{(1)} r(1), 再将 r ( 1 ) \mathbf{r}^{(1)} r(1)带入式(11), 计算出 x ( 1 ) \mathbf{x}^{(1)} x(1), 以此类推。
最后,这里附上保真项梯度的计算证明:
设
x
=
[
x
1
,
x
2
,
.
.
.
,
x
N
]
\mathbf{x}=[x_1,x_2,...,x_N]
x=[x1,x2,...,xN],
y
=
[
y
1
,
y
2
,
.
.
.
,
y
M
]
\mathbf{y}=[y_1,y_2,...,y_M]
y=[y1,y2,...,yM],
Φ
=
ϕ
i
j
,
i
=
1
,
2
,
.
.
.
,
M
,
j
=
1
,
2
,
.
.
.
,
N
,
\mathbf{\Phi}=\phi_{ij},i=1,2,...,M,j=1,2,...,N,
Φ=ϕij,i=1,2,...,M,j=1,2,...,N,
M
≪
N
M\ll N
M≪N.
f
(
x
)
=
1
2
∣
∣
Φ
x
−
y
∣
∣
2
2
=
1
2
∑
i
=
1
M
[
(
∑
j
=
1
N
ϕ
i
j
x
j
)
−
y
i
]
2
(12)
f(\mathbf{x})=\frac{1}{2}||\mathbf{\Phi x-y}||^{2}_{2}\\=\frac{1}{2}\sum_{i=1}^{M}[(\sum_{j=1}^{N}\phi_{ij}x_j)-y_i]^{2}\tag{12}
f(x)=21∣∣Φx−y∣∣22=21i=1∑M[(j=1∑Nϕijxj)−yi]2(12)
设
x
t
∈
x
x_t \in \mathbf{x}
xt∈x, 则保真项式(12)对
x
t
x_t
xt的导数为:
∇
x
t
f
(
x
)
=
∂
f
(
x
)
∂
x
t
=
1
2
∑
i
=
1
M
[
(
∑
j
=
1
N
ϕ
i
j
x
j
)
−
y
i
]
2
∂
x
t
=
1
2
⋅
2
∑
i
=
1
M
[
(
∑
j
=
1
N
ϕ
i
j
x
j
)
−
y
i
]
⋅
∂
[
∑
j
=
1
N
ϕ
i
j
x
j
−
y
i
]
∂
x
t
=
∑
i
=
1
M
[
(
∑
j
=
1
N
ϕ
i
j
x
j
)
−
y
i
]
ϕ
i
t
=
∑
i
=
1
M
ϕ
i
t
[
Φ
x
−
y
]
i
=
∑
i
=
1
M
Φ
t
,
:
T
[
Φ
x
−
y
]
i
(13)
\nabla_{x_t}f(\mathbf{x}) = \frac{\partial f(\mathbf{x})}{\partial x_t}=\frac{1}{2}\frac{\sum_{i=1}^{M}[(\sum_{j=1}^{N}\phi_{ij}x_j)-y_i]^{2}}{\partial x_t}\\=\frac{1}{2}\cdot2 \sum_{i=1}^{M}[(\sum_{j=1}^{N}\phi_{ij}x_j)-y_i]\cdot\frac{\partial[\sum_{j=1}^{N}\phi_{ij}x_j-y_i]}{\partial x_t}\\=\sum_{i=1}^{M}[(\sum_{j=1}^{N}\phi_{ij}x_j)-y_i]\phi_{it}\\=\sum_{i=1}^{M}\phi_{it}[\mathbf{\Phi x}-\mathbf{y}]_{i} \\= \sum_{i=1}^{M}\mathbf{\Phi}^{T}_{t,:}[\mathbf{\Phi x}-\mathbf{y}]_i \tag{13}
∇xtf(x)=∂xt∂f(x)=21∂xt∑i=1M[(∑j=1Nϕijxj)−yi]2=21⋅2i=1∑M[(j=1∑Nϕijxj)−yi]⋅∂xt∂[∑j=1Nϕijxj−yi]=i=1∑M[(j=1∑Nϕijxj)−yi]ϕit=i=1∑Mϕit[Φx−y]i=i=1∑MΦt,:T[Φx−y]i(13)
∇
x
t
f
(
x
)
\nabla_{x_t}f(\mathbf{x})
∇xtf(x)是一个标量数字,为保真项关于
x
\mathbf{x}
x中的任意一项的梯度。而保真项对于整个
x
\mathbf{x}
x的梯度是一个
N
N
N维向量,即
∇
x
f
(
x
)
=
[
∇
x
1
f
(
x
)
,
∇
x
2
f
(
x
)
,
.
.
.
∇
x
N
f
(
x
)
]
T
=
Φ
T
(
Φ
x
−
y
)
(14)
\nabla_{\mathbf{x}}f(\mathbf{x}) = [\nabla_{x_1}f(\mathbf{x}) , \nabla_{x_2}f(\mathbf{x}) ,...\nabla_{x_N}f(\mathbf{x}) ]^{T}=\mathbf{\Phi}^{T}(\mathbf{\Phi x-y})\tag{14}
∇xf(x)=[∇x1f(x),∇x2f(x),...∇xNf(x)]T=ΦT(Φx−y)(14)
3. 软阈值函数在迭代算法中的作用
讲到这里,相信大部分读者对于这些概念已经明晰。我们在回过头来看一下上面两部分,有的读者可能会有这样的疑问: 第一部分解释了软阈值函数,第二部分提出了迭代算法,也没看出来这个迭代算法里用到了软阈值函数啊?这两部分有什么关系?这里就来解释一下这个问题。其实是在迭代算法中是用到了软阈值函数,但是没有写得那么直白。
式(11)其实还是一个优化问题,我们可以看到这个优化问题的形式其实跟式(1)是类似的,式(11)的解就是靠第一部分提到的软阈值函数来求解的。 注意啊,这只是类似,但是不一样!不能直接套用。所以作者在论文中的Equation.5-10, 都是通过各种手段把式(11)变成跟式(1)完全一样的形式,即:
x
(
k
)
=
a
r
g
x
m
i
n
1
2
∣
∣
ϝ
(
x
)
−
ϝ
(
r
(
k
)
)
∣
∣
2
2
+
θ
∣
∣
ϝ
(
x
)
∣
∣
1
(15)
\mathbf{x}^{(k)} = arg_{\mathbf{x}}~min\frac{1}{2}||\digamma(\mathbf{x})-\digamma(\mathbf{r}^{(k)})||^{2}_{2}+\theta||\digamma(\mathbf{x})||_{1}\tag{15}
x(k)=argx min21∣∣ϝ(x)−ϝ(r(k))∣∣22+θ∣∣ϝ(x)∣∣1(15)
的解为:
ϝ
(
x
(
k
)
)
=
s
o
f
t
(
ϝ
(
x
)
,
θ
)
(16)
\digamma{(\mathbf{x}^{(k)})} =soft(\digamma(\mathbf{x}),\theta)\tag{16}
ϝ(x(k))=soft(ϝ(x),θ)(16)
这里
ϝ
(
⋅
)
\digamma(\cdot)
ϝ(⋅)为两层卷积神经网络,作用等效于稀疏变换
Ψ
x
\mathbf{\Psi x}
Ψx的作用。
4. 对称结构 ϝ ~ ( x ( k ) ) \tilde{\digamma}{(\mathbf{x}^{(k)})} ϝ~(x(k))的作用
下面再简要解释一下原文中为什么要有用一个对称结构 ϝ ~ ( x ( k ) ) \tilde{\digamma}{(\mathbf{x}^{(k)})} ϝ~(x(k))?
这是因为式(16)得到的式 ϝ ( x ( k ) ) \digamma{(\mathbf{x}^{(k)})} ϝ(x(k)), 而我们真正需要的是 ( x ( k ) ) (\mathbf{x}^{(k)}) (x(k)),因此需要一个 ϝ ~ ( x ( k ) ) \tilde{\digamma}{(\mathbf{x}^{(k)})} ϝ~(x(k))起到 ϝ ( x ( k ) ) \digamma{(\mathbf{x}^{(k)})} ϝ(x(k))逆变换的作用,从 ϝ ( x ( k ) ) \digamma{(\mathbf{x}^{(k)})} ϝ(x(k))中反解出 ( x ( k ) ) (\mathbf{x}^{(k)}) (x(k))。