libsvm 代码分析

原创 2013年12月03日 13:34:30
void Solve(int l, QMatrix Q, double[] p_, byte[] y_,
  double[] alpha_, double Cp, double Cn, double eps, SolutionInfo si, int shrinking)
{

this.l = l;
this.Q = Q;
QD = Q.get_QD();
p = (double[])p_.clone();
y = (byte[])y_.clone();
alpha = (double[])alpha_.clone();
this.Cp = Cp;
this.Cn = Cn;
this.eps = eps;//停止迭代可以容忍的误差
this.unshrink = false;//对内存的处理


// initialize alpha_status
{
alpha_status = new byte[l];//新建一个长度为L的byte数组
for(int i=0;i<l;i++)
update_alpha_status(i);//调用改方法进行初始化,根据alpha的值,判定该样本属于是否属于支持向量support vector,分对的点,和分错的点
}


// initialize active set (for shrinking)
{
active_set = new int[l];
for(int i=0;i<l;i++)
active_set[i] = i;
active_size = l;
}


// initialize gradient
{
G = new double[l];
G_bar = new double[l];
int i;
for(i=0;i<l;i++)
{
G[i] = p[i];
G_bar[i] = 0;
}
for(i=0;i<l;i++)
if(!is_lower_bound(i))
{
float[] Q_i = Q.get_Q(i,l);
double alpha_i = alpha[i];
int j;
for(j=0;j<l;j++)
G[j] += alpha_i*Q_i[j];
if(is_upper_bound(i))
for(j=0;j<l;j++)
G_bar[j] += get_C(i) * Q_i[j];
}
}


// optimization step


int iter = 0;
int max_iter = Math.max(10000000, l>Integer.MAX_VALUE/100 ? Integer.MAX_VALUE : 100*l);
int counter = Math.min(l,1000)+1;
int[] working_set = new int[2];


while(iter < max_iter)
{
// show progress and do shrinking


if(--counter == 0)
{
counter = Math.min(l,1000);
if(shrinking!=0) do_shrinking();
svm.info(".");
}


if(select_working_set(working_set)!=0)
{
// reconstruct the whole gradient
reconstruct_gradient();
// reset active set size and check
active_size = l;
svm.info("*");
if(select_working_set(working_set)!=0)
break;
else
counter = 1; // do shrinking next iteration
}

int i = working_set[0];
int j = working_set[1];


++iter;


// update alpha[i] and alpha[j], handle bounds carefully


float[] Q_i = Q.get_Q(i,active_size);
float[] Q_j = Q.get_Q(j,active_size);


double C_i = get_C(i);
double C_j = get_C(j);


double old_alpha_i = alpha[i];
double old_alpha_j = alpha[j];


if(y[i]!=y[j])
{
double quad_coef = QD[i]+QD[j]+2*Q_i[j];
if (quad_coef <= 0)
quad_coef = 1e-12;
double delta = (-G[i]-G[j])/quad_coef;
double diff = alpha[i] - alpha[j];
alpha[i] += delta;
alpha[j] += delta;

if(diff > 0)
{
if(alpha[j] < 0)
{
alpha[j] = 0;
alpha[i] = diff;
}
}
else
{
if(alpha[i] < 0)
{
alpha[i] = 0;
alpha[j] = -diff;
}
}
if(diff > C_i - C_j)
{
if(alpha[i] > C_i)
{
alpha[i] = C_i;
alpha[j] = C_i - diff;
}
}
else
{
if(alpha[j] > C_j)
{
alpha[j] = C_j;
alpha[i] = C_j + diff;
}
}
}
else
{
double quad_coef = QD[i]+QD[j]-2*Q_i[j];
if (quad_coef <= 0)
quad_coef = 1e-12;
double delta = (G[i]-G[j])/quad_coef;
double sum = alpha[i] + alpha[j];
alpha[i] -= delta;
alpha[j] += delta;


if(sum > C_i)
{
if(alpha[i] > C_i)
{
alpha[i] = C_i;
alpha[j] = sum - C_i;
}
}
else
{
if(alpha[j] < 0)
{
alpha[j] = 0;
alpha[i] = sum;
}
}
if(sum > C_j)
{
if(alpha[j] > C_j)
{
alpha[j] = C_j;
alpha[i] = sum - C_j;
}
}
else
{
if(alpha[i] < 0)
{
alpha[i] = 0;
alpha[j] = sum;
}
}
}


// update G


double delta_alpha_i = alpha[i] - old_alpha_i;
double delta_alpha_j = alpha[j] - old_alpha_j;


for(int k=0;k<active_size;k++)
{
G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
}


// update alpha_status and G_bar


{
boolean ui = is_upper_bound(i);
boolean uj = is_upper_bound(j);
update_alpha_status(i);
update_alpha_status(j);
int k;
if(ui != is_upper_bound(i))
{
Q_i = Q.get_Q(i,l);
if(ui)
for(k=0;k<l;k++)
G_bar[k] -= C_i * Q_i[k];
else
for(k=0;k<l;k++)
G_bar[k] += C_i * Q_i[k];
}


if(uj != is_upper_bound(j))
{
Q_j = Q.get_Q(j,l);
if(uj)
for(k=0;k<l;k++)
G_bar[k] -= C_j * Q_j[k];
else
for(k=0;k<l;k++)
G_bar[k] += C_j * Q_j[k];
}
}


}

if(iter >= max_iter)
{
if(active_size < l)
{
// reconstruct the whole gradient to calculate objective value
reconstruct_gradient();
active_size = l;
svm.info("*");
}
System.err.print("\nWARNING: reaching max number of iterations\n");
}


// calculate rho


si.rho = calculate_rho();


// calculate objective value
{
double v = 0;
int i;
for(i=0;i<l;i++)
v += alpha[i] * (G[i] + p[i]);


si.obj = v/2;
}


// put back the solution
{
for(int i=0;i<l;i++)
alpha_[active_set[i]] = alpha[i];
}


si.upper_bound_p = Cp;
si.upper_bound_n = Cn;


svm.info("\noptimization finished, #iter = "+iter+"\n");

}

相关文章推荐

java学习--Libsvm java版代码注释及详解(一)

由于工作中要用到SVR算法,项目组的系统是用java开发的,因此,为了能与项目组同步,算法需要用java来实现,还好台湾大学的林智仁教授推出了Libsvm的源代码,包括java、c++等语言的源代码,...

libsvm源代码注释+算法描述:svm_train

(I will try my best to make this note clearer.)We mainly focus on solve_c_svc in this note.Our goal:...

libsvm核心代码分析

  • 2016年11月22日 09:47
  • 701KB
  • 下载

蓝屏代码分析

  • 2016年04月30日 21:26
  • 15KB
  • 下载

libsvm代码阅读:关于Kernel类分析

这一篇博文来分析下Kernel类,代码上很简单,一般都能看懂。Kernel类主要是为SVM的核函数服务的,里面实现了SVM常用的核函数,通过函数指针来使用这些核函数。 其中几个常用核函数如下所示:(一...
  • Linoi
  • Linoi
  • 2014年02月21日 12:16
  • 2466

ucoss中os-tmr.c中的代码分析

  • 2017年07月21日 14:40
  • 51KB
  • 下载

主机代码分析COPYBOOK长度计算

  • 2013年09月17日 22:22
  • 543KB
  • 下载

libsvm代码阅读:关于Solver类分析(二)

如果你看完了上篇博文的伪代码,那么我们就可以开始谈谈它的源代码了。 // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918 // ...
  • Linoi
  • Linoi
  • 2014年02月22日 22:15
  • 3198

ARM启动代码分析

  • 2013年12月20日 09:48
  • 99KB
  • 下载

Lucene 3.0 原理与代码分析

  • 2014年03月02日 22:42
  • 10.61MB
  • 下载
内容举报
返回顶部
收藏助手
不良信息举报
您举报文章:libsvm 代码分析
举报原因:
原因补充:

(最多只允许输入30个字)