摘要:传统的压缩感知方法在重构时的速度通常比较慢。通过将深度学习和压缩感知结合,可以大大提高重构速度。Learned Iterative Shrinkage and Thresholding Algorithm (LISTA)应该是用深度学习方法求解压缩感知最早的方法,本文简单总结一下LISTA,并给出ISTA和LISTA的具体实现,分别使用Python和Pytorch,并对仿真的稀疏信号进行重构。
参考文献
【1】A Fast Iterative Shrinkage-Thresholding Algorithm
【2】Learning Fast Approximations of Sparse Coding
【3】近端梯度下降方法及三个简单例子,软阈值,硬阈值和ReLU。
目录
- Learning Fast Approximations of Sparse Coding 论文阅读
- ISTA算法Python实现
- LISTA算法Pytorch实现
- 一个稀疏重构的算例
1. Learning Fast Approximations of Sparse Coding 论文阅读
摘要:稀疏编码中,输入向量通过稀疏基向量的线性组合进行重构。稀疏编码已经成为从数据中获取特征的流行方法。对于给定的输入向量,稀疏编码最小化二次重构误差以及编码的一范数约束。这个过程在实际应用中通常很慢,例如实时的模式识别。本文我们给出两种能给出稀疏编码的快速算法,可以用于特征提取或用于初始化特定的迭代算法。主要思想是训练一个具有特定结构和固定深度的非线性前向估计器,用于估计稀疏编码的最佳近似。这里只关注LISTA,不关注另一种方法。
方法:ISTA算法伪代码如下,其中 X X X是测量值, W d W_d Wd是字典矩阵, Z Z Z是稀疏编码, α \alpha α是稀疏相关系数, L L L是Lipschitz常量。
算法推导可以通过近端梯度下降方法【3】得到。将ISTA加入动量,很容易就能扩展为FISTA方法【1】。
ISTA和LISTA的算法框图如下。其中,上图是ISTA的框图,其中的符号和算法伪代码中一致;下图为LISTA,全部是全连接结构,激活函数采用ISTA的shrinkage function,
W
W
W和
S
S
S都是通过训练学习到的。
下面通过代码进一步理解ISTA和LISTA。
2. ISTA算法python实现
使用的符号和上面的伪代码一致。
定义shrinkage function和ISTA的迭代,迭代结果返回稀疏解以及重构误差。算例和LISTA一起在最后给出。
import numpy as np
def shrinkage(x, theta):
return np.multiply(np.sign(x), np.maximum(np.abs(x) - theta, 0))
def ista(X, W_d, a, L, max_iter, eps):
eig, eig_vector = np.linalg.eig(W_d.T * W_d)
assert L > np.max(eig)
del eig, eig_vector
W_e = W_d.T / L
recon_errors = []
Z_old = np.zeros((W_d.shape[1], 1))
for i in range(max_iter):
temp = W_d * Z_old - X
Z_new = shrinkage(Z_old - W_e * temp, a / L)
if np.sum(np.abs(Z_new - Z_old)) <= eps: break
Z_old = Z_new
recon_error = np.linalg.norm(X - W_d * Z_new, 2) ** 2
recon_errors.append(recon_error)
return Z_new, recon_errors
3. LISTA算法Pytorch实现
按照上面的算法框图进行,网络只有两个学习参数W和S,但需要在一步优化中重复迭代多次,有代码中的max_iter确定。
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class LISTA(nn.Module):
def __init__(self, n, m, W_e, max_iter, L, theta):
"""
# Arguments
n: int, dimensions of the measurement
m: int, dimensions of the sparse signal
W_e: array, dictionary
max_iter:int, max number of internal iteration
L: Lipschitz const
theta: Thresholding
"""
super(LISTA, self).__init__()
self._W = nn.Linear(in_features=n, out_features=m, bias=False)
self._S = nn.Linear(in_features=m, out_features=m,
bias=False)
self.shrinkage = nn.Softshrink(theta)
self.theta = theta
self.max_iter = max_iter
self.A = W_e
self.L = L
# weights initialization based on the dictionary
def weights_init(self):
A = self.A.cpu().numpy()
L = self.L
S = torch.from_numpy(np.eye(A.shape[1]) - (1/L)*np.matmul(A.T, A))
S = S.float().to(device)
W = torch.from_numpy((1/L)*A.T)
W = W.float().to(device)
self._S.weight = nn.Parameter(S)
self._W.weight = nn.Parameter(W)
def forward(self, y):
x = self.shrinkage(self._W(y))
if self.max_iter == 1 :
return x
for iter in range(self.max_iter):
x = self.shrinkage(self._W(y) + self._S(x))
return x
def train_lista(Y, dictionary, a, L, max_iter=30):
n, m = dictionary.shape
n_samples = Y.shape[0]
batch_size = 128
steps_per_epoch = n_samples // batch_size
# convert the data into tensors
Y = torch.from_numpy(Y)
Y = Y.float().to(device)
W_d = torch.from_numpy(dictionary)
W_d = W_d.float().to(device)
net = LISTA(n, m, W_d, max_iter=30, L=L, theta=a/L)
net = net.float().to(device)
net.weights_init()
# build the optimizer and criterion
learning_rate = 1e-2
criterion1 = nn.MSELoss()
criterion2 = nn.L1Loss()
all_zeros = torch.zeros(batch_size, m).to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
loss_list = []
for epoch in range(100):
index_samples = np.random.choice(a=n_samples, size=n_samples, replace=False, p=None)
Y_shuffle = Y[index_samples]
for step in range(steps_per_epoch):
Y_batch = Y_shuffle[step*batch_size:(step+1)*batch_size]
optimizer.zero_grad()
# get the outputs
X_h = net(Y_batch)
Y_h = torch.mm(X_h, W_d.T)
# compute the losss
loss1 = criterion1(Y_batch.float(), Y_h.float())
loss2 = a * criterion2(X_h.float(), all_zeros.float())
loss = loss1 + loss2
loss.backward()
optimizer.step()
with torch.no_grad():
loss_list.append(loss.detach().data)
return net, loss_list
4. 一个稀疏重构的算例
对于LISTA,需要额外多一个训练的步骤。
下面先进行LISTA的训练。
稀疏信号维度为1000,测量信号维度为256,信号稀疏度为5,使用5000个训练样本。
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import orth
# dimensions of the sparse signal, measurement and sparsity
m, n, k = 1000, 256, 5
# number of test examples
N = 5000
# generate dictionary
Psi = np.eye(m)
Phi = np.random.randn(n, m)
Phi = np.transpose(orth(np.transpose(Phi)))
W_d = np.dot(Phi, Psi)
print(W_d.shape)
# generate sparse signal Z and measurement X
Z = np.zeros((N, m))
X = np.zeros((N, n))
for i in range(N):
index_k = np.random.choice(a=m, size=k, replace=False, p=None)
Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1,])
X[i] = np.dot(W_d, Z[i, :])
print(X.shape)
print(X[0].shape)
# computing average reconstruction-SNR
net, err_list = train_lista(X, W_d, 0.1, 2)
通过训练,得到用于重构稀疏解的网络,将其和ISTA算法的测试结果进行对比。
# Test stage
# generate sparse signal Z and measurement X
Z = np.zeros((1, m))
X = np.zeros((1, n))
for i in range(1):
index_k = np.random.choice(a=m, size=k, replace=False, p=None)
Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1,])
X[i] = np.dot(W_d, Z[i, :])
Z_recon = net(torch.from_numpy(X).float().to(device))
Z_recon = Z_recon.detach().cpu().numpy()
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(X[0])
plt.subplot(2,1,2)
plt.plot(Z[0], label='real')
plt.subplot(2,1,2)
plt.plot(Z_recon[0], '.-', label='LISTA')
# ISTA
Z_recon, recon_errors = ista(np.mat(X).T, np.mat(W_d), 0.1, 2, 1000, 0.00001)
plt.subplot(2, 1, 2)
plt.plot(Z_recon, '--', label='ISTA')
plt.legend()
结果如下图所示,这里两种方法都很好的算出稀疏解。上子图是观测信号,下子图为稀疏信号。