MOD算法学习

 

MOD 之"Hello World"

标签: dictionary learningKSVDMethod of DirectionOMPRegressionregression
  10170人阅读  评论(30)  收藏  举报
  分类:
 
   

首先声明,MOD不是取模函数!MOD是字典学习和sparse coding的一种方法… 最近在看KSVD,其简化版就是MOD(method of directions),这么说吧,KSVD和MOD的优化目标函数是相同的,MOD之所以可以称作KSVD的简化版是因为KSVD在MOD的基础上做了顺序更新列的优化。关于KSVD和MOD的理论知识请见下面我给出的一页note和referenc中的paper。本文主要给出其基本思想及我的代码,已经过测试,如有bug欢迎提出。


Reference

<<From Sparse Solutions of Systems of Equations to Sparse Modeling of Signals and Images>>, Page 68~70



KSVD & MOD's principle & objective function 

Principle:

简单来说,其优化就是一个OMP(orthogonal matching pursuit)与Regression的迭代过程,因此代码包括一个OMP.m, regression.m.


Objective Function & the variation from MOD to KSVD:




Code

CODE1. MOD

运行Main(Main中通过MOD)学习字典和稀疏表示,MOD迭代调用Regression学习字典,调用和OMP获得sparse representation.


Main.m

[cpp]  view plain  copy
  1. %% Main.m  
  2. clc;  
  3. clear;  
  4. P = 512;  
  5. N = 256;  
  6. M = 128;  
  7. K = 100;  
  8.   
  9. %% Data Generator Method 1  
  10. % sparsity_X = 0.4;  
  11. % Y = randi(10,M,P);  
  12. % X = floor(sprand(N,P,sparsity_X)*10);  
  13.   
  14. %% Data Generator Method 2  
  15. Y = randn(M,P);%Notice that Y should be full rank, that is, rank(Y) = N  
  16. X = randn(N,P);% initialization of X  
  17.   
  18. %% Main Iteration  
  19. [D,X] = MOD(Y,X,K,1e-4);  



MOD.m

[cpp]  view plain  copy
  1. %   @Function: Method Of Dirction of 2D signal  
  2. %   For dictionary and sparse representation learning  
  3. %   @CreateTime: 2013-2-22  
  4. %   @Author: Rachel Zhang  @  http://blog.csdn.net/abcjennifer  
  5. %     
  6. %   @Reference: From Sparse Solutions of Systems of Equations to   
  7. %   Sparse Modeling of Signals and Images  
  8.   
  9. function [ D , X ] = MOD( Y ,X ,K ,ErrorThreshold )  
  10. %MOD Summary of this function goes here  
  11. %   Detailed explanation goes here  
  12. %   Sample_Data is Y  
  13. %   Coefficient is X  
  14. %   Dictionary is D  
  15. %   sparsity is K  
  16.   
  17. disp('Run Method of directions');  
  18. iteration_time = 1;  
  19. error = ErrorThreshold+1;  
  20.   
  21.   
  22. while error>=ErrorThreshold;  
  23.     disp(['iteration time = ' num2str(iteration_time)]);  
  24.     D = Regression(Y,X);  
  25.     X = OMP(Y,D,K);  
  26.     iteration_time = iteration_time+1;  
  27.     error = sum(sum(abs(Y-D*X)))  
  28. end  
  29.   
  30. end  



OMP.m

[cpp]  view plain  copy
  1. %   @Function: Orthogonal Matching Pursuit of 2D signal  
  2. %   Learning Sparse Representation Given Dictionary  
  3. %   @CreateTime: 2013-2-21  
  4. %   @Author: Rachel Zhang  @  http://blog.csdn.net/abcjennifer  
  5. %     
  6. %   @Reference: http://www.eee.hku.hk/~wsha/Freecode/freecode.htm     
  7.   
  8. function [ X ] = OMP( Y,D,K )  
  9. % Y is the sample data to be recovered M*P  
  10. % D is the dictionary M*N  
  11. % X is the sparse coefficient N*P  
  12. % K is the sparsity  
  13.   
  14. if nargin==2  
  15.     K = size(D,2);  
  16. end;  
  17.   
  18. M = size(D,1);  
  19. P = size(Y,2);  
  20. N = size(D,2);  
  21. m = K*2;  % execute iterations  
  22.   
  23. for idx = 1:P  
  24.     % recover the idx-th column sample  
  25.     y = Y(:,idx);  
  26.     residual = y;  
  27.     Aug_D = [];  
  28.     D1 = D;  
  29.       
  30.     for times = 1:m;  
  31.         product = abs(D1'*residual);  
  32.         [~,pos] = max(product); %  最大投影系数对应的位置  
  33.         Aug_D = [Aug_D, D1(:,pos)];  
  34.         D1(:,pos) = zeros(M,1);    %去掉选中的列  
  35.         indx(times) = pos;  
  36.         Aug_x = (Aug_D'*Aug_D)^-1*Aug_D'*y; %  最小二乘,使残差最小,i.e. x = pinv(Aug_D)*y  
  37.         residual = y - Aug_D*Aug_x;  
  38.           
  39.         if sum(residual.^2)<1e-6  
  40.             break;  
  41.         end  
  42.     end  
  43.     temp = zeros(N,1);  
  44.     temp(indx(1:times)) = Aug_x;  
  45.     X(:,idx) = sparse(temp);  
  46. end  
  47. end  


Regression.m

[cpp]  view plain  copy
  1. %   @Function: Dictionary learning & Regression  
  2. %   Learning Dictionary Given Sparse Representation  
  3. %   @CreateTime: 2013-2-21  
  4. %   @Author: Rachel Zhang  @  http://blog.csdn.net/abcjennifer  
  5. %     
  6. function [ D ] = Regression( Y,X )  
  7. % Y is the sample data to be recovered M*P  
  8. % D is the dictionary M*N  
  9. % X is the sparse coefficient N*P  
  10. % P>N>M  
  11.   
  12. %由于X是扁矩阵,需要转置求D0 = min(D) ||Y^T-X^TD^T||  
  13. %这样就是N个未知数,P个方程去求解;  
  14. %每次解得D中的一列,共解M次  
  15.   
  16. Y = Y';  
  17. X = X';  
  18. P = size(Y,1);  
  19. N = size(X,2);  
  20. M = size(Y,2);  
  21. D = zeros(N,M);  
  22.   
  23. for i = 1:M;  
  24.     y = Y(:,i);  
  25.     D(:,i) = regress(y,X);  
  26. end  
  27. D = D';  
  28. end  



============================================================================


CODE2. KSVD

ksvd函数代码是国外的人写的,很规矩,这里贴过来。

[cpp]  view plain  copy
  1. function [Dictionary,output] = KSVD(...  
  2.     Data,... % an nXN matrix that contins N signals (Y), each of dimension n.  
  3.     param)  
  4. % =========================================================================  
  5. %                          K-SVD algorithm  
  6. % =========================================================================  
  7. % The K-SVD algorithm finds a dictionary for linear representation of  
  8. % signals. Given a set of signals, it searches for the best dictionary that  
  9. % can sparsely represent each signal. Detailed discussion on the algorithm  
  10. % and possible applications can be found in "The K-SVD: An Algorithm for   
  11. % Designing of Overcomplete Dictionaries for Sparse Representation", written  
  12. % by M. Aharon, M. Elad, and A.M. Bruckstein and appeared in the IEEE Trans.   
  13. % On Signal Processing, Vol. 54, no. 11, pp. 4311-4322, November 2006.   
  14. % =========================================================================  
  15. % INPUT ARGUMENTS:  
  16. % Data                         an nXN matrix that contins N signals (Y), each of dimension n.   
  17. % param                        structure that includes all required  
  18. %                                 parameters for the K-SVD execution.  
  19. %                                 Required fields are:  
  20. %    K, ...                    the number of dictionary elements to train  
  21. %    numIteration,...          number of iterations to perform.  
  22. %    errorFlag...              if =0, a fix number of coefficients is  
  23. %                                 used for representation of each signal. If so, param.L must be  
  24. %                                 specified as the number of representing atom. if =1, arbitrary number  
  25. %                                 of atoms represent each signal, until a specific representation error  
  26. %                                 is reached. If so, param.errorGoal must be specified as the allowed  
  27. %                                 error.  
  28. %    preserveDCAtom...         if =1 then the first atom in the dictionary  
  29. %                                 is set to be constant, and does not ever change. This  
  30. %                                 might be useful for working with natural  
  31. %                                 images (in this case, only param.K-1  
  32. %                                 atoms are trained).  
  33. %    (optional, see errorFlag) L,...                 % maximum coefficients to use in OMP coefficient calculations.  
  34. %    (optional, see errorFlag) errorGoal, ...        % allowed representation error in representing each signal.  
  35. %    InitializationMethod,...  mehtod to initialize the dictionary, can  
  36. %                                 be one of the following arguments:   
  37. %                                 * 'DataElements' (initialization by the signals themselves), or:   
  38. %                                 * 'GivenMatrix' (initialization by a given matrix param.initialDictionary).  
  39. %    (optional, see InitializationMethod) initialDictionary,...      % if the initialization method   
  40. %                                 is 'GivenMatrix'this is the matrix that will be used.  
  41. %    (optional) TrueDictionary, ...        % if specified, in each  
  42. %                                 iteration the difference between this dictionary and the trained one  
  43. %                                 is measured and displayed.  
  44. %    displayProgress, ...      if =1 progress information is displyed. If param.errorFlag==0,   
  45. %                                 the average repersentation error (RMSE) is displayed, while if   
  46. %                                 param.errorFlag==1, the average number of required coefficients for   
  47. %                                 representation of each signal is displayed.  
  48. % =========================================================================  
  49. % OUTPUT ARGUMENTS:  
  50. %  Dictionary                  The extracted dictionary of size nX(param.K).  
  51. %  output                      Struct that contains information about the current run. It may include the following fields:  
  52. %    CoefMatrix                  The final coefficients matrix (it should hold that Data equals approximately Dictionary*output.CoefMatrix.  
  53. %    ratio                       If the true dictionary was defined (in  
  54. %                                synthetic experiments), this parameter holds a vector of length  
  55. %                                param.numIteration that includes the detection ratios in each  
  56. %                                iteration).  
  57. %    totalerr                    The total representation error after each  
  58. %                                iteration (defined only if  
  59. %                                param.displayProgress=1 and  
  60. %                                param.errorFlag = 0)  
  61. %    numCoef                     A vector of length param.numIteration that  
  62. %                                include the average number of coefficients required for representation  
  63. %                                of each signal (in each iteration) (defined only if  
  64. %                                param.displayProgress=1 and  
  65. %                                param.errorFlag = 1)  
  66. % =========================================================================  
  67.   
  68. if (~isfield(param,'displayProgress'))  
  69.     param.displayProgress = 0;  
  70. end  
  71. totalerr(1) = 99999;  
  72. if (isfield(param,'errorFlag')==0)  
  73.     param.errorFlag = 0;  
  74. end  
  75.   
  76. if (isfield(param,'TrueDictionary'))  
  77.     displayErrorWithTrueDictionary = 1;  
  78.     ErrorBetweenDictionaries = zeros(param.numIteration+1,1);  
  79.     ratio = zeros(param.numIteration+1,1);  
  80. else  
  81.     displayErrorWithTrueDictionary = 0;  
  82.     ratio = 0;  
  83. end  
  84. if (param.preserveDCAtom>0)  
  85.     FixedDictionaryElement(1:size(Data,1),1) = 1/sqrt(size(Data,1));  
  86. else  
  87.     FixedDictionaryElement = [];  
  88. end  
  89. % coefficient calculation method is OMP with fixed number of coefficients  
  90.   
  91. if (size(Data,2) < param.K)  
  92.     disp('Size of data is smaller than the dictionary size. Trivial solution...');  
  93.     Dictionary = Data(:,1:size(Data,2));  
  94.     return;  
  95. elseif (strcmp(param.InitializationMethod,'DataElements'))  
  96.     Dictionary(:,1:param.K-param.preserveDCAtom) = Data(:,1:param.K-param.preserveDCAtom);  
  97. elseif (strcmp(param.InitializationMethod,'GivenMatrix'))  
  98.     Dictionary(:,1:param.K-param.preserveDCAtom) = param.initialDictionary(:,1:param.K-param.preserveDCAtom);  
  99. end  
  100. % reduce the components in Dictionary that are spanned by the fixed  
  101. % elements  
  102. if (param.preserveDCAtom)  
  103.     tmpMat = FixedDictionaryElement \ Dictionary;  
  104.     Dictionary = Dictionary - FixedDictionaryElement*tmpMat;  
  105. end  
  106. %normalize the dictionary.  
  107. Dictionary = Dictionary*diag(1./sqrt(sum(Dictionary.*Dictionary)));  
  108. Dictionary = Dictionary.*repmat(sign(Dictionary(1,:)),size(Dictionary,1),1); % multiply in the sign of the first element.  
  109. totalErr = zeros(1,param.numIteration);  
  110.   
  111. % the K-SVD algorithm starts here.  
  112.   
  113. for iterNum = 1:param.numIteration  
  114.     % find the coefficients  
  115.     if (param.errorFlag==0)  
  116.         %CoefMatrix = mexOMPIterative2(Data, [FixedDictionaryElement,Dictionary],param.L);  
  117.         CoefMatrix = OMP([FixedDictionaryElement,Dictionary],Data, param.L);  
  118.     else   
  119.         %CoefMatrix = mexOMPerrIterative(Data, [FixedDictionaryElement,Dictionary],param.errorGoal);  
  120.         CoefMatrix = OMPerr([FixedDictionaryElement,Dictionary],Data, param.errorGoal);  
  121.         param.L = 1;  
  122.     end  
  123.       
  124.     replacedVectorCounter = 0;  
  125.     rPerm = randperm(size(Dictionary,2));  
  126.     for j = rPerm  
  127.         [betterDictionaryElement,CoefMatrix,addedNewVector] = I_findBetterDictionaryElement(Data,...  
  128.             [FixedDictionaryElement,Dictionary],j+size(FixedDictionaryElement,2),...  
  129.             CoefMatrix ,param.L);  
  130.         Dictionary(:,j) = betterDictionaryElement;  
  131.         if (param.preserveDCAtom)  
  132.             tmpCoef = FixedDictionaryElement\betterDictionaryElement;  
  133.             Dictionary(:,j) = betterDictionaryElement - FixedDictionaryElement*tmpCoef;  
  134.             Dictionary(:,j) = Dictionary(:,j)./sqrt(Dictionary(:,j)'*Dictionary(:,j));  
  135.         end  
  136.         replacedVectorCounter = replacedVectorCounter+addedNewVector;  
  137.     end  
  138.   
  139.     if (iterNum>1 & param.displayProgress)  
  140.         if (param.errorFlag==0)  
  141.             output.totalerr(iterNum-1) = sqrt(sum(sum((Data-[FixedDictionaryElement,Dictionary]*CoefMatrix).^2))/prod(size(Data)));  
  142.             disp(['Iteration   ',num2str(iterNum),'   Total error is: ',num2str(output.totalerr(iterNum-1))]);  
  143.         else  
  144.             output.numCoef(iterNum-1) = length(find(CoefMatrix))/size(Data,2);  
  145.             disp(['Iteration   ',num2str(iterNum),'   Average number of coefficients: ',num2str(output.numCoef(iterNum-1))]);  
  146.         end  
  147.     end  
  148.     if (displayErrorWithTrueDictionary )   
  149.         [ratio(iterNum+1),ErrorBetweenDictionaries(iterNum+1)] = I_findDistanseBetweenDictionaries(param.TrueDictionary,Dictionary);  
  150.         disp(strcat(['Iteration  ', num2str(iterNum),' ratio of restored elements: ',num2str(ratio(iterNum+1))]));  
  151.         output.ratio = ratio;  
  152.     end  
  153.     Dictionary = I_clearDictionary(Dictionary,CoefMatrix(size(FixedDictionaryElement,2)+1:end,:),Data);  
  154.       
  155.     if (isfield(param,'waitBarHandle'))  
  156.         waitbar(iterNum/param.counterForWaitBar);  
  157.     end  
  158. end  
  159.   
  160. output.CoefMatrix = CoefMatrix;  
  161. Dictionary = [FixedDictionaryElement,Dictionary];  
  162. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  
  163. %  findBetterDictionaryElement  
  164. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  
  165.   
  166. function [betterDictionaryElement,CoefMatrix,NewVectorAdded] = I_findBetterDictionaryElement(Data,Dictionary,j,CoefMatrix,numCoefUsed)  
  167. if (length(who('numCoefUsed'))==0)  
  168.     numCoefUsed = 1;  
  169. end  
  170. relevantDataIndices = find(CoefMatrix(j,:)); % the data indices that uses the j'th dictionary element.  
  171. if (length(relevantDataIndices)<1) %(length(relevantDataIndices)==0)  
  172.     ErrorMat = Data-Dictionary*CoefMatrix;  
  173.     ErrorNormVec = sum(ErrorMat.^2);  
  174.     [d,i] = max(ErrorNormVec);  
  175.     betterDictionaryElement = Data(:,i);%ErrorMat(:,i); %  
  176.     betterDictionaryElement = betterDictionaryElement./sqrt(betterDictionaryElement'*betterDictionaryElement);  
  177.     betterDictionaryElement = betterDictionaryElement.*sign(betterDictionaryElement(1));  
  178.     CoefMatrix(j,:) = 0;  
  179.     NewVectorAdded = 1;  
  180.     return;  
  181. end  
  182.   
  183. NewVectorAdded = 0;  
  184. tmpCoefMatrix = CoefMatrix(:,relevantDataIndices);   
  185. tmpCoefMatrix(j,:) = 0;% the coeffitients of the element we now improve are not relevant.  
  186. errors =(Data(:,relevantDataIndices) - Dictionary*tmpCoefMatrix); % vector of errors that we want to minimize with the new element  
  187. % % the better dictionary element and the values of beta are found using svd.  
  188. % % This is because we would like to minimize || errors - beta*element ||_F^2.   
  189. % % that is, to approximate the matrix 'errors' with a one-rank matrix. This  
  190. % % is done using the largest singular value.  
  191. [betterDictionaryElement,singularValue,betaVector] = svds(errors,1);  
  192. CoefMatrix(j,relevantDataIndices) = singularValue*betaVector';% *signOfFirstElem  
  193.   
  194. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  
  195. %  findDistanseBetweenDictionaries  
  196. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  
  197. function [ratio,totalDistances] = I_findDistanseBetweenDictionaries(original,new)  
  198. % first, all the column in oiginal starts with positive values.  
  199. catchCounter = 0;  
  200. totalDistances = 0;  
  201. for i = 1:size(new,2)  
  202.     new(:,i) = sign(new(1,i))*new(:,i);  
  203. end  
  204. for i = 1:size(original,2)  
  205.     d = sign(original(1,i))*original(:,i);  
  206.     distances =sum ( (new-repmat(d,1,size(new,2))).^2);  
  207.     [minValue,index] = min(distances);  
  208.     errorOfElement = 1-abs(new(:,index)'*d);  
  209.     totalDistances = totalDistances+errorOfElement;  
  210.     catchCounter = catchCounter+(errorOfElement<0.01);  
  211. end  
  212. ratio = 100*catchCounter/size(original,2);  
  213.   
  214. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  
  215. %  I_clearDictionary  
  216. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  
  217. function Dictionary = I_clearDictionary(Dictionary,CoefMatrix,Data)  
  218. T2 = 0.99;  
  219. T1 = 3;  
  220. K=size(Dictionary,2);  
  221. Er=sum((Data-Dictionary*CoefMatrix).^2,1); % remove identical atoms  
  222. G=Dictionary'*Dictionary; G = G-diag(diag(G));  
  223. for jj=1:1:K,  
  224.     if max(G(jj,:))>T2 | length(find(abs(CoefMatrix(jj,:))>1e-7))<=T1 ,  
  225.         [val,pos]=max(Er);  
  226.         Er(pos(1))=0;  
  227.         Dictionary(:,jj)=Data(:,pos(1))/norm(Data(:,pos(1)));  
  228.         G=Dictionary'*Dictionary; G = G-diag(diag(G));  
  229.     end;  
  230. end;  
  • 2
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
RSA是一种非对称加密算法,由三位科学家Rivest、Shamir和Adleman共同发明,在加密和数字签名领域发挥着重要作用。RSA算法基于数论中的两个重要难题:大整数分解和模幂运算。 RSA算法的核心概念是公钥和私钥。在加密过程中,首先需要生成一对密钥,其中一个是公钥,可以公开给其他人使用,而另一个是私钥,必须保密。通过公钥可以将信息进行加密,而只有使用私钥才能解密。 RSA算法的加密过程如下:选择两个大素数p和q,并计算它们的乘积n=p*q作为所需的大整数。再选择一个与(p-1)*(q-1)互质的正整数e作为公钥,其中1 < e < (p-1)*(q-1)。然后计算d,满足(d*e) mod ((p-1)*(q-1)) = 1,并将d作为私钥。公钥对应着(n, e),私钥对应着(n, d)。 对于明文M,加密后得到密文C,加密过程为C = M^e mod n。解密过程为M = C^d mod n。由于大整数分解问题的复杂性,只有获得私钥才能成功解密,保护了通信的安全性。 RSA算法广泛应用于计算机网络和电子商务中,例如在网站上进行数据传输过程中,使用RSA加密算法保护数据的机密性和完整性,确保数据不被窃取或篡改。 需要注意的是,尽管RSA算法在安全性上相对较好,但其加解密过程消耗较大的计算资源,在处理大量数据时效率可能较低。因此,在实际应用中,常常将RSA与其他加密算法结合使用,以平衡安全性和效率的要求。 总之,RSA算法作为一种非对称加密算法,通过公钥和私钥的配对实现信息的加密和解密。它在数据安全领域的应用广泛,为保护通信和数据的安全做出了重要贡献。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值