深度学习:深度压缩感知-从ISTA到LISTA及其pytorch实现方法

摘要:传统的压缩感知方法在重构时的速度通常比较慢。通过将深度学习和压缩感知结合,可以大大提高重构速度。Learned Iterative Shrinkage and Thresholding Algorithm (LISTA)应该是用深度学习方法求解压缩感知最早的方法,本文简单总结一下LISTA,并给出ISTA和LISTA的具体实现,分别使用Python和Pytorch,并对仿真的稀疏信号进行重构。

参考文献

【1】A Fast Iterative Shrinkage-Thresholding Algorithm

【2】Learning Fast Approximations of Sparse Coding

【3】近端梯度下降方法及三个简单例子,软阈值,硬阈值和ReLU。

目录

  1. Learning Fast Approximations of Sparse Coding 论文阅读
  2. ISTA算法Python实现
  3. LISTA算法Pytorch实现
  4. 一个稀疏重构的算例

1. Learning Fast Approximations of Sparse Coding 论文阅读

摘要:稀疏编码中,输入向量通过稀疏基向量的线性组合进行重构。稀疏编码已经成为从数据中获取特征的流行方法。对于给定的输入向量,稀疏编码最小化二次重构误差以及编码的一范数约束。这个过程在实际应用中通常很慢,例如实时的模式识别。本文我们给出两种能给出稀疏编码的快速算法,可以用于特征提取或用于初始化特定的迭代算法。主要思想是训练一个具有特定结构和固定深度的非线性前向估计器,用于估计稀疏编码的最佳近似。这里只关注LISTA,不关注另一种方法。

方法:ISTA算法伪代码如下,其中 X X X是测量值, W d W_d Wd是字典矩阵, Z Z Z是稀疏编码, α \alpha α是稀疏相关系数, L L L是Lipschitz常量。

算法推导可以通过近端梯度下降方法【3】得到。将ISTA加入动量,很容易就能扩展为FISTA方法【1】。
在这里插入图片描述
ISTA和LISTA的算法框图如下。其中,上图是ISTA的框图,其中的符号和算法伪代码中一致;下图为LISTA,全部是全连接结构,激活函数采用ISTA的shrinkage function, W W W S S S都是通过训练学习到的。
在这里插入图片描述
下面通过代码进一步理解ISTA和LISTA。

2. ISTA算法python实现

使用的符号和上面的伪代码一致。

定义shrinkage function和ISTA的迭代,迭代结果返回稀疏解以及重构误差。算例和LISTA一起在最后给出。

import numpy as np

def shrinkage(x, theta):
    return np.multiply(np.sign(x), np.maximum(np.abs(x) - theta, 0))

def ista(X, W_d, a, L, max_iter, eps):

    eig, eig_vector = np.linalg.eig(W_d.T * W_d)
    assert L > np.max(eig)
    del eig, eig_vector
    
    W_e = W_d.T / L

    recon_errors = []
    Z_old = np.zeros((W_d.shape[1], 1))
    for i in range(max_iter):
        temp = W_d * Z_old - X
        Z_new = shrinkage(Z_old - W_e * temp, a / L)
        if np.sum(np.abs(Z_new - Z_old)) <= eps: break
        Z_old = Z_new
        recon_error = np.linalg.norm(X - W_d * Z_new, 2) ** 2
        recon_errors.append(recon_error)
        
    return Z_new, recon_errors

3. LISTA算法Pytorch实现

按照上面的算法框图进行,网络只有两个学习参数W和S,但需要在一步优化中重复迭代多次,有代码中的max_iter确定。

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class LISTA(nn.Module):
    def __init__(self, n, m, W_e, max_iter, L, theta):
        """
        # Arguments
            n: int, dimensions of the measurement
            m: int, dimensions of the sparse signal
            W_e: array, dictionary
            max_iter:int, max number of internal iteration
            L: Lipschitz const 
            theta: Thresholding
        """
        
        super(LISTA, self).__init__()
        self._W = nn.Linear(in_features=n, out_features=m, bias=False)
        self._S = nn.Linear(in_features=m, out_features=m,
                            bias=False)
        self.shrinkage = nn.Softshrink(theta)
        self.theta = theta
        self.max_iter = max_iter
        self.A = W_e
        self.L = L
        
    # weights initialization based on the dictionary
    def weights_init(self):
        A = self.A.cpu().numpy()
        L = self.L
        S = torch.from_numpy(np.eye(A.shape[1]) - (1/L)*np.matmul(A.T, A))
        S = S.float().to(device)
        W = torch.from_numpy((1/L)*A.T)
        W = W.float().to(device)
        
        self._S.weight = nn.Parameter(S)
        self._W.weight = nn.Parameter(W)


    def forward(self, y):
        x = self.shrinkage(self._W(y))

        if self.max_iter == 1 :
            return x

        for iter in range(self.max_iter):
            x = self.shrinkage(self._W(y) + self._S(x))

        return x

def train_lista(Y, dictionary, a, L, max_iter=30):
    
    n, m = dictionary.shape
    n_samples = Y.shape[0]
    batch_size = 128
    steps_per_epoch = n_samples // batch_size
    
    # convert the data into tensors
    Y = torch.from_numpy(Y)
    Y = Y.float().to(device)
    
    W_d = torch.from_numpy(dictionary)
    W_d = W_d.float().to(device)

    net = LISTA(n, m, W_d, max_iter=30, L=L, theta=a/L)
    net = net.float().to(device)
    net.weights_init()

    # build the optimizer and criterion
    learning_rate = 1e-2
    criterion1 = nn.MSELoss()
    criterion2 = nn.L1Loss()
    all_zeros = torch.zeros(batch_size, m).to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate,  momentum=0.9)

    loss_list = []
    for epoch in range(100):
        index_samples = np.random.choice(a=n_samples, size=n_samples, replace=False, p=None)
        Y_shuffle = Y[index_samples]
        for step in range(steps_per_epoch):
            Y_batch = Y_shuffle[step*batch_size:(step+1)*batch_size]
            optimizer.zero_grad()
    
            # get the outputs
            X_h = net(Y_batch)
            Y_h = torch.mm(X_h, W_d.T)
    
            # compute the losss
            loss1 = criterion1(Y_batch.float(), Y_h.float()) 
            loss2 = a * criterion2(X_h.float(), all_zeros.float())
            loss = loss1 + loss2
            
            loss.backward()
            optimizer.step()  
    
            with torch.no_grad():
                loss_list.append(loss.detach().data)     
            
    return net, loss_list

4. 一个稀疏重构的算例

对于LISTA,需要额外多一个训练的步骤。

下面先进行LISTA的训练。

稀疏信号维度为1000,测量信号维度为256,信号稀疏度为5,使用5000个训练样本。

import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import orth

# dimensions of the sparse signal, measurement and sparsity
m, n, k = 1000, 256, 5
# number of test examples
N = 5000

# generate dictionary
Psi = np.eye(m)
Phi = np.random.randn(n, m)
Phi = np.transpose(orth(np.transpose(Phi)))
W_d = np.dot(Phi, Psi)
print(W_d.shape)

# generate sparse signal Z and measurement X
Z = np.zeros((N, m))
X = np.zeros((N, n))
for i in range(N):
    index_k = np.random.choice(a=m, size=k, replace=False, p=None)
    Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1,])
    X[i] = np.dot(W_d, Z[i, :])

print(X.shape)
print(X[0].shape)

# computing average reconstruction-SNR
net, err_list = train_lista(X, W_d, 0.1, 2)

通过训练,得到用于重构稀疏解的网络,将其和ISTA算法的测试结果进行对比。

# Test stage
# generate sparse signal Z and measurement X
Z = np.zeros((1, m))
X = np.zeros((1, n))
for i in range(1):
    index_k = np.random.choice(a=m, size=k, replace=False, p=None)
    Z[i, index_k] = 5 * np.random.randn(k, 1).reshape([-1,])
    X[i] = np.dot(W_d, Z[i, :])

Z_recon = net(torch.from_numpy(X).float().to(device))
Z_recon = Z_recon.detach().cpu().numpy()
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(X[0])
plt.subplot(2,1,2)
plt.plot(Z[0], label='real')
plt.subplot(2,1,2)
plt.plot(Z_recon[0], '.-', label='LISTA')

# ISTA
Z_recon, recon_errors = ista(np.mat(X).T, np.mat(W_d), 0.1, 2, 1000, 0.00001)
plt.subplot(2, 1, 2)
plt.plot(Z_recon, '--', label='ISTA')
plt.legend()

结果如下图所示,这里两种方法都很好的算出稀疏解。上子图是观测信号,下子图为稀疏信号。
在这里插入图片描述

  • 35
    点赞
  • 191
    收藏
    觉得还不错? 一键收藏
  • 97
    评论
ISTA(迭代收缩阈值算法)是一种常用于稀疏表示的优化算法,以下是一个简单的 Matlab 代码示例: ``` % 定义稀疏表示问题的矩阵和向量 A = randn(100, 200); % 稀疏矩阵 x = sprandn(200, 1, 0.1); % 稀疏向量 b = A*x; % 观测值 % 定义ISTA算法参数 lambda = 0.1; % 正则化参数 alpha = max(eig(A'*A)); % 步长参数 max_iter = 100; % 最大迭代次数 % 初始化ISTA算法的参数 x0 = zeros(size(x)); % 初始值 % 迭代过程 for i=1:max_iter % 计算梯度 grad = A'*(A*x0 - b); % 更新参数 x1 = soft_threshold(x0 - alpha*grad, lambda*alpha); % 计算收敛误差 err = norm(x1 - x)/norm(x); % 打印当前迭代结果 fprintf('Iteration %d: error = %f\n', i, err); % 更新迭代参数 x0 = x1; end % 定义软阈值函数 function y = soft_threshold(x, lambda) y = sign(x).*max(abs(x) - lambda, 0); end ``` 在这个示例中,我们首先定义了一个稀疏表示问题,其中 $A$ 是一个 $100 \times 200$ 的稀疏矩阵,$x$ 是一个稀疏向量,$b$ 是观测值。我们使用 ISTA 算法来求解这个问题。 我们定义了 ISTA 算法的参数,包括正则化参数 $\lambda$、步长参数 $\alpha$ 和最大迭代次数。然后,我们初始化 ISTA 算法的参数 $x_0$ 为全零向量,并开始迭代。 在每次迭代中,我们首先计算梯度 $\nabla f(x)$,然后更新参数 $x$。在更新参数之后,我们计算收敛误差,并打印出当前迭代的结果。最后,我们更新迭代参数 $x_0$。 在这个示例中,我们使用了一个软阈值函数来实现 ISTA 算法的阈值操作。这个函数接受两个参数 $x$ 和 $\lambda$,并返回一个软阈值后的结果 $y$。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 97
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值