MATLAB 使用少量样本的K-SVD过完备字典学习代码(中文备注)

MATLAB 使用少量样本的K-SVD过完备字典学习代码(中文备注)

最近在学习稀疏表示与压缩感知理论,在使用K-SVD方面我遇到了一个问题,就是手上样本数量很少,但是需要构建过完备字典,我看到好像很多人也有这个问题。
目前在论坛里看到流传的K-SVD的MATLAB代码主要是下面这两个,都需要超过样本长度的样本数才能训练过完备字典:

dictionary learning tools中的K-SVD算法代码(转自K-SVD算法

function [Dictionary,output] = KSVD(...
    Data,... % an nXN matrix that contins N signals (Y), each of dimension n.
    param)
% =========================================================================
%                          K-SVD algorithm
% =========================================================================
% The K-SVD algorithm finds a dictionary for linear representation of
% signals. Given a set of signals, it searches for the best dictionary that
% can sparsely represent each signal. Detailed discussion on the algorithm
% and possible applications can be found in "The K-SVD: An Algorithm for 
% Designing of Overcomplete Dictionaries for Sparse Representation", written
% by M. Aharon, M. Elad, and A.M. Bruckstein and appeared in the IEEE Trans. 
% On Signal Processing, Vol. 54, no. 11, pp. 4311-4322, November 2006. 
% =========================================================================
% INPUT ARGUMENTS:
% Data                         an nXN matrix that contins N signals (Y), each of dimension n. 
% param                        structure that includes all required
%                                 parameters for the K-SVD execution.
%                                 Required fields are:
%    K, ...                    the number of dictionary elements to train
%    numIteration,...          number of iterations to perform.
%    errorFlag...              if =0, a fix number of coefficients is
%                                 used for representation of each signal. If so, param.L must be
%                                 specified as the number of representing atom. if =1, arbitrary number
%                                 of atoms represent each signal, until a specific representation error
%                                 is reached. If so, param.errorGoal must be specified as the allowed
%                                 error.
%    preserveDCAtom...         if =1 then the first atom in the dictionary
%                                 is set to be constant, and does not ever change. This
%                                 might be useful for working with natural
%                                 images (in this case, only param.K-1
%                                 atoms are trained).
%    (optional, see errorFlag) L,...                 % maximum coefficients to use in OMP coefficient calculations.
%    (optional, see errorFlag) errorGoal, ...        % allowed representation error in representing each signal.
%    InitializationMethod,...  mehtod to initialize the dictionary, can
%                                 be one of the following arguments: 
%                                 * 'DataElements' (initialization by the signals themselves), or: 
%                                 * 'GivenMatrix' (initialization by a given matrix param.initialDictionary).
%    (optional, see InitializationMethod) initialDictionary,...      % if the initialization method 
%                                 is 'GivenMatrix', this is the matrix that will be used.
%    (optional) TrueDictionary, ...        % if specified, in each
%                                 iteration the difference between this dictionary and the trained one
%                                 is measured and displayed.
%    displayProgress, ...      if =1 progress information is displyed. If param.errorFlag==0, 
%                                 the average repersentation error (RMSE) is displayed, while if 
%                                 param.errorFlag==1, the average number of required coefficients for 
%                                 representation of each signal is displayed.
% =========================================================================
% OUTPUT ARGUMENTS:
%  Dictionary                  The extracted dictionary of size nX(param.K).
%  output                      Struct that contains information about the current run. It may include the following fields:
%    CoefMatrix                  The final coefficients matrix (it should hold that Data equals approximately Dictionary*output.CoefMatrix.
%    ratio                       If the true dictionary was defined (in
%                                synthetic experiments), this parameter holds a vector of length
%                                param.numIteration that includes the detection ratios in each
%                                iteration).
%    totalerr                    The total representation error after each
%                                iteration (defined only if
%                                param.displayProgress=1 and
%                                param.errorFlag = 0)
%    numCoef                     A vector of length param.numIteration that
%                                include the average number of coefficients required for representation
%                                of each signal (in each iteration) (defined only if
%                                param.displayProgress=1 and
%                                param.errorFlag = 1)
% =========================================================================

%%********************* Dimension reduction ****************************
for i = 1:1:250
    for j = 1:1:250
        Data(i,j) = Data(i,j,1);
    end
end

if (~isfield(param,'displayProgress'))
    param.displayProgress = 0;
end
totalerr(1) = 99999;
if (isfield(param,'errorFlag')==0)
    param.errorFlag = 0;
end
 
if (isfield(param,'TrueDictionary'))
    displayErrorWithTrueDictionary = 1;
    ErrorBetweenDictionaries = zeros(param.numIteration+1,1); %产生零矩阵
    ratio = zeros(param.numIteration+1,1);
else
    displayErrorWithTrueDictionary = 0;
	ratio = 0;
end
if (param.preserveDCAtom>0)
    FixedDictionaryElement(1:size(Data,1),1) = 1/sqrt(size(Data,1));
else
    FixedDictionaryElement = [];
end
% coefficient calculation method is OMP with fixed number of coefficients
 
if (size(Data,2) < param.K)
    disp('Size of data is smaller than the dictionary size. Trivial solution...');
    Dictionary = Data(:,1:size(Data,2));
    return;
elseif (strcmp(param.InitializationMethod,'DataElements'))
    Dictionary(:,1:param.K-param.preserveDCAtom) = Data(:,1:param.K-param.preserveDCAtom);
elseif (strcmp(param.InitializationMethod,'GivenMatrix'))
    Dictionary(:,1:param.K-param.preserveDCAtom) = param.initialDictionary(:,1:param.K-param.preserveDCAtom);
end
% reduce the components in Dictionary that are spanned by the fixed
% elements
if (param.preserveDCAtom)
    tmpMat = FixedDictionaryElement \ Dictionary;
    Dictionary = Dictionary - FixedDictionaryElement*tmpMat;
end
%normalize the dictionary.
Dictionary = Dictionary*diag(1./sqrt(sum(Dictionary.*Dictionary)));
Dictionary = Dictionary.*repmat(sign(Dictionary(1,:)),size(Dictionary,1),1); % multiply in the sign of the first element.
totalErr = zeros(1,param.numIteration);
 
% the K-SVD algorithm starts here.

for iterNum = 1:param.numIteration
    % find the coefficients
    if (param.errorFlag==0)
        % CoefMatrix = mexOMPIterative2(Data, [FixedDictionaryElement,Dictionary],param.L);
        CoefMatrix = OMP([FixedDictionaryElement,Dictionary],Data, param.L);
    else 
        %CoefMatrix = mexOMPerrIterative(Data, [FixedDictionaryElement,Dictionary],param.errorGoal);
        CoefMatrix = OMPerr([FixedDictionaryElement,Dictionary],Data, param.errorGoal);
        param.L = 1;
    end
    
    replacedVectorCounter = 0;
	rPerm = randperm(size(Dictionary,2));
    for j = rPerm
        [betterDictionaryElement,CoefMatrix,addedNewVector] = I_findBetterDictionaryElement(Data,...
            [FixedDictionaryElement,Dictionary],j+size(FixedDictionaryElement,2),...
            CoefMatrix ,param.L);
        Dictionary(:,j) = betterDictionaryElement;
        if (param.preserveDCAtom)
            tmpCoef = FixedDictionaryElement\betterDictionaryElement;
            Dictionary(:,j) = betterDictionaryElement - FixedDictionaryElement*tmpCoef;
            Dictionary(:,j) = Dictionary(:,j)./sqrt(Dictionary(:,j)'*Dictionary(:,j));
        end
        replacedVectorCounter = replacedVectorCounter+addedNewVector;
    end
    if (iterNum>1 & param.displayProgress)
        if (param.errorFlag==0)
            output.totalerr(iterNum-1) = sqrt(sum(sum((Data-[FixedDictionaryElement,Dictionary]*CoefMatrix).^2))/prod(size(Data)));
            disp(['Iteration   ',num2str(iterNum),'   Total error is: ',num2str(output.totalerr(iterNum-1))]);
        else
            output.numCoef(iterNum-1) = length(find(CoefMatrix))/size(Data,2);
            disp(['Iteration   ',num2str(iterNum),'   Average number of coefficients: ',num2str(output.numCoef(iterNum-1))]);
        end
    end
    if (displayErrorWithTrueDictionary ) 
        [ratio(iterNum+1),ErrorBetweenDictionaries(iterNum+1)] = I_findDistanseBetweenDictionaries(param.TrueDictionary,Dictionary);
        disp(strcat(['Iteration  ', num2str(iterNum),' ratio of restored elements: ',num2str(ratio(iterNum+1))]));
        output.ratio = ratio;
    end
    Dictionary = I_clearDictionary(Dictionary,CoefMatrix(size(FixedDictionaryElement,2)+1:end,:),Data);
    
    if (isfield(param,'waitBarHandle'))
        waitbar(iterNum/param.counterForWaitBar);
    end
end
output.CoefMatrix = CoefMatrix;
Dictionary = [FixedDictionaryElement,Dictionary];

还有一个简化的版本
(转自用于训练系数字典的K_SVD算法

function [A,x]= K_SVD(y,codebook_size,errGoal) 
%============================== 
%input parameter 
%   y - input signal 
%   codebook_size - count of atoms 
%output parameter 
%   A - dictionary 
%   x - coefficent 
%reference:K-SVD:An Algorithm for Designing of Overcomplete Dictionaries 
%          for Sparse Representation,Aharon M.,Elad M.etc 
%============================== 
if(size(y,2)<codebook_size) 
    disp('codebook_size is too large or training samples is too small'); 
    return; 
end 
% initialization 
[rows,cols]=size(y); 
r=randperm(cols); 
A=y(:,r(1:codebook_size)); 
A=A./repmat(sqrt(sum(A.^2,1)),rows,1); 
ksvd_iter=10; 
for k=1:ksvd_iter 
    % sparse coding 
    if nargin==2 
        x=OMP(A,y,5.0/6*rows); 
    elseif nargin==3 
        x=OMPerr(A,y,errGoal); 
    end 
    % update dictionary 
    for m=1:codebook_size 
        mindex=find(x(m,:)); 
        if ~isempty(mindex) 
            mx=x(:,mindex); 
            mx(m,:)=0; 
            my=A*mx; 
            resy=y(:,mindex); 
            mE=resy-my; 
            [u,s,v]=svds(mE,1); 
            A(:,m)=u; 
            x(m,mindex)=s*v'; 
        end 
    end 
end 

这两个代码初始字典都是从样本中获得的,代码中都对样本长度和字典原子数关系有限制,字典原子数不能超过样本长度,换句话说,要构建过完备字典,就必须有超过样本长度的样本数。

但是很多人都有需要在仅有少量样本(样本数量小于样本长度)的情况下训练过完备字典,事实上,只需要改变初始字典的构建方式就可以了,初始字典可以构建一个指定大小的随机数矩阵,再进行L2范数归一化。

当然构建随机数字典是最普适的一种方式,但是要获得较好的效果也需要相对较高的迭代次数,各位也可以根据自己需求更改初始化字典,如DCT、FFT、小波等形式,初始字典构筑方式有许多,可以构建单个过完备字典,也可以通过几个字典拼接而成,K-SVD算法结果和初始字典的相关性很高,可以根据自己的目的选择合适的初始化字典,以下是我根据第二个代码改造的使用随机初始矩阵的代码,顺带进行了中文备注方便大家学习使用~

使用随机初始字典的K-SVD代码:

function [ D , x ] = myksvd( y , K , iter , L )
%------------------K-SVD过完备字典学习------------------
%    输入参数
%    y----------训练样本集(MxN维 M为一个样本长度,N为样本数)
%    K----------字典原子(列)数
%    iter-------迭代次数
%    L----------目标稀疏度
%-------------------------------
%    输出参数
%    D----------稀疏字典
%    x----------样本的稀疏系数
%------------------------------------------------------
[M,~]=size(y);
%========================初始化字典=========================
d=rand(M,K);        %随机数初始字典


for i=1:K           %L2范数归一化
    D(:,i)=d(:,i)./norm(d(:,i));
end
%==========================迭代============================
ff=waitbar(0,'K-SVD Dictionary Learning...');

for J=1:iter
    
%-----------稀疏编码----------
    x=OMP(D,y,L);
%     x=CS_OMP(y,D,L);
%--------------------------------------------------

%-----------字典更新---------
    fff=waitbar(0,'第k列字典更新中...');
    for k=1:K       
        wk=find(x(k,:));
        if ~isempty(wk)
            %-----计算残差-----
            
            mx=x(:,wk);    %消除x的k行零项对应列
            mx(k,:)=0;     %让x的k行清零,因此残差项不包含dkxk
            my=D*mx;       %为舍去零列及k行的乘积
            resy=y(:,wk);  %舍去非零列的k
            Ek=resy-my;    %直接得到Ek'
            
            %-----SVD求解优化变量-----
            [U,S,V]=svds(Ek,1);
            D(:,k)=U;      %k列字典更新为左奇异矩阵第一个列向量
            x(k,wk)=S*V';  %k行x更新为右奇异矩阵第一个行向量与第一个奇异值乘积
        end   
        str=[['第',num2str(k),'列字典更新中...'],num2str(k/K*100),'%'];
        waitbar(k/K,fff,str);
    end    
%-------------------------------------------------
    close(fff);
    str=[['K-SVD Dictionary Learning...   ',num2str(J),' / ',num2str(iter),'         '],num2str(J/iter*100),'%'];
    waitbar(J/iter,ff,str);
end

close(ff);

K-SVD中使用的OMP代码

function A=OMP(D,X,L)
% 输入参数:
% D - 过完备字典,注意:必须字典的各列必须经过了规范化
% X - 信号
% L - 系数中非零元个数的最大值(可选,默认为D的列数,速度可能慢)
% 输出参数:
% A - 稀疏系数
P=size(X,2);
K=size(D,2);
ffff=waitbar(0,'OMP Running...');
for k=1:1:P
    a=[];
    x=X(:,k);
    residual=x;
    indx=zeros(L,1);
    for j=1:1:L
        proj=D'*residual;
        [maxVal,pos]=max(abs(proj));
        pos=pos(1);
        indx(j)=pos;
        a=pinv(D(:,indx(1:j)))*x;
        residual=x-D(:,indx(1:j))*a;
        if sum(residual.^2) < 1e-6
            break;
        end
    end
    temp=zeros(K,1);
    temp(indx(1:j))=a;
    A(:,k)=sparse(temp);
    str=['OMP Running...   ',num2str(k/P*100),'%'];
    waitbar(k/P,ffff,str);
end
close(ffff);
return;
  • 20
    点赞
  • 113
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值