关闭

SVM算法实现(一)

491人阅读 评论(0) 收藏 举报
分类:

关键字(keywords):SVM 支持向量机 SMO算法 实现 机器学习    

 

      如果对SVM原理不是很懂的,可以先看一下入门的视频,对帮助理解很有用的,然后再深入一点可以看看这几篇入门文章,作者写得挺详细,看完以后SVM的基础就了解得差不多了,再然后买本《支持向量机导论》作者是Nello Cristianini 和 John Shawe-Taylor,电子工业出版社的。然后把书本后面的那个SMO算法实现就基本上弄懂了SVM是怎么一回事,最后再编写一个SVM库出来,比如说像libsvm等工具使用,呵呵,差不多就这样。这些是我学习SVM的整个过程,也算是经验吧。

      下面是SVM的简化版SMO算法,我将结合Java代码来解释一下整个SVM的学习训练过程,即所谓的train训练过程。那么什么是SMO算法呢?

 SMO算法的目的无非是找出一个函数f(x),这个函数能让我们把输入的数据x进行分类。既然是分类肯定需要一个评判的标准,比如分出来有两种情况A和B,那么怎么样才能说x是属于A类的,或不是B类的呢?就是需要有个边界,就好像两个国家一样有边界,如果边界越明显,则就越容易区分,因此,我们的目标是最大化边界的宽度,使得非常容易的区分是A类还是B类。

 在SVM中,要最大化边界则需要最小化这个数值:

 

 

w:是参量,值越大边界越明显 
C代表惩罚系数,即如果某个x是属于某一类,但是它偏离了该类,跑到边界上后者其他类的地方去了,C越大表明越不想放弃这个点,边界就会缩小
代表:松散变量
但问题似乎还不好解,又因为SVM是一个凸二次规划问题,凸二次规划问题有最优解,于是问题转换成下列形式(KKT条件):

 …………(1)

这里的ai是拉格朗日乘子(问题通过拉格朗日乘法数来求解)
对于(a)的情况,表明ai是正常分类,在边界内部(我们知道正确分类的点yi*f(xi)>=0)
对于(b)的情况,表明了ai是支持向量,在边界上
对于(c)的情况,表明了ai是在两条边界之间
而最优解需要满足KKT条件,即满足(a)(b)(c)条件都满足
以下几种情况出现将会出现不满足:

yiui<=1但是ai<C则是不满足的,而原本ai=C
yiui>=1但是ai>0则是不满足的而原本ai=0
yiui=1但是ai=0或者ai=C则表明不满足的,而原本应该是0<ai<C
所以要找出不满足KKT的这些ai,并更新这些ai,但这些ai又受到另外一个约束,即

 

 

因此,我们通过另一个方法,即同时更新ai和aj,满足以下等式

 

就能保证和为0的约束。

 

利用yiai+yjaj=常数,消去ai,可得到一个关于单变量aj的一个凸二次规划问题,不考虑其约束0<=aj<=C,可以得其解为:

 

 ………………………………………(2)

这里………………(3)

表示旧值,然后考虑约束0<=aj<=C可得到a的解析解为:

…………(4)

对于

那么如何求得ai和aj呢?

对于ai,即第一个乘子,可以通过刚刚说的那几种不满足KKT的条件来找,第二个乘子aj可以找满足条件 

 

…………………………………………………………………………(5)

b的更新:

在满足条件:下更新b。……………(6)

 

最后更新所有ai,y和b,这样模型就出来了,然后通过函数:

 ……………………………………………………(7)

输入是x,是一个数组,组中每一个值表示一个特征。

输出是A类还是B类。(正类还是负类)

 

以下是主要的代码段:

 

  1. /* 
  2.  * 默认输入参数值 
  3.  * C: regularization parameter 
  4.  * tol: numerical tolerance 
  5.  * max passes 
  6.  */  
  7. double C = 1//对不在界内的惩罚因子  
  8. double tol = 0.01;//容忍极限值  
  9. int maxPasses = 5//表示没有改变拉格朗日乘子的最多迭代次数  
  10.   
  11. /* 
  12.  * 初始化a[], b, passes  
  13.  */  
  14.   
  15. double a[] = new double[x.length];//拉格朗日乘子  
  16. this.a = a;  
  17.   
  18. //将乘子初始化为0  
  19. for (int i = 0; i < x.length; i++) {  
  20.     a[i] = 0;  
  21. }  
  22. int passes = 0;  
  23.   
  24.   
  25. while (passes < maxPasses) {  
  26.     //表示改变乘子的次数(基本上是成对改变的)  
  27.     int num_changed_alphas = 0;  
  28.     for (int i = 0; i < x.length; i++) {  
  29.         //表示特定阶段由a和b所决定的输出与真实yi的误差  
  30.         //参照公式(7)  
  31.         double Ei = getE(i);  
  32.         /* 
  33.          * 把违背KKT条件的ai作为第一个 
  34.          * 满足KKT条件的情况是: 
  35.          * yi*f(i) >= 1 and alpha == 0 (正确分类) 
  36.          * yi*f(i) == 1 and 0<alpha < C (在边界上的支持向量) 
  37.          * yi*f(i) <= 1 and alpha == C (在边界之间) 
  38.          *  
  39.          *  
  40.          *  
  41.          * ri = y[i] * Ei = y[i] * f(i) - y[i]^2 >= 0 
  42.          * 如果ri < 0并且alpha < C 则违反了KKT条件 
  43.          * 因为原本ri < 0 应该对应的是alpha = C 
  44.          * 同理,ri > 0并且alpha > 0则违反了KKT条件 
  45.          * 因为原本ri > 0对应的应该是alpha =0 
  46.          */  
  47.         if ((y[i] * Ei < -tol && a[i] < C) ||  
  48.             (y[i] * Ei > tol && a[i] > 0))   
  49.         {  
  50.             /* 
  51.              * ui*yi=1边界上的点 0 < a[i] < C 
  52.              * 找MAX|E1 - E2| 
  53.              */  
  54.             int j;  
  55.             /* 
  56.              * boundAlpha表示x点处于边界上所对应的 
  57.              * 拉格朗日乘子a的集合 
  58.              */  
  59.             if (this.boundAlpha.size() > 0) {  
  60.                 //参照公式(5)  
  61.                 j = findMax(Ei, this.boundAlpha);  
  62.             } else   
  63.                 //如果边界上没有,就随便选一个j != i的aj  
  64.                 j = RandomSelect(i);  
  65.               
  66.             double Ej = getE(j);  
  67.               
  68.             //保存当前的ai和aj  
  69.             double oldAi = a[i];  
  70.             double oldAj = a[j];  
  71.               
  72.             /* 
  73.              * 计算乘子的范围U, V 
  74.              * 参考公式(4) 
  75.              */  
  76.             double L, H;  
  77.             if (y[i] != y[j]) {  
  78.                 L = Math.max(0, a[j] - a[i]);  
  79.                 H = Math.min(C, C - a[i] + a[j]);  
  80.             } else {  
  81.                 L = Math.max(0, a[i] + a[j] - C);  
  82.                 H = Math.min(0, a[i] + a[j]);  
  83.             }  
  84.               
  85.               
  86.             /* 
  87.              * 如果eta等于0或者大于0 则表明a最优值应该在L或者U上 
  88.              */  
  89.             double eta = 2 * k(i, j) - k(i, i) - k(j, j);//公式(3)  
  90.               
  91.             if (eta >= 0)  
  92.                 continue;  
  93.               
  94.             a[j] = a[j] - y[j] * (Ei - Ej)/ eta;//公式(2)  
  95.             if (0 < a[j] && a[j] < C)  
  96.                 this.boundAlpha.add(j);  
  97.               
  98.             if (a[j] < L)   
  99.                 a[j] = L;  
  100.             else if (a[j] > H)   
  101.                 a[j] = H;  
  102.               
  103.             if (Math.abs(a[j] - oldAj) < 1e-5)  
  104.                 continue;  
  105.             a[i] = a[i] + y[i] * y[j] * (oldAj - a[j]);  
  106.             if (0 < a[i] && a[i] < C)  
  107.                 this.boundAlpha.add(i);  
  108.               
  109.             /* 
  110.              * 计算b1, b2 
  111.              * 参照公式(6) 
  112.              */  
  113.             double b1 = b - Ei - y[i] * (a[i] - oldAi) * k(i, i) - y[j] * (a[j] - oldAj) * k(i, j);  
  114.             double b2 = b - Ej - y[i] * (a[i] - oldAi) * k(i, j) - y[j] * (a[j] - oldAj) * k(j, j);  
  115.               
  116.             if (0 < a[i] && a[i] < C)  
  117.                 b = b1;  
  118.             else if (0 < a[j] && a[j] < C)  
  119.                 b = b2;  
  120.             else   
  121.                 b = (b1 + b2) / 2;  
  122.               
  123.             num_changed_alphas = num_changed_alphas + 1;  
  124.         }  
  125.     }  
  126.     if (num_changed_alphas == 0) {  
  127.         passes++;  
  128.     } else   
  129.         passes = 0;  
  130. }  
  131.   
  132. return new SVMModel(a, y, b);  

运行后的结果还算可以吧,测试数据主要是用了libsvm的heart_scale的数据。

预测的正确率达到73%以上。

如果我把核函数从线性的改为基于RBF将会更好点。

最后,说到SVM算法实现包,应该有很多,包括svm light,libsvm,有matlab本身自带的svm工具包等。

 

 

 

另外,完整的代码,我将上传到CSDN下载地址上提供下载。

点击这里下载

 

如理解有误敬请指正!谢谢!

我的邮箱:chen-hongqin@163.com

我的其他博客:

百度:http://hi.baidu.com/futrueboy/home

javaeye:http://futrueboy.javaeye.com/

CSDN: http://blog.csdn.net/techq

 


网址:http://blog.csdn.net/techq/article/details/6171688

0
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:196900次
    • 积分:2524
    • 等级:
    • 排名:第15091名
    • 原创:3篇
    • 转载:325篇
    • 译文:0篇
    • 评论:8条
    最新评论