K-SVD算法matlab实现

K-SVD算法
转载自:K-SVDMatlab实现
自己改了之后就可以运行.
1.show.m

%***************************** read in the image **************************
img=imread('C:\Users\xxx\Desktop\x.jpg');
img0 = img;
img=double(img);
[N,n]=size(img); 
% keep an original copy of the input signal

%****************form the measurement matrix and Dictionary ***************
Phi=randn(N,n/3);   
Phi = Phi./repmat(sqrt(sum(Phi.^2,1)),[N,1]); % normalization for each column

%fix the parameters
param.L =20;   % number of elements in each linear combination.
param.K =150;  %number of dictionary elements
param.numIteration = 100; % number of iteration to execute the K-SVD algorithm.
% decompose signals until a certain error is reached. 
%do not use fix number of coefficients. 
%param.errorGoal = sigma;
param.errorFlag = 0; 

param.preserveDCAtom = 0;
param.InitializationMethod ='DataElements'; %initialization by the signals themselves
param.displayProgress = 1; % progress information is displyed.

for i = 1:250
    for j = 1:250
        Img(i,j) = img(i,j,1);
    end
end
disp(size(Img))

[Dictionary,output]= KSVD(Img,param); %Dictionary is N*param.K 

%************************ projection **************************************
y=Phi*Img;          % treat each column as a independent signal
y0=y;               % keep an original copy of the measurements

%********************* recover using OMP *********************************
D=Phi*Dictionary;
A=OMP(D,y,20);
imgr=Dictionary*A;  

%***********************  show the results  ******************************** 
figure(1)
subplot(2,2,1),imagesc(img0),title('original image')
subplot(2,2,2),imagesc(y0),title('measurement image')
subplot(2,2,3),imagesc(Dictionary),title('Dictionary')
psnr=20*log10(255/sqrt(mean((Img(:)-imgr(:)).^2)));
subplot(2,2,4),imagesc(imgr),title(strcat('recover image (',num2str(psnr),'dB)'))
disp('over')

2.ksvd.m

function [Dictionary,output] = KSVD(...
    Data,... % an nXN matrix that contins N signals (Y), each of dimension n.
    param)
% =========================================================================
%                          K-SVD algorithm
% =========================================================================
% The K-SVD algorithm finds a dictionary for linear representation of
% signals. Given a set of signals, it searches for the best dictionary that
% can sparsely represent each signal. Detailed discussion on the algorithm
% and possible applications can be found in "The K-SVD: An Algorithm for 
% Designing of Overcomplete Dictionaries for Sparse Representation", written
% by M. Aharon, M. Elad, and A.M. Bruckstein and appeared in the IEEE Trans. 
% On Signal Processing, Vol. 54, no. 11, pp. 4311-4322, November 2006. 
% =========================================================================
% INPUT ARGUMENTS:
% Data                         an nXN matrix that contins N signals (Y), each of dimension n. 
% param                        structure that includes all required
%                                 parameters for the K-SVD execution.
%                                 Required fields are:
%    K, ...                    the number of dictionary elements to train
%    numIteration,...          number of iterations to perform.
%    errorFlag...              if =0, a fix number of coefficients is
%                                 used for representation of each signal. If so, param.L must be
%                                 specified as the number of representing atom. if =1, arbitrary number
%                                 of atoms represent each signal, until a specific representation error
%                                 is reached. If so, param.errorGoal must be specified as the allowed
%                                 error.
%    preserveDCAtom...         if =1 then the first atom in the dictionary
%                                 is set to be constant, and does not ever change. This
%                                 might be useful for working with natural
%                                 images (in this case, only param.K-1
%                                 atoms are trained).
%    (optional, see errorFlag) L,...                 % maximum coefficients to use in OMP coefficient calculations.
%    (optional, see errorFlag) errorGoal, ...        % allowed representation error in representing each signal.
%    InitializationMethod,...  mehtod to initialize the dictionary, can
%                                 be one of the following arguments: 
%                                 * 'DataElements' (initialization by the signals themselves), or: 
%                                 * 'GivenMatrix' (initialization by a given matrix param.initialDictionary).
%    (optional, see InitializationMethod) initialDictionary,...      % if the initialization method 
%                                 is 'GivenMatrix', this is the matrix that will be used.
%    (optional) TrueDictionary, ...        % if specified, in each
%                                 iteration the difference between this dictionary and the trained one
%                                 is measured and displayed.
%    displayProgress, ...      if =1 progress information is displyed. If param.errorFlag==0, 
%                                 the average repersentation error (RMSE) is displayed, while if 
%                                 param.errorFlag==1, the average number of required coefficients for 
%                                 representation of each signal is displayed.
% =========================================================================
% OUTPUT ARGUMENTS:
%  Dictionary                  The extracted dictionary of size nX(param.K).
%  output                      Struct that contains information about the current run. It may include the following fields:
%    CoefMatrix                  The final coefficients matrix (it should hold that Data equals approximately Dictionary*output.CoefMatrix.
%    ratio                       If the true dictionary was defined (in
%                                synthetic experiments), this parameter holds a vector of length
%                                param.numIteration that includes the detection ratios in each
%                                iteration).
%    totalerr                    The total representation error after each
%                                iteration (defined only if
%                                param.displayProgress=1 and
%                                param.errorFlag = 0)
%    numCoef                     A vector of length param.numIteration that
%                                include the average number of coefficients required for representation
%                                of each signal (in each iteration) (defined only if
%                                param.displayProgress=1 and
%                                param.errorFlag = 1)
% =========================================================================

%%********************* Dimension reduction ****************************
for i = 1:1:250
    for j = 1:1:250
        Data(i,j) = Data(i,j,1);
    end
end

if (~isfield(param,'displayProgress'))
    param.displayProgress = 0;
end
totalerr(1) = 99999;
if (isfield(param,'errorFlag')==0)
    param.errorFlag = 0;
end
 
if (isfield(param,'TrueDictionary'))
    displayErrorWithTrueDictionary = 1;
    ErrorBetweenDictionaries = zeros(param.numIteration+1,1); %产生零矩阵
    ratio = zeros(param.numIteration+1,1);
else
    displayErrorWithTrueDictionary = 0;
	ratio = 0;
end
if (param.preserveDCAtom>0)
    FixedDictionaryElement(1:size(Data,1),1) = 1/sqrt(size(Data,1));
else
    FixedDictionaryElement = [];
end
% coefficient calculation method is OMP with fixed number of coefficients
 
if (size(Data,2) < param.K)
    disp('Size of data is smaller than the dictionary size. Trivial solution...');
    Dictionary = Data(:,1:size(Data,2));
    return;
elseif (strcmp(param.InitializationMethod,'DataElements'))
    Dictionary(:,1:param.K-param.preserveDCAtom) = Data(:,1:param.K-param.preserveDCAtom);
elseif (strcmp(param.InitializationMethod,'GivenMatrix'))
    Dictionary(:,1:param.K-param.preserveDCAtom) = param.initialDictionary(:,1:param.K-param.preserveDCAtom);
end
% reduce the components in Dictionary that are spanned by the fixed
% elements
if (param.preserveDCAtom)
    tmpMat = FixedDictionaryElement \ Dictionary;
    Dictionary = Dictionary - FixedDictionaryElement*tmpMat;
end
%normalize the dictionary.
Dictionary = Dictionary*diag(1./sqrt(sum(Dictionary.*Dictionary)));
Dictionary = Dictionary.*repmat(sign(Dictionary(1,:)),size(Dictionary,1),1); % multiply in the sign of the first element.
totalErr = zeros(1,param.numIteration);
 
% the K-SVD algorithm starts here.

for iterNum = 1:param.numIteration
    % find the coefficients
    if (param.errorFlag==0)
        % CoefMatrix = mexOMPIterative2(Data, [FixedDictionaryElement,Dictionary],param.L);
        CoefMatrix = OMP([FixedDictionaryElement,Dictionary],Data, param.L);
    else 
        %CoefMatrix = mexOMPerrIterative(Data, [FixedDictionaryElement,Dictionary],param.errorGoal);
        CoefMatrix = OMPerr([FixedDictionaryElement,Dictionary],Data, param.errorGoal);
        param.L = 1;
    end
    
    replacedVectorCounter = 0;
	rPerm = randperm(size(Dictionary,2));
    for j = rPerm
        [betterDictionaryElement,CoefMatrix,addedNewVector] = I_findBetterDictionaryElement(Data,...
            [FixedDictionaryElement,Dictionary],j+size(FixedDictionaryElement,2),...
            CoefMatrix ,param.L);
        Dictionary(:,j) = betterDictionaryElement;
        if (param.preserveDCAtom)
            tmpCoef = FixedDictionaryElement\betterDictionaryElement;
            Dictionary(:,j) = betterDictionaryElement - FixedDictionaryElement*tmpCoef;
            Dictionary(:,j) = Dictionary(:,j)./sqrt(Dictionary(:,j)'*Dictionary(:,j));
        end
        replacedVectorCounter = replacedVectorCounter+addedNewVector;
    end
    if (iterNum>1 & param.displayProgress)
        if (param.errorFlag==0)
            output.totalerr(iterNum-1) = sqrt(sum(sum((Data-[FixedDictionaryElement,Dictionary]*CoefMatrix).^2))/prod(size(Data)));
            disp(['Iteration   ',num2str(iterNum),'   Total error is: ',num2str(output.totalerr(iterNum-1))]);
        else
            output.numCoef(iterNum-1) = length(find(CoefMatrix))/size(Data,2);
            disp(['Iteration   ',num2str(iterNum),'   Average number of coefficients: ',num2str(output.numCoef(iterNum-1))]);
        end
    end
    if (displayErrorWithTrueDictionary ) 
        [ratio(iterNum+1),ErrorBetweenDictionaries(iterNum+1)] = I_findDistanseBetweenDictionaries(param.TrueDictionary,Dictionary);
        disp(strcat(['Iteration  ', num2str(iterNum),' ratio of restored elements: ',num2str(ratio(iterNum+1))]));
        output.ratio = ratio;
    end
    Dictionary = I_clearDictionary(Dictionary,CoefMatrix(size(FixedDictionaryElement,2)+1:end,:),Data);
    
    if (isfield(param,'waitBarHandle'))
        waitbar(iterNum/param.counterForWaitBar);
    end
end
output.CoefMatrix = CoefMatrix;
Dictionary = [FixedDictionaryElement,Dictionary];

3.OMP.m

function [A]=OMP(D,X,L); 
%=============================================
% Sparse coding of a group of signals based on a given 
% dictionary and specified number of atoms to use. 
% input arguments: 
%       D - the dictionary (its columns MUST be normalized).
%       X - the signals to represent
%       L - the max. number of coefficients for each signal.
% output arguments: 
%       A - sparse coefficient matrix.
%=============================================
[n,K]=size(D);
[n,P]=size(X);
for k=1:1:P,
    a=[];
    x=X(:,k);%令向量x等于矩阵X的第K列的元素长度为n*1
    residual=x;%n*1
    indx=zeros(L,1);%L*1的0矩阵
    for j=1:1:L,
        proj=D'*residual;%K*n n*1 变成K*1
        [maxVal,pos]=max(abs(proj));%  最大投影系数对应的位置
        pos=pos(1);
        indx(j)=pos; 
        a=pinv(D(:,indx(1:j)))*x;
        residual=x-D(:,indx(1:j))*a;
        if sum(residual.^2) < 1e-6
            break;
        end
    end;
    temp=zeros(K,1);
    temp(indx(1:j))=a;
    A(:,k)=sparse(temp);%A为返回为K*P的矩阵
end;
return;

4.I_findDistanseBetweenDictionaries.m

function [A]=OMP(D,X,L); 
%=============================================
% Sparse coding of a group of signals based on a given 
% dictionary and specified number of atoms to use. 
% input arguments: 
%       D - the dictionary (its columns MUST be normalized).
%       X - the signals to represent
%       L - the max. number of coefficients for each signal.
% output arguments: 
%       A - sparse coefficient matrix.
%=============================================
[n,K]=size(D);
[n,P]=size(X);
for k=1:1:P,
    a=[];
    x=X(:,k);%令向量x等于矩阵X的第K列的元素长度为n*1
    residual=x;%n*1
    indx=zeros(L,1);%L*1的0矩阵
    for j=1:1:L,
        proj=D'*residual;%K*n n*1 变成K*1
        [maxVal,pos]=max(abs(proj));%  最大投影系数对应的位置
        pos=pos(1);
        indx(j)=pos; 
        a=pinv(D(:,indx(1:j)))*x;
        residual=x-D(:,indx(1:j))*a;
        if sum(residual.^2) < 1e-6
            break;
        end
    end;
    temp=zeros(K,1);
    temp(indx(1:j))=a;
    A(:,k)=sparse(temp);%A为返回为K*P的矩阵
end;
return;

5.I_findBetterDictionaryElement.m

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%  findBetterDictionaryElement
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [betterDictionaryElement,CoefMatrix,NewVectorAdded] = I_findBetterDictionaryElement(Data,Dictionary,j,CoefMatrix,numCoefUsed)
if (length(who('numCoefUsed'))==0)
    numCoefUsed = 1;
end
relevantDataIndices = find(CoefMatrix(j,:)); % the data indices that uses the j'th dictionary element.
if (length(relevantDataIndices)<1) %(length(relevantDataIndices)==0)
    ErrorMat = Data-Dictionary*CoefMatrix;
    ErrorNormVec = sum(ErrorMat.^2);
    [d,i] = max(ErrorNormVec);
    betterDictionaryElement = Data(:,i);%ErrorMat(:,i); %
    betterDictionaryElement = betterDictionaryElement./sqrt(betterDictionaryElement'*betterDictionaryElement);
    betterDictionaryElement = betterDictionaryElement.*sign(betterDictionaryElement(1));
    CoefMatrix(j,:) = 0;
    NewVectorAdded = 1;
    return;
end
NewVectorAdded = 0;
tmpCoefMatrix = CoefMatrix(:,relevantDataIndices); 
tmpCoefMatrix(j,:) = 0;% the coeffitients of the element we now improve are not relevant.
errors =(Data(:,relevantDataIndices) - Dictionary*tmpCoefMatrix); % vector of errors that we want to minimize with the new element
% % the better dictionary element and the values of beta are found using svd.
% % This is because we would like to minimize || errors - beta*element ||_F^2. 
% % that is, to approximate the matrix 'errors' with a one-rank matrix. This
% % is done using the largest singular value.
[betterDictionaryElement,singularValue,betaVector] = svds(errors,1);
CoefMatrix(j,relevantDataIndices) = singularValue*betaVector';% *signOfFirstElem

6.I_clearDictionary.m

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%  I_clearDictionary
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function Dictionary = I_clearDictionary(Dictionary,CoefMatrix,Data)
T2 = 0.99;
T1 = 3;
K=size(Dictionary,2);

Er=sum((Data-Dictionary*CoefMatrix).^2,1);   % remove identical atoms
G=Dictionary'*Dictionary; G = G-diag(diag(G));
for jj=1:1:K,
    if max(G(jj,:))>T2 || length(find(abs(CoefMatrix(jj,:))>1e-7))<=T1 ,
        [val,pos]=max(Er);
        Er(pos(1))=0;
        Dictionary(:,jj)=Data(:,pos(1))/norm(Data(:,pos(1)));
        G=Dictionary'*Dictionary; G = G-diag(diag(G));
    end;
end;
  • 5
    点赞
  • 102
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
### 回答1: K-SVD字典学习算法是一种基于稀疏表示的字典学习算法,它可以用于信号处理、图像处理、语音识别等领域。该算法的核心思想是通过迭代更新字典和稀疏表示系数,使得字典能够更好地表示原始信号。在MATLAB中,可以使用K-SVD工具箱来实现算法。 ### 回答2: K-SVD字典学习算法是一种基于稀疏编码思想的字典学习算法,通过学习数据集中的原子信息,构建出一个由原子组成的字典,使得该字典能够最好地表示数据集。在该算法中,将待处理的数据进行稀疏表示,通过迭代优化更新字典,直到收敛为止。 K-SVD算法的步骤如下:首先将待处理的数据进行预处理,通过去除均值和进行归一化,然后将数据进行单位化。接着,确定字典的初始值,可以使用随机矩阵或者先验知识初始化,然后对于每一个样本,使用一个稀疏编码方式求得其系数向量,并且根据系数向量进行字典的更新。 在K-SVD字典学习算法中,字典的更新使用乘法更新方法,通过逐个原子的迭代更新字典矩阵,重新计算每个训练样本的系数向量和重构误差。当重构误差收敛时,迭代结束,得到最终的字典矩阵。 K-SVD字典学习算法在图像压缩、人脸识别、图像修复等领域得到了广泛应用。在MATLAB中,该算法可通过调用spams包中的ksvd方法进行实现。通过调用相应的函数接口,即可实现K-SVD字典学习算法。 ### 回答3: K-SVD字典学习算法是一种基于稀疏表示的字典学习方法,其目的是通过学习稀疏表示的整合来生成能够最好地表达信号的基本元素集合。此算法最初由Aharon等人于2006年提出,被证明是一种具有高准确性和鲁棒性的字典学习方法。 K-SVD算法的核心是基于块稀疏表示方法的矩阵分解,该方法能够将高维信号表示成一些高度凝聚的字典元素。通过重复更新字典和系数矩阵,不断优化整个分解过程,以提高字典表示能力。 在Matlab中,可以利用SparseLab工具箱轻松地实现K-SVD算法。对于K-SVD字典学习的首要步骤是字典初始化。通常将字典元素初始化为信号数据中的随机样本,并通过k-means算法进行聚类,确立成字典的原子(字典元素)。之后,通过矩阵的分解,进行稀疏表示并通过迭代优化过程,持续更新字典和系数矩阵,最终生成一个可以完美表示信号的字典。 总之,K-SVD算法是一种行之有效的字典学习技术,在很多领域都有广泛的应用,例如语音处理,图像处理和信号处理等。在Matlab上,我们可以利用基于SparseLab工具箱的K-SVD实现高效、灵活且精确的字典学习。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值