版权声明:本文为博主原创文章,未经博主允许不得转载。
首先声明,MOD不是取模函数!MOD是字典学习和sparse coding的一种方法… 最近在看KSVD,其简化版就是MOD(method of directions),这么说吧,KSVD和MOD的优化目标函数是相同的,MOD之所以可以称作KSVD的简化版是因为KSVD在MOD的基础上做了顺序更新列的优化。关于KSVD和MOD的理论知识请见下面我给出的一页note和referenc中的paper。本文主要给出其基本思想及我的代码,已经过测试,如有bug欢迎提出。
Reference
<<From Sparse Solutions of Systems of Equations to Sparse Modeling of Signals and Images>>, Page 68~70
KSVD & MOD's principle & objective function
Principle:
简单来说,其优化就是一个OMP(orthogonal matching pursuit)与Regression的迭代过程,因此代码包括一个OMP.m, regression.m.
Objective Function & the variation from MOD to KSVD:
Code
CODE1. MOD
运行Main(Main中通过MOD)学习字典和稀疏表示,MOD迭代调用Regression学习字典,调用和OMP获得sparse representation.
Main.m
- %% Main.m
- clc;
- clear;
- P = 512;
- N = 256;
- M = 128;
- K = 100;
- %% Data Generator Method 1
- % sparsity_X = 0.4;
- % Y = randi(10,M,P);
- % X = floor(sprand(N,P,sparsity_X)*10);
- %% Data Generator Method 2
- Y = randn(M,P);%Notice that Y should be full rank, that is, rank(Y) = N
- X = randn(N,P);% initialization of X
- %% Main Iteration
- [D,X] = MOD(Y,X,K,1e-4);
MOD.m
- % @Function: Method Of Dirction of 2D signal
- % For dictionary and sparse representation learning
- % @CreateTime: 2013-2-22
- % @Author: Rachel Zhang @ http://blog.csdn.net/abcjennifer
- %
- % @Reference: From Sparse Solutions of Systems of Equations to
- % Sparse Modeling of Signals and Images
- function [ D , X ] = MOD( Y ,X ,K ,ErrorThreshold )
- %MOD Summary of this function goes here
- % Detailed explanation goes here
- % Sample_Data is Y
- % Coefficient is X
- % Dictionary is D
- % sparsity is K
- disp('Run Method of directions');
- iteration_time = 1;
- error = ErrorThreshold+1;
- while error>=ErrorThreshold;
- disp(['iteration time = ' num2str(iteration_time)]);
- D = Regression(Y,X);
- X = OMP(Y,D,K);
- iteration_time = iteration_time+1;
- error = sum(sum(abs(Y-D*X)))
- end
- end
OMP.m
- % @Function: Orthogonal Matching Pursuit of 2D signal
- % Learning Sparse Representation Given Dictionary
- % @CreateTime: 2013-2-21
- % @Author: Rachel Zhang @ http://blog.csdn.net/abcjennifer
- %
- % @Reference: http://www.eee.hku.hk/~wsha/Freecode/freecode.htm
- function [ X ] = OMP( Y,D,K )
- % Y is the sample data to be recovered M*P
- % D is the dictionary M*N
- % X is the sparse coefficient N*P
- % K is the sparsity
- if nargin==2
- K = size(D,2);
- end;
- M = size(D,1);
- P = size(Y,2);
- N = size(D,2);
- m = K*2; % execute iterations
- for idx = 1:P
- % recover the idx-th column sample
- y = Y(:,idx);
- residual = y;
- Aug_D = [];
- D1 = D;
- for times = 1:m;
- product = abs(D1'*residual);
- [~,pos] = max(product); % 最大投影系数对应的位置
- Aug_D = [Aug_D, D1(:,pos)];
- D1(:,pos) = zeros(M,1); %去掉选中的列
- indx(times) = pos;
- Aug_x = (Aug_D'*Aug_D)^-1*Aug_D'*y; % 最小二乘,使残差最小,i.e. x = pinv(Aug_D)*y
- residual = y - Aug_D*Aug_x;
- if sum(residual.^2)<1e-6
- break;
- end
- end
- temp = zeros(N,1);
- temp(indx(1:times)) = Aug_x;
- X(:,idx) = sparse(temp);
- end
- end
Regression.m
- % @Function: Dictionary learning & Regression
- % Learning Dictionary Given Sparse Representation
- % @CreateTime: 2013-2-21
- % @Author: Rachel Zhang @ http://blog.csdn.net/abcjennifer
- %
- function [ D ] = Regression( Y,X )
- % Y is the sample data to be recovered M*P
- % D is the dictionary M*N
- % X is the sparse coefficient N*P
- % P>N>M
- %由于X是扁矩阵,需要转置求D0 = min(D) ||Y^T-X^TD^T||
- %这样就是N个未知数,P个方程去求解;
- %每次解得D中的一列,共解M次
- Y = Y';
- X = X';
- P = size(Y,1);
- N = size(X,2);
- M = size(Y,2);
- D = zeros(N,M);
- for i = 1:M;
- y = Y(:,i);
- D(:,i) = regress(y,X);
- end
- D = D';
- end
============================================================================
CODE2. KSVD
ksvd函数代码是国外的人写的,很规矩,这里贴过来。
- function [Dictionary,output] = KSVD(...
- Data,... % an nXN matrix that contins N signals (Y), each of dimension n.
- param)
- % =========================================================================
- % K-SVD algorithm
- % =========================================================================
- % The K-SVD algorithm finds a dictionary for linear representation of
- % signals. Given a set of signals, it searches for the best dictionary that
- % can sparsely represent each signal. Detailed discussion on the algorithm
- % and possible applications can be found in "The K-SVD: An Algorithm for
- % Designing of Overcomplete Dictionaries for Sparse Representation", written
- % by M. Aharon, M. Elad, and A.M. Bruckstein and appeared in the IEEE Trans.
- % On Signal Processing, Vol. 54, no. 11, pp. 4311-4322, November 2006.
- % =========================================================================
- % INPUT ARGUMENTS:
- % Data an nXN matrix that contins N signals (Y), each of dimension n.
- % param structure that includes all required
- % parameters for the K-SVD execution.
- % Required fields are:
- % K, ... the number of dictionary elements to train
- % numIteration,... number of iterations to perform.
- % errorFlag... if =0, a fix number of coefficients is
- % used for representation of each signal. If so, param.L must be
- % specified as the number of representing atom. if =1, arbitrary number
- % of atoms represent each signal, until a specific representation error
- % is reached. If so, param.errorGoal must be specified as the allowed
- % error.
- % preserveDCAtom... if =1 then the first atom in the dictionary
- % is set to be constant, and does not ever change. This
- % might be useful for working with natural
- % images (in this case, only param.K-1
- % atoms are trained).
- % (optional, see errorFlag) L,... % maximum coefficients to use in OMP coefficient calculations.
- % (optional, see errorFlag) errorGoal, ... % allowed representation error in representing each signal.
- % InitializationMethod,... mehtod to initialize the dictionary, can
- % be one of the following arguments:
- % * 'DataElements' (initialization by the signals themselves), or:
- % * 'GivenMatrix' (initialization by a given matrix param.initialDictionary).
- % (optional, see InitializationMethod) initialDictionary,... % if the initialization method
- % is 'GivenMatrix', this is the matrix that will be used.
- % (optional) TrueDictionary, ... % if specified, in each
- % iteration the difference between this dictionary and the trained one
- % is measured and displayed.
- % displayProgress, ... if =1 progress information is displyed. If param.errorFlag==0,
- % the average repersentation error (RMSE) is displayed, while if
- % param.errorFlag==1, the average number of required coefficients for
- % representation of each signal is displayed.
- % =========================================================================
- % OUTPUT ARGUMENTS:
- % Dictionary The extracted dictionary of size nX(param.K).
- % output Struct that contains information about the current run. It may include the following fields:
- % CoefMatrix The final coefficients matrix (it should hold that Data equals approximately Dictionary*output.CoefMatrix.
- % ratio If the true dictionary was defined (in
- % synthetic experiments), this parameter holds a vector of length
- % param.numIteration that includes the detection ratios in each
- % iteration).
- % totalerr The total representation error after each
- % iteration (defined only if
- % param.displayProgress=1 and
- % param.errorFlag = 0)
- % numCoef A vector of length param.numIteration that
- % include the average number of coefficients required for representation
- % of each signal (in each iteration) (defined only if
- % param.displayProgress=1 and
- % param.errorFlag = 1)
- % =========================================================================
- if (~isfield(param,'displayProgress'))
- param.displayProgress = 0;
- end
- totalerr(1) = 99999;
- if (isfield(param,'errorFlag')==0)
- param.errorFlag = 0;
- end
- if (isfield(param,'TrueDictionary'))
- displayErrorWithTrueDictionary = 1;
- ErrorBetweenDictionaries = zeros(param.numIteration+1,1);
- ratio = zeros(param.numIteration+1,1);
- else
- displayErrorWithTrueDictionary = 0;
- ratio = 0;
- end
- if (param.preserveDCAtom>0)
- FixedDictionaryElement(1:size(Data,1),1) = 1/sqrt(size(Data,1));
- else
- FixedDictionaryElement = [];
- end
- % coefficient calculation method is OMP with fixed number of coefficients
- if (size(Data,2) < param.K)
- disp('Size of data is smaller than the dictionary size. Trivial solution...');
- Dictionary = Data(:,1:size(Data,2));
- return;
- elseif (strcmp(param.InitializationMethod,'DataElements'))
- Dictionary(:,1:param.K-param.preserveDCAtom) = Data(:,1:param.K-param.preserveDCAtom);
- elseif (strcmp(param.InitializationMethod,'GivenMatrix'))
- Dictionary(:,1:param.K-param.preserveDCAtom) = param.initialDictionary(:,1:param.K-param.preserveDCAtom);
- end
- % reduce the components in Dictionary that are spanned by the fixed
- % elements
- if (param.preserveDCAtom)
- tmpMat = FixedDictionaryElement \ Dictionary;
- Dictionary = Dictionary - FixedDictionaryElement*tmpMat;
- end
- %normalize the dictionary.
- Dictionary = Dictionary*diag(1./sqrt(sum(Dictionary.*Dictionary)));
- Dictionary = Dictionary.*repmat(sign(Dictionary(1,:)),size(Dictionary,1),1); % multiply in the sign of the first element.
- totalErr = zeros(1,param.numIteration);
- % the K-SVD algorithm starts here.
- for iterNum = 1:param.numIteration
- % find the coefficients
- if (param.errorFlag==0)
- %CoefMatrix = mexOMPIterative2(Data, [FixedDictionaryElement,Dictionary],param.L);
- CoefMatrix = OMP([FixedDictionaryElement,Dictionary],Data, param.L);
- else
- %CoefMatrix = mexOMPerrIterative(Data, [FixedDictionaryElement,Dictionary],param.errorGoal);
- CoefMatrix = OMPerr([FixedDictionaryElement,Dictionary],Data, param.errorGoal);
- param.L = 1;
- end
- replacedVectorCounter = 0;
- rPerm = randperm(size(Dictionary,2));
- for j = rPerm
- [betterDictionaryElement,CoefMatrix,addedNewVector] = I_findBetterDictionaryElement(Data,...
- [FixedDictionaryElement,Dictionary],j+size(FixedDictionaryElement,2),...
- CoefMatrix ,param.L);
- Dictionary(:,j) = betterDictionaryElement;
- if (param.preserveDCAtom)
- tmpCoef = FixedDictionaryElement\betterDictionaryElement;
- Dictionary(:,j) = betterDictionaryElement - FixedDictionaryElement*tmpCoef;
- Dictionary(:,j) = Dictionary(:,j)./sqrt(Dictionary(:,j)'*Dictionary(:,j));
- end
- replacedVectorCounter = replacedVectorCounter+addedNewVector;
- end
- if (iterNum>1 & param.displayProgress)
- if (param.errorFlag==0)
- output.totalerr(iterNum-1) = sqrt(sum(sum((Data-[FixedDictionaryElement,Dictionary]*CoefMatrix).^2))/prod(size(Data)));
- disp(['Iteration ',num2str(iterNum),' Total error is: ',num2str(output.totalerr(iterNum-1))]);
- else
- output.numCoef(iterNum-1) = length(find(CoefMatrix))/size(Data,2);
- disp(['Iteration ',num2str(iterNum),' Average number of coefficients: ',num2str(output.numCoef(iterNum-1))]);
- end
- end
- if (displayErrorWithTrueDictionary )
- [ratio(iterNum+1),ErrorBetweenDictionaries(iterNum+1)] = I_findDistanseBetweenDictionaries(param.TrueDictionary,Dictionary);
- disp(strcat(['Iteration ', num2str(iterNum),' ratio of restored elements: ',num2str(ratio(iterNum+1))]));
- output.ratio = ratio;
- end
- Dictionary = I_clearDictionary(Dictionary,CoefMatrix(size(FixedDictionaryElement,2)+1:end,:),Data);
- if (isfield(param,'waitBarHandle'))
- waitbar(iterNum/param.counterForWaitBar);
- end
- end
- output.CoefMatrix = CoefMatrix;
- Dictionary = [FixedDictionaryElement,Dictionary];
- %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- % findBetterDictionaryElement
- %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- function [betterDictionaryElement,CoefMatrix,NewVectorAdded] = I_findBetterDictionaryElement(Data,Dictionary,j,CoefMatrix,numCoefUsed)
- if (length(who('numCoefUsed'))==0)
- numCoefUsed = 1;
- end
- relevantDataIndices = find(CoefMatrix(j,:)); % the data indices that uses the j'th dictionary element.
- if (length(relevantDataIndices)<1) %(length(relevantDataIndices)==0)
- ErrorMat = Data-Dictionary*CoefMatrix;
- ErrorNormVec = sum(ErrorMat.^2);
- [d,i] = max(ErrorNormVec);
- betterDictionaryElement = Data(:,i);%ErrorMat(:,i); %
- betterDictionaryElement = betterDictionaryElement./sqrt(betterDictionaryElement'*betterDictionaryElement);
- betterDictionaryElement = betterDictionaryElement.*sign(betterDictionaryElement(1));
- CoefMatrix(j,:) = 0;
- NewVectorAdded = 1;
- return;
- end
- NewVectorAdded = 0;
- tmpCoefMatrix = CoefMatrix(:,relevantDataIndices);
- tmpCoefMatrix(j,:) = 0;% the coeffitients of the element we now improve are not relevant.
- errors =(Data(:,relevantDataIndices) - Dictionary*tmpCoefMatrix); % vector of errors that we want to minimize with the new element
- % % the better dictionary element and the values of beta are found using svd.
- % % This is because we would like to minimize || errors - beta*element ||_F^2.
- % % that is, to approximate the matrix 'errors' with a one-rank matrix. This
- % % is done using the largest singular value.
- [betterDictionaryElement,singularValue,betaVector] = svds(errors,1);
- CoefMatrix(j,relevantDataIndices) = singularValue*betaVector';% *signOfFirstElem
- %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- % findDistanseBetweenDictionaries
- %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- function [ratio,totalDistances] = I_findDistanseBetweenDictionaries(original,new)
- % first, all the column in oiginal starts with positive values.
- catchCounter = 0;
- totalDistances = 0;
- for i = 1:size(new,2)
- new(:,i) = sign(new(1,i))*new(:,i);
- end
- for i = 1:size(original,2)
- d = sign(original(1,i))*original(:,i);
- distances =sum ( (new-repmat(d,1,size(new,2))).^2);
- [minValue,index] = min(distances);
- errorOfElement = 1-abs(new(:,index)'*d);
- totalDistances = totalDistances+errorOfElement;
- catchCounter = catchCounter+(errorOfElement<0.01);
- end
- ratio = 100*catchCounter/size(original,2);
- %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- % I_clearDictionary
- %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
- function Dictionary = I_clearDictionary(Dictionary,CoefMatrix,Data)
- T2 = 0.99;
- T1 = 3;
- K=size(Dictionary,2);
- Er=sum((Data-Dictionary*CoefMatrix).^2,1); % remove identical atoms
- G=Dictionary'*Dictionary; G = G-diag(diag(G));
- for jj=1:1:K,
- if max(G(jj,:))>T2 | length(find(abs(CoefMatrix(jj,:))>1e-7))<=T1 ,
- [val,pos]=max(Er);
- Er(pos(1))=0;
- Dictionary(:,jj)=Data(:,pos(1))/norm(Data(:,pos(1)));
- G=Dictionary'*Dictionary; G = G-diag(diag(G));
- end;
- end;