由于工作中要用到SVR算法,项目组的系统是用java开发的,因此,为了能与项目组同步,算法需要用java来实现,还好台湾大学的林智仁教授推出了Libsvm的源代码,包括java、c++等语言的源代码,在此表示感谢!因此,算法的主体部分不用自己开发了,在源代码的基础上做一些修改就能够应用到自己的项目中了,开源真好!受益了无数人。。。为了弘扬开源的精神,开博记录学习Libsvmjava版源代码的过程。下面正式开始!先从SVR回归算法的代码开始,然后逐步扩展到分类算法。希望自己能够坚持下去。加油。
---------------------我是华丽的分割线-----------高手的小jj,我割割割--嘻嘻-------------------
一、初识Libsvm
Libsvmjava版本的源代码很容易下载到,为了使算法能够运行,要先把源代码复制到myeclipse中(貌似是废话),在这里,提供一个链接,里面有很好的说明,按照文档中的说明就能够将Libsvm运行起来。
链接地址:http://wenku.baidu.com/view/54cfa92b453610661ed9f4f6.html
按照上述链接中的方式,自己新建一个main函数,来调用Libsvm算法的源代码,代码如下:
public static void main(String[] args) throws IOException {
String[]arg ={ "trainfile\\train1.txt", //存放SVM训练模型用的数据的路径
"trainfile\\model_r.txt"}; //存放SVM通过训练数据训练出来的模型的路径
String[]parg={"trainfile\\test2.txt", //这个是存放测试数据
"trainfile\\model_r.txt", //调用的是训练以后的模型
"trainfile\\out_r.txt"}; //生成的结果的文件的路径
System.out.println("........SVM运行开始..........");
//创建一个训练对象
svm_traint = new svm_train();
//创建一个预测或者分类的对象
svm_predictp= new svm_predict();
t.main(arg); //调用
p.main(parg); //调用
}
注意:
1. 该主函数是为了调用Libsvm源代码的,创建了两个字符数组:arg[]和parg[]。
其中arg[]数组存放了两个字符串,"trainfile\\train1.txt"和"trainfile\\model_r.txt",这两个字符串传入到Libsvm中的svm_train();
在trainfile\\train1.txt这个文件里面存储的是训练数据,训练数据的格式如下:
Y index1:特征1 index2:特征2 ....
<lable1> 1:特征1 2:特征2 3:特征3 ...
<lable2> 1:特征1 2:特征2 3:特征3 ...
其中如果是分类问题,lable为类标签。如果是回归问题,lable为具体的实数。为了生成这种格式的数据,可以采用FormatDataLibsvm.exl这个excel文件生成,网上可以下的到。当然,也可以自己写java代码,来生成这种格式的数据。代码随后奉上。而trainfile\\model_r.txt这个文件的作用是存储利用SVM算法训练好的模型。
而parg[]这个数组存放了3个字符串:trainfile\\test2.txt---用于存储测试样本的文件,数据格式与训练样本的格式一样、trainfile\\model_r.txt---前面生成的训练好的模型,在用新样本进行预测时直接使用前文训练好的模型即可、trainfile\\model_r.txt---用于存储模型的预测值,分类问题的话存储的是预测样本每条样本所属的类别,而回归问题的话,存储的是每条样本所对应的预测值。
2. svm_train()与svm_predict()。这两个函数是Libsvm程序包中的源代码,由此即进入到了Libsvm源代码的世界了。
二、Libsvm的真面目
由于本篇是介绍支持向量回归机--SVR的,所以仅从用到SVR算法的代码入手,来分析Libsvm的源代码。分析时采用逐层进入的方式。下面直接上程序代码目录:
+++++++目录+++++
前文是首先调用svm_train()然后调用svm_predict()。那就先从svm_train()说起。
classsvm_train{}包含以下几个变量以及函数:
private svm_parameterparam; // 用于设置svm模型的参数
private svm_problemprob; // 用来存储样本序号、样本的目标变量Y、样本自变量X.详看classsvm_problem
private svm_modelmodel;// ??
private Stringinput_file_name; // 输入文件名
private Stringmodel_file_name; // 模型文件名
private Stringerror_msg;//错误信息
private intcross_validation;//交叉验证
private intnr_fold;// ??
private staticsvm_print_interface svm_print_null = newsvm_print_interface() //
private static voidexit_with_help() //打印帮助信息
private voiddo_cross_validation() //交叉验证
private void run(Stringargv[]) throws IOException //运行svm训练程序
public static voidmain(String argv[]) throws IOException// 主函数
private static doubleatof(String s) //将字符串转化为浮点型
private static intatoi(String s) //将字符串转化为int型
private voidparse_command_line(String argv[]) //设置参数
private voidread_problem() throws IOException //读取错误信息
前文自己写的那个函数,调用的svm训练模型,即
svm_train t = new svm_train();
t.main(arg); //调用svm训练模型
程序进入到class svm_train{ }中的main函数,即 publicstatic void main(String argv[]) throws IOException该主函数的代码如下:
public static void main(String argv[])throws IOException
{
svm_train t = newsvm_train();
t.run(argv);//传进来一个数组,数组里面有两个字符串,一个是训练样本.txt,一个是训练好的模型.txt
}
继续执行run()函数,输入为数组argv[].
private void run(String argv[]) throwsIOException{
parse_command_line(argv); // 1.进入到该函数中,获取SVM参数
read_problem(); // 2.进入到该函数中,读取错误信息
error_msg =svm.svm_check_parameter(prob,param);// 3.检查参数
//检查参数,有错误则返回各种参数错误信息,无错误则返回null;
if(error_msg != null)
{
System.err.print("ERROR: "+error_msg+"\n");
System.exit(1);
}
if(cross_validation != 0)
{
do_cross_validation(); // 4.交叉验证
}
else
{
model = svm.svm_train(prob,param); //5.prob--训练样本,param--SVM模型参数
svm.svm_save_model(model_file_name,model); //6.保存训练好的模型
}
}
注:该函数一共调用了6个函数,下文一一说明。
首先进入函数1:parse_command_line(argv); // 1.进入到该函数中,获取SVM参数。该函数的输入为argv[],即两个字符串:一个是训练样本.txt,一个是训练好的模型.txt。该函数虽然无返回值,但在函数里面,已经将svm的一些参数存储在param中了,详细参数名称见classsvm_parameter,因此模型训练时已经有了所需要的各种参数。函数的详细代码如下:
private void _lineparse_command(Stringargv[])
{
int i;
svm_print_interface print_func= null; // default printing to stdout
param = newsvm_parameter();//开始设置SVM模型的各种参数
// default values
//param.svm_type =svm_parameter.C_SVC;
param.svm_type = svm_parameter.EPSILON_SVR;//此时运行的是SVR算法
param.kernel_type = svm_parameter.RBF;//核函数取径向基核函数
param.degree = 3; //??
param.gamma = 0.08;
//gamma为RBF核函数的参数,默认时=1/num_features 此时设置为0.08 gamma=1/2*sig^2sig=2.5
//RBF核函数:exp(-gamma*|Xi-Xj|^2)
param.coef0 = 0; //??
param.nu= 0.5; //??
param.cache_size = 100; //设置缓存的大小
param.C= 100; //惩罚参数
param.eps = 0.005; //??
param.p= 0.001; //此值为EPSILON_SVR中EPSILON
param.shrinking = 1; //??
param.probability = 0; //概率估计??
param.nr_weight = 0; //权重??
param.weight_label = new int[0]; //??
param.weight = new double[0];
cross_validation = 0;//交叉验证。0--不进行交叉验证。1--交叉验证
//获取输入参数
// parse options
for(i=0;i<argv.length;i++) //argv.length=2,即有两个字符串
{
if(argv[i].charAt(0) != '-') break;
//由于第一个字符(trainfile\\train1.txt)中的第一个字符不是‘-’,果断break!退出for循环。i=0
if(++i>=argv.length)
exit_with_help();
switch(argv[i-1].charAt(1))
{
case 's':
param.svm_type = atoi(argv[i]);
break;
case 't':
param.kernel_type = atoi(argv[i]);
break;
case 'd':
param.degree = atoi(argv[i]);
break;
case 'g':
param.gamma = atof(argv[i]);
break;
case 'r':
param.coef0 = atof(argv[i]);
break;
case 'n':
param.nu= atof(argv[i]);
break;
case 'm':
param.cache_size = atof(argv[i]);
break;
case'c':
param.C = atof(argv[i]);
break;
case'e':
param.eps = atof(argv[i]);
break;
case 'p':
param.p = atof(argv[i]);
break;
case 'h':
param.shrinking = atoi(argv[i]);
break;
case 'b':
param.probability= atoi(argv[i]);
break;
case 'q':
print_func = svm_print_null;
i--;
break;
case 'v':
cross_validation = 1;
nr_fold = atoi(argv[i]);
if(nr_fold < 2)
{
System.err.print("n-fold cross validation: n must >=2\n");
exit_with_help();
}
break;
case'w':
++param.nr_weight;
{
int[] old = param.weight_label;
param.weight_label = new int[param.nr_weight];
System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);
}
{
double[] old = param.weight;
param.weight = new double[param.nr_weight];
System.arraycopy(old,0,param.weight,0,param.nr_weight-1);
}
param.weight_label[param.nr_weight-1] =atoi(argv[i-1].substring(2));
param.weight[param.nr_weight-1] = atof(argv[i]);
break;
default:
System.err.print("Unknown option: " + argv[i-1] + "\n");
exit_with_help();
} //switch循环结束
}//for循环结束
svm.svm_set_print_string_function(print_func);//1.1打印,详见下文说明2;
// determine filenames
if(i>=argv.length) //argv.length=2,而i=0,不执行此语句
exit_with_help();
input_file_name =argv[i]; //训练样本的文件名,即trainfile\\data_train_svr.txt
if(i<argv.length-1)//i=0,argv.length-1=1,符合条件
model_file_name = argv[i+1];//模型文件名,即trainfile\\model_r.txt
else//此时不执行下面语句
{
int p = argv[i].lastIndexOf('/');
++p; //whew...
model_file_name = argv[i].substring(p)+".model";
}
}//函数parse_command_line结束
说明:
1.该函数的功能是初始化svm模型的各种参数,本篇用的是SVR算法,初始化了一些参数。
2.该函数调用了一个函数,即函数1.1,由于该函数的输入是print_func =null,经过调用,其输出也为空,即不打印任何信息,因此本文不予深入说明。
此时程序进入函数2:read_problem() ; // 2.进入到该函数中,读取错误信息
该函数无返回值,但在函数体内针对两个错误,用打印输出语句打印出相应的错误信息,其中一个错误信息为:核函数的第一列的标签必须从0开始编号。如果不是从0编号,则打印输出此错误信息。第二个错误为:样本格式有错误,如果样本的编号标签小于0或者样本的编号标签值大于样本的实际个数,则打印输出该错误信息。
private voidread_problem() throwsIOException
{
BufferedReaderfp = new BufferedReader(new FileReader(input_file_name));
Vector<Double>vy = new Vector<Double>();
Vector<svm_node[]> vx = newVector<svm_node[]>();
int max_index = 0;
while(true)
{
String line = fp.readLine();
if(line == null) break;
StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
vy.addElement(atof(st.nextToken()));//atof--将字符串转化为数字
int m = st.countTokens()/2;//训练样本的特征个数,Y,X1,X2...Xm-1.共m个
svm_node[] x = new svm_node[m];
for(int j=0;j<m;j++)
{
x[j] = new svm_node();
x[j].index = atoi(st.nextToken());
x[j].value = atof(st.nextToken());
}
if(m>0) max_index = Math.max(max_index,x[m-1].index);
vx.addElement(x);
}
prob = newsvm_problem();
prob.l = vy.size();
prob.x = newsvm_node[prob.l][];
for(inti=0;i<prob.l;i++)
prob.x[i] = vx.elementAt(i);
prob.y = newdouble[prob.l];
for(inti=0;i<prob.l;i++)
prob.y[i] = vy.elementAt(i);
if(param.gamma == 0&& max_index >0)
param.gamma = 1.0/max_index;
if(param.kernel_type ==svm_parameter.PRECOMPUTED)
for(int i=0;i<prob.l;i++)
{
if(prob.x[i][0].index != 0)
{
System.err.print("Wrong kernel matrix: first column must be0:sample_serial_number\n");
System.exit(1);
}
if ((int)prob.x[i][0].value <= 0 ||(int)prob.x[i][0].value > max_index)
{
System.err.print("Wrong input format:sample_serial_number out of range\n");
System.exit(1);
}
}
fp.close();
}
}
此时程序进入函数3:error_msg =svm.svm_check_parameter(prob,param);// 3.检查参数
该函数是在class svm中,功能是检查svm模型的参数是否正确。
public static Stringsvm_check_parameter(svm_problem prob, svm_parameterparam)
该函数的输入是svm_problemprob和svm_parameterparam两个类,其中类svm_problem如下:
public class svm_problemimplements java.io.Serializable
{
public intl;//训练样本中,样本的标签,即第l个训练样本
publicdouble[] y;//训练样本的目标变量Y
publicsvm_node[][] x;//训练样本的自变量X
}
svm_parameter则是svm所需要的各种参数。
输出则是一个字符串,如果匹配到相应的错误,则输出其错误信息,如果没有错误,则返回NULL.
此时程序进入函数4:do_cross_validation(); // 4.交叉验证
由于SVR算法不需要交叉验证,故不执行此函数。而对于分类而言,执行交叉验证操作可增强算法的推广能力。
这里留作以后详细研究。
此时程序进入函数5:model =svm.svm_train(prob,param); //prob--训练样本,param--SVM模型参数
此时进入到了SVM/SVR算法的关键环节--训练模型。欲知详情,请听下回分解。
-----格格-----------------
参考文献:
1.http://wenku.baidu.com/view/54cfa92b453610661ed9f4f6.html-----很好的libsvm安装使用说明