MATLAB 使用少量样本的K-SVD过完备字典学习代码(中文备注)
最近在学习稀疏表示与压缩感知理论,在使用K-SVD方面我遇到了一个问题,就是手上样本数量很少,但是需要构建过完备字典,我看到好像很多人也有这个问题。
目前在论坛里看到流传的K-SVD的MATLAB代码主要是下面这两个,都需要超过样本长度的样本数才能训练过完备字典:
dictionary learning tools中的K-SVD算法代码(转自K-SVD算法)
function [Dictionary,output] = KSVD(...
Data,... % an nXN matrix that contins N signals (Y), each of dimension n.
param)
% =========================================================================
% K-SVD algorithm
% =========================================================================
% The K-SVD algorithm finds a dictionary for linear representation of
% signals. Given a set of signals, it searches for the best dictionary that
% can sparsely represent each signal. Detailed discussion on the algorithm
% and possible applications can be found in "The K-SVD: An Algorithm for
% Designing of Overcomplete Dictionaries for Sparse Representation", written
% by M. Aharon, M. Elad, and A.M. Bruckstein and appeared in the IEEE Trans.
% On Signal Processing, Vol. 54, no. 11, pp. 4311-4322, November 2006.
% =========================================================================
% INPUT ARGUMENTS:
% Data an nXN matrix that contins N signals (Y), each of dimension n.
% param structure that includes all required
% parameters for the K-SVD execution.
% Required fields are:
% K, ... the number of dictionary elements to train
% numIteration,... number of iterations to perform.
% errorFlag... if =0, a fix number of coefficients is
% used for representation of each signal. If so, param.L must be
% specified as the number of representing atom. if =1, arbitrary number
% of atoms represent each signal, until a specific representation error
% is reached. If so, param.errorGoal must be specified as the allowed
% error.
% preserveDCAtom... if =1 then the first atom in the dictionary
% is set to be constant, and does not ever change. This
% might be useful for working with natural
% images (in this case, only param.K-1
% atoms are trained).
% (optional, see errorFlag) L,... % maximum coefficients to use in OMP coefficient calculations.
% (optional, see errorFlag) errorGoal, ... % allowed representation error in representing each signal.
% InitializationMethod,... mehtod to initialize the dictionary, can
% be one of the following arguments:
% * 'DataElements' (initialization by the signals themselves), or:
% * 'GivenMatrix' (initialization by a given matrix param.initialDictionary).
% (optional, see InitializationMethod) initialDictionary,... % if the initialization method
% is 'GivenMatrix', this is the matrix that will be used.
% (optional) TrueDictionary, ... % if specified, in each
% iteration the difference between this dictionary and the trained one
% is measured and displayed.
% displayProgress, ... if =1 progress information is displyed. If param.errorFlag==0,
% the average repersentation error (RMSE) is displayed, while if
% param.errorFlag==1, the average number of required coefficients for
% representation of each signal is displayed.
% =========================================================================
% OUTPUT ARGUMENTS:
% Dictionary The extracted dictionary of size nX(param.K).
% output Struct that contains information about the current run. It may include the following fields:
% CoefMatrix The final coefficients matrix (it should hold that Data equals approximately Dictionary*output.CoefMatrix.
% ratio If the true dictionary was defined (in
% synthetic experiments), this parameter holds a vector of length
% param.numIteration that includes the detection ratios in each
% iteration).
% totalerr The total representation error after each
% iteration (defined only if
% param.displayProgress=1 and
% param.errorFlag = 0)
% numCoef A vector of length param.numIteration that
% include the average number of coefficients required for representation
% of each signal (in each iteration) (defined only if
% param.displayProgress=1 and
% param.errorFlag = 1)
% =========================================================================
%%********************* Dimension reduction ****************************
for i = 1:1:250
for j = 1:1:250
Data(i,j) = Data(i,j,1);
end
end
if (~isfield(param,'displayProgress'))
param.displayProgress = 0;
end
totalerr(1) = 99999;
if (isfield(param,'errorFlag')==0)
param.errorFlag = 0;
end
if (isfield(param,'TrueDictionary'))
displayErrorWithTrueDictionary = 1;
ErrorBetweenDictionaries = zeros(param.numIteration+1,1); %产生零矩阵
ratio = zeros(param.numIteration+1,1);
else
displayErrorWithTrueDictionary = 0;
ratio = 0;
end
if (param.preserveDCAtom>0)
FixedDictionaryElement(1:size(Data,1),1) = 1/sqrt(size(Data,1));
else
FixedDictionaryElement = [];
end
% coefficient calculation method is OMP with fixed number of coefficients
if (size(Data,2) < param.K)
disp('Size of data is smaller than the dictionary size. Trivial solution...');
Dictionary = Data(:,1:size(Data,2));
return;
elseif (strcmp(param.InitializationMethod,'DataElements'))
Dictionary(:,1:param.K-param.preserveDCAtom) = Data(:,1:param.K-param.preserveDCAtom);
elseif (strcmp(param.InitializationMethod,'GivenMatrix'))
Dictionary(:,1:param.K-param.preserveDCAtom) = param.initialDictionary(:,1:param.K-param.preserveDCAtom);
end
% reduce the components in Dictionary that are spanned by the fixed
% elements
if (param.preserveDCAtom)
tmpMat = FixedDictionaryElement \ Dictionary;
Dictionary = Dictionary - FixedDictionaryElement*tmpMat;
end
%normalize the dictionary.
Dictionary = Dictionary*diag(1./sqrt(sum(Dictionary.*Dictionary)));
Dictionary = Dictionary.*repmat(sign(Dictionary(1,:)),size(Dictionary,1),1); % multiply in the sign of the first element.
totalErr = zeros(1,param.numIteration);
% the K-SVD algorithm starts here.
for iterNum = 1:param.numIteration
% find the coefficients
if (param.errorFlag==0)
% CoefMatrix = mexOMPIterative2(Data, [FixedDictionaryElement,Dictionary],param.L);
CoefMatrix = OMP([FixedDictionaryElement,Dictionary],Data, param.L);
else
%CoefMatrix = mexOMPerrIterative(Data, [FixedDictionaryElement,Dictionary],param.errorGoal);
CoefMatrix = OMPerr([FixedDictionaryElement,Dictionary],Data, param.errorGoal);
param.L = 1;
end
replacedVectorCounter = 0;
rPerm = randperm(size(Dictionary,2));
for j = rPerm
[betterDictionaryElement,CoefMatrix,addedNewVector] = I_findBetterDictionaryElement(Data,...
[FixedDictionaryElement,Dictionary],j+size(FixedDictionaryElement,2),...
CoefMatrix ,param.L);
Dictionary(:,j) = betterDictionaryElement;
if (param.preserveDCAtom)
tmpCoef = FixedDictionaryElement\betterDictionaryElement;
Dictionary(:,j) = betterDictionaryElement - FixedDictionaryElement*tmpCoef;
Dictionary(:,j) = Dictionary(:,j)./sqrt(Dictionary(:,j)'*Dictionary(:,j));
end
replacedVectorCounter = replacedVectorCounter+addedNewVector;
end
if (iterNum>1 & param.displayProgress)
if (param.errorFlag==0)
output.totalerr(iterNum-1) = sqrt(sum(sum((Data-[FixedDictionaryElement,Dictionary]*CoefMatrix).^2))/prod(size(Data)));
disp(['Iteration ',num2str(iterNum),' Total error is: ',num2str(output.totalerr(iterNum-1))]);
else
output.numCoef(iterNum-1) = length(find(CoefMatrix))/size(Data,2);
disp(['Iteration ',num2str(iterNum),' Average number of coefficients: ',num2str(output.numCoef(iterNum-1))]);
end
end
if (displayErrorWithTrueDictionary )
[ratio(iterNum+1),ErrorBetweenDictionaries(iterNum+1)] = I_findDistanseBetweenDictionaries(param.TrueDictionary,Dictionary);
disp(strcat(['Iteration ', num2str(iterNum),' ratio of restored elements: ',num2str(ratio(iterNum+1))]));
output.ratio = ratio;
end
Dictionary = I_clearDictionary(Dictionary,CoefMatrix(size(FixedDictionaryElement,2)+1:end,:),Data);
if (isfield(param,'waitBarHandle'))
waitbar(iterNum/param.counterForWaitBar);
end
end
output.CoefMatrix = CoefMatrix;
Dictionary = [FixedDictionaryElement,Dictionary];
还有一个简化的版本
(转自用于训练系数字典的K_SVD算法)
function [A,x]= K_SVD(y,codebook_size,errGoal)
%==============================
%input parameter
% y - input signal
% codebook_size - count of atoms
%output parameter
% A - dictionary
% x - coefficent
%reference:K-SVD:An Algorithm for Designing of Overcomplete Dictionaries
% for Sparse Representation,Aharon M.,Elad M.etc
%==============================
if(size(y,2)<codebook_size)
disp('codebook_size is too large or training samples is too small');
return;
end
% initialization
[rows,cols]=size(y);
r=randperm(cols);
A=y(:,r(1:codebook_size));
A=A./repmat(sqrt(sum(A.^2,1)),rows,1);
ksvd_iter=10;
for k=1:ksvd_iter
% sparse coding
if nargin==2
x=OMP(A,y,5.0/6*rows);
elseif nargin==3
x=OMPerr(A,y,errGoal);
end
% update dictionary
for m=1:codebook_size
mindex=find(x(m,:));
if ~isempty(mindex)
mx=x(:,mindex);
mx(m,:)=0;
my=A*mx;
resy=y(:,mindex);
mE=resy-my;
[u,s,v]=svds(mE,1);
A(:,m)=u;
x(m,mindex)=s*v';
end
end
end
这两个代码初始字典都是从样本中获得的,代码中都对样本长度和字典原子数关系有限制,字典原子数不能超过样本长度,换句话说,要构建过完备字典,就必须有超过样本长度的样本数。
但是很多人都有需要在仅有少量样本(样本数量小于样本长度)的情况下训练过完备字典,事实上,只需要改变初始字典的构建方式就可以了,初始字典可以构建一个指定大小的随机数矩阵,再进行L2范数归一化。
当然构建随机数字典是最普适的一种方式,但是要获得较好的效果也需要相对较高的迭代次数,各位也可以根据自己需求更改初始化字典,如DCT、FFT、小波等形式,初始字典构筑方式有许多,可以构建单个过完备字典,也可以通过几个字典拼接而成,K-SVD算法结果和初始字典的相关性很高,可以根据自己的目的选择合适的初始化字典,以下是我根据第二个代码改造的使用随机初始矩阵的代码,顺带进行了中文备注方便大家学习使用~
使用随机初始字典的K-SVD代码:
function [ D , x ] = myksvd( y , K , iter , L )
%------------------K-SVD过完备字典学习------------------
% 输入参数
% y----------训练样本集(MxN维 M为一个样本长度,N为样本数)
% K----------字典原子(列)数
% iter-------迭代次数
% L----------目标稀疏度
%-------------------------------
% 输出参数
% D----------稀疏字典
% x----------样本的稀疏系数
%------------------------------------------------------
[M,~]=size(y);
%========================初始化字典=========================
d=rand(M,K); %随机数初始字典
for i=1:K %L2范数归一化
D(:,i)=d(:,i)./norm(d(:,i));
end
%==========================迭代============================
ff=waitbar(0,'K-SVD Dictionary Learning...');
for J=1:iter
%-----------稀疏编码----------
x=OMP(D,y,L);
% x=CS_OMP(y,D,L);
%--------------------------------------------------
%-----------字典更新---------
fff=waitbar(0,'第k列字典更新中...');
for k=1:K
wk=find(x(k,:));
if ~isempty(wk)
%-----计算残差-----
mx=x(:,wk); %消除x的k行零项对应列
mx(k,:)=0; %让x的k行清零,因此残差项不包含dkxk
my=D*mx; %为舍去零列及k行的乘积
resy=y(:,wk); %舍去非零列的k
Ek=resy-my; %直接得到Ek'
%-----SVD求解优化变量-----
[U,S,V]=svds(Ek,1);
D(:,k)=U; %k列字典更新为左奇异矩阵第一个列向量
x(k,wk)=S*V'; %k行x更新为右奇异矩阵第一个行向量与第一个奇异值乘积
end
str=[['第',num2str(k),'列字典更新中...'],num2str(k/K*100),'%'];
waitbar(k/K,fff,str);
end
%-------------------------------------------------
close(fff);
str=[['K-SVD Dictionary Learning... ',num2str(J),' / ',num2str(iter),' '],num2str(J/iter*100),'%'];
waitbar(J/iter,ff,str);
end
close(ff);
K-SVD中使用的OMP代码
function A=OMP(D,X,L)
% 输入参数:
% D - 过完备字典,注意:必须字典的各列必须经过了规范化
% X - 信号
% L - 系数中非零元个数的最大值(可选,默认为D的列数,速度可能慢)
% 输出参数:
% A - 稀疏系数
P=size(X,2);
K=size(D,2);
ffff=waitbar(0,'OMP Running...');
for k=1:1:P
a=[];
x=X(:,k);
residual=x;
indx=zeros(L,1);
for j=1:1:L
proj=D'*residual;
[maxVal,pos]=max(abs(proj));
pos=pos(1);
indx(j)=pos;
a=pinv(D(:,indx(1:j)))*x;
residual=x-D(:,indx(1:j))*a;
if sum(residual.^2) < 1e-6
break;
end
end
temp=zeros(K,1);
temp(indx(1:j))=a;
A(:,k)=sparse(temp);
str=['OMP Running... ',num2str(k/P*100),'%'];
waitbar(k/P,ffff,str);
end
close(ffff);
return;